├── .gitignore ├── LICENSE ├── README.md ├── configs ├── parameters_cdr.yaml └── parameters_gda.yaml ├── data_processing ├── 2017MeshTree.txt ├── __pycache__ │ ├── readers.cpython-36.pyc │ ├── tools.cpython-36.pyc │ └── utils.cpython-36.pyc ├── bin2txt.py ├── convertvec ├── convertvec.c ├── filter_hypernyms.py ├── gda2pubtator.py ├── process.py ├── process_cdr.sh ├── process_gda.sh ├── readers.py ├── reduce_embeds.py ├── split_gda.py ├── statistics.py ├── tools.py ├── train_gda_docs └── utils.py ├── embeds └── PubMed-CDR.txt ├── evaluation └── evaluate.py ├── network.svg ├── requirements.txt └── src ├── __init__.py ├── __pycache__ ├── converter.cpython-36.pyc ├── dataset.cpython-36.pyc ├── loader.cpython-36.pyc ├── reader.cpython-36.pyc └── utils.cpython-36.pyc ├── bin └── run.sh ├── converter.py ├── dataset.py ├── eog.py ├── loader.py ├── nnet ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-36.pyc │ ├── attention.cpython-36.pyc │ ├── init_net.cpython-36.pyc │ ├── modules.cpython-36.pyc │ ├── network.cpython-36.pyc │ ├── trainer.cpython-36.pyc │ └── walks.cpython-36.pyc ├── attention.py ├── init_net.py ├── modules.py ├── network.py ├── trainer.py └── walks.py ├── reader.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | *.idea 3 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | LICENCES 2 | 3 | NaCTeM Non-Commercial No Redistribution Software Licence 4 | 5 | 1. Definitions 6 | 1.1 In this Agreement unless the context otherwise requires the following 7 | words shall have the following meanings: 8 | "Adaptations" means any adaptation, alteration, addition to, deletion 9 | from, manipulation, or modification of parts of the resource; 10 | "Agreement" means these terms and conditions contained herein; 11 | "Commercial Use" means the use of the whole or parts of the Software 12 | for any reason which generates a profit. 13 | "Intellectual Property" means all Intellectual Property Rights 14 | pertaining to and subsisting in any country throughout the world 15 | including but not by way of limitation patents, trade marks and/or 16 | service marks, designs (whether registered or unregistered), utility 17 | models, all applications for any of the foregoing, copyrights, design 18 | rights, trade or business names and confidential know-how including any 19 | licences in connections with the same. 20 | "NaCTeM" means the National Centre for Text Mining. 21 | "Non-Commercial Purpose" means the use of the Software solely for 22 | internal non-commercial research and academic purposes. Non-Commercial 23 | Purpose excludes, without limitation, any use of the Software, as part 24 | of, or in any way in connection with a product or service which is sold, 25 | offered for sale, licensed, leased, loaned, or rented. 26 | "Software" means the materials distributed with this licence for use by 27 | the Users. 28 | "Software Owner" means the person or body who created the Software and 29 | which owns the Intellectual Property in the Software. 30 | "University" means the University of Manchester and any of its servants, 31 | agents, successors and assigns and is a registered charity under 32 | Section 3 (5)(a) of the Charities Act 1993. 33 | "Users" means any person who is permitted to access the Software. 34 | 1.2 Words in the singular shall include the plural and vice versa, references 35 | to any gender shall include the others and references to legal persons shall 36 | include natural persons and vice versa. 37 | 1.3 The headings in these conditions are intended for reference only and shall 38 | not affect their construction. 39 | 40 | 2. Permitted Uses 41 | 2.1 The User may use the Software for Non-Commercial Purposes only and may: 42 | 2.1.1 electronically save the whole or any part or parts of the Software; 43 | 2.1.2 print out single copies of the whole or any part or parts of the 44 | Software; 45 | 2.1.3 make Adaptations to any parts of the Software; 46 | 2.1.4 use those parts of the Software which are the Intellectual Property 47 | of a third party in accordance with the licence terms granted by 48 | the respective Software Owner. 49 | 50 | 3. Restrictions 51 | 3.1 The User may not and may not authorise any other third party to: 52 | 3.1.1 sell or resell or otherwise make the information contained in the 53 | Software available in any manner or on any media to any one else 54 | without the written consent of the Software Owner; 55 | 3.1.2 remove, obscure or modify copyright notices, text acknowledging 56 | or other means of identification or disclaimers as they may appear 57 | without prior written permission of the Software Owner; 58 | 3.1.3 use all or any part of the Software for any Commercial Use or for 59 | any purpose other than Non-Commercial Purposes unless with the 60 | written consent of the Software Owner. 61 | 3.1.4 use any Adaptation of the Software except in accordance with all 62 | the terms of this Agreement unless with the written consent of the 63 | Software Owner. 64 | 3.2 This Clause shall survive termination of this Agreement for any reason. 65 | 66 | 4. Acknowledgement and Protection of Intellectual Property Rights 67 | 4.1 The User acknowledges that all copyrights, patent rights, trademarks, 68 | database rights, trade secrets and other intellectual property rights 69 | relating to the Software, are the property of the Software Owner or 70 | duly licensed and that this Agreement does not assign or transfer to 71 | the User any right, title or interest therein except for the right to 72 | use the Software in accordance with the terms and conditions of this 73 | Agreement. 74 | 4.2 The Users shall comply with the terms of the Copyright, Designs and 75 | Patents Act 1988 and in particular, but without limitation, shall 76 | recognise the rights, including moral right and the rights of attribution, 77 | of the Software Owner. Each use or adaptation of the Software shall 78 | make appropriate acknowledgement of the source, title, and copyright 79 | owner. 80 | 81 | 5. Representation, Warranties and Indemnification 82 | 5.1 The University makes no representation or warranty, and expressly 83 | disclaims any liability with respect to the Software including, 84 | but not limited to, errors or omissions contained therein, libel, 85 | defamation, infringements of rights of publicity, privacy, trademark 86 | rights, infringements of third party intellectual property rights, 87 | moral rights, or the disclosure of confidential information. It is 88 | expressly agreed that any use by the Users of the Software is at the 89 | User's sole risk. 90 | 5.2 The Software is provided on an 'as is' basis and the University 91 | disclaims any and all other warranties, conditions, or representations 92 | (express, implied, oral or written), relating to the Software or any 93 | part thereof, including, without limitation, any and all implied 94 | warranties of title, quality, performance, merchantability or fitness 95 | for a particular purpose. The University further expressly disclaims 96 | any warranty or representation to Users, or to any third party. The 97 | University accepts no liability for loss suffered or incurred by the 98 | User or any third party as a result of their reliance on the Software. 99 | 5.3 Nothing herein shall impose any liability upon the University in respect 100 | of use by the User of the Software and the University gives no indemnity 101 | in respect of any claim by the User or any third party relating to any 102 | action of the User in or arising from the Software. 103 | 5.4 It is the sole responsibility of the User to ensure that he has obtained 104 | any relevant third party permissions for any Adaptations of the Software 105 | made by the User and the User shall be responsible for any and all damages, 106 | liabilities, claims, causes of action, legal fees and costs incurred by 107 | the User in defending against any third party claim of intellectual 108 | property rights infringements or threats of claims thereof with respect 109 | of the use of the Software containing any Adaptations. 110 | 111 | 6. Consequential Loss 112 | 6.1 Neither party shall be liable to the other for any costs, claims, damages 113 | or expenses arising out of any act or omission or any breach of contract 114 | or statutory duty or in tort calculated by reference to profits, income, 115 | production or accruals or loss of such profits, income, production or 116 | accruals or by reference to accrual or such costs, claims, damages or 117 | expenses calculated on a time basis. 118 | 119 | 7. Termination 120 | 7.1 The University shall have the right to terminate this Agreement forthwith 121 | if the User shall have materially breached any of its obligations under 122 | this Agreement or in the event of a breach capable of remedy fails to 123 | remedy the same within thirty (30) days of the giving of notice by the 124 | University to the User of the alleged breach and of the action required 125 | to remedy the same. 126 | 127 | 8. General 128 | 8.1 Delay in exercising, or a failure to exercise, any right or remedy in 129 | connection with this Agreement shall not operate as a waiver of that 130 | right or remedy. A single or partial exercise of any right or remedy 131 | shall not preclude any other or further exercise of that right or remedy, 132 | or the exercise of any other right or remedy. A waiver of a breach of 133 | this Agreement shall not constitute a waiver of any subsequent breach. 134 | 8.2 Each of the parties acknowledges that it is not entering into this 135 | Agreement in reliance upon any representation, warranty, collateral 136 | contract or other assurance (except those set out in this Agreement 137 | and the documents referred to in it) made by or on behalf of any other 138 | party before the execution of this Agreement. Each of the parties waives 139 | all rights and remedies which, but for this clause, might otherwise be 140 | available to it in respect to any such representation, warranty, collateral 141 | contract or other assurance, provided that nothing in this Clause 8.2 142 | shall limit or exclude any liability for fraud. 143 | 8.3 Nothing in this Agreement shall constitute or be deemed to constitute a 144 | partnership or other form of joint venture between the parties or constitute 145 | or be deemed to constitute either party the agent or employee of the other 146 | for any purpose whatsoever. 147 | 8.4 No person who is not a party to this Agreement is entitled to enforce 148 | any of its terms, whether under the Contracts (Rights of Third Parties) 149 | Act 1999 or otherwise. 150 | 8.5 The parties intend each provision of this Agreement to be severable and 151 | distinct from the others. If a provision of this Agreement is held to be 152 | illegal, invalid or unenforceable, in whole or in part, the parties intend 153 | that the legality, validity and enforceability of the remainder of this 154 | Agreement shall not be affected. 155 | 156 | 9. Governing Law and Jurisdiction 157 | 9.1 This Agreement is governed by, and shall be interpreted in accordance with, 158 | English law and each party irrevocably submits to the non-exclusive 159 | jurisdiction of the English Courts in relation to all matters arising out 160 | of or in connection with this Agreement. 161 | 162 | 163 | Component Version License Link 164 | 165 | edge-oriented-graph 1.0 edge-oriented-graph https://github.com/fenchri/edge-oriented-graph/ 166 | 167 | edge-oriented-graph License Information 168 | 169 | Copyright (c) 2019, Fenia Christopoulou, National Centre for Text Mining, School of Computer Science, The University of Manchester. 170 | All rights reserved. 171 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Edge-oriented Graph 2 | Source code for the EMNLP 2019 paper "[Connecting the Dots: Document-level Relation Extraction with Edge-oriented Graphs](https://www.aclweb.org/anthology/D19-1498.pdf)". 3 | 4 |

5 | 6 |

