├── .gitignore
├── LICENSE
├── README.md
├── configs
├── parameters_cdr.yaml
└── parameters_gda.yaml
├── data_processing
├── 2017MeshTree.txt
├── __pycache__
│ ├── readers.cpython-36.pyc
│ ├── tools.cpython-36.pyc
│ └── utils.cpython-36.pyc
├── bin2txt.py
├── convertvec
├── convertvec.c
├── filter_hypernyms.py
├── gda2pubtator.py
├── process.py
├── process_cdr.sh
├── process_gda.sh
├── readers.py
├── reduce_embeds.py
├── split_gda.py
├── statistics.py
├── tools.py
├── train_gda_docs
└── utils.py
├── embeds
└── PubMed-CDR.txt
├── evaluation
└── evaluate.py
├── network.svg
├── requirements.txt
└── src
├── __init__.py
├── __pycache__
├── converter.cpython-36.pyc
├── dataset.cpython-36.pyc
├── loader.cpython-36.pyc
├── reader.cpython-36.pyc
└── utils.cpython-36.pyc
├── bin
└── run.sh
├── converter.py
├── dataset.py
├── eog.py
├── loader.py
├── nnet
├── __init__.py
├── __pycache__
│ ├── __init__.cpython-36.pyc
│ ├── attention.cpython-36.pyc
│ ├── init_net.cpython-36.pyc
│ ├── modules.cpython-36.pyc
│ ├── network.cpython-36.pyc
│ ├── trainer.cpython-36.pyc
│ └── walks.cpython-36.pyc
├── attention.py
├── init_net.py
├── modules.py
├── network.py
├── trainer.py
└── walks.py
├── reader.py
└── utils.py
/.gitignore:
--------------------------------------------------------------------------------
1 | *.pyc
2 | *.idea
3 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | LICENCES
2 |
3 | NaCTeM Non-Commercial No Redistribution Software Licence
4 |
5 | 1. Definitions
6 | 1.1 In this Agreement unless the context otherwise requires the following
7 | words shall have the following meanings:
8 | "Adaptations" means any adaptation, alteration, addition to, deletion
9 | from, manipulation, or modification of parts of the resource;
10 | "Agreement" means these terms and conditions contained herein;
11 | "Commercial Use" means the use of the whole or parts of the Software
12 | for any reason which generates a profit.
13 | "Intellectual Property" means all Intellectual Property Rights
14 | pertaining to and subsisting in any country throughout the world
15 | including but not by way of limitation patents, trade marks and/or
16 | service marks, designs (whether registered or unregistered), utility
17 | models, all applications for any of the foregoing, copyrights, design
18 | rights, trade or business names and confidential know-how including any
19 | licences in connections with the same.
20 | "NaCTeM" means the National Centre for Text Mining.
21 | "Non-Commercial Purpose" means the use of the Software solely for
22 | internal non-commercial research and academic purposes. Non-Commercial
23 | Purpose excludes, without limitation, any use of the Software, as part
24 | of, or in any way in connection with a product or service which is sold,
25 | offered for sale, licensed, leased, loaned, or rented.
26 | "Software" means the materials distributed with this licence for use by
27 | the Users.
28 | "Software Owner" means the person or body who created the Software and
29 | which owns the Intellectual Property in the Software.
30 | "University" means the University of Manchester and any of its servants,
31 | agents, successors and assigns and is a registered charity under
32 | Section 3 (5)(a) of the Charities Act 1993.
33 | "Users" means any person who is permitted to access the Software.
34 | 1.2 Words in the singular shall include the plural and vice versa, references
35 | to any gender shall include the others and references to legal persons shall
36 | include natural persons and vice versa.
37 | 1.3 The headings in these conditions are intended for reference only and shall
38 | not affect their construction.
39 |
40 | 2. Permitted Uses
41 | 2.1 The User may use the Software for Non-Commercial Purposes only and may:
42 | 2.1.1 electronically save the whole or any part or parts of the Software;
43 | 2.1.2 print out single copies of the whole or any part or parts of the
44 | Software;
45 | 2.1.3 make Adaptations to any parts of the Software;
46 | 2.1.4 use those parts of the Software which are the Intellectual Property
47 | of a third party in accordance with the licence terms granted by
48 | the respective Software Owner.
49 |
50 | 3. Restrictions
51 | 3.1 The User may not and may not authorise any other third party to:
52 | 3.1.1 sell or resell or otherwise make the information contained in the
53 | Software available in any manner or on any media to any one else
54 | without the written consent of the Software Owner;
55 | 3.1.2 remove, obscure or modify copyright notices, text acknowledging
56 | or other means of identification or disclaimers as they may appear
57 | without prior written permission of the Software Owner;
58 | 3.1.3 use all or any part of the Software for any Commercial Use or for
59 | any purpose other than Non-Commercial Purposes unless with the
60 | written consent of the Software Owner.
61 | 3.1.4 use any Adaptation of the Software except in accordance with all
62 | the terms of this Agreement unless with the written consent of the
63 | Software Owner.
64 | 3.2 This Clause shall survive termination of this Agreement for any reason.
65 |
66 | 4. Acknowledgement and Protection of Intellectual Property Rights
67 | 4.1 The User acknowledges that all copyrights, patent rights, trademarks,
68 | database rights, trade secrets and other intellectual property rights
69 | relating to the Software, are the property of the Software Owner or
70 | duly licensed and that this Agreement does not assign or transfer to
71 | the User any right, title or interest therein except for the right to
72 | use the Software in accordance with the terms and conditions of this
73 | Agreement.
74 | 4.2 The Users shall comply with the terms of the Copyright, Designs and
75 | Patents Act 1988 and in particular, but without limitation, shall
76 | recognise the rights, including moral right and the rights of attribution,
77 | of the Software Owner. Each use or adaptation of the Software shall
78 | make appropriate acknowledgement of the source, title, and copyright
79 | owner.
80 |
81 | 5. Representation, Warranties and Indemnification
82 | 5.1 The University makes no representation or warranty, and expressly
83 | disclaims any liability with respect to the Software including,
84 | but not limited to, errors or omissions contained therein, libel,
85 | defamation, infringements of rights of publicity, privacy, trademark
86 | rights, infringements of third party intellectual property rights,
87 | moral rights, or the disclosure of confidential information. It is
88 | expressly agreed that any use by the Users of the Software is at the
89 | User's sole risk.
90 | 5.2 The Software is provided on an 'as is' basis and the University
91 | disclaims any and all other warranties, conditions, or representations
92 | (express, implied, oral or written), relating to the Software or any
93 | part thereof, including, without limitation, any and all implied
94 | warranties of title, quality, performance, merchantability or fitness
95 | for a particular purpose. The University further expressly disclaims
96 | any warranty or representation to Users, or to any third party. The
97 | University accepts no liability for loss suffered or incurred by the
98 | User or any third party as a result of their reliance on the Software.
99 | 5.3 Nothing herein shall impose any liability upon the University in respect
100 | of use by the User of the Software and the University gives no indemnity
101 | in respect of any claim by the User or any third party relating to any
102 | action of the User in or arising from the Software.
103 | 5.4 It is the sole responsibility of the User to ensure that he has obtained
104 | any relevant third party permissions for any Adaptations of the Software
105 | made by the User and the User shall be responsible for any and all damages,
106 | liabilities, claims, causes of action, legal fees and costs incurred by
107 | the User in defending against any third party claim of intellectual
108 | property rights infringements or threats of claims thereof with respect
109 | of the use of the Software containing any Adaptations.
110 |
111 | 6. Consequential Loss
112 | 6.1 Neither party shall be liable to the other for any costs, claims, damages
113 | or expenses arising out of any act or omission or any breach of contract
114 | or statutory duty or in tort calculated by reference to profits, income,
115 | production or accruals or loss of such profits, income, production or
116 | accruals or by reference to accrual or such costs, claims, damages or
117 | expenses calculated on a time basis.
118 |
119 | 7. Termination
120 | 7.1 The University shall have the right to terminate this Agreement forthwith
121 | if the User shall have materially breached any of its obligations under
122 | this Agreement or in the event of a breach capable of remedy fails to
123 | remedy the same within thirty (30) days of the giving of notice by the
124 | University to the User of the alleged breach and of the action required
125 | to remedy the same.
126 |
127 | 8. General
128 | 8.1 Delay in exercising, or a failure to exercise, any right or remedy in
129 | connection with this Agreement shall not operate as a waiver of that
130 | right or remedy. A single or partial exercise of any right or remedy
131 | shall not preclude any other or further exercise of that right or remedy,
132 | or the exercise of any other right or remedy. A waiver of a breach of
133 | this Agreement shall not constitute a waiver of any subsequent breach.
134 | 8.2 Each of the parties acknowledges that it is not entering into this
135 | Agreement in reliance upon any representation, warranty, collateral
136 | contract or other assurance (except those set out in this Agreement
137 | and the documents referred to in it) made by or on behalf of any other
138 | party before the execution of this Agreement. Each of the parties waives
139 | all rights and remedies which, but for this clause, might otherwise be
140 | available to it in respect to any such representation, warranty, collateral
141 | contract or other assurance, provided that nothing in this Clause 8.2
142 | shall limit or exclude any liability for fraud.
143 | 8.3 Nothing in this Agreement shall constitute or be deemed to constitute a
144 | partnership or other form of joint venture between the parties or constitute
145 | or be deemed to constitute either party the agent or employee of the other
146 | for any purpose whatsoever.
147 | 8.4 No person who is not a party to this Agreement is entitled to enforce
148 | any of its terms, whether under the Contracts (Rights of Third Parties)
149 | Act 1999 or otherwise.
150 | 8.5 The parties intend each provision of this Agreement to be severable and
151 | distinct from the others. If a provision of this Agreement is held to be
152 | illegal, invalid or unenforceable, in whole or in part, the parties intend
153 | that the legality, validity and enforceability of the remainder of this
154 | Agreement shall not be affected.
155 |
156 | 9. Governing Law and Jurisdiction
157 | 9.1 This Agreement is governed by, and shall be interpreted in accordance with,
158 | English law and each party irrevocably submits to the non-exclusive
159 | jurisdiction of the English Courts in relation to all matters arising out
160 | of or in connection with this Agreement.
161 |
162 |
163 | Component Version License Link
164 |
165 | edge-oriented-graph 1.0 edge-oriented-graph https://github.com/fenchri/edge-oriented-graph/
166 |
167 | edge-oriented-graph License Information
168 |
169 | Copyright (c) 2019, Fenia Christopoulou, National Centre for Text Mining, School of Computer Science, The University of Manchester.
170 | All rights reserved.
171 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Edge-oriented Graph
2 | Source code for the EMNLP 2019 paper "[Connecting the Dots: Document-level Relation Extraction with Edge-oriented Graphs](https://www.aclweb.org/anthology/D19-1498.pdf)".
3 |
4 |
5 |
6 |
7 |
8 |
9 | ### Environment
10 | `$ pip3 install -r requirements.txt`
11 | The model was trained on Tesla K80 GPU, Ubuntu 16.04. Results are reproducible with a fixed seed.
12 |
13 |
14 |
15 | ### Reproducibility & Bug Fixes
16 |
17 | In the original code, there was a bug related to the word embedding layer.
18 | If you want to reproduce the results presented in the paper, you need to use the "buggy" code: [reproduceEMNLP](https://github.com/fenchri/edge-oriented-graph/tree/reproduceEMNLP)
19 | Otherwise we recommend that you use the current version (with higher performance).
20 |
21 |
22 |
23 | ## Datasets & Pre-processing
24 | Download the datasets
25 | ```
26 | $ mkdir data && cd data
27 | $ wget https://biocreative.bioinformatics.udel.edu/media/store/files/2016/CDR_Data.zip && unzip CDR_Data.zip && mv CDR_Data CDR
28 | $ wget https://bitbucket.org/alexwuhkucs/gda-extraction/get/fd4a7409365e.zip && unzip fd4a7409365e.zip && mv alexwuhkucs-gda-extraction-fd4a7409365e GDA
29 | $ cd ..
30 | ```
31 |
32 | Download the GENIA Tagger and Sentence Splitter:
33 | ```
34 | $ cd data_processing
35 | $ mkdir common && cd common
36 | $ wget http://www.nactem.ac.uk/y-matsu/geniass/geniass-1.00.tar.gz && tar xvzf geniass-1.00.tar.gz
37 | $ cd geniass/ && make && cd ..
38 | $ git clone https://github.com/bornabesic/genia-tagger-py.git
39 | $ cd genia-tagger-py
40 | ```
41 | Here, you should modify the Makefile inside genia-tagger-py and replace line 3 with `wget http://www.nactem.ac.uk/GENIA/tagger/geniatagger-3.0.2.tar.gz`
42 | ```
43 | $ make
44 | $ cd ../../
45 | ```
46 |
47 | > **Important**: In case genia splitter produces errors (e.g. cannot find a temp file), make sure you have ruby installed `sudo apt-get install ruby-full`
48 |
49 | In order to process the datasets, they should first be transformed into the PubTator format. The run the processing scripts as follows:
50 | ```
51 | $ sh process_cdr.sh
52 | $ sh process_gda.sh
53 | ```
54 |
55 | In order to get the data statistics run:
56 | ```
57 | python3 statistics.py --data ../data/CDR/processed/train_filter.data
58 | python3 statistics.py --data ../data/CDR/processed/dev_filter.data
59 | python3 statistics.py --data ../data/CDR/processed/test_filter.data
60 | ```
61 | This will additionally generate the gold-annotation file in the same folder with suffix `.gold`.
62 |
63 |
64 | ## Usage
65 | Run the main script for training and testing as follows. Select gpu -1 for cpu mode.
66 |
67 | **CDR dataset**: Train the model on the training set and evaluate on the dev set, in order to identify the best training epoch.
68 | For testing, re-run the model on the union of train and dev (`train+dev_filter.data`) until the best epoch and evaluate on the test set.
69 |
70 | **GDA dataset**: Simply train the model on the training set and evaluate on the dev set. Test the saved model on the test set.
71 |
72 | In order to ensure the usage of early stopping criterion, use the `--early_stop` option.
73 | If during training early stopping is not triggered, the maximum epoch (specified in the config file) will be used.
74 |
75 | Otherwise, if you want to train up to a specific epoch, use the `--epoch epochNumber` option without early stopping.
76 | The maximum stopping epochs is defined by the `--epoch` option.
77 |
78 | For example, in the CDR dataset:
79 | ```
80 | $ cd src/
81 | $ python3 eog.py --config ../configs/parameters_cdr.yaml --train --gpu 0 --early_stop # using early stopping
82 | $ python3 eog.py --config ../configs/parameters_cdr.yaml --train --gpu 0 --epoch 15 # train until the 15th epoch *without* early stopping
83 | $ python3 eog.py --config ../configs/parameters_cdr.yaml --train --gpu 0 --epoch 15 --early_stop # set both early stop and max epoch
84 |
85 | $ python3 eog.py --config ../configs/parameters_cdr.yaml --test --gpu 0
86 | ```
87 |
88 | All necessary parameters can be stored in the yaml files inside the configs directory.
89 | The following parameters can be also directly given as follows:
90 | ```
91 | usage: eog.py [-h] --config CONFIG [--train] [--test] [--gpu GPU]
92 | [--walks WALKS] [--window WINDOW] [--edges [EDGES [EDGES ...]]]
93 | [--types TYPES] [--context CONTEXT] [--dist DIST] [--example]
94 | [--seed SEED] [--early_stop] [--epoch EPOCH]
95 |
96 | optional arguments:
97 | -h, --help show this help message and exit
98 | --config CONFIG Yaml parameter file
99 | --train Training mode - model is saved
100 | --test Testing mode - needs a model to load
101 | --gpu GPU GPU number
102 | --walks WALKS Number of walk iterations
103 | --window WINDOW Window for training (empty processes the whole
104 | document, 1 processes 1 sentence at a time, etc)
105 | --edges [EDGES [EDGES ...]]
106 | Edge types
107 | --types TYPES Include node types (Boolean)
108 | --context CONTEXT Include MM context (Boolean)
109 | --dist DIST Include distance (Boolean)
110 | --example Show example
111 | --seed SEED Fixed random seed number
112 | --early_stop Use early stopping
113 | --epoch EPOCH Maximum training epoch
114 | ```
115 |
116 | ### Evaluation
117 | In order to evaluate you need to first generate the gold data format and then use the evaluation script as follows:
118 | ```
119 | $ cd evaluation/
120 | $ python3 evaluate.py --pred path_to_predictions_file --gold ../data/CDR/processed/test_filter.gold --label 1:CDR:2
121 | $ python3 evaluate.py --pred path_to_predictions_file --gold ../data/GDA/processed/test.gold --label 1:GDA:2
122 | ```
123 |
124 |
125 | ### Citation
126 |
127 | If you found this code useful and plan to use it, please cite the following paper =)
128 | ```
129 | @inproceedings{christopoulou2019connecting,
130 | title = "Connecting the Dots: Document-level Neural Relation Extraction with Edge-oriented Graphs",
131 | author = "Christopoulou, Fenia and Miwa, Makoto and Ananiadou, Sophia",
132 | booktitle = "Proceedings of the 2019 Conference on Empirical Methods in Natural Language Processing and the 9th International Joint Conference on Natural Language Processing (EMNLP-IJCNLP)",
133 | year = "2019",
134 | publisher = "Association for Computational Linguistics",
135 | pages = "4927--4938"
136 | }
137 | ```
138 |
--------------------------------------------------------------------------------
/configs/parameters_cdr.yaml:
--------------------------------------------------------------------------------
1 | # network
2 | batch: 2
3 | epoch: 50
4 | bilstm_layers: 1
5 | word_dim: 200
6 | lstm_dim: 100
7 | out_dim: 100
8 | type_dim: 10
9 | beta: 0.8
10 | dist_dim: 10
11 | drop_i: 0.5
12 | drop_m: 0.0
13 | drop_o: 0.3
14 | lr: 0.002
15 | gc: 10
16 | reg: 0.0001
17 | opt: adam
18 | patience: 10
19 | unk_w_prob: 0.5
20 | min_w_freq: 1
21 | walks_iter: 4
22 |
23 | # data based
24 | train_data: ../data/CDR/processed/train+dev_filter.data
25 | test_data: ../data/CDR/processed/test_filter.data
26 | embeds: ../embeds/PubMed-CDR.txt
27 | folder: ../results/cdr-test
28 | save_pred: test
29 |
30 | # options (chosen from parse input otherwise false)
31 | lowercase: false
32 | plot: true
33 | show_class: false
34 | param_avg: true
35 | early_stop: false
36 | save_model: true
37 | types: true
38 | context: true
39 | dist: true
40 | freeze_words: true
41 |
42 | # extra
43 | seed: 0
44 | shuffle_data: true
45 | label2ignore: 1:NR:2
46 | primary_metric: micro_f
47 | direction: l2r+r2l
48 | include_pairs: ['Chemical-Disease', 'Chemical-Chemical', 'Disease-Disease', 'Disease-Chemical']
49 | classify_pairs: ['Chemical-Disease']
50 | edges: ['MM', 'ME', 'MS', 'ES', 'SS-ind']
51 | window:
52 |
--------------------------------------------------------------------------------
/configs/parameters_gda.yaml:
--------------------------------------------------------------------------------
1 | # network
2 | batch: 3
3 | epoch: 50
4 | bilstm_layers: 1
5 | word_dim: 200
6 | lstm_dim: 100
7 | out_dim: 100
8 | type_dim: 10
9 | beta: 0.8
10 | dist_dim: 10
11 | drop_i: 0.5
12 | drop_m: 0.0
13 | drop_o: 0.3
14 | lr: 0.002
15 | gc: 10
16 | reg: 0.0001
17 | opt: adam
18 | patience: 5
19 | unk_w_prob: 0.5
20 | min_w_freq: 1
21 | walks_iter: 4
22 |
23 | # data based
24 | train_data: ../data/GDA/processed/train.data
25 | test_data: ../data/GDA/processed/dev.data
26 | embeds:
27 | folder: ../results/gda
28 | save_pred: test
29 |
30 | # options (chosen from parse input otherwise false)
31 | lowercase: false
32 | plot: true
33 | show_class: false
34 | param_avg: true
35 | early_stop: false
36 | save_model: true
37 | types: true
38 | context: true
39 | dist: true
40 | freeze_words: false
41 |
42 | # extra
43 | seed: 0
44 | shuffle_data: true
45 | label2ignore: 1:NR:2
46 | primary_metric: micro_f
47 | direction: l2r+r2l
48 | include_pairs: ['Gene-Disease', 'Disease-Disease', 'Gene-Gene', 'Disease-Gene']
49 | classify_pairs: ['Gene-Disease']
50 | edges: ['MM', 'ME', 'MS', 'ES', 'SS-ind']
51 | window:
52 |
--------------------------------------------------------------------------------
/data_processing/__pycache__/readers.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/fenchri/edge-oriented-graph/629b84a630146e81ffe15cbf6f5e5cf4efd9fb34/data_processing/__pycache__/readers.cpython-36.pyc
--------------------------------------------------------------------------------
/data_processing/__pycache__/tools.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/fenchri/edge-oriented-graph/629b84a630146e81ffe15cbf6f5e5cf4efd9fb34/data_processing/__pycache__/tools.cpython-36.pyc
--------------------------------------------------------------------------------
/data_processing/__pycache__/utils.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/fenchri/edge-oriented-graph/629b84a630146e81ffe15cbf6f5e5cf4efd9fb34/data_processing/__pycache__/utils.cpython-36.pyc
--------------------------------------------------------------------------------
/data_processing/bin2txt.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | """
3 | Created on Sat Aug 12 10:32:10 2017
4 |
5 | @author: fenia
6 | """
7 |
8 | from gensim.models.keyedvectors import KeyedVectors
9 | import sys
10 |
11 | """
12 | Transform from 'bin' to 'txt' word vectors.
13 | Input: the bin file
14 | Output: the txt file
15 | """
16 | inp = sys.argv[1]
17 | out = ''.join(inp.split('.bin')[:-1])+'.txt'
18 |
19 | model = KeyedVectors.load_word2vec_format(inp, binary=True)
20 | model.save_word2vec_format(out, binary=False)
21 |
--------------------------------------------------------------------------------
/data_processing/convertvec:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/fenchri/edge-oriented-graph/629b84a630146e81ffe15cbf6f5e5cf4efd9fb34/data_processing/convertvec
--------------------------------------------------------------------------------
/data_processing/convertvec.c:
--------------------------------------------------------------------------------
1 | // Code to convert word2vec vectors between text and binary format
2 | // Created by Marek Rei
3 | // Credits to: https://github.com/marekrei/convertvec
4 |
5 | #include
6 | #include
7 | #include
8 | #include
9 | #include
10 |
11 | const long long max_w = 2000;
12 |
13 | // Convert from text format to binary
14 | void txt2bin(char * input_path, char * output_path){
15 | FILE * fi = fopen(input_path, "rb");
16 | FILE * fo = fopen(output_path, "wb");
17 |
18 | long long words, size;
19 | fscanf(fi, "%lld", &words);
20 | fscanf(fi, "%lld", &size);
21 | fscanf(fi, "%*[ ]");
22 | fscanf(fi, "%*[\n]");
23 |
24 | fprintf(fo, "%lld %lld\n", words, size);
25 |
26 | char word[max_w];
27 | char ch;
28 | float value;
29 | int b, a;
30 | for (b = 0; b < words; b++) {
31 | if(feof(fi))
32 | break;
33 |
34 | word[0] = 0;
35 | fscanf(fi, "%[^ ]", word);
36 | fscanf(fi, "%c", &ch);
37 | // This kind of whitespace handling is a bit more explicit than usual.
38 | // It allows us to correctly handle special characters that would otherwise be skipped.
39 |
40 | fprintf(fo, "%s ", word);
41 |
42 | for (a = 0; a < size; a++) {
43 | fscanf(fi, "%s", word);
44 | fscanf(fi, "%*[ ]");
45 | value = atof(word);
46 | fwrite(&value, sizeof(float), 1, fo);
47 | }
48 | fscanf(fi, "%*[\n]");
49 | fprintf(fo, "\n");
50 | }
51 |
52 | fclose(fi);
53 | fclose(fo);
54 | }
55 |
56 | // Convert from binary to text format
57 | void bin2txt(char * input_path, char * output_path){
58 | FILE * fi = fopen(input_path, "rb");
59 | FILE * fo = fopen(output_path, "wb");
60 |
61 | long long words, size;
62 | fscanf(fi, "%lld", &words);
63 | fscanf(fi, "%lld", &size);
64 | fscanf(fi, "%*[ ]");
65 | fscanf(fi, "%*[\n]");
66 |
67 | fprintf(fo, "%lld %lld\n", words, size);
68 |
69 | char word[max_w];
70 | char ch;
71 | float value;
72 | int b, a;
73 | for (b = 0; b < words; b++) {
74 | if(feof(fi))
75 | break;
76 |
77 | word[0] = 0;
78 | fscanf(fi, "%[^ ]", word);
79 | fscanf(fi, "%c", &ch);
80 |
81 | fprintf(fo, "%s ", word);
82 | for (a = 0; a < size; a++) {
83 | fread(&value, sizeof(float), 1, fi);
84 | fprintf(fo, "%lf ", value);
85 | }
86 | fscanf(fi, "%*[\n]");
87 | fprintf(fo, "\n");
88 | }
89 | fclose(fi);
90 | fclose(fo);
91 | }
92 |
93 | int main(int argc, char **argv) {
94 | if (argc < 4) {
95 | printf("USAGE: convertvec method input_path output_path\n");
96 | printf("Method is either bin2txt or txt2bin\n");
97 | return 0;
98 | }
99 |
100 | if(strcmp(argv[1], "bin2txt") == 0)
101 | bin2txt(argv[2], argv[3]);
102 | else if(strcmp(argv[1], "txt2bin") == 0)
103 | txt2bin(argv[2], argv[3]);
104 | else
105 | printf("Unknown method: %s\n", argv[1]);
106 |
107 | return 0;
108 | }
109 |
--------------------------------------------------------------------------------
/data_processing/filter_hypernyms.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | """
3 | Created on 17/04/19
4 |
5 | author: fenia
6 | """
7 |
8 | import argparse
9 | import codecs
10 | from collections import defaultdict
11 | '''
12 | Adaptation of https://github.com/patverga/bran/blob/master/src/processing/utils/filter_hypernyms.py
13 | '''
14 |
15 | parser = argparse.ArgumentParser()
16 | parser.add_argument('-i', '--input_file', required=True, help='input file in 13col tsv')
17 | parser.add_argument('-m', '--mesh_file', required=True, help='mesh file to get hierarchy from')
18 | parser.add_argument('-o', '--output_file', required=True, help='write results to this file')
19 |
20 | args = parser.parse_args()
21 |
22 |
23 | def chunks(l, n):
24 | """
25 | Yield successive n-sized chunks from l.
26 | """
27 | for i in range(0, len(l), n):
28 | assert len(l[i:i + n]) == n
29 | yield l[i:i + n]
30 |
31 |
32 | # read in mesh hierarchy
33 | ent_tree_map = defaultdict(list)
34 | with codecs.open(args.mesh_file, 'r') as f:
35 | lines = [l.rstrip().split('\t') for i, l in enumerate(f) if i > 0]
36 | [ent_tree_map[l[1]].append(l[0]) for l in lines]
37 |
38 |
39 | # read in positive input file and organize by document
40 | print('Loading examples from %s' % args.input_file)
41 | pos_doc_examples = defaultdict(list)
42 | neg_doc_examples = defaultdict(list)
43 |
44 | unfilitered_pos_count = 0
45 | unfilitered_neg_count = 0
46 | text = {}
47 | with open(args.input_file, 'r') as f:
48 | lines = [l.strip().split('\t') for l in f]
49 |
50 | for l in lines:
51 | pmid = l[0]
52 | text[pmid] = pmid+'\t'+l[1]
53 |
54 | for r in chunks(l[2:], 17):
55 |
56 | if r[0] == '1:NR:2':
57 | assert ((r[7] == 'Chemical') and (r[13] == 'Disease'))
58 | neg_doc_examples[pmid].append(r)
59 | unfilitered_neg_count += 1
60 | elif r[0] == '1:CID:2':
61 | assert ((r[7] == 'Chemical') and (r[13] == 'Disease'))
62 | pos_doc_examples[pmid].append(r)
63 | unfilitered_pos_count += 1
64 |
65 |
66 | # iterate over docs
67 | hypo_count = 0
68 | negative_count = 0
69 |
70 | all_pos = 0
71 | with open(args.output_file, 'w') as out_f:
72 | for doc_id in pos_doc_examples.keys():
73 | towrite = text[doc_id]
74 |
75 | for r in pos_doc_examples[doc_id]:
76 | towrite += '\t'
77 | towrite += '\t'.join(r)
78 | all_pos += len(pos_doc_examples[doc_id])
79 |
80 | # get nodes for all the positive diseases
81 | pos_e2_examples = [(pos_node, pe) for pe in pos_doc_examples[doc_id]
82 | for pos_node in ent_tree_map[pe[11]]]
83 |
84 | pos_e1_examples = [(pos_node, pe) for pe in pos_doc_examples[doc_id]
85 | for pos_node in ent_tree_map[pe[5]]]
86 |
87 | filtered_neg_exampled = []
88 | for ne in neg_doc_examples[doc_id]:
89 | neg_e1 = ne[5]
90 | neg_e2 = ne[11]
91 | example_hyponyms = 0
92 | for neg_node in ent_tree_map[ne[11]]:
93 | hyponyms = [pos_node for pos_node, pe in pos_e2_examples
94 | if neg_node in pos_node and neg_e1 == pe[5]] \
95 | + [pos_node for pos_node, pe in pos_e1_examples
96 | if neg_node in pos_node and neg_e2 == pe[11]]
97 | example_hyponyms += len(hyponyms)
98 | if example_hyponyms == 0:
99 | towrite += '\t'+'\t'.join(ne)
100 | negative_count += 1
101 | else:
102 | ne[0] = 'not_include' # just don't include the negative pairs, but keep the entities
103 | towrite += '\t'+'\t'.join(ne)
104 | hypo_count += example_hyponyms
105 | out_f.write(towrite+'\n')
106 |
107 | print('Mesh entities: %d' % len(ent_tree_map))
108 | print('Positive Docs: %d' % len(pos_doc_examples))
109 | print('Negative Docs: %d' % len(neg_doc_examples))
110 | print('Positive Count: %d Initial Negative Count: %d Final Negative Count: %d Hyponyms: %d' %
111 | (unfilitered_pos_count, unfilitered_neg_count, negative_count, hypo_count))
112 | print(all_pos)
--------------------------------------------------------------------------------
/data_processing/gda2pubtator.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 | """
4 | Created on 06/09/2019
5 |
6 | author: fenia
7 | """
8 |
9 | import os, re, sys
10 | import argparse
11 | from tqdm import tqdm
12 |
13 | parser = argparse.ArgumentParser()
14 | parser.add_argument('--input_folder', '-i', type=str)
15 | parser.add_argument('--output_file', '-o', type=str)
16 | args = parser.parse_args()
17 |
18 | if not os.path.exists('/'.join(args.output_file.split('/')[:-1])):
19 | os.makedirs('/'.join(args.output_file.split('/')[:-1]))
20 |
21 | abstracts = {}
22 | entities = {}
23 | relations = {}
24 | with open(args.input_folder + 'abstracts.txt', 'r') as infile:
25 | for line in infile:
26 | if line.rstrip().isdigit():
27 | pmid = line.rstrip()
28 | abstracts[pmid] = []
29 | relations[pmid] = []
30 | entities[pmid] = []
31 |
32 | elif line != '\n':
33 | abstracts[pmid] += [line.rstrip()]
34 |
35 | with open(args.input_folder + 'anns.txt', 'r') as infile:
36 | for line in infile:
37 | line = line.split('\t')
38 |
39 | if line[0].isdigit():
40 | entities[line[0]] += [tuple(line)]
41 |
42 | with open(args.input_folder + 'labels.csv', 'r') as infile:
43 | for line in infile:
44 | line = line.split(',')
45 |
46 | if line[0].isdigit() and line[3].rstrip() == '1':
47 | line = ','.join(line).rstrip().split(',')
48 |
49 | relations[line[0]] += [tuple([line[0]] + ['GDA'] + line[1:-1])]
50 |
51 | with open(args.output_file, 'w') as outfile:
52 | for d in tqdm(abstracts.keys(), desc='Writing 2 PubTator format'):
53 | if len(abstracts[d]) > 2:
54 | print('something is wrong')
55 | exit(0)
56 |
57 | for i in range(0, len(abstracts[d])):
58 | if i == 0:
59 | outfile.write('{}|t|{}\n'.format(d, abstracts[d][i]))
60 | else:
61 | outfile.write('{}|a|{}\n'.format(d, abstracts[d][i]))
62 |
63 | for e in entities[d]:
64 | outfile.write('{}'.format('\t'.join(e)))
65 |
66 | for r in relations[d]:
67 | outfile.write('{}\n'.format('\t'.join(r)))
68 | outfile.write('\n')
69 |
70 |
--------------------------------------------------------------------------------
/data_processing/process.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 | """
4 | Created on %(date)s
5 |
6 | @author: fenia
7 | """
8 |
9 | import os
10 | import re
11 | from tqdm import tqdm
12 | from recordtype import recordtype
13 | from collections import OrderedDict
14 | import argparse
15 | import pickle
16 | from itertools import permutations, combinations
17 | from tools import sentence_split_genia, tokenize_genia
18 | from tools import adjust_offsets, find_mentions, find_cross, fix_sent_break, convert2sent, generate_pairs
19 | from readers import *
20 |
21 | TextStruct = recordtype('TextStruct', 'pmid txt')
22 | EntStruct = recordtype('EntStruct', 'pmid name off1 off2 type kb_id sent_no word_id bio')
23 | RelStruct = recordtype('RelStruct', 'pmid type arg1 arg2')
24 |
25 |
26 | def main():
27 | """
28 | Main processing function
29 | """
30 | parser = argparse.ArgumentParser()
31 | parser.add_argument('--input_file', '-i', type=str)
32 | parser.add_argument('--output_file', '-o', type=str)
33 | parser.add_argument('--data', '-d', type=str)
34 | args = parser.parse_args()
35 |
36 | if args.data == 'GDA':
37 | abstracts, entities, relations = readPubTator(args)
38 | type1 = ['Gene']
39 | type2 = ['Disease']
40 |
41 | elif args.data == 'CDR':
42 | abstracts, entities, relations = readPubTator(args)
43 | type1 = ['Chemical']
44 | type2 = ['Disease']
45 |
46 | else:
47 | print('Dataset non-existent.')
48 | sys.exit()
49 |
50 | if not os.path.exists(args.output_file + '_files/'):
51 | os.makedirs(args.output_file + '_files/')
52 |
53 | # Process
54 | positive, negative = 0, 0
55 | with open(args.output_file + '.data', 'w') as data_out:
56 | pbar = tqdm(list(abstracts.keys()))
57 | for i in pbar:
58 | pbar.set_description("Processing Doc_ID {}".format(i))
59 |
60 | ''' Sentence Split '''
61 | orig_sentences = [item for sublist in [a.txt.split('\n') for a in abstracts[i]] for item in sublist]
62 | split_sents = sentence_split_genia(orig_sentences)
63 | split_sents = fix_sent_break(split_sents, entities[i])
64 | with open(args.output_file + '_files/' + i + '.split.txt', 'w') as f:
65 | f.write('\n'.join(split_sents))
66 |
67 | # adjust offsets
68 | new_entities = adjust_offsets(orig_sentences, split_sents, entities[i], show=False)
69 |
70 | ''' Tokenisation '''
71 | token_sents = tokenize_genia(split_sents)
72 | with open(args.output_file + '_files/' + i + '.split.tok.txt', 'w') as f:
73 | f.write('\n'.join(token_sents))
74 |
75 | # adjust offsets
76 | new_entities = adjust_offsets(split_sents, token_sents, new_entities, show=True)
77 |
78 | ''' Find mentions '''
79 | unique_entities = find_mentions(new_entities)
80 | with open(args.output_file + '_files/' + i + '.mention', 'wb') as f:
81 | pickle.dump(unique_entities, f, pickle.HIGHEST_PROTOCOL)
82 |
83 | ''' Generate Pairs '''
84 | if i in relations:
85 | pairs = generate_pairs(unique_entities, type1, type2, relations[i])
86 | else:
87 | pairs = generate_pairs(unique_entities, type1, type2, []) # generate only negative pairs
88 |
89 | # 'pmid type arg1 arg2 dir cross'
90 | data_out.write('{}\t{}'.format(i, '|'.join(token_sents)))
91 |
92 | for args_, p in pairs.items():
93 | if p.type != '1:NR:2':
94 | positive += 1
95 | elif p.type == '1:NR:2':
96 | negative += 1
97 |
98 | data_out.write('\t{}\t{}\t{}\t{}-{}\t{}-{}'.format(p.type, p.dir, p.cross, p.closest[0].word_id[0],
99 | p.closest[0].word_id[-1]+1,
100 | p.closest[1].word_id[0],
101 | p.closest[1].word_id[-1]+1))
102 | data_out.write('\t{}\t{}\t{}\t{}\t{}\t{}'.format(
103 | '|'.join([g for g in p.arg1]),
104 | '|'.join([e.name for e in unique_entities[p.arg1]]),
105 | unique_entities[p.arg1][0].type,
106 | ':'.join([str(e.word_id[0]) for e in unique_entities[p.arg1]]),
107 | ':'.join([str(e.word_id[-1] + 1) for e in unique_entities[p.arg1]]),
108 | ':'.join([str(e.sent_no) for e in unique_entities[p.arg1]])))
109 |
110 | data_out.write('\t{}\t{}\t{}\t{}\t{}\t{}'.format(
111 | '|'.join([g for g in p.arg2]),
112 | '|'.join([e.name for e in unique_entities[p.arg2]]),
113 | unique_entities[p.arg2][0].type,
114 | ':'.join([str(e.word_id[0]) for e in unique_entities[p.arg2]]),
115 | ':'.join([str(e.word_id[-1] + 1) for e in unique_entities[p.arg2]]),
116 | ':'.join([str(e.sent_no) for e in unique_entities[p.arg2]])))
117 | data_out.write('\n')
118 | print('Total positive pairs:', positive)
119 | print('Total negative pairs:', negative)
120 |
121 |
122 | if __name__ == "__main__":
123 | main()
124 |
--------------------------------------------------------------------------------
/data_processing/process_cdr.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | for d in "Training" "Development" "Test";
4 | do
5 | python3 process.py --input_file ../data/CDR/CDR.Corpus.v010516/CDR_${d}Set.PubTator.txt \
6 | --output_file ../data/CDR/processed/${d} \
7 | --data CDR
8 |
9 | python3 filter_hypernyms.py --mesh_file 2017MeshTree.txt \
10 | --input_file ../data/CDR/processed/${d}.data \
11 | --output_file ../data/CDR/processed/${d}_filter.data
12 | done
13 |
14 | mv ../data/CDR/processed/Training.data ../data/CDR/processed/train.data
15 | mv ../data/CDR/processed/Development.data ../data/CDR/processed/dev.data
16 | mv ../data/CDR/processed/Test.data ../data/CDR/processed/test.data
17 |
18 | mv ../data/CDR/processed/Training_filter.data ../data/CDR/processed/train_filter.data
19 | mv ../data/CDR/processed/Development_filter.data ../data/CDR/processed/dev_filter.data
20 | mv ../data/CDR/processed/Test_filter.data ../data/CDR/processed/test_filter.data
21 |
22 | # merge train and dev
23 | cat ../data/CDR/processed/train_filter.data > ../data/CDR/processed/train+dev_filter.data
24 | cat ../data/CDR/processed/dev_filter.data >> ../data/CDR/processed/train+dev_filter.data
25 |
26 |
27 |
28 |
--------------------------------------------------------------------------------
/data_processing/process_gda.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | for d in "training" "testing";
4 | do
5 | python3 gda2pubtator.py --input_folder ../data/GDA/${d}_data/ \
6 | --output_file ../data/GDA/processed/${d}.pubtator
7 |
8 | python3 process.py --input_file ../data/GDA/processed/${d}.pubtator \
9 | --output_file ../data/GDA/processed/${d} \
10 | --data GDA
11 | done
12 |
13 | mv ../data/GDA/processed/testing.data ../data/GDA/processed/test.data
14 | mv ../data/GDA/processed/training.data ../data/GDA/processed/train+dev.data
15 |
16 | python3 split_gda.py --input_file ../data/GDA/processed/train+dev.data \
17 | --output_train ../data/GDA/processed/train.data \
18 | --output_dev ../data/GDA/processed/dev.data \
19 | --list train_gda_docs
20 |
--------------------------------------------------------------------------------
/data_processing/readers.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 | """
4 | Created on 14/08/2019
5 |
6 | author: fenia
7 | """
8 |
9 | import os, re, sys
10 | from utils import replace2symbol, replace2space
11 | from collections import OrderedDict
12 | from tqdm import tqdm
13 | from recordtype import recordtype
14 |
15 | TextStruct = recordtype('TextStruct', 'pmid txt')
16 | EntStruct = recordtype('EntStruct', 'pmid name off1 off2 type kb_id sent_no word_id bio')
17 | RelStruct = recordtype('RelStruct', 'pmid type arg1 arg2')
18 | PairStruct = recordtype('PairStruct', 'pmid type arg1 arg2 dir cross closest')
19 |
20 |
21 | def readPubTator(args):
22 | """
23 | Read data and store in structs
24 | """
25 | if not os.path.exists('/'.join(args.output_file.split('/')[:-1])):
26 | os.makedirs('/'.join(args.output_file.split('/')[:-1]))
27 |
28 | abstracts = OrderedDict()
29 | entities = OrderedDict()
30 | relations = OrderedDict()
31 |
32 | with open(args.input_file, 'r') as infile:
33 | for line in tqdm(infile):
34 |
35 | # text
36 | if len(line.rstrip().split('|')) == 3 and \
37 | (line.strip().split('|')[1] == 't' or line.strip().split('|')[1] == 'a'):
38 | line = line.strip().split('|')
39 |
40 | pmid = line[0]
41 | text = line[2] # .replace('>', '\n')
42 |
43 | # replace weird symbols and spaces
44 | text = replace2symbol(text)
45 | text = replace2space(text)
46 |
47 | if pmid not in abstracts:
48 | abstracts[pmid] = [TextStruct(pmid, text)]
49 | else:
50 | abstracts[pmid] += [TextStruct(pmid, text)]
51 |
52 | # entities
53 | elif len(line.rstrip().split('\t')) == 6:
54 | line = line.strip().split('\t')
55 | pmid = line[0]
56 | offset1 = int(line[1])
57 | offset2 = int(line[2])
58 | ent_name = line[3]
59 | ent_type = line[4]
60 | kb_id = line[5].split('|')
61 |
62 | # replace weird symbols and spaces
63 | ent_name = replace2symbol(ent_name)
64 | ent_name = replace2space(ent_name)
65 |
66 | # currently consider each possible ID as another entity
67 | for k in kb_id:
68 | if pmid not in entities:
69 | entities[pmid] = [EntStruct(pmid, ent_name, offset1, offset2, ent_type, [k], -1, [], [])]
70 | else:
71 | entities[pmid] += [EntStruct(pmid, ent_name, offset1, offset2, ent_type, [k], -1, [], [])]
72 |
73 | elif len(line.rstrip().split('\t')) == 7:
74 | line = line.strip().split('\t')
75 | pmid = line[0]
76 | offset1 = int(line[1])
77 | offset2 = int(line[2])
78 | ent_name = line[3]
79 | ent_type = line[4]
80 | kb_id = line[5].split('|')
81 | extra_ents = line[6].split('|')
82 |
83 | # replace weird symbols and spaces
84 | ent_name = replace2symbol(ent_name)
85 | ent_name = replace2space(ent_name)
86 | for i, e in enumerate(extra_ents):
87 | if pmid not in entities:
88 | entities[pmid] = [EntStruct(pmid, ent_name, offset1, offset2, ent_type, [kb_id[i]], -1, [], [])]
89 | else:
90 | entities[pmid] += [EntStruct(pmid, ent_name, offset1, offset2, ent_type, [kb_id[i]], -1, [], [])]
91 |
92 | # relations
93 | elif len(line.rstrip().split('\t')) == 4:
94 | line = line.strip().split('\t')
95 | pmid = line[0]
96 | rel_type = line[1]
97 | arg1 = tuple((line[2].split('|')))
98 | arg2 = tuple((line[3].split('|')))
99 |
100 | if pmid not in relations:
101 | relations[pmid] = [RelStruct(pmid, rel_type, arg1, arg2)]
102 | else:
103 | relations[pmid] += [RelStruct(pmid, rel_type, arg1, arg2)]
104 |
105 | elif line == '\n':
106 | continue
107 |
108 | return abstracts, entities, relations
109 |
--------------------------------------------------------------------------------
/data_processing/reduce_embeds.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/python3
2 |
3 | import numpy as np
4 | import glob
5 | import sys
6 | from collections import OrderedDict
7 | import argparse
8 |
9 | """
10 | Crop embeddings to the size of the dataset, i.e. keeping only existing words.
11 | """
12 |
13 | def load_pretrained_embeddings(embeds):
14 | """
15 | :param params: input parameters
16 | :returns
17 | dictionary with words (keys) and embeddings (values)
18 | """
19 | if embeds:
20 | E = OrderedDict()
21 | with open(embeds, 'r') as vectors:
22 | for x, line in enumerate(vectors):
23 | if x == 0 and len(line.split()) == 2:
24 | words, num = map(int, line.rstrip().split())
25 | else:
26 | word = line.rstrip().split()[0]
27 | vec = line.rstrip().split()[1:]
28 | n = len(vec)
29 | if len(vec) != num:
30 | # print('Wrong dimensionality: {} {} != {}'.format(word, len(vec), num))
31 | continue
32 | else:
33 | E[word] = np.asarray(vec, dtype=np.float32)
34 | print('Pre-trained word embeddings: {} x {}'.format(len(E), n))
35 | else:
36 | E = OrderedDict()
37 | print('No pre-trained word embeddings loaded.')
38 | return E
39 |
40 |
41 | def main():
42 | parser = argparse.ArgumentParser()
43 | parser.add_argument('--full_embeds', type=str)
44 | parser.add_argument('--out_embeds', type=str)
45 | parser.add_argument('--in_data', nargs='+')
46 | args = parser.parse_args()
47 |
48 | words = []
49 | print('Extracting words from the dataset ... ', end="")
50 |
51 | for filef in args.in_data:
52 | with open(filef, 'r') as infile:
53 | for line in infile:
54 | line = line.strip().split('\t')[1]
55 | line = line.split('|')
56 | line = [l.split(' ') for l in line]
57 | line = [item for sublist in line for item in sublist]
58 |
59 | for l in line:
60 | words.append(l)
61 | print('Done')
62 |
63 | # make lowercase
64 | words_lower = list(map(lambda x:x.lower(), words))
65 |
66 | print('Loading embeddings ... ', end="")
67 | embeddings = load_pretrained_embeddings(args.full_embeds)
68 |
69 | print('Writing final embeddings ... ', end="")
70 | words = set(words)
71 | words_lower = set(words_lower) # lowercased
72 |
73 | new_embeds = OrderedDict()
74 | for w in embeddings.keys():
75 | if (w in words) or (w in words_lower):
76 | new_embeds[w] = embeddings[w]
77 |
78 | with open(args.out_embeds, 'w') as outfile:
79 | for g in new_embeds.keys():
80 | outfile.write('{} {}\n'.format(g, ' '.join(map(str, list(new_embeds[g])))))
81 | print('Done')
82 |
83 | if __name__ == "__main__":
84 | main()
85 |
--------------------------------------------------------------------------------
/data_processing/split_gda.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 | """
4 | Created on 06/09/2019
5 |
6 | author: fenia
7 | """
8 |
9 | import argparse
10 | import os
11 |
12 | parser = argparse.ArgumentParser()
13 | parser.add_argument('--input_file', type=str)
14 | parser.add_argument('--output_train', type=str)
15 | parser.add_argument('--output_dev', type=str)
16 | parser.add_argument('--list', type=str)
17 | args = parser.parse_args()
18 |
19 |
20 | with open(args.list, 'r') as infile:
21 | docs = [i.rstrip() for i in infile]
22 | docs = frozenset(docs)
23 |
24 | with open(args.input_file, 'r') as infile, open(args.output_train, 'w') as otr, open(args.output_dev, 'w') as odev:
25 | for line in infile:
26 | pmid = line.rstrip().split('\t')[0]
27 |
28 | if pmid in docs:
29 | otr.write(line)
30 | else:
31 | odev.write(line)
32 |
33 |
--------------------------------------------------------------------------------
/data_processing/statistics.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 | """
4 | Created on 09/05/2019
5 |
6 | author: fenia
7 | """
8 |
9 | import argparse
10 | import numpy as np
11 | from collections import OrderedDict
12 | from recordtype import recordtype
13 |
14 | parser = argparse.ArgumentParser()
15 | parser.add_argument('--data', type=str)
16 | args = parser.parse_args()
17 |
18 |
19 | EntityInfo = recordtype('EntityInfo', 'type mstart mend sentNo')
20 | PairInfo = recordtype('PairInfo', 'type direction cross closeA closeB')
21 |
22 |
23 | def chunks(l, n):
24 | """
25 | Yield successive n-sized chunks from l.
26 | """
27 | for i in range(0, len(l), n):
28 | assert len(l[i:i + n]) == n
29 | yield l[i:i + n]
30 |
31 |
32 | for d, data in zip(['DATA'], [args.data]):
33 | documents = {}
34 | entities = {}
35 | relations = {}
36 |
37 | with open(data, 'r') as infile:
38 |
39 | for line in infile:
40 | line = line.rstrip().split('\t')
41 | pairs = chunks(line[2:], 17)
42 |
43 | id_ = line[0]
44 |
45 | if id_ not in documents:
46 | documents[id_] = []
47 |
48 | for sent in line[1].split('|'):
49 | documents[id_] += [sent]
50 |
51 | if id_ not in entities:
52 | entities[id_] = OrderedDict()
53 |
54 | if id_ not in relations:
55 | relations[id_] = OrderedDict()
56 |
57 | for p in pairs:
58 | # pairs
59 | if (p[5], p[11]) not in relations[id_]:
60 | relations[id_][(p[5], p[11])] = PairInfo(p[0], p[1], p[2], p[3], p[4])
61 | else:
62 | print('duplicates!')
63 |
64 | # entities
65 | if p[5] not in entities[id_]:
66 | entities[id_][p[5]] = EntityInfo(p[7], p[8], p[9], p[10])
67 |
68 | if p[11] not in entities[id_]:
69 | entities[id_][p[11]] = EntityInfo(p[13], p[14], p[15], p[16])
70 |
71 | docs = len(documents)
72 | pair_types = {}
73 | inter_types = {}
74 | intra_types = {}
75 | ent_types = {}
76 | men_types = {}
77 | dist = {}
78 | for id_ in relations.keys():
79 | for k, p in relations[id_].items():
80 |
81 | if p.type not in pair_types:
82 | pair_types[p.type] = 0
83 | pair_types[p.type] += 1
84 |
85 | if p.type not in dist:
86 | dist[p.type] = []
87 |
88 | if p.type not in inter_types:
89 | inter_types[p.type] = 0
90 |
91 | if p.type not in intra_types:
92 | intra_types[p.type] = 0
93 |
94 | if p.type not in dist:
95 | dist[p.type] = []
96 |
97 | if p.cross == 'CROSS':
98 | inter_types[p.type] += 1
99 | else:
100 | intra_types[p.type] += 1
101 |
102 | if p.cross == 'CROSS':
103 | dist_ = 10000
104 |
105 | for m1 in entities[id_][k[0]].sentNo.split(':'):
106 | for m2 in entities[id_][k[1]].sentNo.split(':'):
107 |
108 | if abs(int(m1) - int(m2)) < dist_:
109 | dist_ = abs(int(m1) - int(m2))
110 |
111 | dist[p.type] += [dist_]
112 |
113 | for e in entities[id_].values():
114 | if e.type not in ent_types:
115 | ent_types[e.type] = 0
116 | ent_types[e.type] += 1
117 |
118 | if e.type not in men_types:
119 | men_types[e.type] = 0
120 | for m in e.mstart.split(':'):
121 | men_types[e.type] += 1
122 |
123 | ents_per_doc = [len(entities[n]) for n in documents.keys()]
124 | ments_per_doc = [np.sum([len(e.sentNo.split(':')) for e in entities[n].values()]) for n in documents.keys()]
125 | ments_per_ent = [[len(e.sentNo.split(':')) for e in entities[n].values()] for n in documents.keys()]
126 | sents_per_doc = [len(s) for s in documents.values()]
127 | sent_len = [len(a.split()) for s in documents.values() for a in s]
128 |
129 | # write data
130 | with open('/'.join(args.data.split('/')[:-1]) + '/' + args.data.split('/')[-1].split('.')[0] + '.gold', 'w') as outfile:
131 | for id_ in relations.keys():
132 | for k, p in relations[id_].items():
133 | PairInfo = recordtype('PairInfo', 'type direction cross closeA closeB')
134 | outfile.write('{}|{}|{}|{}|{}\n'.format(id_, k[0], k[1], p.cross, p.type))
135 |
136 | print('''
137 | ----------------------- {} ----------------------
138 | Documents {}
139 | '''.format(d, docs))
140 |
141 | print(' Pairs')
142 |
143 | for x in ['{:<10}\t{:<5}'.format(k, v) for k, v in sorted(pair_types.items())]:
144 | print(' {}'.format(x))
145 | print()
146 |
147 | print(' Entities')
148 | for x in ['{:<10}\t{:<5}'.format(k, v) for k, v in sorted(ent_types.items())]:
149 | print(' {}'.format(x))
150 | print()
151 |
152 | print(' Mentions')
153 | for x in ['{:<10}\t{:<5}'.format(k, v) for k, v in sorted(men_types.items())]:
154 | print(' {}'.format(x))
155 | print()
156 |
157 | print(' Intra Pairs')
158 | for x in ['{:<10}\t{:<5}'.format(k, v) for k, v in sorted(intra_types.items())]:
159 | print(' {}'.format(x))
160 | print()
161 |
162 | print(' Inter Pairs')
163 | for x in ['{:<10}\t{:<5}'.format(k, v) for k, v in sorted(inter_types.items())]:
164 | print(' {}'.format(x))
165 | print()
166 |
167 | print(' Average/Max Sentence Distance')
168 | for x in ['{:<10}\t{:.1f}\t{}'.format(k, np.average(v), np.max(v)) for k, v in sorted(dist.items())]:
169 | print(' {}'.format(x))
170 | print()
171 |
172 | print('''
173 | Average entites/doc {:.1f}
174 | Max {}
175 |
176 | Average mentions/doc {:.1f}
177 | Max {}
178 |
179 | Average mentions/entity {:.1f}
180 | Max {}
181 |
182 | Average sents/doc {:.1f}
183 | Max {}
184 |
185 | Average/max sent length {:.1f}
186 | Max {}
187 | '''.format(np.average(ents_per_doc),
188 | np.max(ents_per_doc),
189 | np.average(ments_per_doc),
190 | np.max(ments_per_doc),
191 | np.average([item for sublist in ments_per_ent for item in sublist]),
192 | np.max([item for sublist in ments_per_ent for item in sublist]),
193 | np.average(sents_per_doc),
194 | np.max(sents_per_doc),
195 | np.average(sent_len),
196 | np.max(sent_len)))
197 |
198 |
199 |
200 |
201 |
202 |
203 |
204 |
205 |
206 |
207 |
208 |
209 |
210 |
--------------------------------------------------------------------------------
/data_processing/tools.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 | """
4 | Created on %(date)s
5 |
6 | @author: fenia
7 | """
8 |
9 | import os
10 | import sys
11 | import re
12 | from recordtype import recordtype
13 | from networkx.algorithms.components.connected import connected_components
14 | from itertools import combinations
15 | import numpy as np
16 | from collections import OrderedDict
17 | from utils import to_graph, to_edges, using_split2
18 | from tqdm import tqdm
19 | sys.path.append('./common/genia-tagger-py/')
20 | from geniatagger import GENIATagger
21 |
22 | pwd = '/'.join(os.path.realpath(__file__).split('/')[:-1])
23 |
24 | genia_splitter = os.path.join("./common", "geniass")
25 | genia_tagger = GENIATagger(os.path.join("./common", "genia-tagger-py", "geniatagger-3.0.2", "geniatagger"))
26 |
27 |
28 | TextStruct = recordtype('TextStruct', 'pmid txt')
29 | EntStruct = recordtype('EntStruct', 'pmid name off1 off2 type kb_id sent_no word_id bio')
30 | RelStruct = recordtype('RelStruct', 'pmid type arg1 arg2')
31 | PairStruct = recordtype('PairStruct', 'pmid type arg1 arg2 dir cross closest')
32 |
33 |
34 | def generate_pairs(uents, type1, type2, true_rels):
35 | """
36 | Generate pairs (both positive & negative):
37 | Type1 - Type2 should have 1-1 association, e.g. [A, A] [B, C] --> (A,B), (A,C)
38 | Args:
39 | uents:
40 | type1: (list) with entity semantic types
41 | type2: (list) with entity semantic types
42 | true_rels:
43 | """
44 | pairs = OrderedDict()
45 | combs = combinations(uents, 2)
46 |
47 | unk = 0
48 | total_rels = len(true_rels)
49 | found_rels = 0
50 |
51 | for c in combs:
52 | # all pairs
53 | diff = 99999
54 |
55 | target = []
56 | for e1 in uents[c[0]]:
57 | for e2 in uents[c[1]]:
58 | # find most close pair to each other
59 | if e1.word_id[-1] <= e2.word_id[0]:
60 | if abs(e2.word_id[0] - e1.word_id[-1]) < diff:
61 | target = [e1, e2]
62 | diff = abs(e2.word_id[0] - e1.word_id[-1])
63 | else:
64 | if abs(e1.word_id[0] - e2.word_id[-1]) < diff:
65 | target = [e1, e2]
66 | diff = abs(e2.word_id[0] - e1.word_id[-1])
67 |
68 | if target[0].word_id[-1] <= target[1].word_id[0]: # A before B (in text)
69 | a1 = target[0]
70 | a2 = target[1]
71 | else: # B before A (in text)
72 | a1 = target[1]
73 | a2 = target[0]
74 |
75 | if c[0][0].startswith('UNK:') or c[1][0].startswith('UNK:'): # ignore non-grounded entities
76 | continue
77 |
78 | cross_res = find_cross(c, uents)
79 | not_found_rels = 0
80 |
81 | for tr in true_rels:
82 |
83 | # AB existing relation
84 | if list(set(tr.arg1).intersection(set(c[0]))) and list(set(tr.arg2).intersection(set(c[1]))):
85 | for t1, t2 in zip(type1, type2):
86 | if uents[c[0]][0].type == t1 and uents[c[1]][0].type == t2:
87 | pairs[(c[0], c[1])] = \
88 | PairStruct(tr.pmid, '1:' + tr.type + ':2', c[0], c[1], 'L2R', cross_res, (a1, a2))
89 | found_rels += 1
90 |
91 | # BA existing relation
92 | elif list(set(tr.arg1).intersection(set(c[1]))) and list(set(tr.arg2).intersection(set(c[0]))):
93 | for t1, t2 in zip(type1, type2):
94 | if uents[c[1]][0].type == t1 and uents[c[0]][0].type == t2:
95 | pairs[(c[1], c[0])] = \
96 | PairStruct(tr.pmid, '1:'+tr.type+':2', c[1], c[0], 'R2L', cross_res, (a2, a1))
97 | found_rels += 1
98 |
99 | # relation not found
100 | else:
101 | not_found_rels += 1
102 |
103 | # this pair does not have a relation
104 | if not_found_rels == total_rels:
105 | for t1, t2 in zip(type1, type2):
106 | if uents[c[0]][0].type == t1 and uents[c[1]][0].type == t2:
107 | pairs[(c[0], c[1])] = PairStruct(a1.pmid, '1:NR:2', c[0], c[1], 'L2R', cross_res, (a1, a2))
108 | unk += 1
109 | elif uents[c[1]][0].type == t1 and uents[c[0]][0].type == t2:
110 | pairs[(c[1], c[0])] = PairStruct(a1.pmid, '1:NR:2', c[1], c[0], 'R2L', cross_res, (a2, a1))
111 | unk += 1
112 |
113 | assert found_rels == total_rels, '{} <> {}, {}, {}'.format(found_rels, total_rels, true_rels, pairs)
114 |
115 | # # Checking
116 | # if found_rels != total_rels:
117 | # print('NON-FOUND RELATIONS: {} <> {}'.format(found_rels, total_rels))
118 | # for p in true_rels:
119 | # if (p.arg1, p.arg2) not in pairs:
120 | # print(p.arg1, p.arg2)
121 | return pairs
122 |
123 |
124 | def convert2sent(arg1, arg2, token_sents):
125 | """
126 | Convert document info to sentence (for pairs in same sentence).
127 | Args:
128 | arg1:
129 | arg2:
130 | token_sents:
131 | """
132 | # make sure they are in the same sentence
133 | assert arg1.sent_no == arg2.sent_no, 'error: entities not in the same sentence'
134 |
135 | toks_per_sent = []
136 | sent_offs = []
137 | cnt = 0
138 | for i, s in enumerate(token_sents):
139 | toks_per_sent.append(len(s.split(' ')))
140 | sent_offs.append((cnt, cnt+len(s.split(' '))-1))
141 | cnt = len(' '.join(token_sents[:i+1]).split(' '))
142 |
143 | target_sent = token_sents[arg1.sent_no].split(' ')
144 | n = sum(toks_per_sent[0:arg1.sent_no])
145 |
146 | arg1_span = [a-n for a in arg1.word_id]
147 | arg2_span = [a-n for a in arg2.word_id]
148 | assert target_sent[arg1_span[0]:arg1_span[-1]+1] == \
149 | ' '.join(token_sents).split(' ')[arg1.word_id[0]:arg1.word_id[-1]+1]
150 | assert target_sent[arg2_span[0]:arg2_span[-1]+1] == \
151 | ' '.join(token_sents).split(' ')[arg2.word_id[0]:arg2.word_id[-1]+1]
152 |
153 | arg1_n = EntStruct(arg1.pmid, arg1.name, arg1.off1, arg1.off2, arg1.type,
154 | arg1.kb_id, arg1.sent_no, arg1_span, arg1.bio)
155 | arg2_n = EntStruct(arg2.pmid, arg2.name, arg2.off1, arg2.off2, arg2.type,
156 | arg2.kb_id, arg2.sent_no, arg2_span, arg2.bio)
157 |
158 | return arg1_n, arg2_n
159 |
160 |
161 | def find_cross(pair, unique_ents):
162 | """
163 | Find if the pair is in cross or non-cross sentence.
164 | Args:
165 | pair: (tuple) target pair
166 | unique_ents: (dic) entities based on grounded IDs
167 | Returns: (str) cross/non-cross
168 | """
169 | non_cross = False
170 | for m1 in unique_ents[pair[0]]:
171 | for m2 in unique_ents[pair[1]]:
172 | if m1.sent_no == m2.sent_no:
173 | non_cross = True
174 | if non_cross:
175 | return 'NON-CROSS'
176 | else:
177 | return 'CROSS'
178 |
179 |
180 | def fix_sent_break(sents, entities):
181 | """
182 | Fix sentence break + Find sentence of each entity
183 | Args:
184 | sents: (list) sentences
185 | entities: (recordtype)
186 | Returns: (list) sentences with fixed sentence breaks
187 | """
188 | sents_break = '\n'.join(sents)
189 |
190 | for e in entities:
191 | if '\n' in sents_break[e.off1:e.off2]:
192 | sents_break = sents_break[0:e.off1] + sents_break[e.off1:e.off2].replace('\n', ' ') + sents_break[e.off2:]
193 | return sents_break.split('\n')
194 |
195 |
196 | def find_mentions(entities):
197 | """
198 | Find unique entities and their mentions
199 | Args:
200 | entities: (dic) a struct for each entity
201 | Returns: (dic) unique entities based on their grounded ID, if -1 ID=UNK:No
202 | """
203 | equivalents = []
204 | for e in entities:
205 | if e.kb_id not in equivalents:
206 | equivalents.append(e.kb_id)
207 |
208 | # mention-level data sets
209 | g = to_graph(equivalents)
210 | cc = connected_components(g)
211 |
212 | unique_entities = OrderedDict()
213 | unk_id = 0
214 | for c in cc:
215 | if tuple(c)[0] == '-1':
216 | continue
217 | unique_entities[tuple(c)] = []
218 |
219 | # consider non-grounded entities as separate entities
220 | for e in entities:
221 | if e.kb_id[0] == '-1':
222 | unique_entities[tuple(('UNK:' + str(unk_id),))] = [e]
223 | unk_id += 1
224 | else:
225 | for ue in unique_entities.keys():
226 | if list(set(e.kb_id).intersection(set(ue))):
227 | unique_entities[ue] += [e]
228 |
229 | return unique_entities
230 |
231 |
232 | def sentence_split_genia(tabst):
233 | """
234 | Sentence Splitting Using GENIA sentence splitter
235 | Args:
236 | tabst: (list) title+abstract
237 |
238 | Returns: (list) all sentences in abstract
239 | """
240 | os.chdir(genia_splitter)
241 |
242 | with open('temp_file.txt', 'w') as ofile:
243 | for t in tabst:
244 | ofile.write(t+'\n')
245 | os.system('./geniass temp_file.txt temp_file.split.txt > /dev/null 2>&1')
246 |
247 | split_lines = []
248 | with open('temp_file.split.txt', 'r') as ifile:
249 | for line in ifile:
250 | line = line.rstrip()
251 | if line != '':
252 | split_lines.append(line.rstrip())
253 | os.system('rm temp_file.txt temp_file.split.txt')
254 | os.chdir(pwd)
255 | return split_lines
256 |
257 |
258 | def tokenize_genia(sents):
259 | """
260 | Tokenization using Genia Tokenizer
261 | Args:
262 | sents: (list) sentences
263 |
264 | Returns: (list) tokenized sentences
265 | """
266 | token_sents = []
267 | for i, s in enumerate(sents):
268 | tokens = []
269 |
270 | for word, base_form, pos_tag, chunk, named_entity in genia_tagger.tag(s):
271 | tokens += [word]
272 |
273 | text = []
274 | for t in tokens:
275 | if t == "'s":
276 | text.append(t)
277 | elif t == "''":
278 | text.append(t)
279 | else:
280 | text.append(t.replace("'", " ' "))
281 |
282 | text = ' '.join(text)
283 | text = text.replace("-LRB-", '(')
284 | text = text.replace("-RRB-", ')')
285 | text = text.replace("-LSB-", '[')
286 | text = text.replace("-RSB-", ']')
287 | text = text.replace("``", '"')
288 | text = text.replace("`", "'")
289 | text = text.replace("'s", " 's")
290 | text = text.replace('-', ' - ')
291 | text = text.replace('/', ' / ')
292 | text = text.replace('+', ' + ')
293 | text = text.replace('.', ' . ')
294 | text = text.replace('=', ' = ')
295 | text = text.replace('*', ' * ')
296 | if '&' in s:
297 | text = text.replace("&", "&")
298 | else:
299 | text = text.replace("&", "&")
300 |
301 | text = re.sub(' +', ' ', text).strip() # remove continuous spaces
302 |
303 | if "''" in ''.join(s):
304 | pass
305 | else:
306 | text = text.replace("''", '"')
307 |
308 | token_sents.append(text)
309 | return token_sents
310 |
311 |
312 | def adjust_offsets(old_sents, new_sents, old_entities, show=False):
313 | """
314 | Adjust offsets based on tokenization
315 | Args:
316 | old_sents: (list) old, non-tokenized sentences
317 | new_sents: (list) new, tokenized sentences
318 | old_entities: (dic) entities with old offsets
319 | Returns:
320 | new_entities: (dic) entities with adjusted offsets
321 | abst_seq: (list) abstract sequence with entity tags
322 | """
323 | cur = 0
324 | new_sent_range = []
325 | for s in new_sents:
326 | new_sent_range += [(cur, cur + len(s))]
327 | cur += len(s) + 1
328 |
329 | original = " ".join(old_sents)
330 | newtext = " ".join(new_sents)
331 | new_entities = []
332 | terms = {}
333 | for e in old_entities:
334 | start = int(e.off1)
335 | end = int(e.off2)
336 |
337 | if (start, end) not in terms:
338 | terms[(start, end)] = [[start, end, e.type, e.name, e.pmid, e.kb_id]]
339 | else:
340 | terms[(start, end)].append([start, end, e.type, e.name, e.pmid, e.kb_id])
341 |
342 | orgidx = 0
343 | newidx = 0
344 | orglen = len(original)
345 | newlen = len(newtext)
346 |
347 | terms2 = terms.copy()
348 | while orgidx < orglen and newidx < newlen:
349 | # print(repr(original[orgidx]), orgidx, repr(newtext[newidx]), newidx)
350 | if original[orgidx] == newtext[newidx]:
351 | orgidx += 1
352 | newidx += 1
353 | elif original[orgidx] == "`" and newtext[newidx] == "'":
354 | orgidx += 1
355 | newidx += 1
356 | elif newtext[newidx] == '\n':
357 | newidx += 1
358 | elif original[orgidx] == '\n':
359 | orgidx += 1
360 | elif newtext[newidx] == ' ':
361 | newidx += 1
362 | elif original[orgidx] == ' ':
363 | orgidx += 1
364 | elif newtext[newidx] == '\t':
365 | newidx += 1
366 | elif original[orgidx] == '\t':
367 | orgidx += 1
368 | elif newtext[newidx] == '.':
369 | # ignore extra "." for stanford
370 | newidx += 1
371 | else:
372 | print("Non-existent text: %d\t --> %s != %s " % (orgidx, repr(original[orgidx-10:orgidx+10]),
373 | repr(newtext[newidx-10:newidx+10])))
374 | exit(0)
375 |
376 | starts = [key[0] for key in terms2.keys()]
377 | ends = [key[1] for key in terms2.keys()]
378 |
379 | if orgidx in starts:
380 | tt = [key for key in terms2.keys() if key[0] == orgidx]
381 | for sel in tt:
382 | for l in terms[sel]:
383 | l[0] = newidx
384 |
385 | if orgidx in ends:
386 | tt2 = [key for key in terms2.keys() if key[1] == orgidx]
387 | for sel2 in tt2:
388 | for l in terms[sel2]:
389 | if l[1] == orgidx:
390 | l[1] = newidx
391 |
392 | for t_ in tt2:
393 | del terms2[t_]
394 |
395 | ent_sequences = []
396 | for ts in terms.values():
397 | for term in ts:
398 | condition = False
399 |
400 | if newtext[term[0]:term[1]].replace(" ", "").replace("\n", "") != term[3].replace(" ", "").replace('\n', ''):
401 | if newtext[term[0]:term[1]].replace(" ", "").replace("\n", "").lower() == \
402 | term[3].replace(" ", "").replace('\n', '').lower():
403 | condition = True
404 | tqdm.write('DOC_ID {}, Lowercase Issue: {} <-> {}'.format(term[4], newtext[term[0]:term[1]], term[3]))
405 | else:
406 | condition = False
407 | else:
408 | condition = True
409 |
410 | if condition:
411 | """ Convert to word Ids """
412 | tok_seq = []
413 | span2append = []
414 | bio = []
415 | tag = term[2]
416 |
417 | for tok_id, (tok, start, end) in enumerate(using_split2(newtext)):
418 | start = int(start)
419 | end = int(end)
420 |
421 | if (start, end) == (term[0], term[1]):
422 | bio.append('B-' + tag)
423 | tok_seq.append('B-' + tag)
424 | span2append.append(tok_id)
425 |
426 | elif start == term[0] and end < term[1]:
427 | bio.append('B-' + tag)
428 | tok_seq.append('B-' + tag)
429 | span2append.append(tok_id)
430 |
431 | elif start > term[0] and end < term[1]:
432 | bio.append('I-' + tag)
433 | tok_seq.append('I-' + tag)
434 | span2append.append(tok_id)
435 |
436 | elif start > term[0] and end == term[1] and (start != end):
437 | bio.append('I-' + tag)
438 | tok_seq.append('I-' + tag)
439 | span2append.append(tok_id)
440 |
441 | elif len(set(range(start, end)).intersection(set(range(term[0], term[1])))) > 0:
442 | span2append.append(tok_id)
443 |
444 | if show:
445 | tqdm.write('DOC_ID {}, entity: {:<20} ({:4}-{:4}) '
446 | '<-> token: {:<20} ({:4}-{:4}) <-> final: {:<20}'.format(
447 | term[4], newtext[term[0]:term[1]], term[0], term[1], tok, start, end,
448 | ' '.join(newtext.split(' ')[span2append[0]:span2append[-1]+1])))
449 |
450 | if not bio:
451 | bio.append('B-' + tag)
452 | tok_seq.append('B-' + tag)
453 | else:
454 | bio.append('I-' + tag)
455 | tok_seq.append('I-' + tag)
456 |
457 | else:
458 | tok_seq.append('O')
459 |
460 | ent_sequences += [tok_seq]
461 |
462 | # inlude all tokens!!
463 | if len(span2append) != len(newtext[term[0]:term[1]].split(' ')):
464 | tqdm.write('DOC_ID {}, entity {}, tokens {}\n{}'.format(
465 | term[4], newtext[term[0]:term[1]].split(' '), span2append, newtext))
466 |
467 | # Find sentence number of each entity
468 | sent_no = []
469 | for s_no, sr in enumerate(new_sent_range):
470 | if set(np.arange(term[0], term[1])).issubset(set(np.arange(sr[0], sr[1]))):
471 | sent_no += [s_no]
472 |
473 | assert (len(sent_no) == 1), '{} ({}, {}) -- {} -- {} <> {}'.format(sent_no, term[0], term[1],
474 | new_sent_range,
475 | newtext[term[0]:term[1]], term[3])
476 |
477 | new_entities += [EntStruct(term[4], newtext[term[0]:term[1]], term[0], term[1], term[2], term[5],
478 | sent_no[0], span2append, bio)]
479 | else:
480 | print(newtext, term[3])
481 | assert False, 'ERROR: {} ({}-{}) <=> {}'.format(repr(newtext[term[0]:term[1]]), term[0], term[1],
482 | repr(term[3]))
483 | return new_entities
484 |
--------------------------------------------------------------------------------
/data_processing/utils.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 | """
4 | Created on %(date)s
5 |
6 | @author: fenia
7 | """
8 |
9 | import networkx
10 |
11 |
12 | def to_graph(l):
13 | """
14 | https://stackoverflow.com/questions/4842613/merge-lists-that-share-common-elements
15 | """
16 | G = networkx.Graph()
17 | for part in l:
18 | # each sublist is a bunch of nodes
19 | G.add_nodes_from(part)
20 | # it also implies a number of edges:
21 | G.add_edges_from(to_edges(part))
22 | return G
23 |
24 |
25 | def to_edges(l):
26 | """
27 | treat `l` as a Graph and returns it's edges
28 | to_edges(['a','b','c','d']) -> [(a,b), (b,c),(c,d)]
29 | """
30 | it = iter(l)
31 | last = next(it)
32 |
33 | for current in it:
34 | yield last, current
35 | last = current
36 |
37 |
38 | def using_split2(line, _len=len):
39 | """
40 | Credits to https://stackoverflow.com/users/1235039/aquavitae
41 |
42 | :param line: sentence
43 | :return: a list of words and their indexes in a string.
44 | """
45 | words = line.split(' ')
46 | index = line.index
47 | offsets = []
48 | append = offsets.append
49 | running_offset = 0
50 | for word in words:
51 | word_offset = index(word, running_offset)
52 | word_len = _len(word)
53 | running_offset = word_offset + word_len
54 | append((word, word_offset, running_offset))
55 | return offsets
56 |
57 |
58 | def replace2symbol(string):
59 | string = string.replace('”', '"').replace('’', "'").replace('–', '-').replace('‘', "'").replace('‑', '-').replace(
60 | '\x92', "'").replace('»', '"').replace('—', '-').replace('\uf8fe', ' ').replace('«', '"').replace(
61 | '\uf8ff', ' ').replace('£', '#').replace('\u2028', ' ').replace('\u2029', ' ')
62 |
63 | return string
64 |
65 |
66 | def replace2space(string):
67 | spaces = ["\r", '\xa0', '\xe2\x80\x85', '\xc2\xa0', '\u2009', '\u2002', '\u200a', '\u2005', '\u2003', '\u2006',
68 | 'Ⅲ', '…', 'Ⅴ', "\u202f"]
69 |
70 | for i in spaces:
71 | string = string.replace(i, ' ')
72 | return string
73 |
74 |
--------------------------------------------------------------------------------
/evaluation/evaluate.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 | """
4 | Created on 13/05/2019
5 |
6 | author: fenia
7 | """
8 |
9 | import argparse
10 |
11 |
12 | def prf(tp, fp, fn):
13 | micro_p = float(tp) / (tp + fp) if (tp + fp != 0) else 0.0
14 | micro_r = float(tp) / (tp + fn) if (tp + fn != 0) else 0.0
15 | micro_f = ((2 * micro_p * micro_r) / (micro_p + micro_r)) if micro_p != 0.0 and micro_r != 0.0 else 0.0
16 | return [micro_p, micro_r, micro_f]
17 |
18 |
19 | def main():
20 | parser = argparse.ArgumentParser()
21 | parser.add_argument('--gold', type=str)
22 | parser.add_argument('--pred', type=str)
23 | parser.add_argument('--label', type=str)
24 | args = parser.parse_args()
25 |
26 | with open(args.pred) as pred, open(args.gold) as gold:
27 | preds_all = []
28 | preds_intra = []
29 | preds_inter = []
30 |
31 | golds_all = []
32 | golds_intra = []
33 | golds_inter = []
34 |
35 | for line in pred:
36 | line = line.rstrip().split('|')
37 | if line[5] == args.label:
38 |
39 | if (line[0], line[1], line[2], line[3], line[5]) not in preds_all:
40 | preds_all += [(line[0], line[1], line[2], line[3], line[5])]
41 |
42 | if ((line[0], line[1], line[2], line[5]) not in preds_inter) and (line[3] == 'CROSS'):
43 | preds_inter += [(line[0], line[1], line[2], line[5])]
44 |
45 | if ((line[0], line[1], line[2], line[5]) not in preds_intra) and (line[3] == 'NON-CROSS'):
46 | preds_intra += [(line[0], line[1], line[2], line[5])]
47 |
48 | for line2 in gold:
49 | line2 = line2.rstrip().split('|')
50 |
51 | if line2[4] == args.label:
52 |
53 | if (line2[0], line2[1], line2[2], line2[3], line2[4]) not in golds_all:
54 | golds_all += [(line2[0], line2[1], line2[2], line2[3], line2[4])]
55 |
56 | if ((line2[0], line2[1], line2[2], line2[4]) not in golds_inter) and (line2[3] == 'CROSS'):
57 | golds_inter += [(line2[0], line2[1], line2[2], line2[4])]
58 |
59 | if ((line2[0], line2[1], line2[2], line2[4]) not in golds_intra) and (line2[3] == 'NON-CROSS'):
60 | golds_intra += [(line2[0], line2[1], line2[2], line2[4])]
61 |
62 | tp = len([a for a in preds_all if a in golds_all])
63 | tp_intra = len([a for a in preds_intra if a in golds_intra])
64 | tp_inter = len([a for a in preds_inter if a in golds_inter])
65 |
66 | fp = len([a for a in preds_all if a not in golds_all])
67 | fp_intra = len([a for a in preds_intra if a not in golds_intra])
68 | fp_inter = len([a for a in preds_inter if a not in golds_inter])
69 |
70 | fn = len([a for a in golds_all if a not in preds_all])
71 | fn_intra = len([a for a in golds_intra if a not in preds_intra])
72 | fn_inter = len([a for a in golds_inter if a not in preds_inter])
73 |
74 | r1 = prf(tp, fp, fn)
75 | r2 = prf(tp_intra, fp_intra, fn_intra)
76 | r3 = prf(tp_inter, fp_inter, fn_inter)
77 |
78 | print(' TOTAL\tTP\tFP\tFN')
79 | print('Overall P/R/F1\t{:.4f}\t{:.4f}\t{:.4f}\t| {}\t{}\t{}\t{}'.format(r1[0], r1[1], r1[2],
80 | tp + fn, tp, fp, fn))
81 | print('INTRA P/R/F1\t{:.4f}\t{:.4f}\t{:.4f}\t| {}\t{}\t{}\t{}'.format(r2[0], r2[1], r2[2],
82 | tp_intra + fn_intra,
83 | tp_intra, fp_intra,
84 | fn_intra))
85 | print('INTER P/R/F1\t{:.4f}\t{:.4f}\t{:.4f}\t| {}\t{}\t{}\t{}'.format(r3[0], r3[1], r3[2],
86 | tp_inter + fn_inter,
87 | tp_inter, fp_inter,
88 | fn_inter))
89 |
90 |
91 | if __name__ == "__main__":
92 | main()
93 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | torch==1.1.0
2 | numpy==1.16.3
3 | matplotlib
4 | tqdm
5 | recordtype
6 | yamlordereddictloader
7 | tabulate
8 | networkx
9 |
--------------------------------------------------------------------------------
/src/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/fenchri/edge-oriented-graph/629b84a630146e81ffe15cbf6f5e5cf4efd9fb34/src/__init__.py
--------------------------------------------------------------------------------
/src/__pycache__/converter.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/fenchri/edge-oriented-graph/629b84a630146e81ffe15cbf6f5e5cf4efd9fb34/src/__pycache__/converter.cpython-36.pyc
--------------------------------------------------------------------------------
/src/__pycache__/dataset.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/fenchri/edge-oriented-graph/629b84a630146e81ffe15cbf6f5e5cf4efd9fb34/src/__pycache__/dataset.cpython-36.pyc
--------------------------------------------------------------------------------
/src/__pycache__/loader.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/fenchri/edge-oriented-graph/629b84a630146e81ffe15cbf6f5e5cf4efd9fb34/src/__pycache__/loader.cpython-36.pyc
--------------------------------------------------------------------------------
/src/__pycache__/reader.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/fenchri/edge-oriented-graph/629b84a630146e81ffe15cbf6f5e5cf4efd9fb34/src/__pycache__/reader.cpython-36.pyc
--------------------------------------------------------------------------------
/src/__pycache__/utils.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/fenchri/edge-oriented-graph/629b84a630146e81ffe15cbf6f5e5cf4efd9fb34/src/__pycache__/utils.cpython-36.pyc
--------------------------------------------------------------------------------
/src/bin/run.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | expA=("ME MS ES SS-ind"
4 | "MM MS ES SS-ind"
5 | "MM ME ES SS-ind"
6 | "MM ME MS SS-ind"
7 | "MM ME MS ES"
8 | "MM ME MS ES SS-ind"
9 | "MM ME"
10 | "ES SS-ind")
11 |
12 | #options=("--context"
13 | # "--types"
14 | # "--distances"
15 | # "--context --types"
16 | # "--context --distances"
17 | # "--distances --types")
18 |
19 |
20 | options=("--types --dist")
21 |
22 |
23 | for o in "${options[@]}":
24 | do
25 | counter=0
26 | for w in 1 2 3 4;
27 | do
28 | # EOG + SS-ind
29 | python3 eog.py --config ../configs/parameters_cdr.yaml --train \
30 | --edges MM ME MS ES SS-ind \
31 | --walks ${w} "${o}" \
32 | --gpu ${counter} &
33 | counter=$((counter+1))
34 |
35 | # EOG + SS
36 | python3 eog.py --config ../configs/parameters_cdr.yaml --train \
37 | --edges MM ME MS ES SS \
38 | --walks ${w} "${o}" \
39 | --gpu ${counter} &
40 | counter=$((counter+1))
41 | done
42 | # wait
43 | done
44 |
45 |
46 | wait
47 |
48 |
49 | for w in 1 2 3 4;
50 | do
51 | # rest
52 | counter=0
53 | for edg in "${expA[@]}";
54 | do
55 | python3 eog.py --config ../configs/parameters_cdr.yaml --train \
56 | --edges "${edg}" \
57 | --walks ${w} "${o}" \
58 | --gpu ${counter} &
59 | counter=$((counter+1))
60 | done
61 | wait
62 | done
63 |
64 |
65 | # fully connected
66 | for w in 1 2 3 4;
67 | do
68 | counter=0
69 | for edg in "${expA[@]}";
70 | do
71 | python3 eog.py --config ../configs/parameters_cdr.yaml --train \
72 | --edges "FULL" \
73 | --walks ${w} "${o}" \
74 | --gpu ${counter} &
75 | counter=$((counter+1))
76 | done
77 | done
78 |
79 |
80 | # sentence
81 | for w in 1 2 3 4;
82 | do
83 | counter=0
84 | for edg in "${expA[@]}";
85 | do
86 | python3 eog.py --config ../configs/parametes_cdr.yaml --train \
87 | --edges MM ME MS ES SS-ind \
88 | --walks ${w} \
89 | --gpu ${counter} \
90 | --walks ${w} "${o}" \
91 | --window 1 &
92 | counter=$((counter+1))
93 | done
94 | done
95 |
96 | wait
97 |
98 |
99 |
100 | # no inference
101 | #counter=0
102 | #for w in 0 1 2 3 4 5;
103 | #do
104 | # python3 eog.py --config ../configs/parameters_cdr.yaml --train \
105 | # --edges "EE" \
106 | # --walks ${w} \
107 | # --gpu ${counter} &
108 | # counter=$((counter+1))
109 | #done
110 |
111 |
112 | # sentence (different options, EOG model)
113 | #python3 eog.py --config ../configs/parameters_cdr.yaml --train \
114 | # --edges MM ME MS ES SS-ind \
115 | # --walks 3 \
116 | # --types \
117 | # --context \
118 | # --window 1 \
119 | # --gpu 0 &
120 | #
121 | #python3 eog.py --config ../configs/parameters_cdr.yaml --train \
122 | # --edges MM ME MS ES SS-ind \
123 | # --walks 3 \
124 | # --types \
125 | # --window 1 \
126 | # --gpu 1 &
127 | #
128 | #python3 eog.py --config ../configs/parameters_cdr.yaml --train \
129 | # --edges MM ME MS ES SS-ind \
130 | # --walks 3 \
131 | # --context \
132 | # --window 1 \
133 | # --gpu 2 &
134 | #
135 | #python3 eog.py --config ../configs/parameters_cdr.yaml --train \
136 | # --edges MM ME MS ES SS-ind \
137 | # --walks 3 \
138 | # --window 1 \
139 | # --gpu 3 &
140 |
141 |
142 |
--------------------------------------------------------------------------------
/src/converter.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/python3
2 |
3 | import torch
4 | import six
5 | import numpy
6 |
7 |
8 | def to_device(device, x):
9 | if device is None:
10 | return torch.as_tensor(x).long()
11 | return torch.as_tensor(x).long().to(device)
12 |
13 |
14 | def concat_examples(batch, device=None, padding=-1):
15 | assert device is None or isinstance(device, torch.device)
16 | if len(batch) == 0:
17 | raise ValueError('batch is empty')
18 |
19 | first_elem = batch[0]
20 |
21 | if isinstance(first_elem, tuple):
22 | result = []
23 | if not isinstance(padding, tuple):
24 | padding = [padding] * len(first_elem)
25 |
26 | for i in six.moves.range(len(first_elem)):
27 | result.append(to_device(device, _concat_arrays(
28 | [example[i] for example in batch], padding[i])))
29 |
30 | return tuple(result)
31 |
32 | elif isinstance(first_elem, dict):
33 | result = {}
34 | if not isinstance(padding, dict):
35 | padding = {key: padding for key in first_elem}
36 |
37 | for key in first_elem:
38 | result[key] = to_device(device, _concat_arrays(
39 | [example[key] for example in batch], padding[key]))
40 |
41 | return result
42 |
43 | else:
44 | return to_device(device, _concat_arrays(batch, padding))
45 |
46 |
47 | def _concat_arrays(arrays, padding):
48 | # Convert `arrays` to numpy.ndarray if `arrays` consists of the built-in
49 | # types such as int, float or list.
50 | if not isinstance(arrays[0], type(torch.get_default_dtype())):
51 | arrays = numpy.asarray(arrays)
52 |
53 | if padding is not None:
54 | arr_concat = _concat_arrays_with_padding(arrays, padding)
55 | else:
56 | arr_concat = numpy.concatenate([array[None] for array in arrays])
57 |
58 | return arr_concat
59 |
60 |
61 | def _concat_arrays_with_padding(arrays, padding):
62 | shape = numpy.array(arrays[0].shape, dtype=int)
63 | for array in arrays[1:]:
64 | if numpy.any(shape != array.shape):
65 | numpy.maximum(shape, array.shape, shape)
66 | shape = tuple(numpy.insert(shape, 0, len(arrays)))
67 |
68 | result = numpy.full(shape, padding, dtype=arrays[0].dtype)
69 | for i in six.moves.range(len(arrays)):
70 | src = arrays[i]
71 | slices = tuple(slice(dim) for dim in src.shape)
72 | result[(i,) + slices] = src
73 |
74 | return result
75 |
--------------------------------------------------------------------------------
/src/dataset.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 | """
4 | Created on 13/06/2019
5 |
6 | author: fenia
7 | """
8 |
9 | import random
10 | random.seed(0)
11 | import numpy as np
12 | np.random.seed(0)
13 | import numpy as np
14 | from tqdm import tqdm
15 | from collections import OrderedDict
16 | from torch.utils.data import Dataset
17 |
18 |
19 | class DocRelationDataset:
20 | """
21 | My simpler converter approach, stores everything in a list and then just iterates to create batches.
22 | """
23 | def __init__(self, loader, data_type, params, mappings):
24 | self.unk_w_prob = params['unk_w_prob']
25 | self.mappings = mappings
26 | self.loader = loader
27 | self.data_type = data_type
28 | self.edges = params['edges']
29 | self.data = []
30 |
31 | def __len__(self):
32 | return len(self.data)
33 |
34 | def __call__(self):
35 | pbar = tqdm(self.loader.documents.keys())
36 | for pmid in pbar:
37 | pbar.set_description(' Preparing {} data - PMID {}'.format(self.data_type.upper(), pmid))
38 |
39 | # TEXT
40 | doc = []
41 | for sentence in self.loader.documents[pmid]:
42 | sent = []
43 | if self.data_type == 'train':
44 | for w, word in enumerate(sentence):
45 | if (word in self.mappings.singletons) and (random.uniform(0, 1) < float(self.unk_w_prob)):
46 | sent += [self.mappings.word2index['']] # UNK words = singletons for train
47 | else:
48 | sent += [self.mappings.word2index[word]]
49 |
50 | else:
51 |
52 | for w, word in enumerate(sentence):
53 | if word in self.mappings.word2index:
54 | sent += [self.mappings.word2index[word]]
55 | else:
56 | sent += [self.mappings.word2index['']]
57 | assert len(sentence) == len(sent), '{}, {}'.format(len(sentence), len(sent))
58 | doc += [sent]
59 |
60 | # ENTITIES [id, type, start, end] + NODES [id, type, start, end, node_type_id]
61 | nodes = []
62 | ent = []
63 | for id_, (e, i) in enumerate(self.loader.entities[pmid].items()):
64 | nodes += [[id_, self.mappings.type2index[i.type], int(i.mstart.split(':')[0]),
65 | int(i.mend.split(':')[0]), i.sentNo.split(':')[0], 0]]
66 |
67 | for id_, (e, i) in enumerate(self.loader.entities[pmid].items()):
68 | for sent_id, m1, m2 in zip(i.sentNo.split(':'), i.mstart.split(':'), i.mend.split(':')):
69 | ent += [[id_, self.mappings.type2index[i.type], int(m1), int(m2), int(sent_id)]]
70 | nodes += [[id_, self.mappings.type2index[i.type], int(m1), int(m2), int(sent_id), 1]]
71 |
72 | for s, sentence in enumerate(self.loader.documents[pmid]):
73 | nodes += [[s, s, s, s, s, 2]]
74 |
75 | nodes = np.array(nodes, 'i')
76 | ent = np.array(ent, 'i')
77 |
78 | # RELATIONS
79 | ents_keys = list(self.loader.entities[pmid].keys()) # in order
80 | trel = -1 * np.ones((len(ents_keys), len(ents_keys)), 'i')
81 | rel_info = np.empty((len(ents_keys), len(ents_keys)), dtype='object_')
82 | for id_, (r, i) in enumerate(self.loader.pairs[pmid].items()):
83 | if i.type == 'not_include':
84 | continue
85 | trel[ents_keys.index(r[0]), ents_keys.index(r[1])] = self.mappings.rel2index[i.type]
86 | rel_info[ents_keys.index(r[0]), ents_keys.index(r[1])] = OrderedDict(
87 | [('pmid', pmid),
88 | ('sentA', self.loader.entities[pmid][r[0]].sentNo),
89 | ('sentB',
90 | self.loader.entities[pmid][r[1]].sentNo),
91 | ('doc', self.loader.documents[pmid]),
92 | ('entA', self.loader.entities[pmid][r[0]]),
93 | ('entB', self.loader.entities[pmid][r[1]]),
94 | ('rel', self.mappings.rel2index[i.type]),
95 | ('dir', i.direction),
96 | ('cross', i.cross)])
97 |
98 | #######################
99 | # DISTANCES
100 | #######################
101 | xv, yv = np.meshgrid(np.arange(nodes.shape[0]), np.arange(nodes.shape[0]), indexing='ij')
102 |
103 | r_id, c_id = nodes[xv, 5], nodes[yv, 5]
104 | r_Eid, c_Eid = nodes[xv, 0], nodes[yv, 0]
105 | r_Sid, c_Sid = nodes[xv, 4], nodes[yv, 4]
106 | r_Ms, c_Ms = nodes[xv, 2], nodes[yv, 2]
107 | r_Me, c_Me = nodes[xv, 3]-1, nodes[yv, 3]-1
108 |
109 | ignore_pos = self.mappings.n_dist
110 | self.mappings.dist2index[ignore_pos] = ignore_pos
111 | self.mappings.index2dist[ignore_pos] = ignore_pos
112 |
113 | dist = np.full((r_id.shape[0], r_id.shape[0]), ignore_pos, 'i')
114 |
115 | # MM: mention-mention
116 | a_start = np.where((r_id == 1) & (c_id == 1), r_Ms, -1)
117 | a_end = np.where((r_id == 1) & (c_id == 1), r_Me, -1)
118 | b_start = np.where((r_id == 1) & (c_id == 1), c_Ms, -1)
119 | b_end = np.where((r_id == 1) & (c_id == 1), c_Me, -1)
120 |
121 | dist = np.where((a_end < b_start) & (a_end != -1) & (b_start != -1), abs(b_start - a_end), dist)
122 | dist = np.where((b_end < a_start) & (b_end != -1) & (a_start != -1), abs(b_end - a_start), dist)
123 |
124 | # nested (find the distance between their last words)
125 | dist = np.where((b_start <= a_start) & (b_end >= a_end)
126 | & (b_start != -1) & (a_end != -1) & (b_end != -1) & (a_start != -1), abs(b_end-a_end), dist)
127 | dist = np.where((b_start >= a_start) & (b_end <= a_end)
128 | & (b_start != -1) & (a_end != -1) & (b_end != -1) & (a_start != -1), abs(a_end-b_end), dist)
129 |
130 | # diagonal
131 | dist[np.arange(nodes.shape[0]), np.arange(nodes.shape[0])] = 0
132 |
133 | # limit max distance according to training set
134 | dist = np.where(dist > self.mappings.max_distance, self.mappings.max_distance, dist)
135 |
136 | # restrictions: to MM pairs in the same sentence
137 | dist = np.where(((r_id == 1) & (c_id == 1) & (r_Sid == c_Sid)), dist, ignore_pos)
138 |
139 | # SS: sentence-sentence
140 | dist = np.where(((r_id == 2) & (c_id == 2)), abs(c_Sid - r_Sid), dist)
141 |
142 | #######################
143 | # GRAPH CONNECTIONS
144 | #######################
145 | adjacency = np.full((r_id.shape[0], r_id.shape[0]), 0, 'i')
146 |
147 | if 'FULL' in self.edges:
148 | adjacency = np.full(adjacency.shape, 1, 'i')
149 |
150 | if 'MM' in self.edges:
151 | # mention-mention
152 | adjacency = np.where((r_id == 1) & (c_id == 1) & (r_Sid == c_Sid), 1, adjacency) # in same sentence
153 |
154 | if ('EM' in self.edges) or ('ME' in self.edges):
155 | # entity-mention
156 | adjacency = np.where((r_id == 0) & (c_id == 1) & (r_Eid == c_Eid), 1, adjacency) # belongs to entity
157 | adjacency = np.where((r_id == 1) & (c_id == 0) & (r_Eid == c_Eid), 1, adjacency)
158 |
159 | if 'SS' in self.edges:
160 | # sentence-sentence (in order)
161 | adjacency = np.where((r_id == 2) & (c_id == 2) & (r_Sid == c_Sid - 1), 1, adjacency)
162 | adjacency = np.where((r_id == 2) & (c_id == 2) & (c_Sid == r_Sid - 1), 1, adjacency)
163 |
164 | if 'SS-ind' in self.edges:
165 | # sentence-sentence (direct + indirect)
166 | adjacency = np.where((r_id == 2) & (c_id == 2), 1, adjacency)
167 |
168 | if ('MS' in self.edges) or ('SM' in self.edges):
169 | # mention-sentence
170 | adjacency = np.where((r_id == 1) & (c_id == 2) & (r_Sid == c_Sid), 1, adjacency) # belongs to sentence
171 | adjacency = np.where((r_id == 2) & (c_id == 1) & (r_Sid == c_Sid), 1, adjacency)
172 |
173 | if ('ES' in self.edges) or ('SE' in self.edges):
174 | # entity-sentence
175 | for x, y in zip(xv.ravel(), yv.ravel()):
176 | if nodes[x, 5] == 0 and nodes[y, 5] == 2: # this is an entity-sentence edge
177 | z = np.where((r_Eid == nodes[x, 0]) & (r_id == 1) & (c_id == 2) & (c_Sid == nodes[y, 4]))
178 |
179 | # at least one M in S
180 | temp_ = np.where((r_id == 1) & (c_id == 2) & (r_Sid == c_Sid), 1, adjacency)
181 | temp_ = np.where((r_id == 2) & (c_id == 1) & (r_Sid == c_Sid), 1, temp_)
182 | adjacency[x, y] = 1 if (temp_[z] == 1).any() else 0
183 | adjacency[y, x] = 1 if (temp_[z] == 1).any() else 0
184 |
185 | if 'EE' in self.edges:
186 | adjacency = np.where((r_id == 0) & (c_id == 0), 1, adjacency)
187 |
188 | # self-loops = 0 [always]
189 | adjacency[np.arange(r_id.shape[0]), np.arange(r_id.shape[0])] = 0
190 |
191 | dist = list(map(lambda y: self.mappings.dist2index[y], dist.ravel().tolist())) # map
192 | dist = np.array(dist, 'i').reshape((nodes.shape[0], nodes.shape[0]))
193 |
194 | if (trel == -1).all(): # no relations --> ignore
195 | continue
196 |
197 | self.data += [{'ents': ent, 'rels': trel, 'dist': dist, 'text': doc, 'info': rel_info,
198 | 'adjacency': adjacency,
199 | 'section': np.array([len(self.loader.entities[pmid].items()), ent.shape[0], len(doc)], 'i'),
200 | 'word_sec': np.array([len(s) for s in doc]),
201 | 'words': np.hstack([np.array(s, 'i') for s in doc])}]
202 | return self.data
203 |
--------------------------------------------------------------------------------
/src/eog.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 | """
4 | Created on 21-Feb-2019
5 |
6 | author: fenia
7 | """
8 |
9 | import torch
10 | import random
11 | import numpy as np
12 | from dataset import DocRelationDataset
13 | from loader import DataLoader, ConfigLoader
14 | from nnet.trainer import Trainer
15 | from utils import setup_log, save_model, load_model, plot_learning_curve, load_mappings
16 |
17 |
18 | def set_seed(seed):
19 | torch.manual_seed(seed)
20 | torch.cuda.manual_seed(seed)
21 | torch.cuda.manual_seed_all(seed) # if you are using multi-GPU
22 | np.random.seed(seed) # Numpy module
23 | random.seed(seed) # Python random module
24 | torch.backends.cudnn.benchmark = False
25 | torch.backends.cudnn.deterministic = True
26 |
27 |
28 | def train(parameters):
29 | model_folder = setup_log(parameters, 'train')
30 |
31 | set_seed(parameters['seed'])
32 |
33 | ###################################
34 | # Data Loading
35 | ###################################
36 | print('Loading training data ...')
37 | train_loader = DataLoader(parameters['train_data'], parameters)
38 | train_loader(embeds=parameters['embeds'])
39 | train_data = DocRelationDataset(train_loader, 'train', parameters, train_loader).__call__()
40 |
41 | print('\nLoading testing data ...')
42 | test_loader = DataLoader(parameters['test_data'], parameters)
43 | test_loader()
44 | test_data = DocRelationDataset(test_loader, 'test', parameters, train_loader).__call__()
45 |
46 | ###################################
47 | # Training
48 | ###################################
49 | trainer = Trainer(train_loader, parameters, {'train': train_data, 'test': test_data}, model_folder)
50 | trainer.run()
51 |
52 | if parameters['plot']:
53 | plot_learning_curve(trainer, model_folder)
54 |
55 | if parameters['save_model']:
56 | save_model(model_folder, trainer, train_loader)
57 |
58 |
59 | def test(parameters):
60 | model_folder = setup_log(parameters, 'test')
61 |
62 | print('\nLoading mappings ...')
63 | train_loader = load_mappings(model_folder)
64 |
65 | print('\nLoading testing data ...')
66 | test_loader = DataLoader(parameters['test_data'], parameters)
67 | test_loader()
68 | test_data = DocRelationDataset(test_loader, 'test', parameters, train_loader).__call__()
69 |
70 | m = Trainer(train_loader, parameters, {'train': [], 'test': test_data}, model_folder)
71 | trainer = load_model(model_folder, m)
72 | trainer.eval_epoch(final=True, save_predictions=True)
73 |
74 |
75 | def main():
76 | config = ConfigLoader()
77 | parameters = config.load_config()
78 |
79 | if parameters['train']:
80 | train(parameters)
81 |
82 | elif parameters['test']:
83 | test(parameters)
84 |
85 |
86 | if __name__ == "__main__":
87 | main()
88 |
89 |
--------------------------------------------------------------------------------
/src/loader.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 | """
4 | Created on 21-Feb-2019
5 |
6 | author: fenia
7 | """
8 |
9 | import numpy as np
10 | import argparse
11 | import yaml
12 | import yamlordereddictloader
13 | from collections import OrderedDict
14 | from reader import read, read_subdocs
15 |
16 |
17 | def str2bool(i):
18 | if isinstance(i, bool):
19 | return i
20 | if i.lower() in ('yes', 'true', 't', 'y', '1'):
21 | return True
22 | elif i.lower() in ('no', 'false', 'f', 'n', '0'):
23 | return False
24 | else:
25 | raise argparse.ArgumentTypeError('Boolean value expected.')
26 |
27 |
28 | class ConfigLoader:
29 | def __init__(self):
30 | pass
31 |
32 | @staticmethod
33 | def load_cmd():
34 | parser = argparse.ArgumentParser()
35 | parser.add_argument('--config', type=str, required=True, help='Yaml parameter file')
36 | parser.add_argument('--train', action='store_true', help='Training mode - model is saved')
37 | parser.add_argument('--test', action='store_true', help='Testing mode - needs a model to load')
38 | parser.add_argument('--gpu', type=int, help='GPU number')
39 | parser.add_argument('--walks', type=int, help='Number of walk iterations')
40 | parser.add_argument('--window', type=int, help='Window for training (empty processes the whole document, '
41 | '1 processes 1 sentence at a time, etc)')
42 | parser.add_argument('--edges', nargs='*', help='Edge types')
43 | parser.add_argument('--types', type=str2bool, help='Include node types (Boolean)')
44 | parser.add_argument('--context', type=str2bool, help='Include MM context (Boolean)')
45 | parser.add_argument('--dist', type=str2bool, help='Include distance (Boolean)')
46 | parser.add_argument('--example', help='Show example', action='store_true')
47 | parser.add_argument('--seed', help='Fixed random seed number', type=int)
48 | parser.add_argument('--early_stop', action='store_true', help='Use early stopping')
49 | parser.add_argument('--epoch', type=int, help='Maximum training epoch')
50 | return parser.parse_args()
51 |
52 | def load_config(self):
53 | inp = self.load_cmd()
54 | with open(vars(inp)['config'], 'r') as f:
55 | parameters = yaml.load(f, Loader=yamlordereddictloader.Loader)
56 |
57 | parameters = dict(parameters)
58 | if not inp.train and not inp.test:
59 | print('Please specify train/test mode.')
60 | sys.exit(0)
61 |
62 | parameters['train'] = inp.train
63 | parameters['test'] = inp.test
64 | parameters['gpu'] = inp.gpu
65 | parameters['example'] = inp.example
66 |
67 | if inp.walks and inp.walks >= 0:
68 | parameters['walks_iter'] = inp.walks
69 |
70 | if inp.edges:
71 | parameters['edges'] = inp.edges
72 |
73 | if inp.types != None:
74 | parameters['types'] = inp.types
75 |
76 | if inp.dist != None:
77 | parameters['dist'] = inp.dist
78 |
79 | if inp.window:
80 | parameters['window'] = inp.window
81 |
82 | if inp.context != None:
83 | parameters['context'] = inp.context
84 |
85 | if inp.seed:
86 | parameters['seed'] = inp.seed
87 |
88 | if inp.epoch:
89 | parameters['epoch'] = inp.epoch
90 |
91 | if inp.early_stop:
92 | parameters['early_stop'] = True
93 |
94 | return parameters
95 |
96 |
97 | class DataLoader:
98 | def __init__(self, input_file, params):
99 | self.input = input_file
100 | self.params = params
101 |
102 | self.pre_words = []
103 | self.pre_embeds = OrderedDict()
104 | self.max_distance = -9999999999
105 | self.singletons = []
106 | self.label2ignore = -1
107 | self.ign_label = self.params['label2ignore']
108 |
109 | self.word2index, self.index2word, self.n_words, self.word2count = {'': 0}, {0: ''}, 1, {'': 1}
110 | self.type2index, self.index2type, self.n_type, self.type2count = {'': 0, '': 1, '': 2}, \
111 | {0: '', 1: '', 2: ''}, 3, \
112 | {'': 1, '': 1, '': 1}
113 | self.rel2index, self.index2rel, self.n_rel, self.rel2count = {}, {}, 0, {}
114 | self.dist2index, self.index2dist, self.n_dist, self.dist2count = {}, {}, 0, {}
115 | self.documents, self.entities, self.pairs = OrderedDict(), OrderedDict(), OrderedDict()
116 |
117 | def find_ignore_label(self):
118 | """
119 | Find relation Id to ignore
120 | """
121 | for key, val in self.index2rel.items():
122 | if val == self.ign_label:
123 | self.label2ignore = key
124 | assert self.label2ignore != -1
125 |
126 | @staticmethod
127 | def check_nested(p):
128 | starts1 = list(map(int, p[8].split(':')))
129 | ends1 = list(map(int, p[9].split(':')))
130 |
131 | starts2 = list(map(int, p[14].split(':')))
132 | ends2 = list(map(int, p[15].split(':')))
133 |
134 | for s1, e1, s2, e2 in zip(starts1, ends1, starts2, ends2):
135 | if bool(set(np.arange(s1, e1)) & set(np.arange(s2, e2))):
136 | print('nested pair', p)
137 |
138 | def find_singletons(self, min_w_freq=1):
139 | """
140 | Find items with frequency <= 2 and based on probability
141 | """
142 | self.singletons = frozenset([elem for elem, val in self.word2count.items()
143 | if (val <= min_w_freq) and elem != ''])
144 |
145 | def add_relation(self, rel):
146 | if rel not in self.rel2index:
147 | self.rel2index[rel] = self.n_rel
148 | self.rel2count[rel] = 1
149 | self.index2rel[self.n_rel] = rel
150 | self.n_rel += 1
151 | else:
152 | self.rel2count[rel] += 1
153 |
154 | def add_word(self, word):
155 | if word not in self.word2index:
156 | self.word2index[word] = self.n_words
157 | self.word2count[word] = 1
158 | self.index2word[self.n_words] = word
159 | self.n_words += 1
160 | else:
161 | self.word2count[word] += 1
162 |
163 | def add_type(self, type):
164 | if type not in self.type2index:
165 | self.type2index[type] = self.n_type
166 | self.type2count[type] = 1
167 | self.index2type[self.n_type] = type
168 | self.n_type += 1
169 | else:
170 | self.type2count[type] += 1
171 |
172 | def add_dist(self, dist):
173 | if dist not in self.dist2index:
174 | self.dist2index[dist] = self.n_dist
175 | self.dist2count[dist] = 1
176 | self.index2dist[self.n_dist] = dist
177 | self.n_dist += 1
178 | else:
179 | self.dist2count[dist] += 1
180 |
181 | def add_sentence(self, sentence):
182 | for word in sentence:
183 | self.add_word(word)
184 |
185 | def add_document(self, document):
186 | for sentence in document:
187 | self.add_sentence(sentence)
188 |
189 | def load_embeds(self, word_dim):
190 | """
191 | Load pre-trained word embeddings if specified
192 | """
193 | self.pre_embeds = OrderedDict()
194 | with open(self.params['embeds'], 'r') as vectors:
195 | for x, line in enumerate(vectors):
196 |
197 | if x == 0 and len(line.split()) == 2:
198 | words, num = map(int, line.rstrip().split())
199 | else:
200 | word = line.rstrip().split()[0]
201 | vec = line.rstrip().split()[1:]
202 |
203 | n = len(vec)
204 | if n != word_dim:
205 | print('Wrong dimensionality! -- line No{}, word: {}, len {}'.format(x, word, n))
206 | continue
207 | self.add_word(word)
208 | self.pre_embeds[word] = np.asarray(vec, 'f')
209 | self.pre_words = [w for w, e in self.pre_embeds.items()]
210 | print(' Found pre-trained word embeddings: {} x {}'.format(len(self.pre_embeds), word_dim), end="")
211 |
212 | def find_max_length(self, lengths):
213 | self.max_distance = max(lengths) - 1
214 |
215 | def read_n_map(self):
216 | """
217 | Read input.
218 | Lengths is the max distance for each document
219 | """
220 | if not self.params['window']:
221 | lengths, sents, self.documents, self.entities, self.pairs = \
222 | read(self.input, self.documents, self.entities, self.pairs)
223 | else:
224 | lengths, sents, self.documents, self.entities, self.pairs = \
225 | read_subdocs(self.input, self.params['window'], self.documents, self.entities, self.pairs)
226 |
227 | self.find_max_length(lengths)
228 |
229 | # map types and positions and relation types
230 | for did, d in self.documents.items():
231 | self.add_document(d)
232 |
233 | for did, e in self.entities.items():
234 | for k, v in e.items():
235 | self.add_type(v.type)
236 |
237 | for dist in np.arange(0, self.max_distance+1):
238 | self.add_dist(dist)
239 |
240 | for did, p in self.pairs.items():
241 | for k, v in p.items():
242 | if v.type == 'not_include':
243 | continue
244 | self.add_relation(v.type)
245 | assert len(self.entities) == len(self.documents) == len(self.pairs)
246 |
247 | def statistics(self):
248 | """
249 | Print statistics for the dataset
250 | """
251 | print(' Documents: {:<5}\n Words: {:<5}'.format(len(self.documents), self.n_words))
252 |
253 | print(' Relations: {}'.format(sum([v for k, v in self.rel2count.items()])))
254 | for k, v in sorted(self.rel2count.items()):
255 | print('\t{:<10}\t{:<5}\tID: {}'.format(k, v, self.rel2index[k]))
256 |
257 | print(' Entities: {}'.format(sum([len(e) for e in self.entities.values()])))
258 | for k, v in sorted(self.type2count.items()):
259 | print('\t{:<10}\t{:<5}\tID: {}'.format(k, v, self.type2index[k]))
260 |
261 | print(' Singletons: {}/{}'.format(len(self.singletons), self.n_words))
262 |
263 | def __call__(self, embeds=None):
264 | self.read_n_map()
265 | self.find_ignore_label()
266 | self.find_singletons(self.params['min_w_freq']) # words with freq=1
267 | self.statistics()
268 | if embeds:
269 | self.load_embeds(self.params['word_dim'])
270 | print(' --> Words + Pre-trained: {:<5}'.format(self.n_words))
271 |
--------------------------------------------------------------------------------
/src/nnet/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/fenchri/edge-oriented-graph/629b84a630146e81ffe15cbf6f5e5cf4efd9fb34/src/nnet/__init__.py
--------------------------------------------------------------------------------
/src/nnet/__pycache__/__init__.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/fenchri/edge-oriented-graph/629b84a630146e81ffe15cbf6f5e5cf4efd9fb34/src/nnet/__pycache__/__init__.cpython-36.pyc
--------------------------------------------------------------------------------
/src/nnet/__pycache__/attention.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/fenchri/edge-oriented-graph/629b84a630146e81ffe15cbf6f5e5cf4efd9fb34/src/nnet/__pycache__/attention.cpython-36.pyc
--------------------------------------------------------------------------------
/src/nnet/__pycache__/init_net.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/fenchri/edge-oriented-graph/629b84a630146e81ffe15cbf6f5e5cf4efd9fb34/src/nnet/__pycache__/init_net.cpython-36.pyc
--------------------------------------------------------------------------------
/src/nnet/__pycache__/modules.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/fenchri/edge-oriented-graph/629b84a630146e81ffe15cbf6f5e5cf4efd9fb34/src/nnet/__pycache__/modules.cpython-36.pyc
--------------------------------------------------------------------------------
/src/nnet/__pycache__/network.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/fenchri/edge-oriented-graph/629b84a630146e81ffe15cbf6f5e5cf4efd9fb34/src/nnet/__pycache__/network.cpython-36.pyc
--------------------------------------------------------------------------------
/src/nnet/__pycache__/trainer.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/fenchri/edge-oriented-graph/629b84a630146e81ffe15cbf6f5e5cf4efd9fb34/src/nnet/__pycache__/trainer.cpython-36.pyc
--------------------------------------------------------------------------------
/src/nnet/__pycache__/walks.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/fenchri/edge-oriented-graph/629b84a630146e81ffe15cbf6f5e5cf4efd9fb34/src/nnet/__pycache__/walks.cpython-36.pyc
--------------------------------------------------------------------------------
/src/nnet/attention.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | """
3 | Created on 29/03/19
4 |
5 | author: fenia
6 | """
7 |
8 | import torch
9 | from torch import nn, torch
10 | import math
11 |
12 |
13 | class Dot_Attention(nn.Module):
14 | """
15 | Adaptation from "Attention is all you need".
16 | Here the query is the target pair and the keys/values are the words of the sentence.
17 | The dimensionality of the queries and the values should be the same.
18 | """
19 | def __init__(self, input_size, device=-1, scale=False):
20 |
21 | super(Dot_Attention, self).__init__()
22 |
23 | self.softmax = nn.Softmax(dim=2)
24 | self.scale = scale
25 | if scale:
26 | self.sc = 1.0 / math.sqrt(input_size)
27 | self.device = device
28 |
29 | def create_mask(self, alpha, size_, lengths, idx_):
30 | """ Put 1 in valid tokens """
31 | mention_sents = torch.index_select(lengths, 0, idx_[:, 4])
32 |
33 | # mask padded words (longer that sentence length)
34 | tempa = torch.arange(size_).unsqueeze(0).repeat(alpha.shape[0], 1).to(self.device)
35 | mask_a = torch.ge(tempa, mention_sents[:, None])
36 |
37 | # mask tokens that are used as queries
38 | tempb = torch.arange(lengths.size(0)).unsqueeze(0).repeat(alpha.shape[0], 1).to(self.device) # m x sents
39 | sents = torch.where(torch.lt(tempb, idx_[:, 4].unsqueeze(1)),
40 | lengths.unsqueeze(0).repeat(alpha.shape[0], 1),
41 | torch.zeros_like(lengths.unsqueeze(0).repeat(alpha.shape[0], 1)))
42 |
43 | total_l = torch.cumsum(sents, dim=1)[:, -1]
44 | mask_b = torch.ge(tempa, (idx_[:, 2] - total_l)[:, None]) & torch.lt(tempa, (idx_[:, 3] - total_l)[:, None])
45 |
46 | mask = ~(mask_a | mask_b)
47 | del tempa, tempb, total_l
48 | return mask
49 |
50 | def forward(self, queries, values, idx, lengths):
51 | """
52 | a = softmax( q * H^T )
53 | v = a * H
54 | """
55 | alpha = torch.matmul(queries.unsqueeze(1), values.transpose(1, 2))
56 |
57 | if self.scale:
58 | alpha = alpha * self.sc
59 |
60 | mask_ = self.create_mask(alpha, values.size(1), lengths, idx)
61 | alpha = torch.where(mask_.unsqueeze(1),
62 | alpha,
63 | torch.as_tensor([float('-inf')]).to(self.device))
64 |
65 | # in case all context words are masked (e.g. only 2 mentions in the sentence) -- a naive fix
66 | alpha = torch.where(torch.isinf(alpha).all(dim=2, keepdim=True), torch.full_like(alpha, 1.0), alpha)
67 | alpha = self.softmax(alpha)
68 | alpha = torch.where(torch.eq(alpha, 1/alpha.shape[2]).all(dim=2, keepdim=True), torch.zeros_like(alpha), alpha)
69 |
70 | alpha = torch.squeeze(alpha, 1)
71 | return alpha
72 |
--------------------------------------------------------------------------------
/src/nnet/init_net.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 | """
4 | Created on 21-Feb-2019
5 |
6 | author: fenia
7 | """
8 |
9 | import torch
10 | from torch import nn
11 | from nnet.modules import EmbedLayer, Encoder, Classifier
12 | from nnet.attention import Dot_Attention
13 | from nnet.walks import WalkLayer
14 |
15 |
16 | class BaseNet(nn.Module):
17 | def __init__(self, params, pembeds, sizes=None, maps=None, lab2ign=None):
18 | super(BaseNet, self).__init__()
19 |
20 | self.edg = ['MM', 'SS', 'ME', 'MS', 'ES', 'EE']
21 |
22 | self.dims = {}
23 | for k in self.edg:
24 | self.dims[k] = 4 * params['lstm_dim']
25 |
26 | self.device = torch.device("cuda:{}".format(params['gpu']) if params['gpu'] != -1 else "cpu")
27 |
28 | self.encoder = Encoder(input_size=params['word_dim'],
29 | rnn_size=params['out_dim'],
30 | num_layers=1,
31 | bidirectional=True,
32 | dropout=0.0)
33 |
34 | self.word_embed = EmbedLayer(num_embeddings=sizes['word_size'],
35 | embedding_dim=params['word_dim'],
36 | dropout=params['drop_i'],
37 | ignore=None,
38 | freeze=params['freeze_words'],
39 | pretrained=pembeds,
40 | mapping=maps['word2idx'])
41 |
42 | if params['dist']:
43 | self.dims['MM'] += params['dist_dim']
44 | self.dims['SS'] += params['dist_dim']
45 | self.dist_embed = EmbedLayer(num_embeddings=sizes['dist_size'] + 1,
46 | embedding_dim=params['dist_dim'],
47 | dropout=0.0,
48 | ignore=sizes['dist_size'],
49 | freeze=False,
50 | pretrained=None,
51 | mapping=None)
52 |
53 | if params['context']:
54 | self.dims['MM'] += (2 * params['lstm_dim'])
55 | self.attention = Dot_Attention(input_size=2 * params['lstm_dim'],
56 | device=self.device,
57 | scale=False)
58 |
59 | if params['types']:
60 | for k in self.edg:
61 | self.dims[k] += (2 * params['type_dim'])
62 |
63 | self.type_embed = EmbedLayer(num_embeddings=3,
64 | embedding_dim=params['type_dim'],
65 | dropout=0.0,
66 | freeze=False,
67 | pretrained=None,
68 | mapping=None)
69 |
70 | self.reduce = nn.ModuleDict()
71 | for k in self.edg:
72 | if k != 'EE':
73 | self.reduce.update({k: nn.Linear(self.dims[k], params['out_dim'], bias=False)})
74 | elif (('EE' in params['edges']) or ('FULL' in params['edges'])) and (k == 'EE'):
75 | self.ee = True
76 | self.reduce.update({k: nn.Linear(self.dims[k], params['out_dim'], bias=False)})
77 | else:
78 | self.ee = False
79 |
80 | if params['walks_iter'] and params['walks_iter'] > 0:
81 | self.walk = WalkLayer(input_size=params['out_dim'],
82 | iters=params['walks_iter'],
83 | beta=params['beta'],
84 | device=self.device)
85 |
86 | self.classifier = Classifier(in_size=params['out_dim'],
87 | out_size=sizes['rel_size'],
88 | dropout=params['drop_o'])
89 | self.loss = nn.CrossEntropyLoss()
90 |
91 | # hyper-parameters for tuning
92 | self.beta = params['beta']
93 | self.dist_dim = params['dist_dim']
94 | self.type_dim = params['type_dim']
95 | self.drop_i = params['drop_i']
96 | self.drop_o = params['drop_o']
97 | self.gradc = params['gc']
98 | self.learn = params['lr']
99 | self.reg = params['reg']
100 | self.out_dim = params['out_dim']
101 |
102 | # other parameters
103 | self.mappings = {'word': maps['word2idx'], 'type': maps['type2idx'], 'dist': maps['dist2idx']}
104 | self.inv_mappings = {'word': maps['idx2word'], 'type': maps['idx2type'], 'dist': maps['idx2dist']}
105 | self.word_dim = params['word_dim']
106 | self.lstm_dim = params['lstm_dim']
107 | self.walks_iter = params['walks_iter']
108 | self.rel_size = sizes['rel_size']
109 | self.types = params['types']
110 | self.ignore_label = lab2ign
111 | self.context = params['context']
112 | self.dist = params['dist']
113 |
114 |
--------------------------------------------------------------------------------
/src/nnet/modules.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 | """
4 | Created on 06-Mar-2019
5 |
6 | author: fenia
7 | """
8 |
9 | import torch
10 | from torch import nn, torch
11 | from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
12 |
13 |
14 | class EmbedLayer(nn.Module):
15 | def __init__(self, num_embeddings, embedding_dim, dropout, ignore=None, freeze=False, pretrained=None, mapping=None):
16 | """
17 | Args:
18 | num_embeddings: (tensor) number of unique items
19 | embedding_dim: (int) dimensionality of vectors
20 | dropout: (float) dropout rate
21 | trainable: (bool) train or not
22 | pretrained: (dict) pretrained embeddings
23 | mapping: (dict) mapping of items to unique ids
24 | """
25 | super(EmbedLayer, self).__init__()
26 |
27 | self.embedding_dim = embedding_dim
28 | self.num_embeddings = num_embeddings
29 | self.freeze = freeze
30 |
31 | self.embedding = nn.Embedding(num_embeddings=num_embeddings,
32 | embedding_dim=embedding_dim,
33 | padding_idx=ignore)
34 |
35 | if pretrained:
36 | self.load_pretrained(pretrained, mapping)
37 | self.embedding.weight.requires_grad = not freeze
38 |
39 | self.drop = nn.Dropout(dropout)
40 |
41 | def load_pretrained(self, pretrained, mapping):
42 | """
43 | Args:
44 | weights: (dict) keys are words, values are vectors
45 | mapping: (dict) keys are words, values are unique ids
46 |
47 | Returns: updates the embedding matrix with pre-trained embeddings
48 | """
49 | for word in mapping.keys():
50 | if word in pretrained:
51 | self.embedding.weight.data[mapping[word], :] = torch.from_numpy(pretrained[word])
52 | elif word.lower() in pretrained:
53 | self.embedding.weight.data[mapping[word], :] = torch.from_numpy(pretrained[word.lower()])
54 |
55 | assert (self.embedding.weight[mapping['and']].to('cpu').data.numpy() == pretrained['and']).all(), \
56 | 'ERROR: Embeddings not assigned'
57 |
58 | def forward(self, xs):
59 | """
60 | Args:
61 | xs: (tensor) batchsize x word_ids
62 |
63 | Returns: (tensor) batchsize x word_ids x dimensionality
64 | """
65 | embeds = self.embedding(xs)
66 | if self.drop.p > 0:
67 | embeds = self.drop(embeds)
68 |
69 | return embeds
70 |
71 |
72 | class Encoder(nn.Module):
73 | def __init__(self, input_size, rnn_size, num_layers, bidirectional, dropout):
74 | """
75 | Wrapper for LSTM encoder
76 | Args:
77 | input_size (int): the size of the input features
78 | rnn_size (int):
79 | num_layers (int):
80 | bidirectional (bool):
81 | dropout (float):
82 | Returns: outputs, last_outputs
83 | - **outputs** of shape `(batch, seq_len, hidden_size)`:
84 | tensor containing the output features `(h_t)`
85 | from the last layer of the LSTM, for each t.
86 | - **last_outputs** of shape `(batch, hidden_size)`:
87 | tensor containing the last output features
88 | from the last layer of the LSTM, for each t=seq_len.
89 | """
90 | super(Encoder, self).__init__()
91 |
92 | self.enc = nn.LSTM(input_size=input_size,
93 | hidden_size=rnn_size,
94 | num_layers=num_layers,
95 | bidirectional=bidirectional,
96 | dropout=dropout,
97 | batch_first=True)
98 |
99 | # the dropout "layer" for the output of the RNN
100 | self.drop = nn.Dropout(dropout)
101 |
102 | # define output feature size
103 | self.feature_size = rnn_size
104 |
105 | if bidirectional:
106 | self.feature_size *= 2
107 |
108 | @staticmethod
109 | def sort(lengths):
110 | sorted_len, sorted_idx = lengths.sort() # indices that result in sorted sequence
111 | _, original_idx = sorted_idx.sort(0, descending=True)
112 | reverse_idx = torch.linspace(lengths.size(0) - 1, 0, lengths.size(0)).long() # for big-to-small
113 |
114 | return sorted_idx, original_idx, reverse_idx
115 |
116 | def forward(self, embeds, lengths, hidden=None):
117 | """
118 | This is the heart of the model. This function, defines how the data
119 | passes through the network.
120 | Args:
121 | embs (tensor): word embeddings
122 | lengths (list): the lengths of each sentence
123 | Returns: the logits for each class
124 | """
125 | # sort sequence
126 | sorted_idx, original_idx, reverse_idx = self.sort(lengths)
127 |
128 | # pad - sort - pack
129 | embeds = nn.utils.rnn.pad_sequence(embeds, batch_first=True, padding_value=0)
130 | embeds = embeds[sorted_idx][reverse_idx] # big-to-small
131 | packed = pack_padded_sequence(embeds, list(lengths[sorted_idx][reverse_idx].data), batch_first=True)
132 |
133 | self.enc.flatten_parameters()
134 | out_packed, _ = self.enc(packed, hidden)
135 |
136 | # unpack
137 | outputs, _ = pad_packed_sequence(out_packed, batch_first=True)
138 |
139 | # apply dropout to the outputs of the RNN
140 | outputs = self.drop(outputs)
141 |
142 | # unsort the list
143 | outputs = outputs[reverse_idx][original_idx][reverse_idx]
144 | return outputs
145 |
146 |
147 | class Classifier(nn.Module):
148 | def __init__(self, in_size, out_size, dropout):
149 | """
150 | Args:
151 | in_size: input tensor dimensionality
152 | out_size: outpout tensor dimensionality
153 | dropout: dropout rate
154 | """
155 | super(Classifier, self).__init__()
156 |
157 | self.drop = nn.Dropout(dropout)
158 | self.lin = nn.Linear(in_features=in_size,
159 | out_features=out_size,
160 | bias=True)
161 |
162 | def forward(self, xs):
163 | """
164 | Args:
165 | xs: (tensor) batchsize x * x features
166 |
167 | Returns: (tensor) batchsize x * x class_size
168 | """
169 | if self.drop.p > 0:
170 | xs = self.drop(xs)
171 |
172 | xs = self.lin(xs)
173 | return xs
174 |
175 |
176 |
177 |
178 |
179 |
180 |
--------------------------------------------------------------------------------
/src/nnet/network.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 | """
4 | Created on 25-Feb-2019
5 |
6 | author: fenia
7 | """
8 |
9 | import torch
10 | import torch.nn.functional as F
11 | from torch.nn.utils.rnn import pad_sequence
12 | from nnet.init_net import BaseNet
13 |
14 |
15 | class EOG(BaseNet):
16 | def input_layer(self, words_):
17 | """
18 | Word Embedding Layer
19 | """
20 | word_vec = self.word_embed(words_)
21 | return word_vec
22 |
23 | def encoding_layer(self, word_vec, word_sec):
24 | """
25 | Encoder Layer -> Encode sequences using BiLSTM.
26 | """
27 | ys = self.encoder(torch.split(word_vec, word_sec.tolist(), dim=0), word_sec)
28 | return ys
29 |
30 | def graph_layer(self, encoded_seq, info, word_sec, section, positions):
31 | """
32 | Graph Layer -> Construct a document-level graph
33 | The graph edges hold representations for the connections between the nodes.
34 | Args:
35 | encoded_seq: Encoded sequence, shape (sentences, words, dimension)
36 | info: (Tensor, 5 columns) entity_id, entity_type, start_wid, end_wid, sentence_id
37 | word_sec: (Tensor) number of words per sentence
38 | section: (Tensor ) #entities/#mentions/#sentences per batch
39 | positions: distances between nodes (only M-M and S-S)
40 |
41 | Returns: (Tensor) graph, (Tensor) tensor_mapping, (Tensors) indices, (Tensor) node information
42 | """
43 | # SENTENCE NODES
44 | sentences = torch.mean(encoded_seq, dim=1) # sentence nodes (avg of sentence words)
45 |
46 | # MENTION & ENTITY NODES
47 | temp_ = torch.arange(word_sec.max()).unsqueeze(0).repeat(sentences.size(0), 1).to(self.device)
48 | remove_pad = (temp_ < word_sec.unsqueeze(1))
49 |
50 | mentions = self.merge_tokens(info, encoded_seq, remove_pad) # mention nodes
51 | entities = self.merge_mentions(info, mentions) # entity nodes
52 |
53 | # all nodes in order: entities - mentions - sentences
54 | nodes = torch.cat((entities, mentions, sentences), dim=0) # e + m + s (all)
55 | nodes_info = self.node_info(section, info) # info/node: node type | semantic type | sentence ID
56 |
57 | if self.types: # + node types
58 | nodes = torch.cat((nodes, self.type_embed(nodes_info[:, 0])), dim=1)
59 |
60 | # re-order nodes per document (batch)
61 | nodes = self.rearrange_nodes(nodes, section)
62 | nodes = self.split_n_pad(nodes, section, pad=0)
63 |
64 | nodes_info = self.rearrange_nodes(nodes_info, section)
65 | nodes_info = self.split_n_pad(nodes_info, section, pad=-1)
66 |
67 | # create initial edges (concat node representations)
68 | r_idx, c_idx = torch.meshgrid(torch.arange(nodes.size(1)).to(self.device),
69 | torch.arange(nodes.size(1)).to(self.device))
70 | graph = torch.cat((nodes[:, r_idx], nodes[:, c_idx]), dim=3)
71 | r_id, c_id = nodes_info[..., 0][:, r_idx], nodes_info[..., 0][:, c_idx] # node type indicators
72 |
73 | # pair masks
74 | pid = self.pair_ids(r_id, c_id)
75 |
76 | # Linear reduction layers
77 | reduced_graph = torch.where(pid['MS'].unsqueeze(-1), self.reduce['MS'](graph),
78 | torch.zeros(graph.size()[:-1] + (self.out_dim,)).to(self.device))
79 | reduced_graph = torch.where(pid['ME'].unsqueeze(-1), self.reduce['ME'](graph), reduced_graph)
80 | reduced_graph = torch.where(pid['ES'].unsqueeze(-1), self.reduce['ES'](graph), reduced_graph)
81 |
82 | if self.dist:
83 | dist_vec = self.dist_embed(positions) # distances
84 | reduced_graph = torch.where(pid['SS'].unsqueeze(-1),
85 | self.reduce['SS'](torch.cat((graph, dist_vec), dim=3)), reduced_graph)
86 | else:
87 | reduced_graph = torch.where(pid['SS'].unsqueeze(-1), self.reduce['SS'](graph), reduced_graph)
88 |
89 | if self.context and self.dist:
90 | m_cntx = self.attention(mentions, encoded_seq[info[:, 4]], info, word_sec)
91 | m_cntx = self.prepare_mention_context(m_cntx, section, r_idx, c_idx,
92 | encoded_seq[info[:, 4]], pid, nodes_info)
93 |
94 | reduced_graph = torch.where(pid['MM'].unsqueeze(-1),
95 | self.reduce['MM'](torch.cat((graph, dist_vec, m_cntx), dim=3)), reduced_graph)
96 |
97 | elif self.context:
98 | m_cntx = self.attention(mentions, encoded_seq[info[:, 4]], info, word_sec)
99 | m_cntx = self.prepare_mention_context(m_cntx, section, r_idx, c_idx,
100 | encoded_seq[info[:, 4]], pid, nodes_info)
101 |
102 | reduced_graph = torch.where(pid['MM'].unsqueeze(-1),
103 | self.reduce['MM'](torch.cat((graph, m_cntx), dim=3)), reduced_graph)
104 |
105 | elif self.dist:
106 | reduced_graph = torch.where(pid['MM'].unsqueeze(-1),
107 | self.reduce['MM'](torch.cat((graph, dist_vec), dim=3)), reduced_graph)
108 |
109 | else:
110 | reduced_graph = torch.where(pid['MM'].unsqueeze(-1), self.reduce['MM'](graph), reduced_graph)
111 |
112 | if self.ee:
113 | reduced_graph = torch.where(pid['EE'].unsqueeze(-1), self.reduce['EE'](graph), reduced_graph)
114 |
115 | mask = self.get_nodes_mask(section.sum(dim=1))
116 | return reduced_graph, (r_idx, c_idx), nodes_info, mask
117 |
118 | def prepare_mention_context(self, m_cntx, section, r_idx, c_idx, s_seq, pid, nodes_info):
119 | """
120 | Estimate attention scores for each pair
121 | (a1 + a2)/2 * sentence_words
122 | """
123 | # "fake" mention weight nodes
124 | m_cntx = torch.cat((torch.zeros(section.sum(dim=0)[0], m_cntx.size(1)).to(self.device),
125 | m_cntx,
126 | torch.zeros(section.sum(dim=0)[2], m_cntx.size(1)).to(self.device)), dim=0)
127 | m_cntx = self.rearrange_nodes(m_cntx, section)
128 | m_cntx = self.split_n_pad(m_cntx, section, pad=0)
129 | m_cntx = torch.div(m_cntx[:, r_idx] + m_cntx[:, c_idx], 2)
130 |
131 | # mask non-MM pairs
132 | # mask invalid weights (i.e. M-M not in the same sentence)
133 | mask_ = torch.eq(nodes_info[..., 2][:, r_idx], nodes_info[..., 2][:, c_idx]) & pid['MM']
134 | m_cntx = torch.where(mask_.unsqueeze(-1), m_cntx, torch.zeros_like(m_cntx))
135 |
136 | # "fake" mention sentences nodes
137 | sents = torch.cat((torch.zeros(section.sum(dim=0)[0], m_cntx.size(3), s_seq.size(2)).to(self.device),
138 | s_seq,
139 | torch.zeros(section.sum(dim=0)[2], m_cntx.size(3), s_seq.size(2)).to(self.device)), dim=0)
140 | sents = self.rearrange_nodes(sents, section)
141 | sents = self.split_n_pad(sents, section, pad=0)
142 | m_cntx = torch.matmul(m_cntx, sents)
143 | return m_cntx
144 |
145 | @staticmethod
146 | def pair_ids(r_id, c_id):
147 | pids = {
148 | 'EE': ((r_id == 0) & (c_id == 0)),
149 | 'MM': ((r_id == 1) & (c_id == 1)),
150 | 'SS': ((r_id == 2) & (c_id == 2)),
151 | 'ES': (((r_id == 0) & (c_id == 2)) | ((r_id == 2) & (c_id == 0))),
152 | 'MS': (((r_id == 1) & (c_id == 2)) | ((r_id == 2) & (c_id == 1))),
153 | 'ME': (((r_id == 1) & (c_id == 0)) | ((r_id == 0) & (c_id == 1)))
154 | }
155 | return pids
156 |
157 | @staticmethod
158 | def rearrange_nodes(nodes, section):
159 | """
160 | Re-arrange nodes so that they are in 'Entity - Mention - Sentence' order for each document (batch)
161 | """
162 | tmp1 = section.t().contiguous().view(-1).long().to(nodes.device)
163 | tmp3 = torch.arange(section.numel()).view(section.size(1),
164 | section.size(0)).t().contiguous().view(-1).long().to(nodes.device)
165 | tmp2 = torch.arange(section.sum()).to(nodes.device).split(tmp1.tolist())
166 | tmp2 = pad_sequence(tmp2, batch_first=True, padding_value=-1)[tmp3].view(-1)
167 | tmp2 = tmp2[(tmp2 != -1).nonzero().squeeze()] # remove -1 (padded)
168 |
169 | nodes = torch.index_select(nodes, 0, tmp2)
170 | return nodes
171 |
172 | @staticmethod
173 | def split_n_pad(nodes, section, pad=None):
174 | nodes = torch.split(nodes, section.sum(dim=1).tolist())
175 | nodes = pad_sequence(nodes, batch_first=True, padding_value=pad)
176 | return nodes
177 |
178 | @staticmethod
179 | def get_nodes_mask(nodes_size):
180 | """
181 | Create mask for padded nodes
182 | """
183 | n_total = torch.arange(nodes_size.max()).to(nodes_size.device)
184 | idx_r, idx_c, idx_d = torch.meshgrid(n_total, n_total, n_total)
185 |
186 | # masks for padded elements (1 in valid, 0 in padded)
187 | ns = nodes_size[:, None, None, None]
188 | mask3d = ~(torch.ge(idx_r, ns) | torch.ge(idx_c, ns) | torch.ge(idx_d, ns))
189 | return mask3d
190 |
191 | def node_info(self, section, info):
192 | """
193 | Col 0: node type | Col 1: semantic type | Col 2: sentence id
194 | """
195 | typ = torch.repeat_interleave(torch.arange(3).to(self.device), section.sum(dim=0)) # node types (0,1,2)
196 | rows_ = torch.bincount(info[:, 0]).cumsum(dim=0).sub(1)
197 | stypes = torch.neg(torch.ones(section[:, 2].sum())).to(self.device).long() # semantic type sentences = -1
198 | all_types = torch.cat((info[:, 1][rows_], info[:, 1], stypes), dim=0)
199 | sents_ = torch.arange(section.sum(dim=0)[2]).to(self.device)
200 | sent_id = torch.cat((info[:, 4][rows_], info[:, 4], sents_), dim=0) # sent_id
201 | return torch.cat((typ.unsqueeze(-1), all_types.unsqueeze(-1), sent_id.unsqueeze(-1)), dim=1)
202 |
203 | def estimate_loss(self, pred_pairs, truth):
204 | """
205 | Softmax cross entropy loss.
206 | Args:
207 | pred_pairs (Tensor): Un-normalized pairs (# pairs, classes)
208 | truth (Tensor): Ground-truth labels (# pairs, id)
209 |
210 | Returns: (Tensor) loss, (Tensors) TP/FP/FN
211 | """
212 | mask = torch.ne(truth, -1)
213 | truth = truth[mask]
214 | pred_pairs = pred_pairs[mask]
215 |
216 | assert (truth != -1).all()
217 | loss = self.loss(pred_pairs, truth)
218 |
219 | predictions = F.softmax(pred_pairs, dim=1).data.argmax(dim=1)
220 | stats = self.count_predictions(predictions, truth)
221 | return loss, stats, predictions
222 |
223 | @staticmethod
224 | def merge_mentions(info, mentions):
225 | """
226 | Merge mentions into entities;
227 | Find which rows (mentions) have the same entity id and average them
228 | """
229 | m_ids, e_ids = torch.broadcast_tensors(info[:, 0].unsqueeze(0),
230 | torch.arange(0, max(info[:, 0]) + 1).unsqueeze(-1).to(info.device))
231 | index_m = torch.eq(m_ids, e_ids).type('torch.FloatTensor').to(info.device)
232 | entities = torch.div(torch.matmul(index_m, mentions), torch.sum(index_m, dim=1).unsqueeze(-1)) # average
233 | return entities
234 |
235 | @staticmethod
236 | def merge_tokens(info, enc_seq, rm_pad):
237 | """
238 | Merge tokens into mentions;
239 | Find which tokens belong to a mention (based on start-end ids) and average them
240 | """
241 | enc_seq = enc_seq[rm_pad]
242 | start, end, w_ids = torch.broadcast_tensors(info[:, 2].unsqueeze(-1),
243 | info[:, 3].unsqueeze(-1),
244 | torch.arange(0, enc_seq.shape[0]).unsqueeze(0).to(info.device))
245 | index_t = (torch.ge(w_ids, start) & torch.lt(w_ids, end)).float().to(info.device)
246 | mentions = torch.div(torch.matmul(index_t, enc_seq), torch.sum(index_t, dim=1).unsqueeze(-1)) # average
247 | return mentions
248 |
249 | @staticmethod
250 | def select_pairs(combs, nodes_info, idx):
251 | """
252 | Select (entity node) pairs for classification based on input parameter restrictions (i.e. their entity type).
253 | """
254 | combs = torch.split(combs, 2, dim=0)
255 | sel = torch.zeros(nodes_info.size(0), nodes_info.size(1), nodes_info.size(1)).to(nodes_info.device)
256 |
257 | a_ = nodes_info[..., 0][:, idx[0]]
258 | b_ = nodes_info[..., 0][:, idx[1]]
259 | c_ = nodes_info[..., 1][:, idx[0]]
260 | d_ = nodes_info[..., 1][:, idx[1]]
261 | for ca, cb in combs:
262 | condition1 = torch.eq(a_, 0) & torch.eq(b_, 0) # needs to be an entity node (id=0)
263 | condition2 = torch.eq(c_, ca) & torch.eq(d_, cb) # valid pair semantic types
264 | sel = torch.where(condition1 & condition2, torch.ones_like(sel), sel)
265 | return sel.nonzero().unbind(dim=1)
266 |
267 | def count_predictions(self, y, t):
268 | """
269 | Count number of TP, FP, FN, TN for each relation class
270 | """
271 | label_num = torch.as_tensor([self.rel_size]).long().to(self.device)
272 | ignore_label = torch.as_tensor([self.ignore_label]).long().to(self.device)
273 |
274 | mask_t = torch.eq(t, ignore_label).view(-1) # where the ground truth needs to be ignored
275 | mask_p = torch.eq(y, ignore_label).view(-1) # where the predicted needs to be ignored
276 |
277 | true = torch.where(mask_t, label_num, t.view(-1).long().to(self.device)) # ground truth
278 | pred = torch.where(mask_p, label_num, y.view(-1).long().to(self.device)) # output of NN
279 |
280 | tp_mask = torch.where(torch.eq(pred, true), true, label_num)
281 | fp_mask = torch.where(torch.ne(pred, true), pred, label_num)
282 | fn_mask = torch.where(torch.ne(pred, true), true, label_num)
283 |
284 | tp = torch.bincount(tp_mask, minlength=self.rel_size + 1)[:self.rel_size]
285 | fp = torch.bincount(fp_mask, minlength=self.rel_size + 1)[:self.rel_size]
286 | fn = torch.bincount(fn_mask, minlength=self.rel_size + 1)[:self.rel_size]
287 | tn = torch.sum(mask_t & mask_p)
288 | return {'tp': tp, 'fp': fp, 'fn': fn, 'tn': tn}
289 |
290 | def forward(self, batch):
291 | """
292 | Network Forward computation.
293 | Args:
294 | batch: dictionary with tensors
295 | Returns: (Tensors) loss, statistics, predictions, index
296 | """
297 | # Word Embeddings
298 | word_vec = self.input_layer(batch['words'])
299 |
300 | # Encoder
301 | encoded_seq = self.encoding_layer(word_vec, batch['word_sec'])
302 |
303 | # Graph
304 | graph, pindex, nodes_info, mask = self.graph_layer(encoded_seq, batch['entities'], batch['word_sec'],
305 | batch['section'], batch['distances'])
306 |
307 | # Inference/Walks
308 | if self.walks_iter and self.walks_iter > 0:
309 | graph = self.walk(graph, adj_=batch['adjacency'], mask_=mask)
310 |
311 | # Classification
312 | select = self.select_pairs(batch['pairs4class'], nodes_info, pindex)
313 | graph = self.classifier(graph[select])
314 |
315 | loss, stats, preds = self.estimate_loss(graph, batch['relations'][select].long())
316 | return loss, stats, preds, select
317 |
--------------------------------------------------------------------------------
/src/nnet/trainer.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 | """
4 | Created on 27-Feb-2019
5 |
6 | author: fenia
7 | """
8 |
9 | import torch
10 | import numpy as np
11 | import os
12 | from time import time
13 | import itertools
14 | import copy
15 | import datetime
16 | import random
17 | from random import shuffle
18 | from utils import print_results, write_preds, write_errors, print_options
19 | from converter import concat_examples
20 | from torch import autograd
21 | from nnet.network import EOG
22 | from torch import nn, optim
23 | import sys
24 | torch.set_printoptions(profile="full")
25 | np.set_printoptions(threshold=np.inf)
26 |
27 |
28 | class Trainer:
29 | def __init__(self, loader, params, data, model_folder):
30 | """
31 | Trainer object.
32 |
33 | Args:
34 | loader: loader object that holds information for training data
35 | params (dict): model parameters
36 | data (dict): 'train' and 'test' data
37 | """
38 | self.data = data
39 | self.params = params
40 | self.rel_size = loader.n_rel
41 | self.loader = loader
42 | self.model_folder = model_folder
43 |
44 | self.device = torch.device("cuda:{}".format(params['gpu']) if params['gpu'] != -1 else "cpu")
45 | self.gc = params['gc']
46 | self.epoch = params['epoch']
47 | self.example = params['example']
48 | self.pa = params['param_avg']
49 | self.es = params['early_stop']
50 | self.primary_metric = params['primary_metric']
51 | self.show_class = params['show_class']
52 | self.window = params['window']
53 | self.preds_file = os.path.join(model_folder, params['save_pred'])
54 | self.best_epoch = 0
55 |
56 | self.train_res = {'loss': [], 'score': []}
57 | self.test_res = {'loss': [], 'score': []}
58 |
59 | self.max_patience = self.params['patience']
60 | self.cur_patience = 0
61 | self.best_score = 0.0
62 | self.best_epoch = 0
63 |
64 | # pairs while training
65 | self.pairs4train = []
66 | for i in params['include_pairs']:
67 | m, n = i.split('-')
68 | self.pairs4train += [self.loader.type2index[m], self.loader.type2index[n]]
69 |
70 | # pairs to classify
71 | self.pairs4class = []
72 | for i in params['classify_pairs']:
73 | m, n = i.split('-')
74 | self.pairs4class += [self.loader.type2index[m], self.loader.type2index[n]]
75 |
76 | # parameter averaging
77 | if params['param_avg']:
78 | self.averaged_params = {}
79 |
80 | self.model = self.init_model()
81 | self.optimizer = self.set_optimizer(self.model)
82 |
83 | def init_model(self):
84 | model_0 = EOG(self.params, self.loader.pre_embeds,
85 | sizes={'word_size': self.loader.n_words, 'dist_size': self.loader.n_dist,
86 | 'type_size': self.loader.n_type, 'rel_size': self.loader.n_rel},
87 | maps={'word2idx': self.loader.word2index, 'idx2word': self.loader.index2word,
88 | 'rel2idx': self.loader.rel2index, 'idx2rel': self.loader.index2rel,
89 | 'type2idx': self.loader.type2index, 'idx2type': self.loader.index2type,
90 | 'dist2idx': self.loader.dist2index, 'idx2dist': self.loader.index2dist},
91 | lab2ign=self.loader.label2ignore)
92 |
93 | # GPU/CPU
94 | if self.params['gpu'] != -1:
95 | torch.cuda.set_device(self.device)
96 | model_0.to(self.device)
97 | return model_0
98 |
99 | def set_optimizer(self, model_0):
100 | # OPTIMIZER
101 | # do not regularize biases
102 | params2reg = []
103 | params0reg = []
104 | for p_name, p_value in model_0.named_parameters():
105 | if '.bias' in p_name:
106 | params0reg += [p_value]
107 | else:
108 | params2reg += [p_value]
109 | assert len(params0reg) + len(params2reg) == len(list(model_0.parameters()))
110 | groups = [dict(params=params2reg), dict(params=params0reg, weight_decay=.0)]
111 | optimizer = optim.Adam(groups, lr=self.params['lr'], weight_decay=self.params['reg'], amsgrad=True)
112 |
113 | # Train Model
114 | print_options(self.params)
115 | for p_name, p_value in model_0.named_parameters():
116 | if p_value.requires_grad:
117 | print(p_name)
118 | return optimizer
119 |
120 | @staticmethod
121 | def iterator(x, shuffle_=False, batch_size=1):
122 | """
123 | Create a new iterator for this epoch.
124 | Shuffle the data if specified.
125 | """
126 | if shuffle_:
127 | shuffle(x)
128 | new = [x[i:i+batch_size] for i in range(0, len(x), batch_size)]
129 | return new
130 |
131 | def run(self):
132 | """
133 | Main Training Loop.
134 | """
135 | print('\n======== START TRAINING: {} ========\n'.format(
136 | datetime.datetime.now().strftime("%d-%m-%y_%H:%M:%S")))
137 |
138 | random.shuffle(self.data['train']) # shuffle training data at least once
139 | for epoch in range(1, self.epoch+1):
140 | self.train_epoch(epoch)
141 |
142 | if self.pa:
143 | self.parameter_averaging()
144 |
145 | self.eval_epoch()
146 |
147 | stop = self.epoch_checking(epoch)
148 | if stop and self.es:
149 | break
150 |
151 | if self.pa:
152 | self.parameter_averaging(reset=True)
153 |
154 | print('Best epoch: {}'.format(self.best_epoch))
155 | if self.pa:
156 | self.parameter_averaging(epoch=self.best_epoch)
157 | self.eval_epoch(final=True, save_predictions=True)
158 |
159 | print('\n======== END TRAINING: {} ========\n'.format(
160 | datetime.datetime.now().strftime("%d-%m-%y_%H:%M:%S")))
161 |
162 | def train_epoch(self, epoch):
163 | """
164 | Evaluate the model on the train set.
165 | """
166 | t1 = time()
167 | output = {'tp': [], 'fp': [], 'fn': [], 'tn': [], 'loss': [], 'preds': []}
168 | train_info = []
169 |
170 | self.model = self.model.train()
171 | train_iter = self.iterator(self.data['train'], batch_size=self.params['batch'],
172 | shuffle_=self.params['shuffle_data'])
173 | for batch_idx, batch in enumerate(train_iter):
174 | batch = self.convert_batch(batch)
175 |
176 | with autograd.detect_anomaly():
177 | self.optimizer.zero_grad()
178 | loss, stats, predictions, select = self.model(batch)
179 | loss.backward() # backward computation
180 |
181 | nn.utils.clip_grad_norm_(self.model.parameters(), self.gc) # gradient clipping
182 | self.optimizer.step() # update
183 |
184 | output['loss'] += [loss.item()]
185 | output['tp'] += [stats['tp'].to('cpu').data.numpy()]
186 | output['fp'] += [stats['fp'].to('cpu').data.numpy()]
187 | output['fn'] += [stats['fn'].to('cpu').data.numpy()]
188 | output['tn'] += [stats['tn'].to('cpu').data.numpy()]
189 | output['preds'] += [predictions.to('cpu').data.numpy()]
190 | train_info += [batch['info'][select[0].to('cpu').data.numpy(),
191 | select[1].to('cpu').data.numpy(),
192 | select[2].to('cpu').data.numpy()]]
193 |
194 | t2 = time()
195 | if self.window:
196 | total_loss, scores = self.subdocs_performance(output['loss'], output['preds'], train_info)
197 | else:
198 | total_loss, scores = self.performance(output)
199 |
200 | self.train_res['loss'] += [total_loss]
201 | self.train_res['score'] += [scores[self.primary_metric]]
202 | print('Epoch: {:02d} | TRAIN | LOSS = {:.05f}, '.format(epoch, total_loss), end="")
203 | print_results(scores, [], self.show_class, t2-t1)
204 |
205 | def eval_epoch(self, final=False, save_predictions=False):
206 | """
207 | Evaluate the model on the test set.
208 | No backward computation is allowed.
209 | """
210 | t1 = time()
211 | output = {'tp': [], 'fp': [], 'fn': [], 'tn': [], 'loss': [], 'preds': []}
212 | test_info = []
213 |
214 | self.model = self.model.eval()
215 | test_iter = self.iterator(self.data['test'], batch_size=self.params['batch'], shuffle_=False)
216 | for batch_idx, batch in enumerate(test_iter):
217 | batch = self.convert_batch(batch)
218 |
219 | with torch.no_grad():
220 | loss, stats, predictions, select = self.model(batch)
221 |
222 | output['loss'] += [loss.item()]
223 | output['tp'] += [stats['tp'].to('cpu').data.numpy()]
224 | output['fp'] += [stats['fp'].to('cpu').data.numpy()]
225 | output['fn'] += [stats['fn'].to('cpu').data.numpy()]
226 | output['tn'] += [stats['tn'].to('cpu').data.numpy()]
227 | output['preds'] += [predictions.to('cpu').data.numpy()]
228 | test_info += [batch['info'][select[0].to('cpu').data.numpy(),
229 | select[1].to('cpu').data.numpy(),
230 | select[2].to('cpu').data.numpy()]]
231 | t2 = time()
232 |
233 | # estimate performance
234 | if self.window:
235 | total_loss, scores = self.subdocs_performance(output['loss'], output['preds'], test_info)
236 | else:
237 | total_loss, scores = self.performance(output)
238 |
239 | if not final:
240 | self.test_res['loss'] += [total_loss]
241 | self.test_res['score'] += [scores[self.primary_metric]]
242 | print(' TEST | LOSS = {:.05f}, '.format(total_loss), end="")
243 | print_results(scores, [], self.show_class, t2-t1)
244 | print()
245 |
246 | if save_predictions:
247 | write_preds(output['preds'], test_info, self.preds_file, map_=self.loader.index2rel)
248 | write_errors(output['preds'], test_info, self.preds_file, map_=self.loader.index2rel)
249 |
250 | def parameter_averaging(self, epoch=None, reset=False):
251 | """
252 | Perform parameter averaging.
253 | For each epoch, average the parameters up to this epoch and then evaluate on test set.
254 | If 'reset' option: use the last epoch parameters for the next epock
255 | """
256 | for p_name, p_value in self.model.named_parameters():
257 | if p_name not in self.averaged_params:
258 | self.averaged_params[p_name] = []
259 |
260 | if reset:
261 | p_new = copy.deepcopy(self.averaged_params[p_name][-1]) # use last epoch param
262 |
263 | elif epoch:
264 | p_new = np.mean(self.averaged_params[p_name][:epoch], axis=0) # estimate average until this epoch
265 |
266 | else:
267 | self.averaged_params[p_name].append(p_value.data.to('cpu').numpy())
268 | p_new = np.mean(self.averaged_params[p_name], axis=0) # estimate average
269 |
270 | # assign to array
271 | if self.device != 'cpu':
272 | p_value.data = torch.from_numpy(p_new).to(self.device)
273 | else:
274 | p_value.data = torch.from_numpy(p_new)
275 |
276 | def epoch_checking(self, epoch):
277 | """
278 | Perform early stopping.
279 | If performance does not improve for a number of consecutive epochs ("max_patience")
280 | then stop the training and keep the best epoch: stopped_epoch - max_patience
281 |
282 | Args:
283 | epoch (int): current training epoch
284 |
285 | Returns: (int) best_epoch, (bool) stop
286 | """
287 | if self.test_res['score'][-1] > self.best_score: # improvement
288 | self.best_score = self.test_res['score'][-1]
289 | self.cur_patience = 0
290 | if self.es:
291 | self.best_epoch = epoch
292 | else:
293 | self.cur_patience += 1
294 | if not self.es:
295 | self.best_epoch = epoch
296 |
297 | if epoch % 5 == 0 and self.es:
298 | print('Current best {} score {:.6f} @ epoch {}\n'.format(self.params['primary_metric'],
299 | self.best_score, self.best_epoch))
300 |
301 | if self.max_patience == self.cur_patience and self.es: # early stop must happen
302 | self.best_epoch = epoch - self.max_patience
303 | return True
304 | else:
305 | return False
306 |
307 | @staticmethod
308 | def performance(stats):
309 | """
310 | Estimate total loss for an epoch.
311 | Calculate Micro and Macro P/R/F1 scores & Accuracy.
312 | Returns: (float) average loss, (float) micro and macro P/R/F1
313 | """
314 | def fbeta_score(precision, recall, beta=1.0):
315 | beta_square = beta * beta
316 | if (precision != 0.0) and (recall != 0.0):
317 | res = ((1 + beta_square) * precision * recall / (beta_square * precision + recall))
318 | else:
319 | res = 0.0
320 | return res
321 |
322 | def prf1(tp_, fp_, fn_, tn_):
323 | tp_ = np.sum(tp_, axis=0)
324 | fp_ = np.sum(fp_, axis=0)
325 | fn_ = np.sum(fn_, axis=0)
326 | tn_ = np.sum(tn_, axis=0)
327 |
328 | atp = np.sum(tp_)
329 | afp = np.sum(fp_)
330 | afn = np.sum(fn_)
331 | atn = np.sum(tn_)
332 |
333 | micro_p = (1.0 * atp) / (atp + afp) if (atp + afp != 0) else 0.0
334 | micro_r = (1.0 * atp) / (atp + afn) if (atp + afn != 0) else 0.0
335 | micro_f = fbeta_score(micro_p, micro_r)
336 |
337 | pp = [0]
338 | rr = [0]
339 | ff = [0]
340 | macro_p = np.mean(pp)
341 | macro_r = np.mean(rr)
342 | macro_f = np.mean(ff)
343 |
344 | acc = (atp + atn) / (atp + atn + afp + afn) if (atp + atn + afp + afn) else 0.0
345 | return {'acc': acc,
346 | 'micro_p': micro_p, 'micro_r': micro_r, 'micro_f': micro_f,
347 | 'macro_p': macro_p, 'macro_r': macro_r, 'macro_f': macro_f,
348 | 'tp': atp, 'true': atp+afn, 'pred': atp+afp, 'total': (atp + atn + afp + afn)}
349 |
350 | fin_loss = sum(stats['loss']) / len(stats['loss'])
351 | scores = prf1(stats['tp'], stats['fp'], stats['fn'], stats['tn'])
352 | return fin_loss, scores
353 |
354 | def subdocs_performance(self, loss_, preds, info):
355 | pairs = {}
356 | for p, i in zip(preds, info):
357 | i = [i_ for i_ in i if i_]
358 | assert len(p) == len(i)
359 |
360 | # pmid, e1, e2, pred, truth
361 | for k, j in zip(p, i):
362 | if (j['pmid'].split('__')[0], j['entA'].id, j['entB'].id, j['rel']) not in pairs:
363 | pairs[(j['pmid'].split('__')[0], j['entA'].id, j['entB'].id, j['rel'])] = []
364 | pairs[(j['pmid'].split('__')[0], j['entA'].id, j['entB'].id, j['rel'])] += [k]
365 |
366 | res = []
367 | tr = []
368 | for l in pairs.keys():
369 | res += [self.loader.label2ignore if all([c == self.loader.label2ignore for c in pairs[l]])
370 | else not self.loader.label2ignore]
371 | tr += [l[3]]
372 |
373 | rsize = self.rel_size
374 |
375 | mask_t = np.equal(tr, self.loader.label2ignore)
376 | mask_p = np.equal(res, self.loader.label2ignore)
377 |
378 | true = np.where(mask_t, rsize, tr)
379 | pred = np.where(mask_p, rsize, res)
380 |
381 | tp_mask = np.where(pred == true, true, rsize)
382 | fp_mask = np.where(pred != true, pred, rsize)
383 | fn_mask = np.where(pred != true, true, rsize)
384 |
385 | tp = np.bincount(tp_mask, minlength=rsize + 1)[:rsize]
386 | fp = np.bincount(fp_mask, minlength=rsize + 1)[:rsize]
387 | fn = np.bincount(fn_mask, minlength=rsize + 1)[:rsize]
388 | tn = np.sum(mask_t & mask_p)
389 | return self.performance({'loss': loss_, 'tp': tp, 'fp': fp, 'fn': fn, 'tn': tn})
390 |
391 | def convert_batch(self, batch):
392 | new_batch = {'entities': []}
393 | ent_count, sent_count, word_count = 0, 0, 0
394 | full_text = []
395 |
396 | # TODO make this faster
397 | for i, b in enumerate(batch):
398 | current_text = list(itertools.chain.from_iterable(b['text']))
399 | full_text += current_text
400 |
401 | temp = []
402 | for e in b['ents']:
403 | # token ids are correct
404 | assert full_text[(e[2] + word_count):(e[3] + word_count)] == current_text[e[2]:e[3]], \
405 | '{} != {}'.format(full_text[(e[2] + word_count):(e[3] + word_count)], current_text[e[2]:e[3]])
406 | temp += [[e[0] + ent_count, e[1], e[2] + word_count, e[3] + word_count, e[4] + sent_count]]
407 |
408 | new_batch['entities'] += [np.array(temp)]
409 | word_count += sum([len(s) for s in b['text']])
410 | ent_count = max([t[0] for t in temp]) + 1
411 | sent_count += len(b['text'])
412 |
413 | new_batch['entities'] = torch.as_tensor(np.concatenate(new_batch['entities'], axis=0)).long().to(self.device)
414 |
415 | batch_ = [{k: v for k, v in b.items() if (k != 'info' and k != 'text')} for b in batch]
416 | converted_batch = concat_examples(batch_, device=self.device, padding=-1)
417 |
418 | converted_batch['adjacency'][converted_batch['adjacency'] == -1] = 0
419 | converted_batch['dist'][converted_batch['dist'] == -1] = self.loader.n_dist
420 |
421 | new_batch['adjacency'] = converted_batch['adjacency'].byte()
422 | new_batch['distances'] = converted_batch['dist']
423 | new_batch['relations'] = converted_batch['rels']
424 | new_batch['section'] = converted_batch['section']
425 | new_batch['word_sec'] = converted_batch['word_sec'][converted_batch['word_sec'] != -1].long()
426 | new_batch['words'] = converted_batch['words'][converted_batch['words'] != -1].long()
427 | new_batch['pairs4class'] = torch.as_tensor(self.pairs4class).long().to(self.device)
428 | new_batch['info'] = np.stack([np.array(np.pad(b['info'],
429 | ((0, new_batch['section'][:, 0].sum(dim=0).item() - b['info'].shape[0]),
430 | (0, new_batch['section'][:, 0].sum(dim=0).item() - b['info'].shape[0])),
431 | 'constant',
432 | constant_values=(-1, -1))) for b in batch], axis=0)
433 |
434 | if self.example:
435 | for i, b in enumerate(batch):
436 | print('===== DOCUMENT NO {} ====='.format(i))
437 | for s in b['text']:
438 | print(' '.join([self.loader.index2word[t] for t in s]))
439 | print(b['ents'])
440 |
441 | print(new_batch['relations'][i])
442 | print(new_batch['adjacency'][i])
443 | print(np.array([self.loader.index2dist[p] for p in
444 | new_batch['distances'][i].to('cpu').data.numpy().ravel()]).reshape(
445 | new_batch['distances'][i].shape))
446 | sys.exit()
447 | return new_batch
448 |
--------------------------------------------------------------------------------
/src/nnet/walks.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 | """
4 | Created on 06/03/19
5 |
6 | author: fenia
7 | """
8 |
9 | import torch
10 | from torch import nn
11 | torch.set_printoptions(profile="full")
12 |
13 |
14 | class WalkLayer(nn.Module):
15 | def __init__(self, input_size, iters=0, beta=0.9, device=-1):
16 | """
17 | Walk Layer --> Walk on the edges of a predefined graph
18 | Args:
19 | input_size (int): input dimensionality
20 | iters (int): number of iterations --> 2^{iters} = walks-length
21 | beta (float): weight shorter/longer walks
22 | Return:
23 | pairs (Tensor): final pair representations
24 | size (batch * nodes * nodes, features)
25 | """
26 | super(WalkLayer, self).__init__()
27 |
28 | self.W = nn.Parameter(nn.init.normal_(torch.empty(input_size, input_size)), requires_grad=True)
29 | self.sigmoid = nn.Sigmoid()
30 | self.beta = beta
31 | self.iters = iters
32 | self.device = device
33 |
34 | @staticmethod
35 | def init_graph(graph, adj):
36 | """
37 | Initialize graph with 0 connections based on the adjacency matrix
38 | """
39 | graph = torch.where(adj.unsqueeze(-1), graph, torch.zeros_like(graph))
40 | return graph
41 |
42 | @staticmethod
43 | def mask_invalid_paths(graph, mask3d):
44 | """
45 | Mask invalid paths
46 | *(any node) -> A -> A
47 | A -> A -> *(any node)
48 | A -> *(any node) -> A
49 |
50 | Additionally mask paths that involve padded entities as intermediate nodes
51 | -inf so that sigmoid returns 0
52 | """
53 | items = range(graph.size(1))
54 | graph[:, :, items, items] = float('-inf') # *->A->A
55 | graph[:, items, items] = float('-inf') # A->A->*
56 | graph[:, items, :, items] = float('-inf') # A->*->A (self-connection)
57 |
58 | graph = torch.where(mask3d.unsqueeze(-1), graph, torch.as_tensor([float('-inf')]).to(graph.device)) # padded
59 | graph = torch.where(torch.eq(graph, 0.0).all(dim=4, keepdim=True),
60 | torch.as_tensor([float('-inf')]).to(graph.device),
61 | graph) # remaining (make sure the whole representation is zero)
62 | return graph
63 |
64 | def generate(self, old_graph):
65 | """
66 | Walk-generation: Combine consecutive edges.
67 | Returns: previous graph,
68 | extended graph with intermediate node connections (dim=2)
69 | """
70 | graph = torch.matmul(old_graph, self.W[None, None]) # (B, I, I, D)
71 | graph = torch.einsum('bijk, bjmk -> bijmk', graph, old_graph) # (B, I, I, I, D) -> dim=2 intermediate node
72 | return old_graph, graph
73 |
74 | def aggregate(self, old_graph, new_graph):
75 | """
76 | Walk-aggregation: Combine multiple paths via intermediate nodes.
77 | """
78 | # if the new representation is zero (i.e. impossible path), keep the original one --> [beta = 1]
79 | beta_mat = torch.where(torch.isinf(new_graph).all(dim=2),
80 | torch.ones_like(old_graph),
81 | torch.full_like(old_graph, self.beta))
82 |
83 | new_graph = self.sigmoid(new_graph)
84 | new_graph = torch.sum(new_graph, dim=2) # non-linearity & sum pooling
85 | new_graph = torch.lerp(new_graph, old_graph, weight=beta_mat)
86 | return new_graph
87 |
88 | def forward(self, graph, adj_=None, mask_=None):
89 | graph = self.init_graph(graph, adj_)
90 |
91 | for _ in range(0, self.iters):
92 | old_graph, graph = self.generate(graph)
93 | graph = self.mask_invalid_paths(graph, mask_)
94 | graph = self.aggregate(old_graph, graph)
95 | return graph
96 |
--------------------------------------------------------------------------------
/src/reader.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 | """
4 | Created on 19/06/2019
5 |
6 | author: fenia
7 | """
8 |
9 | from collections import OrderedDict
10 | from recordtype import recordtype
11 | import numpy as np
12 |
13 |
14 | EntityInfo = recordtype('EntityInfo', 'id type mstart mend sentNo')
15 | PairInfo = recordtype('PairInfo', 'type direction cross')
16 |
17 |
18 | def chunks(l, n):
19 | """
20 | Successive n-sized chunks from l.
21 | """
22 | res = []
23 | for i in range(0, len(l), n):
24 | assert len(l[i:i + n]) == n
25 | res += [l[i:i + n]]
26 | return res
27 |
28 |
29 | def overlap_chunk(chunk=1, lst=None):
30 | if len(lst) <= chunk:
31 | return [lst]
32 | else:
33 | return [lst[i:i + chunk] for i in range(0, len(lst)-chunk+1, 1)]
34 |
35 |
36 | def read_subdocs(input_file, window, documents, entities, relations):
37 | """
38 | Read documents as sub-documents of N consecutive sentences.
39 | Args:
40 | input_file: file with documents
41 | """
42 | lost_pairs, total_pairs = 0, 0
43 | lengths = []
44 | sents = []
45 | with open(input_file, 'r') as infile:
46 | for line in infile:
47 | line = line.rstrip().split('\t')
48 | pmid = line[0]
49 | text = line[1]
50 | prs = chunks(line[2:], 17) # all the pairs in the document
51 |
52 | sentences = text.split('|') # document sentences
53 | all_sent_lengths = [len(s.split(' ')) for s in sentences] # document sentence lengths
54 |
55 | sent_chunks = overlap_chunk(chunk=window, lst=sentences) # split document into sub-documents
56 |
57 | unique_pairs = []
58 | for num, sent in enumerate(sent_chunks):
59 | sent_ids = list(np.arange(int(window)) + num)
60 |
61 | sub_pmid = pmid+'__'+str(num)
62 |
63 | if sub_pmid not in documents:
64 | documents[sub_pmid] = [t.split(' ') for t in sent]
65 |
66 | if sub_pmid not in entities:
67 | entities[sub_pmid] = OrderedDict()
68 |
69 | if sub_pmid not in relations:
70 | relations[sub_pmid] = OrderedDict()
71 |
72 | lengths += [max([len(d) for d in documents[sub_pmid]])]
73 | sents += [len(sent)]
74 |
75 | for p in prs:
76 | # entities
77 | for (ent, typ_, start, end, sn) in [(p[5], p[7], p[8], p[9], p[10]),
78 | (p[11], p[13], p[14], p[15], p[16])]:
79 |
80 | if ent not in entities[sub_pmid]:
81 | s_ = list(map(int, sn.split(':'))) # doc-level ids
82 | m_s_ = list(map(int, start.split(':')))
83 | m_e_ = list(map(int, end.split(':')))
84 | assert len(s_) == len(m_s_) == len(m_e_)
85 |
86 | sent_no_new = []
87 | mstart_new = []
88 | mend_new = []
89 | for n, (old_s, old_ms, old_me) in enumerate(zip(s_, m_s_, m_e_)):
90 | if old_s in sent_ids:
91 | sub_ = sum(all_sent_lengths[0:old_s])
92 |
93 | assert sent[old_s-num] == sentences[old_s]
94 | assert sent[old_s-num].split(' ')[(old_ms-sub_):(old_me-sub_)] == \
95 | ' '.join(sentences).split(' ')[old_ms:old_me]
96 | sent_no_new += [old_s - num]
97 | mstart_new += [old_ms - sub_]
98 | mend_new += [old_me - sub_]
99 |
100 | if sent_no_new and mstart_new and mend_new:
101 | entities[sub_pmid][ent] = EntityInfo(ent, typ_,
102 | ':'.join(map(str, mstart_new)),
103 | ':'.join(map(str, mend_new)),
104 | ':'.join(map(str, sent_no_new)))
105 |
106 | for p in prs:
107 | # pairs
108 | if (p[5] in entities[sub_pmid]) and (p[11] in entities[sub_pmid]):
109 | if (p[5], p[11]) not in relations[sub_pmid]:
110 | relations[sub_pmid][(p[5], p[11])] = PairInfo(p[0], p[1], p[2])
111 |
112 | if (pmid, p[5], p[11]) not in unique_pairs:
113 | unique_pairs += [(pmid, p[5], p[11])]
114 |
115 | if len(prs) != len(unique_pairs):
116 | for x in prs:
117 | if (pmid, x[5], x[11]) not in unique_pairs:
118 | if x[0] != '1:NR:2' and x[0] != 'not_include':
119 | lost_pairs += 1
120 | print('--> Lost pair {}, {}, {}: {} {}'.format(pmid, x[5], x[11], x[10], x[16]))
121 | else:
122 | if x[0] != '1:NR:2' and x[0] != 'not_include':
123 | total_pairs += 1
124 |
125 | todel = []
126 | for pmid, d in relations.items():
127 | if not relations[pmid]:
128 | todel += [pmid]
129 |
130 | for pmid in todel:
131 | del documents[pmid]
132 | del entities[pmid]
133 | del relations[pmid]
134 |
135 | print('LOST PAIRS: {}/{}'.format(lost_pairs, total_pairs))
136 | assert len(entities) == len(documents) == len(relations)
137 | return lengths, sents, documents, entities, relations
138 |
139 |
140 | def read(input_file, documents, entities, relations):
141 | """
142 | Read the full document at a time.
143 | """
144 | lengths = []
145 | sents = []
146 | with open(input_file, 'r') as infile:
147 | for line in infile:
148 | line = line.rstrip().split('\t')
149 | pmid = line[0]
150 | text = line[1]
151 | prs = chunks(line[2:], 17)
152 |
153 | if pmid not in documents:
154 | documents[pmid] = [t.split(' ') for t in text.split('|')]
155 |
156 | if pmid not in entities:
157 | entities[pmid] = OrderedDict()
158 |
159 | if pmid not in relations:
160 | relations[pmid] = OrderedDict()
161 |
162 | # max intra-sentence length and max inter-sentence length
163 | lengths += [max([len(s) for s in documents[pmid]] + [len(documents[pmid])])]
164 | sents += [len(text.split('|'))]
165 |
166 | allp = 0
167 | for p in prs:
168 | if (p[5], p[11]) not in relations[pmid]:
169 | relations[pmid][(p[5], p[11])] = PairInfo(p[0], p[1], p[2])
170 | allp += 1
171 | else:
172 | print('duplicates!')
173 |
174 | # entities
175 | if p[5] not in entities[pmid]:
176 | entities[pmid][p[5]] = EntityInfo(p[5], p[7], p[8], p[9], p[10])
177 |
178 | if p[11] not in entities[pmid]:
179 | entities[pmid][p[11]] = EntityInfo(p[11], p[13], p[14], p[15], p[16])
180 |
181 | assert len(relations[pmid]) == allp
182 |
183 | todel = []
184 | for pmid, d in relations.items():
185 | if not relations[pmid]:
186 | todel += [pmid]
187 |
188 | for pmid in todel:
189 | del documents[pmid]
190 | del entities[pmid]
191 | del relations[pmid]
192 |
193 | return lengths, sents, documents, entities, relations
194 |
--------------------------------------------------------------------------------
/src/utils.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 | """
4 | Created on 22/02/19
5 |
6 | author: fenia
7 | """
8 |
9 | import sys
10 | import os
11 | from tabulate import tabulate
12 | import itertools
13 | import numpy as np
14 | import pickle as pkl
15 | import torch
16 | import matplotlib
17 | matplotlib.use('Agg')
18 | import matplotlib.pyplot as plt
19 |
20 |
21 | def solve(A, B):
22 | A = list(map(int, A))
23 | B = list(map(int, B))
24 | m = len(A)
25 | n = len(B)
26 | A.sort()
27 | B.sort()
28 | a = 0
29 | b = 0
30 | result = sys.maxsize
31 |
32 | while a < m and b < n:
33 | if abs(A[a] - B[b]) < result:
34 | result = abs(A[a] - B[b])
35 |
36 | # Move Smaller Value
37 | if A[a] < B[b]:
38 | a += 1
39 | else:
40 | b += 1
41 | # return final sma result
42 | return result
43 |
44 |
45 | def write_errors(preds, info, ofile, map_=None):
46 | """ Write model errors to file """
47 | print('Saving predictions ... ', end="")
48 | with open(ofile+'.errors', 'w') as outfile:
49 | for p, i in zip(preds, info):
50 | i = [i_ for i_ in i if i_]
51 | assert len(p) == len(i)
52 |
53 | for k, j in zip(p, i):
54 | if k != j['rel']:
55 | outfile.write('Prediction: {} \t Truth: {} \t Type: {} \n'.format(map_[k], map_[j['rel']], j['cross']))
56 | doc = [it for items in j['doc'] for it in items]
57 | outfile.write('{}\n{}\n'.format(j['pmid'], ' '.join(doc)))
58 |
59 | gg1 = ' | '.join([' '.join(doc[int(m1):int(m2)]) for m1,m2 in
60 | zip(j['entA'].mstart.split(':'), j['entA'].mend.split(':'))])
61 | gg2 = ' | '.join([' '.join(doc[int(m1):int(m2)]) for m1, m2 in
62 | zip(j['entB'].mstart.split(':'), j['entB'].mend.split(':'))])
63 |
64 | outfile.write('Arg1: {} | {}\n'.format(j['entA'].id, gg1))
65 | outfile.write('Arg2: {} | {}\n'.format(j['entB'].id, gg2))
66 | outfile.write('Distance: {}\n'.format(solve(j['sentA'].split(':'), j['sentB'].split(':'))))
67 | outfile.write('\n')
68 | print('DONE')
69 |
70 |
71 | def write_preds(preds, info, ofile, map_=None):
72 | """ Write predictions to file """
73 | print('Saving errors ... ', end="")
74 | with open(ofile+'.preds', 'w') as outfile:
75 | for p, i in zip(preds, info):
76 | i = [i_ for i_ in i if i_]
77 | assert len(p) == len(i)
78 |
79 | for k, j in zip(p, i):
80 | # pmid, e1, e2, pred, truth
81 | if map_[k] == '1:NR:2':
82 | pass
83 | else:
84 | outfile.write('{}\n'.format('|'.join([j['pmid'].split('__')[0],
85 | j['entA'].id, j['entB'].id, j['cross'],
86 | str(solve(j['sentA'].split(':'), j['sentB'].split(':'))),
87 | map_[k]])))
88 | print('DONE')
89 |
90 |
91 | def plot_learning_curve(trainer, model_folder):
92 | """
93 | Plot the learning curves for training and test set (loss and primary score measure)
94 |
95 | Args:
96 | trainer (Class): trainer object
97 | model_folder (str): folder to save figures
98 | """
99 | x = list(map(int, np.arange(len(trainer.train_res['loss']))))
100 | fig = plt.figure()
101 | plt.subplot(2, 1, 1)
102 | plt.plot(x, trainer.train_res['loss'], 'b', label='train')
103 | plt.plot(x, trainer.test_res['loss'], 'g', label='test')
104 | plt.legend()
105 | plt.ylabel('Loss')
106 | plt.yticks(np.arange(0, 1, 0.1))
107 |
108 | plt.subplot(2, 1, 2)
109 | plt.plot(x, trainer.train_res['score'], 'b', label='train')
110 | plt.plot(x, trainer.test_res['score'], 'g', label='test')
111 | plt.legend()
112 | plt.ylabel('F1-score')
113 | plt.xlabel('Epochs')
114 | plt.yticks(np.arange(0, 1, 0.1))
115 |
116 | fig.savefig(model_folder + '/learn_curves.png', bbox_inches='tight')
117 |
118 |
119 | def print_results(scores, scores_class, show_class, time):
120 | """
121 | Print class-wise results.
122 |
123 | Args:
124 | scores (dict): micro and macro scores
125 | scores_class: score per class
126 | show_class (bool): show or not
127 | time: time
128 | """
129 |
130 | def indent(txt, spaces=18):
131 | return "\n".join(" " * spaces + ln for ln in txt.splitlines())
132 |
133 | if show_class:
134 | # print results for every class
135 | scores_class.append(['-----', None, None, None])
136 | scores_class.append(['macro score', scores['macro_p'], scores['macro_r'], scores['macro_f']])
137 | scores_class.append(['micro score', scores['micro_p'], scores['micro_r'], scores['micro_f']])
138 | print(' | {}\n'.format(humanized_time(time)))
139 | print(indent(tabulate(scores_class,
140 | headers=['Class', 'P', 'R', 'F1'],
141 | tablefmt='orgtbl',
142 | floatfmtL=".4f",
143 | missingval="")))
144 | print()
145 | else:
146 | print('ACC = {:.04f} , '
147 | 'MICRO P/R/F1 = {:.04f}\t{:.04f}\t{:.04f} | '.format(scores['acc'], scores['micro_p'], scores['micro_r'],
148 | scores['micro_f']), end="")
149 |
150 | l = ':<7' # +str(len(str(scores['total'])))
151 | s = 'TP/ACTUAL/PRED = {'+l+'}/{'+l+'}/{'+l+'}, TOTAL {'+l+'}'
152 | print(s.format(scores['tp'], scores['true'], scores['pred'], scores['total']), end="")
153 | print(' | {}'.format(humanized_time(time)))
154 |
155 |
156 | class Tee(object):
157 | """
158 | Object to print stdout to a file.
159 | """
160 | def __init__(self, *files):
161 | self.files = files
162 |
163 | def write(self, obj):
164 | for f_ in self.files:
165 | f_.write(obj)
166 | f_.flush() # If you want the output to be visible immediately
167 |
168 | def flush(self):
169 | for f_ in self.files:
170 | f_.flush()
171 |
172 |
173 | def humanized_time(second):
174 | """
175 | :param second: time in seconds
176 | :return: human readable time (hours, minutes, seconds)
177 | """
178 | m, s = divmod(second, 60)
179 | h, m = divmod(m, 60)
180 | return "%dh %02dm %02ds" % (h, m, s)
181 |
182 |
183 | def setup_log(params, mode):
184 | """
185 | Setup .log file to record training process and results.
186 |
187 | Args:
188 | params (dict): model parameters
189 |
190 | Returns:
191 | model_folder (str): model directory
192 | """
193 | if params['walks_iter'] == 0:
194 | length = 1
195 | elif not params['walks_iter']:
196 | length = 1
197 | else:
198 | length = 2**params['walks_iter']
199 | # folder_name = 'b{}-wd{}-ld{}-od{}-td{}-beta{}-pd{}-di{}-do{}-lr{}-gc{}-r{}-p{}-walks{}'.format(
200 | # params['batch'], params['word_dim'], params['lstm_dim'], params['out_dim'], params['type_dim'],
201 | # params['pos_dim'], params['beta'], params['pos_dim'], params['drop_i'], params['drop_o'], params['lr'],
202 | # params['gc'], params['reg'], params['patience'], length)
203 |
204 | folder_name = 'b{}-walks{}'.format(params['batch'], length)
205 |
206 | folder_name += '_'+'_'.join(params['edges'])
207 |
208 | if params['context']:
209 | folder_name += '_context'
210 |
211 | if params['types']:
212 | folder_name += '_types'
213 |
214 | if params['dist']:
215 | folder_name += '_dist'
216 |
217 | if params['freeze_words']:
218 | folder_name += '_freeze'
219 |
220 | if params['window']:
221 | folder_name += '_win'+str(params['window'])
222 |
223 | model_folder = params['folder'] + '/' + folder_name
224 | if not os.path.exists(model_folder):
225 | os.makedirs(model_folder)
226 | log_file = model_folder + '/info_'+mode+'.log'
227 |
228 | f = open(log_file, 'w')
229 | sys.stdout = Tee(sys.stdout, f)
230 | return model_folder
231 |
232 |
233 | def observe(model):
234 | """
235 | Observe model parameters: name, range of matrices & gradients
236 |
237 | Args
238 | model: specified model object
239 | """
240 | for name, param in model.named_parameters():
241 | p_data, p_grad = param.data, param.grad.data
242 | print('Name: {:<30}\tRange of data: [{:.4f}, {:.4f}]\tRange of gradient: [{:.4f}, {:.4f}]'.format(name,
243 | np.min(p_data.data.to('cpu').numpy()),
244 | np.max(p_data.data.to('cpu').numpy()),
245 | np.min(p_grad.data.to('cpu').numpy()),
246 | np.max(p_grad.data.to('cpu').numpy())))
247 | print('--------------------------------------')
248 |
249 |
250 | def save_model(model_folder, trainer, loader):
251 | print('\nSaving the model & the parameters ...')
252 | # save mappings
253 | with open(os.path.join(model_folder, 'mappings.pkl'), 'wb') as f:
254 | pkl.dump(loader, f, pkl.HIGHEST_PROTOCOL)
255 | torch.save(trainer.model.state_dict(), os.path.join(model_folder, 're.model'))
256 |
257 |
258 | def load_model(model_folder, trainer):
259 | print('\nLoading model & parameters ...')
260 | trainer.model.load_state_dict(torch.load(os.path.join(model_folder, 're.model'),
261 | map_location=trainer.model.device))
262 | return trainer
263 |
264 |
265 | def load_mappings(model_folder):
266 | with open(os.path.join(model_folder, 'mappings.pkl'), 'rb') as f:
267 | loader = pkl.load(f)
268 | return loader
269 |
270 |
271 | def print_options(params):
272 | print('''\nParameters:
273 | - Train Data {}
274 | - Test Data {}
275 | - Embeddings {}, Freeze: {}
276 | - Save folder {}
277 |
278 | - batchsize {}
279 | - Walks iteration {} -> Length = {}
280 | - beta {}
281 |
282 | - Context {}
283 | - Node Type {}
284 | - Distances {}
285 | - Edge Types {}
286 | - Window {}
287 |
288 | - Epoch {}
289 | - UNK word prob {}
290 | - Parameter Average {}
291 | - Early stop {} -> Patience = {}
292 | - Regularization {}
293 | - Gradient Clip {}
294 | - Dropout I/O {}/{}
295 | - Learning rate {}
296 | - Seed {}
297 | '''.format(params['train_data'], params['test_data'], params['embeds'], params['freeze_words'],
298 | params['folder'],
299 | params['batch'],
300 | params['walks_iter'], 2 ** params['walks_iter'] if params['walks_iter'] else 0, params['beta'],
301 | params['context'], params['types'], params['dist'], params['edges'],
302 | params['window'], params['epoch'],
303 | params['unk_w_prob'], params['param_avg'],
304 | params['early_stop'], params['patience'],
305 | params['reg'], params['gc'], params['drop_i'], params['drop_o'], params['lr'], params['seed']))
306 |
--------------------------------------------------------------------------------