├── requirements.txt ├── tmp └── readme.txt ├── LICENSE ├── .gitignore ├── sample_data ├── ptb-gold │ ├── train.conllx │ ├── test.conllx │ └── dev.conllx └── amr-split │ ├── amr-dev.txt │ ├── amr-training.txt │ └── amr-test.txt ├── README.md └── src ├── main.py ├── embed.py ├── models.py ├── evaluation_amr.py ├── evaluation_ptb.py ├── probe.py └── utils.py /requirements.txt: -------------------------------------------------------------------------------- 1 | tqdm==4.52.0 2 | networkx==2.5 3 | numpy==1.19.2 4 | scikit-learn==0.23.2 5 | gensim==3.8.3 6 | transformers==4.6.0 7 | penman==1.1.0 8 | conllu==4.2.1 9 | -------------------------------------------------------------------------------- /tmp/readme.txt: -------------------------------------------------------------------------------- 1 | This folder is for temporary use. The embeddings of language models and graphs are stored here. If the embedding file is corrupted, you can simply remove that file and the code will recalculate it. 2 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 yifan 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .nox/ 42 | .coverage 43 | .coverage.* 44 | .cache 45 | nosetests.xml 46 | coverage.xml 47 | *.cover 48 | .hypothesis/ 49 | .pytest_cache/ 50 | 51 | # Translations 52 | *.mo 53 | *.pot 54 | 55 | # Django stuff: 56 | *.log 57 | local_settings.py 58 | db.sqlite3 59 | 60 | # Flask stuff: 61 | instance/ 62 | .webassets-cache 63 | 64 | # Scrapy stuff: 65 | .scrapy 66 | 67 | # Sphinx documentation 68 | docs/_build/ 69 | 70 | # PyBuilder 71 | target/ 72 | 73 | # Jupyter Notebook 74 | .ipynb_checkpoints 75 | 76 | # IPython 77 | profile_default/ 78 | ipython_config.py 79 | 80 | # pyenv 81 | .python-version 82 | 83 | # celery beat schedule file 84 | celerybeat-schedule 85 | 86 | # SageMath parsed files 87 | *.sage.py 88 | 89 | # Environments 90 | .env 91 | .venv 92 | env/ 93 | venv/ 94 | ENV/ 95 | env.bak/ 96 | venv.bak/ 97 | 98 | # Spyder project settings 99 | .spyderproject 100 | .spyproject 101 | 102 | # Rope project settings 103 | .ropeproject 104 | 105 | # mkdocs documentation 106 | /site 107 | 108 | # mypy 109 | .mypy_cache/ 110 | .dmypy.json 111 | dmypy.json 112 | 113 | # Pyre type checker 114 | .pyre/ 115 | -------------------------------------------------------------------------------- /sample_data/ptb-gold/train.conllx: -------------------------------------------------------------------------------- 1 | 1 In _ IN IN _ 45 prep _ _ 2 | 2 an _ DT DT _ 5 det _ _ 3 | 3 Oct. _ NNP NNP _ 5 nn _ _ 4 | 4 19 _ CD CD _ 5 num _ _ 5 | 5 review _ NN NN _ 1 pobj _ _ 6 | 6 of _ IN IN _ 5 prep _ _ 7 | 7 `` _ `` `` _ 9 punct _ _ 8 | 8 The _ DT DT _ 9 det _ _ 9 | 9 Misanthrope _ NN NN _ 6 pobj _ _ 10 | 10 '' _ '' '' _ 9 punct _ _ 11 | 11 at _ IN IN _ 9 prep _ _ 12 | 12 Chicago _ NNP NNP _ 15 poss _ _ 13 | 13 's _ POS POS _ 12 possessive _ _ 14 | 14 Goodman _ NNP NNP _ 15 nn _ _ 15 | 15 Theatre _ NNP NNP _ 11 pobj _ _ 16 | 16 -LRB- _ -LRB- -LRB- _ 20 punct _ _ 17 | 17 `` _ `` `` _ 20 punct _ _ 18 | 18 Revitalized _ VBN VBN _ 19 amod _ _ 19 | 19 Classics _ NNS NNS _ 20 nsubj _ _ 20 | 20 Take _ VBP VBP _ 5 dep _ _ 21 | 21 the _ DT DT _ 22 det _ _ 22 | 22 Stage _ NN NN _ 20 dobj _ _ 23 | 23 in _ IN IN _ 20 prep _ _ 24 | 24 Windy _ NNP NNP _ 25 nn _ _ 25 | 25 City _ NNP NNP _ 23 pobj _ _ 26 | 26 , _ , , _ 20 punct _ _ 27 | 27 '' _ '' '' _ 20 punct _ _ 28 | 28 Leisure _ NN NN _ 20 dep _ _ 29 | 29 & _ CC CC _ 28 cc _ _ 30 | 30 Arts _ NNS NNS _ 28 conj _ _ 31 | 31 -RRB- _ -RRB- -RRB- _ 20 punct _ _ 32 | 32 , _ , , _ 45 punct _ _ 33 | 33 the _ DT DT _ 34 det _ _ 34 | 34 role _ NN NN _ 45 nsubjpass _ _ 35 | 35 of _ IN IN _ 34 prep _ _ 36 | 36 Celimene _ NNP NNP _ 35 pobj _ _ 37 | 37 , _ , , _ 34 punct _ _ 38 | 38 played _ VBN VBN _ 34 partmod _ _ 39 | 39 by _ IN IN _ 38 prep _ _ 40 | 40 Kim _ NNP NNP _ 41 nn _ _ 41 | 41 Cattrall _ NNP NNP _ 39 pobj _ _ 42 | 42 , _ , , _ 34 punct _ _ 43 | 43 was _ VBD VBD _ 45 auxpass _ _ 44 | 44 mistakenly _ RB RB _ 45 advmod _ _ 45 | 45 attributed _ VBN VBN _ 0 root _ _ 46 | 46 to _ TO TO _ 45 prep _ _ 47 | 47 Christina _ NNP NNP _ 48 nn _ _ 48 | 48 Haag _ NNP NNP _ 46 pobj _ _ 49 | 49 . _ . . _ 45 punct _ _ 50 | 51 | 1 Ms. _ NNP NNP _ 2 nn _ _ 52 | 2 Haag _ NNP NNP _ 3 nsubj _ _ 53 | 3 plays _ VBZ VBZ _ 0 root _ _ 54 | 4 Elianti _ NNP NNP _ 3 dobj _ _ 55 | 5 . _ . . _ 3 punct _ _ 56 | 57 | 1 Rolls-Royce _ NNP NNP _ 4 nn _ _ 58 | 2 Motor _ NNP NNP _ 4 nn _ _ 59 | 3 Cars _ NNPS NNPS _ 4 nn _ _ 60 | 4 Inc. _ NNP NNP _ 5 nsubj _ _ 61 | 5 said _ VBD VBD _ 0 root _ _ 62 | 6 it _ PRP PRP _ 7 nsubj _ _ 63 | 7 expects _ VBZ VBZ _ 5 ccomp _ _ 64 | 8 its _ PRP$ PRP$ _ 10 poss _ _ 65 | 9 U.S. _ NNP NNP _ 10 nn _ _ 66 | 10 sales _ NNS NNS _ 13 nsubj _ _ 67 | 11 to _ TO TO _ 13 aux _ _ 68 | 12 remain _ VB VB _ 13 cop _ _ 69 | 13 steady _ JJ JJ _ 7 xcomp _ _ 70 | 14 at _ IN IN _ 13 prep _ _ 71 | 15 about _ IN IN _ 16 quantmod _ _ 72 | 16 1,200 _ CD CD _ 17 num _ _ 73 | 17 cars _ NNS NNS _ 14 pobj _ _ 74 | 18 in _ IN IN _ 13 prep _ _ 75 | 19 1990 _ CD CD _ 18 pobj _ _ 76 | 20 . _ . . _ 5 punct _ _ 77 | 78 | 1 The _ DT DT _ 4 det _ _ 79 | 2 luxury _ NN NN _ 4 nn _ _ 80 | 3 auto _ NN NN _ 4 nn _ _ 81 | 4 maker _ NN NN _ 7 nsubj _ _ 82 | 5 last _ JJ JJ _ 6 amod _ _ 83 | 6 year _ NN NN _ 7 tmod _ _ 84 | 7 sold _ VBD VBD _ 0 root _ _ 85 | 8 1,214 _ CD CD _ 9 num _ _ 86 | 9 cars _ NNS NNS _ 7 dobj _ _ 87 | 10 in _ IN IN _ 7 prep _ _ 88 | 11 the _ DT DT _ 12 det _ _ 89 | 12 U.S. _ NNP NNP _ 10 pobj _ _ 90 | 91 | -------------------------------------------------------------------------------- /sample_data/amr-split/amr-dev.txt: -------------------------------------------------------------------------------- 1 | # AMR-English alignment release (generated on Fri Feb 2, 2018 at 16:13:17) 2 | 3 | # ::id bolt12_64545_0526.1 ::amr-annotator SDL-AMR-09 ::preferred 4 | # ::tok There are many who have a sense of urgency , quietly watching how things develop , you are dragons coiling , you are tigers crouching , I admire noble @-@ minded patriots . 5 | # ::alignments 1-1.2.1.r 2-1.1 6-1.1.1 7-1.1.1.1.r 8-1.1.1.1 10-1.1.1.2.3 11-1.1.1.2 12-1.1.1.2.3.r 13-1.1.1.2.2 13-1.1.1.2.2.1.1 14-1.1.1.2.2.1 16-1.2.1 17-1.2.1.r 18-1.2 19-1.2.2 21-1.3.1 22-1.3.1.r 23-1.3 24-1.3.2 26-1.4.1 27-1.4 28-1.4.2.1.1 30-1.4.2.1 31-1.4.2 6 | (m / multi-sentence 7 | :snt1 (m2 / many~e.2 8 | :ARG0-of (s / sense-01~e.6 9 | :ARG1~e.7 (u / urgency~e.8) 10 | :time (w / watch-01~e.11 11 | :ARG0 m2 12 | :ARG1 (t3 / thing~e.13 13 | :manner-of (d / develop-02~e.14 14 | :ARG0 (t / thing~e.13))) 15 | :manner~e.12 (q / quiet-04~e.10 16 | :ARG1 m2)))) 17 | :snt2 (d2 / dragon~e.18 18 | :domain~e.1,17 (y / you~e.16) 19 | :ARG0-of (c / coil-01~e.19)) 20 | :snt3 (t2 / tiger~e.23 21 | :domain~e.22 (y2 / you~e.21) 22 | :ARG0-of (c2 / crouch-01~e.24)) 23 | :snt4 (a / admire-01~e.27 24 | :ARG0 (i / i~e.26) 25 | :ARG1 (p / patriot~e.31 26 | :poss-of (m3 / mind~e.30 27 | :mod (n / noble~e.28))))) 28 | 29 | # ::id bolt12_64545_0526.2 ::amr-annotator SDL-AMR-09 ::preferred 30 | # ::tok Has history given us too many lessons ? , 530 , 412 , 64 31 | # ::alignments 0-1.1.2.1 1-1.1.1 2-1.1 3-1.1.3 4-1.1.2.1.2 5-1.1.2.1.1 6-1.1.2 7-1.1.4 7-1.1.4.r 9-1.2.1 11-1.2.2 13-1.2.3 32 | (m2 / multi-sentence 33 | :snt1 (g / give-01~e.2 34 | :ARG0 (h / history~e.1) 35 | :ARG1 (l / lesson~e.6 36 | :ARG1-of (h2 / have-quant-91~e.0 37 | :ARG2 (m / many~e.5) 38 | :ARG3 (t / too~e.4))) 39 | :ARG2 (w / we~e.3) 40 | :polarity~e.7 (a2 / amr-unknown~e.7)) 41 | :snt2 (a / and :op1 530~e.9 :op2 412~e.11 :op3 64~e.13)) 42 | 43 | # ::id bolt12_64545_0527.1 ::amr-annotator SDL-AMR-09 ::preferred 44 | # ::tok taking a look 45 | # ::alignments 0-1 2-1 46 | (l / look-01~e.0,2) 47 | 48 | # ::id bolt12_64545_0528.1 ::amr-annotator SDL-AMR-09 ::preferred 49 | # ::tok the ones who are suffering are the ordinary people : even if the body of a salted fish is turned over , it is still a salted fish ... 50 | # ::alignments 3-1.2.3.r 4-1.1 7-1.1.1.1 8-1.1.1 10-1.2.4 11-1.2.4 13-1.2.4.1.1 14-1.2.4.1.1.1.r 16-1.2.4.1.1.1.1 17-1.2.4.1.1.1 19-1.2.4.1 20-1.2.4.1.2 22-1.2.3 23-1.2.3.r 24-1.2.2 26-1.2.1 27-1.2 51 | (m / multi-sentence 52 | :snt1 (s / suffer-01~e.4 53 | :ARG0 (p / person~e.8 54 | :mod (o2 / ordinary~e.7))) 55 | :snt2 (f / fish~e.27 56 | :ARG1-of (s2 / salt-01~e.26) 57 | :mod (s3 / still~e.24) 58 | :domain~e.3,23 f2~e.22 59 | :concession (e / even-if~e.10,11 60 | :op1 (t / turn-01~e.19 61 | :ARG1 (b / body~e.13 62 | :poss~e.14 (f2 / fish~e.17 63 | :ARG1-of (s4 / salt-01~e.16))) 64 | :direction (o3 / over~e.20))))) -------------------------------------------------------------------------------- /sample_data/ptb-gold/test.conllx: -------------------------------------------------------------------------------- 1 | 1 No _ RB RB _ 7 discourse _ _ 2 | 2 , _ , , _ 7 punct _ _ 3 | 3 it _ PRP PRP _ 7 nsubj _ _ 4 | 4 was _ VBD VBD _ 7 cop _ _ 5 | 5 n't _ RB RB _ 7 neg _ _ 6 | 6 Black _ NNP NNP _ 7 nn _ _ 7 | 7 Monday _ NNP NNP _ 0 root _ _ 8 | 8 . _ . . _ 7 punct _ _ 9 | 10 | 1 But _ CC CC _ 33 cc _ _ 11 | 2 while _ IN IN _ 10 mark _ _ 12 | 3 the _ DT DT _ 7 det _ _ 13 | 4 New _ NNP NNP _ 7 nn _ _ 14 | 5 York _ NNP NNP _ 7 nn _ _ 15 | 6 Stock _ NNP NNP _ 7 nn _ _ 16 | 7 Exchange _ NNP NNP _ 10 nsubj _ _ 17 | 8 did _ VBD VBD _ 10 aux _ _ 18 | 9 n't _ RB RB _ 10 neg _ _ 19 | 10 fall _ VB VB _ 33 advcl _ _ 20 | 11 apart _ RB RB _ 10 advmod _ _ 21 | 12 Friday _ NNP NNP _ 10 tmod _ _ 22 | 13 as _ IN IN _ 19 mark _ _ 23 | 14 the _ DT DT _ 18 det _ _ 24 | 15 Dow _ NNP NNP _ 18 nn _ _ 25 | 16 Jones _ NNP NNP _ 18 nn _ _ 26 | 17 Industrial _ NNP NNP _ 18 nn _ _ 27 | 18 Average _ NNP NNP _ 19 nsubj _ _ 28 | 19 plunged _ VBD VBD _ 10 advcl _ _ 29 | 20 190.58 _ CD CD _ 21 num _ _ 30 | 21 points _ NNS NNS _ 19 dobj _ _ 31 | 22 -- _ : : _ 23 punct _ _ 32 | 23 most _ JJS JJS _ 21 dep _ _ 33 | 24 of _ IN IN _ 23 prep _ _ 34 | 25 it _ PRP PRP _ 24 pobj _ _ 35 | 26 in _ IN IN _ 23 prep _ _ 36 | 27 the _ DT DT _ 29 det _ _ 37 | 28 final _ JJ JJ _ 29 amod _ _ 38 | 29 hour _ NN NN _ 26 pobj _ _ 39 | 30 -- _ : : _ 23 punct _ _ 40 | 31 it _ PRP PRP _ 33 nsubj _ _ 41 | 32 barely _ RB RB _ 33 advmod _ _ 42 | 33 managed _ VBD VBD _ 0 root _ _ 43 | 34 to _ TO TO _ 37 aux _ _ 44 | 35 stay _ VB VB _ 37 cop _ _ 45 | 36 this _ DT DT _ 37 det _ _ 46 | 37 side _ NN NN _ 33 xcomp _ _ 47 | 38 of _ IN IN _ 37 prep _ _ 48 | 39 chaos _ NN NN _ 38 pobj _ _ 49 | 40 . _ . . _ 33 punct _ _ 50 | 51 | 1 Some _ DT DT _ 4 det _ _ 52 | 2 `` _ `` `` _ 4 punct _ _ 53 | 3 circuit _ NN NN _ 4 nn _ _ 54 | 4 breakers _ NNS NNS _ 12 nsubj _ _ 55 | 5 '' _ '' '' _ 4 punct _ _ 56 | 6 installed _ VBN VBN _ 4 partmod _ _ 57 | 7 after _ IN IN _ 6 prep _ _ 58 | 8 the _ DT DT _ 11 det _ _ 59 | 9 October _ NNP NNP _ 11 nn _ _ 60 | 10 1987 _ CD CD _ 11 num _ _ 61 | 11 crash _ NN NN _ 7 pobj _ _ 62 | 12 failed _ VBD VBD _ 0 root _ _ 63 | 13 their _ PRP$ PRP$ _ 15 poss _ _ 64 | 14 first _ JJ JJ _ 15 amod _ _ 65 | 15 test _ NN NN _ 12 dobj _ _ 66 | 16 , _ , , _ 18 punct _ _ 67 | 17 traders _ NNS NNS _ 18 nsubj _ _ 68 | 18 say _ VBP VBP _ 12 parataxis _ _ 69 | 19 , _ , , _ 18 punct _ _ 70 | 20 unable _ JJ JJ _ 12 dep _ _ 71 | 21 to _ TO TO _ 22 aux _ _ 72 | 22 cool _ VB VB _ 20 xcomp _ _ 73 | 23 the _ DT DT _ 25 det _ _ 74 | 24 selling _ NN NN _ 25 nn _ _ 75 | 25 panic _ NN NN _ 22 dobj _ _ 76 | 26 in _ IN IN _ 25 prep _ _ 77 | 27 both _ DT DT _ 28 preconj _ _ 78 | 28 stocks _ NNS NNS _ 26 pobj _ _ 79 | 29 and _ CC CC _ 28 cc _ _ 80 | 30 futures _ NNS NNS _ 28 conj _ _ 81 | 31 . _ . . _ 12 punct _ _ 82 | 83 | 1 The _ DT DT _ 5 det _ _ 84 | 2 49 _ CD CD _ 5 num _ _ 85 | 3 stock _ NN NN _ 5 nn _ _ 86 | 4 specialist _ NN NN _ 5 nn _ _ 87 | 5 firms _ NNS NNS _ 31 nsubj _ _ 88 | 6 on _ IN IN _ 5 prep _ _ 89 | 7 the _ DT DT _ 10 det _ _ 90 | 8 Big _ NNP NNP _ 10 nn _ _ 91 | 9 Board _ NNP NNP _ 10 nn _ _ 92 | 10 floor _ NN NN _ 6 pobj _ _ 93 | 11 -- _ : : _ 5 punct _ _ 94 | 12 the _ DT DT _ 13 det _ _ 95 | 13 buyers _ NNS NNS _ 5 dep _ _ 96 | 14 and _ CC CC _ 13 cc _ _ 97 | 15 sellers _ NNS NNS _ 13 conj _ _ 98 | 16 of _ IN IN _ 13 prep _ _ 99 | 17 last _ JJ JJ _ 18 amod _ _ 100 | 18 resort _ NN NN _ 16 pobj _ _ 101 | 19 who _ WP WP _ 21 nsubjpass _ _ 102 | 20 were _ VBD VBD _ 21 auxpass _ _ 103 | 21 criticized _ VBN VBN _ 13 rcmod _ _ 104 | 22 after _ IN IN _ 21 prep _ _ 105 | 23 the _ DT DT _ 25 det _ _ 106 | 24 1987 _ CD CD _ 25 num _ _ 107 | 25 crash _ NN NN _ 22 pobj _ _ 108 | 26 -- _ : : _ 5 punct _ _ 109 | 27 once _ RB RB _ 28 advmod _ _ 110 | 28 again _ RB RB _ 31 advmod _ _ 111 | 29 could _ MD MD _ 31 aux _ _ 112 | 30 n't _ RB RB _ 31 neg _ _ 113 | 31 handle _ VB VB _ 0 root _ _ 114 | 32 the _ DT DT _ 34 det _ _ 115 | 33 selling _ NN NN _ 34 nn _ _ 116 | 34 pressure _ NN NN _ 31 dobj _ _ 117 | 35 . _ . . _ 31 punct _ _ 118 | 119 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## Bird's Eye: Probing for Linguistic Graph Structures with a Simple Information-Theoretic Approach 2 | 3 | #### Authors: [Yifan Hou](https://yifan-h.github.io/), [Mrinmaya Sachan](http://www.mrinmaya.io/) 4 | 5 | ### Overview 6 | 7 | NLP has a rich history of representing our prior understanding of language in the form of graphs. Recent work on analyzing contextualized text representations has focused on hand-designed probe models to understand how and to what extent do these representations encode a particular linguistic phenomenon. However, due to the inter-dependence of various phenomena and randomness of training probe models, detecting how these representations encode the rich information in these linguistic graphs remains a challenging problem. 8 | 9 | In this paper, we propose a new information-theoretic probe, Bird's Eye, which is a fairly simple probe method for detecting if and how these representations encode the information in these linguistic graphs. Instead of using classifier performance, our probe takes an information-theoretic view of probing and estimates the mutual information between the linguistic graph embedded in a continuous space and the contextualized word representations. Furthermore, we also propose an approach to use our probe to investigate localized linguistic information in the linguistic graphs using perturbation analysis. We call this probing setup Worm's Eye. Using these probes, we analyze BERT models on their ability to encode a syntactic and a semantic graph structure, and find that these models encode to some degree both syntactic as well as semantic information; albeit syntactic information to a greater extent. 10 | 11 | Please see the [paper](https://arxiv.org/abs/2105.02629) for more details. 12 | 13 | *Note:* If you make use of this code, or the probing model in your work, please cite the following paper *(formal bibliograph will be updated soon)*: 14 | 15 | @inproceedings{DBLP:conf/acl/HouS20, 16 | author = {Yifan Hou and Mrinmaya Sachan}, 17 | editor = {Chengqing Zong and Fei Xia and Wenjie Li and Roberto Navigli}, 18 | title = {Bird's Eye: Probing for Linguistic Graph Structures with a Simple Information-Theoretic Approach}, 19 | booktitle = {Proceedings of the 59th Annual Meeting of the Association for Computational 20 | Linguistics and the 11th International Joint Conference on Natural 21 | Language Processing, {ACL/IJCNLP} 2021, (Volume 1: Long Papers), Virtual 22 | Event, August 1-6, 2021}, 23 | pages = {1844--1859}, 24 | publisher = {Association for Computational Linguistics}, 25 | year = {2021}, 26 | url = {https://doi.org/10.18653/v1/2021.acl-long.145}, 27 | doi = {10.18653/v1/2021.acl-long.145}, 28 | timestamp = {Fri, 30 Jul 2021 16:41:20 +0200}, 29 | biburl = {https://dblp.org/rec/conf/acl/HouS20.bib}, 30 | bibsource = {dblp computer science bibliography, https://dblp.org} 31 | } 32 | 33 | 34 | ### Requirements 35 | 36 | [PyTorch](https://pytorch.org/get-started/locally/) (1.8.1) should be installed based on your system. For other libraries, recent versions of: tqdm, networkx, numpy, sklearn, gensim, transformers, penman, and conllu are required. You can install those required packages using the following command: 37 | 38 | $ pip install -r requirements.txt 39 | 40 | ### How to run 41 | 42 | You can refer to comments and hyperparameters in *./src/main.py* for functions to reproduce all experiments in the paper. 43 | 44 | Simply run *./src/main.py*, then the *BERT-base-uncased model* is probed with *Penn Treebank* sample data in a layer-wise manner: 45 | 46 | $ python ./src/main.py 47 | 48 | #### Input Format 49 | 50 | You can refer to the sample data for the information about input format. 51 | 52 | * syntax tree: Penn Treebank -- conllx format, parsed by CoNLL-U Parser (conllu). 53 | * semantic graph: AMR Bank -- PENMAN notation format, parsed by PENMAN graph notation library (penman) 54 | 55 | #### Usage of *Bird's Eye* 56 | 57 | You can preprocess your data to the same formats as the sample data (syntax or semantics). Then run the code directly to get probing results. Difference models in [transformers](https://huggingface.co/transformers/) can be selected flexibly by changing several hyperparameters. 58 | 59 | ### Academic Paper 60 | 61 | [**ACL 2021**] **Bird's Eye: Probing for Linguistic Graph Structures with a Simple Information-Theoretic Approach**, Yifan Hou, Mrinmaya Sachan 62 | -------------------------------------------------------------------------------- /sample_data/ptb-gold/dev.conllx: -------------------------------------------------------------------------------- 1 | 1 Influential _ JJ JJ _ 2 amod _ _ 2 | 2 members _ NNS NNS _ 10 nsubj _ _ 3 | 3 of _ IN IN _ 2 prep _ _ 4 | 4 the _ DT DT _ 6 det _ _ 5 | 5 House _ NNP NNP _ 6 nn _ _ 6 | 6 Ways _ NNP NNP _ 3 pobj _ _ 7 | 7 and _ CC CC _ 6 cc _ _ 8 | 8 Means _ NNP NNP _ 9 nn _ _ 9 | 9 Committee _ NNP NNP _ 6 conj _ _ 10 | 10 introduced _ VBD VBD _ 0 root _ _ 11 | 11 legislation _ NN NN _ 10 dobj _ _ 12 | 12 that _ WDT WDT _ 14 nsubj _ _ 13 | 13 would _ MD MD _ 14 aux _ _ 14 | 14 restrict _ VB VB _ 11 rcmod _ _ 15 | 15 how _ WRB WRB _ 22 advmod _ _ 16 | 16 the _ DT DT _ 20 det _ _ 17 | 17 new _ JJ JJ _ 20 amod _ _ 18 | 18 savings-and-loan _ NN NN _ 20 nn _ _ 19 | 19 bailout _ NN NN _ 20 nn _ _ 20 | 20 agency _ NN NN _ 22 nsubj _ _ 21 | 21 can _ MD MD _ 22 aux _ _ 22 | 22 raise _ VB VB _ 14 ccomp _ _ 23 | 23 capital _ NN NN _ 22 dobj _ _ 24 | 24 , _ , , _ 14 punct _ _ 25 | 25 creating _ VBG VBG _ 14 xcomp _ _ 26 | 26 another _ DT DT _ 28 det _ _ 27 | 27 potential _ JJ JJ _ 28 amod _ _ 28 | 28 obstacle _ NN NN _ 25 dobj _ _ 29 | 29 to _ TO TO _ 28 prep _ _ 30 | 30 the _ DT DT _ 31 det _ _ 31 | 31 government _ NN NN _ 33 poss _ _ 32 | 32 's _ POS POS _ 31 possessive _ _ 33 | 33 sale _ NN NN _ 29 pobj _ _ 34 | 34 of _ IN IN _ 33 prep _ _ 35 | 35 sick _ JJ JJ _ 36 amod _ _ 36 | 36 thrifts _ NNS NNS _ 34 pobj _ _ 37 | 37 . _ . . _ 10 punct _ _ 38 | 39 | 1 The _ DT DT _ 2 det _ _ 40 | 2 bill _ NN NN _ 17 nsubj _ _ 41 | 3 , _ , , _ 2 punct _ _ 42 | 4 whose _ WP$ WP$ _ 5 poss _ _ 43 | 5 backers _ NNS NNS _ 6 nsubj _ _ 44 | 6 include _ VBP VBP _ 2 rcmod _ _ 45 | 7 Chairman _ NNP NNP _ 9 nn _ _ 46 | 8 Dan _ NNP NNP _ 9 nn _ _ 47 | 9 Rostenkowski _ NNP NNP _ 6 dobj _ _ 48 | 10 -LRB- _ -LRB- -LRB- _ 11 punct _ _ 49 | 11 D. _ NNP NNP _ 9 appos _ _ 50 | 12 , _ , , _ 11 punct _ _ 51 | 13 Ill. _ NNP NNP _ 11 dep _ _ 52 | 14 -RRB- _ -RRB- -RRB- _ 11 punct _ _ 53 | 15 , _ , , _ 2 punct _ _ 54 | 16 would _ MD MD _ 17 aux _ _ 55 | 17 prevent _ VB VB _ 0 root _ _ 56 | 18 the _ DT DT _ 21 det _ _ 57 | 19 Resolution _ NNP NNP _ 21 nn _ _ 58 | 20 Trust _ NNP NNP _ 21 nn _ _ 59 | 21 Corp. _ NNP NNP _ 17 dobj _ _ 60 | 22 from _ IN IN _ 17 prep _ _ 61 | 23 raising _ VBG VBG _ 22 pcomp _ _ 62 | 24 temporary _ JJ JJ _ 26 amod _ _ 63 | 25 working _ VBG VBG _ 26 amod _ _ 64 | 26 capital _ NN NN _ 23 dobj _ _ 65 | 27 by _ IN IN _ 17 prep _ _ 66 | 28 having _ VBG VBG _ 27 pcomp _ _ 67 | 29 an _ DT DT _ 31 det _ _ 68 | 30 RTC-owned _ JJ JJ _ 31 amod _ _ 69 | 31 bank _ NN NN _ 28 dobj _ _ 70 | 32 or _ CC CC _ 31 cc _ _ 71 | 33 thrift _ NN NN _ 35 nn _ _ 72 | 34 issue _ NN NN _ 35 nn _ _ 73 | 35 debt _ NN NN _ 31 conj _ _ 74 | 36 that _ WDT WDT _ 40 nsubjpass _ _ 75 | 37 would _ MD MD _ 40 aux _ _ 76 | 38 n't _ RB RB _ 40 neg _ _ 77 | 39 be _ VB VB _ 40 auxpass _ _ 78 | 40 counted _ VBN VBN _ 31 rcmod _ _ 79 | 41 on _ IN IN _ 40 prep _ _ 80 | 42 the _ DT DT _ 44 det _ _ 81 | 43 federal _ JJ JJ _ 44 amod _ _ 82 | 44 budget _ NN NN _ 41 pobj _ _ 83 | 45 . _ . . _ 17 punct _ _ 84 | 85 | 1 The _ DT DT _ 2 det _ _ 86 | 2 bill _ NN NN _ 3 nsubj _ _ 87 | 3 intends _ VBZ VBZ _ 0 root _ _ 88 | 4 to _ TO TO _ 5 aux _ _ 89 | 5 restrict _ VB VB _ 3 xcomp _ _ 90 | 6 the _ DT DT _ 7 det _ _ 91 | 7 RTC _ NNP NNP _ 5 dobj _ _ 92 | 8 to _ TO TO _ 5 prep _ _ 93 | 9 Treasury _ NNP NNP _ 10 nn _ _ 94 | 10 borrowings _ NNS NNS _ 8 pobj _ _ 95 | 11 only _ RB RB _ 10 advmod _ _ 96 | 12 , _ , , _ 3 punct _ _ 97 | 13 unless _ IN IN _ 16 mark _ _ 98 | 14 the _ DT DT _ 15 det _ _ 99 | 15 agency _ NN NN _ 16 nsubj _ _ 100 | 16 receives _ VBZ VBZ _ 3 advcl _ _ 101 | 17 specific _ JJ JJ _ 19 amod _ _ 102 | 18 congressional _ JJ JJ _ 19 amod _ _ 103 | 19 authorization _ NN NN _ 16 dobj _ _ 104 | 20 . _ . . _ 3 punct _ _ 105 | 106 | 1 `` _ `` `` _ 22 punct _ _ 107 | 2 Such _ JJ JJ _ 7 amod _ _ 108 | 3 agency _ NN NN _ 7 nn _ _ 109 | 4 ` _ `` `` _ 7 punct _ _ 110 | 5 self-help _ NN NN _ 7 nn _ _ 111 | 6 ' _ '' '' _ 7 punct _ _ 112 | 7 borrowing _ NN NN _ 9 nsubj _ _ 113 | 8 is _ VBZ VBZ _ 9 cop _ _ 114 | 9 unauthorized _ JJ JJ _ 22 ccomp _ _ 115 | 10 and _ CC CC _ 9 cc _ _ 116 | 11 expensive _ JJ JJ _ 9 conj _ _ 117 | 12 , _ , , _ 9 punct _ _ 118 | 13 far _ RB RB _ 15 advmod _ _ 119 | 14 more _ RBR RBR _ 15 advmod _ _ 120 | 15 expensive _ JJ JJ _ 9 dep _ _ 121 | 16 than _ IN IN _ 15 prep _ _ 122 | 17 direct _ JJ JJ _ 19 amod _ _ 123 | 18 Treasury _ NNP NNP _ 19 nn _ _ 124 | 19 borrowing _ NN NN _ 16 pobj _ _ 125 | 20 , _ , , _ 22 punct _ _ 126 | 21 '' _ '' '' _ 22 punct _ _ 127 | 22 said _ VBD VBD _ 0 root _ _ 128 | 23 Rep. _ NNP NNP _ 25 nn _ _ 129 | 24 Fortney _ NNP NNP _ 25 nn _ _ 130 | 25 Stark _ NNP NNP _ 22 nsubj _ _ 131 | 26 -LRB- _ -LRB- -LRB- _ 27 punct _ _ 132 | 27 D. _ NNP NNP _ 25 appos _ _ 133 | 28 , _ , , _ 27 punct _ _ 134 | 29 Calif. _ NNP NNP _ 27 dep _ _ 135 | 30 -RRB- _ -RRB- -RRB- _ 27 punct _ _ 136 | 31 , _ , , _ 25 punct _ _ 137 | 32 the _ DT DT _ 33 det _ _ 138 | 33 bill _ NN NN _ 36 poss _ _ 139 | 34 's _ POS POS _ 33 possessive _ _ 140 | 35 chief _ JJ JJ _ 36 amod _ _ 141 | 36 sponsor _ NN NN _ 25 appos _ _ 142 | 37 . _ . . _ 22 punct _ _ 143 | 144 | -------------------------------------------------------------------------------- /sample_data/amr-split/amr-training.txt: -------------------------------------------------------------------------------- 1 | # AMR-English alignment release (generated on Fri Feb 2, 2018 at 16:13:17) 2 | 3 | # ::id bolt12_07_4800.1 ::amr-annotator SDL-AMR-09 ::preferred 4 | # ::tok Establishing Models in Industrial Innovation 5 | # ::alignments 0-1 1-1.1 2-1.1.1.r 3-1.1.1.1 4-1.1.1 6 | (e / establish-01~e.0 7 | :ARG1 (m / model~e.1 8 | :mod~e.2 (i / innovate-01~e.4 9 | :ARG1 (i2 / industry~e.3)))) 10 | 11 | # ::id bolt12_07_4800.2 ::amr-annotator SDL-AMR-09 ::preferred 12 | # ::tok After its competitor invented the front loading washing machine , the CEO of the American IM company believed that each of its employees had the ability for innovation , and formulated strategic countermeasures for innovation in the industry . 13 | # ::alignments 0-1.3 1-1.3.1.1.1.1 1-1.3.1.1.1.1.r 2-1.3.1.1.1 3-1.3.1 5-1.3.1.2.2.1 6-1.3.1.2.2 7-1.3.1.2.1 8-1.3.1.2 11-1.1.1.1.2 11-1.1.1.1.2.1 11-1.1.1.1.2.1.r 11-1.1.1.1.2.2 11-1.1.1.1.2.2.r 11-1.2.1 11-1.2.1.1 11-1.2.1.1.r 11-1.2.1.2 11-1.2.1.2.r 12-1.1.1.1 14-1.1.1.1.1.3.2.1 14-1.1.1.1.1.3.2.2 15-1.1.1.1.1.2.1 16-1.3.1.1 17-1.1 18-1.1.2.r 19-1.1.2.1.2 22-1.1.2.1 22-1.1.2.1.1 22-1.1.2.1.1.r 25-1.1.2 26-1.1.2.2.r 27-1.1.2.2 29-1 30-1.2 31-1.2.2.1 32-1.2.2 33-1.2.2.2.r 34-1.2.2.2 35-1.2.2.2.1.r 37-1.2.2.2.1 14 | (a / and~e.29 15 | :op1 (b / believe-01~e.17 16 | :ARG0 (p2 / person 17 | :ARG0-of (h2 / have-org-role-91~e.12 18 | :ARG1 (c2 / company :wiki - 19 | :name (n / name :op1 "IM"~e.15) 20 | :mod (c3 / country :wiki "United_States" 21 | :name (n2 / name :op1 "United"~e.14 :op2 "States"~e.14))) 22 | :ARG2 (o / officer~e.11 23 | :mod~e.11 (e3 / executive~e.11) 24 | :mod~e.11 (c7 / chief~e.11)))) 25 | :ARG1~e.18 (c8 / capable-01~e.25 26 | :ARG1 (p / person~e.22 27 | :ARG1-of~e.22 (e / employ-01~e.22 28 | :ARG0 c2) 29 | :mod (e2 / each~e.19)) 30 | :ARG2~e.26 (i / innovate-01~e.27 31 | :ARG0 p))) 32 | :op2 (f / formulate-01~e.30 33 | :ARG0 (o2 / officer~e.11 34 | :mod~e.11 (e4 / executive~e.11) 35 | :mod~e.11 (c / chief~e.11)) 36 | :ARG1 (c4 / countermeasure~e.32 37 | :mod (s / strategy~e.31) 38 | :purpose~e.33 (i2 / innovate-01~e.34 39 | :topic~e.35 (i3 / industry~e.37)))) 40 | :time (a3 / after~e.0 41 | :op1 (i4 / invent-01~e.3 42 | :ARG0 (c5 / company~e.16 43 | :ARG0-of (c6 / compete-02~e.2 44 | :ARG1~e.1 c2~e.1)) 45 | :ARG1 (m / machine~e.8 46 | :ARG0-of (w / wash-01~e.7) 47 | :ARG1-of (l / load-01~e.6 48 | :mod (f2 / front~e.5)))))) 49 | 50 | # ::id bolt12_07_4800.3 ::amr-annotator SDL-AMR-09 ::preferred 51 | # ::tok 1 . Establish an innovation fund with a maximum amount of 1,000 U.S. dollars . 52 | # ::alignments 0-1.1 2-1 4-1.2.1 5-1.2 6-1.2.2.r 8-1.2.2.1 9-1.2.2 10-1.2.2.1.1.r 11-1.2.2.1.1.1 12-1.2.2.1.1.2.1.2.1 12-1.2.2.1.1.2.1.2.2 13-1.2.2.1.1.2 53 | (e / establish-01~e.2 :li 1~e.0 54 | :ARG1 (f2 / fund~e.5 55 | :purpose (i / innovate-01~e.4) 56 | :ARG1-of~e.6 (a / amount-01~e.9 57 | :ARG2 (a2 / at-most~e.8 58 | :op1~e.10 (m / monetary-quantity :quant 1000~e.11 59 | :unit (d / dollar~e.13 60 | :mod (c / country :wiki "United_States" 61 | :name (n / name :op1 "United"~e.12 :op2 "States"~e.12)))))))) 62 | 63 | # ::id bolt12_07_4800.4 ::amr-annotator SDL-AMR-09 ::preferred 64 | # ::tok 2 . Choose 100 innovative concepts to encourage employees to conduct research and development during their work time or spare time . 65 | # ::alignments 0-1.1 2-1 3-1.2.1 4-1.2.2 5-1.2 7-1.3 8-1.3.2 8-1.3.2.1 8-1.3.2.1.r 11-1.3.3.1 12-1.3.3 13-1.3.3.2 14-1.3.3.3.r 15-1.3.3.3.1.1 15-1.3.3.3.1.1.r 16-1.3.3.3.1 17-1.3.3.3.2 17-1.3.3.3.r 18-1.3.3.3 19-1.3.3.3.2.2 20-1.3.3.3.2 66 | (c / choose-01~e.2 :li 2~e.0 67 | :ARG1 (c2 / concept~e.5 :quant 100~e.3 68 | :ARG1-of (i / innovate-01~e.4)) 69 | :purpose (e / encourage-01~e.7 70 | :ARG0 c2 71 | :ARG1 (p / person~e.8 72 | :ARG1-of~e.8 (e2 / employ-01~e.8)) 73 | :ARG2 (a / and~e.12 74 | :op1 (r / research-01~e.11 75 | :ARG0 p) 76 | :op2 (d / develop-02~e.13 77 | :ARG0 p) 78 | :time~e.14,17 (o / or~e.18 79 | :op1 (w / work-01~e.16 80 | :ARG0~e.15 p~e.15) 81 | :op2 (t2 / time~e.17,20 82 | :poss p 83 | :mod (s / spare~e.19)))))) -------------------------------------------------------------------------------- /src/main.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import time 4 | import json 5 | import numpy as np 6 | import argparse 7 | 8 | from probe import mi_bert_ptb, mi_bert_amr, mi_mlps_ptb, mi_mlps_amr 9 | from evaluation_ptb import test_ge_ptb, test_bert_ptb, test_mi_ptb, mi_noise_ptb, test_random_ptb 10 | from evaluation_amr import test_ge_amr, test_bert_amr, test_mi_amr, mi_noise_amr, test_random_amr 11 | 12 | def main_func(args): 13 | # cpu or gpu 14 | if args.device < 0: 15 | device = torch.device("cpu") 16 | else: 17 | device = torch.device("cuda:" + str(args.device)) 18 | args.device = device 19 | 20 | ''' 21 | "ptb" = penn tree bank 22 | "amr" = amr bank 23 | "_bert" = no data split 24 | "_graph" = train/dev/test split (default) 25 | ''' 26 | 27 | ## function descriptions 28 | ## please refer to specific py file for sub-function descriptions 29 | ''' 30 | mi_bert_ptb: probe BERT layers with PTB dataset (uncontext=True for ELMo0) 31 | mi_bert_amr: probe BERT layers with AMR dataset (uncontext=True for ELMo0) 32 | mi_mlps_ptb: probe GloVe embeddings with PTB dataset 33 | mi_mlps_amr: probe GloVe embeddings with AMR dataset 34 | test_ge_ptb: test the graph embedding quality with PTB dataset 35 | test_ge_amr: test the graph embedding quality with AMR dataset 36 | test_bert_ptb: test the BERT embedding quality (recover original graphs) with PTB dataset 37 | test_bert_amr: test the BERT embedding quality (recover original graphs) with AMR dataset 38 | test_mi_ptb: calculate MI upper bound (global) with PTB dataset (different levels of noise) 39 | test_mi_amr: calculate MI upper bound (global) with AMR dataset (different levels of noise) 40 | mi_noise_ptb: calculate MI estimate I(X,G) (local) with PTB dataset (different corrupted types of edges) 41 | mi_noise_amr: calculate MI estimate I(X,G) (local) with AMR dataset (different corrupted types of edges) 42 | test_random_ptb: calculate classifier accuracy (local) with PTB dataset (different corrupted types of edges) 43 | test_random_amr: calculate classifier accuracy (local) with AMR dataset (different corrupted types of edges) 44 | ''' 45 | 46 | if args.task == 'ptb_bert': 47 | # mi_noise_ptb(args, pos=True) 48 | mi_bert_ptb(args) 49 | # mi_bert_ptb(args, uncontext=True) 50 | # test_mi_ptb(args) 51 | # mi_noise_ptb(args) 52 | # test_ge_ptb(args) 53 | # test_ge_ptb(args, data_split=False) 54 | # mi_mlps_ptb(args) 55 | pass 56 | elif args.task == 'ptb_graph': 57 | # test_random_ptb(args) 58 | # test_random_ptb(args, corrupt=True) 59 | # test_ge_ptb(args) 60 | # test_bert_ptb(args) 61 | # test_bert_ptb(args, model_name='elmo') 62 | # test_bert_ptb(args, model_name='glove') 63 | pass 64 | elif args.task == 'amr_bert': 65 | # mi_bert_ptb(args, npeet=True) 66 | # mi_bert_amr(args) 67 | # mi_bert_amr(args, model='elmo') 68 | # test_mi_amr(args) 69 | # mi_noise_amr(args) 70 | # test_ge_ptb(args) 71 | # test_ge_amr(args, data_split=False) 72 | # mi_mlps_amr(args) 73 | pass 74 | elif args.task == 'amr_graph': 75 | # test_random_amr(args) 76 | # test_random_amr(args, corrupt=True) 77 | # test_ge_amr(args) 78 | # test_bert_amr(args) 79 | # test_bert_amr(args, model_name='elmo') 80 | # test_bert_amr(args, model_name='glove') 81 | pass 82 | return 83 | 84 | 85 | 86 | 87 | if __name__ == '__main__': 88 | parser = argparse.ArgumentParser(description='Bird\'s Eye Probing. Please refer main_func for specific task selection') 89 | parser.add_argument("--device", type=int, default=0, 90 | help="which GPU to use. set -1 to use CPU.") 91 | parser.add_argument("--lr", type=float, default=1e-4, 92 | help="learning rate") 93 | parser.add_argument("--patience", type=int, default=2, 94 | help="early stop patience number") 95 | parser.add_argument("--baselines", type=bool, default=False, 96 | help="whether calculate baselines of MI") 97 | parser.add_argument("--repeat", type=int, default=5, 98 | help="number of repeat time of MI calculation") 99 | parser.add_argument("--classifier-layers-num", type=int, default=5, 100 | help="number of layers of binary classifier") 101 | parser.add_argument("--bert-layers-num", type=int, default=13, 102 | help="number of layers of BERT model + 1") 103 | parser.add_argument("--bert-hidden-num", type=int, default=768, 104 | help="number of hidden units of BERT model") 105 | parser.add_argument("--hidden-num", type=int, default=128, 106 | help="number of hidden units of mutual information") 107 | parser.add_argument("--model-name", type=str, default='bert-base-uncased', 108 | help="select which the model to probe: e.g., bert-base-uncased") 109 | parser.add_argument('--task', type=str, default='ptb_bert', 110 | help="tasks: penn_treebank") 111 | args = parser.parse_args() 112 | print(args) 113 | 114 | main_func(args) 115 | -------------------------------------------------------------------------------- /sample_data/amr-split/amr-test.txt: -------------------------------------------------------------------------------- 1 | # AMR-English alignment release (generated on Fri Feb 2, 2018 at 16:13:17) 2 | 3 | # ::id bolt12_64556_5627.1 ::amr-annotator SDL-AMR-09 ::preferred 4 | # ::tok Resolutely support the thread starter ! I compose a poem in reply : 5 | # ::alignments 0-1.1.4 0-1.1.4.r 1-1.1 3-1.1.3.1.1 4-1.1.3 4-1.1.3.1 4-1.1.3.1.r 5-1.1.1 5-1.1.1.r 6-1.2.1 7-1.2.2 9-1.2.2.2 11-1.2 6 | (m / multi-sentence 7 | :snt1 (s / support-01~e.1 :mode~e.5 imperative~e.5 8 | :ARG0 (y / you) 9 | :ARG1 (p / person~e.4 10 | :ARG0-of~e.4 (s2 / start-01~e.4 11 | :ARG1 (t / thread~e.3))) 12 | :manner~e.0 (r / resolute~e.0)) 13 | :snt2 (r2 / reply-01~e.11 14 | :ARG0 (i / i~e.6) 15 | :ARG2 (c / compose-02~e.7 16 | :ARG0 i 17 | :ARG1 (p2 / poem~e.9)))) 18 | 19 | # ::id bolt12_64556_5627.2 ::amr-annotator SDL-AMR-09 ::preferred 20 | # ::tok Pledge to fight to the death defending the Diaoyu Islands and the related islands 21 | # ::alignments 0-1 2-1.3 3-1.3.3.r 5-1.3.3 6-1.3.2 8-1.3.2.2.1.2.1 9-1.3.2.2.1.2.2 10-1.3.2.2 12-1.3.2.2.2.1 13-1.3.2.2.2 22 | (p / pledge-01~e.0 :mode imperative 23 | :ARG0 (y / you) 24 | :ARG2 (f / fight-01~e.2 25 | :ARG0 y 26 | :ARG2 (d2 / defend-01~e.6 27 | :ARG0 y 28 | :ARG1 (a / and~e.10 29 | :op1 (i / island :wiki "Senkaku_Islands" 30 | :name (n / name :op1 "Diaoyu"~e.8 :op2 "Islands"~e.9)) 31 | :op2 (i2 / island~e.13 32 | :ARG1-of (r / relate-01~e.12 33 | :ARG2 i)))) 34 | :manner~e.3 (d / die-01~e.5 35 | :ARG1 y))) 36 | 37 | # ::id bolt12_64556_5627.3 ::amr-annotator SDL-AMR-09 ::preferred 38 | # ::tok Fleets bumping fishing boats . Little evil Japanese ghosts stirring up trouble and unrest . With hearts of thieves and arrogant form , they again show their wolfish appearance 39 | # ::alignments 0-1.1.2 1-1.1 2-1.1.1.1 3-1.1.1 5-1.3.1.1.2 6-1.3.1.1.3 7-1.3.1.1.1.2.1 8-1.3.1.1 9-1.3.1 10-1.3.1 11-1.3.2 12-1.3 13-1.3.1.2 16-1.2.1.1.1.1 17-1.2.1.1.1.1.1.r 18-1.2.1.1.1.1.1 18-1.2.1.1.1.1.1.1 18-1.2.1.1.1.1.1.1.r 19-1.2.1.1.1 20-1.2.1.1.1.2.1 21-1.2.1.1.1.2 23-1.2.1 24-1.2.3 25-1.2 26-1.2.2.1 26-1.2.2.1.r 27-1.2.2.2 28-1.2.2 40 | (m / multi-sentence 41 | :snt1 (b / bump-01~e.1 42 | :ARG1 (b2 / boat~e.3 43 | :purpose (f / fish-01~e.2)) 44 | :ARG2 (f2 / fleet~e.0)) 45 | :snt3 (s2 / show-01~e.25 46 | :ARG0 (t2 / they~e.23 47 | :ARG0-of (h2 / have-03 48 | :ARG1 (a4 / and~e.19 49 | :op1 (h / heart~e.16 50 | :mod~e.17 (p / person~e.18 51 | :ARG0-of~e.18 (s3 / steal-01~e.18))) 52 | :op2 (f3 / form~e.21 53 | :mod (a5 / arrogance~e.20))))) 54 | :ARG1 (a3 / appearance~e.28 55 | :poss~e.26 t2~e.26 56 | :mod (w / wolfish~e.27)) 57 | :mod (a2 / again~e.24)) 58 | :snt2 (a6 / and~e.12 59 | :op1 (s / stir-up-04~e.9,10 60 | :ARG0 (g / ghost~e.8 61 | :mod (c / country :wiki "Japan" 62 | :name (n / name :op1 "Japan"~e.7)) 63 | :mod (l / little~e.5) 64 | :mod (e / evil~e.6)) 65 | :ARG1 (u / unrest~e.13)) 66 | :op2 (m2 / make-trouble-06~e.11 67 | :ARG0 g))) 68 | 69 | # ::id bolt12_64556_5627.4 ::amr-annotator SDL-AMR-09 ::preferred 70 | # ::tok Never go back to that time , our humiliating appearance when signing the treaties . China be strong , swords be sharp and knives be shining , let 's bury the approaching enemies ! 71 | # ::alignments 0-1.1.2 0-1.1.2.r 0-1.1.5 1-1.1 2-1.1 2-1.1.6 2-1.1.6.r 3-1.1.4.r 4-1.1.4.1 5-1.1.4 5-1.1.4.2.r 7-1.1.4.2.1 7-1.1.4.2.1.r 8-1.1.4.2.2 9-1.1.4.2 10-1.1.4.2.3.r 11-1.1.4.2.3 13-1.1.4.2.3.2 15-1.2.2.2.1 17-1.2 19-1.3.1.2 21-1.3.1 22-1.3 23-1.3.2.2 25-1.3.2 27-1.1.1 27-1.1.1.r 27-1.2.1 27-1.2.1.r 27-1.3.1.1 27-1.3.2.1 27-1.3.2.1.r 28-1.4.1 28-1.4.1.r 28-1.4.2 29-1.4 31-1.4.3.1 32-1.4.3 33-1.1.1.r 72 | (m / multi-sentence 73 | :snt1 (g / go-back-19~e.1,2 :mode~e.27,33 imperative~e.27 :polarity~e.0 -~e.0 74 | :ARG1 (y / you) 75 | :ARG2~e.3 (t2 / time~e.5 76 | :mod (t3 / that~e.4) 77 | :time-of~e.5 (a / appear-01~e.9 78 | :ARG1~e.7 (w / we~e.7) 79 | :ARG0-of (h / humiliate-01~e.8) 80 | :time~e.10 (s / sign-02~e.11 81 | :ARG0 w 82 | :ARG1 (t / treaty~e.13)))) 83 | :time (e / ever~e.0) 84 | :direction~e.2 (b / back~e.2)) 85 | :snt2 (s2 / strong-02~e.17 :mode~e.27 imperative~e.27 86 | :ARG1 (c / country :wiki "China" 87 | :name (n / name :op1 "China"~e.15))) 88 | :snt3 (a2 / and~e.22 89 | :op1 (s3 / sharp-02~e.21 :mode imperative~e.27 90 | :ARG1 (s4 / sword~e.19)) 91 | :op2 (s5 / shine-01~e.25 :mode~e.27 imperative~e.27 92 | :ARG0 (k / knife~e.23))) 93 | :snt4 (b3 / bury-01~e.29 :mode~e.28 imperative~e.28 94 | :ARG0 (w2 / we~e.28) 95 | :ARG1 (e2 / enemy~e.32 96 | :ARG1-of (a4 / approach-01~e.31)))) -------------------------------------------------------------------------------- /src/embed.py: -------------------------------------------------------------------------------- 1 | import os 2 | import penman 3 | import numpy as np 4 | import networkx as nx 5 | from tqdm import tqdm 6 | import torch 7 | import torch.nn.functional as F 8 | from gensim.models import Word2Vec 9 | from transformers import BertTokenizer, BertForPreTraining 10 | from utils import random_walks, load_data, clean_string 11 | 12 | 13 | ## sub-function descriptions 14 | ## please refer to main py file for function descriptions 15 | ''' 16 | graph_embeddings: load / calculate graph embedding (PTB dataset) 17 | bert_embeddings: load / calculate BERT embedding (PTB dataset) 18 | get_embeddings: load / calculate graph embedding and BERT embedding (AMR dataset) 19 | ''' 20 | 21 | 22 | def graph_embeddings(args, global_graph, doc_id, sen_id, data_div=''): 23 | if not os.path.exists('./tmp/ge_'+args.task+data_div+'.npz'): 24 | ''' 25 | # get global graph embedding ge 26 | print('1.2 start to calculate global graph embedding...') 27 | global_walks = random_walks(global_graph) 28 | global_model = Word2Vec(global_walks, 29 | size=640, 30 | window=2, 31 | min_count=0, 32 | sg=1, 33 | hs=1, 34 | workers=20) 35 | ''' 36 | # get local graph embedding le 37 | print('1.3 start to calculate local graphs embedding...') 38 | ge_vec = [] 39 | for i in tqdm(range(len(doc_id))): 40 | # global_idx = [str(idx[1]) for idx in doc_id[i]] 41 | local_idx = [str(idx[1]) for idx in sen_id[i]] 42 | # get local graph embeddings 43 | local_graph = nx.Graph() 44 | for (s, t) in sen_id[i]: local_graph.add_edge(s, t) 45 | local_walks = random_walks(local_graph, 100, 10) 46 | if len(local_idx) > 1: 47 | local_model = Word2Vec(local_walks, 48 | size=128, 49 | window=2, 50 | min_count=0, 51 | sg=1, 52 | hs=1, 53 | workers=20) 54 | local_vec = local_model.wv[local_idx] 55 | else: 56 | local_vec = np.zeros((1,128)) 57 | # save graph embeddings (global + local) 58 | # global_vec = global_model.wv[global_idx] 59 | # ge_vec.append(np.concatenate((global_vec, local_vec), axis=1)) 60 | ge_vec.append(local_vec) 61 | 62 | # save graph embeddings 63 | savez_dict = {} 64 | for i in range(len(ge_vec)): savez_dict['s'+str(i)] = ge_vec[i] 65 | np.savez('./tmp/ge_'+args.task+data_div+'.npz', **savez_dict) 66 | 67 | return np.load('./tmp/ge_'+args.task+data_div+'.npz') 68 | 69 | 70 | def bert_embeddings(args, sentences, data_div=''): 71 | if args.model_name == 'bert-base-uncased': 72 | data_path = './tmp/be12_'+args.task+data_div+'.npz' 73 | else: 74 | data_path = './tmp/be24_'+args.task+data_div+'.npz' 75 | if not os.path.exists(data_path): 76 | # get BERT hidden representations 77 | tokenizer = BertTokenizer.from_pretrained(args.model_name) 78 | model = BertForPreTraining.from_pretrained(args.model_name, 79 | return_dict=True, 80 | output_hidden_states = True) 81 | bert_hs = [[] for l in range(args.bert_layers_num)] 82 | for s in tqdm(sentences): 83 | inputs = tokenizer(s, return_tensors="pt", is_split_into_words = True) 84 | outputs = model(**inputs) 85 | outputs = outputs.hidden_states 86 | 87 | # average word pieces to get whole word embedding 88 | s_pieces = tokenizer.tokenize(' '.join(s)) 89 | w_ids = [] 90 | for i in range(len(s_pieces)): 91 | if len(s_pieces[i]) < 2 or s_pieces[i][:2] != '##': 92 | w_ids.append(i) 93 | # check piece number 94 | if len(w_ids) != len(s): 95 | print('Error! failed to get whole word embedding!', s, s_pieces, w_ids) 96 | hidden_s = [] 97 | for l in range(len(outputs)): 98 | # remove EOS BOS tokens 99 | piece_embed = torch.squeeze(outputs[l].data)[1:-1] 100 | # get word embeddings 101 | word_embed = piece_embed[w_ids] 102 | # average word embedding 103 | for i in range(len(w_ids)-1): 104 | if w_ids[i+1] - w_ids[i] != 1: 105 | tmp_idx = [w_ids[i]+j for j in range(w_ids[i+1] - w_ids[i])] 106 | word_embed[i] = torch.mean(piece_embed[tmp_idx], dim=0) 107 | hidden_s.append(word_embed) 108 | # bert embedding for sentence s: len(s) * 768 109 | for l in range(len(hidden_s)): 110 | bert_hs[l].append(hidden_s[l].detach().cpu().data.numpy()) 111 | 112 | if len(data_div): 113 | l = args.bert_layers_num - 1 114 | savez_dict = {} 115 | for i in range(len(bert_hs[l])): savez_dict['s'+str(i)] = bert_hs[l][i] 116 | np.savez('./tmp/be'+str(l)+'_'+args.task+data_div+'.npz', **savez_dict) 117 | else: 118 | # save bert embeddings 119 | for l in range(args.bert_layers_num): 120 | savez_dict = {} 121 | for i in range(len(bert_hs[l])): savez_dict['s'+str(i)] = bert_hs[l][i] 122 | np.savez('./tmp/be'+str(l)+'_'+args.task+data_div+'.npz', **savez_dict) 123 | 124 | return ['./tmp/be'+str(l)+'_'+args.task+data_div+'.npz' for l in range(args.bert_layers_num)] 125 | 126 | 127 | def get_embeddings(args, amr_s, data_div=''): 128 | print('1. start to parse and embed the sentences...') 129 | if args.model_name == 'bert-base-uncased': 130 | data_path = './tmp/be12_'+args.task+data_div+'.npz' 131 | else: 132 | data_path = './tmp/be24_'+args.task+data_div+'.npz' 133 | if not os.path.exists(data_path): 134 | # get BERT hidden representations 135 | tokenizer = BertTokenizer.from_pretrained(args.model_name) 136 | model = BertForPreTraining.from_pretrained(args.model_name, 137 | return_dict=True, 138 | output_hidden_states = True) 139 | ge_vecs = [] 140 | bert_hs = [[] for l in range(args.bert_layers_num)] 141 | for s in tqdm(amr_s): 142 | # parse 143 | penman_g = penman.decode(s) 144 | s = penman_g.metadata.get('tok').split(' ') 145 | wid = [] 146 | var = [] # k=word id; v=variable 147 | for k, v in penman_g.epidata.items(): 148 | if k[1] == ':instance': 149 | if len(v): 150 | if type(v[0]) == penman.surface.Alignment: 151 | wid.append(v[0].indices[0]) 152 | var.append(k[0]) 153 | 154 | # BERT embedding 155 | c_s = [] 156 | for w in s: 157 | c_w = clean_string(w) 158 | if len(c_w) == 0: c_w = ',' 159 | c_s.append(c_w) 160 | inputs = tokenizer(c_s, return_tensors="pt", is_split_into_words = True) 161 | outputs = model(**inputs) 162 | outputs = outputs.hidden_states 163 | # average word pieces to get whole word embedding 164 | s_pieces = tokenizer.tokenize(' '.join(c_s)) 165 | w_ids = [] 166 | for i in range(len(s_pieces)): 167 | if len(s_pieces[i]) < 2 or s_pieces[i][:2] != '##': 168 | w_ids.append(i) 169 | # check piece number 170 | if len(w_ids) != len(c_s): 171 | print('Error! failed to get BERT word embedding!', c_s, s_pieces, w_ids) 172 | hidden_s = [] 173 | for l in range(len(outputs)): 174 | # remove EOS BOS tokens 175 | piece_embed = torch.squeeze(outputs[l].data)[1:-1] 176 | # get word embeddings 177 | word_embed = piece_embed[w_ids] 178 | # average word embedding 179 | for i in range(len(w_ids)-1): 180 | if w_ids[i+1] - w_ids[i] != 1: 181 | tmp_idx = [w_ids[i]+j for j in range(w_ids[i+1] - w_ids[i])] 182 | word_embed[i] = torch.mean(piece_embed[tmp_idx], dim=0) 183 | hidden_s.append(word_embed[wid]) 184 | # bert embedding for sentence c_s: 13 - len(c_s) * 768 185 | for l in range(len(hidden_s)): 186 | bert_hs[l].append(hidden_s[l].detach().cpu().data.numpy()) 187 | 188 | # graph embedding 189 | g = nx.Graph() 190 | for v in penman_g.variables(): g.add_node(v) 191 | for e in penman_g.edges(): g.add_edge(e.source, e.target) 192 | walks = random_walks(g, 50, 10, True) 193 | if len(var) > 1: 194 | ge_model = Word2Vec(walks, 195 | size=128, 196 | window=2, 197 | min_count=0, 198 | sg=1, 199 | hs=1, 200 | workers=20) 201 | ge_vec = ge_model.wv[var] 202 | else: 203 | ge_vec = np.zeros((1,args.bert_hidden_num)) 204 | ge_vecs.append(ge_vec) 205 | 206 | if len(data_div): 207 | l = args.bert_layers_num - 1 208 | savez_dict = {} 209 | for i in range(len(bert_hs[l])): savez_dict['s'+str(i)] = bert_hs[l][i] 210 | np.savez('./tmp/be'+str(l)+'_'+args.task+data_div+'.npz', **savez_dict) 211 | else: 212 | # save bert embedding 213 | for l in range(args.bert_layers_num): 214 | savez_dict = {} 215 | for i in range(len(bert_hs[l])): savez_dict['s'+str(i)] = bert_hs[l][i] 216 | np.savez('./tmp/be'+str(l)+'_'+args.task+data_div+'.npz', **savez_dict) 217 | # save graph embeddings 218 | savez_dict = {} 219 | for i in range(len(ge_vecs)): savez_dict['s'+str(i)] = ge_vecs[i] 220 | np.savez('./tmp/ge_'+args.task+data_div+'.npz', **savez_dict) 221 | 222 | return np.load('./tmp/ge_'+args.task+data_div+'.npz'), \ 223 | ['./tmp/be'+str(l)+'_'+args.task+data_div+'.npz' for l in range(args.bert_layers_num)] -------------------------------------------------------------------------------- /src/models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import torch.optim as optim 5 | import numpy as np 6 | 7 | 8 | ## sub-function descriptions 9 | ## please refer to main py file for function descriptions 10 | ''' 11 | binary_classifer: PyTorch model for link prediction (with different hidden layers) 12 | graph_probe: PyTorch model for probing 13 | ''' 14 | 15 | 16 | class binary_classifer(nn.Module): 17 | def __init__(self, 18 | layers_num=5, 19 | feat_dim=0, 20 | hidden_dim=128): 21 | super(binary_classifer, self).__init__() 22 | self.layers_num = layers_num 23 | self.linear_sh = nn.Linear(feat_dim, hidden_dim) 24 | self.linear_dh = nn.Linear(feat_dim, hidden_dim) 25 | self.linear_h1 = nn.Linear(2* hidden_dim, hidden_dim) 26 | self.linear_h2 = nn.Linear(hidden_dim, hidden_dim) 27 | self.linear_h3 = nn.Linear(hidden_dim, hidden_dim) 28 | self.linear_h4 = nn.Linear(hidden_dim, hidden_dim) 29 | if layers_num == 0: # 0 30 | self.linear_ca = nn.Linear(2*feat_dim, 1) 31 | elif layers_num <= 1: # 1 32 | self.linear_ca = nn.Linear(2*hidden_dim, 1) 33 | else: # 2, 3, 4, 5 34 | self.linear_ca = nn.Linear(hidden_dim, 1) 35 | nn.init.xavier_normal_(self.linear_ca.weight.data) 36 | 37 | def forward(self, src, dst): 38 | # layers_num = 0 39 | h = torch.cat((src, dst), dim=1) 40 | if self.layers_num == 0: 41 | return torch.sigmoid(self.linear_ca(h)) 42 | 43 | # layers_num = 1 44 | s = F.relu(self.linear_sh(src)) 45 | d = F.relu(self.linear_dh(dst)) 46 | h = torch.cat((s, d), dim=1) 47 | if self.layers_num == 1: 48 | return torch.sigmoid(self.linear_ca(h)) 49 | 50 | # layers_num = 2 51 | h = F.relu(self.linear_h1(h)) 52 | if self.layers_num == 2: 53 | return torch.sigmoid(self.linear_ca(h)) 54 | 55 | # layers_num = 3 56 | h = F.relu(self.linear_h2(h)) 57 | if self.layers_num == 3: 58 | return torch.sigmoid(self.linear_ca(h)) 59 | 60 | # layers_num = 4 61 | h = F.relu(self.linear_h3(h)) 62 | if self.layers_num == 4: 63 | return torch.sigmoid(self.linear_ca(h)) 64 | 65 | # layers_num = 5 66 | h = F.relu(self.linear_h4(h)) 67 | if self.layers_num == 5: 68 | return torch.sigmoid(self.linear_ca(h)) 69 | 70 | 71 | class graph_probe(nn.Module): 72 | def __init__(self, 73 | adj_dim, 74 | feat_dim, 75 | hidden_dim=64): 76 | super(graph_probe, self).__init__() 77 | self.linear_ah = nn.Linear(adj_dim, hidden_dim) 78 | self.linear_fh = nn.Linear(feat_dim, hidden_dim) 79 | self.linear_hh2 = nn.Linear(hidden_dim*2, hidden_dim) 80 | self.linear_hh3 = nn.Linear(hidden_dim, hidden_dim) 81 | self.linear_hh4 = nn.Linear(hidden_dim, hidden_dim) 82 | self.linear_hh5 = nn.Linear(hidden_dim, hidden_dim*2) 83 | self.linear_hm = nn.Linear(hidden_dim*2, 1) 84 | nn.init.xavier_normal_(self.linear_ah.weight.data) 85 | nn.init.xavier_normal_(self.linear_fh.weight.data) 86 | nn.init.xavier_normal_(self.linear_hm.weight.data) 87 | 88 | def forward(self, a, f): 89 | a = self.linear_ah(a) 90 | f = self.linear_fh(f) 91 | h = torch.cat((f, a), dim=1) 92 | # h = F.elu(self.linear_hh2(h)) 93 | # h = F.elu(self.linear_hh3(h)) 94 | # h = F.elu(self.linear_hh4(h)) 95 | # h = F.elu(self.linear_hh5(h)) 96 | return F.elu(self.linear_hm(h)) 97 | # return self.linear_hm(h) 98 | 99 | 100 | ''' 101 | class autoencoder(nn.Module): 102 | def __init__(self, 103 | feat_dim, 104 | hidden_dim=8): 105 | super(autoencoder, self).__init__() 106 | self.linear_fh = nn.Linear(feat_dim, hidden_dim) 107 | self.linear_hf = nn.Linear(hidden_dim, feat_dim) 108 | nn.init.xavier_normal_(self.linear_fh.weight.data) 109 | nn.init.xavier_normal_(self.linear_hf.weight.data) 110 | 111 | def forward(self, f, encoding=False): 112 | h = F.elu(self.linear_fh(f)) 113 | if encoding: 114 | return h.detach().cpu().numpy() 115 | else: 116 | return F.elu(self.linear_hf(h)) 117 | 118 | class graph_probe_attn(nn.Module): 119 | def __init__(self, 120 | adj_dim, 121 | feat_dim, 122 | hidden_dim=64, 123 | d_model=96, 124 | dropout=0.1): 125 | super(graph_probe_attn, self).__init__() 126 | self.dropout = dropout 127 | self.linear_context = nn.Linear(adj_dim, d_model) 128 | self.linear_query = nn.Linear(feat_dim, d_model) 129 | w4C = torch.empty(d_model, 1) 130 | w4Q = torch.empty(d_model, 1) 131 | w4mlu = torch.empty(1, d_model) 132 | self.w4C = nn.Parameter(w4C) 133 | self.w4Q = nn.Parameter(w4Q) 134 | self.w4mlu = nn.Parameter(w4mlu) 135 | bias = torch.empty(1) 136 | self.bias = nn.Parameter(bias) 137 | nn.init.xavier_uniform_(w4C) 138 | nn.init.xavier_uniform_(w4Q) 139 | nn.init.xavier_uniform_(w4mlu) 140 | nn.init.constant_(bias, 0) 141 | 142 | self.linear_oh = nn.Linear(d_model*4, hidden_dim) 143 | self.linear_hh = nn.Linear(hidden_dim, hidden_dim) 144 | self.linear_hm = nn.Linear(hidden_dim, 1) 145 | nn.init.xavier_normal_(self.linear_oh.weight.data) 146 | nn.init.xavier_normal_(self.linear_hh.weight.data) 147 | nn.init.xavier_normal_(self.linear_hm.weight.data) 148 | 149 | def forward(self, a, f): 150 | # context-query attention 151 | C = self.linear_context(a) # context=graph 152 | Q = self.linear_query(f) # query=feature 153 | Lc, d_model = C.shape 154 | Lq, d_model = Q.shape 155 | S = self.trilinear_for_attention(C, Q) 156 | S1 = F.softmax(S, dim=1) 157 | S2 = F.softmax(S, dim=0) 158 | A = torch.matmul(S1, Q) 159 | B = torch.matmul(torch.matmul(S1, S2.transpose(0, 1)), C) 160 | out = torch.cat([C, A, torch.mul(C, A), torch.mul(C, B)], dim=1) 161 | 162 | h = F.elu(self.linear_oh(out)) 163 | h = F.elu(self.linear_hh(h)) 164 | return F.elu(self.linear_hm(h)) 165 | 166 | 167 | def trilinear_for_attention(self, C, Q): 168 | Lc, d_model = C.shape 169 | Lq, d_model = Q.shape 170 | dropout = self.dropout 171 | C = F.dropout(C, p=dropout, training=self.training) 172 | Q = F.dropout(Q, p=dropout, training=self.training) 173 | subres0 = torch.matmul(C, self.w4C).expand([-1, Lq]) 174 | subres1 = torch.matmul(Q, self.w4Q).transpose(0, 1).expand([Lc, -1]) 175 | subres2 = torch.matmul(C * self.w4mlu, Q.transpose(0,1)) 176 | res = subres0 + subres1 + subres2 177 | res += self.bias 178 | return res 179 | 180 | 181 | # MINEE 182 | def _resample(data, batch_size, replace=False): 183 | # Resample the given data sample. 184 | index = np.random.choice( 185 | range(data.shape[0]), size=batch_size, replace=replace) 186 | batch = data[index] 187 | return batch 188 | 189 | def _normal_sample(data, batch_size): 190 | # Sample the reference uniform distribution 191 | data_min = data.min(dim=0)[0] 192 | data_max = data.max(dim=0)[0] 193 | # return (data_max - data_min) * torch.rand((batch_size, data_min.shape[0])) + data_min 194 | return torch.randn((batch_size, data_min.shape[0])) 195 | 196 | def _div(net, data, ref): 197 | # Calculate the divergence estimate using a neural network 198 | mean_f = net(data).mean() 199 | log_mean_ef_ref = torch.logsumexp(net(ref), 0) - np.log(ref.shape[0]) 200 | return mean_f - log_mean_ef_ref 201 | 202 | class MINEE(): 203 | class Net(nn.Module): 204 | # Inner class that defines the neural network architecture 205 | def __init__(self, input_size=2, hidden_size=100, sigma=0.02): 206 | super().__init__() 207 | self.fc1 = nn.Linear(input_size, hidden_size) 208 | self.fc2 = nn.Linear(hidden_size, hidden_size) 209 | self.fc3 = nn.Linear(hidden_size, 1) 210 | nn.init.xavier_normal_(self.fc1.weight) 211 | nn.init.xavier_normal_(self.fc2.weight) 212 | nn.init.xavier_normal_(self.fc3.weight) 213 | 214 | def forward(self, inputs): 215 | 216 | output = F.elu(self.fc1(input)) 217 | output = F.elu(self.fc2(output)) 218 | output = self.fc3(output) 219 | return output 220 | 221 | def __init__(self, x_dim, y_dim, ref_batch_factor=1, lr=1e-4, hidden_size=100): 222 | self.lr = lr 223 | self.ref_batch_factor = ref_batch_factor 224 | self.XY_net = MINEE.Net(input_size=x_dim + y_dim, hidden_size=100) 225 | self.X_net = MINEE.Net(input_size=x_dim, hidden_size=100) 226 | self.Y_net = MINEE.Net(input_size=y_dim, hidden_size=100) 227 | self.XY_optimizer = optim.Adam(self.XY_net.parameters(), lr=lr) 228 | self.X_optimizer = optim.Adam(self.X_net.parameters(), lr=lr) 229 | self.Y_optimizer = optim.Adam(self.Y_net.parameters(), lr=lr) 230 | 231 | def step(self, X, Y, iter=1): 232 | r"""Train the neural networks for one or more steps. 233 | Argument: 234 | iter (int, optional): number of steps to train. 235 | """ 236 | self.X = X 237 | self.Y = Y 238 | self.batch_size = X.shape[0] 239 | self.XY = torch.cat((self.X, self.Y), dim=1) 240 | for i in range(iter): 241 | self.XY_optimizer.zero_grad() 242 | self.X_optimizer.zero_grad() 243 | self.Y_optimizer.zero_grad() 244 | batch_XY = _resample(self.XY, batch_size=self.batch_size) 245 | batch_X = _resample(self.X, batch_size=self.batch_size) 246 | batch_Y = _resample(self.Y, batch_size=self.batch_size) 247 | batch_X_ref = _normal_sample(self.X, batch_size=int( 248 | self.ref_batch_factor * self.batch_size)) 249 | batch_Y_ref = _normal_sample(self.Y, batch_size=int( 250 | self.ref_batch_factor * self.batch_size)) 251 | batch_XY_ref = torch.cat((batch_X_ref, batch_Y_ref), dim=1) 252 | 253 | batch_loss_XY = -_div(self.XY_net, batch_XY, batch_XY_ref) 254 | batch_loss_X = -_div(self.X_net, batch_X, batch_X_ref) 255 | batch_loss_Y = -_div(self.Y_net, batch_Y, batch_Y_ref) 256 | 257 | val_loss_XY = batch_loss_XY.data.item() 258 | val_loss_X = batch_loss_X.data.item() 259 | val_loss_Y = batch_loss_Y.data.item() 260 | val_loss_sum = val_loss_XY + val_loss_X + val_loss_Y 261 | 262 | if val_loss_sum != 0: 263 | batch_loss_XY = (1 - val_loss_XY/val_loss_sum) * batch_loss_XY 264 | batch_loss_X = (1 - val_loss_X/val_loss_sum) * batch_loss_X 265 | batch_loss_Y = (1 - val_loss_Y/val_loss_sum) * batch_loss_Y 266 | 267 | total_loss = batch_loss_XY + batch_loss_X + batch_loss_Y 268 | total_loss.backward() 269 | self.XY_optimizer.step() 270 | self.X_optimizer.step() 271 | self.Y_optimizer.step() 272 | 273 | return batch_loss_XY.data.item(), batch_loss_X.data.item(), \ 274 | batch_loss_Y.data.item() 275 | 276 | def forward(self, X, Y): 277 | r"""Evaluate the neural networks to return an array of 3 divergences estimates 278 | (dXY, dX, dY). 279 | Outputs: 280 | dXY: divergence of sample joint distribution of (X,Y) 281 | to the uniform reference 282 | dX: divergence of sample marginal distribution of X 283 | to the uniform reference 284 | dY: divergence of sample marginal distribution of Y 285 | to the uniform reference 286 | Arguments: 287 | X (tensor, optional): samples of X. 288 | Y (tensor, optional): samples of Y. 289 | By default, X and Y for training is used. 290 | The arguments are useful for testing/validation with a separate data set. 291 | """ 292 | XY = torch.cat((X, Y), dim=1) 293 | X_ref = _normal_sample(X, batch_size=int(self.ref_batch_factor * X.shape[0])) 294 | Y_ref = _normal_sample(Y, batch_size=int(self.ref_batch_factor * Y.shape[0])) 295 | XY_ref = torch.cat((X_ref, Y_ref), dim=1) 296 | 297 | ce_XY = _div(self.XY_net, XY, XY_ref).cpu().item() 298 | ce_X = _div(self.X_net, X, X_ref).cpu().item() 299 | ce_Y = _div(self.Y_net, Y, Y_ref).cpu().item() 300 | 301 | return ce_XY, ce_X, ce_Y 302 | 303 | def estimate(self, X=None, Y=None): 304 | r"""Return the mutual information estimate. 305 | Arguments: 306 | X (tensor, optional): samples of X. 307 | Y (tensor, optional): samples of Y. 308 | By default, X and Y for training is used. 309 | The arguments are useful for testing/validation with a separate data set. 310 | """ 311 | dXY, dX, dY = self.forward(X, Y) 312 | return dXY - dX - dY 313 | 314 | def state_dict(self): 315 | r"""Return a dictionary storing the state of the estimator. 316 | """ 317 | return { 318 | 'XY_net': self.XY_net.state_dict(), 319 | 'XY_optimizer': self.XY_optimizer.state_dict(), 320 | 'X_net': self.X_net.state_dict(), 321 | 'X_optimizer': self.X_optimizer.state_dict(), 322 | 'Y_net': self.Y_net.state_dict(), 323 | 'Y_optimizer': self.Y_optimizer.state_dict(), 324 | 'X': self.X, 325 | 'Y': self.Y, 326 | 'lr': self.lr, 327 | 'batch_size': self.batch_size, 328 | 'ref_batch_factor': self.ref_batch_factor 329 | } 330 | 331 | def load_state_dict(self, state_dict): 332 | r"""Load the dictionary of state state_dict. 333 | """ 334 | self.XY_net.load_state_dict(state_dict['XY_net']) 335 | self.XY_optimizer.load_state_dict(state_dict['XY_optimizer']) 336 | self.X_net.load_state_dict(state_dict['X_net']) 337 | self.X_optimizer.load_state_dict(state_dict['X_optimizer']) 338 | self.Y_net.load_state_dict(state_dict['Y_net']) 339 | self.Y_optimizer.load_state_dict(state_dict['Y_optimizer']) 340 | self.X = state_dict['X'] 341 | self.Y = state_dict['Y'] 342 | if 'lr' in state_dict: 343 | self.lr = state_dict['lr'] 344 | if 'batch_size' in state_dict: 345 | self.batch_size = state_dict['batch_size'] 346 | if 'ref_batch_factor' in state_dict: 347 | self.ref_batch_factor = state_dict['ref_batch_factor'] 348 | ''' -------------------------------------------------------------------------------- /src/evaluation_amr.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from sklearn.metrics import roc_auc_score, jaccard_score 4 | from tqdm import tqdm 5 | 6 | from probe import mine_probe 7 | from models import binary_classifer 8 | from embed import get_embeddings 9 | from utils import load_data, get_edge_idx_amr, load_noisy_graphs, uuas_score,\ 10 | load_graph_labels, load_split_emb 11 | 12 | def test_ge_amr(args, data_split=True): 13 | # load data & graph embeddings 14 | if data_split: 15 | s_train = load_data('amr_dataset', 'train') 16 | s_dev = load_data('amr_dataset', 'dev') 17 | s_test = load_data('amr_dataset', 'test') 18 | ge_train, _ = get_embeddings(args, s_train, data_div='_train') 19 | ge_dev, _ = get_embeddings(args, s_dev, data_div='_dev') 20 | ge_test, _ = get_embeddings(args, s_test, data_div='_test') 21 | else: 22 | s_train = load_data('amr_dataset', 'train') 23 | s_dev = load_data('amr_dataset', 'dev') 24 | s_test = load_data('amr_dataset', 'test') 25 | s_train = s_train + s_dev + s_test 26 | s_dev, s_test = s_train, s_train 27 | ge_train, _ = get_embeddings(args, s_train) 28 | ge_dev, ge_test = ge_train, ge_train 29 | 30 | model = binary_classifer(args.classifier_layers_num, 31 | ge_train['s0'].shape[1]).to(args.device) 32 | optimizer = torch.optim.Adam(model.parameters(), lr=1e-4) 33 | loss_fcn = torch.nn.BCELoss() 34 | print('2. start to train model... ') 35 | # train 36 | model.train() 37 | train_losses = [999 for _ in range(args.patience)] 38 | for _ in range(10): 39 | loss_train = 0 40 | for s in range(len(s_train)): 41 | # get graph embedding 42 | graph_emb = ge_train['s'+str(s)] 43 | graph_emb = torch.FloatTensor(graph_emb).to(args.device) 44 | # get ground-truth graph 45 | src_idx, dst_idx, edge_labels = get_edge_idx_amr(s_train[s]) 46 | edge_labels = torch.FloatTensor(edge_labels).to(args.device) 47 | if len(src_idx) <= 1: continue 48 | optimizer.zero_grad() 49 | edge_pred = model(graph_emb[src_idx], graph_emb[dst_idx]) 50 | edge_pred = torch.squeeze(edge_pred) 51 | loss = loss_fcn(edge_pred, edge_labels) 52 | loss.backward() 53 | optimizer.step() 54 | loss_train += loss.data.item() 55 | loss_train = loss_train/len(s_train) 56 | print(' the training loss is: {:.4f}'.format(loss_train)) 57 | # early stop 58 | if loss_train < max(train_losses): 59 | train_losses.remove(max(train_losses)) 60 | train_losses.append(loss_train) 61 | else: 62 | break 63 | 64 | if data_split: 65 | # validation 66 | model.eval() 67 | loss_dev = 0 68 | for s in range(len(s_dev)): 69 | # get graph embedding 70 | graph_emb = ge_dev['s'+str(s)] 71 | graph_emb = torch.FloatTensor(graph_emb).to(args.device) 72 | # get ground-truth graph 73 | src_idx, dst_idx, edge_labels = get_edge_idx_amr(s_dev[s]) 74 | edge_labels = torch.FloatTensor(edge_labels).to(args.device) 75 | if len(src_idx) <= 1: continue 76 | edge_pred = model(graph_emb[src_idx], graph_emb[dst_idx]) 77 | edge_pred = torch.squeeze(edge_pred) 78 | loss = loss_fcn(edge_pred, edge_labels) 79 | loss_dev += loss.data.item() 80 | loss_dev = loss_dev/len(s_dev) 81 | print('2. start to test model... | Train loss: {:.4f} | Val loss: {:.4f}'.format( 82 | loss_train, loss_dev)) 83 | 84 | # test 85 | model.eval() 86 | auc, jaccard, uuas = [], [], [] 87 | print('2. start to test graph embedding model... ') 88 | for s in range(len(s_test)): 89 | # get graph embedding 90 | graph_emb = ge_test['s'+str(s)] 91 | graph_emb = torch.FloatTensor(graph_emb).to(args.device) 92 | # get ground-truth graph 93 | src_idx, dst_idx, edge_labels = get_edge_idx_amr(s_test[s]) 94 | edge_labels = torch.FloatTensor(edge_labels).to(args.device) 95 | if len(src_idx) <= 1: continue 96 | edge_pred = model(graph_emb[src_idx], graph_emb[dst_idx]) 97 | edge_pred = torch.squeeze(edge_pred).detach().cpu().numpy() 98 | edge_labels = edge_labels.detach().cpu().numpy() 99 | if edge_labels.sum() > 0: 100 | auc.append(roc_auc_score(edge_labels, edge_pred)) 101 | edge_pred = np.where(edge_pred > 0.5, 1, 0) 102 | jaccard.append(jaccard_score(edge_labels, edge_pred)) 103 | # uuas.append(uuas_score(src_idx, dst_idx, edge_labels, edge_pred)) 104 | 105 | print(sum(auc)/len(auc), sum(jaccard)/len(jaccard)) 106 | 107 | return 108 | 109 | 110 | def test_bert_amr(args, model_name='bert'): 111 | # load data & graph embeddings 112 | s_train = load_data('amr_dataset', 'train') 113 | s_dev = load_data('amr_dataset', 'dev') 114 | s_test = load_data('amr_dataset', 'test') 115 | _, s_train_paths = get_embeddings(args, s_train, data_div='_train') 116 | _, s_dev_paths = get_embeddings(args, s_dev, data_div='_dev') 117 | _, s_test_paths = get_embeddings(args, s_test, data_div='_test') 118 | bert_train = np.load(s_train_paths[-1]) 119 | bert_dev = np.load(s_dev_paths[-1]) 120 | bert_test = np.load(s_test_paths[-1]) 121 | if model_name != 'bert': 122 | bert_train, bert_dev, bert_test = load_split_emb(len(bert_train), len(bert_dev), 123 | len(bert_test), model_name, 'amr') 124 | 125 | feat_dim = bert_train['s0'].shape[1] 126 | model = binary_classifer(args.classifier_layers_num, 127 | feat_dim).to(args.device) 128 | optimizer = torch.optim.Adam(model.parameters(), lr=1e-4) 129 | loss_fcn = torch.nn.BCELoss() 130 | 131 | print('2. start to train model...') 132 | # train 133 | model.train() 134 | train_losses = [999 for _ in range(args.patience)] 135 | for _ in range(10): 136 | loss_train = 0 137 | for s in range(len(bert_train)): 138 | # get graph embedding 139 | bert_emb = bert_train['s'+str(s)] 140 | bert_emb = torch.FloatTensor(bert_emb).to(args.device) 141 | # get ground-truth graph 142 | src_idx, dst_idx, edge_labels = get_edge_idx_amr(s_train[s]) 143 | edge_labels = torch.FloatTensor(edge_labels).to(args.device) 144 | if len(src_idx) <= 1: continue 145 | 146 | optimizer.zero_grad() 147 | edge_pred = model(bert_emb[src_idx], bert_emb[dst_idx]) 148 | edge_pred = torch.squeeze(edge_pred) 149 | loss = loss_fcn(edge_pred, edge_labels) 150 | loss.backward() 151 | optimizer.step() 152 | loss_train += loss.data.item() 153 | loss_train = loss_train/len(bert_train) 154 | print(' the training loss is: {:.4f}'.format(loss_train)) 155 | # early stop 156 | if loss_train < max(train_losses): 157 | train_losses.remove(max(train_losses)) 158 | train_losses.append(loss_train) 159 | else: 160 | break 161 | 162 | print('2. start to validate model...') 163 | # validation 164 | model.eval() 165 | loss_dev = 0 166 | for s in range(len(bert_dev)): 167 | # get graph embedding 168 | bert_emb = bert_dev['s'+str(s)] 169 | bert_emb = torch.FloatTensor(bert_emb).to(args.device) 170 | # get ground-truth graph 171 | src_idx, dst_idx, edge_labels = get_edge_idx_amr(s_dev[s]) 172 | edge_labels = torch.FloatTensor(edge_labels).to(args.device) 173 | if len(src_idx) <= 1: continue 174 | 175 | edge_pred = model(bert_emb[src_idx], bert_emb[dst_idx]) 176 | edge_pred = torch.squeeze(edge_pred) 177 | loss = loss_fcn(edge_pred, edge_labels) 178 | loss_dev += loss.data.item() 179 | loss_dev = loss_dev/len(bert_dev) 180 | 181 | print('2. | Train loss: {:.4f} | Val loss: {:.4f}'.format(loss_train, loss_dev)) 182 | # test 183 | model.eval() 184 | auc, jaccard, uuas = [], [], [] 185 | print('2. start to test model...') 186 | for s in range(len(bert_test)): 187 | # get graph embedding 188 | bert_emb = bert_test['s'+str(s)] 189 | bert_emb = torch.FloatTensor(bert_emb).to(args.device) 190 | # get ground-truth graph 191 | src_idx, dst_idx, edge_labels = get_edge_idx_amr(s_test[s]) 192 | edge_labels = torch.FloatTensor(edge_labels).to(args.device) 193 | if len(src_idx) <= 1: continue 194 | edge_pred = model(bert_emb[src_idx], bert_emb[dst_idx]) 195 | edge_pred = torch.squeeze(edge_pred).detach().cpu().numpy() 196 | edge_labels = edge_labels.detach().cpu().numpy() 197 | if edge_labels.sum() > 0: 198 | auc.append(roc_auc_score(edge_labels, edge_pred)) 199 | edge_pred = np.where(edge_pred > 0.5, 1, 0) 200 | jaccard.append(jaccard_score(edge_labels, edge_pred)) 201 | # uuas.append(uuas_score(src_idx, dst_idx, edge_labels, edge_pred)) 202 | 203 | print(sum(auc)/len(auc), sum(jaccard)/len(jaccard)) 204 | 205 | return 206 | 207 | 208 | def test_mi_amr(args): 209 | # load data & embeddings 210 | s_train = load_data('amr_dataset', 'train') 211 | s_dev = load_data('amr_dataset', 'dev') 212 | s_test = load_data('amr_dataset', 'test') 213 | amr_s = s_train + s_dev + s_test 214 | graph_emb, _ = get_embeddings(args, amr_s) 215 | 216 | print('2.2 start to test bert MI...') 217 | task_mi = [] 218 | for n in tqdm(range(11)): 219 | noise_mi = [] 220 | for r in range(args.repeat): 221 | mi_s = mine_probe(args, graph_emb, graph_emb, len(graph_emb), n/10) 222 | mi_s = sum(mi_s)/len(mi_s) 223 | noise_mi.append(mi_s) 224 | task_mi.append(noise_mi) 225 | 226 | print(task_mi) 227 | 228 | return 229 | 230 | 231 | def mi_noise_amr(args): 232 | # load data & embeddings 233 | s_train = load_data('amr_dataset', 'train') 234 | s_dev = load_data('amr_dataset', 'dev') 235 | s_test = load_data('amr_dataset', 'test') 236 | amr_s = s_train + s_dev + s_test 237 | graph_emb, bert_emb_paths = get_embeddings(args, amr_s) 238 | bert_emb = np.load(bert_emb_paths[-1]) 239 | noisy_g = load_noisy_graphs(args) 240 | 241 | results = {} 242 | for k, v in noisy_g.items(): 243 | if k not in results: results[k] = [] 244 | for r in range(args.repeat): 245 | mi = mine_probe(args, graph_emb, bert_emb, len(graph_emb), 'noisy', v) 246 | results[k].append(sum(mi) / len(mi)) 247 | print(results) 248 | return 249 | 250 | 251 | def test_random_amr(args, corrupt=False): 252 | # load data & graph embeddings 253 | s_train = load_data('amr_dataset', 'train') 254 | s_dev = load_data('amr_dataset', 'dev') 255 | s_test = load_data('amr_dataset', 'test') 256 | _, s_train_paths = get_embeddings(args, s_train, data_div='_train') 257 | _, s_dev_paths = get_embeddings(args, s_dev, data_div='_dev') 258 | _, s_test_paths = get_embeddings(args, s_test, data_div='_test') 259 | bert_train = np.load(s_train_paths[-1]) 260 | bert_dev = np.load(s_dev_paths[-1]) 261 | bert_test = np.load(s_test_paths[-1]) 262 | feat_dim = args.bert_hidden_num 263 | noisy_g = load_noisy_graphs(args) 264 | 265 | for noisy_tag, noisy_id in noisy_g.items(): 266 | print('2. corrupt type: ', noisy_tag) 267 | model = binary_classifer(args.classifier_layers_num, feat_dim).to(args.device) 268 | optimizer = torch.optim.Adam(model.parameters(), lr=1e-4) 269 | loss_fcn = torch.nn.BCELoss() 270 | 271 | print('2. start to train model...') 272 | # train 273 | model.train() 274 | train_losses = [999 for _ in range(args.patience)] 275 | for _ in range(10): 276 | loss_train = 0 277 | for s in range(len(bert_train)): 278 | # get graph embedding 279 | bert_emb = bert_train['s'+str(s)] 280 | if corrupt: 281 | rand_vec = np.random.randn(bert_emb.shape[0], bert_emb.shape[1]) 282 | bert_emb[noisy_id[s]] = rand_vec[noisy_id[s]] 283 | bert_emb = torch.FloatTensor(bert_emb).to(args.device) 284 | # get ground-truth graph 285 | src_idx, dst_idx, edge_labels = get_edge_idx_amr(s_train[s]) 286 | edge_labels = torch.FloatTensor(edge_labels).to(args.device) 287 | if len(src_idx) <= 1: continue 288 | 289 | optimizer.zero_grad() 290 | edge_pred = model(bert_emb[src_idx], bert_emb[dst_idx]) 291 | edge_pred = torch.squeeze(edge_pred) 292 | loss = loss_fcn(edge_pred, edge_labels) 293 | loss.backward() 294 | optimizer.step() 295 | loss_train += loss.data.item() 296 | loss_train = loss_train/len(bert_train) 297 | print(' the training loss is: {:.4f}'.format(loss_train)) 298 | # early stop 299 | if loss_train < max(train_losses): 300 | train_losses.remove(max(train_losses)) 301 | train_losses.append(loss_train) 302 | else: 303 | break 304 | 305 | # test 306 | model.eval() 307 | label_ids = load_graph_labels(args) 308 | auc_dict = {} 309 | print('2. start to test model...') 310 | for k, v_l in label_ids.items(): 311 | auc_tmp = [] 312 | for s in range(len(bert_test)): 313 | # get graph embedding 314 | bert_emb = bert_test['s'+str(s)] 315 | bert_emb = torch.FloatTensor(bert_emb).to(args.device) 316 | # get ground-truth graph 317 | src_idx, dst_idx, edge_labels = get_edge_idx_amr(s_test[s]) 318 | edge_labels = torch.FloatTensor(edge_labels).to(args.device) 319 | if len(src_idx) <= 1: continue 320 | edge_pred = model(bert_emb[src_idx], bert_emb[dst_idx]) 321 | edge_pred = torch.squeeze(edge_pred).detach().cpu().numpy() 322 | edge_labels = edge_labels.detach().cpu().numpy() 323 | edge_mask = [] 324 | for i in range(bert_emb.shape[0]): 325 | for j in range(bert_emb.shape[0]): 326 | if i in v_l[s] or j in v_l[s]: 327 | edge_mask.append(1) 328 | else: 329 | edge_mask.append(0) 330 | edge_mask = np.array(edge_mask) 331 | edge_pred = edge_pred[edge_mask==1] 332 | edge_labels = edge_labels[edge_mask==1] 333 | if edge_labels.sum() > 0: 334 | auc_tmp.append(roc_auc_score(edge_labels, edge_pred)) 335 | auc_dict[k] = sum(auc_tmp) / (1e-5 + len(auc_tmp)) 336 | print(auc_dict) 337 | 338 | return -------------------------------------------------------------------------------- /src/evaluation_ptb.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from sklearn.metrics import roc_auc_score, jaccard_score 4 | from tqdm import tqdm 5 | 6 | from probe import mine_probe 7 | from models import binary_classifer 8 | from embed import graph_embeddings, bert_embeddings 9 | from utils import get_edge_idx, load_data, construct_graph, get_graph_emb, \ 10 | load_noisy_trees, uuas_score, load_tree_labels, load_split_emb 11 | 12 | 13 | ## sub-function descriptions 14 | ## please refer to main py file for function descriptions 15 | ''' 16 | load_graph: load graph embedding 17 | load_bert: load BERT embedding 18 | load_embeddings: load graph embedding and BERT embedding (only top layer) 19 | ''' 20 | 21 | 22 | def load_graph(args, data_split=True): 23 | if not data_split: 24 | _, p_train = load_data('penn_treebank_dataset', 'train') 25 | doc_id, sen_id_train, global_graph = construct_graph(p_train) 26 | _, p_dev = load_data('penn_treebank_dataset', 'dev') 27 | doc_id, sen_id_dev, global_graph = construct_graph(p_dev) 28 | _, p_test = load_data('penn_treebank_dataset', 'test') 29 | doc_id, sen_id_test, global_graph = construct_graph(p_test) 30 | parsed = p_train + p_dev + p_test 31 | sen_id = sen_id_train + sen_id_dev + sen_id_test 32 | graph_emb = graph_embeddings(args, global_graph, doc_id, sen_id) 33 | return graph_emb, sen_id 34 | else: 35 | _, p_train = load_data('penn_treebank_dataset', 'train') 36 | doc_id, sen_id_train, global_graph = construct_graph(p_train) 37 | ge_train = graph_embeddings(args, global_graph, doc_id, sen_id_train, '_train') 38 | _, p_dev = load_data('penn_treebank_dataset', 'dev') 39 | doc_id, sen_id_dev, global_graph = construct_graph(p_dev) 40 | ge_dev = graph_embeddings(args, global_graph, doc_id, sen_id_dev, '_dev') 41 | _, p_test = load_data('penn_treebank_dataset', 'test') 42 | doc_id, sen_id_test, global_graph = construct_graph(p_test) 43 | ge_test = graph_embeddings(args, global_graph, doc_id, sen_id_test, '_test') 44 | return ge_train, ge_dev, ge_test, sen_id_train, sen_id_dev, sen_id_test 45 | 46 | 47 | def load_bert(args): 48 | s_train, p_train = load_data('penn_treebank_dataset', 'train') 49 | doc_id, sen_id_train, global_graph = construct_graph(p_train) 50 | bert_train_paths = bert_embeddings(args, s_train, '_train') 51 | bert_train = np.load(bert_train_paths[-1]) 52 | 53 | s_dev, p_dev = load_data('penn_treebank_dataset', 'dev') 54 | doc_id, sen_id_dev, global_graph = construct_graph(p_dev) 55 | bert_dev_paths = bert_embeddings(args, s_dev, '_dev') 56 | bert_dev = np.load(bert_dev_paths[-1]) 57 | 58 | s_test, p_test = load_data('penn_treebank_dataset', 'test') 59 | doc_id, sen_id_test, global_graph = construct_graph(p_test) 60 | bert_test_paths = bert_embeddings(args, s_test, '_test') 61 | bert_test = np.load(bert_test_paths[-1]) 62 | 63 | return bert_train, bert_dev, bert_test, sen_id_train, sen_id_dev, sen_id_test 64 | 65 | 66 | def load_embeddings(args): 67 | # load data 68 | s_train, p_train = load_data('penn_treebank_dataset', 'train') 69 | s_dev, p_dev = load_data('penn_treebank_dataset', 'dev') 70 | s_test, p_test = load_data('penn_treebank_dataset', 'test') 71 | sentences = s_train + s_dev + s_test 72 | parsed = p_train + p_dev + p_test 73 | # sentences = s_test 74 | # parsed = p_test 75 | doc_id, sen_id, global_graph = construct_graph(parsed) 76 | # load embeddings 77 | graph_emb = graph_embeddings(args, global_graph, doc_id, sen_id) 78 | bert_emb_paths = bert_embeddings(args, sentences) 79 | # graph_emb = graph_embeddings(args, global_graph, doc_id, sen_id, '_test') 80 | # bert_emb_paths = bert_embeddings(args, sentences, '_test') 81 | bert_emb = np.load(bert_emb_paths[-1]) 82 | 83 | return graph_emb, bert_emb 84 | 85 | 86 | def test_ge_ptb(args, data_split=True): 87 | # load data 88 | if data_split: 89 | ge_train, ge_dev, ge_test, sid_train, sid_dev, sid_test \ 90 | = load_graph(args, data_split) 91 | else: 92 | graph_emb, sen_id = load_graph(args, data_split) 93 | ge_train, ge_dev, ge_test = graph_emb, graph_emb, graph_emb 94 | sid_train, sid_dev, sid_test = sen_id, sen_id, sen_id 95 | 96 | # test_tasks = ['all', 'local', 'global'] 97 | test_tasks = ['local'] 98 | aucs, jaccards, uuass = {}, {}, {} 99 | 100 | for test_task in test_tasks: 101 | feat_dim = ge_train['s0'].shape[1] 102 | model = binary_classifer(args.classifier_layers_num, 103 | feat_dim).to(args.device) 104 | optimizer = torch.optim.Adam(model.parameters(), lr=1e-4) 105 | loss_fcn = torch.nn.BCELoss() 106 | 107 | print('2. start to train model: ', test_task) 108 | # train 109 | model.train() 110 | train_losses = [999 for _ in range(args.patience)] 111 | for _ in range(10): # epoch 112 | loss_train = 0 113 | for s in range(len(sid_train)): 114 | # get graph embedding 115 | graph_emb = ge_train['s'+str(s)] 116 | graph_emb = torch.FloatTensor(graph_emb).to(args.device) 117 | # get ground-truth graph 118 | src_idx, dst_idx, edge_labels = get_edge_idx(sid_train[s]) 119 | edge_labels = torch.FloatTensor(edge_labels).to(args.device) 120 | if len(src_idx) <= 1: continue 121 | optimizer.zero_grad() 122 | edge_pred = model(graph_emb[src_idx], graph_emb[dst_idx]) 123 | edge_pred = torch.squeeze(edge_pred) 124 | loss = loss_fcn(edge_pred, edge_labels) 125 | loss.backward() 126 | optimizer.step() 127 | loss_train += loss.data.item() 128 | loss_train = loss_train/len(sid_train) 129 | print(' the training loss is: {:.4f}'.format(loss_train)) 130 | # early stop 131 | if loss_train < max(train_losses): 132 | train_losses.remove(max(train_losses)) 133 | train_losses.append(loss_train) 134 | else: 135 | break 136 | 137 | # validation 138 | if data_split: 139 | print('2. start to validate model: ', test_task) 140 | # validation 141 | model.eval() 142 | loss_dev = 0 143 | for s in range(len(sid_dev)): 144 | # get graph embedding 145 | graph_emb = ge_dev['s'+str(s)] 146 | graph_emb = torch.FloatTensor(graph_emb).to(args.device) 147 | # get ground-truth graph 148 | src_idx, dst_idx, edge_labels = get_edge_idx(sid_dev[s]) 149 | edge_labels = torch.FloatTensor(edge_labels).to(args.device) 150 | if len(src_idx) <= 1: continue 151 | edge_pred = model(graph_emb[src_idx], graph_emb[dst_idx]) 152 | edge_pred = torch.squeeze(edge_pred) 153 | loss = loss_fcn(edge_pred, edge_labels) 154 | loss_dev += loss.data.item() 155 | loss_dev = loss_dev/len(sid_dev) 156 | print('2. start to test model: {} |Train loss: {:.4f} |Val loss: {:.4f}'.format( 157 | test_task, loss_train, loss_dev)) 158 | 159 | # test 160 | model.eval() 161 | auc, jaccard, uuas = [], [], [] 162 | print('2. start to test model: ', test_task) 163 | for s in range(len(sid_test)): 164 | # get graph embedding 165 | graph_emb = ge_test['s'+str(s)] 166 | graph_emb = torch.FloatTensor(graph_emb).to(args.device) 167 | # get ground-truth graph 168 | src_idx, dst_idx, edge_labels = get_edge_idx(sid_test[s]) 169 | edge_labels = torch.FloatTensor(edge_labels).to(args.device) 170 | if len(src_idx) <= 1: continue 171 | edge_pred = model(graph_emb[src_idx], graph_emb[dst_idx]) 172 | edge_pred = torch.squeeze(edge_pred).detach().cpu().numpy() 173 | edge_labels = edge_labels.detach().cpu().numpy() 174 | if edge_labels.sum() > 0: 175 | auc.append(roc_auc_score(edge_labels, edge_pred)) 176 | edge_pred = np.where(edge_pred > 0.5, 1, 0) 177 | jaccard.append(jaccard_score(edge_labels, edge_pred)) 178 | # uuas.append(uuas_score(src_idx, dst_idx, edge_labels, edge_pred)) 179 | aucs[test_task] = auc 180 | jaccards[test_task] = jaccard 181 | uuass[test_task] = uuas 182 | 183 | print(sum(aucs['local']) / len(aucs['local']), \ 184 | sum(jaccards['local']) / len(jaccards['local'])) 185 | 186 | return 187 | 188 | 189 | def test_bert_ptb(args, model_name='bert'): 190 | # load data 191 | bert_train, bert_dev, bert_test, sid_train, sid_dev, sid_test = load_bert(args) 192 | if model_name != 'bert': 193 | bert_train, bert_dev, bert_test = load_split_emb(len(bert_train), len(bert_dev), 194 | len(bert_test), model_name, 'ptb') 195 | feat_dim = bert_train['s0'].shape[1] 196 | model = binary_classifer(args.classifier_layers_num, 197 | feat_dim).to(args.device) 198 | optimizer = torch.optim.Adam(model.parameters(), lr=1e-4) 199 | loss_fcn = torch.nn.BCELoss() 200 | 201 | print('2. start to train model...') 202 | # train 203 | model.train() 204 | train_losses = [999 for _ in range(args.patience)] 205 | for _ in range(10): # epoch 206 | loss_train = 0 207 | for s in range(len(sid_train)): 208 | # get graph embedding 209 | bert_emb = bert_train['s'+str(s)] 210 | bert_emb = torch.FloatTensor(bert_emb).to(args.device) 211 | # get ground-truth graph 212 | src_idx, dst_idx, edge_labels = get_edge_idx(sid_train[s]) 213 | edge_labels = torch.FloatTensor(edge_labels).to(args.device) 214 | if len(src_idx) <= 1: continue 215 | optimizer.zero_grad() 216 | edge_pred = model(bert_emb[src_idx], bert_emb[dst_idx]) 217 | edge_pred = torch.squeeze(edge_pred) 218 | loss = loss_fcn(edge_pred, edge_labels) 219 | loss.backward() 220 | optimizer.step() 221 | loss_train += loss.data.item() 222 | loss_train = loss_train/len(sid_train) 223 | print(' the training loss is: {:.4f}'.format(loss_train)) 224 | # early stop 225 | if loss_train < max(train_losses): 226 | train_losses.remove(max(train_losses)) 227 | train_losses.append(loss_train) 228 | else: 229 | break 230 | 231 | print('2. start to validate model...') 232 | # validation 233 | model.eval() 234 | loss_dev = 0 235 | for s in range(len(sid_dev)): 236 | # get graph embedding 237 | bert_emb = bert_dev['s'+str(s)] 238 | bert_emb = torch.FloatTensor(bert_emb).to(args.device) 239 | # get ground-truth graph 240 | src_idx, dst_idx, edge_labels = get_edge_idx(sid_dev[s]) 241 | edge_labels = torch.FloatTensor(edge_labels).to(args.device) 242 | if len(src_idx) <= 1: continue 243 | 244 | edge_pred = model(bert_emb[src_idx], bert_emb[dst_idx]) 245 | edge_pred = torch.squeeze(edge_pred) 246 | loss = loss_fcn(edge_pred, edge_labels) 247 | loss_dev += loss.data.item() 248 | loss_dev = loss_dev/len(sid_dev) 249 | 250 | print('2. | Train loss: {:.4f} | Val loss: {:.4f}'.format(loss_train, loss_dev)) 251 | # test 252 | model.eval() 253 | auc, jaccard, uuas = [], [], [] 254 | print('2. start to test model...') 255 | for s in range(len(sid_test)): 256 | # get graph embedding 257 | bert_emb = bert_test['s'+str(s)] 258 | bert_emb = torch.FloatTensor(bert_emb).to(args.device) 259 | # get ground-truth graph 260 | src_idx, dst_idx, edge_labels = get_edge_idx(sid_test[s]) 261 | edge_labels = torch.FloatTensor(edge_labels).to(args.device) 262 | if len(src_idx) <= 1: continue 263 | edge_pred = model(bert_emb[src_idx], bert_emb[dst_idx]) 264 | edge_pred = torch.squeeze(edge_pred).detach().cpu().numpy() 265 | edge_labels = edge_labels.detach().cpu().numpy() 266 | if edge_labels.sum() > 0: 267 | auc.append(roc_auc_score(edge_labels, edge_pred)) 268 | edge_pred = np.where(edge_pred > 0.5, 1, 0) 269 | jaccard.append(jaccard_score(edge_labels, edge_pred)) 270 | # uuas.append(uuas_score(src_idx, dst_idx, edge_labels, edge_pred)) 271 | 272 | print(sum(auc)/len(auc), sum(jaccard)/len(jaccard)) 273 | 274 | return 275 | 276 | 277 | def test_mi_ptb(args): 278 | print('4.1 start to load embeddings...') 279 | graph_emb, bert_emb = load_embeddings(args) 280 | 281 | print('4.2 start to test graph MI...') 282 | # graph task = graph_emb: {all, local, global} X graph noise {0.0 ~ 1.0} 283 | graph_results = {'local':[]} 284 | ''' 285 | graph_results = {'all':[], 'local':[], 'global':[]} 286 | for k, _ in graph_results.items(): 287 | graph_emb = get_graph_emb(graph_emb, k) 288 | task_mi = [] 289 | for n in tqdm(range(11)): 290 | noise_mi = [] 291 | for r in range(args.repeat): 292 | mi_s = mine_probe(args, graph_emb, graph_emb, len(graph_emb), n/10) 293 | mi_s = sum(mi_s)/len(mi_s) 294 | noise_mi.append(mi_s) 295 | task_mi.append(noise_mi) 296 | graph_results[k] = task_mi 297 | ''' 298 | 299 | print('4.2 start to test bert MI...') 300 | # bert task = graph_emb: {all, local, global} X bert noise {0.0 ~ 1.0} 301 | # bert_results = {'all':[], 'local':[], 'global':[]} 302 | bert_results = {'local':[]} 303 | for k, _ in bert_results.items(): 304 | task_mi = [] 305 | for n in tqdm(range(11)): 306 | noise_mi = [] 307 | for r in range(args.repeat): 308 | mi_s = mine_probe(args, graph_emb, bert_emb, len(graph_emb), n/10) 309 | mi_s = sum(mi_s)/len(mi_s) 310 | noise_mi.append(mi_s) 311 | task_mi.append(noise_mi) 312 | bert_results[k] = task_mi 313 | print(graph_results, bert_results) 314 | return 315 | 316 | 317 | def mi_noise_ptb(args, pos=False): 318 | graph_emb, bert_emb = load_embeddings(args) 319 | noisy_g = load_noisy_trees(args, pos) 320 | results = {} 321 | for k, v in noisy_g.items(): 322 | if k not in results: results[k] = [] 323 | for r in range(args.repeat): 324 | mi = mine_probe(args, graph_emb, bert_emb, len(graph_emb), 'noisy', v) 325 | results[k].append(sum(mi) / len(mi)) 326 | print(results) 327 | return 328 | 329 | 330 | def test_random_ptb(args, pos=False, corrupt=False): 331 | # load data 332 | bert_train, bert_dev, bert_test, sid_train, sid_dev, sid_test = load_bert(args) 333 | feat_dim = args.bert_hidden_num 334 | noisy_g = load_noisy_trees(args, data_split=True) 335 | 336 | for noisy_tag, noisy_id in noisy_g.items(): 337 | print('2. corrupt type: ', noisy_tag) 338 | model = binary_classifer(args.classifier_layers_num, feat_dim).to(args.device) 339 | optimizer = torch.optim.Adam(model.parameters(), lr=1e-4) 340 | loss_fcn = torch.nn.BCELoss() 341 | print('2. start to train model...') 342 | # train 343 | model.train() 344 | train_losses = [999 for _ in range(args.patience)] 345 | for _ in range(10): # epoch 346 | loss_train = 0 347 | for s in range(len(sid_train)): 348 | # get graph embedding 349 | bert_emb = bert_train['s'+str(s)] 350 | if corrupt: 351 | rand_vec = np.random.randn(bert_emb.shape[0], bert_emb.shape[1]) 352 | bert_emb[noisy_id[s]] = rand_vec[noisy_id[s]] 353 | bert_emb = torch.FloatTensor(bert_emb).to(args.device) 354 | # get ground-truth graph 355 | src_idx, dst_idx, edge_labels = get_edge_idx(sid_train[s]) 356 | edge_labels = torch.FloatTensor(edge_labels).to(args.device) 357 | if len(src_idx) <= 1: continue 358 | 359 | optimizer.zero_grad() 360 | edge_pred = model(bert_emb[src_idx], bert_emb[dst_idx]) 361 | edge_pred = torch.squeeze(edge_pred) 362 | loss = loss_fcn(edge_pred, edge_labels) 363 | loss.backward() 364 | optimizer.step() 365 | loss_train += loss.data.item() 366 | loss_train = loss_train/len(sid_train) 367 | print(' the training loss is: {:.4f}'.format(loss_train)) 368 | # early stop 369 | if loss_train < max(train_losses): 370 | train_losses.remove(max(train_losses)) 371 | train_losses.append(loss_train) 372 | else: 373 | break 374 | 375 | # test 376 | model.eval() 377 | label_ids = load_tree_labels(args, pos) 378 | auc_dict = {} 379 | print('2. start to test model...') 380 | for k, v_l in label_ids.items(): 381 | auc_tmp = [] 382 | for s in range(len(sid_test)): 383 | # get graph embedding 384 | bert_emb = bert_test['s'+str(s)] 385 | bert_emb = torch.FloatTensor(bert_emb).to(args.device) 386 | # get ground-truth graph 387 | src_idx, dst_idx, edge_labels = get_edge_idx(sid_test[s]) 388 | edge_labels = torch.FloatTensor(edge_labels).to(args.device) 389 | if len(src_idx) <= 1: continue 390 | edge_pred = model(bert_emb[src_idx], bert_emb[dst_idx]) 391 | edge_pred = torch.squeeze(edge_pred).detach().cpu().numpy() 392 | edge_labels = edge_labels.detach().cpu().numpy() 393 | edge_mask = [] 394 | for i in range(len(sid_test[s])): 395 | for j in range(len(sid_test[s])): 396 | if i in v_l[s] or j in v_l[s]: 397 | edge_mask.append(1) 398 | else: 399 | edge_mask.append(0) 400 | edge_mask = np.array(edge_mask) 401 | edge_pred = edge_pred[edge_mask==1] 402 | edge_labels = edge_labels[edge_mask==1] 403 | if edge_labels.sum() > 0: 404 | auc_tmp.append(roc_auc_score(edge_labels, edge_pred)) 405 | auc_dict[k] = sum(auc_tmp) / (1e-5 + len(auc_tmp)) 406 | print(auc_dict) 407 | 408 | return -------------------------------------------------------------------------------- /src/probe.py: -------------------------------------------------------------------------------- 1 | import os 2 | import gc 3 | import random 4 | import torch 5 | import networkx as nx 6 | import numpy as np 7 | from tqdm import tqdm 8 | from models import graph_probe 9 | from embed import graph_embeddings, bert_embeddings, get_embeddings 10 | from utils import load_data, construct_graph, load_glove, load_elmo, load_elmos 11 | 12 | 13 | ## sub-function descriptions 14 | ## please refer to main py file for function descriptions 15 | ''' 16 | mine_probe: probing function. maximizing lower bound as estimation 17 | ''' 18 | 19 | 20 | def mine_probe(args, graph_emb, bert_emb, sen_num, task_name, noisy_id=[]): 21 | bert_dim = bert_emb['s0'].shape[1] 22 | graph_dim = graph_emb['s0'].shape[1] 23 | if task_name == 'upper': 24 | bert_dim = graph_dim 25 | model = graph_probe(graph_dim, bert_dim).to(args.device) 26 | optimizer = torch.optim.Adam(model.parameters(), lr=args.lr) 27 | 28 | bad_np = [39927] 29 | mi_es = [-1 for _ in range(args.patience)] 30 | model.train() 31 | for epoch in range(10): # epoch 32 | mi_train = [] 33 | for i in range(sen_num): # batch 34 | if i in bad_np: continue 35 | graph_vec = graph_emb['s'+str(i)] 36 | if task_name == 'lower': 37 | feat_vec = torch.randn(size=bert_emb['s'+str(i)].shape) 38 | elif task_name == 'upper': 39 | feat_vec = graph_vec 40 | elif type(task_name) == int: 41 | feat_vec = bert_emb['s'+str(i)] 42 | elif type(task_name) == float: 43 | vec_shape = bert_emb['s'+str(i)].shape 44 | feat_vec = (1-task_name) * bert_emb['s'+str(i)] + \ 45 | task_name * np.random.randn(vec_shape[0], vec_shape[1]) 46 | elif task_name == 'noisy': 47 | feat_vec = bert_emb['s'+str(i)] 48 | graph_vec = graph_emb['s'+str(i)] 49 | vec_shape = graph_emb['s'+str(i)].shape 50 | rand_vec = np.random.randn(vec_shape[0], vec_shape[1]) 51 | # graph_vec[noisy_id[i]] = 0.5*rand_vec[noisy_id[i]] +\ 52 | # 0.5*graph_vec[noisy_id[i]] 53 | graph_vec[noisy_id[i]] = rand_vec[noisy_id[i]] 54 | else: 55 | print('Error probe task name: ', task_name) 56 | graph_vec = torch.FloatTensor(graph_vec).to(args.device) 57 | feat_vec = torch.FloatTensor(feat_vec).to(args.device) 58 | 59 | optimizer.zero_grad() 60 | if graph_vec.shape[0] <= 1: continue 61 | if feat_vec.shape[0] <= 1: continue 62 | joint = model(graph_vec, feat_vec) 63 | feat_shuffle = feat_vec[torch.randperm(feat_vec.shape[0])] 64 | marginal = torch.exp(torch.clamp(model(graph_vec, feat_shuffle), max=88)) 65 | mi = torch.mean(joint) - torch.log(torch.mean(marginal)) 66 | loss = -mi 67 | mi_train.append(mi.data.item()) 68 | 69 | loss.backward() 70 | optimizer.step() 71 | ''' 72 | print(" Training probe model: {} | Epoch {:05d} | MI: {:.4f}".format( 73 | task_name, 74 | epoch + 1, 75 | sum(mi_train)/len(mi_train))) 76 | ''' 77 | # early stop 78 | if sum(mi_train)/len(mi_train) > min(mi_es): 79 | mi_es.remove(min(mi_es)) 80 | mi_es.append(sum(mi_train)/len(mi_train)) 81 | else: 82 | break 83 | 84 | mi_eval = [] 85 | model.eval() 86 | for i in range(sen_num): # batch 87 | graph_vec = graph_emb['s'+str(i)] 88 | if task_name == 'lower': 89 | feat_vec = torch.randn(size=bert_emb['s'+str(i)].shape) 90 | elif task_name == 'upper': 91 | feat_vec = graph_vec 92 | elif type(task_name) == int: 93 | feat_vec = bert_emb['s'+str(i)] 94 | elif type(task_name) == float: 95 | vec_shape = bert_emb['s'+str(i)].shape 96 | feat_vec = (1-task_name) * bert_emb['s'+str(i)] + \ 97 | task_name * np.random.randn(vec_shape[0], vec_shape[1]) 98 | elif task_name == 'noisy': 99 | feat_vec = bert_emb['s'+str(i)] 100 | graph_vec = graph_emb['s'+str(i)] 101 | vec_shape = graph_emb['s'+str(i)].shape 102 | rand_vec = np.random.randn(vec_shape[0], vec_shape[1]) 103 | graph_vec[noisy_id[i]] = rand_vec[noisy_id[i]] 104 | else: 105 | print('Error probe task name: ', task_name) 106 | graph_vec = torch.FloatTensor(graph_vec).to(args.device) 107 | feat_vec = torch.FloatTensor(feat_vec).to(args.device) 108 | 109 | optimizer.zero_grad() 110 | if graph_vec.shape[0] <= 1: 111 | mi_eval.append(0.) 112 | continue 113 | if feat_vec.shape[0] <= 1: 114 | mi_eval.append(0.) 115 | continue 116 | joint = model(graph_vec, feat_vec) 117 | feat_shuffle = feat_vec[torch.randperm(feat_vec.shape[0])] 118 | marginal = torch.exp(torch.clamp(model(graph_vec, feat_shuffle), max=88)) 119 | mi = torch.mean(joint) - torch.log(torch.mean(marginal)) 120 | loss = -mi 121 | mi_eval.append(mi.data.item()) 122 | 123 | loss.backward() 124 | optimizer.step() 125 | 126 | print(" ----Testing probe model: {} | Epoch {:05d} | MI: {:.4f}".format( 127 | task_name, 128 | epoch + 1, 129 | sum(mi_eval)/len(mi_eval))) 130 | 131 | # free memory 132 | model = None 133 | optimizer = None 134 | torch.cuda.empty_cache() 135 | gc.collect() 136 | 137 | # return [max(0, min(1, m)) for m in mi_eval] 138 | return mi_eval 139 | 140 | 141 | ''' 142 | def npeet_probe(args, graph_emb, bert_emb, sen_num, task_name, noisy_id=[]): 143 | print(' ----Start to train autoencoder...') 144 | # train graph autoencoder 145 | graph_model = autoencoder(graph_emb['s0'].shape[1]).to(args.device) 146 | if not os.path.exists('./tmp/g_model_'+args.task+'.pkl'): 147 | optimizer = torch.optim.Adam(graph_model.parameters(), lr=args.lr) 148 | loss_fcn = torch.nn.MSELoss() 149 | graph_model.train() 150 | train_losses = [999 for _ in range(args.patience)] 151 | for _ in range(100): 152 | loss_train = 0 153 | for i in range(sen_num): # batch 154 | graph_vec = graph_emb['s'+str(i)] 155 | if graph_vec.shape[0] <= 3: continue 156 | graph_vec = torch.FloatTensor(graph_vec).to(args.device) 157 | optimizer.zero_grad() 158 | pred = graph_model(graph_vec) 159 | loss = loss_fcn(pred, graph_vec) 160 | loss.backward() 161 | optimizer.step() 162 | loss_train += loss.data.item() 163 | loss_train = loss_train / sen_num 164 | print('----Training graph autoencoder loss: {:.4f}'.format(loss_train)) 165 | # early stop 166 | if loss_train < max(train_losses): 167 | train_losses.remove(max(train_losses)) 168 | train_losses.append(loss_train) 169 | else: 170 | break 171 | torch.save(graph_model.state_dict(), './tmp/g_model_'+args.task+'.pkl') 172 | graph_model.load_state_dict(torch.load('./tmp/g_model_'+args.task+'.pkl')) 173 | graph_model = graph_model.to(args.device) 174 | 175 | # train bert autoencoder 176 | bert_model = autoencoder(bert_emb['s0'].shape[1]).to(args.device) 177 | if not os.path.exists('./tmp/b_model_'+args.task+'.pkl'): 178 | optimizer = torch.optim.Adam(bert_model.parameters(), lr=args.lr) 179 | loss_fcn = torch.nn.MSELoss() 180 | bert_model.train() 181 | train_losses = [999 for _ in range(args.patience)] 182 | for _ in range(100): 183 | loss_train = 0 184 | for i in range(sen_num): # batch 185 | feat_vec = bert_emb['s'+str(i)] 186 | if feat_vec.shape[0] <= 3: continue 187 | feat_vec = torch.FloatTensor(feat_vec).to(args.device) 188 | optimizer.zero_grad() 189 | pred = bert_model(feat_vec) 190 | loss = loss_fcn(pred, feat_vec) 191 | loss.backward() 192 | optimizer.step() 193 | loss_train += loss.data.item() 194 | loss_train = loss_train / sen_num 195 | print('----Training BERT autoencoder loss: {:.4f}'.format(loss_train)) 196 | # early stop 197 | if loss_train < max(train_losses): 198 | train_losses.remove(max(train_losses)) 199 | train_losses.append(loss_train) 200 | else: 201 | break 202 | torch.save(bert_model.state_dict(), './tmp/b_model_'+args.task+'.pkl') 203 | bert_model.load_state_dict(torch.load('./tmp/b_model_'+args.task+'.pkl')) 204 | bert_model = bert_model.to(args.device) 205 | 206 | from npeet import entropy_estimators as ee 207 | from pycit import itest 208 | print(' ----Start to calculate low-dimensional representations...') 209 | graph_model.eval() 210 | bert_model.eval() 211 | mi_estimate = [] 212 | graph_vecs, bert_vecs = [], [] 213 | for i in tqdm(range(sen_num)): # batch 214 | graph_vec = torch.FloatTensor(graph_emb['s'+str(i)]).to(args.device) 215 | graph_vec = graph_model(graph_vec, True) 216 | if task_name == 'lower': 217 | vec_shape = graph_vec.shape 218 | feat_vec = np.random.randn(vec_shape[0], vec_shape[1]) 219 | elif task_name == 'upper': 220 | feat_vec = graph_vec 221 | elif type(task_name) == int: 222 | feat_vec = torch.FloatTensor(bert_emb['s'+str(i)]).to(args.device) 223 | feat_vec = bert_model(feat_vec, True) 224 | elif type(task_name) == float: 225 | feat_vec = torch.FloatTensor(bert_emb['s'+str(i)]).to(args.device) 226 | feat_vec = bert_model(feat_vec, True) 227 | feat_vec = (1 - task_name) * feat_vec + \ 228 | task_name * np.random.randn(feat_vec[0], feat_vec[1]) 229 | else: 230 | print('Error probe task name: ', task_name) 231 | graph_vecs.append(graph_vec) 232 | bert_vecs.append(feat_vec) 233 | graph_vecs = np.concatenate(graph_vecs) 234 | bert_vecs = np.concatenate(bert_vecs) 235 | 236 | print(' ----Start to calculate MI with NPEET...') 237 | batch_size = 1000 238 | ksg_mis, bi_ksg_mis = [], [] 239 | for i in tqdm(range(int(graph_vecs.shape[0]/batch_size))): 240 | range_lower = i*batch_size 241 | range_upper = min(graph_vecs.shape[0], (i+1)*batch_size) 242 | ksg_mis.append(itest(graph_vecs[range_lower:range_upper], 243 | bert_vecs[range_lower:range_upper], 244 | test_args={'statistic': 'ksg_mi', 'n_jobs': 10})) 245 | bi_ksg_mis.append(itest(graph_vecs[range_lower:range_upper], 246 | bert_vecs[range_lower:range_upper], 247 | test_args={'statistic': 'bi_ksg_mi', 'n_jobs': 10})) 248 | # mi_estimate.append(ee.mi(graph_vecs[range_lower:range_upper], 249 | # bert_vecs[range_lower:range_upper], k=knn_k)) 250 | # if graph_vec.shape[0] <= knn_k+8: continue 251 | # if feat_vec.shape[0] <= knn_k+8: continue 252 | # mi_estimate.append(ee.mi(graph_vec, feat_vec, k=knn_k)) 253 | # mi_estimate = sum(mi_estimate) / len(mi_estimate) 254 | # print(' ----Testing estimate MI value: {:.4f}'.format(mi_estimate)) 255 | ksg_mi_val = sum(ksg_mis) / len(ksg_mis) 256 | bi_ksg_mi_val = sum(bi_ksg_mis) / len(bi_ksg_mis) 257 | print(' ---Testing estimate MI value: {:.4f} | {:.4f}'.format(ksg_mi_val, bi_ksg_mi_val)) 258 | # free memory 259 | gc.collect() 260 | 261 | return mi_estimate 262 | ''' 263 | 264 | 265 | def mi_bert_ptb(args, npeet=False, uncontext=False): 266 | # load data 267 | s_train, p_train = load_data('penn_treebank_dataset', 'train') 268 | s_dev, p_dev = load_data('penn_treebank_dataset', 'dev') 269 | s_test, p_test = load_data('penn_treebank_dataset', 'test') 270 | sentences = s_train + s_dev + s_test 271 | parsed = p_train + p_dev + p_test 272 | doc_id, sen_id, global_graph = construct_graph(parsed) 273 | s_train, p_train, s_dev, p_dev, s_test, p_test = [], [], [], [], [], [] 274 | 275 | # load embeddings 276 | graph_emb = graph_embeddings(args, global_graph, doc_id, sen_id) 277 | if uncontext: 278 | bert_emb = load_glove(args, sentences) 279 | # bert_emb = load_elmo(args, sentences) 280 | else: 281 | bert_emb_paths = bert_embeddings(args, sentences) 282 | # bert_emb_paths = load_elmos(args, sentences) 283 | bert_emb = np.load(bert_emb_paths[0], allow_pickle=True) 284 | 285 | # initialize mi 286 | mir, mig, mib = [], [], [] 287 | for l in range(args.bert_layers_num): mib.append([]) 288 | for s in range(len(sentences)): 289 | mir.append(0.) 290 | mig.append(0.) 291 | for l in range(args.bert_layers_num): 292 | mib[l].append(0.) 293 | 294 | if args.baselines: 295 | print('3.1 start to calculate baselines of MI...') 296 | # calculate MI baselines 297 | for r in range(args.repeat): 298 | tmp_mir = mine_probe(args, graph_emb, bert_emb, len(sentences), 'lower') 299 | tmp_mig = mine_probe(args, graph_emb, bert_emb, len(sentences), 'upper') 300 | # get sum value 301 | mir = [mir[s]+tmp_mir[s] for s in range(len(tmp_mir))] 302 | mig = [mig[s]+tmp_mig[s] for s in range(len(tmp_mig))] 303 | 304 | print('3.2 start to calculate BERT hidden states of MI...') 305 | if uncontext: 306 | for r in range(args.repeat): 307 | tmp_mib = mine_probe(args, graph_emb, bert_emb, len(sentences), 308 | args.bert_layers_num - 1) 309 | mib[-1] = [mib[-1][s]+tmp_mib[s] for s in range(len(tmp_mib))] 310 | mib_layers = sum(mib[-1]) / (len(mib[-1]) * args.repeat) 311 | print('MI(G, Glove): {} |'.format(mib_layers)) 312 | else: 313 | # calculate MI of BERT 314 | for l in range(args.bert_layers_num): 315 | bert_emb = np.load(bert_emb_paths[l], allow_pickle=True) 316 | for r in range(args.repeat): 317 | tmp_mib = mine_probe(args, graph_emb, bert_emb, len(sentences), l) 318 | mib[l] = [mib[l][s]+tmp_mib[s] for s in range(len(tmp_mib))] 319 | # compute average values for all results 320 | mir = [mi/args.repeat for mi in mir] 321 | mig = [mi/args.repeat for mi in mig] 322 | for l in range(args.bert_layers_num): 323 | mib[l] = [mi/args.repeat for mi in mib[l]] 324 | mib_layers = [sum(mib[l])/len(mib[l]) for l in range(len(mib))] 325 | 326 | # print general results 327 | results = {'lower:': mir, 'upper': mig, 'bert': mib} 328 | # print('\n', results, '\n') 329 | 330 | print('MI(G, R): {} | MI(G, G): {}| MI(G, BERT): {} |'.format(sum( 331 | mir)/len(mir), 332 | sum(mig)/len(mig), 333 | mib_layers)) 334 | 335 | return 336 | 337 | 338 | def mi_bert_amr(args, uncontext=False): 339 | # load data & embeddings 340 | s_train = load_data('amr_dataset', 'train') 341 | s_dev = load_data('amr_dataset', 'dev') 342 | s_test = load_data('amr_dataset', 'test') 343 | amr_s = s_train + s_dev + s_test 344 | print(amr_s[45672], amr_s[599]) 345 | graph_emb, bert_emb_paths = get_embeddings(args, amr_s) 346 | # bert_emb_paths = load_elmos(args, amr_s, dataset='amr') 347 | s_num = len(graph_emb) 348 | if uncontext: 349 | bert_emb = load_glove(args, amr_s, dataset='amr') 350 | # bert_emb = load_elmo(args, amr_s, dataset='amr') 351 | else: 352 | bert_emb = np.load(bert_emb_paths[0], allow_pickle=True) 353 | 354 | print('2.1 start to calculate baselines of MI...') 355 | # initialize mi 356 | mir, mig, mib = [], [], [] 357 | for l in range(args.bert_layers_num): mib.append([]) 358 | 359 | if args.baselines: 360 | print('3.1 start to calculate baselines of MI...') 361 | # calculate MI baselines 362 | for r in range(args.repeat): 363 | tmp_mir = mine_probe(args, graph_emb, bert_emb, s_num, 'lower') 364 | tmp_mig = mine_probe(args, graph_emb, bert_emb, s_num, 'upper') 365 | # get sum value 366 | if len(mir) == 0: 367 | mir = tmp_mir 368 | else: 369 | mir = [mir[s]+tmp_mir[s] for s in range(len(tmp_mir))] 370 | if len(mig) == 0: 371 | mig = tmp_mig 372 | else: 373 | mig = [mig[s]+tmp_mig[s] for s in range(len(tmp_mig))] 374 | 375 | print('2.2 start to calculate BERT hidden states of MI...') 376 | # calculate MI of BERT 377 | if uncontext: 378 | for r in range(args.repeat): 379 | tmp_mib = mine_probe(args, graph_emb, bert_emb, s_num, args.bert_layers_num-1) 380 | if len(mib[-1]) == 0: 381 | mib[-1] = tmp_mib 382 | else: 383 | mib[-1] = [mib[-1][s]+tmp_mib[s] for s in range(len(tmp_mib))] 384 | mib_layers = sum(mib[-1]) / (len(mib[-1]) * args.repeat) 385 | print('MI(G, Glove): {} |'.format(mib_layers)) 386 | else: 387 | for l in range(args.bert_layers_num): 388 | bert_emb = np.load(bert_emb_paths[l], allow_pickle=True) 389 | for r in range(args.repeat): 390 | tmp_mib = mine_probe(args, graph_emb, bert_emb, s_num, l) 391 | if len(mib[l]) == 0: 392 | mib[l] = tmp_mib 393 | else: 394 | mib[l] = [mib[l][s]+tmp_mib[s] for s in range(len(tmp_mib))] 395 | 396 | # compute average values for all results 397 | mir = [mi/args.repeat for mi in mir] 398 | mig = [mi/args.repeat for mi in mig] 399 | for l in range(args.bert_layers_num): 400 | mib[l] = [mi/args.repeat for mi in mib[l]] 401 | 402 | # print general results 403 | results = {'lower:': mir, 'upper': mig, 'bert': mib} 404 | print('\n', results, '\n') 405 | mib_layers = [sum(mib[l])/len(mib[l]) for l in range(len(mib)) if len(mib)] 406 | print('MI(G, R): {} | MI(G, G): {}| MI(G, BERT): {} |'.format(sum( 407 | mir)/len(mir), 408 | sum(mig)/len(mig), 409 | mib_layers)) 410 | 411 | return 412 | 413 | 414 | def mi_mlps_ptb(args): 415 | # load data 416 | s_train, p_train = load_data('penn_treebank_dataset', 'train') 417 | s_dev, p_dev = load_data('penn_treebank_dataset', 'dev') 418 | s_test, p_test = load_data('penn_treebank_dataset', 'test') 419 | sentences = s_train + s_dev + s_test 420 | parsed = p_train + p_dev + p_test 421 | doc_id, sen_id, global_graph = construct_graph(parsed) 422 | s_train, p_train, s_dev, p_dev, s_test, p_test = [], [], [], [], [], [] 423 | 424 | # load embeddings 425 | graph_emb = graph_embeddings(args, global_graph, doc_id, sen_id) 426 | bert_emb = load_glove(args, sentences) 427 | # bert_emb = load_elmo(args, sentences) 428 | 429 | # bert_emb_paths = bert_embeddings(args, sentences) 430 | # bert_emb = np.load(bert_emb_paths[0], allow_pickle=True) 431 | 432 | 433 | # initialize mi 434 | mir, mig, mib = [], [], [] 435 | for l in range(args.bert_layers_num): mib.append([]) 436 | for s in range(len(sentences)): 437 | mir.append(0.) 438 | mig.append(0.) 439 | for l in range(args.bert_layers_num): 440 | mib[l].append(0.) 441 | 442 | if args.baselines: 443 | print('3.1 start to calculate baselines of MI...') 444 | # calculate MI baselines 445 | for r in range(args.repeat): 446 | tmp_mir = mine_probe(args, graph_emb, bert_emb, len(sentences), 'lower') 447 | tmp_mig = mine_probe(args, graph_emb, bert_emb, len(sentences), 'upper') 448 | # get sum value 449 | mir = [mir[s]+tmp_mir[s] for s in range(len(tmp_mir))] 450 | mig = [mig[s]+tmp_mig[s] for s in range(len(tmp_mig))] 451 | 452 | print('3.2 start to calculate BERT hidden states of MI...') 453 | for r in range(args.repeat): 454 | tmp_mib = mine_probe(args, graph_emb, bert_emb, len(sentences), 455 | args.bert_layers_num - 1) 456 | mib[-1] = [mib[-1][s]+tmp_mib[s] for s in range(len(tmp_mib))] 457 | mib_layers = sum(mib[-1]) / (len(mib[-1]) * args.repeat) 458 | print('MI(G, Glove): {} |'.format(mib_layers)) 459 | 460 | 461 | 462 | def mi_mlps_amr(args): 463 | return 464 | -------------------------------------------------------------------------------- /src/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import penman 3 | import networkx as nx 4 | import numpy as np 5 | import random 6 | import re 7 | import penman 8 | from tqdm import tqdm 9 | from conllu import parse 10 | # from torchnlp.datasets import penn_treebank_dataset 11 | 12 | 13 | ## sub-function descriptions 14 | ## please refer to main py file for function descriptions 15 | ''' 16 | clean_string: get clean (no special character) string 17 | load_data: load raw dataset 18 | construct_graph: get graph structure (PTB dataset) 19 | random_walks: run random walk for graph embedding 20 | get_edge_idx: generate index list for link prediction (PTB) dataset 21 | get_edge_idx_amr: generate index list for link prediction (AMR) dataset 22 | uuas_score (deprecated): calculate UUAS score 23 | get_graph_emb (deprecated): get global / local graph embedding 24 | load_split_emb: load graph embedding in the split setting 25 | load_noisy_trees: load corrupted PTB tree structures 26 | load_tree_labels: load corrupted PTB tree labels 27 | load_noisy_graphs: load corrupted AMR tree structures 28 | load_graph_labels: load corrupted AMR tree labels 29 | load_glove: load GloVe embedding 30 | load_elmo (python version <= 3.7): load ELMo0 embedding 31 | load_elmos (python version <= 3.7): load ELMo embedding 32 | ''' 33 | 34 | 35 | def clean_string(s): 36 | return re.sub('[^A-Za-z0-9]+', '', s) 37 | 38 | def load_data(data_name, data_type): 39 | if data_name == 'penn_treebank_dataset': 40 | # process for stanza dependency parsing 41 | sentences, tmp = [], [] 42 | if data_type == 'train': 43 | with open ("./sample_data/ptb-gold/train.conllx", "r") as f: 44 | data = f.read() 45 | elif data_type == 'test': 46 | with open ("./sample_data/ptb-gold/test.conllx", "r") as f: 47 | data = f.read() 48 | if data_type == 'dev': 49 | with open ("./sample_data/ptb-gold/dev.conllx", "r") as f: 50 | data = f.read() 51 | parsed = parse(data) 52 | for s in parsed: 53 | tmp = [] 54 | for t in s: 55 | tmp_txt = clean_string(t.get('form')) 56 | if len(tmp_txt): tmp.append(tmp_txt) 57 | if len(tmp): sentences.append(tmp) 58 | return sentences, parsed 59 | 60 | elif data_name == 'amr_dataset': 61 | # process for stanza dependency parsing 62 | sentences, tmp = [], [] 63 | if data_type == 'train': 64 | with open ("./sample_data/amr-split/amr-training.txt", "r") as f: 65 | data = f.read() 66 | elif data_type == 'test': 67 | with open ("./sample_data/amr-split/amr-test.txt", "r") as f: 68 | data = f.read() 69 | if data_type == 'dev': 70 | with open ("./sample_data/amr-split/amr-dev.txt", "r") as f: 71 | data = f.read() 72 | data = data.split('\n\n') 73 | for s in data: 74 | if s[:4] != '# ::': 75 | data.remove(s) 76 | return data 77 | 78 | else: 79 | print('Error data name!') 80 | return 81 | 82 | def construct_graph(parsed): 83 | ''' 84 | # stanza doc type to python dictionary 85 | nlp = stanza.Pipeline(lang='en', 86 | processors='tokenize, pos, lemma, depparse', 87 | tokenize_pretokenized=True) 88 | ''' 89 | word_dict = {'global_sentence_root': 0} 90 | word_id = 1 91 | global_graph = nx.Graph() 92 | doc_id, sen_id = [], [] 93 | print('1. start to tokenize and construct global graph...') 94 | for s in parsed: 95 | # construct word dictionary 96 | for t in s: 97 | if len(clean_string(t.get('form'))) == 0: continue 98 | if clean_string(t.get('form')) not in word_dict: 99 | word_dict[clean_string(t.get('form'))] = word_id 100 | word_id += 1 101 | # construct doc_id and global graph 102 | tmp1, tmp2 = [], [] 103 | for t in s: 104 | tail_txt = clean_string(t.get('form')) 105 | if len(tail_txt) == 0: continue 106 | tail_id = word_dict.get(tail_txt) 107 | head_idx = t.get('head') 108 | if head_idx == 0: 109 | head_id = 0 110 | else: 111 | head_txt = clean_string(s[head_idx-1].get('form')) 112 | if len(head_txt) == 0: 113 | head_id = tail_id 114 | else: 115 | head_id = word_dict.get(head_txt) 116 | global_graph.add_edge(head_id, tail_id) 117 | tmp1.append((t.get('head'), t.get('id'))) 118 | tmp2.append((head_id, tail_id)) 119 | if len(tmp1): sen_id.append(tmp1) 120 | if len(tmp2): doc_id.append(tmp2) 121 | ''' 122 | print(' global graph (V,E) numbers: ({}, {})'.format( 123 | global_graph.number_of_nodes(), 124 | global_graph.number_of_edges())) 125 | ''' 126 | return doc_id, sen_id, global_graph 127 | 128 | 129 | def random_walks(G, num_walks=100, walk_len=10, string_nid=False): 130 | paths = [] 131 | # add self loop 132 | for nid in G.nodes(): G.add_edge(nid, nid) 133 | if not string_nid: 134 | for nid in G.nodes(): 135 | if G.degree(nid) == 0: continue 136 | for i in range(num_walks): 137 | tmp_path = [str(nid)] 138 | for j in range(walk_len): 139 | neighbors = [str(n) for n in G.neighbors(int(tmp_path[-1]))] 140 | tmp_path.append(random.choice(neighbors)) 141 | paths.append(tmp_path) 142 | else: 143 | for nid in G.nodes(): 144 | if G.degree(nid) == 0: continue 145 | for i in range(num_walks): 146 | tmp_path = [nid] 147 | for j in range(walk_len): 148 | neighbors = [n for n in G.neighbors(tmp_path[-1])] 149 | tmp_path.append(random.choice(neighbors)) 150 | paths.append(tmp_path) 151 | 152 | return paths 153 | 154 | 155 | def get_edge_idx(edge_list): 156 | batch_ids = [] 157 | for (_, j) in edge_list: 158 | batch_ids.append(j) 159 | batch_ids = list(set(batch_ids)) 160 | tmp_dict = {} 161 | for i in range(len(batch_ids)): 162 | tmp_dict[batch_ids[i]] = i 163 | 164 | sort_edge_list = [] 165 | for (s, d) in edge_list: 166 | sort_edge_list.append((tmp_dict.get(s), tmp_dict.get(d))) 167 | 168 | edge_space = [] 169 | for i in range(len(batch_ids)): 170 | for j in range(len(batch_ids)): 171 | edge_space.append((i, j)) 172 | # random.shuffle(edge_space) 173 | src_idx = [i for (i, j) in edge_space] 174 | dst_idx = [j for (i, j) in edge_space] 175 | 176 | edge_labels = [] 177 | for e in edge_space: 178 | if e in sort_edge_list: 179 | edge_labels.append(1) 180 | elif (e[-1], e[0]) in sort_edge_list: 181 | edge_labels.append(1) 182 | else: 183 | edge_labels.append(0) 184 | 185 | return src_idx, dst_idx, np.array(edge_labels) 186 | 187 | 188 | def get_edge_idx_amr(s): 189 | # parse 190 | penman_g = penman.decode(s) 191 | s = penman_g.metadata.get('tok').split(' ') 192 | wid = [] 193 | var = [] # k=word id; v=variable 194 | for k, v in penman_g.epidata.items(): 195 | if k[1] == ':instance': 196 | if len(v): 197 | if type(v[0]) == penman.surface.Alignment: 198 | wid.append(v[0].indices[0]) 199 | var.append(k[0]) 200 | # graph construction 201 | g = nx.Graph() 202 | for v in penman_g.variables(): g.add_node(v) 203 | for e in penman_g.edges(): g.add_edge(e.source, e.target) 204 | 205 | edge_space = [] 206 | for i in range(len(var)): 207 | for j in range(len(var)): 208 | edge_space.append((i, j)) 209 | # random.shuffle(edge_space) 210 | src_idx = [i for (i, j) in edge_space] 211 | dst_idx = [j for (i, j) in edge_space] 212 | 213 | edge_labels = [] 214 | for e in edge_space: 215 | if (var[e[0]], var[e[1]]) in g.edges(): 216 | edge_labels.append(1) 217 | elif (var[e[1]], var[e[0]]) in g.edges(): 218 | edge_labels.append(1) 219 | else: 220 | edge_labels.append(0) 221 | 222 | return src_idx, dst_idx, np.array(edge_labels) 223 | 224 | 225 | def uuas_score(src_idx, dst_idx, edge_labels, edge_pred): 226 | g_label = nx.Graph() 227 | g_pred = nx.Graph() 228 | tmp_num = max(src_idx) + 1 229 | for src in src_idx: 230 | for dst in dst_idx: 231 | # add edge in label (ground-truth) graph 232 | if edge_labels[src*tmp_num+dst] == 1: 233 | g_label.add_edge(src, dst) 234 | # add edge in predicted graph 235 | if (src, dst) not in g_pred.edges(): 236 | weight_1 = edge_pred[src*tmp_num+dst] 237 | weight_2 = edge_pred[dst*tmp_num+src] 238 | g_pred.add_edge(src, dst, weight=weight_1+weight_2) 239 | g_mst = nx.minimum_spanning_tree(g_pred) 240 | total_num = g_mst.number_of_edges() 241 | uuas_num = 0 242 | for e in g_mst.edges(): 243 | if e in g_label: 244 | uuas_num += 1 245 | return uuas_num / total_num 246 | 247 | 248 | def get_graph_emb(graph_emb, task_name): 249 | new_graph_emb = {} 250 | if task_name == 'local': 251 | for s in range(len(graph_emb)): 252 | new_graph_emb['s'+str(s)] = graph_emb['s'+str(s)][:,640:] 253 | elif task_name == 'global': 254 | for s in range(len(graph_emb)): 255 | new_graph_emb['s'+str(s)] = graph_emb['s'+str(s)][:,:640] 256 | else: 257 | new_graph_emb = graph_emb 258 | 259 | return new_graph_emb 260 | 261 | 262 | def load_split_emb(train_len, dev_len, test_len, model_name, task='ptb'): 263 | if task == 'ptb': 264 | if model_name == 'elmo': 265 | emb_path = './tmp/elmo_ptb_bert.npz' 266 | else: 267 | emb_path = './tmp/glove_ptb_bert.npz' 268 | else: 269 | if model_name == 'elmo': 270 | emb_path = './tmp/elmo_amr_bert.npz' 271 | else: 272 | emb_path = './tmp/glove_amr_bert.npz' 273 | all_emb = np.load(emb_path) 274 | train_emb, dev_emb, test_emb = {}, {}, {} 275 | if len(all_emb) != train_len + dev_len + test_len: 276 | print('Error of length !', len(all_emb), train_len, dev_len, test_len) 277 | tmp_count = 0 278 | for i in range(len(all_emb)): 279 | curr_key = 's' + str(i) 280 | if i < train_len: 281 | train_key = curr_key 282 | train_emb[train_key] = all_emb[curr_key] 283 | elif i < train_len + dev_len: 284 | dev_key = 's' + str(i-train_len) 285 | dev_emb[dev_key] = all_emb[curr_key] 286 | else: 287 | test_key = 's' + str(i-train_len-dev_len) 288 | test_emb[test_key] = all_emb[curr_key] 289 | 290 | return train_emb, dev_emb, test_emb 291 | 292 | 293 | def load_noisy_trees(args, pos=False, data_split=False): 294 | print('2. start to calculate noisy id...') 295 | 296 | if data_split: 297 | sentences, parsed = load_data('penn_treebank_dataset', 'train') 298 | else: 299 | s_train, p_train = load_data('penn_treebank_dataset', 'train') 300 | s_dev, p_dev = load_data('penn_treebank_dataset', 'dev') 301 | s_test, p_test = load_data('penn_treebank_dataset', 'test') 302 | sentences = s_train + s_dev + s_test 303 | parsed = p_train + p_dev + p_test 304 | # get noisy node ids 305 | edge_labels = {} 306 | if pos: # pos noisy 307 | for s in parsed: 308 | for w in s: 309 | if w.get('upos') not in edge_labels: 310 | edge_labels[w.get('upos')] = 1 311 | else: 312 | edge_labels[w.get('upos')] += 1 313 | noisy_id = {} 314 | k_list = ['IN', 'NNP', 'DT', 'JJ', 'NNS'] 315 | for k, v in edge_labels.items(): 316 | if k in k_list: 317 | noisy_id[k] = [] 318 | else: # edge noisy 319 | for s in parsed: 320 | for w in s: 321 | if w.get('deprel') not in edge_labels: 322 | edge_labels[w.get('deprel')] = 1 323 | else: 324 | edge_labels[w.get('deprel')] += 1 325 | noisy_id = {} 326 | k_list = ['prep', 'det', 'nn', 'pobj', 'nsubj'] 327 | for k in k_list: 328 | noisy_id[k] = [] 329 | 330 | for k, _ in noisy_id.items(): 331 | sen_id = [] 332 | for s in parsed: 333 | tmp = set() 334 | tmp_embed = [] 335 | for w in s: 336 | if len(clean_string(w.get('form'))) == 0: 337 | continue 338 | tmp_embed.append((w.get('head'), w.get('id'))) 339 | if pos: # pos noisy 340 | if w.get('upos') == k: tmp.add(w.get('id')) 341 | else: # edge noisy 342 | if w.get('deprel') == k: 343 | tmp.add(w.get('id')) 344 | tmp.add(w.get('head')) 345 | tmp_set = set() 346 | for _, dst in tmp_embed: 347 | tmp_set.add(dst) 348 | tmp_dict = {} 349 | count = 0 350 | for idx in tmp_set: 351 | if idx > 0: 352 | tmp_dict[idx] = count 353 | count += 1 354 | tmp_new = [] 355 | for idx in tmp: 356 | if idx in tmp_dict: 357 | tmp_new.append(tmp_dict[idx]) 358 | tmp_new.sort() 359 | if len(tmp_embed): sen_id.append(tmp_new) 360 | noisy_id[k] = sen_id 361 | 362 | # leverage noisy node number 363 | drop_ratio = {} 364 | for k, v in noisy_id.items(): 365 | drop_ratio[k] = 0 366 | for s in v: 367 | drop_ratio[k] += len(s) 368 | min_count = min([v for k, v in drop_ratio.items()]) 369 | for k, v in drop_ratio.items(): 370 | drop_ratio[k] = min_count / drop_ratio[k] 371 | 372 | noisy_id_new = {} 373 | for k, _ in noisy_id.items(): 374 | noisy_id_new[k] = [] 375 | for k, v in noisy_id.items(): 376 | tmp_s = [] 377 | for s in v: 378 | tmp_w = [] 379 | for w in s: 380 | if random.random() < drop_ratio[k]: 381 | tmp_w.append(w) 382 | tmp_s.append(tmp_w) 383 | noisy_id_new[k] = tmp_s 384 | 385 | ''' 386 | test_dict = {} 387 | for k, v in noisy_id_new.items(): 388 | count = 0 389 | for s in v: 390 | count += len(s) 391 | test_dict[k] = count 392 | print(test_dict) 393 | ''' 394 | 395 | return noisy_id_new 396 | 397 | 398 | def load_tree_labels(args, pos=False): 399 | print('2. start to calculate noisy id...') 400 | sentences, parsed = load_data('penn_treebank_dataset', 'test') 401 | # get noisy node ids 402 | edge_labels = {} 403 | if pos: # pos noisy 404 | for s in parsed: 405 | for w in s: 406 | if w.get('upos') not in edge_labels: 407 | edge_labels[w.get('upos')] = 1 408 | else: 409 | edge_labels[w.get('upos')] += 1 410 | noisy_id = {} 411 | for k, v in edge_labels.items(): 412 | noisy_id[k] = [] 413 | else: # edge noisy 414 | for s in parsed: 415 | for w in s: 416 | if w.get('deprel') not in edge_labels: 417 | edge_labels[w.get('deprel')] = 1 418 | else: 419 | edge_labels[w.get('deprel')] += 1 420 | noisy_id = {} 421 | for k, v in edge_labels.items(): 422 | noisy_id[k] = [] 423 | 424 | for k, _ in noisy_id.items(): 425 | sen_id = [] 426 | for s in parsed: 427 | tmp = set() 428 | tmp_embed = [] 429 | for w in s: 430 | if len(clean_string(w.get('form'))) == 0: 431 | continue 432 | tmp_embed.append((w.get('head'), w.get('id'))) 433 | if pos: # pos noisy 434 | if w.get('upos') == k: tmp.add(w.get('id')) 435 | else: # edge noisy 436 | if w.get('deprel') == k: 437 | tmp.add(w.get('id')) 438 | tmp.add(w.get('head')) 439 | tmp_set = set() 440 | for _, dst in tmp_embed: 441 | tmp_set.add(dst) 442 | tmp_dict = {} 443 | count = 0 444 | for idx in tmp_set: 445 | if idx > 0: 446 | tmp_dict[idx] = count 447 | count += 1 448 | tmp_new = [] 449 | for idx in tmp: 450 | if idx in tmp_dict: 451 | tmp_new.append(tmp_dict[idx]) 452 | tmp_new.sort() 453 | if len(tmp_embed): sen_id.append(tmp_new) 454 | noisy_id[k] = sen_id 455 | 456 | ''' 457 | test_dict = {} 458 | for k, v in noisy_id.items(): 459 | count = 0 460 | for s in v: 461 | count += len(s) 462 | test_dict[k] = count 463 | print(test_dict, edge_labels) 464 | ''' 465 | 466 | return noisy_id 467 | 468 | 469 | def load_noisy_graphs(args): 470 | print('2. start to calculate noisy id...') 471 | s_train = load_data('amr_dataset', 'train') 472 | s_dev = load_data('amr_dataset', 'dev') 473 | s_test = load_data('amr_dataset', 'test') 474 | amr_s = s_train + s_dev + s_test 475 | 476 | label_arg = [':ARG0', ':ARG1', ':ARG2', ':ARG3', ':ARG4',\ 477 | ':ARG5', ':ARG6', ':ARG7', ':ARG8', ':ARG9'] 478 | label_op = [':op1', ':op2', ':op3', ':op4', ':op5', ':op6', ':op7',\ 479 | ':op8', ':op9', ':op10', ':op11', ':op12', ':op13', ':op14',\ 480 | ':op15', ':op16', ':op17', ':op18', ':op19'] 481 | label_general = [':accompanier', ':age', ':beneficiary', ':concession',\ 482 | ':condition', ':consist', ':degree', ':destination',\ 483 | ':direction', ':domain', ':duration', ':example', \ 484 | ':extent', ':frequency', ':instrument', ':location',\ 485 | ':manner', ':medium', ':mod', ':name', ':part', ':path',\ 486 | ':polarity', ':poss', ':purpose', ':source', ':subevent',\ 487 | ':subset', ':time', ':topic', ':value', ':ord', ':range'] 488 | labels_dict = {'arg': label_arg, 'op': label_op, 'general': label_general} 489 | 490 | # get edge label 491 | noisy_id = {} 492 | for k, _ in labels_dict.items(): 493 | noisy_id[k] = [] 494 | for k_label, _ in noisy_id.items(): 495 | sen_id = [] 496 | for s in amr_s: 497 | penman_g = penman.decode(s) 498 | var = [] # k=word id; v=variable 499 | for k, v in penman_g.epidata.items(): 500 | if k[1] == ':instance': 501 | if len(v): 502 | if type(v[0]) == penman.surface.Alignment: 503 | var.append(k[0]) 504 | tmp_idx, tmp_set = [], set() 505 | for e in penman_g.edges(): 506 | if e.role in labels_dict[k_label]: 507 | tmp_set.add(e.source) 508 | tmp_set.add(e.target) 509 | for n in tmp_set: 510 | if n in var: 511 | tmp_idx.append(var.index(n)) 512 | tmp_idx.sort() 513 | sen_id.append(tmp_idx) 514 | noisy_id[k_label] = sen_id 515 | 516 | # leverage edge number 517 | drop_ratio = {} 518 | for k, v in noisy_id.items(): 519 | drop_ratio[k] = 0 520 | for s in v: 521 | drop_ratio[k] += len(s) 522 | min_count = min([v for k, v in drop_ratio.items()]) 523 | for k, v in drop_ratio.items(): 524 | drop_ratio[k] = min_count / drop_ratio[k] 525 | 526 | noisy_id_new = {} 527 | for k, _ in noisy_id.items(): 528 | noisy_id_new[k] = [] 529 | for k, v in noisy_id.items(): 530 | tmp_s = [] 531 | for s in v: 532 | tmp_w = [] 533 | for w in s: 534 | if random.random() < drop_ratio[k]: 535 | tmp_w.append(w) 536 | tmp_s.append(tmp_w) 537 | noisy_id_new[k] = tmp_s 538 | 539 | return noisy_id_new 540 | 541 | 542 | def load_graph_labels(args): 543 | print('2. start to calculate noisy id...') 544 | s_train = load_data('amr_dataset', 'train') 545 | s_dev = load_data('amr_dataset', 'dev') 546 | s_test = load_data('amr_dataset', 'test') 547 | amr_s = s_train + s_dev + s_test 548 | 549 | label_arg = [':ARG0', ':ARG1', ':ARG2', ':ARG3', ':ARG4',\ 550 | ':ARG5', ':ARG6', ':ARG7', ':ARG8', ':ARG9'] 551 | label_op = [':op1', ':op2', ':op3', ':op4', ':op5', ':op6', ':op7',\ 552 | ':op8', ':op9', ':op10', ':op11', ':op12', ':op13', ':op14',\ 553 | ':op15', ':op16', ':op17', ':op18', ':op19'] 554 | label_general = [':accompanier', ':age', ':beneficiary', ':concession',\ 555 | ':condition', ':consist', ':degree', ':destination',\ 556 | ':direction', ':domain', ':duration', ':example', \ 557 | ':extent', ':frequency', ':instrument', ':location',\ 558 | ':manner', ':medium', ':mod', ':name', ':part', ':path',\ 559 | ':polarity', ':poss', ':purpose', ':source', ':subevent',\ 560 | ':subset', ':time', ':topic', ':value', ':ord', ':range'] 561 | label_quantities = [':quant', ':scale', ':unit'] 562 | label_date = [':dayperiod', ':calendar', ':season', ':timezone', ':weekday'] 563 | labels_dict = {'arg': label_arg, 'op': label_op, 'general': label_general, 564 | 'quantities': label_quantities, 'date': label_date} 565 | 566 | # get edge label 567 | noisy_id = {} 568 | for k, _ in labels_dict.items(): 569 | noisy_id[k] = [] 570 | for k_label, _ in noisy_id.items(): 571 | sen_id = [] 572 | for s in amr_s: 573 | penman_g = penman.decode(s) 574 | var = [] # k=word id; v=variable 575 | for k, v in penman_g.epidata.items(): 576 | if k[1] == ':instance': 577 | if len(v): 578 | if type(v[0]) == penman.surface.Alignment: 579 | var.append(k[0]) 580 | tmp_idx, tmp_set = [], set() 581 | for e in penman_g.edges(): 582 | if e.role in labels_dict[k_label]: 583 | tmp_set.add(e.source) 584 | tmp_set.add(e.target) 585 | for n in tmp_set: 586 | if n in var: 587 | tmp_idx.append(var.index(n)) 588 | tmp_idx.sort() 589 | sen_id.append(tmp_idx) 590 | noisy_id[k_label] = sen_id 591 | 592 | 593 | test_dict = {} 594 | for k, v in noisy_id.items(): 595 | count = 0 596 | for s in v: 597 | count += len(s) 598 | test_dict[k] = count 599 | print(test_dict) 600 | 601 | return noisy_id 602 | 603 | def load_glove(args, sentences, data_div='', dataset='ptb'): 604 | data_path = './tmp/glove_'+args.task+data_div+'.npz' 605 | if os.path.exists(data_path): 606 | return np.load(data_path) 607 | 608 | savez_dict = {} 609 | embeddings_dict = {} 610 | with open('./tmp/glove/glove.42B.300d.txt', 'r') as f: 611 | for line in f: 612 | values = line.split() 613 | word = values[0] 614 | vector = np.asarray(values[1:], "float32") 615 | embeddings_dict[word] = vector 616 | 617 | if dataset == 'ptb': 618 | for s in range(len(sentences)): 619 | word_emb = [] 620 | for w in sentences[s]: 621 | if w.lower() in embeddings_dict: 622 | word_emb.append(np.expand_dims(embeddings_dict[w.lower()], axis=0)) 623 | else: 624 | word_emb.append(np.expand_dims(embeddings_dict[','], axis=0)) 625 | savez_dict['s'+str(s)] = np.concatenate(word_emb) 626 | else: 627 | for s in range(len(sentences)): 628 | word_emb = [] 629 | # parse 630 | penman_g = penman.decode(sentences[s]) 631 | sen = penman_g.metadata.get('tok').split(' ') 632 | wid = [] 633 | var = [] # k=word id; v=variable 634 | for k, v in penman_g.epidata.items(): 635 | if k[1] == ':instance': 636 | if len(v): 637 | if type(v[0]) == penman.surface.Alignment: 638 | wid.append(v[0].indices[0]) 639 | var.append(k[0]) 640 | c_s = [] 641 | for w in sen: 642 | c_w = clean_string(w) 643 | if len(c_w) == 0: c_w = ',' 644 | c_s.append(c_w) 645 | for w in c_s: 646 | if w.lower() in embeddings_dict: 647 | word_emb.append(np.expand_dims(embeddings_dict[w.lower()], axis=0)) 648 | else: 649 | word_emb.append(np.expand_dims(embeddings_dict[','], axis=0)) 650 | if len(wid) == 0: wid = [0] 651 | savez_dict['s'+str(s)] = np.concatenate([word_emb[i] for i in wid]) 652 | np.savez('./tmp/glove_'+args.task+data_div+'.npz', **savez_dict) 653 | 654 | return np.load(data_path) 655 | 656 | 657 | def load_elmo(args, sentences, data_div='', dataset='ptb'): 658 | data_path = './tmp/elmo_'+args.task+data_div+'.npz' 659 | if os.path.exists(data_path): 660 | return np.load(data_path) 661 | else: 662 | import nlu 663 | elmo_model = nlu.load('elmo') 664 | 665 | savez_dict = {} 666 | if dataset == 'ptb': 667 | word_set = set() 668 | for s in range(len(sentences)): 669 | for w in sentences[s]: 670 | if w.lower() not in word_set: word_set.add(w.lower()) 671 | embeddings_dict = {} 672 | print('1. start to calculate ELMo embeddings...') 673 | for w in tqdm(word_set): 674 | output = elmo_model.predict(w) 675 | embeddings_dict[w] = output.to_numpy()[0,1] 676 | for s in range(len(sentences)): 677 | word_emb = [] 678 | for w in sentences[s]: 679 | if w.lower() in embeddings_dict: 680 | word_emb.append(np.expand_dims(embeddings_dict[w.lower()], axis=0)) 681 | else: 682 | word_emb.append(np.expand_dims(embeddings_dict[','], axis=0)) 683 | savez_dict['s'+str(s)] = np.concatenate(word_emb) 684 | else: 685 | word_set = set() 686 | for s in range(len(sentences)): 687 | penman_g = penman.decode(sentences[s]) 688 | sen = penman_g.metadata.get('tok').split(' ') 689 | for w in sen: 690 | if w.lower() not in word_set: word_set.add(w.lower()) 691 | embeddings_dict = {} 692 | print('1. start to calculate ELMo embeddings...') 693 | for w in tqdm(word_set): 694 | output = elmo_model.predict(w) 695 | embeddings_dict[w] = output.to_numpy()[0,1] 696 | for s in range(len(sentences)): 697 | word_emb = [] 698 | # parse 699 | penman_g = penman.decode(sentences[s]) 700 | sen = penman_g.metadata.get('tok').split(' ') 701 | wid = [] 702 | var = [] # k=word id; v=variable 703 | for k, v in penman_g.epidata.items(): 704 | if k[1] == ':instance': 705 | if len(v): 706 | if type(v[0]) == penman.surface.Alignment: 707 | wid.append(v[0].indices[0]) 708 | var.append(k[0]) 709 | c_s = [] 710 | for w in sen: 711 | c_w = clean_string(w) 712 | if len(c_w) == 0: c_w = ',' 713 | c_s.append(c_w) 714 | for w in c_s: 715 | if w.lower() in embeddings_dict: 716 | word_emb.append(np.expand_dims(embeddings_dict[w.lower()], axis=0)) 717 | else: 718 | word_emb.append(np.expand_dims(embeddings_dict[','], axis=0)) 719 | if len(wid) == 0: wid = [0] 720 | savez_dict['s'+str(s)] = np.concatenate([word_emb[i] for i in wid]) 721 | np.savez(data_path, **savez_dict) 722 | 723 | return np.load(data_path) 724 | 725 | 726 | def load_elmos(args, sentences, data_div='', dataset='ptb'): 727 | data_path = './tmp/elmo0_'+args.task+data_div+'.npz' 728 | data_paths = [] 729 | for i in range(3): 730 | data_paths.append('./tmp/elmo'+str(i)+'_'+args.task+data_div+'.npz') 731 | 732 | if os.path.exists(data_path): 733 | return data_paths 734 | else: 735 | import nlu 736 | pipe = nlu.load('elmo') 737 | 738 | layers_name = ['lstm_outputs1', 'lstm_outputs2', 'word_emb'] 739 | savez_dict = {} 740 | if dataset == 'ptb': 741 | print('1. start to calculate ELMo embeddings...') 742 | for n in range(len(layers_name)): 743 | for s in tqdm(range(len(sentences))): 744 | pipe['elmo'].setPoolingLayer(layers_name[n]) 745 | outputs = pipe.predict(sentences[s]) 746 | output = outputs.to_numpy()[:,1:] 747 | if len(sentences[s]) <= 1: 748 | savez_dict['s'+str(s)] = output[0] 749 | continue 750 | output_vectors = np.empty((output.shape[0], output[1,0].shape[0])) 751 | for i in range(output_vectors.shape[0]): 752 | output_vectors[i,:] = output[i,0] 753 | if len(sentences[s]) != output_vectors.shape[0]: 754 | print('Error! failed to get whole word', s, sentences[s], output_vectors) 755 | savez_dict['s'+str(s)] = output_vectors 756 | np.savez(data_paths[n], **savez_dict) 757 | else: 758 | print('1. start to calculate ELMo embeddings...') 759 | for n in range(len(layers_name)): 760 | for s in tqdm(range(len(sentences))): 761 | # parse 762 | penman_g = penman.decode(sentences[s]) 763 | sen = penman_g.metadata.get('tok').split(' ') 764 | wid = [] 765 | var = [] # k=word id; v=variable 766 | for k, v in penman_g.epidata.items(): 767 | if k[1] == ':instance': 768 | if len(v): 769 | if type(v[0]) == penman.surface.Alignment: 770 | wid.append(v[0].indices[0]) 771 | var.append(k[0]) 772 | c_s = [] 773 | for w in sen: 774 | c_w = clean_string(w) 775 | if len(c_w) == 0: c_w = ',' 776 | c_s.append(c_w) 777 | 778 | outputs = pipe.predict(c_s) 779 | output = outputs.to_numpy()[:,1:] 780 | if len(c_s) <= 1: 781 | savez_dict['s'+str(s)] = output[0] 782 | continue 783 | output_vectors = np.empty((output.shape[0], output[1,0].shape[0])) 784 | for i in range(output_vectors.shape[0]): 785 | output_vectors[i,:] = output[i,0] 786 | if len(c_s) != output_vectors.shape[0]: 787 | print('Error! failed to get whole word', s, c_s, outputs.to_numpy()[:,1:]) 788 | if len(wid) == 0: wid = [0] 789 | savez_dict['s'+str(s)] = output_vectors[wid] 790 | np.savez(data_paths[n], **savez_dict) 791 | 792 | return data_paths 793 | 794 | 795 | --------------------------------------------------------------------------------