7 | 8 | 9 | ### Environment 10 | `$ pip3 install -r requirements.txt` 11 | The model was trained on Tesla K80 GPU, Ubuntu 16.04. Results are reproducible with a fixed seed. 12 | 13 | 14 | 15 | ### Reproducibility & Bug Fixes 16 | 17 | In the original code, there was a bug related to the word embedding layer. 18 | If you want to reproduce the results presented in the paper, you need to use the "buggy" code: [reproduceEMNLP](https://github.com/fenchri/edge-oriented-graph/tree/reproduceEMNLP) 19 | Otherwise we recommend that you use the current version (with higher performance). 20 | 21 | 22 | 23 | ## Datasets & Pre-processing 24 | Download the datasets 25 | ``` 26 | $ mkdir data && cd data 27 | $ wget https://biocreative.bioinformatics.udel.edu/media/store/files/2016/CDR_Data.zip && unzip CDR_Data.zip && mv CDR_Data CDR 28 | $ wget https://bitbucket.org/alexwuhkucs/gda-extraction/get/fd4a7409365e.zip && unzip fd4a7409365e.zip && mv alexwuhkucs-gda-extraction-fd4a7409365e GDA 29 | $ cd .. 30 | ``` 31 | 32 | Download the GENIA Tagger and Sentence Splitter: 33 | ``` 34 | $ cd data_processing 35 | $ mkdir common && cd common 36 | $ wget http://www.nactem.ac.uk/y-matsu/geniass/geniass-1.00.tar.gz && tar xvzf geniass-1.00.tar.gz 37 | $ cd geniass/ && make && cd .. 38 | $ git clone https://github.com/bornabesic/genia-tagger-py.git 39 | $ cd genia-tagger-py 40 | ``` 41 | Here, you should modify the Makefile inside genia-tagger-py and replace line 3 with `wget http://www.nactem.ac.uk/GENIA/tagger/geniatagger-3.0.2.tar.gz` 42 | ``` 43 | $ make 44 | $ cd ../../ 45 | ``` 46 | 47 | > **Important**: In case genia splitter produces errors (e.g. cannot find a temp file), make sure you have ruby installed `sudo apt-get install ruby-full` 48 | 49 | In order to process the datasets, they should first be transformed into the PubTator format. The run the processing scripts as follows: 50 | ``` 51 | $ sh process_cdr.sh 52 | $ sh process_gda.sh 53 | ``` 54 | 55 | In order to get the data statistics run: 56 | ``` 57 | python3 statistics.py --data ../data/CDR/processed/train_filter.data 58 | python3 statistics.py --data ../data/CDR/processed/dev_filter.data 59 | python3 statistics.py --data ../data/CDR/processed/test_filter.data 60 | ``` 61 | This will additionally generate the gold-annotation file in the same folder with suffix `.gold`. 62 | 63 | 64 | ## Usage 65 | Run the main script for training and testing as follows. Select gpu -1 for cpu mode. 66 | 67 | **CDR dataset**: Train the model on the training set and evaluate on the dev set, in order to identify the best training epoch. 68 | For testing, re-run the model on the union of train and dev (`train+dev_filter.data`) until the best epoch and evaluate on the test set. 69 | 70 | **GDA dataset**: Simply train the model on the training set and evaluate on the dev set. Test the saved model on the test set. 71 | 72 | In order to ensure the usage of early stopping criterion, use the `--early_stop` option. 73 | If during training early stopping is not triggered, the maximum epoch (specified in the config file) will be used. 74 | 75 | Otherwise, if you want to train up to a specific epoch, use the `--epoch epochNumber` option without early stopping. 76 | The maximum stopping epochs is defined by the `--epoch` option. 77 | 78 | For example, in the CDR dataset: 79 | ``` 80 | $ cd src/ 81 | $ python3 eog.py --config ../configs/parameters_cdr.yaml --train --gpu 0 --early_stop # using early stopping 82 | $ python3 eog.py --config ../configs/parameters_cdr.yaml --train --gpu 0 --epoch 15 # train until the 15th epoch *without* early stopping 83 | $ python3 eog.py --config ../configs/parameters_cdr.yaml --train --gpu 0 --epoch 15 --early_stop # set both early stop and max epoch 84 | 85 | $ python3 eog.py --config ../configs/parameters_cdr.yaml --test --gpu 0 86 | ``` 87 | 88 | All necessary parameters can be stored in the yaml files inside the configs directory. 89 | The following parameters can be also directly given as follows: 90 | ``` 91 | usage: eog.py [-h] --config CONFIG [--train] [--test] [--gpu GPU] 92 | [--walks WALKS] [--window WINDOW] [--edges [EDGES [EDGES ...]]] 93 | [--types TYPES] [--context CONTEXT] [--dist DIST] [--example] 94 | [--seed SEED] [--early_stop] [--epoch EPOCH] 95 | 96 | optional arguments: 97 | -h, --help show this help message and exit 98 | --config CONFIG Yaml parameter file 99 | --train Training mode - model is saved 100 | --test Testing mode - needs a model to load 101 | --gpu GPU GPU number 102 | --walks WALKS Number of walk iterations 103 | --window WINDOW Window for training (empty processes the whole 104 | document, 1 processes 1 sentence at a time, etc) 105 | --edges [EDGES [EDGES ...]] 106 | Edge types 107 | --types TYPES Include node types (Boolean) 108 | --context CONTEXT Include MM context (Boolean) 109 | --dist DIST Include distance (Boolean) 110 | --example Show example 111 | --seed SEED Fixed random seed number 112 | --early_stop Use early stopping 113 | --epoch EPOCH Maximum training epoch 114 | ``` 115 | 116 | ### Evaluation 117 | In order to evaluate you need to first generate the gold data format and then use the evaluation script as follows: 118 | ``` 119 | $ cd evaluation/ 120 | $ python3 evaluate.py --pred path_to_predictions_file --gold ../data/CDR/processed/test_filter.gold --label 1:CDR:2 121 | $ python3 evaluate.py --pred path_to_predictions_file --gold ../data/GDA/processed/test.gold --label 1:GDA:2 122 | ``` 123 | 124 | 125 | ### Citation 126 | 127 | If you found this code useful and plan to use it, please cite the following paper =) 128 | ``` 129 | @inproceedings{christopoulou2019connecting, 130 | title = "Connecting the Dots: Document-level Neural Relation Extraction with Edge-oriented Graphs", 131 | author = "Christopoulou, Fenia and Miwa, Makoto and Ananiadou, Sophia", 132 | booktitle = "Proceedings of the 2019 Conference on Empirical Methods in Natural Language Processing and the 9th International Joint Conference on Natural Language Processing (EMNLP-IJCNLP)", 133 | year = "2019", 134 | publisher = "Association for Computational Linguistics", 135 | pages = "4927--4938" 136 | } 137 | ``` 138 | -------------------------------------------------------------------------------- /configs/parameters_cdr.yaml: -------------------------------------------------------------------------------- 1 | # network 2 | batch: 2 3 | epoch: 50 4 | bilstm_layers: 1 5 | word_dim: 200 6 | lstm_dim: 100 7 | out_dim: 100 8 | type_dim: 10 9 | beta: 0.8 10 | dist_dim: 10 11 | drop_i: 0.5 12 | drop_m: 0.0 13 | drop_o: 0.3 14 | lr: 0.002 15 | gc: 10 16 | reg: 0.0001 17 | opt: adam 18 | patience: 10 19 | unk_w_prob: 0.5 20 | min_w_freq: 1 21 | walks_iter: 4 22 | 23 | # data based 24 | train_data: ../data/CDR/processed/train+dev_filter.data 25 | test_data: ../data/CDR/processed/test_filter.data 26 | embeds: ../embeds/PubMed-CDR.txt 27 | folder: ../results/cdr-test 28 | save_pred: test 29 | 30 | # options (chosen from parse input otherwise false) 31 | lowercase: false 32 | plot: true 33 | show_class: false 34 | param_avg: true 35 | early_stop: false 36 | save_model: true 37 | types: true 38 | context: true 39 | dist: true 40 | freeze_words: true 41 | 42 | # extra 43 | seed: 0 44 | shuffle_data: true 45 | label2ignore: 1:NR:2 46 | primary_metric: micro_f 47 | direction: l2r+r2l 48 | include_pairs: ['Chemical-Disease', 'Chemical-Chemical', 'Disease-Disease', 'Disease-Chemical'] 49 | classify_pairs: ['Chemical-Disease'] 50 | edges: ['MM', 'ME', 'MS', 'ES', 'SS-ind'] 51 | window: 52 | -------------------------------------------------------------------------------- /configs/parameters_gda.yaml: -------------------------------------------------------------------------------- 1 | # network 2 | batch: 3 3 | epoch: 50 4 | bilstm_layers: 1 5 | word_dim: 200 6 | lstm_dim: 100 7 | out_dim: 100 8 | type_dim: 10 9 | beta: 0.8 10 | dist_dim: 10 11 | drop_i: 0.5 12 | drop_m: 0.0 13 | drop_o: 0.3 14 | lr: 0.002 15 | gc: 10 16 | reg: 0.0001 17 | opt: adam 18 | patience: 5 19 | unk_w_prob: 0.5 20 | min_w_freq: 1 21 | walks_iter: 4 22 | 23 | # data based 24 | train_data: ../data/GDA/processed/train.data 25 | test_data: ../data/GDA/processed/dev.data 26 | embeds: 27 | folder: ../results/gda 28 | save_pred: test 29 | 30 | # options (chosen from parse input otherwise false) 31 | lowercase: false 32 | plot: true 33 | show_class: false 34 | param_avg: true 35 | early_stop: false 36 | save_model: true 37 | types: true 38 | context: true 39 | dist: true 40 | freeze_words: false 41 | 42 | # extra 43 | seed: 0 44 | shuffle_data: true 45 | label2ignore: 1:NR:2 46 | primary_metric: micro_f 47 | direction: l2r+r2l 48 | include_pairs: ['Gene-Disease', 'Disease-Disease', 'Gene-Gene', 'Disease-Gene'] 49 | classify_pairs: ['Gene-Disease'] 50 | edges: ['MM', 'ME', 'MS', 'ES', 'SS-ind'] 51 | window: 52 | -------------------------------------------------------------------------------- /data_processing/__pycache__/readers.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fenchri/edge-oriented-graph/629b84a630146e81ffe15cbf6f5e5cf4efd9fb34/data_processing/__pycache__/readers.cpython-36.pyc -------------------------------------------------------------------------------- /data_processing/__pycache__/tools.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fenchri/edge-oriented-graph/629b84a630146e81ffe15cbf6f5e5cf4efd9fb34/data_processing/__pycache__/tools.cpython-36.pyc -------------------------------------------------------------------------------- /data_processing/__pycache__/utils.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fenchri/edge-oriented-graph/629b84a630146e81ffe15cbf6f5e5cf4efd9fb34/data_processing/__pycache__/utils.cpython-36.pyc -------------------------------------------------------------------------------- /data_processing/bin2txt.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Sat Aug 12 10:32:10 2017 4 | 5 | @author: fenia 6 | """ 7 | 8 | from gensim.models.keyedvectors import KeyedVectors 9 | import sys 10 | 11 | """ 12 | Transform from 'bin' to 'txt' word vectors. 13 | Input: the bin file 14 | Output: the txt file 15 | """ 16 | inp = sys.argv[1] 17 | out = ''.join(inp.split('.bin')[:-1])+'.txt' 18 | 19 | model = KeyedVectors.load_word2vec_format(inp, binary=True) 20 | model.save_word2vec_format(out, binary=False) 21 | -------------------------------------------------------------------------------- /data_processing/convertvec: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fenchri/edge-oriented-graph/629b84a630146e81ffe15cbf6f5e5cf4efd9fb34/data_processing/convertvec -------------------------------------------------------------------------------- /data_processing/convertvec.c: -------------------------------------------------------------------------------- 1 | // Code to convert word2vec vectors between text and binary format 2 | // Created by Marek Rei 3 | // Credits to: https://github.com/marekrei/convertvec 4 | 5 | #include 6 | #include 7 | #include 8 | #include 9 | #include 10 | 11 | const long long max_w = 2000; 12 | 13 | // Convert from text format to binary 14 | void txt2bin(char * input_path, char * output_path){ 15 | FILE * fi = fopen(input_path, "rb"); 16 | FILE * fo = fopen(output_path, "wb"); 17 | 18 | long long words, size; 19 | fscanf(fi, "%lld", &words); 20 | fscanf(fi, "%lld", &size); 21 | fscanf(fi, "%*[ ]"); 22 | fscanf(fi, "%*[\n]"); 23 | 24 | fprintf(fo, "%lld %lld\n", words, size); 25 | 26 | char word[max_w]; 27 | char ch; 28 | float value; 29 | int b, a; 30 | for (b = 0; b < words; b++) { 31 | if(feof(fi)) 32 | break; 33 | 34 | word[0] = 0; 35 | fscanf(fi, "%[^ ]", word); 36 | fscanf(fi, "%c", &ch); 37 | // This kind of whitespace handling is a bit more explicit than usual. 38 | // It allows us to correctly handle special characters that would otherwise be skipped. 39 | 40 | fprintf(fo, "%s ", word); 41 | 42 | for (a = 0; a < size; a++) { 43 | fscanf(fi, "%s", word); 44 | fscanf(fi, "%*[ ]"); 45 | value = atof(word); 46 | fwrite(&value, sizeof(float), 1, fo); 47 | } 48 | fscanf(fi, "%*[\n]"); 49 | fprintf(fo, "\n"); 50 | } 51 | 52 | fclose(fi); 53 | fclose(fo); 54 | } 55 | 56 | // Convert from binary to text format 57 | void bin2txt(char * input_path, char * output_path){ 58 | FILE * fi = fopen(input_path, "rb"); 59 | FILE * fo = fopen(output_path, "wb"); 60 | 61 | long long words, size; 62 | fscanf(fi, "%lld", &words); 63 | fscanf(fi, "%lld", &size); 64 | fscanf(fi, "%*[ ]"); 65 | fscanf(fi, "%*[\n]"); 66 | 67 | fprintf(fo, "%lld %lld\n", words, size); 68 | 69 | char word[max_w]; 70 | char ch; 71 | float value; 72 | int b, a; 73 | for (b = 0; b < words; b++) { 74 | if(feof(fi)) 75 | break; 76 | 77 | word[0] = 0; 78 | fscanf(fi, "%[^ ]", word); 79 | fscanf(fi, "%c", &ch); 80 | 81 | fprintf(fo, "%s ", word); 82 | for (a = 0; a < size; a++) { 83 | fread(&value, sizeof(float), 1, fi); 84 | fprintf(fo, "%lf ", value); 85 | } 86 | fscanf(fi, "%*[\n]"); 87 | fprintf(fo, "\n"); 88 | } 89 | fclose(fi); 90 | fclose(fo); 91 | } 92 | 93 | int main(int argc, char **argv) { 94 | if (argc < 4) { 95 | printf("USAGE: convertvec method input_path output_path\n"); 96 | printf("Method is either bin2txt or txt2bin\n"); 97 | return 0; 98 | } 99 | 100 | if(strcmp(argv[1], "bin2txt") == 0) 101 | bin2txt(argv[2], argv[3]); 102 | else if(strcmp(argv[1], "txt2bin") == 0) 103 | txt2bin(argv[2], argv[3]); 104 | else 105 | printf("Unknown method: %s\n", argv[1]); 106 | 107 | return 0; 108 | } 109 | -------------------------------------------------------------------------------- /data_processing/filter_hypernyms.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | """ 3 | Created on 17/04/19 4 | 5 | author: fenia 6 | """ 7 | 8 | import argparse 9 | import codecs 10 | from collections import defaultdict 11 | ''' 12 | Adaptation of https://github.com/patverga/bran/blob/master/src/processing/utils/filter_hypernyms.py 13 | ''' 14 | 15 | parser = argparse.ArgumentParser() 16 | parser.add_argument('-i', '--input_file', required=True, help='input file in 13col tsv') 17 | parser.add_argument('-m', '--mesh_file', required=True, help='mesh file to get hierarchy from') 18 | parser.add_argument('-o', '--output_file', required=True, help='write results to this file') 19 | 20 | args = parser.parse_args() 21 | 22 | 23 | def chunks(l, n): 24 | """ 25 | Yield successive n-sized chunks from l. 26 | """ 27 | for i in range(0, len(l), n): 28 | assert len(l[i:i + n]) == n 29 | yield l[i:i + n] 30 | 31 | 32 | # read in mesh hierarchy 33 | ent_tree_map = defaultdict(list) 34 | with codecs.open(args.mesh_file, 'r') as f: 35 | lines = [l.rstrip().split('\t') for i, l in enumerate(f) if i > 0] 36 | [ent_tree_map[l[1]].append(l[0]) for l in lines] 37 | 38 | 39 | # read in positive input file and organize by document 40 | print('Loading examples from %s' % args.input_file) 41 | pos_doc_examples = defaultdict(list) 42 | neg_doc_examples = defaultdict(list) 43 | 44 | unfilitered_pos_count = 0 45 | unfilitered_neg_count = 0 46 | text = {} 47 | with open(args.input_file, 'r') as f: 48 | lines = [l.strip().split('\t') for l in f] 49 | 50 | for l in lines: 51 | pmid = l[0] 52 | text[pmid] = pmid+'\t'+l[1] 53 | 54 | for r in chunks(l[2:], 17): 55 | 56 | if r[0] == '1:NR:2': 57 | assert ((r[7] == 'Chemical') and (r[13] == 'Disease')) 58 | neg_doc_examples[pmid].append(r) 59 | unfilitered_neg_count += 1 60 | elif r[0] == '1:CID:2': 61 | assert ((r[7] == 'Chemical') and (r[13] == 'Disease')) 62 | pos_doc_examples[pmid].append(r) 63 | unfilitered_pos_count += 1 64 | 65 | 66 | # iterate over docs 67 | hypo_count = 0 68 | negative_count = 0 69 | 70 | all_pos = 0 71 | with open(args.output_file, 'w') as out_f: 72 | for doc_id in pos_doc_examples.keys(): 73 | towrite = text[doc_id] 74 | 75 | for r in pos_doc_examples[doc_id]: 76 | towrite += '\t' 77 | towrite += '\t'.join(r) 78 | all_pos += len(pos_doc_examples[doc_id]) 79 | 80 | # get nodes for all the positive diseases 81 | pos_e2_examples = [(pos_node, pe) for pe in pos_doc_examples[doc_id] 82 | for pos_node in ent_tree_map[pe[11]]] 83 | 84 | pos_e1_examples = [(pos_node, pe) for pe in pos_doc_examples[doc_id] 85 | for pos_node in ent_tree_map[pe[5]]] 86 | 87 | filtered_neg_exampled = [] 88 | for ne in neg_doc_examples[doc_id]: 89 | neg_e1 = ne[5] 90 | neg_e2 = ne[11] 91 | example_hyponyms = 0 92 | for neg_node in ent_tree_map[ne[11]]: 93 | hyponyms = [pos_node for pos_node, pe in pos_e2_examples 94 | if neg_node in pos_node and neg_e1 == pe[5]] \ 95 | + [pos_node for pos_node, pe in pos_e1_examples 96 | if neg_node in pos_node and neg_e2 == pe[11]] 97 | example_hyponyms += len(hyponyms) 98 | if example_hyponyms == 0: 99 | towrite += '\t'+'\t'.join(ne) 100 | negative_count += 1 101 | else: 102 | ne[0] = 'not_include' # just don't include the negative pairs, but keep the entities 103 | towrite += '\t'+'\t'.join(ne) 104 | hypo_count += example_hyponyms 105 | out_f.write(towrite+'\n') 106 | 107 | print('Mesh entities: %d' % len(ent_tree_map)) 108 | print('Positive Docs: %d' % len(pos_doc_examples)) 109 | print('Negative Docs: %d' % len(neg_doc_examples)) 110 | print('Positive Count: %d Initial Negative Count: %d Final Negative Count: %d Hyponyms: %d' % 111 | (unfilitered_pos_count, unfilitered_neg_count, negative_count, hypo_count)) 112 | print(all_pos) -------------------------------------------------------------------------------- /data_processing/gda2pubtator.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | Created on 06/09/2019 5 | 6 | author: fenia 7 | """ 8 | 9 | import os, re, sys 10 | import argparse 11 | from tqdm import tqdm 12 | 13 | parser = argparse.ArgumentParser() 14 | parser.add_argument('--input_folder', '-i', type=str) 15 | parser.add_argument('--output_file', '-o', type=str) 16 | args = parser.parse_args() 17 | 18 | if not os.path.exists('/'.join(args.output_file.split('/')[:-1])): 19 | os.makedirs('/'.join(args.output_file.split('/')[:-1])) 20 | 21 | abstracts = {} 22 | entities = {} 23 | relations = {} 24 | with open(args.input_folder + 'abstracts.txt', 'r') as infile: 25 | for line in infile: 26 | if line.rstrip().isdigit(): 27 | pmid = line.rstrip() 28 | abstracts[pmid] = [] 29 | relations[pmid] = [] 30 | entities[pmid] = [] 31 | 32 | elif line != '\n': 33 | abstracts[pmid] += [line.rstrip()] 34 | 35 | with open(args.input_folder + 'anns.txt', 'r') as infile: 36 | for line in infile: 37 | line = line.split('\t') 38 | 39 | if line[0].isdigit(): 40 | entities[line[0]] += [tuple(line)] 41 | 42 | with open(args.input_folder + 'labels.csv', 'r') as infile: 43 | for line in infile: 44 | line = line.split(',') 45 | 46 | if line[0].isdigit() and line[3].rstrip() == '1': 47 | line = ','.join(line).rstrip().split(',') 48 | 49 | relations[line[0]] += [tuple([line[0]] + ['GDA'] + line[1:-1])] 50 | 51 | with open(args.output_file, 'w') as outfile: 52 | for d in tqdm(abstracts.keys(), desc='Writing 2 PubTator format'): 53 | if len(abstracts[d]) > 2: 54 | print('something is wrong') 55 | exit(0) 56 | 57 | for i in range(0, len(abstracts[d])): 58 | if i == 0: 59 | outfile.write('{}|t|{}\n'.format(d, abstracts[d][i])) 60 | else: 61 | outfile.write('{}|a|{}\n'.format(d, abstracts[d][i])) 62 | 63 | for e in entities[d]: 64 | outfile.write('{}'.format('\t'.join(e))) 65 | 66 | for r in relations[d]: 67 | outfile.write('{}\n'.format('\t'.join(r))) 68 | outfile.write('\n') 69 | 70 | -------------------------------------------------------------------------------- /data_processing/process.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | Created on %(date)s 5 | 6 | @author: fenia 7 | """ 8 | 9 | import os 10 | import re 11 | from tqdm import tqdm 12 | from recordtype import recordtype 13 | from collections import OrderedDict 14 | import argparse 15 | import pickle 16 | from itertools import permutations, combinations 17 | from tools import sentence_split_genia, tokenize_genia 18 | from tools import adjust_offsets, find_mentions, find_cross, fix_sent_break, convert2sent, generate_pairs 19 | from readers import * 20 | 21 | TextStruct = recordtype('TextStruct', 'pmid txt') 22 | EntStruct = recordtype('EntStruct', 'pmid name off1 off2 type kb_id sent_no word_id bio') 23 | RelStruct = recordtype('RelStruct', 'pmid type arg1 arg2') 24 | 25 | 26 | def main(): 27 | """ 28 | Main processing function 29 | """ 30 | parser = argparse.ArgumentParser() 31 | parser.add_argument('--input_file', '-i', type=str) 32 | parser.add_argument('--output_file', '-o', type=str) 33 | parser.add_argument('--data', '-d', type=str) 34 | args = parser.parse_args() 35 | 36 | if args.data == 'GDA': 37 | abstracts, entities, relations = readPubTator(args) 38 | type1 = ['Gene'] 39 | type2 = ['Disease'] 40 | 41 | elif args.data == 'CDR': 42 | abstracts, entities, relations = readPubTator(args) 43 | type1 = ['Chemical'] 44 | type2 = ['Disease'] 45 | 46 | else: 47 | print('Dataset non-existent.') 48 | sys.exit() 49 | 50 | if not os.path.exists(args.output_file + '_files/'): 51 | os.makedirs(args.output_file + '_files/') 52 | 53 | # Process 54 | positive, negative = 0, 0 55 | with open(args.output_file + '.data', 'w') as data_out: 56 | pbar = tqdm(list(abstracts.keys())) 57 | for i in pbar: 58 | pbar.set_description("Processing Doc_ID {}".format(i)) 59 | 60 | ''' Sentence Split ''' 61 | orig_sentences = [item for sublist in [a.txt.split('\n') for a in abstracts[i]] for item in sublist] 62 | split_sents = sentence_split_genia(orig_sentences) 63 | split_sents = fix_sent_break(split_sents, entities[i]) 64 | with open(args.output_file + '_files/' + i + '.split.txt', 'w') as f: 65 | f.write('\n'.join(split_sents)) 66 | 67 | # adjust offsets 68 | new_entities = adjust_offsets(orig_sentences, split_sents, entities[i], show=False) 69 | 70 | ''' Tokenisation ''' 71 | token_sents = tokenize_genia(split_sents) 72 | with open(args.output_file + '_files/' + i + '.split.tok.txt', 'w') as f: 73 | f.write('\n'.join(token_sents)) 74 | 75 | # adjust offsets 76 | new_entities = adjust_offsets(split_sents, token_sents, new_entities, show=True) 77 | 78 | ''' Find mentions ''' 79 | unique_entities = find_mentions(new_entities) 80 | with open(args.output_file + '_files/' + i + '.mention', 'wb') as f: 81 | pickle.dump(unique_entities, f, pickle.HIGHEST_PROTOCOL) 82 | 83 | ''' Generate Pairs ''' 84 | if i in relations: 85 | pairs = generate_pairs(unique_entities, type1, type2, relations[i]) 86 | else: 87 | pairs = generate_pairs(unique_entities, type1, type2, []) # generate only negative pairs 88 | 89 | # 'pmid type arg1 arg2 dir cross' 90 | data_out.write('{}\t{}'.format(i, '|'.join(token_sents))) 91 | 92 | for args_, p in pairs.items(): 93 | if p.type != '1:NR:2': 94 | positive += 1 95 | elif p.type == '1:NR:2': 96 | negative += 1 97 | 98 | data_out.write('\t{}\t{}\t{}\t{}-{}\t{}-{}'.format(p.type, p.dir, p.cross, p.closest[0].word_id[0], 99 | p.closest[0].word_id[-1]+1, 100 | p.closest[1].word_id[0], 101 | p.closest[1].word_id[-1]+1)) 102 | data_out.write('\t{}\t{}\t{}\t{}\t{}\t{}'.format( 103 | '|'.join([g for g in p.arg1]), 104 | '|'.join([e.name for e in unique_entities[p.arg1]]), 105 | unique_entities[p.arg1][0].type, 106 | ':'.join([str(e.word_id[0]) for e in unique_entities[p.arg1]]), 107 | ':'.join([str(e.word_id[-1] + 1) for e in unique_entities[p.arg1]]), 108 | ':'.join([str(e.sent_no) for e in unique_entities[p.arg1]]))) 109 | 110 | data_out.write('\t{}\t{}\t{}\t{}\t{}\t{}'.format( 111 | '|'.join([g for g in p.arg2]), 112 | '|'.join([e.name for e in unique_entities[p.arg2]]), 113 | unique_entities[p.arg2][0].type, 114 | ':'.join([str(e.word_id[0]) for e in unique_entities[p.arg2]]), 115 | ':'.join([str(e.word_id[-1] + 1) for e in unique_entities[p.arg2]]), 116 | ':'.join([str(e.sent_no) for e in unique_entities[p.arg2]]))) 117 | data_out.write('\n') 118 | print('Total positive pairs:', positive) 119 | print('Total negative pairs:', negative) 120 | 121 | 122 | if __name__ == "__main__": 123 | main() 124 | -------------------------------------------------------------------------------- /data_processing/process_cdr.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | for d in "Training" "Development" "Test"; 4 | do 5 | python3 process.py --input_file ../data/CDR/CDR.Corpus.v010516/CDR_${d}Set.PubTator.txt \ 6 | --output_file ../data/CDR/processed/${d} \ 7 | --data CDR 8 | 9 | python3 filter_hypernyms.py --mesh_file 2017MeshTree.txt \ 10 | --input_file ../data/CDR/processed/${d}.data \ 11 | --output_file ../data/CDR/processed/${d}_filter.data 12 | done 13 | 14 | mv ../data/CDR/processed/Training.data ../data/CDR/processed/train.data 15 | mv ../data/CDR/processed/Development.data ../data/CDR/processed/dev.data 16 | mv ../data/CDR/processed/Test.data ../data/CDR/processed/test.data 17 | 18 | mv ../data/CDR/processed/Training_filter.data ../data/CDR/processed/train_filter.data 19 | mv ../data/CDR/processed/Development_filter.data ../data/CDR/processed/dev_filter.data 20 | mv ../data/CDR/processed/Test_filter.data ../data/CDR/processed/test_filter.data 21 | 22 | # merge train and dev 23 | cat ../data/CDR/processed/train_filter.data > ../data/CDR/processed/train+dev_filter.data 24 | cat ../data/CDR/processed/dev_filter.data >> ../data/CDR/processed/train+dev_filter.data 25 | 26 | 27 | 28 | -------------------------------------------------------------------------------- /data_processing/process_gda.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | for d in "training" "testing"; 4 | do 5 | python3 gda2pubtator.py --input_folder ../data/GDA/${d}_data/ \ 6 | --output_file ../data/GDA/processed/${d}.pubtator 7 | 8 | python3 process.py --input_file ../data/GDA/processed/${d}.pubtator \ 9 | --output_file ../data/GDA/processed/${d} \ 10 | --data GDA 11 | done 12 | 13 | mv ../data/GDA/processed/testing.data ../data/GDA/processed/test.data 14 | mv ../data/GDA/processed/training.data ../data/GDA/processed/train+dev.data 15 | 16 | python3 split_gda.py --input_file ../data/GDA/processed/train+dev.data \ 17 | --output_train ../data/GDA/processed/train.data \ 18 | --output_dev ../data/GDA/processed/dev.data \ 19 | --list train_gda_docs 20 | -------------------------------------------------------------------------------- /data_processing/readers.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | Created on 14/08/2019 5 | 6 | author: fenia 7 | """ 8 | 9 | import os, re, sys 10 | from utils import replace2symbol, replace2space 11 | from collections import OrderedDict 12 | from tqdm import tqdm 13 | from recordtype import recordtype 14 | 15 | TextStruct = recordtype('TextStruct', 'pmid txt') 16 | EntStruct = recordtype('EntStruct', 'pmid name off1 off2 type kb_id sent_no word_id bio') 17 | RelStruct = recordtype('RelStruct', 'pmid type arg1 arg2') 18 | PairStruct = recordtype('PairStruct', 'pmid type arg1 arg2 dir cross closest') 19 | 20 | 21 | def readPubTator(args): 22 | """ 23 | Read data and store in structs 24 | """ 25 | if not os.path.exists('/'.join(args.output_file.split('/')[:-1])): 26 | os.makedirs('/'.join(args.output_file.split('/')[:-1])) 27 | 28 | abstracts = OrderedDict() 29 | entities = OrderedDict() 30 | relations = OrderedDict() 31 | 32 | with open(args.input_file, 'r') as infile: 33 | for line in tqdm(infile): 34 | 35 | # text 36 | if len(line.rstrip().split('|')) == 3 and \ 37 | (line.strip().split('|')[1] == 't' or line.strip().split('|')[1] == 'a'): 38 | line = line.strip().split('|') 39 | 40 | pmid = line[0] 41 | text = line[2] # .replace('>', '\n') 42 | 43 | # replace weird symbols and spaces 44 | text = replace2symbol(text) 45 | text = replace2space(text) 46 | 47 | if pmid not in abstracts: 48 | abstracts[pmid] = [TextStruct(pmid, text)] 49 | else: 50 | abstracts[pmid] += [TextStruct(pmid, text)] 51 | 52 | # entities 53 | elif len(line.rstrip().split('\t')) == 6: 54 | line = line.strip().split('\t') 55 | pmid = line[0] 56 | offset1 = int(line[1]) 57 | offset2 = int(line[2]) 58 | ent_name = line[3] 59 | ent_type = line[4] 60 | kb_id = line[5].split('|') 61 | 62 | # replace weird symbols and spaces 63 | ent_name = replace2symbol(ent_name) 64 | ent_name = replace2space(ent_name) 65 | 66 | # currently consider each possible ID as another entity 67 | for k in kb_id: 68 | if pmid not in entities: 69 | entities[pmid] = [EntStruct(pmid, ent_name, offset1, offset2, ent_type, [k], -1, [], [])] 70 | else: 71 | entities[pmid] += [EntStruct(pmid, ent_name, offset1, offset2, ent_type, [k], -1, [], [])] 72 | 73 | elif len(line.rstrip().split('\t')) == 7: 74 | line = line.strip().split('\t') 75 | pmid = line[0] 76 | offset1 = int(line[1]) 77 | offset2 = int(line[2]) 78 | ent_name = line[3] 79 | ent_type = line[4] 80 | kb_id = line[5].split('|') 81 | extra_ents = line[6].split('|') 82 | 83 | # replace weird symbols and spaces 84 | ent_name = replace2symbol(ent_name) 85 | ent_name = replace2space(ent_name) 86 | for i, e in enumerate(extra_ents): 87 | if pmid not in entities: 88 | entities[pmid] = [EntStruct(pmid, ent_name, offset1, offset2, ent_type, [kb_id[i]], -1, [], [])] 89 | else: 90 | entities[pmid] += [EntStruct(pmid, ent_name, offset1, offset2, ent_type, [kb_id[i]], -1, [], [])] 91 | 92 | # relations 93 | elif len(line.rstrip().split('\t')) == 4: 94 | line = line.strip().split('\t') 95 | pmid = line[0] 96 | rel_type = line[1] 97 | arg1 = tuple((line[2].split('|'))) 98 | arg2 = tuple((line[3].split('|'))) 99 | 100 | if pmid not in relations: 101 | relations[pmid] = [RelStruct(pmid, rel_type, arg1, arg2)] 102 | else: 103 | relations[pmid] += [RelStruct(pmid, rel_type, arg1, arg2)] 104 | 105 | elif line == '\n': 106 | continue 107 | 108 | return abstracts, entities, relations 109 | -------------------------------------------------------------------------------- /data_processing/reduce_embeds.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python3 2 | 3 | import numpy as np 4 | import glob 5 | import sys 6 | from collections import OrderedDict 7 | import argparse 8 | 9 | """ 10 | Crop embeddings to the size of the dataset, i.e. keeping only existing words. 11 | """ 12 | 13 | def load_pretrained_embeddings(embeds): 14 | """ 15 | :param params: input parameters 16 | :returns 17 | dictionary with words (keys) and embeddings (values) 18 | """ 19 | if embeds: 20 | E = OrderedDict() 21 | with open(embeds, 'r') as vectors: 22 | for x, line in enumerate(vectors): 23 | if x == 0 and len(line.split()) == 2: 24 | words, num = map(int, line.rstrip().split()) 25 | else: 26 | word = line.rstrip().split()[0] 27 | vec = line.rstrip().split()[1:] 28 | n = len(vec) 29 | if len(vec) != num: 30 | # print('Wrong dimensionality: {} {} != {}'.format(word, len(vec), num)) 31 | continue 32 | else: 33 | E[word] = np.asarray(vec, dtype=np.float32) 34 | print('Pre-trained word embeddings: {} x {}'.format(len(E), n)) 35 | else: 36 | E = OrderedDict() 37 | print('No pre-trained word embeddings loaded.') 38 | return E 39 | 40 | 41 | def main(): 42 | parser = argparse.ArgumentParser() 43 | parser.add_argument('--full_embeds', type=str) 44 | parser.add_argument('--out_embeds', type=str) 45 | parser.add_argument('--in_data', nargs='+') 46 | args = parser.parse_args() 47 | 48 | words = [] 49 | print('Extracting words from the dataset ... ', end="") 50 | 51 | for filef in args.in_data: 52 | with open(filef, 'r') as infile: 53 | for line in infile: 54 | line = line.strip().split('\t')[1] 55 | line = line.split('|') 56 | line = [l.split(' ') for l in line] 57 | line = [item for sublist in line for item in sublist] 58 | 59 | for l in line: 60 | words.append(l) 61 | print('Done') 62 | 63 | # make lowercase 64 | words_lower = list(map(lambda x:x.lower(), words)) 65 | 66 | print('Loading embeddings ... ', end="") 67 | embeddings = load_pretrained_embeddings(args.full_embeds) 68 | 69 | print('Writing final embeddings ... ', end="") 70 | words = set(words) 71 | words_lower = set(words_lower) # lowercased 72 | 73 | new_embeds = OrderedDict() 74 | for w in embeddings.keys(): 75 | if (w in words) or (w in words_lower): 76 | new_embeds[w] = embeddings[w] 77 | 78 | with open(args.out_embeds, 'w') as outfile: 79 | for g in new_embeds.keys(): 80 | outfile.write('{} {}\n'.format(g, ' '.join(map(str, list(new_embeds[g]))))) 81 | print('Done') 82 | 83 | if __name__ == "__main__": 84 | main() 85 | -------------------------------------------------------------------------------- /data_processing/split_gda.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | Created on 06/09/2019 5 | 6 | author: fenia 7 | """ 8 | 9 | import argparse 10 | import os 11 | 12 | parser = argparse.ArgumentParser() 13 | parser.add_argument('--input_file', type=str) 14 | parser.add_argument('--output_train', type=str) 15 | parser.add_argument('--output_dev', type=str) 16 | parser.add_argument('--list', type=str) 17 | args = parser.parse_args() 18 | 19 | 20 | with open(args.list, 'r') as infile: 21 | docs = [i.rstrip() for i in infile] 22 | docs = frozenset(docs) 23 | 24 | with open(args.input_file, 'r') as infile, open(args.output_train, 'w') as otr, open(args.output_dev, 'w') as odev: 25 | for line in infile: 26 | pmid = line.rstrip().split('\t')[0] 27 | 28 | if pmid in docs: 29 | otr.write(line) 30 | else: 31 | odev.write(line) 32 | 33 | -------------------------------------------------------------------------------- /data_processing/statistics.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | Created on 09/05/2019 5 | 6 | author: fenia 7 | """ 8 | 9 | import argparse 10 | import numpy as np 11 | from collections import OrderedDict 12 | from recordtype import recordtype 13 | 14 | parser = argparse.ArgumentParser() 15 | parser.add_argument('--data', type=str) 16 | args = parser.parse_args() 17 | 18 | 19 | EntityInfo = recordtype('EntityInfo', 'type mstart mend sentNo') 20 | PairInfo = recordtype('PairInfo', 'type direction cross closeA closeB') 21 | 22 | 23 | def chunks(l, n): 24 | """ 25 | Yield successive n-sized chunks from l. 26 | """ 27 | for i in range(0, len(l), n): 28 | assert len(l[i:i + n]) == n 29 | yield l[i:i + n] 30 | 31 | 32 | for d, data in zip(['DATA'], [args.data]): 33 | documents = {} 34 | entities = {} 35 | relations = {} 36 | 37 | with open(data, 'r') as infile: 38 | 39 | for line in infile: 40 | line = line.rstrip().split('\t') 41 | pairs = chunks(line[2:], 17) 42 | 43 | id_ = line[0] 44 | 45 | if id_ not in documents: 46 | documents[id_] = [] 47 | 48 | for sent in line[1].split('|'): 49 | documents[id_] += [sent] 50 | 51 | if id_ not in entities: 52 | entities[id_] = OrderedDict() 53 | 54 | if id_ not in relations: 55 | relations[id_] = OrderedDict() 56 | 57 | for p in pairs: 58 | # pairs 59 | if (p[5], p[11]) not in relations[id_]: 60 | relations[id_][(p[5], p[11])] = PairInfo(p[0], p[1], p[2], p[3], p[4]) 61 | else: 62 | print('duplicates!') 63 | 64 | # entities 65 | if p[5] not in entities[id_]: 66 | entities[id_][p[5]] = EntityInfo(p[7], p[8], p[9], p[10]) 67 | 68 | if p[11] not in entities[id_]: 69 | entities[id_][p[11]] = EntityInfo(p[13], p[14], p[15], p[16]) 70 | 71 | docs = len(documents) 72 | pair_types = {} 73 | inter_types = {} 74 | intra_types = {} 75 | ent_types = {} 76 | men_types = {} 77 | dist = {} 78 | for id_ in relations.keys(): 79 | for k, p in relations[id_].items(): 80 | 81 | if p.type not in pair_types: 82 | pair_types[p.type] = 0 83 | pair_types[p.type] += 1 84 | 85 | if p.type not in dist: 86 | dist[p.type] = [] 87 | 88 | if p.type not in inter_types: 89 | inter_types[p.type] = 0 90 | 91 | if p.type not in intra_types: 92 | intra_types[p.type] = 0 93 | 94 | if p.type not in dist: 95 | dist[p.type] = [] 96 | 97 | if p.cross == 'CROSS': 98 | inter_types[p.type] += 1 99 | else: 100 | intra_types[p.type] += 1 101 | 102 | if p.cross == 'CROSS': 103 | dist_ = 10000 104 | 105 | for m1 in entities[id_][k[0]].sentNo.split(':'): 106 | for m2 in entities[id_][k[1]].sentNo.split(':'): 107 | 108 | if abs(int(m1) - int(m2)) < dist_: 109 | dist_ = abs(int(m1) - int(m2)) 110 | 111 | dist[p.type] += [dist_] 112 | 113 | for e in entities[id_].values(): 114 | if e.type not in ent_types: 115 | ent_types[e.type] = 0 116 | ent_types[e.type] += 1 117 | 118 | if e.type not in men_types: 119 | men_types[e.type] = 0 120 | for m in e.mstart.split(':'): 121 | men_types[e.type] += 1 122 | 123 | ents_per_doc = [len(entities[n]) for n in documents.keys()] 124 | ments_per_doc = [np.sum([len(e.sentNo.split(':')) for e in entities[n].values()]) for n in documents.keys()] 125 | ments_per_ent = [[len(e.sentNo.split(':')) for e in entities[n].values()] for n in documents.keys()] 126 | sents_per_doc = [len(s) for s in documents.values()] 127 | sent_len = [len(a.split()) for s in documents.values() for a in s] 128 | 129 | # write data 130 | with open('/'.join(args.data.split('/')[:-1]) + '/' + args.data.split('/')[-1].split('.')[0] + '.gold', 'w') as outfile: 131 | for id_ in relations.keys(): 132 | for k, p in relations[id_].items(): 133 | PairInfo = recordtype('PairInfo', 'type direction cross closeA closeB') 134 | outfile.write('{}|{}|{}|{}|{}\n'.format(id_, k[0], k[1], p.cross, p.type)) 135 | 136 | print(''' 137 | ----------------------- {} ---------------------- 138 | Documents {} 139 | '''.format(d, docs)) 140 | 141 | print(' Pairs') 142 | 143 | for x in ['{:<10}\t{:<5}'.format(k, v) for k, v in sorted(pair_types.items())]: 144 | print(' {}'.format(x)) 145 | print() 146 | 147 | print(' Entities') 148 | for x in ['{:<10}\t{:<5}'.format(k, v) for k, v in sorted(ent_types.items())]: 149 | print(' {}'.format(x)) 150 | print() 151 | 152 | print(' Mentions') 153 | for x in ['{:<10}\t{:<5}'.format(k, v) for k, v in sorted(men_types.items())]: 154 | print(' {}'.format(x)) 155 | print() 156 | 157 | print(' Intra Pairs') 158 | for x in ['{:<10}\t{:<5}'.format(k, v) for k, v in sorted(intra_types.items())]: 159 | print(' {}'.format(x)) 160 | print() 161 | 162 | print(' Inter Pairs') 163 | for x in ['{:<10}\t{:<5}'.format(k, v) for k, v in sorted(inter_types.items())]: 164 | print(' {}'.format(x)) 165 | print() 166 | 167 | print(' Average/Max Sentence Distance') 168 | for x in ['{:<10}\t{:.1f}\t{}'.format(k, np.average(v), np.max(v)) for k, v in sorted(dist.items())]: 169 | print(' {}'.format(x)) 170 | print() 171 | 172 | print(''' 173 | Average entites/doc {:.1f} 174 | Max {} 175 | 176 | Average mentions/doc {:.1f} 177 | Max {} 178 | 179 | Average mentions/entity {:.1f} 180 | Max {} 181 | 182 | Average sents/doc {:.1f} 183 | Max {} 184 | 185 | Average/max sent length {:.1f} 186 | Max {} 187 | '''.format(np.average(ents_per_doc), 188 | np.max(ents_per_doc), 189 | np.average(ments_per_doc), 190 | np.max(ments_per_doc), 191 | np.average([item for sublist in ments_per_ent for item in sublist]), 192 | np.max([item for sublist in ments_per_ent for item in sublist]), 193 | np.average(sents_per_doc), 194 | np.max(sents_per_doc), 195 | np.average(sent_len), 196 | np.max(sent_len))) 197 | 198 | 199 | 200 | 201 | 202 | 203 | 204 | 205 | 206 | 207 | 208 | 209 | 210 | -------------------------------------------------------------------------------- /data_processing/tools.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | Created on %(date)s 5 | 6 | @author: fenia 7 | """ 8 | 9 | import os 10 | import sys 11 | import re 12 | from recordtype import recordtype 13 | from networkx.algorithms.components.connected import connected_components 14 | from itertools import combinations 15 | import numpy as np 16 | from collections import OrderedDict 17 | from utils import to_graph, to_edges, using_split2 18 | from tqdm import tqdm 19 | sys.path.append('./common/genia-tagger-py/') 20 | from geniatagger import GENIATagger 21 | 22 | pwd = '/'.join(os.path.realpath(__file__).split('/')[:-1]) 23 | 24 | genia_splitter = os.path.join("./common", "geniass") 25 | genia_tagger = GENIATagger(os.path.join("./common", "genia-tagger-py", "geniatagger-3.0.2", "geniatagger")) 26 | 27 | 28 | TextStruct = recordtype('TextStruct', 'pmid txt') 29 | EntStruct = recordtype('EntStruct', 'pmid name off1 off2 type kb_id sent_no word_id bio') 30 | RelStruct = recordtype('RelStruct', 'pmid type arg1 arg2') 31 | PairStruct = recordtype('PairStruct', 'pmid type arg1 arg2 dir cross closest') 32 | 33 | 34 | def generate_pairs(uents, type1, type2, true_rels): 35 | """ 36 | Generate pairs (both positive & negative): 37 | Type1 - Type2 should have 1-1 association, e.g. [A, A] [B, C] --> (A,B), (A,C) 38 | Args: 39 | uents: 40 | type1: (list) with entity semantic types 41 | type2: (list) with entity semantic types 42 | true_rels: 43 | """ 44 | pairs = OrderedDict() 45 | combs = combinations(uents, 2) 46 | 47 | unk = 0 48 | total_rels = len(true_rels) 49 | found_rels = 0 50 | 51 | for c in combs: 52 | # all pairs 53 | diff = 99999 54 | 55 | target = [] 56 | for e1 in uents[c[0]]: 57 | for e2 in uents[c[1]]: 58 | # find most close pair to each other 59 | if e1.word_id[-1] <= e2.word_id[0]: 60 | if abs(e2.word_id[0] - e1.word_id[-1]) < diff: 61 | target = [e1, e2] 62 | diff = abs(e2.word_id[0] - e1.word_id[-1]) 63 | else: 64 | if abs(e1.word_id[0] - e2.word_id[-1]) < diff: 65 | target = [e1, e2] 66 | diff = abs(e2.word_id[0] - e1.word_id[-1]) 67 | 68 | if target[0].word_id[-1] <= target[1].word_id[0]: # A before B (in text) 69 | a1 = target[0] 70 | a2 = target[1] 71 | else: # B before A (in text) 72 | a1 = target[1] 73 | a2 = target[0] 74 | 75 | if c[0][0].startswith('UNK:') or c[1][0].startswith('UNK:'): # ignore non-grounded entities 76 | continue 77 | 78 | cross_res = find_cross(c, uents) 79 | not_found_rels = 0 80 | 81 | for tr in true_rels: 82 | 83 | # AB existing relation 84 | if list(set(tr.arg1).intersection(set(c[0]))) and list(set(tr.arg2).intersection(set(c[1]))): 85 | for t1, t2 in zip(type1, type2): 86 | if uents[c[0]][0].type == t1 and uents[c[1]][0].type == t2: 87 | pairs[(c[0], c[1])] = \ 88 | PairStruct(tr.pmid, '1:' + tr.type + ':2', c[0], c[1], 'L2R', cross_res, (a1, a2)) 89 | found_rels += 1 90 | 91 | # BA existing relation 92 | elif list(set(tr.arg1).intersection(set(c[1]))) and list(set(tr.arg2).intersection(set(c[0]))): 93 | for t1, t2 in zip(type1, type2): 94 | if uents[c[1]][0].type == t1 and uents[c[0]][0].type == t2: 95 | pairs[(c[1], c[0])] = \ 96 | PairStruct(tr.pmid, '1:'+tr.type+':2', c[1], c[0], 'R2L', cross_res, (a2, a1)) 97 | found_rels += 1 98 | 99 | # relation not found 100 | else: 101 | not_found_rels += 1 102 | 103 | # this pair does not have a relation 104 | if not_found_rels == total_rels: 105 | for t1, t2 in zip(type1, type2): 106 | if uents[c[0]][0].type == t1 and uents[c[1]][0].type == t2: 107 | pairs[(c[0], c[1])] = PairStruct(a1.pmid, '1:NR:2', c[0], c[1], 'L2R', cross_res, (a1, a2)) 108 | unk += 1 109 | elif uents[c[1]][0].type == t1 and uents[c[0]][0].type == t2: 110 | pairs[(c[1], c[0])] = PairStruct(a1.pmid, '1:NR:2', c[1], c[0], 'R2L', cross_res, (a2, a1)) 111 | unk += 1 112 | 113 | assert found_rels == total_rels, '{} <> {}, {}, {}'.format(found_rels, total_rels, true_rels, pairs) 114 | 115 | # # Checking 116 | # if found_rels != total_rels: 117 | # print('NON-FOUND RELATIONS: {} <> {}'.format(found_rels, total_rels)) 118 | # for p in true_rels: 119 | # if (p.arg1, p.arg2) not in pairs: 120 | # print(p.arg1, p.arg2) 121 | return pairs 122 | 123 | 124 | def convert2sent(arg1, arg2, token_sents): 125 | """ 126 | Convert document info to sentence (for pairs in same sentence). 127 | Args: 128 | arg1: 129 | arg2: 130 | token_sents: 131 | """ 132 | # make sure they are in the same sentence 133 | assert arg1.sent_no == arg2.sent_no, 'error: entities not in the same sentence' 134 | 135 | toks_per_sent = [] 136 | sent_offs = [] 137 | cnt = 0 138 | for i, s in enumerate(token_sents): 139 | toks_per_sent.append(len(s.split(' '))) 140 | sent_offs.append((cnt, cnt+len(s.split(' '))-1)) 141 | cnt = len(' '.join(token_sents[:i+1]).split(' ')) 142 | 143 | target_sent = token_sents[arg1.sent_no].split(' ') 144 | n = sum(toks_per_sent[0:arg1.sent_no]) 145 | 146 | arg1_span = [a-n for a in arg1.word_id] 147 | arg2_span = [a-n for a in arg2.word_id] 148 | assert target_sent[arg1_span[0]:arg1_span[-1]+1] == \ 149 | ' '.join(token_sents).split(' ')[arg1.word_id[0]:arg1.word_id[-1]+1] 150 | assert target_sent[arg2_span[0]:arg2_span[-1]+1] == \ 151 | ' '.join(token_sents).split(' ')[arg2.word_id[0]:arg2.word_id[-1]+1] 152 | 153 | arg1_n = EntStruct(arg1.pmid, arg1.name, arg1.off1, arg1.off2, arg1.type, 154 | arg1.kb_id, arg1.sent_no, arg1_span, arg1.bio) 155 | arg2_n = EntStruct(arg2.pmid, arg2.name, arg2.off1, arg2.off2, arg2.type, 156 | arg2.kb_id, arg2.sent_no, arg2_span, arg2.bio) 157 | 158 | return arg1_n, arg2_n 159 | 160 | 161 | def find_cross(pair, unique_ents): 162 | """ 163 | Find if the pair is in cross or non-cross sentence. 164 | Args: 165 | pair: (tuple) target pair 166 | unique_ents: (dic) entities based on grounded IDs 167 | Returns: (str) cross/non-cross 168 | """ 169 | non_cross = False 170 | for m1 in unique_ents[pair[0]]: 171 | for m2 in unique_ents[pair[1]]: 172 | if m1.sent_no == m2.sent_no: 173 | non_cross = True 174 | if non_cross: 175 | return 'NON-CROSS' 176 | else: 177 | return 'CROSS' 178 | 179 | 180 | def fix_sent_break(sents, entities): 181 | """ 182 | Fix sentence break + Find sentence of each entity 183 | Args: 184 | sents: (list) sentences 185 | entities: (recordtype) 186 | Returns: (list) sentences with fixed sentence breaks 187 | """ 188 | sents_break = '\n'.join(sents) 189 | 190 | for e in entities: 191 | if '\n' in sents_break[e.off1:e.off2]: 192 | sents_break = sents_break[0:e.off1] + sents_break[e.off1:e.off2].replace('\n', ' ') + sents_break[e.off2:] 193 | return sents_break.split('\n') 194 | 195 | 196 | def find_mentions(entities): 197 | """ 198 | Find unique entities and their mentions 199 | Args: 200 | entities: (dic) a struct for each entity 201 | Returns: (dic) unique entities based on their grounded ID, if -1 ID=UNK:No 202 | """ 203 | equivalents = [] 204 | for e in entities: 205 | if e.kb_id not in equivalents: 206 | equivalents.append(e.kb_id) 207 | 208 | # mention-level data sets 209 | g = to_graph(equivalents) 210 | cc = connected_components(g) 211 | 212 | unique_entities = OrderedDict() 213 | unk_id = 0 214 | for c in cc: 215 | if tuple(c)[0] == '-1': 216 | continue 217 | unique_entities[tuple(c)] = [] 218 | 219 | # consider non-grounded entities as separate entities 220 | for e in entities: 221 | if e.kb_id[0] == '-1': 222 | unique_entities[tuple(('UNK:' + str(unk_id),))] = [e] 223 | unk_id += 1 224 | else: 225 | for ue in unique_entities.keys(): 226 | if list(set(e.kb_id).intersection(set(ue))): 227 | unique_entities[ue] += [e] 228 | 229 | return unique_entities 230 | 231 | 232 | def sentence_split_genia(tabst): 233 | """ 234 | Sentence Splitting Using GENIA sentence splitter 235 | Args: 236 | tabst: (list) title+abstract 237 | 238 | Returns: (list) all sentences in abstract 239 | """ 240 | os.chdir(genia_splitter) 241 | 242 | with open('temp_file.txt', 'w') as ofile: 243 | for t in tabst: 244 | ofile.write(t+'\n') 245 | os.system('./geniass temp_file.txt temp_file.split.txt > /dev/null 2>&1') 246 | 247 | split_lines = [] 248 | with open('temp_file.split.txt', 'r') as ifile: 249 | for line in ifile: 250 | line = line.rstrip() 251 | if line != '': 252 | split_lines.append(line.rstrip()) 253 | os.system('rm temp_file.txt temp_file.split.txt') 254 | os.chdir(pwd) 255 | return split_lines 256 | 257 | 258 | def tokenize_genia(sents): 259 | """ 260 | Tokenization using Genia Tokenizer 261 | Args: 262 | sents: (list) sentences 263 | 264 | Returns: (list) tokenized sentences 265 | """ 266 | token_sents = [] 267 | for i, s in enumerate(sents): 268 | tokens = [] 269 | 270 | for word, base_form, pos_tag, chunk, named_entity in genia_tagger.tag(s): 271 | tokens += [word] 272 | 273 | text = [] 274 | for t in tokens: 275 | if t == "'s": 276 | text.append(t) 277 | elif t == "''": 278 | text.append(t) 279 | else: 280 | text.append(t.replace("'", " ' ")) 281 | 282 | text = ' '.join(text) 283 | text = text.replace("-LRB-", '(') 284 | text = text.replace("-RRB-", ')') 285 | text = text.replace("-LSB-", '[') 286 | text = text.replace("-RSB-", ']') 287 | text = text.replace("``", '"') 288 | text = text.replace("`", "'") 289 | text = text.replace("'s", " 's") 290 | text = text.replace('-', ' - ') 291 | text = text.replace('/', ' / ') 292 | text = text.replace('+', ' + ') 293 | text = text.replace('.', ' . ') 294 | text = text.replace('=', ' = ') 295 | text = text.replace('*', ' * ') 296 | if '&' in s: 297 | text = text.replace("&", "&") 298 | else: 299 | text = text.replace("&", "&") 300 | 301 | text = re.sub(' +', ' ', text).strip() # remove continuous spaces 302 | 303 | if "''" in ''.join(s): 304 | pass 305 | else: 306 | text = text.replace("''", '"') 307 | 308 | token_sents.append(text) 309 | return token_sents 310 | 311 | 312 | def adjust_offsets(old_sents, new_sents, old_entities, show=False): 313 | """ 314 | Adjust offsets based on tokenization 315 | Args: 316 | old_sents: (list) old, non-tokenized sentences 317 | new_sents: (list) new, tokenized sentences 318 | old_entities: (dic) entities with old offsets 319 | Returns: 320 | new_entities: (dic) entities with adjusted offsets 321 | abst_seq: (list) abstract sequence with entity tags 322 | """ 323 | cur = 0 324 | new_sent_range = [] 325 | for s in new_sents: 326 | new_sent_range += [(cur, cur + len(s))] 327 | cur += len(s) + 1 328 | 329 | original = " ".join(old_sents) 330 | newtext = " ".join(new_sents) 331 | new_entities = [] 332 | terms = {} 333 | for e in old_entities: 334 | start = int(e.off1) 335 | end = int(e.off2) 336 | 337 | if (start, end) not in terms: 338 | terms[(start, end)] = [[start, end, e.type, e.name, e.pmid, e.kb_id]] 339 | else: 340 | terms[(start, end)].append([start, end, e.type, e.name, e.pmid, e.kb_id]) 341 | 342 | orgidx = 0 343 | newidx = 0 344 | orglen = len(original) 345 | newlen = len(newtext) 346 | 347 | terms2 = terms.copy() 348 | while orgidx < orglen and newidx < newlen: 349 | # print(repr(original[orgidx]), orgidx, repr(newtext[newidx]), newidx) 350 | if original[orgidx] == newtext[newidx]: 351 | orgidx += 1 352 | newidx += 1 353 | elif original[orgidx] == "`" and newtext[newidx] == "'": 354 | orgidx += 1 355 | newidx += 1 356 | elif newtext[newidx] == '\n': 357 | newidx += 1 358 | elif original[orgidx] == '\n': 359 | orgidx += 1 360 | elif newtext[newidx] == ' ': 361 | newidx += 1 362 | elif original[orgidx] == ' ': 363 | orgidx += 1 364 | elif newtext[newidx] == '\t': 365 | newidx += 1 366 | elif original[orgidx] == '\t': 367 | orgidx += 1 368 | elif newtext[newidx] == '.': 369 | # ignore extra "." for stanford 370 | newidx += 1 371 | else: 372 | print("Non-existent text: %d\t --> %s != %s " % (orgidx, repr(original[orgidx-10:orgidx+10]), 373 | repr(newtext[newidx-10:newidx+10]))) 374 | exit(0) 375 | 376 | starts = [key[0] for key in terms2.keys()] 377 | ends = [key[1] for key in terms2.keys()] 378 | 379 | if orgidx in starts: 380 | tt = [key for key in terms2.keys() if key[0] == orgidx] 381 | for sel in tt: 382 | for l in terms[sel]: 383 | l[0] = newidx 384 | 385 | if orgidx in ends: 386 | tt2 = [key for key in terms2.keys() if key[1] == orgidx] 387 | for sel2 in tt2: 388 | for l in terms[sel2]: 389 | if l[1] == orgidx: 390 | l[1] = newidx 391 | 392 | for t_ in tt2: 393 | del terms2[t_] 394 | 395 | ent_sequences = [] 396 | for ts in terms.values(): 397 | for term in ts: 398 | condition = False 399 | 400 | if newtext[term[0]:term[1]].replace(" ", "").replace("\n", "") != term[3].replace(" ", "").replace('\n', ''): 401 | if newtext[term[0]:term[1]].replace(" ", "").replace("\n", "").lower() == \ 402 | term[3].replace(" ", "").replace('\n', '').lower(): 403 | condition = True 404 | tqdm.write('DOC_ID {}, Lowercase Issue: {} <-> {}'.format(term[4], newtext[term[0]:term[1]], term[3])) 405 | else: 406 | condition = False 407 | else: 408 | condition = True 409 | 410 | if condition: 411 | """ Convert to word Ids """ 412 | tok_seq = [] 413 | span2append = [] 414 | bio = [] 415 | tag = term[2] 416 | 417 | for tok_id, (tok, start, end) in enumerate(using_split2(newtext)): 418 | start = int(start) 419 | end = int(end) 420 | 421 | if (start, end) == (term[0], term[1]): 422 | bio.append('B-' + tag) 423 | tok_seq.append('B-' + tag) 424 | span2append.append(tok_id) 425 | 426 | elif start == term[0] and end < term[1]: 427 | bio.append('B-' + tag) 428 | tok_seq.append('B-' + tag) 429 | span2append.append(tok_id) 430 | 431 | elif start > term[0] and end < term[1]: 432 | bio.append('I-' + tag) 433 | tok_seq.append('I-' + tag) 434 | span2append.append(tok_id) 435 | 436 | elif start > term[0] and end == term[1] and (start != end): 437 | bio.append('I-' + tag) 438 | tok_seq.append('I-' + tag) 439 | span2append.append(tok_id) 440 | 441 | elif len(set(range(start, end)).intersection(set(range(term[0], term[1])))) > 0: 442 | span2append.append(tok_id) 443 | 444 | if show: 445 | tqdm.write('DOC_ID {}, entity: {:<20} ({:4}-{:4}) ' 446 | '<-> token: {:<20} ({:4}-{:4}) <-> final: {:<20}'.format( 447 | term[4], newtext[term[0]:term[1]], term[0], term[1], tok, start, end, 448 | ' '.join(newtext.split(' ')[span2append[0]:span2append[-1]+1]))) 449 | 450 | if not bio: 451 | bio.append('B-' + tag) 452 | tok_seq.append('B-' + tag) 453 | else: 454 | bio.append('I-' + tag) 455 | tok_seq.append('I-' + tag) 456 | 457 | else: 458 | tok_seq.append('O') 459 | 460 | ent_sequences += [tok_seq] 461 | 462 | # inlude all tokens!! 463 | if len(span2append) != len(newtext[term[0]:term[1]].split(' ')): 464 | tqdm.write('DOC_ID {}, entity {}, tokens {}\n{}'.format( 465 | term[4], newtext[term[0]:term[1]].split(' '), span2append, newtext)) 466 | 467 | # Find sentence number of each entity 468 | sent_no = [] 469 | for s_no, sr in enumerate(new_sent_range): 470 | if set(np.arange(term[0], term[1])).issubset(set(np.arange(sr[0], sr[1]))): 471 | sent_no += [s_no] 472 | 473 | assert (len(sent_no) == 1), '{} ({}, {}) -- {} -- {} <> {}'.format(sent_no, term[0], term[1], 474 | new_sent_range, 475 | newtext[term[0]:term[1]], term[3]) 476 | 477 | new_entities += [EntStruct(term[4], newtext[term[0]:term[1]], term[0], term[1], term[2], term[5], 478 | sent_no[0], span2append, bio)] 479 | else: 480 | print(newtext, term[3]) 481 | assert False, 'ERROR: {} ({}-{}) <=> {}'.format(repr(newtext[term[0]:term[1]]), term[0], term[1], 482 | repr(term[3])) 483 | return new_entities 484 | -------------------------------------------------------------------------------- /data_processing/utils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | Created on %(date)s 5 | 6 | @author: fenia 7 | """ 8 | 9 | import networkx 10 | 11 | 12 | def to_graph(l): 13 | """ 14 | https://stackoverflow.com/questions/4842613/merge-lists-that-share-common-elements 15 | """ 16 | G = networkx.Graph() 17 | for part in l: 18 | # each sublist is a bunch of nodes 19 | G.add_nodes_from(part) 20 | # it also implies a number of edges: 21 | G.add_edges_from(to_edges(part)) 22 | return G 23 | 24 | 25 | def to_edges(l): 26 | """ 27 | treat `l` as a Graph and returns it's edges 28 | to_edges(['a','b','c','d']) -> [(a,b), (b,c),(c,d)] 29 | """ 30 | it = iter(l) 31 | last = next(it) 32 | 33 | for current in it: 34 | yield last, current 35 | last = current 36 | 37 | 38 | def using_split2(line, _len=len): 39 | """ 40 | Credits to https://stackoverflow.com/users/1235039/aquavitae 41 | 42 | :param line: sentence 43 | :return: a list of words and their indexes in a string. 44 | """ 45 | words = line.split(' ') 46 | index = line.index 47 | offsets = [] 48 | append = offsets.append 49 | running_offset = 0 50 | for word in words: 51 | word_offset = index(word, running_offset) 52 | word_len = _len(word) 53 | running_offset = word_offset + word_len 54 | append((word, word_offset, running_offset)) 55 | return offsets 56 | 57 | 58 | def replace2symbol(string): 59 | string = string.replace('”', '"').replace('’', "'").replace('–', '-').replace('‘', "'").replace('‑', '-').replace( 60 | '\x92', "'").replace('»', '"').replace('—', '-').replace('\uf8fe', ' ').replace('«', '"').replace( 61 | '\uf8ff', ' ').replace('£', '#').replace('\u2028', ' ').replace('\u2029', ' ') 62 | 63 | return string 64 | 65 | 66 | def replace2space(string): 67 | spaces = ["\r", '\xa0', '\xe2\x80\x85', '\xc2\xa0', '\u2009', '\u2002', '\u200a', '\u2005', '\u2003', '\u2006', 68 | 'Ⅲ', '…', 'Ⅴ', "\u202f"] 69 | 70 | for i in spaces: 71 | string = string.replace(i, ' ') 72 | return string 73 | 74 | -------------------------------------------------------------------------------- /evaluation/evaluate.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | Created on 13/05/2019 5 | 6 | author: fenia 7 | """ 8 | 9 | import argparse 10 | 11 | 12 | def prf(tp, fp, fn): 13 | micro_p = float(tp) / (tp + fp) if (tp + fp != 0) else 0.0 14 | micro_r = float(tp) / (tp + fn) if (tp + fn != 0) else 0.0 15 | micro_f = ((2 * micro_p * micro_r) / (micro_p + micro_r)) if micro_p != 0.0 and micro_r != 0.0 else 0.0 16 | return [micro_p, micro_r, micro_f] 17 | 18 | 19 | def main(): 20 | parser = argparse.ArgumentParser() 21 | parser.add_argument('--gold', type=str) 22 | parser.add_argument('--pred', type=str) 23 | parser.add_argument('--label', type=str) 24 | args = parser.parse_args() 25 | 26 | with open(args.pred) as pred, open(args.gold) as gold: 27 | preds_all = [] 28 | preds_intra = [] 29 | preds_inter = [] 30 | 31 | golds_all = [] 32 | golds_intra = [] 33 | golds_inter = [] 34 | 35 | for line in pred: 36 | line = line.rstrip().split('|') 37 | if line[5] == args.label: 38 | 39 | if (line[0], line[1], line[2], line[3], line[5]) not in preds_all: 40 | preds_all += [(line[0], line[1], line[2], line[3], line[5])] 41 | 42 | if ((line[0], line[1], line[2], line[5]) not in preds_inter) and (line[3] == 'CROSS'): 43 | preds_inter += [(line[0], line[1], line[2], line[5])] 44 | 45 | if ((line[0], line[1], line[2], line[5]) not in preds_intra) and (line[3] == 'NON-CROSS'): 46 | preds_intra += [(line[0], line[1], line[2], line[5])] 47 | 48 | for line2 in gold: 49 | line2 = line2.rstrip().split('|') 50 | 51 | if line2[4] == args.label: 52 | 53 | if (line2[0], line2[1], line2[2], line2[3], line2[4]) not in golds_all: 54 | golds_all += [(line2[0], line2[1], line2[2], line2[3], line2[4])] 55 | 56 | if ((line2[0], line2[1], line2[2], line2[4]) not in golds_inter) and (line2[3] == 'CROSS'): 57 | golds_inter += [(line2[0], line2[1], line2[2], line2[4])] 58 | 59 | if ((line2[0], line2[1], line2[2], line2[4]) not in golds_intra) and (line2[3] == 'NON-CROSS'): 60 | golds_intra += [(line2[0], line2[1], line2[2], line2[4])] 61 | 62 | tp = len([a for a in preds_all if a in golds_all]) 63 | tp_intra = len([a for a in preds_intra if a in golds_intra]) 64 | tp_inter = len([a for a in preds_inter if a in golds_inter]) 65 | 66 | fp = len([a for a in preds_all if a not in golds_all]) 67 | fp_intra = len([a for a in preds_intra if a not in golds_intra]) 68 | fp_inter = len([a for a in preds_inter if a not in golds_inter]) 69 | 70 | fn = len([a for a in golds_all if a not in preds_all]) 71 | fn_intra = len([a for a in golds_intra if a not in preds_intra]) 72 | fn_inter = len([a for a in golds_inter if a not in preds_inter]) 73 | 74 | r1 = prf(tp, fp, fn) 75 | r2 = prf(tp_intra, fp_intra, fn_intra) 76 | r3 = prf(tp_inter, fp_inter, fn_inter) 77 | 78 | print(' TOTAL\tTP\tFP\tFN') 79 | print('Overall P/R/F1\t{:.4f}\t{:.4f}\t{:.4f}\t| {}\t{}\t{}\t{}'.format(r1[0], r1[1], r1[2], 80 | tp + fn, tp, fp, fn)) 81 | print('INTRA P/R/F1\t{:.4f}\t{:.4f}\t{:.4f}\t| {}\t{}\t{}\t{}'.format(r2[0], r2[1], r2[2], 82 | tp_intra + fn_intra, 83 | tp_intra, fp_intra, 84 | fn_intra)) 85 | print('INTER P/R/F1\t{:.4f}\t{:.4f}\t{:.4f}\t| {}\t{}\t{}\t{}'.format(r3[0], r3[1], r3[2], 86 | tp_inter + fn_inter, 87 | tp_inter, fp_inter, 88 | fn_inter)) 89 | 90 | 91 | if __name__ == "__main__": 92 | main() 93 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch==1.1.0 2 | numpy==1.16.3 3 | matplotlib 4 | tqdm 5 | recordtype 6 | yamlordereddictloader 7 | tabulate 8 | networkx 9 | -------------------------------------------------------------------------------- /src/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fenchri/edge-oriented-graph/629b84a630146e81ffe15cbf6f5e5cf4efd9fb34/src/__init__.py -------------------------------------------------------------------------------- /src/__pycache__/converter.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fenchri/edge-oriented-graph/629b84a630146e81ffe15cbf6f5e5cf4efd9fb34/src/__pycache__/converter.cpython-36.pyc -------------------------------------------------------------------------------- /src/__pycache__/dataset.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fenchri/edge-oriented-graph/629b84a630146e81ffe15cbf6f5e5cf4efd9fb34/src/__pycache__/dataset.cpython-36.pyc -------------------------------------------------------------------------------- /src/__pycache__/loader.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fenchri/edge-oriented-graph/629b84a630146e81ffe15cbf6f5e5cf4efd9fb34/src/__pycache__/loader.cpython-36.pyc -------------------------------------------------------------------------------- /src/__pycache__/reader.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fenchri/edge-oriented-graph/629b84a630146e81ffe15cbf6f5e5cf4efd9fb34/src/__pycache__/reader.cpython-36.pyc -------------------------------------------------------------------------------- /src/__pycache__/utils.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fenchri/edge-oriented-graph/629b84a630146e81ffe15cbf6f5e5cf4efd9fb34/src/__pycache__/utils.cpython-36.pyc -------------------------------------------------------------------------------- /src/bin/run.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | expA=("ME MS ES SS-ind" 4 | "MM MS ES SS-ind" 5 | "MM ME ES SS-ind" 6 | "MM ME MS SS-ind" 7 | "MM ME MS ES" 8 | "MM ME MS ES SS-ind" 9 | "MM ME" 10 | "ES SS-ind") 11 | 12 | #options=("--context" 13 | # "--types" 14 | # "--distances" 15 | # "--context --types" 16 | # "--context --distances" 17 | # "--distances --types") 18 | 19 | 20 | options=("--types --dist") 21 | 22 | 23 | for o in "${options[@]}": 24 | do 25 | counter=0 26 | for w in 1 2 3 4; 27 | do 28 | # EOG + SS-ind 29 | python3 eog.py --config ../configs/parameters_cdr.yaml --train \ 30 | --edges MM ME MS ES SS-ind \ 31 | --walks ${w} "${o}" \ 32 | --gpu ${counter} & 33 | counter=$((counter+1)) 34 | 35 | # EOG + SS 36 | python3 eog.py --config ../configs/parameters_cdr.yaml --train \ 37 | --edges MM ME MS ES SS \ 38 | --walks ${w} "${o}" \ 39 | --gpu ${counter} & 40 | counter=$((counter+1)) 41 | done 42 | # wait 43 | done 44 | 45 | 46 | wait 47 | 48 | 49 | for w in 1 2 3 4; 50 | do 51 | # rest 52 | counter=0 53 | for edg in "${expA[@]}"; 54 | do 55 | python3 eog.py --config ../configs/parameters_cdr.yaml --train \ 56 | --edges "${edg}" \ 57 | --walks ${w} "${o}" \ 58 | --gpu ${counter} & 59 | counter=$((counter+1)) 60 | done 61 | wait 62 | done 63 | 64 | 65 | # fully connected 66 | for w in 1 2 3 4; 67 | do 68 | counter=0 69 | for edg in "${expA[@]}"; 70 | do 71 | python3 eog.py --config ../configs/parameters_cdr.yaml --train \ 72 | --edges "FULL" \ 73 | --walks ${w} "${o}" \ 74 | --gpu ${counter} & 75 | counter=$((counter+1)) 76 | done 77 | done 78 | 79 | 80 | # sentence 81 | for w in 1 2 3 4; 82 | do 83 | counter=0 84 | for edg in "${expA[@]}"; 85 | do 86 | python3 eog.py --config ../configs/parametes_cdr.yaml --train \ 87 | --edges MM ME MS ES SS-ind \ 88 | --walks ${w} \ 89 | --gpu ${counter} \ 90 | --walks ${w} "${o}" \ 91 | --window 1 & 92 | counter=$((counter+1)) 93 | done 94 | done 95 | 96 | wait 97 | 98 | 99 | 100 | # no inference 101 | #counter=0 102 | #for w in 0 1 2 3 4 5; 103 | #do 104 | # python3 eog.py --config ../configs/parameters_cdr.yaml --train \ 105 | # --edges "EE" \ 106 | # --walks ${w} \ 107 | # --gpu ${counter} & 108 | # counter=$((counter+1)) 109 | #done 110 | 111 | 112 | # sentence (different options, EOG model) 113 | #python3 eog.py --config ../configs/parameters_cdr.yaml --train \ 114 | # --edges MM ME MS ES SS-ind \ 115 | # --walks 3 \ 116 | # --types \ 117 | # --context \ 118 | # --window 1 \ 119 | # --gpu 0 & 120 | # 121 | #python3 eog.py --config ../configs/parameters_cdr.yaml --train \ 122 | # --edges MM ME MS ES SS-ind \ 123 | # --walks 3 \ 124 | # --types \ 125 | # --window 1 \ 126 | # --gpu 1 & 127 | # 128 | #python3 eog.py --config ../configs/parameters_cdr.yaml --train \ 129 | # --edges MM ME MS ES SS-ind \ 130 | # --walks 3 \ 131 | # --context \ 132 | # --window 1 \ 133 | # --gpu 2 & 134 | # 135 | #python3 eog.py --config ../configs/parameters_cdr.yaml --train \ 136 | # --edges MM ME MS ES SS-ind \ 137 | # --walks 3 \ 138 | # --window 1 \ 139 | # --gpu 3 & 140 | 141 | 142 | -------------------------------------------------------------------------------- /src/converter.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python3 2 | 3 | import torch 4 | import six 5 | import numpy 6 | 7 | 8 | def to_device(device, x): 9 | if device is None: 10 | return torch.as_tensor(x).long() 11 | return torch.as_tensor(x).long().to(device) 12 | 13 | 14 | def concat_examples(batch, device=None, padding=-1): 15 | assert device is None or isinstance(device, torch.device) 16 | if len(batch) == 0: 17 | raise ValueError('batch is empty') 18 | 19 | first_elem = batch[0] 20 | 21 | if isinstance(first_elem, tuple): 22 | result = [] 23 | if not isinstance(padding, tuple): 24 | padding = [padding] * len(first_elem) 25 | 26 | for i in six.moves.range(len(first_elem)): 27 | result.append(to_device(device, _concat_arrays( 28 | [example[i] for example in batch], padding[i]))) 29 | 30 | return tuple(result) 31 | 32 | elif isinstance(first_elem, dict): 33 | result = {} 34 | if not isinstance(padding, dict): 35 | padding = {key: padding for key in first_elem} 36 | 37 | for key in first_elem: 38 | result[key] = to_device(device, _concat_arrays( 39 | [example[key] for example in batch], padding[key])) 40 | 41 | return result 42 | 43 | else: 44 | return to_device(device, _concat_arrays(batch, padding)) 45 | 46 | 47 | def _concat_arrays(arrays, padding): 48 | # Convert `arrays` to numpy.ndarray if `arrays` consists of the built-in 49 | # types such as int, float or list. 50 | if not isinstance(arrays[0], type(torch.get_default_dtype())): 51 | arrays = numpy.asarray(arrays) 52 | 53 | if padding is not None: 54 | arr_concat = _concat_arrays_with_padding(arrays, padding) 55 | else: 56 | arr_concat = numpy.concatenate([array[None] for array in arrays]) 57 | 58 | return arr_concat 59 | 60 | 61 | def _concat_arrays_with_padding(arrays, padding): 62 | shape = numpy.array(arrays[0].shape, dtype=int) 63 | for array in arrays[1:]: 64 | if numpy.any(shape != array.shape): 65 | numpy.maximum(shape, array.shape, shape) 66 | shape = tuple(numpy.insert(shape, 0, len(arrays))) 67 | 68 | result = numpy.full(shape, padding, dtype=arrays[0].dtype) 69 | for i in six.moves.range(len(arrays)): 70 | src = arrays[i] 71 | slices = tuple(slice(dim) for dim in src.shape) 72 | result[(i,) + slices] = src 73 | 74 | return result 75 | -------------------------------------------------------------------------------- /src/dataset.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | Created on 13/06/2019 5 | 6 | author: fenia 7 | """ 8 | 9 | import random 10 | random.seed(0) 11 | import numpy as np 12 | np.random.seed(0) 13 | import numpy as np 14 | from tqdm import tqdm 15 | from collections import OrderedDict 16 | from torch.utils.data import Dataset 17 | 18 | 19 | class DocRelationDataset: 20 | """ 21 | My simpler converter approach, stores everything in a list and then just iterates to create batches. 22 | """ 23 | def __init__(self, loader, data_type, params, mappings): 24 | self.unk_w_prob = params['unk_w_prob'] 25 | self.mappings = mappings 26 | self.loader = loader 27 | self.data_type = data_type 28 | self.edges = params['edges'] 29 | self.data = [] 30 | 31 | def __len__(self): 32 | return len(self.data) 33 | 34 | def __call__(self): 35 | pbar = tqdm(self.loader.documents.keys()) 36 | for pmid in pbar: 37 | pbar.set_description(' Preparing {} data - PMID {}'.format(self.data_type.upper(), pmid)) 38 | 39 | # TEXT 40 | doc = [] 41 | for sentence in self.loader.documents[pmid]: 42 | sent = [] 43 | if self.data_type == 'train': 44 | for w, word in enumerate(sentence): 45 | if (word in self.mappings.singletons) and (random.uniform(0, 1) < float(self.unk_w_prob)): 46 | sent += [self.mappings.word2index['']] # UNK words = singletons for train 47 | else: 48 | sent += [self.mappings.word2index[word]] 49 | 50 | else: 51 | 52 | for w, word in enumerate(sentence): 53 | if word in self.mappings.word2index: 54 | sent += [self.mappings.word2index[word]] 55 | else: 56 | sent += [self.mappings.word2index['']] 57 | assert len(sentence) == len(sent), '{}, {}'.format(len(sentence), len(sent)) 58 | doc += [sent] 59 | 60 | # ENTITIES [id, type, start, end] + NODES [id, type, start, end, node_type_id] 61 | nodes = [] 62 | ent = [] 63 | for id_, (e, i) in enumerate(self.loader.entities[pmid].items()): 64 | nodes += [[id_, self.mappings.type2index[i.type], int(i.mstart.split(':')[0]), 65 | int(i.mend.split(':')[0]), i.sentNo.split(':')[0], 0]] 66 | 67 | for id_, (e, i) in enumerate(self.loader.entities[pmid].items()): 68 | for sent_id, m1, m2 in zip(i.sentNo.split(':'), i.mstart.split(':'), i.mend.split(':')): 69 | ent += [[id_, self.mappings.type2index[i.type], int(m1), int(m2), int(sent_id)]] 70 | nodes += [[id_, self.mappings.type2index[i.type], int(m1), int(m2), int(sent_id), 1]] 71 | 72 | for s, sentence in enumerate(self.loader.documents[pmid]): 73 | nodes += [[s, s, s, s, s, 2]] 74 | 75 | nodes = np.array(nodes, 'i') 76 | ent = np.array(ent, 'i') 77 | 78 | # RELATIONS 79 | ents_keys = list(self.loader.entities[pmid].keys()) # in order 80 | trel = -1 * np.ones((len(ents_keys), len(ents_keys)), 'i') 81 | rel_info = np.empty((len(ents_keys), len(ents_keys)), dtype='object_') 82 | for id_, (r, i) in enumerate(self.loader.pairs[pmid].items()): 83 | if i.type == 'not_include': 84 | continue 85 | trel[ents_keys.index(r[0]), ents_keys.index(r[1])] = self.mappings.rel2index[i.type] 86 | rel_info[ents_keys.index(r[0]), ents_keys.index(r[1])] = OrderedDict( 87 | [('pmid', pmid), 88 | ('sentA', self.loader.entities[pmid][r[0]].sentNo), 89 | ('sentB', 90 | self.loader.entities[pmid][r[1]].sentNo), 91 | ('doc', self.loader.documents[pmid]), 92 | ('entA', self.loader.entities[pmid][r[0]]), 93 | ('entB', self.loader.entities[pmid][r[1]]), 94 | ('rel', self.mappings.rel2index[i.type]), 95 | ('dir', i.direction), 96 | ('cross', i.cross)]) 97 | 98 | ####################### 99 | # DISTANCES 100 | ####################### 101 | xv, yv = np.meshgrid(np.arange(nodes.shape[0]), np.arange(nodes.shape[0]), indexing='ij') 102 | 103 | r_id, c_id = nodes[xv, 5], nodes[yv, 5] 104 | r_Eid, c_Eid = nodes[xv, 0], nodes[yv, 0] 105 | r_Sid, c_Sid = nodes[xv, 4], nodes[yv, 4] 106 | r_Ms, c_Ms = nodes[xv, 2], nodes[yv, 2] 107 | r_Me, c_Me = nodes[xv, 3]-1, nodes[yv, 3]-1 108 | 109 | ignore_pos = self.mappings.n_dist 110 | self.mappings.dist2index[ignore_pos] = ignore_pos 111 | self.mappings.index2dist[ignore_pos] = ignore_pos 112 | 113 | dist = np.full((r_id.shape[0], r_id.shape[0]), ignore_pos, 'i') 114 | 115 | # MM: mention-mention 116 | a_start = np.where((r_id == 1) & (c_id == 1), r_Ms, -1) 117 | a_end = np.where((r_id == 1) & (c_id == 1), r_Me, -1) 118 | b_start = np.where((r_id == 1) & (c_id == 1), c_Ms, -1) 119 | b_end = np.where((r_id == 1) & (c_id == 1), c_Me, -1) 120 | 121 | dist = np.where((a_end < b_start) & (a_end != -1) & (b_start != -1), abs(b_start - a_end), dist) 122 | dist = np.where((b_end < a_start) & (b_end != -1) & (a_start != -1), abs(b_end - a_start), dist) 123 | 124 | # nested (find the distance between their last words) 125 | dist = np.where((b_start <= a_start) & (b_end >= a_end) 126 | & (b_start != -1) & (a_end != -1) & (b_end != -1) & (a_start != -1), abs(b_end-a_end), dist) 127 | dist = np.where((b_start >= a_start) & (b_end <= a_end) 128 | & (b_start != -1) & (a_end != -1) & (b_end != -1) & (a_start != -1), abs(a_end-b_end), dist) 129 | 130 | # diagonal 131 | dist[np.arange(nodes.shape[0]), np.arange(nodes.shape[0])] = 0 132 | 133 | # limit max distance according to training set 134 | dist = np.where(dist > self.mappings.max_distance, self.mappings.max_distance, dist) 135 | 136 | # restrictions: to MM pairs in the same sentence 137 | dist = np.where(((r_id == 1) & (c_id == 1) & (r_Sid == c_Sid)), dist, ignore_pos) 138 | 139 | # SS: sentence-sentence 140 | dist = np.where(((r_id == 2) & (c_id == 2)), abs(c_Sid - r_Sid), dist) 141 | 142 | ####################### 143 | # GRAPH CONNECTIONS 144 | ####################### 145 | adjacency = np.full((r_id.shape[0], r_id.shape[0]), 0, 'i') 146 | 147 | if 'FULL' in self.edges: 148 | adjacency = np.full(adjacency.shape, 1, 'i') 149 | 150 | if 'MM' in self.edges: 151 | # mention-mention 152 | adjacency = np.where((r_id == 1) & (c_id == 1) & (r_Sid == c_Sid), 1, adjacency) # in same sentence 153 | 154 | if ('EM' in self.edges) or ('ME' in self.edges): 155 | # entity-mention 156 | adjacency = np.where((r_id == 0) & (c_id == 1) & (r_Eid == c_Eid), 1, adjacency) # belongs to entity 157 | adjacency = np.where((r_id == 1) & (c_id == 0) & (r_Eid == c_Eid), 1, adjacency) 158 | 159 | if 'SS' in self.edges: 160 | # sentence-sentence (in order) 161 | adjacency = np.where((r_id == 2) & (c_id == 2) & (r_Sid == c_Sid - 1), 1, adjacency) 162 | adjacency = np.where((r_id == 2) & (c_id == 2) & (c_Sid == r_Sid - 1), 1, adjacency) 163 | 164 | if 'SS-ind' in self.edges: 165 | # sentence-sentence (direct + indirect) 166 | adjacency = np.where((r_id == 2) & (c_id == 2), 1, adjacency) 167 | 168 | if ('MS' in self.edges) or ('SM' in self.edges): 169 | # mention-sentence 170 | adjacency = np.where((r_id == 1) & (c_id == 2) & (r_Sid == c_Sid), 1, adjacency) # belongs to sentence 171 | adjacency = np.where((r_id == 2) & (c_id == 1) & (r_Sid == c_Sid), 1, adjacency) 172 | 173 | if ('ES' in self.edges) or ('SE' in self.edges): 174 | # entity-sentence 175 | for x, y in zip(xv.ravel(), yv.ravel()): 176 | if nodes[x, 5] == 0 and nodes[y, 5] == 2: # this is an entity-sentence edge 177 | z = np.where((r_Eid == nodes[x, 0]) & (r_id == 1) & (c_id == 2) & (c_Sid == nodes[y, 4])) 178 | 179 | # at least one M in S 180 | temp_ = np.where((r_id == 1) & (c_id == 2) & (r_Sid == c_Sid), 1, adjacency) 181 | temp_ = np.where((r_id == 2) & (c_id == 1) & (r_Sid == c_Sid), 1, temp_) 182 | adjacency[x, y] = 1 if (temp_[z] == 1).any() else 0 183 | adjacency[y, x] = 1 if (temp_[z] == 1).any() else 0 184 | 185 | if 'EE' in self.edges: 186 | adjacency = np.where((r_id == 0) & (c_id == 0), 1, adjacency) 187 | 188 | # self-loops = 0 [always] 189 | adjacency[np.arange(r_id.shape[0]), np.arange(r_id.shape[0])] = 0 190 | 191 | dist = list(map(lambda y: self.mappings.dist2index[y], dist.ravel().tolist())) # map 192 | dist = np.array(dist, 'i').reshape((nodes.shape[0], nodes.shape[0])) 193 | 194 | if (trel == -1).all(): # no relations --> ignore 195 | continue 196 | 197 | self.data += [{'ents': ent, 'rels': trel, 'dist': dist, 'text': doc, 'info': rel_info, 198 | 'adjacency': adjacency, 199 | 'section': np.array([len(self.loader.entities[pmid].items()), ent.shape[0], len(doc)], 'i'), 200 | 'word_sec': np.array([len(s) for s in doc]), 201 | 'words': np.hstack([np.array(s, 'i') for s in doc])}] 202 | return self.data 203 | -------------------------------------------------------------------------------- /src/eog.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | Created on 21-Feb-2019 5 | 6 | author: fenia 7 | """ 8 | 9 | import torch 10 | import random 11 | import numpy as np 12 | from dataset import DocRelationDataset 13 | from loader import DataLoader, ConfigLoader 14 | from nnet.trainer import Trainer 15 | from utils import setup_log, save_model, load_model, plot_learning_curve, load_mappings 16 | 17 | 18 | def set_seed(seed): 19 | torch.manual_seed(seed) 20 | torch.cuda.manual_seed(seed) 21 | torch.cuda.manual_seed_all(seed) # if you are using multi-GPU 22 | np.random.seed(seed) # Numpy module 23 | random.seed(seed) # Python random module 24 | torch.backends.cudnn.benchmark = False 25 | torch.backends.cudnn.deterministic = True 26 | 27 | 28 | def train(parameters): 29 | model_folder = setup_log(parameters, 'train') 30 | 31 | set_seed(parameters['seed']) 32 | 33 | ################################### 34 | # Data Loading 35 | ################################### 36 | print('Loading training data ...') 37 | train_loader = DataLoader(parameters['train_data'], parameters) 38 | train_loader(embeds=parameters['embeds']) 39 | train_data = DocRelationDataset(train_loader, 'train', parameters, train_loader).__call__() 40 | 41 | print('\nLoading testing data ...') 42 | test_loader = DataLoader(parameters['test_data'], parameters) 43 | test_loader() 44 | test_data = DocRelationDataset(test_loader, 'test', parameters, train_loader).__call__() 45 | 46 | ################################### 47 | # Training 48 | ################################### 49 | trainer = Trainer(train_loader, parameters, {'train': train_data, 'test': test_data}, model_folder) 50 | trainer.run() 51 | 52 | if parameters['plot']: 53 | plot_learning_curve(trainer, model_folder) 54 | 55 | if parameters['save_model']: 56 | save_model(model_folder, trainer, train_loader) 57 | 58 | 59 | def test(parameters): 60 | model_folder = setup_log(parameters, 'test') 61 | 62 | print('\nLoading mappings ...') 63 | train_loader = load_mappings(model_folder) 64 | 65 | print('\nLoading testing data ...') 66 | test_loader = DataLoader(parameters['test_data'], parameters) 67 | test_loader() 68 | test_data = DocRelationDataset(test_loader, 'test', parameters, train_loader).__call__() 69 | 70 | m = Trainer(train_loader, parameters, {'train': [], 'test': test_data}, model_folder) 71 | trainer = load_model(model_folder, m) 72 | trainer.eval_epoch(final=True, save_predictions=True) 73 | 74 | 75 | def main(): 76 | config = ConfigLoader() 77 | parameters = config.load_config() 78 | 79 | if parameters['train']: 80 | train(parameters) 81 | 82 | elif parameters['test']: 83 | test(parameters) 84 | 85 | 86 | if __name__ == "__main__": 87 | main() 88 | 89 | -------------------------------------------------------------------------------- /src/loader.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | Created on 21-Feb-2019 5 | 6 | author: fenia 7 | """ 8 | 9 | import numpy as np 10 | import argparse 11 | import yaml 12 | import yamlordereddictloader 13 | from collections import OrderedDict 14 | from reader import read, read_subdocs 15 | 16 | 17 | def str2bool(i): 18 | if isinstance(i, bool): 19 | return i 20 | if i.lower() in ('yes', 'true', 't', 'y', '1'): 21 | return True 22 | elif i.lower() in ('no', 'false', 'f', 'n', '0'): 23 | return False 24 | else: 25 | raise argparse.ArgumentTypeError('Boolean value expected.') 26 | 27 | 28 | class ConfigLoader: 29 | def __init__(self): 30 | pass 31 | 32 | @staticmethod 33 | def load_cmd(): 34 | parser = argparse.ArgumentParser() 35 | parser.add_argument('--config', type=str, required=True, help='Yaml parameter file') 36 | parser.add_argument('--train', action='store_true', help='Training mode - model is saved') 37 | parser.add_argument('--test', action='store_true', help='Testing mode - needs a model to load') 38 | parser.add_argument('--gpu', type=int, help='GPU number') 39 | parser.add_argument('--walks', type=int, help='Number of walk iterations') 40 | parser.add_argument('--window', type=int, help='Window for training (empty processes the whole document, ' 41 | '1 processes 1 sentence at a time, etc)') 42 | parser.add_argument('--edges', nargs='*', help='Edge types') 43 | parser.add_argument('--types', type=str2bool, help='Include node types (Boolean)') 44 | parser.add_argument('--context', type=str2bool, help='Include MM context (Boolean)') 45 | parser.add_argument('--dist', type=str2bool, help='Include distance (Boolean)') 46 | parser.add_argument('--example', help='Show example', action='store_true') 47 | parser.add_argument('--seed', help='Fixed random seed number', type=int) 48 | parser.add_argument('--early_stop', action='store_true', help='Use early stopping') 49 | parser.add_argument('--epoch', type=int, help='Maximum training epoch') 50 | return parser.parse_args() 51 | 52 | def load_config(self): 53 | inp = self.load_cmd() 54 | with open(vars(inp)['config'], 'r') as f: 55 | parameters = yaml.load(f, Loader=yamlordereddictloader.Loader) 56 | 57 | parameters = dict(parameters) 58 | if not inp.train and not inp.test: 59 | print('Please specify train/test mode.') 60 | sys.exit(0) 61 | 62 | parameters['train'] = inp.train 63 | parameters['test'] = inp.test 64 | parameters['gpu'] = inp.gpu 65 | parameters['example'] = inp.example 66 | 67 | if inp.walks and inp.walks >= 0: 68 | parameters['walks_iter'] = inp.walks 69 | 70 | if inp.edges: 71 | parameters['edges'] = inp.edges 72 | 73 | if inp.types != None: 74 | parameters['types'] = inp.types 75 | 76 | if inp.dist != None: 77 | parameters['dist'] = inp.dist 78 | 79 | if inp.window: 80 | parameters['window'] = inp.window 81 | 82 | if inp.context != None: 83 | parameters['context'] = inp.context 84 | 85 | if inp.seed: 86 | parameters['seed'] = inp.seed 87 | 88 | if inp.epoch: 89 | parameters['epoch'] = inp.epoch 90 | 91 | if inp.early_stop: 92 | parameters['early_stop'] = True 93 | 94 | return parameters 95 | 96 | 97 | class DataLoader: 98 | def __init__(self, input_file, params): 99 | self.input = input_file 100 | self.params = params 101 | 102 | self.pre_words = [] 103 | self.pre_embeds = OrderedDict() 104 | self.max_distance = -9999999999 105 | self.singletons = [] 106 | self.label2ignore = -1 107 | self.ign_label = self.params['label2ignore'] 108 | 109 | self.word2index, self.index2word, self.n_words, self.word2count = {'': 0}, {0: ''}, 1, {'': 1} 110 | self.type2index, self.index2type, self.n_type, self.type2count = {'': 0, '': 1, '': 2}, \ 111 | {0: '', 1: '', 2: ''}, 3, \ 112 | {'': 1, '': 1, '': 1} 113 | self.rel2index, self.index2rel, self.n_rel, self.rel2count = {}, {}, 0, {} 114 | self.dist2index, self.index2dist, self.n_dist, self.dist2count = {}, {}, 0, {} 115 | self.documents, self.entities, self.pairs = OrderedDict(), OrderedDict(), OrderedDict() 116 | 117 | def find_ignore_label(self): 118 | """ 119 | Find relation Id to ignore 120 | """ 121 | for key, val in self.index2rel.items(): 122 | if val == self.ign_label: 123 | self.label2ignore = key 124 | assert self.label2ignore != -1 125 | 126 | @staticmethod 127 | def check_nested(p): 128 | starts1 = list(map(int, p[8].split(':'))) 129 | ends1 = list(map(int, p[9].split(':'))) 130 | 131 | starts2 = list(map(int, p[14].split(':'))) 132 | ends2 = list(map(int, p[15].split(':'))) 133 | 134 | for s1, e1, s2, e2 in zip(starts1, ends1, starts2, ends2): 135 | if bool(set(np.arange(s1, e1)) & set(np.arange(s2, e2))): 136 | print('nested pair', p) 137 | 138 | def find_singletons(self, min_w_freq=1): 139 | """ 140 | Find items with frequency <= 2 and based on probability 141 | """ 142 | self.singletons = frozenset([elem for elem, val in self.word2count.items() 143 | if (val <= min_w_freq) and elem != '']) 144 | 145 | def add_relation(self, rel): 146 | if rel not in self.rel2index: 147 | self.rel2index[rel] = self.n_rel 148 | self.rel2count[rel] = 1 149 | self.index2rel[self.n_rel] = rel 150 | self.n_rel += 1 151 | else: 152 | self.rel2count[rel] += 1 153 | 154 | def add_word(self, word): 155 | if word not in self.word2index: 156 | self.word2index[word] = self.n_words 157 | self.word2count[word] = 1 158 | self.index2word[self.n_words] = word 159 | self.n_words += 1 160 | else: 161 | self.word2count[word] += 1 162 | 163 | def add_type(self, type): 164 | if type not in self.type2index: 165 | self.type2index[type] = self.n_type 166 | self.type2count[type] = 1 167 | self.index2type[self.n_type] = type 168 | self.n_type += 1 169 | else: 170 | self.type2count[type] += 1 171 | 172 | def add_dist(self, dist): 173 | if dist not in self.dist2index: 174 | self.dist2index[dist] = self.n_dist 175 | self.dist2count[dist] = 1 176 | self.index2dist[self.n_dist] = dist 177 | self.n_dist += 1 178 | else: 179 | self.dist2count[dist] += 1 180 | 181 | def add_sentence(self, sentence): 182 | for word in sentence: 183 | self.add_word(word) 184 | 185 | def add_document(self, document): 186 | for sentence in document: 187 | self.add_sentence(sentence) 188 | 189 | def load_embeds(self, word_dim): 190 | """ 191 | Load pre-trained word embeddings if specified 192 | """ 193 | self.pre_embeds = OrderedDict() 194 | with open(self.params['embeds'], 'r') as vectors: 195 | for x, line in enumerate(vectors): 196 | 197 | if x == 0 and len(line.split()) == 2: 198 | words, num = map(int, line.rstrip().split()) 199 | else: 200 | word = line.rstrip().split()[0] 201 | vec = line.rstrip().split()[1:] 202 | 203 | n = len(vec) 204 | if n != word_dim: 205 | print('Wrong dimensionality! -- line No{}, word: {}, len {}'.format(x, word, n)) 206 | continue 207 | self.add_word(word) 208 | self.pre_embeds[word] = np.asarray(vec, 'f') 209 | self.pre_words = [w for w, e in self.pre_embeds.items()] 210 | print(' Found pre-trained word embeddings: {} x {}'.format(len(self.pre_embeds), word_dim), end="") 211 | 212 | def find_max_length(self, lengths): 213 | self.max_distance = max(lengths) - 1 214 | 215 | def read_n_map(self): 216 | """ 217 | Read input. 218 | Lengths is the max distance for each document 219 | """ 220 | if not self.params['window']: 221 | lengths, sents, self.documents, self.entities, self.pairs = \ 222 | read(self.input, self.documents, self.entities, self.pairs) 223 | else: 224 | lengths, sents, self.documents, self.entities, self.pairs = \ 225 | read_subdocs(self.input, self.params['window'], self.documents, self.entities, self.pairs) 226 | 227 | self.find_max_length(lengths) 228 | 229 | # map types and positions and relation types 230 | for did, d in self.documents.items(): 231 | self.add_document(d) 232 | 233 | for did, e in self.entities.items(): 234 | for k, v in e.items(): 235 | self.add_type(v.type) 236 | 237 | for dist in np.arange(0, self.max_distance+1): 238 | self.add_dist(dist) 239 | 240 | for did, p in self.pairs.items(): 241 | for k, v in p.items(): 242 | if v.type == 'not_include': 243 | continue 244 | self.add_relation(v.type) 245 | assert len(self.entities) == len(self.documents) == len(self.pairs) 246 | 247 | def statistics(self): 248 | """ 249 | Print statistics for the dataset 250 | """ 251 | print(' Documents: {:<5}\n Words: {:<5}'.format(len(self.documents), self.n_words)) 252 | 253 | print(' Relations: {}'.format(sum([v for k, v in self.rel2count.items()]))) 254 | for k, v in sorted(self.rel2count.items()): 255 | print('\t{:<10}\t{:<5}\tID: {}'.format(k, v, self.rel2index[k])) 256 | 257 | print(' Entities: {}'.format(sum([len(e) for e in self.entities.values()]))) 258 | for k, v in sorted(self.type2count.items()): 259 | print('\t{:<10}\t{:<5}\tID: {}'.format(k, v, self.type2index[k])) 260 | 261 | print(' Singletons: {}/{}'.format(len(self.singletons), self.n_words)) 262 | 263 | def __call__(self, embeds=None): 264 | self.read_n_map() 265 | self.find_ignore_label() 266 | self.find_singletons(self.params['min_w_freq']) # words with freq=1 267 | self.statistics() 268 | if embeds: 269 | self.load_embeds(self.params['word_dim']) 270 | print(' --> Words + Pre-trained: {:<5}'.format(self.n_words)) 271 | -------------------------------------------------------------------------------- /src/nnet/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fenchri/edge-oriented-graph/629b84a630146e81ffe15cbf6f5e5cf4efd9fb34/src/nnet/__init__.py -------------------------------------------------------------------------------- /src/nnet/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fenchri/edge-oriented-graph/629b84a630146e81ffe15cbf6f5e5cf4efd9fb34/src/nnet/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /src/nnet/__pycache__/attention.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fenchri/edge-oriented-graph/629b84a630146e81ffe15cbf6f5e5cf4efd9fb34/src/nnet/__pycache__/attention.cpython-36.pyc -------------------------------------------------------------------------------- /src/nnet/__pycache__/init_net.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fenchri/edge-oriented-graph/629b84a630146e81ffe15cbf6f5e5cf4efd9fb34/src/nnet/__pycache__/init_net.cpython-36.pyc -------------------------------------------------------------------------------- /src/nnet/__pycache__/modules.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fenchri/edge-oriented-graph/629b84a630146e81ffe15cbf6f5e5cf4efd9fb34/src/nnet/__pycache__/modules.cpython-36.pyc -------------------------------------------------------------------------------- /src/nnet/__pycache__/network.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fenchri/edge-oriented-graph/629b84a630146e81ffe15cbf6f5e5cf4efd9fb34/src/nnet/__pycache__/network.cpython-36.pyc -------------------------------------------------------------------------------- /src/nnet/__pycache__/trainer.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fenchri/edge-oriented-graph/629b84a630146e81ffe15cbf6f5e5cf4efd9fb34/src/nnet/__pycache__/trainer.cpython-36.pyc -------------------------------------------------------------------------------- /src/nnet/__pycache__/walks.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fenchri/edge-oriented-graph/629b84a630146e81ffe15cbf6f5e5cf4efd9fb34/src/nnet/__pycache__/walks.cpython-36.pyc -------------------------------------------------------------------------------- /src/nnet/attention.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | """ 3 | Created on 29/03/19 4 | 5 | author: fenia 6 | """ 7 | 8 | import torch 9 | from torch import nn, torch 10 | import math 11 | 12 | 13 | class Dot_Attention(nn.Module): 14 | """ 15 | Adaptation from "Attention is all you need". 16 | Here the query is the target pair and the keys/values are the words of the sentence. 17 | The dimensionality of the queries and the values should be the same. 18 | """ 19 | def __init__(self, input_size, device=-1, scale=False): 20 | 21 | super(Dot_Attention, self).__init__() 22 | 23 | self.softmax = nn.Softmax(dim=2) 24 | self.scale = scale 25 | if scale: 26 | self.sc = 1.0 / math.sqrt(input_size) 27 | self.device = device 28 | 29 | def create_mask(self, alpha, size_, lengths, idx_): 30 | """ Put 1 in valid tokens """ 31 | mention_sents = torch.index_select(lengths, 0, idx_[:, 4]) 32 | 33 | # mask padded words (longer that sentence length) 34 | tempa = torch.arange(size_).unsqueeze(0).repeat(alpha.shape[0], 1).to(self.device) 35 | mask_a = torch.ge(tempa, mention_sents[:, None]) 36 | 37 | # mask tokens that are used as queries 38 | tempb = torch.arange(lengths.size(0)).unsqueeze(0).repeat(alpha.shape[0], 1).to(self.device) # m x sents 39 | sents = torch.where(torch.lt(tempb, idx_[:, 4].unsqueeze(1)), 40 | lengths.unsqueeze(0).repeat(alpha.shape[0], 1), 41 | torch.zeros_like(lengths.unsqueeze(0).repeat(alpha.shape[0], 1))) 42 | 43 | total_l = torch.cumsum(sents, dim=1)[:, -1] 44 | mask_b = torch.ge(tempa, (idx_[:, 2] - total_l)[:, None]) & torch.lt(tempa, (idx_[:, 3] - total_l)[:, None]) 45 | 46 | mask = ~(mask_a | mask_b) 47 | del tempa, tempb, total_l 48 | return mask 49 | 50 | def forward(self, queries, values, idx, lengths): 51 | """ 52 | a = softmax( q * H^T ) 53 | v = a * H 54 | """ 55 | alpha = torch.matmul(queries.unsqueeze(1), values.transpose(1, 2)) 56 | 57 | if self.scale: 58 | alpha = alpha * self.sc 59 | 60 | mask_ = self.create_mask(alpha, values.size(1), lengths, idx) 61 | alpha = torch.where(mask_.unsqueeze(1), 62 | alpha, 63 | torch.as_tensor([float('-inf')]).to(self.device)) 64 | 65 | # in case all context words are masked (e.g. only 2 mentions in the sentence) -- a naive fix 66 | alpha = torch.where(torch.isinf(alpha).all(dim=2, keepdim=True), torch.full_like(alpha, 1.0), alpha) 67 | alpha = self.softmax(alpha) 68 | alpha = torch.where(torch.eq(alpha, 1/alpha.shape[2]).all(dim=2, keepdim=True), torch.zeros_like(alpha), alpha) 69 | 70 | alpha = torch.squeeze(alpha, 1) 71 | return alpha 72 | -------------------------------------------------------------------------------- /src/nnet/init_net.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | Created on 21-Feb-2019 5 | 6 | author: fenia 7 | """ 8 | 9 | import torch 10 | from torch import nn 11 | from nnet.modules import EmbedLayer, Encoder, Classifier 12 | from nnet.attention import Dot_Attention 13 | from nnet.walks import WalkLayer 14 | 15 | 16 | class BaseNet(nn.Module): 17 | def __init__(self, params, pembeds, sizes=None, maps=None, lab2ign=None): 18 | super(BaseNet, self).__init__() 19 | 20 | self.edg = ['MM', 'SS', 'ME', 'MS', 'ES', 'EE'] 21 | 22 | self.dims = {} 23 | for k in self.edg: 24 | self.dims[k] = 4 * params['lstm_dim'] 25 | 26 | self.device = torch.device("cuda:{}".format(params['gpu']) if params['gpu'] != -1 else "cpu") 27 | 28 | self.encoder = Encoder(input_size=params['word_dim'], 29 | rnn_size=params['out_dim'], 30 | num_layers=1, 31 | bidirectional=True, 32 | dropout=0.0) 33 | 34 | self.word_embed = EmbedLayer(num_embeddings=sizes['word_size'], 35 | embedding_dim=params['word_dim'], 36 | dropout=params['drop_i'], 37 | ignore=None, 38 | freeze=params['freeze_words'], 39 | pretrained=pembeds, 40 | mapping=maps['word2idx']) 41 | 42 | if params['dist']: 43 | self.dims['MM'] += params['dist_dim'] 44 | self.dims['SS'] += params['dist_dim'] 45 | self.dist_embed = EmbedLayer(num_embeddings=sizes['dist_size'] + 1, 46 | embedding_dim=params['dist_dim'], 47 | dropout=0.0, 48 | ignore=sizes['dist_size'], 49 | freeze=False, 50 | pretrained=None, 51 | mapping=None) 52 | 53 | if params['context']: 54 | self.dims['MM'] += (2 * params['lstm_dim']) 55 | self.attention = Dot_Attention(input_size=2 * params['lstm_dim'], 56 | device=self.device, 57 | scale=False) 58 | 59 | if params['types']: 60 | for k in self.edg: 61 | self.dims[k] += (2 * params['type_dim']) 62 | 63 | self.type_embed = EmbedLayer(num_embeddings=3, 64 | embedding_dim=params['type_dim'], 65 | dropout=0.0, 66 | freeze=False, 67 | pretrained=None, 68 | mapping=None) 69 | 70 | self.reduce = nn.ModuleDict() 71 | for k in self.edg: 72 | if k != 'EE': 73 | self.reduce.update({k: nn.Linear(self.dims[k], params['out_dim'], bias=False)}) 74 | elif (('EE' in params['edges']) or ('FULL' in params['edges'])) and (k == 'EE'): 75 | self.ee = True 76 | self.reduce.update({k: nn.Linear(self.dims[k], params['out_dim'], bias=False)}) 77 | else: 78 | self.ee = False 79 | 80 | if params['walks_iter'] and params['walks_iter'] > 0: 81 | self.walk = WalkLayer(input_size=params['out_dim'], 82 | iters=params['walks_iter'], 83 | beta=params['beta'], 84 | device=self.device) 85 | 86 | self.classifier = Classifier(in_size=params['out_dim'], 87 | out_size=sizes['rel_size'], 88 | dropout=params['drop_o']) 89 | self.loss = nn.CrossEntropyLoss() 90 | 91 | # hyper-parameters for tuning 92 | self.beta = params['beta'] 93 | self.dist_dim = params['dist_dim'] 94 | self.type_dim = params['type_dim'] 95 | self.drop_i = params['drop_i'] 96 | self.drop_o = params['drop_o'] 97 | self.gradc = params['gc'] 98 | self.learn = params['lr'] 99 | self.reg = params['reg'] 100 | self.out_dim = params['out_dim'] 101 | 102 | # other parameters 103 | self.mappings = {'word': maps['word2idx'], 'type': maps['type2idx'], 'dist': maps['dist2idx']} 104 | self.inv_mappings = {'word': maps['idx2word'], 'type': maps['idx2type'], 'dist': maps['idx2dist']} 105 | self.word_dim = params['word_dim'] 106 | self.lstm_dim = params['lstm_dim'] 107 | self.walks_iter = params['walks_iter'] 108 | self.rel_size = sizes['rel_size'] 109 | self.types = params['types'] 110 | self.ignore_label = lab2ign 111 | self.context = params['context'] 112 | self.dist = params['dist'] 113 | 114 | -------------------------------------------------------------------------------- /src/nnet/modules.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | Created on 06-Mar-2019 5 | 6 | author: fenia 7 | """ 8 | 9 | import torch 10 | from torch import nn, torch 11 | from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence 12 | 13 | 14 | class EmbedLayer(nn.Module): 15 | def __init__(self, num_embeddings, embedding_dim, dropout, ignore=None, freeze=False, pretrained=None, mapping=None): 16 | """ 17 | Args: 18 | num_embeddings: (tensor) number of unique items 19 | embedding_dim: (int) dimensionality of vectors 20 | dropout: (float) dropout rate 21 | trainable: (bool) train or not 22 | pretrained: (dict) pretrained embeddings 23 | mapping: (dict) mapping of items to unique ids 24 | """ 25 | super(EmbedLayer, self).__init__() 26 | 27 | self.embedding_dim = embedding_dim 28 | self.num_embeddings = num_embeddings 29 | self.freeze = freeze 30 | 31 | self.embedding = nn.Embedding(num_embeddings=num_embeddings, 32 | embedding_dim=embedding_dim, 33 | padding_idx=ignore) 34 | 35 | if pretrained: 36 | self.load_pretrained(pretrained, mapping) 37 | self.embedding.weight.requires_grad = not freeze 38 | 39 | self.drop = nn.Dropout(dropout) 40 | 41 | def load_pretrained(self, pretrained, mapping): 42 | """ 43 | Args: 44 | weights: (dict) keys are words, values are vectors 45 | mapping: (dict) keys are words, values are unique ids 46 | 47 | Returns: updates the embedding matrix with pre-trained embeddings 48 | """ 49 | for word in mapping.keys(): 50 | if word in pretrained: 51 | self.embedding.weight.data[mapping[word], :] = torch.from_numpy(pretrained[word]) 52 | elif word.lower() in pretrained: 53 | self.embedding.weight.data[mapping[word], :] = torch.from_numpy(pretrained[word.lower()]) 54 | 55 | assert (self.embedding.weight[mapping['and']].to('cpu').data.numpy() == pretrained['and']).all(), \ 56 | 'ERROR: Embeddings not assigned' 57 | 58 | def forward(self, xs): 59 | """ 60 | Args: 61 | xs: (tensor) batchsize x word_ids 62 | 63 | Returns: (tensor) batchsize x word_ids x dimensionality 64 | """ 65 | embeds = self.embedding(xs) 66 | if self.drop.p > 0: 67 | embeds = self.drop(embeds) 68 | 69 | return embeds 70 | 71 | 72 | class Encoder(nn.Module): 73 | def __init__(self, input_size, rnn_size, num_layers, bidirectional, dropout): 74 | """ 75 | Wrapper for LSTM encoder 76 | Args: 77 | input_size (int): the size of the input features 78 | rnn_size (int): 79 | num_layers (int): 80 | bidirectional (bool): 81 | dropout (float): 82 | Returns: outputs, last_outputs 83 | - **outputs** of shape `(batch, seq_len, hidden_size)`: 84 | tensor containing the output features `(h_t)` 85 | from the last layer of the LSTM, for each t. 86 | - **last_outputs** of shape `(batch, hidden_size)`: 87 | tensor containing the last output features 88 | from the last layer of the LSTM, for each t=seq_len. 89 | """ 90 | super(Encoder, self).__init__() 91 | 92 | self.enc = nn.LSTM(input_size=input_size, 93 | hidden_size=rnn_size, 94 | num_layers=num_layers, 95 | bidirectional=bidirectional, 96 | dropout=dropout, 97 | batch_first=True) 98 | 99 | # the dropout "layer" for the output of the RNN 100 | self.drop = nn.Dropout(dropout) 101 | 102 | # define output feature size 103 | self.feature_size = rnn_size 104 | 105 | if bidirectional: 106 | self.feature_size *= 2 107 | 108 | @staticmethod 109 | def sort(lengths): 110 | sorted_len, sorted_idx = lengths.sort() # indices that result in sorted sequence 111 | _, original_idx = sorted_idx.sort(0, descending=True) 112 | reverse_idx = torch.linspace(lengths.size(0) - 1, 0, lengths.size(0)).long() # for big-to-small 113 | 114 | return sorted_idx, original_idx, reverse_idx 115 | 116 | def forward(self, embeds, lengths, hidden=None): 117 | """ 118 | This is the heart of the model. This function, defines how the data 119 | passes through the network. 120 | Args: 121 | embs (tensor): word embeddings 122 | lengths (list): the lengths of each sentence 123 | Returns: the logits for each class 124 | """ 125 | # sort sequence 126 | sorted_idx, original_idx, reverse_idx = self.sort(lengths) 127 | 128 | # pad - sort - pack 129 | embeds = nn.utils.rnn.pad_sequence(embeds, batch_first=True, padding_value=0) 130 | embeds = embeds[sorted_idx][reverse_idx] # big-to-small 131 | packed = pack_padded_sequence(embeds, list(lengths[sorted_idx][reverse_idx].data), batch_first=True) 132 | 133 | self.enc.flatten_parameters() 134 | out_packed, _ = self.enc(packed, hidden) 135 | 136 | # unpack 137 | outputs, _ = pad_packed_sequence(out_packed, batch_first=True) 138 | 139 | # apply dropout to the outputs of the RNN 140 | outputs = self.drop(outputs) 141 | 142 | # unsort the list 143 | outputs = outputs[reverse_idx][original_idx][reverse_idx] 144 | return outputs 145 | 146 | 147 | class Classifier(nn.Module): 148 | def __init__(self, in_size, out_size, dropout): 149 | """ 150 | Args: 151 | in_size: input tensor dimensionality 152 | out_size: outpout tensor dimensionality 153 | dropout: dropout rate 154 | """ 155 | super(Classifier, self).__init__() 156 | 157 | self.drop = nn.Dropout(dropout) 158 | self.lin = nn.Linear(in_features=in_size, 159 | out_features=out_size, 160 | bias=True) 161 | 162 | def forward(self, xs): 163 | """ 164 | Args: 165 | xs: (tensor) batchsize x * x features 166 | 167 | Returns: (tensor) batchsize x * x class_size 168 | """ 169 | if self.drop.p > 0: 170 | xs = self.drop(xs) 171 | 172 | xs = self.lin(xs) 173 | return xs 174 | 175 | 176 | 177 | 178 | 179 | 180 | -------------------------------------------------------------------------------- /src/nnet/network.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | Created on 25-Feb-2019 5 | 6 | author: fenia 7 | """ 8 | 9 | import torch 10 | import torch.nn.functional as F 11 | from torch.nn.utils.rnn import pad_sequence 12 | from nnet.init_net import BaseNet 13 | 14 | 15 | class EOG(BaseNet): 16 | def input_layer(self, words_): 17 | """ 18 | Word Embedding Layer 19 | """ 20 | word_vec = self.word_embed(words_) 21 | return word_vec 22 | 23 | def encoding_layer(self, word_vec, word_sec): 24 | """ 25 | Encoder Layer -> Encode sequences using BiLSTM. 26 | """ 27 | ys = self.encoder(torch.split(word_vec, word_sec.tolist(), dim=0), word_sec) 28 | return ys 29 | 30 | def graph_layer(self, encoded_seq, info, word_sec, section, positions): 31 | """ 32 | Graph Layer -> Construct a document-level graph 33 | The graph edges hold representations for the connections between the nodes. 34 | Args: 35 | encoded_seq: Encoded sequence, shape (sentences, words, dimension) 36 | info: (Tensor, 5 columns) entity_id, entity_type, start_wid, end_wid, sentence_id 37 | word_sec: (Tensor) number of words per sentence 38 | section: (Tensor ) #entities/#mentions/#sentences per batch 39 | positions: distances between nodes (only M-M and S-S) 40 | 41 | Returns: (Tensor) graph, (Tensor) tensor_mapping, (Tensors) indices, (Tensor) node information 42 | """ 43 | # SENTENCE NODES 44 | sentences = torch.mean(encoded_seq, dim=1) # sentence nodes (avg of sentence words) 45 | 46 | # MENTION & ENTITY NODES 47 | temp_ = torch.arange(word_sec.max()).unsqueeze(0).repeat(sentences.size(0), 1).to(self.device) 48 | remove_pad = (temp_ < word_sec.unsqueeze(1)) 49 | 50 | mentions = self.merge_tokens(info, encoded_seq, remove_pad) # mention nodes 51 | entities = self.merge_mentions(info, mentions) # entity nodes 52 | 53 | # all nodes in order: entities - mentions - sentences 54 | nodes = torch.cat((entities, mentions, sentences), dim=0) # e + m + s (all) 55 | nodes_info = self.node_info(section, info) # info/node: node type | semantic type | sentence ID 56 | 57 | if self.types: # + node types 58 | nodes = torch.cat((nodes, self.type_embed(nodes_info[:, 0])), dim=1) 59 | 60 | # re-order nodes per document (batch) 61 | nodes = self.rearrange_nodes(nodes, section) 62 | nodes = self.split_n_pad(nodes, section, pad=0) 63 | 64 | nodes_info = self.rearrange_nodes(nodes_info, section) 65 | nodes_info = self.split_n_pad(nodes_info, section, pad=-1) 66 | 67 | # create initial edges (concat node representations) 68 | r_idx, c_idx = torch.meshgrid(torch.arange(nodes.size(1)).to(self.device), 69 | torch.arange(nodes.size(1)).to(self.device)) 70 | graph = torch.cat((nodes[:, r_idx], nodes[:, c_idx]), dim=3) 71 | r_id, c_id = nodes_info[..., 0][:, r_idx], nodes_info[..., 0][:, c_idx] # node type indicators 72 | 73 | # pair masks 74 | pid = self.pair_ids(r_id, c_id) 75 | 76 | # Linear reduction layers 77 | reduced_graph = torch.where(pid['MS'].unsqueeze(-1), self.reduce['MS'](graph), 78 | torch.zeros(graph.size()[:-1] + (self.out_dim,)).to(self.device)) 79 | reduced_graph = torch.where(pid['ME'].unsqueeze(-1), self.reduce['ME'](graph), reduced_graph) 80 | reduced_graph = torch.where(pid['ES'].unsqueeze(-1), self.reduce['ES'](graph), reduced_graph) 81 | 82 | if self.dist: 83 | dist_vec = self.dist_embed(positions) # distances 84 | reduced_graph = torch.where(pid['SS'].unsqueeze(-1), 85 | self.reduce['SS'](torch.cat((graph, dist_vec), dim=3)), reduced_graph) 86 | else: 87 | reduced_graph = torch.where(pid['SS'].unsqueeze(-1), self.reduce['SS'](graph), reduced_graph) 88 | 89 | if self.context and self.dist: 90 | m_cntx = self.attention(mentions, encoded_seq[info[:, 4]], info, word_sec) 91 | m_cntx = self.prepare_mention_context(m_cntx, section, r_idx, c_idx, 92 | encoded_seq[info[:, 4]], pid, nodes_info) 93 | 94 | reduced_graph = torch.where(pid['MM'].unsqueeze(-1), 95 | self.reduce['MM'](torch.cat((graph, dist_vec, m_cntx), dim=3)), reduced_graph) 96 | 97 | elif self.context: 98 | m_cntx = self.attention(mentions, encoded_seq[info[:, 4]], info, word_sec) 99 | m_cntx = self.prepare_mention_context(m_cntx, section, r_idx, c_idx, 100 | encoded_seq[info[:, 4]], pid, nodes_info) 101 | 102 | reduced_graph = torch.where(pid['MM'].unsqueeze(-1), 103 | self.reduce['MM'](torch.cat((graph, m_cntx), dim=3)), reduced_graph) 104 | 105 | elif self.dist: 106 | reduced_graph = torch.where(pid['MM'].unsqueeze(-1), 107 | self.reduce['MM'](torch.cat((graph, dist_vec), dim=3)), reduced_graph) 108 | 109 | else: 110 | reduced_graph = torch.where(pid['MM'].unsqueeze(-1), self.reduce['MM'](graph), reduced_graph) 111 | 112 | if self.ee: 113 | reduced_graph = torch.where(pid['EE'].unsqueeze(-1), self.reduce['EE'](graph), reduced_graph) 114 | 115 | mask = self.get_nodes_mask(section.sum(dim=1)) 116 | return reduced_graph, (r_idx, c_idx), nodes_info, mask 117 | 118 | def prepare_mention_context(self, m_cntx, section, r_idx, c_idx, s_seq, pid, nodes_info): 119 | """ 120 | Estimate attention scores for each pair 121 | (a1 + a2)/2 * sentence_words 122 | """ 123 | # "fake" mention weight nodes 124 | m_cntx = torch.cat((torch.zeros(section.sum(dim=0)[0], m_cntx.size(1)).to(self.device), 125 | m_cntx, 126 | torch.zeros(section.sum(dim=0)[2], m_cntx.size(1)).to(self.device)), dim=0) 127 | m_cntx = self.rearrange_nodes(m_cntx, section) 128 | m_cntx = self.split_n_pad(m_cntx, section, pad=0) 129 | m_cntx = torch.div(m_cntx[:, r_idx] + m_cntx[:, c_idx], 2) 130 | 131 | # mask non-MM pairs 132 | # mask invalid weights (i.e. M-M not in the same sentence) 133 | mask_ = torch.eq(nodes_info[..., 2][:, r_idx], nodes_info[..., 2][:, c_idx]) & pid['MM'] 134 | m_cntx = torch.where(mask_.unsqueeze(-1), m_cntx, torch.zeros_like(m_cntx)) 135 | 136 | # "fake" mention sentences nodes 137 | sents = torch.cat((torch.zeros(section.sum(dim=0)[0], m_cntx.size(3), s_seq.size(2)).to(self.device), 138 | s_seq, 139 | torch.zeros(section.sum(dim=0)[2], m_cntx.size(3), s_seq.size(2)).to(self.device)), dim=0) 140 | sents = self.rearrange_nodes(sents, section) 141 | sents = self.split_n_pad(sents, section, pad=0) 142 | m_cntx = torch.matmul(m_cntx, sents) 143 | return m_cntx 144 | 145 | @staticmethod 146 | def pair_ids(r_id, c_id): 147 | pids = { 148 | 'EE': ((r_id == 0) & (c_id == 0)), 149 | 'MM': ((r_id == 1) & (c_id == 1)), 150 | 'SS': ((r_id == 2) & (c_id == 2)), 151 | 'ES': (((r_id == 0) & (c_id == 2)) | ((r_id == 2) & (c_id == 0))), 152 | 'MS': (((r_id == 1) & (c_id == 2)) | ((r_id == 2) & (c_id == 1))), 153 | 'ME': (((r_id == 1) & (c_id == 0)) | ((r_id == 0) & (c_id == 1))) 154 | } 155 | return pids 156 | 157 | @staticmethod 158 | def rearrange_nodes(nodes, section): 159 | """ 160 | Re-arrange nodes so that they are in 'Entity - Mention - Sentence' order for each document (batch) 161 | """ 162 | tmp1 = section.t().contiguous().view(-1).long().to(nodes.device) 163 | tmp3 = torch.arange(section.numel()).view(section.size(1), 164 | section.size(0)).t().contiguous().view(-1).long().to(nodes.device) 165 | tmp2 = torch.arange(section.sum()).to(nodes.device).split(tmp1.tolist()) 166 | tmp2 = pad_sequence(tmp2, batch_first=True, padding_value=-1)[tmp3].view(-1) 167 | tmp2 = tmp2[(tmp2 != -1).nonzero().squeeze()] # remove -1 (padded) 168 | 169 | nodes = torch.index_select(nodes, 0, tmp2) 170 | return nodes 171 | 172 | @staticmethod 173 | def split_n_pad(nodes, section, pad=None): 174 | nodes = torch.split(nodes, section.sum(dim=1).tolist()) 175 | nodes = pad_sequence(nodes, batch_first=True, padding_value=pad) 176 | return nodes 177 | 178 | @staticmethod 179 | def get_nodes_mask(nodes_size): 180 | """ 181 | Create mask for padded nodes 182 | """ 183 | n_total = torch.arange(nodes_size.max()).to(nodes_size.device) 184 | idx_r, idx_c, idx_d = torch.meshgrid(n_total, n_total, n_total) 185 | 186 | # masks for padded elements (1 in valid, 0 in padded) 187 | ns = nodes_size[:, None, None, None] 188 | mask3d = ~(torch.ge(idx_r, ns) | torch.ge(idx_c, ns) | torch.ge(idx_d, ns)) 189 | return mask3d 190 | 191 | def node_info(self, section, info): 192 | """ 193 | Col 0: node type | Col 1: semantic type | Col 2: sentence id 194 | """ 195 | typ = torch.repeat_interleave(torch.arange(3).to(self.device), section.sum(dim=0)) # node types (0,1,2) 196 | rows_ = torch.bincount(info[:, 0]).cumsum(dim=0).sub(1) 197 | stypes = torch.neg(torch.ones(section[:, 2].sum())).to(self.device).long() # semantic type sentences = -1 198 | all_types = torch.cat((info[:, 1][rows_], info[:, 1], stypes), dim=0) 199 | sents_ = torch.arange(section.sum(dim=0)[2]).to(self.device) 200 | sent_id = torch.cat((info[:, 4][rows_], info[:, 4], sents_), dim=0) # sent_id 201 | return torch.cat((typ.unsqueeze(-1), all_types.unsqueeze(-1), sent_id.unsqueeze(-1)), dim=1) 202 | 203 | def estimate_loss(self, pred_pairs, truth): 204 | """ 205 | Softmax cross entropy loss. 206 | Args: 207 | pred_pairs (Tensor): Un-normalized pairs (# pairs, classes) 208 | truth (Tensor): Ground-truth labels (# pairs, id) 209 | 210 | Returns: (Tensor) loss, (Tensors) TP/FP/FN 211 | """ 212 | mask = torch.ne(truth, -1) 213 | truth = truth[mask] 214 | pred_pairs = pred_pairs[mask] 215 | 216 | assert (truth != -1).all() 217 | loss = self.loss(pred_pairs, truth) 218 | 219 | predictions = F.softmax(pred_pairs, dim=1).data.argmax(dim=1) 220 | stats = self.count_predictions(predictions, truth) 221 | return loss, stats, predictions 222 | 223 | @staticmethod 224 | def merge_mentions(info, mentions): 225 | """ 226 | Merge mentions into entities; 227 | Find which rows (mentions) have the same entity id and average them 228 | """ 229 | m_ids, e_ids = torch.broadcast_tensors(info[:, 0].unsqueeze(0), 230 | torch.arange(0, max(info[:, 0]) + 1).unsqueeze(-1).to(info.device)) 231 | index_m = torch.eq(m_ids, e_ids).type('torch.FloatTensor').to(info.device) 232 | entities = torch.div(torch.matmul(index_m, mentions), torch.sum(index_m, dim=1).unsqueeze(-1)) # average 233 | return entities 234 | 235 | @staticmethod 236 | def merge_tokens(info, enc_seq, rm_pad): 237 | """ 238 | Merge tokens into mentions; 239 | Find which tokens belong to a mention (based on start-end ids) and average them 240 | """ 241 | enc_seq = enc_seq[rm_pad] 242 | start, end, w_ids = torch.broadcast_tensors(info[:, 2].unsqueeze(-1), 243 | info[:, 3].unsqueeze(-1), 244 | torch.arange(0, enc_seq.shape[0]).unsqueeze(0).to(info.device)) 245 | index_t = (torch.ge(w_ids, start) & torch.lt(w_ids, end)).float().to(info.device) 246 | mentions = torch.div(torch.matmul(index_t, enc_seq), torch.sum(index_t, dim=1).unsqueeze(-1)) # average 247 | return mentions 248 | 249 | @staticmethod 250 | def select_pairs(combs, nodes_info, idx): 251 | """ 252 | Select (entity node) pairs for classification based on input parameter restrictions (i.e. their entity type). 253 | """ 254 | combs = torch.split(combs, 2, dim=0) 255 | sel = torch.zeros(nodes_info.size(0), nodes_info.size(1), nodes_info.size(1)).to(nodes_info.device) 256 | 257 | a_ = nodes_info[..., 0][:, idx[0]] 258 | b_ = nodes_info[..., 0][:, idx[1]] 259 | c_ = nodes_info[..., 1][:, idx[0]] 260 | d_ = nodes_info[..., 1][:, idx[1]] 261 | for ca, cb in combs: 262 | condition1 = torch.eq(a_, 0) & torch.eq(b_, 0) # needs to be an entity node (id=0) 263 | condition2 = torch.eq(c_, ca) & torch.eq(d_, cb) # valid pair semantic types 264 | sel = torch.where(condition1 & condition2, torch.ones_like(sel), sel) 265 | return sel.nonzero().unbind(dim=1) 266 | 267 | def count_predictions(self, y, t): 268 | """ 269 | Count number of TP, FP, FN, TN for each relation class 270 | """ 271 | label_num = torch.as_tensor([self.rel_size]).long().to(self.device) 272 | ignore_label = torch.as_tensor([self.ignore_label]).long().to(self.device) 273 | 274 | mask_t = torch.eq(t, ignore_label).view(-1) # where the ground truth needs to be ignored 275 | mask_p = torch.eq(y, ignore_label).view(-1) # where the predicted needs to be ignored 276 | 277 | true = torch.where(mask_t, label_num, t.view(-1).long().to(self.device)) # ground truth 278 | pred = torch.where(mask_p, label_num, y.view(-1).long().to(self.device)) # output of NN 279 | 280 | tp_mask = torch.where(torch.eq(pred, true), true, label_num) 281 | fp_mask = torch.where(torch.ne(pred, true), pred, label_num) 282 | fn_mask = torch.where(torch.ne(pred, true), true, label_num) 283 | 284 | tp = torch.bincount(tp_mask, minlength=self.rel_size + 1)[:self.rel_size] 285 | fp = torch.bincount(fp_mask, minlength=self.rel_size + 1)[:self.rel_size] 286 | fn = torch.bincount(fn_mask, minlength=self.rel_size + 1)[:self.rel_size] 287 | tn = torch.sum(mask_t & mask_p) 288 | return {'tp': tp, 'fp': fp, 'fn': fn, 'tn': tn} 289 | 290 | def forward(self, batch): 291 | """ 292 | Network Forward computation. 293 | Args: 294 | batch: dictionary with tensors 295 | Returns: (Tensors) loss, statistics, predictions, index 296 | """ 297 | # Word Embeddings 298 | word_vec = self.input_layer(batch['words']) 299 | 300 | # Encoder 301 | encoded_seq = self.encoding_layer(word_vec, batch['word_sec']) 302 | 303 | # Graph 304 | graph, pindex, nodes_info, mask = self.graph_layer(encoded_seq, batch['entities'], batch['word_sec'], 305 | batch['section'], batch['distances']) 306 | 307 | # Inference/Walks 308 | if self.walks_iter and self.walks_iter > 0: 309 | graph = self.walk(graph, adj_=batch['adjacency'], mask_=mask) 310 | 311 | # Classification 312 | select = self.select_pairs(batch['pairs4class'], nodes_info, pindex) 313 | graph = self.classifier(graph[select]) 314 | 315 | loss, stats, preds = self.estimate_loss(graph, batch['relations'][select].long()) 316 | return loss, stats, preds, select 317 | -------------------------------------------------------------------------------- /src/nnet/trainer.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | Created on 27-Feb-2019 5 | 6 | author: fenia 7 | """ 8 | 9 | import torch 10 | import numpy as np 11 | import os 12 | from time import time 13 | import itertools 14 | import copy 15 | import datetime 16 | import random 17 | from random import shuffle 18 | from utils import print_results, write_preds, write_errors, print_options 19 | from converter import concat_examples 20 | from torch import autograd 21 | from nnet.network import EOG 22 | from torch import nn, optim 23 | import sys 24 | torch.set_printoptions(profile="full") 25 | np.set_printoptions(threshold=np.inf) 26 | 27 | 28 | class Trainer: 29 | def __init__(self, loader, params, data, model_folder): 30 | """ 31 | Trainer object. 32 | 33 | Args: 34 | loader: loader object that holds information for training data 35 | params (dict): model parameters 36 | data (dict): 'train' and 'test' data 37 | """ 38 | self.data = data 39 | self.params = params 40 | self.rel_size = loader.n_rel 41 | self.loader = loader 42 | self.model_folder = model_folder 43 | 44 | self.device = torch.device("cuda:{}".format(params['gpu']) if params['gpu'] != -1 else "cpu") 45 | self.gc = params['gc'] 46 | self.epoch = params['epoch'] 47 | self.example = params['example'] 48 | self.pa = params['param_avg'] 49 | self.es = params['early_stop'] 50 | self.primary_metric = params['primary_metric'] 51 | self.show_class = params['show_class'] 52 | self.window = params['window'] 53 | self.preds_file = os.path.join(model_folder, params['save_pred']) 54 | self.best_epoch = 0 55 | 56 | self.train_res = {'loss': [], 'score': []} 57 | self.test_res = {'loss': [], 'score': []} 58 | 59 | self.max_patience = self.params['patience'] 60 | self.cur_patience = 0 61 | self.best_score = 0.0 62 | self.best_epoch = 0 63 | 64 | # pairs while training 65 | self.pairs4train = [] 66 | for i in params['include_pairs']: 67 | m, n = i.split('-') 68 | self.pairs4train += [self.loader.type2index[m], self.loader.type2index[n]] 69 | 70 | # pairs to classify 71 | self.pairs4class = [] 72 | for i in params['classify_pairs']: 73 | m, n = i.split('-') 74 | self.pairs4class += [self.loader.type2index[m], self.loader.type2index[n]] 75 | 76 | # parameter averaging 77 | if params['param_avg']: 78 | self.averaged_params = {} 79 | 80 | self.model = self.init_model() 81 | self.optimizer = self.set_optimizer(self.model) 82 | 83 | def init_model(self): 84 | model_0 = EOG(self.params, self.loader.pre_embeds, 85 | sizes={'word_size': self.loader.n_words, 'dist_size': self.loader.n_dist, 86 | 'type_size': self.loader.n_type, 'rel_size': self.loader.n_rel}, 87 | maps={'word2idx': self.loader.word2index, 'idx2word': self.loader.index2word, 88 | 'rel2idx': self.loader.rel2index, 'idx2rel': self.loader.index2rel, 89 | 'type2idx': self.loader.type2index, 'idx2type': self.loader.index2type, 90 | 'dist2idx': self.loader.dist2index, 'idx2dist': self.loader.index2dist}, 91 | lab2ign=self.loader.label2ignore) 92 | 93 | # GPU/CPU 94 | if self.params['gpu'] != -1: 95 | torch.cuda.set_device(self.device) 96 | model_0.to(self.device) 97 | return model_0 98 | 99 | def set_optimizer(self, model_0): 100 | # OPTIMIZER 101 | # do not regularize biases 102 | params2reg = [] 103 | params0reg = [] 104 | for p_name, p_value in model_0.named_parameters(): 105 | if '.bias' in p_name: 106 | params0reg += [p_value] 107 | else: 108 | params2reg += [p_value] 109 | assert len(params0reg) + len(params2reg) == len(list(model_0.parameters())) 110 | groups = [dict(params=params2reg), dict(params=params0reg, weight_decay=.0)] 111 | optimizer = optim.Adam(groups, lr=self.params['lr'], weight_decay=self.params['reg'], amsgrad=True) 112 | 113 | # Train Model 114 | print_options(self.params) 115 | for p_name, p_value in model_0.named_parameters(): 116 | if p_value.requires_grad: 117 | print(p_name) 118 | return optimizer 119 | 120 | @staticmethod 121 | def iterator(x, shuffle_=False, batch_size=1): 122 | """ 123 | Create a new iterator for this epoch. 124 | Shuffle the data if specified. 125 | """ 126 | if shuffle_: 127 | shuffle(x) 128 | new = [x[i:i+batch_size] for i in range(0, len(x), batch_size)] 129 | return new 130 | 131 | def run(self): 132 | """ 133 | Main Training Loop. 134 | """ 135 | print('\n======== START TRAINING: {} ========\n'.format( 136 | datetime.datetime.now().strftime("%d-%m-%y_%H:%M:%S"))) 137 | 138 | random.shuffle(self.data['train']) # shuffle training data at least once 139 | for epoch in range(1, self.epoch+1): 140 | self.train_epoch(epoch) 141 | 142 | if self.pa: 143 | self.parameter_averaging() 144 | 145 | self.eval_epoch() 146 | 147 | stop = self.epoch_checking(epoch) 148 | if stop and self.es: 149 | break 150 | 151 | if self.pa: 152 | self.parameter_averaging(reset=True) 153 | 154 | print('Best epoch: {}'.format(self.best_epoch)) 155 | if self.pa: 156 | self.parameter_averaging(epoch=self.best_epoch) 157 | self.eval_epoch(final=True, save_predictions=True) 158 | 159 | print('\n======== END TRAINING: {} ========\n'.format( 160 | datetime.datetime.now().strftime("%d-%m-%y_%H:%M:%S"))) 161 | 162 | def train_epoch(self, epoch): 163 | """ 164 | Evaluate the model on the train set. 165 | """ 166 | t1 = time() 167 | output = {'tp': [], 'fp': [], 'fn': [], 'tn': [], 'loss': [], 'preds': []} 168 | train_info = [] 169 | 170 | self.model = self.model.train() 171 | train_iter = self.iterator(self.data['train'], batch_size=self.params['batch'], 172 | shuffle_=self.params['shuffle_data']) 173 | for batch_idx, batch in enumerate(train_iter): 174 | batch = self.convert_batch(batch) 175 | 176 | with autograd.detect_anomaly(): 177 | self.optimizer.zero_grad() 178 | loss, stats, predictions, select = self.model(batch) 179 | loss.backward() # backward computation 180 | 181 | nn.utils.clip_grad_norm_(self.model.parameters(), self.gc) # gradient clipping 182 | self.optimizer.step() # update 183 | 184 | output['loss'] += [loss.item()] 185 | output['tp'] += [stats['tp'].to('cpu').data.numpy()] 186 | output['fp'] += [stats['fp'].to('cpu').data.numpy()] 187 | output['fn'] += [stats['fn'].to('cpu').data.numpy()] 188 | output['tn'] += [stats['tn'].to('cpu').data.numpy()] 189 | output['preds'] += [predictions.to('cpu').data.numpy()] 190 | train_info += [batch['info'][select[0].to('cpu').data.numpy(), 191 | select[1].to('cpu').data.numpy(), 192 | select[2].to('cpu').data.numpy()]] 193 | 194 | t2 = time() 195 | if self.window: 196 | total_loss, scores = self.subdocs_performance(output['loss'], output['preds'], train_info) 197 | else: 198 | total_loss, scores = self.performance(output) 199 | 200 | self.train_res['loss'] += [total_loss] 201 | self.train_res['score'] += [scores[self.primary_metric]] 202 | print('Epoch: {:02d} | TRAIN | LOSS = {:.05f}, '.format(epoch, total_loss), end="") 203 | print_results(scores, [], self.show_class, t2-t1) 204 | 205 | def eval_epoch(self, final=False, save_predictions=False): 206 | """ 207 | Evaluate the model on the test set. 208 | No backward computation is allowed. 209 | """ 210 | t1 = time() 211 | output = {'tp': [], 'fp': [], 'fn': [], 'tn': [], 'loss': [], 'preds': []} 212 | test_info = [] 213 | 214 | self.model = self.model.eval() 215 | test_iter = self.iterator(self.data['test'], batch_size=self.params['batch'], shuffle_=False) 216 | for batch_idx, batch in enumerate(test_iter): 217 | batch = self.convert_batch(batch) 218 | 219 | with torch.no_grad(): 220 | loss, stats, predictions, select = self.model(batch) 221 | 222 | output['loss'] += [loss.item()] 223 | output['tp'] += [stats['tp'].to('cpu').data.numpy()] 224 | output['fp'] += [stats['fp'].to('cpu').data.numpy()] 225 | output['fn'] += [stats['fn'].to('cpu').data.numpy()] 226 | output['tn'] += [stats['tn'].to('cpu').data.numpy()] 227 | output['preds'] += [predictions.to('cpu').data.numpy()] 228 | test_info += [batch['info'][select[0].to('cpu').data.numpy(), 229 | select[1].to('cpu').data.numpy(), 230 | select[2].to('cpu').data.numpy()]] 231 | t2 = time() 232 | 233 | # estimate performance 234 | if self.window: 235 | total_loss, scores = self.subdocs_performance(output['loss'], output['preds'], test_info) 236 | else: 237 | total_loss, scores = self.performance(output) 238 | 239 | if not final: 240 | self.test_res['loss'] += [total_loss] 241 | self.test_res['score'] += [scores[self.primary_metric]] 242 | print(' TEST | LOSS = {:.05f}, '.format(total_loss), end="") 243 | print_results(scores, [], self.show_class, t2-t1) 244 | print() 245 | 246 | if save_predictions: 247 | write_preds(output['preds'], test_info, self.preds_file, map_=self.loader.index2rel) 248 | write_errors(output['preds'], test_info, self.preds_file, map_=self.loader.index2rel) 249 | 250 | def parameter_averaging(self, epoch=None, reset=False): 251 | """ 252 | Perform parameter averaging. 253 | For each epoch, average the parameters up to this epoch and then evaluate on test set. 254 | If 'reset' option: use the last epoch parameters for the next epock 255 | """ 256 | for p_name, p_value in self.model.named_parameters(): 257 | if p_name not in self.averaged_params: 258 | self.averaged_params[p_name] = [] 259 | 260 | if reset: 261 | p_new = copy.deepcopy(self.averaged_params[p_name][-1]) # use last epoch param 262 | 263 | elif epoch: 264 | p_new = np.mean(self.averaged_params[p_name][:epoch], axis=0) # estimate average until this epoch 265 | 266 | else: 267 | self.averaged_params[p_name].append(p_value.data.to('cpu').numpy()) 268 | p_new = np.mean(self.averaged_params[p_name], axis=0) # estimate average 269 | 270 | # assign to array 271 | if self.device != 'cpu': 272 | p_value.data = torch.from_numpy(p_new).to(self.device) 273 | else: 274 | p_value.data = torch.from_numpy(p_new) 275 | 276 | def epoch_checking(self, epoch): 277 | """ 278 | Perform early stopping. 279 | If performance does not improve for a number of consecutive epochs ("max_patience") 280 | then stop the training and keep the best epoch: stopped_epoch - max_patience 281 | 282 | Args: 283 | epoch (int): current training epoch 284 | 285 | Returns: (int) best_epoch, (bool) stop 286 | """ 287 | if self.test_res['score'][-1] > self.best_score: # improvement 288 | self.best_score = self.test_res['score'][-1] 289 | self.cur_patience = 0 290 | if self.es: 291 | self.best_epoch = epoch 292 | else: 293 | self.cur_patience += 1 294 | if not self.es: 295 | self.best_epoch = epoch 296 | 297 | if epoch % 5 == 0 and self.es: 298 | print('Current best {} score {:.6f} @ epoch {}\n'.format(self.params['primary_metric'], 299 | self.best_score, self.best_epoch)) 300 | 301 | if self.max_patience == self.cur_patience and self.es: # early stop must happen 302 | self.best_epoch = epoch - self.max_patience 303 | return True 304 | else: 305 | return False 306 | 307 | @staticmethod 308 | def performance(stats): 309 | """ 310 | Estimate total loss for an epoch. 311 | Calculate Micro and Macro P/R/F1 scores & Accuracy. 312 | Returns: (float) average loss, (float) micro and macro P/R/F1 313 | """ 314 | def fbeta_score(precision, recall, beta=1.0): 315 | beta_square = beta * beta 316 | if (precision != 0.0) and (recall != 0.0): 317 | res = ((1 + beta_square) * precision * recall / (beta_square * precision + recall)) 318 | else: 319 | res = 0.0 320 | return res 321 | 322 | def prf1(tp_, fp_, fn_, tn_): 323 | tp_ = np.sum(tp_, axis=0) 324 | fp_ = np.sum(fp_, axis=0) 325 | fn_ = np.sum(fn_, axis=0) 326 | tn_ = np.sum(tn_, axis=0) 327 | 328 | atp = np.sum(tp_) 329 | afp = np.sum(fp_) 330 | afn = np.sum(fn_) 331 | atn = np.sum(tn_) 332 | 333 | micro_p = (1.0 * atp) / (atp + afp) if (atp + afp != 0) else 0.0 334 | micro_r = (1.0 * atp) / (atp + afn) if (atp + afn != 0) else 0.0 335 | micro_f = fbeta_score(micro_p, micro_r) 336 | 337 | pp = [0] 338 | rr = [0] 339 | ff = [0] 340 | macro_p = np.mean(pp) 341 | macro_r = np.mean(rr) 342 | macro_f = np.mean(ff) 343 | 344 | acc = (atp + atn) / (atp + atn + afp + afn) if (atp + atn + afp + afn) else 0.0 345 | return {'acc': acc, 346 | 'micro_p': micro_p, 'micro_r': micro_r, 'micro_f': micro_f, 347 | 'macro_p': macro_p, 'macro_r': macro_r, 'macro_f': macro_f, 348 | 'tp': atp, 'true': atp+afn, 'pred': atp+afp, 'total': (atp + atn + afp + afn)} 349 | 350 | fin_loss = sum(stats['loss']) / len(stats['loss']) 351 | scores = prf1(stats['tp'], stats['fp'], stats['fn'], stats['tn']) 352 | return fin_loss, scores 353 | 354 | def subdocs_performance(self, loss_, preds, info): 355 | pairs = {} 356 | for p, i in zip(preds, info): 357 | i = [i_ for i_ in i if i_] 358 | assert len(p) == len(i) 359 | 360 | # pmid, e1, e2, pred, truth 361 | for k, j in zip(p, i): 362 | if (j['pmid'].split('__')[0], j['entA'].id, j['entB'].id, j['rel']) not in pairs: 363 | pairs[(j['pmid'].split('__')[0], j['entA'].id, j['entB'].id, j['rel'])] = [] 364 | pairs[(j['pmid'].split('__')[0], j['entA'].id, j['entB'].id, j['rel'])] += [k] 365 | 366 | res = [] 367 | tr = [] 368 | for l in pairs.keys(): 369 | res += [self.loader.label2ignore if all([c == self.loader.label2ignore for c in pairs[l]]) 370 | else not self.loader.label2ignore] 371 | tr += [l[3]] 372 | 373 | rsize = self.rel_size 374 | 375 | mask_t = np.equal(tr, self.loader.label2ignore) 376 | mask_p = np.equal(res, self.loader.label2ignore) 377 | 378 | true = np.where(mask_t, rsize, tr) 379 | pred = np.where(mask_p, rsize, res) 380 | 381 | tp_mask = np.where(pred == true, true, rsize) 382 | fp_mask = np.where(pred != true, pred, rsize) 383 | fn_mask = np.where(pred != true, true, rsize) 384 | 385 | tp = np.bincount(tp_mask, minlength=rsize + 1)[:rsize] 386 | fp = np.bincount(fp_mask, minlength=rsize + 1)[:rsize] 387 | fn = np.bincount(fn_mask, minlength=rsize + 1)[:rsize] 388 | tn = np.sum(mask_t & mask_p) 389 | return self.performance({'loss': loss_, 'tp': tp, 'fp': fp, 'fn': fn, 'tn': tn}) 390 | 391 | def convert_batch(self, batch): 392 | new_batch = {'entities': []} 393 | ent_count, sent_count, word_count = 0, 0, 0 394 | full_text = [] 395 | 396 | # TODO make this faster 397 | for i, b in enumerate(batch): 398 | current_text = list(itertools.chain.from_iterable(b['text'])) 399 | full_text += current_text 400 | 401 | temp = [] 402 | for e in b['ents']: 403 | # token ids are correct 404 | assert full_text[(e[2] + word_count):(e[3] + word_count)] == current_text[e[2]:e[3]], \ 405 | '{} != {}'.format(full_text[(e[2] + word_count):(e[3] + word_count)], current_text[e[2]:e[3]]) 406 | temp += [[e[0] + ent_count, e[1], e[2] + word_count, e[3] + word_count, e[4] + sent_count]] 407 | 408 | new_batch['entities'] += [np.array(temp)] 409 | word_count += sum([len(s) for s in b['text']]) 410 | ent_count = max([t[0] for t in temp]) + 1 411 | sent_count += len(b['text']) 412 | 413 | new_batch['entities'] = torch.as_tensor(np.concatenate(new_batch['entities'], axis=0)).long().to(self.device) 414 | 415 | batch_ = [{k: v for k, v in b.items() if (k != 'info' and k != 'text')} for b in batch] 416 | converted_batch = concat_examples(batch_, device=self.device, padding=-1) 417 | 418 | converted_batch['adjacency'][converted_batch['adjacency'] == -1] = 0 419 | converted_batch['dist'][converted_batch['dist'] == -1] = self.loader.n_dist 420 | 421 | new_batch['adjacency'] = converted_batch['adjacency'].byte() 422 | new_batch['distances'] = converted_batch['dist'] 423 | new_batch['relations'] = converted_batch['rels'] 424 | new_batch['section'] = converted_batch['section'] 425 | new_batch['word_sec'] = converted_batch['word_sec'][converted_batch['word_sec'] != -1].long() 426 | new_batch['words'] = converted_batch['words'][converted_batch['words'] != -1].long() 427 | new_batch['pairs4class'] = torch.as_tensor(self.pairs4class).long().to(self.device) 428 | new_batch['info'] = np.stack([np.array(np.pad(b['info'], 429 | ((0, new_batch['section'][:, 0].sum(dim=0).item() - b['info'].shape[0]), 430 | (0, new_batch['section'][:, 0].sum(dim=0).item() - b['info'].shape[0])), 431 | 'constant', 432 | constant_values=(-1, -1))) for b in batch], axis=0) 433 | 434 | if self.example: 435 | for i, b in enumerate(batch): 436 | print('===== DOCUMENT NO {} ====='.format(i)) 437 | for s in b['text']: 438 | print(' '.join([self.loader.index2word[t] for t in s])) 439 | print(b['ents']) 440 | 441 | print(new_batch['relations'][i]) 442 | print(new_batch['adjacency'][i]) 443 | print(np.array([self.loader.index2dist[p] for p in 444 | new_batch['distances'][i].to('cpu').data.numpy().ravel()]).reshape( 445 | new_batch['distances'][i].shape)) 446 | sys.exit() 447 | return new_batch 448 | -------------------------------------------------------------------------------- /src/nnet/walks.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | Created on 06/03/19 5 | 6 | author: fenia 7 | """ 8 | 9 | import torch 10 | from torch import nn 11 | torch.set_printoptions(profile="full") 12 | 13 | 14 | class WalkLayer(nn.Module): 15 | def __init__(self, input_size, iters=0, beta=0.9, device=-1): 16 | """ 17 | Walk Layer --> Walk on the edges of a predefined graph 18 | Args: 19 | input_size (int): input dimensionality 20 | iters (int): number of iterations --> 2^{iters} = walks-length 21 | beta (float): weight shorter/longer walks 22 | Return: 23 | pairs (Tensor): final pair representations 24 | size (batch * nodes * nodes, features) 25 | """ 26 | super(WalkLayer, self).__init__() 27 | 28 | self.W = nn.Parameter(nn.init.normal_(torch.empty(input_size, input_size)), requires_grad=True) 29 | self.sigmoid = nn.Sigmoid() 30 | self.beta = beta 31 | self.iters = iters 32 | self.device = device 33 | 34 | @staticmethod 35 | def init_graph(graph, adj): 36 | """ 37 | Initialize graph with 0 connections based on the adjacency matrix 38 | """ 39 | graph = torch.where(adj.unsqueeze(-1), graph, torch.zeros_like(graph)) 40 | return graph 41 | 42 | @staticmethod 43 | def mask_invalid_paths(graph, mask3d): 44 | """ 45 | Mask invalid paths 46 | *(any node) -> A -> A 47 | A -> A -> *(any node) 48 | A -> *(any node) -> A 49 | 50 | Additionally mask paths that involve padded entities as intermediate nodes 51 | -inf so that sigmoid returns 0 52 | """ 53 | items = range(graph.size(1)) 54 | graph[:, :, items, items] = float('-inf') # *->A->A 55 | graph[:, items, items] = float('-inf') # A->A->* 56 | graph[:, items, :, items] = float('-inf') # A->*->A (self-connection) 57 | 58 | graph = torch.where(mask3d.unsqueeze(-1), graph, torch.as_tensor([float('-inf')]).to(graph.device)) # padded 59 | graph = torch.where(torch.eq(graph, 0.0).all(dim=4, keepdim=True), 60 | torch.as_tensor([float('-inf')]).to(graph.device), 61 | graph) # remaining (make sure the whole representation is zero) 62 | return graph 63 | 64 | def generate(self, old_graph): 65 | """ 66 | Walk-generation: Combine consecutive edges. 67 | Returns: previous graph, 68 | extended graph with intermediate node connections (dim=2) 69 | """ 70 | graph = torch.matmul(old_graph, self.W[None, None]) # (B, I, I, D) 71 | graph = torch.einsum('bijk, bjmk -> bijmk', graph, old_graph) # (B, I, I, I, D) -> dim=2 intermediate node 72 | return old_graph, graph 73 | 74 | def aggregate(self, old_graph, new_graph): 75 | """ 76 | Walk-aggregation: Combine multiple paths via intermediate nodes. 77 | """ 78 | # if the new representation is zero (i.e. impossible path), keep the original one --> [beta = 1] 79 | beta_mat = torch.where(torch.isinf(new_graph).all(dim=2), 80 | torch.ones_like(old_graph), 81 | torch.full_like(old_graph, self.beta)) 82 | 83 | new_graph = self.sigmoid(new_graph) 84 | new_graph = torch.sum(new_graph, dim=2) # non-linearity & sum pooling 85 | new_graph = torch.lerp(new_graph, old_graph, weight=beta_mat) 86 | return new_graph 87 | 88 | def forward(self, graph, adj_=None, mask_=None): 89 | graph = self.init_graph(graph, adj_) 90 | 91 | for _ in range(0, self.iters): 92 | old_graph, graph = self.generate(graph) 93 | graph = self.mask_invalid_paths(graph, mask_) 94 | graph = self.aggregate(old_graph, graph) 95 | return graph 96 | -------------------------------------------------------------------------------- /src/reader.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | Created on 19/06/2019 5 | 6 | author: fenia 7 | """ 8 | 9 | from collections import OrderedDict 10 | from recordtype import recordtype 11 | import numpy as np 12 | 13 | 14 | EntityInfo = recordtype('EntityInfo', 'id type mstart mend sentNo') 15 | PairInfo = recordtype('PairInfo', 'type direction cross') 16 | 17 | 18 | def chunks(l, n): 19 | """ 20 | Successive n-sized chunks from l. 21 | """ 22 | res = [] 23 | for i in range(0, len(l), n): 24 | assert len(l[i:i + n]) == n 25 | res += [l[i:i + n]] 26 | return res 27 | 28 | 29 | def overlap_chunk(chunk=1, lst=None): 30 | if len(lst) <= chunk: 31 | return [lst] 32 | else: 33 | return [lst[i:i + chunk] for i in range(0, len(lst)-chunk+1, 1)] 34 | 35 | 36 | def read_subdocs(input_file, window, documents, entities, relations): 37 | """ 38 | Read documents as sub-documents of N consecutive sentences. 39 | Args: 40 | input_file: file with documents 41 | """ 42 | lost_pairs, total_pairs = 0, 0 43 | lengths = [] 44 | sents = [] 45 | with open(input_file, 'r') as infile: 46 | for line in infile: 47 | line = line.rstrip().split('\t') 48 | pmid = line[0] 49 | text = line[1] 50 | prs = chunks(line[2:], 17) # all the pairs in the document 51 | 52 | sentences = text.split('|') # document sentences 53 | all_sent_lengths = [len(s.split(' ')) for s in sentences] # document sentence lengths 54 | 55 | sent_chunks = overlap_chunk(chunk=window, lst=sentences) # split document into sub-documents 56 | 57 | unique_pairs = [] 58 | for num, sent in enumerate(sent_chunks): 59 | sent_ids = list(np.arange(int(window)) + num) 60 | 61 | sub_pmid = pmid+'__'+str(num) 62 | 63 | if sub_pmid not in documents: 64 | documents[sub_pmid] = [t.split(' ') for t in sent] 65 | 66 | if sub_pmid not in entities: 67 | entities[sub_pmid] = OrderedDict() 68 | 69 | if sub_pmid not in relations: 70 | relations[sub_pmid] = OrderedDict() 71 | 72 | lengths += [max([len(d) for d in documents[sub_pmid]])] 73 | sents += [len(sent)] 74 | 75 | for p in prs: 76 | # entities 77 | for (ent, typ_, start, end, sn) in [(p[5], p[7], p[8], p[9], p[10]), 78 | (p[11], p[13], p[14], p[15], p[16])]: 79 | 80 | if ent not in entities[sub_pmid]: 81 | s_ = list(map(int, sn.split(':'))) # doc-level ids 82 | m_s_ = list(map(int, start.split(':'))) 83 | m_e_ = list(map(int, end.split(':'))) 84 | assert len(s_) == len(m_s_) == len(m_e_) 85 | 86 | sent_no_new = [] 87 | mstart_new = [] 88 | mend_new = [] 89 | for n, (old_s, old_ms, old_me) in enumerate(zip(s_, m_s_, m_e_)): 90 | if old_s in sent_ids: 91 | sub_ = sum(all_sent_lengths[0:old_s]) 92 | 93 | assert sent[old_s-num] == sentences[old_s] 94 | assert sent[old_s-num].split(' ')[(old_ms-sub_):(old_me-sub_)] == \ 95 | ' '.join(sentences).split(' ')[old_ms:old_me] 96 | sent_no_new += [old_s - num] 97 | mstart_new += [old_ms - sub_] 98 | mend_new += [old_me - sub_] 99 | 100 | if sent_no_new and mstart_new and mend_new: 101 | entities[sub_pmid][ent] = EntityInfo(ent, typ_, 102 | ':'.join(map(str, mstart_new)), 103 | ':'.join(map(str, mend_new)), 104 | ':'.join(map(str, sent_no_new))) 105 | 106 | for p in prs: 107 | # pairs 108 | if (p[5] in entities[sub_pmid]) and (p[11] in entities[sub_pmid]): 109 | if (p[5], p[11]) not in relations[sub_pmid]: 110 | relations[sub_pmid][(p[5], p[11])] = PairInfo(p[0], p[1], p[2]) 111 | 112 | if (pmid, p[5], p[11]) not in unique_pairs: 113 | unique_pairs += [(pmid, p[5], p[11])] 114 | 115 | if len(prs) != len(unique_pairs): 116 | for x in prs: 117 | if (pmid, x[5], x[11]) not in unique_pairs: 118 | if x[0] != '1:NR:2' and x[0] != 'not_include': 119 | lost_pairs += 1 120 | print('--> Lost pair {}, {}, {}: {} {}'.format(pmid, x[5], x[11], x[10], x[16])) 121 | else: 122 | if x[0] != '1:NR:2' and x[0] != 'not_include': 123 | total_pairs += 1 124 | 125 | todel = [] 126 | for pmid, d in relations.items(): 127 | if not relations[pmid]: 128 | todel += [pmid] 129 | 130 | for pmid in todel: 131 | del documents[pmid] 132 | del entities[pmid] 133 | del relations[pmid] 134 | 135 | print('LOST PAIRS: {}/{}'.format(lost_pairs, total_pairs)) 136 | assert len(entities) == len(documents) == len(relations) 137 | return lengths, sents, documents, entities, relations 138 | 139 | 140 | def read(input_file, documents, entities, relations): 141 | """ 142 | Read the full document at a time. 143 | """ 144 | lengths = [] 145 | sents = [] 146 | with open(input_file, 'r') as infile: 147 | for line in infile: 148 | line = line.rstrip().split('\t') 149 | pmid = line[0] 150 | text = line[1] 151 | prs = chunks(line[2:], 17) 152 | 153 | if pmid not in documents: 154 | documents[pmid] = [t.split(' ') for t in text.split('|')] 155 | 156 | if pmid not in entities: 157 | entities[pmid] = OrderedDict() 158 | 159 | if pmid not in relations: 160 | relations[pmid] = OrderedDict() 161 | 162 | # max intra-sentence length and max inter-sentence length 163 | lengths += [max([len(s) for s in documents[pmid]] + [len(documents[pmid])])] 164 | sents += [len(text.split('|'))] 165 | 166 | allp = 0 167 | for p in prs: 168 | if (p[5], p[11]) not in relations[pmid]: 169 | relations[pmid][(p[5], p[11])] = PairInfo(p[0], p[1], p[2]) 170 | allp += 1 171 | else: 172 | print('duplicates!') 173 | 174 | # entities 175 | if p[5] not in entities[pmid]: 176 | entities[pmid][p[5]] = EntityInfo(p[5], p[7], p[8], p[9], p[10]) 177 | 178 | if p[11] not in entities[pmid]: 179 | entities[pmid][p[11]] = EntityInfo(p[11], p[13], p[14], p[15], p[16]) 180 | 181 | assert len(relations[pmid]) == allp 182 | 183 | todel = [] 184 | for pmid, d in relations.items(): 185 | if not relations[pmid]: 186 | todel += [pmid] 187 | 188 | for pmid in todel: 189 | del documents[pmid] 190 | del entities[pmid] 191 | del relations[pmid] 192 | 193 | return lengths, sents, documents, entities, relations 194 | -------------------------------------------------------------------------------- /src/utils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | Created on 22/02/19 5 | 6 | author: fenia 7 | """ 8 | 9 | import sys 10 | import os 11 | from tabulate import tabulate 12 | import itertools 13 | import numpy as np 14 | import pickle as pkl 15 | import torch 16 | import matplotlib 17 | matplotlib.use('Agg') 18 | import matplotlib.pyplot as plt 19 | 20 | 21 | def solve(A, B): 22 | A = list(map(int, A)) 23 | B = list(map(int, B)) 24 | m = len(A) 25 | n = len(B) 26 | A.sort() 27 | B.sort() 28 | a = 0 29 | b = 0 30 | result = sys.maxsize 31 | 32 | while a < m and b < n: 33 | if abs(A[a] - B[b]) < result: 34 | result = abs(A[a] - B[b]) 35 | 36 | # Move Smaller Value 37 | if A[a] < B[b]: 38 | a += 1 39 | else: 40 | b += 1 41 | # return final sma result 42 | return result 43 | 44 | 45 | def write_errors(preds, info, ofile, map_=None): 46 | """ Write model errors to file """ 47 | print('Saving predictions ... ', end="") 48 | with open(ofile+'.errors', 'w') as outfile: 49 | for p, i in zip(preds, info): 50 | i = [i_ for i_ in i if i_] 51 | assert len(p) == len(i) 52 | 53 | for k, j in zip(p, i): 54 | if k != j['rel']: 55 | outfile.write('Prediction: {} \t Truth: {} \t Type: {} \n'.format(map_[k], map_[j['rel']], j['cross'])) 56 | doc = [it for items in j['doc'] for it in items] 57 | outfile.write('{}\n{}\n'.format(j['pmid'], ' '.join(doc))) 58 | 59 | gg1 = ' | '.join([' '.join(doc[int(m1):int(m2)]) for m1,m2 in 60 | zip(j['entA'].mstart.split(':'), j['entA'].mend.split(':'))]) 61 | gg2 = ' | '.join([' '.join(doc[int(m1):int(m2)]) for m1, m2 in 62 | zip(j['entB'].mstart.split(':'), j['entB'].mend.split(':'))]) 63 | 64 | outfile.write('Arg1: {} | {}\n'.format(j['entA'].id, gg1)) 65 | outfile.write('Arg2: {} | {}\n'.format(j['entB'].id, gg2)) 66 | outfile.write('Distance: {}\n'.format(solve(j['sentA'].split(':'), j['sentB'].split(':')))) 67 | outfile.write('\n') 68 | print('DONE') 69 | 70 | 71 | def write_preds(preds, info, ofile, map_=None): 72 | """ Write predictions to file """ 73 | print('Saving errors ... ', end="") 74 | with open(ofile+'.preds', 'w') as outfile: 75 | for p, i in zip(preds, info): 76 | i = [i_ for i_ in i if i_] 77 | assert len(p) == len(i) 78 | 79 | for k, j in zip(p, i): 80 | # pmid, e1, e2, pred, truth 81 | if map_[k] == '1:NR:2': 82 | pass 83 | else: 84 | outfile.write('{}\n'.format('|'.join([j['pmid'].split('__')[0], 85 | j['entA'].id, j['entB'].id, j['cross'], 86 | str(solve(j['sentA'].split(':'), j['sentB'].split(':'))), 87 | map_[k]]))) 88 | print('DONE') 89 | 90 | 91 | def plot_learning_curve(trainer, model_folder): 92 | """ 93 | Plot the learning curves for training and test set (loss and primary score measure) 94 | 95 | Args: 96 | trainer (Class): trainer object 97 | model_folder (str): folder to save figures 98 | """ 99 | x = list(map(int, np.arange(len(trainer.train_res['loss'])))) 100 | fig = plt.figure() 101 | plt.subplot(2, 1, 1) 102 | plt.plot(x, trainer.train_res['loss'], 'b', label='train') 103 | plt.plot(x, trainer.test_res['loss'], 'g', label='test') 104 | plt.legend() 105 | plt.ylabel('Loss') 106 | plt.yticks(np.arange(0, 1, 0.1)) 107 | 108 | plt.subplot(2, 1, 2) 109 | plt.plot(x, trainer.train_res['score'], 'b', label='train') 110 | plt.plot(x, trainer.test_res['score'], 'g', label='test') 111 | plt.legend() 112 | plt.ylabel('F1-score') 113 | plt.xlabel('Epochs') 114 | plt.yticks(np.arange(0, 1, 0.1)) 115 | 116 | fig.savefig(model_folder + '/learn_curves.png', bbox_inches='tight') 117 | 118 | 119 | def print_results(scores, scores_class, show_class, time): 120 | """ 121 | Print class-wise results. 122 | 123 | Args: 124 | scores (dict): micro and macro scores 125 | scores_class: score per class 126 | show_class (bool): show or not 127 | time: time 128 | """ 129 | 130 | def indent(txt, spaces=18): 131 | return "\n".join(" " * spaces + ln for ln in txt.splitlines()) 132 | 133 | if show_class: 134 | # print results for every class 135 | scores_class.append(['-----', None, None, None]) 136 | scores_class.append(['macro score', scores['macro_p'], scores['macro_r'], scores['macro_f']]) 137 | scores_class.append(['micro score', scores['micro_p'], scores['micro_r'], scores['micro_f']]) 138 | print(' | {}\n'.format(humanized_time(time))) 139 | print(indent(tabulate(scores_class, 140 | headers=['Class', 'P', 'R', 'F1'], 141 | tablefmt='orgtbl', 142 | floatfmtL=".4f", 143 | missingval=""))) 144 | print() 145 | else: 146 | print('ACC = {:.04f} , ' 147 | 'MICRO P/R/F1 = {:.04f}\t{:.04f}\t{:.04f} | '.format(scores['acc'], scores['micro_p'], scores['micro_r'], 148 | scores['micro_f']), end="") 149 | 150 | l = ':<7' # +str(len(str(scores['total']))) 151 | s = 'TP/ACTUAL/PRED = {'+l+'}/{'+l+'}/{'+l+'}, TOTAL {'+l+'}' 152 | print(s.format(scores['tp'], scores['true'], scores['pred'], scores['total']), end="") 153 | print(' | {}'.format(humanized_time(time))) 154 | 155 | 156 | class Tee(object): 157 | """ 158 | Object to print stdout to a file. 159 | """ 160 | def __init__(self, *files): 161 | self.files = files 162 | 163 | def write(self, obj): 164 | for f_ in self.files: 165 | f_.write(obj) 166 | f_.flush() # If you want the output to be visible immediately 167 | 168 | def flush(self): 169 | for f_ in self.files: 170 | f_.flush() 171 | 172 | 173 | def humanized_time(second): 174 | """ 175 | :param second: time in seconds 176 | :return: human readable time (hours, minutes, seconds) 177 | """ 178 | m, s = divmod(second, 60) 179 | h, m = divmod(m, 60) 180 | return "%dh %02dm %02ds" % (h, m, s) 181 | 182 | 183 | def setup_log(params, mode): 184 | """ 185 | Setup .log file to record training process and results. 186 | 187 | Args: 188 | params (dict): model parameters 189 | 190 | Returns: 191 | model_folder (str): model directory 192 | """ 193 | if params['walks_iter'] == 0: 194 | length = 1 195 | elif not params['walks_iter']: 196 | length = 1 197 | else: 198 | length = 2**params['walks_iter'] 199 | # folder_name = 'b{}-wd{}-ld{}-od{}-td{}-beta{}-pd{}-di{}-do{}-lr{}-gc{}-r{}-p{}-walks{}'.format( 200 | # params['batch'], params['word_dim'], params['lstm_dim'], params['out_dim'], params['type_dim'], 201 | # params['pos_dim'], params['beta'], params['pos_dim'], params['drop_i'], params['drop_o'], params['lr'], 202 | # params['gc'], params['reg'], params['patience'], length) 203 | 204 | folder_name = 'b{}-walks{}'.format(params['batch'], length) 205 | 206 | folder_name += '_'+'_'.join(params['edges']) 207 | 208 | if params['context']: 209 | folder_name += '_context' 210 | 211 | if params['types']: 212 | folder_name += '_types' 213 | 214 | if params['dist']: 215 | folder_name += '_dist' 216 | 217 | if params['freeze_words']: 218 | folder_name += '_freeze' 219 | 220 | if params['window']: 221 | folder_name += '_win'+str(params['window']) 222 | 223 | model_folder = params['folder'] + '/' + folder_name 224 | if not os.path.exists(model_folder): 225 | os.makedirs(model_folder) 226 | log_file = model_folder + '/info_'+mode+'.log' 227 | 228 | f = open(log_file, 'w') 229 | sys.stdout = Tee(sys.stdout, f) 230 | return model_folder 231 | 232 | 233 | def observe(model): 234 | """ 235 | Observe model parameters: name, range of matrices & gradients 236 | 237 | Args 238 | model: specified model object 239 | """ 240 | for name, param in model.named_parameters(): 241 | p_data, p_grad = param.data, param.grad.data 242 | print('Name: {:<30}\tRange of data: [{:.4f}, {:.4f}]\tRange of gradient: [{:.4f}, {:.4f}]'.format(name, 243 | np.min(p_data.data.to('cpu').numpy()), 244 | np.max(p_data.data.to('cpu').numpy()), 245 | np.min(p_grad.data.to('cpu').numpy()), 246 | np.max(p_grad.data.to('cpu').numpy()))) 247 | print('--------------------------------------') 248 | 249 | 250 | def save_model(model_folder, trainer, loader): 251 | print('\nSaving the model & the parameters ...') 252 | # save mappings 253 | with open(os.path.join(model_folder, 'mappings.pkl'), 'wb') as f: 254 | pkl.dump(loader, f, pkl.HIGHEST_PROTOCOL) 255 | torch.save(trainer.model.state_dict(), os.path.join(model_folder, 're.model')) 256 | 257 | 258 | def load_model(model_folder, trainer): 259 | print('\nLoading model & parameters ...') 260 | trainer.model.load_state_dict(torch.load(os.path.join(model_folder, 're.model'), 261 | map_location=trainer.model.device)) 262 | return trainer 263 | 264 | 265 | def load_mappings(model_folder): 266 | with open(os.path.join(model_folder, 'mappings.pkl'), 'rb') as f: 267 | loader = pkl.load(f) 268 | return loader 269 | 270 | 271 | def print_options(params): 272 | print('''\nParameters: 273 | - Train Data {} 274 | - Test Data {} 275 | - Embeddings {}, Freeze: {} 276 | - Save folder {} 277 | 278 | - batchsize {} 279 | - Walks iteration {} -> Length = {} 280 | - beta {} 281 | 282 | - Context {} 283 | - Node Type {} 284 | - Distances {} 285 | - Edge Types {} 286 | - Window {} 287 | 288 | - Epoch {} 289 | - UNK word prob {} 290 | - Parameter Average {} 291 | - Early stop {} -> Patience = {} 292 | - Regularization {} 293 | - Gradient Clip {} 294 | - Dropout I/O {}/{} 295 | - Learning rate {} 296 | - Seed {} 297 | '''.format(params['train_data'], params['test_data'], params['embeds'], params['freeze_words'], 298 | params['folder'], 299 | params['batch'], 300 | params['walks_iter'], 2 ** params['walks_iter'] if params['walks_iter'] else 0, params['beta'], 301 | params['context'], params['types'], params['dist'], params['edges'], 302 | params['window'], params['epoch'], 303 | params['unk_w_prob'], params['param_avg'], 304 | params['early_stop'], params['patience'], 305 | params['reg'], params['gc'], params['drop_i'], params['drop_o'], params['lr'], params['seed'])) 306 | --------------------------------------------------------------------------------