├── .gitignore
├── LICENSE
├── README.md
├── config
├── bert.cfg
├── bert_large.cfg
├── context.cfg
└── word.cfg
├── data
└── download_data.sh
├── evaluation
├── evaluation_all_bert.py
├── evaluation_all_context.py
└── evaluation_all_word.py
├── img
├── dis.png
└── framework.png
├── model
├── models.py
└── visual_attention.py
├── train_bert2score.py
├── train_context2score.py
├── train_word2score.py
└── utils
├── data_helper.py
├── data_helper_4bert.py
├── data_helper_4context.py
├── loader.py
└── util.py
/.gitignore:
--------------------------------------------------------------------------------
1 | .DS_Store
2 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2020 HKUST-KnowComp
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # ComHyper [](https://opensource.org/licenses/MIT)
2 |
3 |
4 |
5 | Code for EMNLP'20 paper "When Hearst Is not Enough: Improving Hypernymy Detection from Corpus with Distributional Models" ([arXiv](https://arxiv.org/abs/2010.04941v1))
6 |
7 |
8 |
9 | In a nutshell, ComHyper is the complementary framework for solving hypernymy detection tasks from the perspective of blind points of Hearst pattern-based methods. As shown in the left Figure, long-tailed nouns cannot well covered by Hearst patterns and thus form non-negligible sparsity types. For such cases, we propose to use supervised distributional models for complmenting pattern-based models shown in the right Figure.
10 |
11 |


12 |
13 |
14 |
15 | ## Use ComHyper
16 |
17 | ### 1. Download Hearst pattern files and corpus.
18 |
19 | First prepare the extracted Hearst pattern pairs such as `hearst_counts.txt.gz` from the repo [hypernymysuite](https://github.com/facebookresearch/hypernymysuite) or `data-concept.zip` from Microsoft Concept Graph (Also known as [Probase](https://concept.research.microsoft.com/Home/Download)). Specify the parameter `pattern_filename` in the `config` as the file location.
20 |
21 | ```
22 | wget https://github.com/facebookresearch/hypernymysuite/blob/master/hearst_counts.txt.gz
23 | curl -L "https://concept.research.microsoft.com/Home/StartDownload" > data-concept.zip
24 | ```
25 |
26 | Then extract the contexts for words from large-scale corpus such as Wiki + Gigaword or ukWac. All the contexts for one word should be organized into one `txt` file and one line for one context.
27 |
28 | For those words appearing in the Hearst patterns (**IP words**), organize their context files into the directory `context` in the `config`. For **OOP words**, organize their context files into the `context_oov` in the `config`.
29 |
30 | ### 2. Train and evaluate the ComHyper.
31 |
32 | For training the distributional models supervsied by the output of pattern-based models, different context encoders are provided:
33 |
34 | ```console
35 | python train_word2score.py config/word.cfg
36 | python train_context2score.py config/context.cfg
37 | python train_bert2score.py config/bert.cfg
38 | ```
39 |
40 | The same evaluation scripts work for all settings. For reproducing the results, run:
41 |
42 | ```console
43 | python evaluation/evaluation_all_context.py ../config/context.cfg
44 | ```
45 |
46 | Note that we choose not to report the `BERT` encoder results in our orginial paper due to efficiency but release the relevant codes for incoroporating effective pre-trained contextualized encoders to further improve the performance. Welcome to PR or contact cyuaq # cse.ust.hk !
47 |
48 |
49 | ## Citation
50 |
51 | Please cite the following paper if you found our method helpful. Thanks !
52 |
53 | ```
54 | @inproceedings{yu-etal-2020-hearst,
55 | title = "When Hearst Is Not Enough: Improving Hypernymy Detection from Corpus with Distributional Models",
56 | author = "Yu, Changlong and Han, Jialong and Wang, Peifeng and Song, Yangqiu and Zhang, Hongming and Ng, Wilfred and Shi, Shuming",
57 | booktitle = "Proceedings of the 2020 Conference on Empirical Methods in Natural Language Processing (EMNLP)",
58 | month = "nov",
59 | year = "2020",
60 | address = "Online",
61 | publisher = "Association for Computational Linguistics",
62 | url = "https://www.aclweb.org/anthology/2020.emnlp-main.502",
63 | pages = "6208--6217",
64 | }
65 | ```
66 |
67 |
68 |
69 |
--------------------------------------------------------------------------------
/config/bert.cfg:
--------------------------------------------------------------------------------
1 | ;TIP: one can comment lines in this config format by adding a ; at the start of a line
2 |
3 | [data]
4 |
5 | pattern_filename=/home/shared/hypernymysuite/hearst_counts.txt.gz
6 | context = /home/shared/hearst_full_context/
7 | context_oov = /home/shared/context/
8 | bert_path = /home/shared/pretrained-lm/bert/bert_base_uncased/
9 | ckpt = /home/cyuaq/comHyper/checkpoints_binary
10 |
11 |
12 | [hyperparameters]
13 |
14 | model = bert_base
15 | svd_dimension = 50
16 | number_hidden_layers = 2
17 | hidden_layer_size = 300
18 | batch_size = 32
19 | negative_num = 1
20 | max_epochs = 500
21 | learning_rate = 0.00001
22 | weight_decay = 0
23 |
24 | context_num = 10
25 | context_len = 10
26 | max_seq_length = 64
27 |
28 | gpu_device = 0,1,2,3
29 |
--------------------------------------------------------------------------------
/config/bert_large.cfg:
--------------------------------------------------------------------------------
1 | ;TIP: one can comment lines in this config format by adding a ; at the start of a line
2 |
3 | [data]
4 |
5 | pattern_filename=/home/shared/hypernymysuite/hearst_counts.txt.gz
6 | context = /home/shared/hearst_full_context/
7 | context_oov = /home/shared/context/
8 | bert_path = /home/shared/pretrained-lm/bert/bert_large_uncased/
9 | ckpt = /home/cyuaq/compHyper/checkpoints_binary
10 |
11 |
12 | [hyperparameters]
13 |
14 | model = bert_large
15 | svd_dimension = 50
16 | number_hidden_layers = 2
17 | hidden_layer_size = 300
18 | batch_size = 8
19 | negative_num = 1
20 | max_epochs = 500
21 | learning_rate = 0.00001
22 | weight_decay = 0
23 |
24 | context_num = 10
25 | context_len = 10
26 | max_seq_length = 64
27 |
28 | gpu_device = 0,1,2,3
29 |
--------------------------------------------------------------------------------
/config/context.cfg:
--------------------------------------------------------------------------------
1 | ;TIP: one can comment lines in this config format by adding a ; at the start of a line
2 |
3 | [data]
4 |
5 | pattern_filename=/home/shared/hypernymysuite/hearst_counts.txt.gz
6 | context = /home/shared/task1_hearst_full_context/
7 | context_oov = /home/shared/task1_context/
8 | ckpt = /home/cyuaq/compHyper/checkpoints_context
9 |
10 |
11 | [hyperparameters]
12 |
13 | model = han
14 | svd_dimension = 50
15 | number_hidden_layers = 2
16 | hidden_layer_size = 300
17 | batch_size = 32
18 | negative_num = 1
19 | max_epochs = 500
20 | learning_rate = 0.001
21 | weight_decay = 0
22 |
23 | context_num = 10
24 | context_len = 10
25 |
26 | gpu_device = 3
--------------------------------------------------------------------------------
/config/word.cfg:
--------------------------------------------------------------------------------
1 | ;TIP: one can comment lines in this config format by adding a ; at the start of a line
2 |
3 | [data]
4 |
5 |
6 | ; two training data files have to be aligned (the two vectors of the same word in the same line)
7 |
8 | pattern_filename=/home/shared/hypernymysuite/hearst_counts.txt.gz
9 | context = /home/shared/context
10 | ckpt = /home/cyuaq/compHyper/checkpoints_word/
11 |
12 |
13 | [hyperparameters]
14 |
15 | model = mlp_unisample_svd
16 | svd_dimension = 50
17 | number_hidden_layers = 2
18 | hidden_layer_size = 300
19 | batch_size = 128
20 | negative_num = 400
21 | max_epochs = 500
22 | learning_rate = 0.001
23 | weight_decay = 0
24 |
25 | gpu_device = 3
--------------------------------------------------------------------------------
/data/download_data.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | #
3 | # Copyright (c) 2017-present, Facebook, Inc.
4 | # All rights reserved.
5 | #
6 | # This source code is licensed under the license found in the
7 | # LICENSE file in the root directory of this source tree.
8 | #
9 | # -------------------------------------------------------------------------------
10 | # This shell script downloads and preprocesses all the datasets
11 | # -------------------------------------------------------------------------------
12 |
13 | # Directly from the repo: https://github.com/facebookresearch/hypernymysuite/blob/master/download_data.sh
14 |
15 | # Immediately quit on error
16 | set -e
17 |
18 | # if you have any proxies, etc., put them here
19 | CURL_OPTIONS="-s"
20 |
21 |
22 | # URLS of each of the different datasets
23 | OMER_URL="http://u.cs.biu.ac.il/~nlp/wp-content/uploads/lexical_inference.zip"
24 | SHWARTZ_URL="https://raw.githubusercontent.com/vered1986/HypeNET/v2/dataset/datasets.rar"
25 | VERED_REPO_URL="https://raw.githubusercontent.com/vered1986/UnsupervisedHypernymy/e3b22709365c7b3042126e5887c9baa03631354e/datasets"
26 | KIMANH_REPO_URL="https://raw.githubusercontent.com/nguyenkh/HyperVec/bd2cb15a6be2a4726ffbf9c0d7e742144790dee3/datasets_classification"
27 | HYPERLEX_URL="https://raw.githubusercontent.com/ivulic/hyperlex/master/hyperlex-data.zip"
28 |
29 | function warning () {
30 | echo "$1" >&2
31 | }
32 |
33 | get_seeded_random()
34 | {
35 | seed="$1"
36 | openssl enc -aes-256-ctr -pass pass:"$seed" -nosalt \
37 | /dev/null
38 | }
39 |
40 | function deterministic_shuffle () {
41 | # sort randomly but with a predictable seed
42 | sort --random-sort --random-source=<(get_seeded_random 42)
43 | }
44 |
45 | function download_hyperlex () {
46 | TMPFILE="$(mktemp)"
47 | TMPDIRE="$(mktemp -d)"
48 | curl $CURL_OPTIONS "$HYPERLEX_URL" > "$TMPFILE"
49 | unzip "$TMPFILE" -d "$TMPDIRE" > /dev/null
50 |
51 | echo -e 'word1\tword2\tpos\tlabel\tscore\tfold'
52 | grep -v WORD1 "$TMPDIRE/splits/random/hyperlex_training_all_random.txt" | \
53 | cut -d' ' -f1-5 | tr ' ' '\t' | \
54 | awk -F'\t' '$0=$0"\ttrain"'
55 | grep -v WORD1 "$TMPDIRE/splits/random/hyperlex_dev_all_random.txt" | \
56 | cut -d' ' -f1-5 | tr ' ' '\t' | \
57 | awk -F'\t' '$0=$0"\tval"'
58 | grep -v WORD1 "$TMPDIRE/splits/random/hyperlex_test_all_random.txt" | \
59 | cut -d' ' -f1-5 | tr ' ' '\t' | \
60 | awk -F'\t' '$0=$0"\ttest"'
61 |
62 | rm -rf "$TMPFILE" "$TMPDIRE"
63 | }
64 |
65 | function download_bless () {
66 | TMPFILE="$(mktemp)"
67 | TMPDIRE="$(mktemp -d)"
68 | curl $CURL_OPTIONS "$OMER_URL" > "$TMPFILE"
69 | unzip "$TMPFILE" -d "$TMPDIRE" > /dev/null
70 |
71 | echo -e 'word1\tword2\tlabel\tfold'
72 | cat "${TMPDIRE}/lexical_entailment/bless2011/data_rnd_test.tsv" \
73 | "${TMPDIRE}/lexical_entailment/bless2011/data_rnd_train.tsv" \
74 | "${TMPDIRE}/lexical_entailment/bless2011/data_rnd_val.tsv" | \
75 | tr -d '\15' | \
76 | deterministic_shuffle | \
77 | awk '{if (NR < 1454) {print $0 "\tval"} else {print $0 "\ttest"}}'
78 |
79 | rm -rf "$TMPFILE" "$TMPDIRE"
80 | }
81 |
82 | function download_leds () {
83 | TMPFILE="$(mktemp)"
84 | TMPDIRE="$(mktemp -d)"
85 | curl $CURL_OPTIONS "$OMER_URL" > "$TMPFILE"
86 | unzip "$TMPFILE" -d "$TMPDIRE" > /dev/null
87 |
88 | echo -e 'word1\tword2\tlabel\tfold'
89 | cat "${TMPDIRE}/lexical_entailment/baroni2012/data_rnd_test.tsv" \
90 | "${TMPDIRE}/lexical_entailment/baroni2012/data_rnd_train.tsv" \
91 | "${TMPDIRE}/lexical_entailment/baroni2012/data_rnd_val.tsv" | \
92 | tr -d '\15' | \
93 | deterministic_shuffle | \
94 | awk '{if (NR < 276) {print $0 "\tval"} else {print $0 "\ttest"}}'
95 |
96 | rm -rf "$TMPFILE" "$TMPDIRE"
97 | }
98 |
99 | function download_shwartz () {
100 | TMPFILE="$(mktemp)"
101 | TMPDIRE="$(mktemp -d)"
102 | curl $CURL_OPTIONS "$SHWARTZ_URL" > "$TMPFILE"
103 |
104 | unrar x "$TMPFILE" "$TMPDIRE" >/dev/null
105 | echo -e 'word1\tword2\tlabel\tfold'
106 | cat "$TMPDIRE/dataset_rnd/train.tsv" \
107 | "$TMPDIRE/dataset_rnd/test.tsv" \
108 | "$TMPDIRE/dataset_rnd/val.tsv" | \
109 | grep -v ' ' | \
110 | deterministic_shuffle | \
111 | awk '{if (NR < 5257) {print $0 "\tval"} else {print $0 "\ttest"}}'
112 |
113 | rm -rf "$TMPFILE" "$TMPDIRE"
114 | }
115 |
116 | function download_bibless () {
117 | echo -e 'word1\tword2\trelation\tlabel'
118 | curl $CURL_OPTIONS "$KIMANH_REPO_URL/ABIBLESS.txt" | \
119 | cut -f1,2,4 | \
120 | awk -F'\t' '{if ($3 == "hyper") {print $0 "\t1"} else if ($3 == "other") {print $0 "\t0"} else {print $0 "\t-1"}}'
121 | }
122 |
123 | function download_wbless () {
124 | echo -e 'word1\tword2\tlabel\trelation\tfold'
125 | curl $CURL_OPTIONS "$KIMANH_REPO_URL/AWBLESS.txt" | \
126 | deterministic_shuffle | \
127 | awk '{if (NR < 168) {print $0 "\tval"} else {print $0 "\ttest"}}'
128 | }
129 |
130 | function download_eval () {
131 | echo -e 'word1\tword2\tlabel\trelation\tfold'
132 | curl $CURL_OPTIONS "$VERED_REPO_URL/EVALution.val" "$VERED_REPO_URL/EVALution.test" | \
133 | sort | uniq | sed 's/-[jvn]\t/\t/g' | \
134 | deterministic_shuffle | \
135 | awk '{if (NR < 737) {print $0 "\tval"} else {print $0 "\ttest"}}'
136 | }
137 |
138 |
139 | # Let the user specify output directory, default to `data`
140 | # Ex: `HYPERNYMY_DATA_OUTPUT=.my_data_dir bash download_data.sh`
141 | if [ -z $HYPERNYMY_DATA_OUTPUT ]; then
142 | HYPERNYMY_DATA_OUTPUT="data"
143 | fi
144 |
145 | if [ -d "$HYPERNYMY_DATA_OUTPUT" ]
146 | then
147 | echo "Warning: Already found the data. Please run 'rm -rf $HYPERNYMY_DATA_OUTPUT'" >&2
148 | exit 1
149 | fi
150 |
151 | if [ ! -x "$(command -v unrar)" ]
152 | then
153 | warning "This script requires the 'unrar' tool. Please run"
154 | warning " brew install unrar"
155 | warning "or whatever your system's equivalent is."
156 | exit 1
157 | fi
158 |
159 | if [ ! -x "$(command -v openssl)" ]
160 | then
161 | warning "This script requires the 'openssl' tool. Please run"
162 | warning " brew install unrar"
163 | warning "or whatever your system's equivalent is."
164 | exit 1
165 | fi
166 |
167 |
168 |
169 | # prep the output folder
170 | mkdir -p "$HYPERNYMY_DATA_OUTPUT"
171 |
172 |
173 | warning "[1/7] Downloading BLESS"
174 | download_bless > "$HYPERNYMY_DATA_OUTPUT/bless.tsv"
175 |
176 | warning "[2/7] Downloading LEDS"
177 | download_leds > "$HYPERNYMY_DATA_OUTPUT/leds.tsv"
178 |
179 | warning "[3/7] Downloading EVAL"
180 | download_eval > "$HYPERNYMY_DATA_OUTPUT/eval.tsv"
181 |
182 | warning "[4/7] Downloading Shwartz"
183 | download_shwartz > "$HYPERNYMY_DATA_OUTPUT/shwartz.tsv"
184 |
185 | warning "[5/7] Downloading Hyperlex"
186 | download_hyperlex > "$HYPERNYMY_DATA_OUTPUT/hyperlex_rnd.tsv"
187 |
188 | warning "[6/7] Downloading WBLESS"
189 | download_wbless > "$HYPERNYMY_DATA_OUTPUT/wbless.tsv"
190 |
191 | warning "[7/7] Downloading BiBLESS"
192 | download_bibless > "$HYPERNYMY_DATA_OUTPUT/bibless.tsv"
193 |
194 | warning "All done."
--------------------------------------------------------------------------------
/evaluation/evaluation_all_bert.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import time
3 | import os
4 | import logging
5 | import tqdm
6 | import torch
7 | import torch.nn as nn
8 | import torch.nn.functional as F
9 | import torch.optim as optim
10 |
11 | import numpy as np
12 | import configparser
13 | from model.models import *
14 | from utils.data_helper_4bert import Dataset
15 | from utils.loader import Testdataset
16 | from scipy import stats
17 | from gensim.models import Word2Vec
18 | from sklearn.metrics import average_precision_score,precision_recall_curve
19 | from collections import OrderedDict
20 |
21 | SIEGE_EVALUATIONS = [
22 | ("bless", "data/bless.tsv"),
23 | ("eval", "data/eval.tsv"),
24 | ("leds", "data/leds.tsv"),
25 | ("shwartz", "data/shwartz.tsv"),
26 | ("weeds", "data/wbless.tsv"),
27 | ]
28 |
29 | CORRELATION_EVAL_DATASETS = [("hyperlex", "data/hyperlex_rnd.tsv"),
30 | ("hyperlex_noun", "data/hyperlex_noun.tsv")]
31 |
32 |
33 | def make_hparam_string(config):
34 | hparam = "{}/s{}_h{}-{}_n{}_c{}-{}_b{}".format(
35 | config.get("hyperparameters", "model"),
36 | config.get("hyperparameters", "svd_dimension"),
37 | config.get("hyperparameters", "number_hidden_layers"),
38 | config.get("hyperparameters", "hidden_layer_size"),
39 | config.get("hyperparameters", "negative_num"),
40 | # config.get("hyperparameters", "weight_decay"),
41 | config.get("hyperparameters", "context_num"),
42 | config.get("hyperparameters", "context_len"),
43 | config.get("hyperparameters", "batch_size")
44 | )
45 | return hparam
46 |
47 | def init_model(config, ckpt_path, device):
48 |
49 | encoder_type = config.get("hyperparameters", "model")
50 | number_hidden_layers = int(config.getfloat("hyperparameters", "number_hidden_layers"))
51 | hidden_layer_size = int(config.getfloat("hyperparameters", "hidden_layer_size"))
52 | bert_dir = config.get("data", "bert_path")
53 | model = Bert2Score(encoder_type, bert_dir, hidden_layer_size, 0.1)
54 | model.to(device)
55 | pretrain = torch.load(ckpt_path)
56 | new_pretrain = OrderedDict()
57 | #for k, v in pretrain.items():
58 | # name = k[7:]
59 | # new_pretrain[name] = v
60 | # pretrain.pop("word_embedding.weight")
61 | model.load_state_dict(pretrain)
62 | model.eval()
63 |
64 | return model
65 |
66 |
67 | def predict_many(data, model, hypos, hypers, reverse, device):
68 |
69 | num = 0
70 | result = []
71 | result_svd = []
72 | count_oop = 0
73 | count_pair = 0
74 | for hypon, hyper in zip(hypos, hypers):
75 | count_pair += 1
76 | if hypon in data.vocab and hyper in data.vocab:
77 | l = data.word2id[hypon]
78 | r = data.word2id[hyper]
79 |
80 | if reverse:
81 | pred = data.U[r].dot(data.V[l])
82 | else:
83 | pred = data.U[l].dot(data.V[r])
84 | result_svd.append(pred)
85 |
86 | else:
87 | # out of pattern mode
88 | result_svd.append(0.0)
89 | count_oop += 1
90 | try:
91 | hypon_id = data.context_w2i[hypon]
92 | hyper_id = data.context_w2i[hyper]
93 | hypon_word_context = data.context_dict[hypon_id]['ids']
94 | hyper_word_context = data.context_dict[hyper_id]['ids']
95 |
96 | hypon_word_mask = data.context_dict[hypon_id]['mask']
97 | hyper_word_mask = data.context_dict[hyper_id]['mask']
98 |
99 |
100 | if reverse:
101 | inputs = torch.tensor(np.asarray([[hyper_word_context, hypon_word_context]]), dtype=torch.long).to(device)
102 | inputs_mask = torch.tensor(np.asarray([[hyper_word_mask, hypon_word_mask]]), dtype=torch.long).to(device)
103 | pred = model(inputs, inputs_mask).detach().cpu().numpy()[0]
104 | else:
105 | inputs = torch.tensor(np.asarray([[hypon_word_context, hyper_word_context]]), dtype=torch.long).to(device)
106 | inputs_mask = torch.tensor(np.asarray([[hypon_word_mask, hyper_word_mask]]), dtype=torch.long).to(device)
107 | pred = model(inputs, inputs_mask).detach().cpu().numpy()[0]
108 | except:
109 | num +=1
110 | pred = 0.0
111 |
112 | result.append(pred)
113 | # num = 0 -> all the word in the embedding
114 | oop_rate = count_oop * 1.0 / count_pair
115 | return np.array(result, dtype=np.float32), np.array(result_svd, dtype=np.float32), oop_rate
116 |
117 |
118 | def detection_setup(file_name, model, matrix_data ,device):
119 |
120 | logger.info("-" * 80)
121 | logger.info("processing dataset :{}".format(file_name))
122 | ds = Testdataset(file_name, matrix_data.vocab)
123 |
124 | m_val = ds.val_mask
125 | m_test = ds.test_mask
126 |
127 | h = np.zeros(len(ds))
128 |
129 | h_ip = np.zeros(len(ds))
130 |
131 | print(len(ds))
132 |
133 | predict_mask = np.full(len(ds), True)
134 | inpattern_mask = np.full(len(ds), True)
135 |
136 | true_prediction = []
137 | in_pattern_prediction = []
138 |
139 | count_context = 0
140 |
141 | mask_idx = 0
142 | for x,y in zip(ds.hypos, ds.hypers):
143 | if x in matrix_data.vocab and y in matrix_data.vocab:
144 |
145 | l = matrix_data.word2id[x]
146 | r = matrix_data.word2id[y]
147 | score = matrix_data.U[l].dot(matrix_data.V[r])
148 |
149 | in_pattern_prediction.append(score)
150 | true_prediction.append(score)
151 | else:
152 |
153 | inpattern_mask[mask_idx] = False
154 | # out of pattern
155 | try:
156 | hypon_id = matrix_data.context_w2i[x]
157 | hyper_id = matrix_data.context_w2i[y]
158 |
159 | hypon_word_context = matrix_data.context_dict[hypon_id]['ids']
160 | hyper_word_context = matrix_data.context_dict[hyper_id]['ids']
161 |
162 | hypon_word_mask = matrix_data.context_dict[hypon_id]['mask']
163 | hyper_word_mask = matrix_data.context_dict[hyper_id]['mask']
164 |
165 | inputs = torch.tensor(np.asarray([[hypon_word_context, hyper_word_context]]), dtype=torch.long).to(device)
166 | inputs_mask = torch.tensor(np.asarray([[hypon_word_mask, hyper_word_mask]]), dtype=torch.long).to(device)
167 | score = model(inputs, inputs_mask).detach().cpu().numpy()[0]
168 | count_context +=1
169 |
170 | true_prediction.append(score)
171 |
172 | except Exception as e:
173 | print(repr(e))
174 | print(file_name)
175 | predict_mask[mask_idx] = False
176 | mask_idx +=1
177 |
178 | h[predict_mask] = np.array(true_prediction, dtype=np.float32)
179 | h[~predict_mask] = h[predict_mask].min()
180 |
181 | h_ip[inpattern_mask] = np.array(in_pattern_prediction, dtype=np.float32)
182 | h_ip[~inpattern_mask] = h_ip[inpattern_mask].min()
183 |
184 | y = ds.y
185 |
186 |
187 | result= {
188 | "ap_val": average_precision_score(y[m_val],h[m_val]),
189 | "ap_test": average_precision_score(y[m_test],h[m_test]),
190 | }
191 |
192 | result['true_oov'] = int(np.sum(ds.oov_mask & ds.y))
193 |
194 | result['oov_rate'] = np.mean(ds.oov_mask)
195 | result['predict_num'] = int(np.sum(predict_mask))
196 | result['oov_num'] = int(np.sum(ds.oov_mask))
197 |
198 | logger.info("there are {:2d}/{:2d} pairs have context".format(count_context, result['oov_num']))
199 | logger.info("Bert : AP for validation is :{} || for test is :{}".format(average_precision_score(y[m_val],h[m_val]),
200 | average_precision_score(y[m_test],h[m_test]) ))
201 | logger.info("Svdppmi : AP for validation is :{} || for test is :{}".format(average_precision_score(y[m_val],h_ip[m_val]),
202 | average_precision_score(y[m_test],h_ip[m_test]) ))
203 | logger.info("OOV true number is ".format(result['true_oov']))
204 |
205 | return result
206 |
207 | def hyperlex_setup(file_name, model, matrix_data,device):
208 |
209 | logger.info("-" * 80)
210 | logger.info("processing dataset :{}".format(file_name))
211 |
212 | ds = Testdataset(file_name, matrix_data.vocab, ycolumn='score')
213 |
214 | h = np.zeros(len(ds))
215 | h_ip = np.zeros(len(ds))
216 |
217 | predict_mask = np.full(len(ds), True)
218 | inpattern_mask = np.full(len(ds), True)
219 |
220 | true_prediction = []
221 | in_pattern_prediction = []
222 |
223 | mask_idx = 0
224 | for x,y in zip(ds.hypos, ds.hypers):
225 | if x in matrix_data.vocab and y in matrix_data.vocab:
226 |
227 | l = matrix_data.word2id[x]
228 | r = matrix_data.word2id[y]
229 | score = matrix_data.U[l].dot(matrix_data.V[r])
230 |
231 | true_prediction.append(score)
232 | in_pattern_prediction.append(score)
233 |
234 | else:
235 | # out of pattern
236 | inpattern_mask[mask_idx] = False
237 | try:
238 | hypon_id = matrix_data.context_w2i[x]
239 | hyper_id = matrix_data.context_w2i[y]
240 |
241 | hypon_word_context = matrix_data.context_dict[hypon_id]['ids']
242 | hyper_word_context = matrix_data.context_dict[hyper_id]['ids']
243 |
244 | hypon_word_mask = matrix_data.context_dict[hypon_id]['mask']
245 | hyper_word_mask = matrix_data.context_dict[hyper_id]['mask']
246 |
247 | inputs = torch.tensor(np.asarray([[hypon_word_context, hyper_word_context]]), dtype=torch.long).to(device)
248 | inputs_mask = torch.tensor(np.asarray([[hypon_word_mask, hyper_word_mask]]), dtype=torch.long).to(device)
249 |
250 | score = model(inputs, inputs_mask).detach().cpu().numpy()[0]
251 |
252 | true_prediction.append(score)
253 |
254 | except Exception as e:
255 | print(repr(e))
256 | print(file_name)
257 | predict_mask[mask_idx] = False
258 |
259 | mask_idx +=1
260 |
261 | h[predict_mask] = np.array(true_prediction, dtype=np.float32)
262 | h[~predict_mask] = np.median(h[predict_mask])
263 |
264 | h_ip[inpattern_mask] = np.array(in_pattern_prediction, dtype=np.float32)
265 | h_ip[~inpattern_mask] = np.median(h_ip[inpattern_mask])
266 |
267 |
268 | y = ds.labels
269 |
270 | m_train = ds.train_mask
271 | m_val = ds.val_mask
272 | m_test = ds.test_mask
273 |
274 | result = {
275 | "spearman_train": stats.spearmanr(y[m_train], h[m_train])[0],
276 | "spearman_val": stats.spearmanr(y[m_val], h[m_val])[0],
277 | "spearman_test": stats.spearmanr(y[m_test], h[m_test])[0],
278 | }
279 |
280 |
281 | result['oov_rate'] = np.mean(ds.oov_mask)
282 | result['predict_num'] = int(np.sum(predict_mask))
283 | result['oov_num'] = int(np.sum(ds.oov_mask))
284 |
285 | svd_train = stats.spearmanr(y[m_train], h_ip[m_train])[0]
286 | svd_test = stats.spearmanr(y[m_test], h_ip[m_test])[0]
287 |
288 | oov_train = stats.spearmanr(y[ds.oov_mask], h[ds.oov_mask])
289 |
290 | logger.info("Bert: train cor: {} | test cor:{}".format(result['spearman_train'],result['spearman_test']))
291 | logger.info("OOV cor: {}".format(oov_train))
292 |
293 | logger.info("Svdppmi: train cor: {} | test cor:{}".format(svd_train, svd_test))
294 |
295 | return result
296 |
297 |
298 | def dir_bless_setup(model, matrix_data ,device):
299 |
300 | logger.info("-" * 80)
301 | logger.info("processing dataset : dir_bless")
302 | ds = Testdataset("data/bless.tsv", matrix_data.vocab)
303 |
304 | hypos = ds.hypos[ds.y]
305 | hypers = ds.hypers[ds.y]
306 |
307 | m_val = ds.val_mask[ds.y]
308 | m_test = ds.test_mask[ds.y]
309 |
310 | h = np.zeros(len(ds))
311 |
312 | pred_score_list = []
313 | svd_pred_list = []
314 | count_oop = 0
315 | count_pair = 0
316 |
317 | for hypon, hyper in zip(hypos, hypers):
318 | if hypon in matrix_data.vocab and hyper in matrix_data.vocab:
319 | l = matrix_data.word2id[hypon]
320 | r = matrix_data.word2id[hyper]
321 |
322 | forward_pred = matrix_data.U[l].dot(matrix_data.V[r])
323 | reverse_pred = matrix_data.U[r].dot(matrix_data.V[l])
324 |
325 | if forward_pred > reverse_pred:
326 | pred_score_list.append(1)
327 | svd_pred_list.append(1)
328 | else:
329 | pred_score_list.append(0)
330 | svd_pred_list.append(0)
331 | else:
332 | # out of pattern mode
333 | svd_pred_list.append(0)
334 | count_oop += 1
335 | try:
336 | hypon_id = matrix_data.context_w2i[hypon]
337 | hyper_id = matrix_data.context_w2i[hyper]
338 | hypon_word_context = matrix_data.context_dict[hypon_id]['ids']
339 | hyper_word_context = matrix_data.context_dict[hyper_id]['ids']
340 |
341 | hypon_word_mask = matrix_data.context_dict[hypon_id]['ids']
342 | hyper_word_mask = matrix_data.context_dict[hyper_id]['ids']
343 | inputs = torch.tensor(np.asarray([[hypon_word_context, hyper_word_context]]), dtype=torch.long).to(device)
344 | inputs_mask = torch.tensor(np.asarray([[hypon_word_mask, hyper_word_mask]]), dtype=torch.long).to(device)
345 |
346 | forward_pred = model(inputs, inputs_mask).detach().cpu().numpy()[0]
347 |
348 | inputs = torch.tensor(np.asarray([[hyper_word_context, hypon_word_context]]), dtype=torch.long).to(device)
349 | inputs_mask = torch.tensor(np.asarray([[hyper_word_mask, hypon_word_mask]]), dtype=torch.long).to(device)
350 |
351 | reverse_pred = model(inputs, inputs_mask).detach().cpu().numpy()[0]
352 |
353 | if forward_pred > reverse_pred:
354 | pred_score_list.append(1)
355 | else:
356 | pred_score_list.append(0)
357 | except Exception as e:
358 | print(repr(e))
359 | pred_score_list.append(0)
360 |
361 | acc = np.mean(np.asarray(pred_score_list))
362 | acc_val = np.mean(np.asarray(pred_score_list)[m_val])
363 | acc_test = np.mean(np.asarray(pred_score_list)[m_test])
364 |
365 | s_acc = np.mean(np.asarray(svd_pred_list))
366 |
367 | logger.info("Val Acc : {} || Test Acc: {} ".format(acc_val, acc_test))
368 | logger.info("Sppmi Acc: {} ".format(s_acc))
369 |
370 |
371 | def dir_wbless_setup(model, data ,device):
372 |
373 | logger.info("-" * 80)
374 | logger.info("processing dataset : dir_wbless")
375 | data_path = "data/wbless.tsv"
376 | ds = Testdataset(data_path, data.vocab)
377 |
378 |
379 | rng = np.random.RandomState(42)
380 | VAL_PROB = .02
381 | NUM_TRIALS = 1000
382 |
383 | # We have no way of handling oov
384 | h, h_svd, _ = predict_many(data, model, ds.hypos, ds.hypers, False, device)
385 | y = ds.y
386 |
387 | val_scores = []
388 | test_scores = []
389 |
390 | for _ in range(NUM_TRIALS):
391 | # Generate a new mask every time
392 | m_val = rng.rand(len(y)) < VAL_PROB
393 | # Test is everything except val
394 | m_test = ~m_val
395 | _, _, t = precision_recall_curve(y[m_val], h[m_val])
396 | # pick the highest accuracy on the validation set
397 | thr_accs = np.mean((h[m_val, np.newaxis] >= t) == y[m_val, np.newaxis], axis=0)
398 | best_t = t[thr_accs.argmax()]
399 | preds_val = h[m_val] >= best_t
400 | preds_test = h[m_test] >= best_t
401 | # Evaluate
402 | val_scores.append(np.mean(preds_val == y[m_val]))
403 | test_scores.append(np.mean(preds_test == y[m_test]))
404 | # sanity check
405 | assert np.allclose(val_scores[-1], thr_accs.max())
406 |
407 | # report average across many folds
408 | logger.info("bert: acc_val_inv: {} acc_test_inv: {}".format(np.mean(val_scores), np.mean(test_scores)))
409 |
410 | val_scores = []
411 | test_scores = []
412 |
413 | for _ in range(NUM_TRIALS):
414 | # Generate a new mask every time
415 | m_val = rng.rand(len(y)) < VAL_PROB
416 | # Test is everything except val
417 | m_test = ~m_val
418 | _, _, t = precision_recall_curve(y[m_val], h_svd[m_val])
419 | # pick the highest accuracy on the validation set
420 | thr_accs = np.mean((h_svd[m_val, np.newaxis] >= t) == y[m_val, np.newaxis], axis=0)
421 | best_t = t[thr_accs.argmax()]
422 | preds_val = h_svd[m_val] >= best_t
423 | preds_test = h_svd[m_test] >= best_t
424 | # Evaluate
425 | val_scores.append(np.mean(preds_val == y[m_val]))
426 | test_scores.append(np.mean(preds_test == y[m_test]))
427 | # sanity check
428 | assert np.allclose(val_scores[-1], thr_accs.max())
429 |
430 | # report average across many folds
431 | logger.info("sppmi: acc_val_inv: {} acc_test_inv: {}".format(np.mean(val_scores), np.mean(test_scores)))
432 |
433 |
434 | def dir_bibless_setup(model, data ,device):
435 |
436 | logger.info("-" * 80)
437 | logger.info("processing dataset : dir_bibless")
438 | data_path = "data/bibless.tsv"
439 | ds = Testdataset(data_path, data.vocab)
440 |
441 |
442 | rng = np.random.RandomState(42)
443 | VAL_PROB = .02
444 | NUM_TRIALS = 1000
445 |
446 |
447 | #y = ds.y[ds.invocab_mask]
448 | y = ds.y
449 | # hypernymy could be either direction
450 | yh = y != 0
451 |
452 | # get forward and backward predictions
453 | hf, hf_svd, oop_rate = predict_many(data, model, ds.hypos, ds.hypers, False, device)
454 | hr, hr_svd, _ = predict_many(data, model, ds.hypos, ds.hypers, True, device)
455 | logger.info('OOP Rate: {}'.format(oop_rate))
456 | h = np.max([hf, hr], axis=0)
457 | h_svd = np.max([hf_svd, hr_svd], axis=0)
458 |
459 | dir_pred = 2 * np.float32(hf >= hr) - 1
460 | dir_pred_svd = 2 * np.float32(hf_svd >= hr_svd) - 1
461 |
462 | val_scores = []
463 | test_scores = []
464 | for _ in range(NUM_TRIALS):
465 | # Generate a new mask every time
466 | m_val = rng.rand(len(y)) < VAL_PROB
467 | # Test is everything except val
468 | m_test = ~m_val
469 |
470 | # set the threshold based on the maximum score
471 | _, _, t = precision_recall_curve(yh[m_val], h[m_val])
472 | thr_accs = np.mean((h[m_val, np.newaxis] >= t) == yh[m_val, np.newaxis], axis=0)
473 | best_t = t[thr_accs.argmax()]
474 |
475 | det_preds_val = h[m_val] >= best_t
476 | det_preds_test = h[m_test] >= best_t
477 |
478 | fin_preds_val = det_preds_val * dir_pred[m_val]
479 | fin_preds_test = det_preds_test * dir_pred[m_test]
480 |
481 | val_scores.append(np.mean(fin_preds_val == y[m_val]))
482 | test_scores.append(np.mean(fin_preds_test == y[m_test]))
483 |
484 | # report average across many folds
485 | logger.info("bert: acc_val_all: {}, acc_test_all: {}".format(np.mean(val_scores),np.mean(test_scores)))
486 |
487 | val_scores = []
488 | test_scores = []
489 | for _ in range(NUM_TRIALS):
490 | # Generate a new mask every time
491 | m_val = rng.rand(len(y)) < VAL_PROB
492 | # Test is everything except val
493 | m_test = ~m_val
494 |
495 | # set the threshold based on the maximum score
496 | _, _, t = precision_recall_curve(yh[m_val], h_svd[m_val])
497 | thr_accs = np.mean((h_svd[m_val, np.newaxis] >= t) == yh[m_val, np.newaxis], axis=0)
498 | best_t = t[thr_accs.argmax()]
499 |
500 | det_preds_val = h_svd[m_val] >= best_t
501 | det_preds_test = h_svd[m_test] >= best_t
502 |
503 | fin_preds_val = det_preds_val * dir_pred_svd[m_val]
504 | fin_preds_test = det_preds_test * dir_pred_svd[m_test]
505 |
506 | val_scores.append(np.mean(fin_preds_val == y[m_val]))
507 | test_scores.append(np.mean(fin_preds_test == y[m_test]))
508 |
509 | # report average across many folds
510 | logger.info("sppmi: acc_val_all: {}, acc_test_all: {}".format(np.mean(val_scores),np.mean(test_scores)))
511 |
512 | def evaluation_all(config, ckpt_path):
513 |
514 | #embedding = load_gensim_word2vec()
515 |
516 | matrix_data = Dataset(config,train=False)
517 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
518 |
519 | model = init_model(config, ckpt_path, device)
520 |
521 | results = {}
522 |
523 | for taskname, filename in SIEGE_EVALUATIONS:
524 | result = detection_setup(filename, model, matrix_data ,device)
525 | results["detec_{}".format(taskname)] = result
526 |
527 | for taskname, filename in CORRELATION_EVAL_DATASETS:
528 | result = hyperlex_setup(filename, model, matrix_data, device)
529 | results["corr_{}".format(taskname)] = result
530 |
531 | dir_bless_setup(model,matrix_data, device)
532 | dir_wbless_setup(model, matrix_data, device)
533 | dir_bibless_setup(model, matrix_data, device)
534 |
535 | return results
536 |
537 | if __name__ == "__main__":
538 |
539 | config_file = sys.argv[1]
540 | config = configparser.RawConfigParser()
541 | config.read(config_file)
542 |
543 | ckpt_dir = config.get("data", "ckpt")
544 | hparam = make_hparam_string(config)
545 | ckpt_dir = os.path.join(ckpt_dir, hparam)
546 | ckpt_path = os.path.join(ckpt_dir, 'best.ckpt')
547 | log_path = os.path.join(ckpt_dir, 'bert_res.log')
548 |
549 | logger = logging.getLogger()
550 | logger.setLevel(logging.INFO)
551 | handler = logging.FileHandler(log_path, 'w')
552 | handler.setLevel(logging.INFO)
553 | formatter = logging.Formatter('%(asctime)s: %(message)s', datefmt='%Y/%m/%d %H:%M:%S')
554 | handler.setFormatter(formatter)
555 | logger.addHandler(handler)
556 |
557 | results = evaluation_all(config, ckpt_path)
558 | print(results)
559 |
--------------------------------------------------------------------------------
/evaluation/evaluation_all_context.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import time
3 | import os
4 | import logging
5 | import tqdm
6 | import torch
7 | import torch.nn as nn
8 | import torch.nn.functional as F
9 | import torch.optim as optim
10 |
11 | import numpy as np
12 | import configparser
13 | from model.models import *
14 | from utils.data_helper_4context import Dataset
15 | from utils.loader import Testdataset
16 | from scipy import stats
17 | from gensim.models import Word2Vec
18 | from sklearn.metrics import average_precision_score,precision_recall_curve
19 |
20 | SIEGE_EVALUATIONS = [
21 | ("bless", "data/bless.tsv"),
22 | ("eval", "data/eval.tsv"),
23 | ("leds", "data/leds.tsv"),
24 | ("shwartz", "data/shwartz.tsv"),
25 | ("weeds", "data/wbless.tsv"),
26 | ]
27 |
28 | CORRELATION_EVAL_DATASETS = [("hyperlex", "data/hyperlex_rnd.tsv"),
29 | ("hyperlex_noun", "data/hyperlex_noun.tsv")]
30 |
31 |
32 | def make_hparam_string(config):
33 | hparam = "{}/s{}_h{}-{}_n{}_c{}-{}_b{}".format(
34 | config.get("hyperparameters", "model"),
35 | config.get("hyperparameters", "svd_dimension"),
36 | config.get("hyperparameters", "number_hidden_layers"),
37 | config.get("hyperparameters", "hidden_layer_size"),
38 | config.get("hyperparameters", "negative_num"),
39 | # config.get("hyperparameters", "weight_decay"),
40 | config.get("hyperparameters", "context_num"),
41 | config.get("hyperparameters", "context_len"),
42 | config.get("hyperparameters", "batch_size")
43 | )
44 | return hparam
45 |
46 | def init_model(config, ckpt_path, init_w2v_embedding, device):
47 |
48 | encoder_type = config.get("hyperparameters", "model")
49 | number_hidden_layers = int(config.getfloat("hyperparameters", "number_hidden_layers"))
50 | hidden_layer_size = int(config.getfloat("hyperparameters", "hidden_layer_size"))
51 |
52 | model = Context2Score(encoder_type, 300, hidden_layer_size, device)
53 |
54 | pretrain = torch.load(ckpt_path)
55 | # pretrain.pop("word_embedding.weight")
56 | model.load_state_dict(pretrain)
57 | model.init_emb(torch.FloatTensor(init_w2v_embedding))
58 | model.eval()
59 |
60 | return model
61 |
62 |
63 | def predict_many(data, model, hypos, hypers, reverse, attention):
64 |
65 | num = 0
66 | result = []
67 | result_svd = []
68 | count_oop = 0
69 | count_pair = 0
70 | for hypon, hyper in zip(hypos, hypers):
71 | count_pair += 1
72 | if hypon in data.vocab and hyper in data.vocab:
73 | l = data.word2id[hypon]
74 | r = data.word2id[hyper]
75 |
76 | if reverse:
77 | pred = data.U[r].dot(data.V[l])
78 | else:
79 | pred = data.U[l].dot(data.V[r])
80 | result_svd.append(pred)
81 |
82 | else:
83 | # out of pattern mode
84 | result_svd.append(0.0)
85 | count_oop += 1
86 | try:
87 | hypon_id = data.context_w2i[hypon]
88 | hyper_id = data.context_w2i[hyper]
89 | hypon_word_context = data.context_dict[hypon_id]
90 | hyper_word_context = data.context_dict[hyper_id]
91 | inputs = torch.tensor(np.asarray([[hypon_word_context, hyper_word_context]]), dtype=torch.long)
92 |
93 | if reverse:
94 | inputs = torch.tensor(np.asarray([[hyper_word_context, hypon_word_context]]), dtype=torch.long)
95 | if attention:
96 | pred = model(inputs)[0].detach().cpu().numpy()[0]
97 | else:
98 | pred = model(inputs).detach().cpu().numpy()[0]
99 | else:
100 | inputs = torch.tensor(np.asarray([[hypon_word_context, hyper_word_context]]), dtype=torch.long)
101 | if attention:
102 | pred = model(inputs)[0].detach().cpu().numpy()[0]
103 | else:
104 | pred = model(inputs).detach().cpu().numpy()[0]
105 | except:
106 | num +=1
107 | pred = 0.0
108 |
109 | result.append(pred)
110 | # num = 0 -> all the word in the embedding
111 | oop_rate = count_oop * 1.0 / count_pair
112 | return np.array(result, dtype=np.float32), np.array(result_svd, dtype=np.float32), oop_rate
113 |
114 |
115 | def detection_setup(file_name, model, matrix_data ,device, attention):
116 |
117 | logger.info("-" * 80)
118 | logger.info("processing dataset :{}".format(file_name))
119 | ds = Testdataset(file_name, matrix_data.vocab)
120 |
121 | m_val = ds.val_mask
122 | m_test = ds.test_mask
123 |
124 | h = np.zeros(len(ds))
125 |
126 | h_ip = np.zeros(len(ds))
127 |
128 | print(len(ds))
129 |
130 | predict_mask = np.full(len(ds), True)
131 | inpattern_mask = np.full(len(ds), True)
132 |
133 | true_prediction = []
134 | in_pattern_prediction = []
135 |
136 | count_context = 0
137 |
138 | mask_idx = 0
139 | for x,y in zip(ds.hypos, ds.hypers):
140 | if x in matrix_data.vocab and y in matrix_data.vocab:
141 |
142 | l = matrix_data.word2id[x]
143 | r = matrix_data.word2id[y]
144 | score = matrix_data.U[l].dot(matrix_data.V[r])
145 |
146 | in_pattern_prediction.append(score)
147 | true_prediction.append(score)
148 | else:
149 |
150 | inpattern_mask[mask_idx] = False
151 | # out of pattern
152 | try:
153 | hypon_id = matrix_data.context_w2i[x]
154 | hyper_id = matrix_data.context_w2i[y]
155 |
156 | hypon_word_context = matrix_data.context_dict[hypon_id]
157 | hyper_word_context = matrix_data.context_dict[hyper_id]
158 | inputs = torch.tensor(np.asarray([[hypon_word_context, hyper_word_context]]), dtype=torch.long)
159 | #if model.name == "self":
160 | if attention:
161 | score = model(inputs)[0].detach().cpu().numpy()[0]
162 | else:
163 | score = model(inputs).detach().cpu().numpy()[0]
164 | count_context +=1
165 |
166 | true_prediction.append(score)
167 |
168 | except Exception as e:
169 | print(repr(e))
170 | print(file_name)
171 | predict_mask[mask_idx] = False
172 | mask_idx +=1
173 |
174 | h[predict_mask] = np.array(true_prediction, dtype=np.float32)
175 | h[~predict_mask] = h[predict_mask].min()
176 |
177 | h_ip[inpattern_mask] = np.array(in_pattern_prediction, dtype=np.float32)
178 | h_ip[~inpattern_mask] = h_ip[inpattern_mask].min()
179 |
180 | y = ds.y
181 |
182 |
183 | result= {
184 | "ap_val": average_precision_score(y[m_val],h[m_val]),
185 | "ap_test": average_precision_score(y[m_test],h[m_test]),
186 | }
187 |
188 | result['true_oov'] = int(np.sum(ds.oov_mask & ds.y))
189 |
190 | result['oov_rate'] = np.mean(ds.oov_mask)
191 | result['predict_num'] = int(np.sum(predict_mask))
192 | result['oov_num'] = int(np.sum(ds.oov_mask))
193 |
194 | logger.info("there are {:2d}/{:2d} pairs have context".format(count_context, result['oov_num']))
195 | logger.info("Context : AP for validation is :{} || for test is :{}".format(average_precision_score(y[m_val],h[m_val]),
196 | average_precision_score(y[m_test],h[m_test]) ))
197 | logger.info("Svdppmi : AP for validation is :{} || for test is :{}".format(average_precision_score(y[m_val],h_ip[m_val]),
198 | average_precision_score(y[m_test],h_ip[m_test]) ))
199 | logger.info("OOV true number is ".format(result['true_oov']))
200 |
201 | return result
202 |
203 | def hyperlex_setup(file_name, model, matrix_data,device, attention):
204 |
205 | logger.info("-" * 80)
206 | logger.info("processing dataset :{}".format(file_name))
207 |
208 | ds = Testdataset(file_name, matrix_data.vocab, ycolumn='score')
209 |
210 | h = np.zeros(len(ds))
211 | h_ip = np.zeros(len(ds))
212 |
213 | predict_mask = np.full(len(ds), True)
214 | inpattern_mask = np.full(len(ds), True)
215 |
216 | true_prediction = []
217 | in_pattern_prediction = []
218 |
219 | mask_idx = 0
220 | for x,y in zip(ds.hypos, ds.hypers):
221 | if x in matrix_data.vocab and y in matrix_data.vocab:
222 |
223 | l = matrix_data.word2id[x]
224 | r = matrix_data.word2id[y]
225 | score = matrix_data.U[l].dot(matrix_data.V[r])
226 |
227 | true_prediction.append(score)
228 | in_pattern_prediction.append(score)
229 |
230 | else:
231 | # out of pattern
232 | inpattern_mask[mask_idx] = False
233 | try:
234 | hypon_id = matrix_data.context_w2i[x]
235 | hyper_id = matrix_data.context_w2i[y]
236 |
237 | hypon_word_context = matrix_data.context_dict[hypon_id]
238 | hyper_word_context = matrix_data.context_dict[hyper_id]
239 | inputs = torch.tensor(np.asarray([[hypon_word_context, hyper_word_context]]), dtype=torch.long)
240 |
241 | if attention:
242 | score = model(inputs)[0].detach().cpu().numpy()[0]
243 | else:
244 | score = model(inputs).detach().cpu().numpy()[0]
245 |
246 | true_prediction.append(score)
247 |
248 | except Exception as e:
249 | print(repr(e))
250 | print(file_name)
251 | predict_mask[mask_idx] = False
252 |
253 | mask_idx +=1
254 |
255 | h[predict_mask] = np.array(true_prediction, dtype=np.float32)
256 | h[~predict_mask] = np.median(h[predict_mask])
257 |
258 | h_ip[inpattern_mask] = np.array(in_pattern_prediction, dtype=np.float32)
259 | h_ip[~inpattern_mask] = np.median(h_ip[inpattern_mask])
260 |
261 |
262 | y = ds.labels
263 |
264 | m_train = ds.train_mask
265 | m_val = ds.val_mask
266 | m_test = ds.test_mask
267 |
268 | result = {
269 | "spearman_train": stats.spearmanr(y[m_train], h[m_train])[0],
270 | "spearman_val": stats.spearmanr(y[m_val], h[m_val])[0],
271 | "spearman_test": stats.spearmanr(y[m_test], h[m_test])[0],
272 | }
273 |
274 |
275 | result['oov_rate'] = np.mean(ds.oov_mask)
276 | result['predict_num'] = int(np.sum(predict_mask))
277 | result['oov_num'] = int(np.sum(ds.oov_mask))
278 |
279 | svd_train = stats.spearmanr(y[m_train], h_ip[m_train])[0]
280 | svd_test = stats.spearmanr(y[m_test], h_ip[m_test])[0]
281 |
282 | oov_train = stats.spearmanr(y[ds.oov_mask], h[ds.oov_mask])
283 |
284 | logger.info("Context: train cor: {} | test cor:{}".format(result['spearman_train'],result['spearman_test']))
285 | logger.info("OOV cor: {}".format(oov_train))
286 |
287 | logger.info("Svdppmi: train cor: {} | test cor:{}".format(svd_train, svd_test))
288 |
289 | return result
290 |
291 |
292 | def dir_bless_setup(model, matrix_data ,device, attention):
293 |
294 | logger.info("-" * 80)
295 | logger.info("processing dataset : dir_bless")
296 | ds = Testdataset("data/bless.tsv", matrix_data.vocab)
297 |
298 | hypos = ds.hypos[ds.y]
299 | hypers = ds.hypers[ds.y]
300 |
301 | m_val = ds.val_mask[ds.y]
302 | m_test = ds.test_mask[ds.y]
303 |
304 | h = np.zeros(len(ds))
305 |
306 | pred_score_list = []
307 | svd_pred_list = []
308 | count_oop = 0
309 | count_pair = 0
310 |
311 | for hypon, hyper in zip(hypos, hypers):
312 | if hypon in matrix_data.vocab and hyper in matrix_data.vocab:
313 | l = matrix_data.word2id[hypon]
314 | r = matrix_data.word2id[hyper]
315 |
316 | forward_pred = matrix_data.U[l].dot(matrix_data.V[r])
317 | reverse_pred = matrix_data.U[r].dot(matrix_data.V[l])
318 |
319 | if forward_pred > reverse_pred:
320 | pred_score_list.append(1)
321 | svd_pred_list.append(1)
322 | else:
323 | pred_score_list.append(0)
324 | svd_pred_list.append(0)
325 | else:
326 | # out of pattern mode
327 | svd_pred_list.append(0)
328 | count_oop += 1
329 | try:
330 | hypon_id = matrix_data.context_w2i[hypon]
331 | hyper_id = matrix_data.context_w2i[hyper]
332 | hypon_word_context = matrix_data.context_dict[hypon_id]
333 | hyper_word_context = matrix_data.context_dict[hyper_id]
334 | inputs = torch.tensor(np.asarray([[hypon_word_context, hyper_word_context]]), dtype=torch.long)
335 |
336 | if attention:
337 | forward_pred = model(inputs)[0].detach().cpu().numpy()[0]
338 | else:
339 | forward_pred = model(inputs).detach().cpu().numpy()[0]
340 |
341 | inputs = torch.tensor(np.asarray([[hyper_word_context, hypon_word_context]]), dtype=torch.long)
342 |
343 | if attention:
344 | reverse_pred = model(inputs)[0].detach().cpu().numpy()[0]
345 | else:
346 | reverse_pred = model(inputs).detach().cpu().numpy()[0]
347 |
348 | if forward_pred > reverse_pred:
349 | pred_score_list.append(1)
350 | else:
351 | pred_score_list.append(0)
352 | except Exception as e:
353 | print(repr(e))
354 | pred_score_list.append(0)
355 |
356 | acc = np.mean(np.asarray(pred_score_list))
357 | acc_val = np.mean(np.asarray(pred_score_list)[m_val])
358 | acc_test = np.mean(np.asarray(pred_score_list)[m_test])
359 |
360 | s_acc = np.mean(np.asarray(svd_pred_list))
361 |
362 | logger.info("Val Acc : {} || Test Acc: {} ".format(acc_val, acc_test))
363 | logger.info("Sppmi Acc: {} ".format(s_acc))
364 |
365 |
366 | def dir_wbless_setup(model, data ,device,attention):
367 |
368 | logger.info("-" * 80)
369 | logger.info("processing dataset : dir_wbless")
370 | data_path = "data/wbless.tsv"
371 | ds = Testdataset(data_path, data.vocab)
372 |
373 |
374 | rng = np.random.RandomState(42)
375 | VAL_PROB = .02
376 | NUM_TRIALS = 1000
377 |
378 | # We have no way of handling oov
379 | h, h_svd, _ = predict_many(data, model, ds.hypos, ds.hypers, False, attention)
380 | y = ds.y
381 |
382 | val_scores = []
383 | test_scores = []
384 |
385 | for _ in range(NUM_TRIALS):
386 | # Generate a new mask every time
387 | m_val = rng.rand(len(y)) < VAL_PROB
388 | # Test is everything except val
389 | m_test = ~m_val
390 | _, _, t = precision_recall_curve(y[m_val], h[m_val])
391 | # pick the highest accuracy on the validation set
392 | thr_accs = np.mean((h[m_val, np.newaxis] >= t) == y[m_val, np.newaxis], axis=0)
393 | best_t = t[thr_accs.argmax()]
394 | preds_val = h[m_val] >= best_t
395 | preds_test = h[m_test] >= best_t
396 | # Evaluate
397 | val_scores.append(np.mean(preds_val == y[m_val]))
398 | test_scores.append(np.mean(preds_test == y[m_test]))
399 | # sanity check
400 | assert np.allclose(val_scores[-1], thr_accs.max())
401 |
402 | # report average across many folds
403 | logger.info("w2v: acc_val_inv: {} acc_test_inv: {}".format(np.mean(val_scores), np.mean(test_scores)))
404 |
405 | val_scores = []
406 | test_scores = []
407 |
408 | for _ in range(NUM_TRIALS):
409 | # Generate a new mask every time
410 | m_val = rng.rand(len(y)) < VAL_PROB
411 | # Test is everything except val
412 | m_test = ~m_val
413 | _, _, t = precision_recall_curve(y[m_val], h_svd[m_val])
414 | # pick the highest accuracy on the validation set
415 | thr_accs = np.mean((h_svd[m_val, np.newaxis] >= t) == y[m_val, np.newaxis], axis=0)
416 | best_t = t[thr_accs.argmax()]
417 | preds_val = h_svd[m_val] >= best_t
418 | preds_test = h_svd[m_test] >= best_t
419 | # Evaluate
420 | val_scores.append(np.mean(preds_val == y[m_val]))
421 | test_scores.append(np.mean(preds_test == y[m_test]))
422 | # sanity check
423 | assert np.allclose(val_scores[-1], thr_accs.max())
424 |
425 | # report average across many folds
426 | logger.info("sppmi: acc_val_inv: {} acc_test_inv: {}".format(np.mean(val_scores), np.mean(test_scores)))
427 |
428 |
429 | def dir_bibless_setup(model, data ,device, attention):
430 |
431 | logger.info("-" * 80)
432 | logger.info("processing dataset : dir_bibless")
433 | data_path = "data/bibless.tsv"
434 | ds = Testdataset(data_path, data.vocab)
435 |
436 |
437 | rng = np.random.RandomState(42)
438 | VAL_PROB = .02
439 | NUM_TRIALS = 1000
440 |
441 |
442 | #y = ds.y[ds.invocab_mask]
443 | y = ds.y
444 | # hypernymy could be either direction
445 | yh = y != 0
446 |
447 | # get forward and backward predictions
448 | hf, hf_svd, oop_rate = predict_many(data, model, ds.hypos, ds.hypers, False,attention)
449 | hr, hr_svd, _ = predict_many(data, model, ds.hypos, ds.hypers, True, attention)
450 | logger.info('OOP Rate: {}'.format(oop_rate))
451 | h = np.max([hf, hr], axis=0)
452 | h_svd = np.max([hf_svd, hr_svd], axis=0)
453 |
454 | dir_pred = 2 * np.float32(hf >= hr) - 1
455 | dir_pred_svd = 2 * np.float32(hf_svd >= hr_svd) - 1
456 |
457 | val_scores = []
458 | test_scores = []
459 | for _ in range(NUM_TRIALS):
460 | # Generate a new mask every time
461 | m_val = rng.rand(len(y)) < VAL_PROB
462 | # Test is everything except val
463 | m_test = ~m_val
464 |
465 | # set the threshold based on the maximum score
466 | _, _, t = precision_recall_curve(yh[m_val], h[m_val])
467 | thr_accs = np.mean((h[m_val, np.newaxis] >= t) == yh[m_val, np.newaxis], axis=0)
468 | best_t = t[thr_accs.argmax()]
469 |
470 | det_preds_val = h[m_val] >= best_t
471 | det_preds_test = h[m_test] >= best_t
472 |
473 | fin_preds_val = det_preds_val * dir_pred[m_val]
474 | fin_preds_test = det_preds_test * dir_pred[m_test]
475 |
476 | val_scores.append(np.mean(fin_preds_val == y[m_val]))
477 | test_scores.append(np.mean(fin_preds_test == y[m_test]))
478 |
479 | # report average across many folds
480 | logger.info("w2v: acc_val_all: {}, acc_test_all: {}".format(np.mean(val_scores),np.mean(test_scores)))
481 |
482 | val_scores = []
483 | test_scores = []
484 | for _ in range(NUM_TRIALS):
485 | # Generate a new mask every time
486 | m_val = rng.rand(len(y)) < VAL_PROB
487 | # Test is everything except val
488 | m_test = ~m_val
489 |
490 | # set the threshold based on the maximum score
491 | _, _, t = precision_recall_curve(yh[m_val], h_svd[m_val])
492 | thr_accs = np.mean((h_svd[m_val, np.newaxis] >= t) == yh[m_val, np.newaxis], axis=0)
493 | best_t = t[thr_accs.argmax()]
494 |
495 | det_preds_val = h_svd[m_val] >= best_t
496 | det_preds_test = h_svd[m_test] >= best_t
497 |
498 | fin_preds_val = det_preds_val * dir_pred_svd[m_val]
499 | fin_preds_test = det_preds_test * dir_pred_svd[m_test]
500 |
501 | val_scores.append(np.mean(fin_preds_val == y[m_val]))
502 | test_scores.append(np.mean(fin_preds_test == y[m_test]))
503 |
504 | # report average across many folds
505 | logger.info("sppmi: acc_val_all: {}, acc_test_all: {}".format(np.mean(val_scores),np.mean(test_scores)))
506 |
507 |
508 |
509 |
510 | def evaluation_all(config, ckpt_path):
511 |
512 | #embedding = load_gensim_word2vec()
513 |
514 | matrix_data = Dataset(config,train=False)
515 | gpu_device = config.get("hyperparameters", "gpu_device")
516 | device = torch.device('cuda:{}'.format(gpu_device) if torch.cuda.is_available() else 'cpu')
517 |
518 | model = init_model(config, ckpt_path, matrix_data.context_word_emb, device)
519 |
520 | results = {}
521 |
522 | if "self" or "han" in config.get("hyperparameters", "model"):
523 | attention = True
524 | else:
525 | attention = False
526 |
527 |
528 | for taskname, filename in SIEGE_EVALUATIONS:
529 | result = detection_setup(filename, model, matrix_data ,device, attention)
530 | results["detec_{}".format(taskname)] = result
531 |
532 | for taskname, filename in CORRELATION_EVAL_DATASETS:
533 | result = hyperlex_setup(filename, model, matrix_data, device, attention)
534 | results["corr_{}".format(taskname)] = result
535 |
536 | dir_bless_setup(model,matrix_data, device, attention)
537 | dir_wbless_setup(model, matrix_data, device,attention)
538 | dir_bibless_setup(model, matrix_data, device,attention)
539 |
540 | return results
541 |
542 | if __name__ == "__main__":
543 |
544 | config_file = sys.argv[1]
545 | config = configparser.RawConfigParser()
546 | config.read(config_file)
547 |
548 | ckpt_dir = config.get("data", "ckpt")
549 | hparam = make_hparam_string(config)
550 | ckpt_dir = os.path.join(ckpt_dir, hparam)
551 | ckpt_path = os.path.join(ckpt_dir, 'best.ckpt')
552 | log_path = os.path.join(ckpt_dir, 'context.log')
553 |
554 | logger = logging.getLogger()
555 | logger.setLevel(logging.INFO)
556 | handler = logging.FileHandler(log_path, 'w')
557 | handler.setLevel(logging.INFO)
558 | formatter = logging.Formatter('%(asctime)s: %(message)s', datefmt='%Y/%m/%d %H:%M:%S')
559 | handler.setFormatter(formatter)
560 | logger.addHandler(handler)
561 |
562 | results = evaluation_all(config, ckpt_path)
563 | print(results)
564 |
--------------------------------------------------------------------------------
/evaluation/evaluation_all_word.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import time
3 | import os
4 | import logging
5 | import tqdm
6 | import torch
7 | import torch.nn as nn
8 | import torch.nn.functional as F
9 | import torch.optim as optim
10 |
11 | import numpy as np
12 | import configparser
13 | from model.models import *
14 | from model.models import Word2Score
15 | from utils.data_helper import Dataset
16 | from utils.loader import Testdataset
17 | from scipy import stats
18 | from gensim.models import Word2Vec
19 | from sklearn.metrics import average_precision_score,precision_recall_curve
20 |
21 | SIEGE_EVALUATIONS = [
22 | ("bless", "data/bless.tsv"),
23 | ("eval", "data/eval.tsv"),
24 | ("leds", "data/leds.tsv"),
25 | ("shwartz", "data/shwartz.tsv"),
26 | ("weeds", "data/wbless.tsv"),
27 | ]
28 |
29 | CORRELATION_EVAL_DATASETS = [("hyperlex", "data/hyperlex_rnd.tsv"),
30 | ("hyperlex_noun", "data/hyperlex_noun.tsv")]
31 |
32 |
33 | def predict_many(data, model, hypos, hypers, embedding, device, reverse=False):
34 |
35 | num = 0
36 | result = []
37 | result_svd = []
38 | count_oop = 0
39 | count_pair = 0
40 | for hypon, hyper in zip(hypos, hypers):
41 | count_pair += 1
42 | if hypon in data.vocab and hyper in data.vocab:
43 | l = data.word2id[hypon]
44 | r = data.word2id[hyper]
45 |
46 | if reverse:
47 | pred = data.U[r].dot(data.V[l])
48 | else:
49 | pred = data.U[l].dot(data.V[r])
50 | result_svd.append(pred)
51 |
52 | else:
53 | # out of pattern mode
54 | result_svd.append(0.0)
55 | count_oop += 1
56 | if hypon in embedding and hyper in embedding:
57 | hypon_tensor = torch.from_numpy(embedding[hypon]).view(1,300).to(device)
58 | hyper_tensor = torch.from_numpy(embedding[hyper]).view(1,300).to(device)
59 |
60 | if reverse:
61 | # pred = inference(saved_model,hyper_tensor, hypon_tensor)
62 | pred = model.inference(hyper_tensor, hypon_tensor).detach().cpu().numpy()[0]
63 | else:
64 | # pred = inference(saved_model,hypon_tensor, hyper_tensor)
65 | pred = model.inference(hypon_tensor, hyper_tensor).detach().cpu().numpy()[0]
66 | else:
67 | num +=1
68 | pred = 0.0
69 |
70 | result.append(pred)
71 | # num = 0 -> all the word in the embedding
72 | oop_rate = count_oop * 1.0 / count_pair
73 | return np.array(result, dtype=np.float32), np.array(result_svd, dtype=np.float32), oop_rate
74 |
75 |
76 |
77 | def make_hparam_string(config):
78 | hparam = "{}/s{}_h{}-{}_n{}_w{}".format(
79 | config.get("hyperparameters", "model"),
80 | config.get("hyperparameters", "svd_dimension"),
81 | config.get("hyperparameters", "number_hidden_layers"),
82 | config.get("hyperparameters", "hidden_layer_size"),
83 | config.get("hyperparameters", "negative_num"),
84 | # config.get("hyperparameters", "batch_size"),
85 | config.get("hyperparameters", "weight_decay"),
86 | # config.get("hyperparameters", "context_num"),
87 | # config.get("hyperparameters", "context_len")
88 | )
89 | return hparam
90 |
91 | def init_model(config):
92 |
93 | hidden_layer_size = int(config.getfloat("hyperparameters", "hidden_layer_size"))
94 | number_hidden_layers = int(config.getfloat("hyperparameters", "number_hidden_layers"))
95 |
96 | model = Word2Score(hidden_layer_size, number_hidden_layers)
97 | return model
98 |
99 | def load_gensim_word2vec():
100 |
101 | print("Loading pretrained word embedding ... ")
102 | wv_model = Word2Vec.load("/home/shared/embedding/ukwac.model")
103 | embedding = wv_model.wv
104 |
105 | return embedding
106 |
107 |
108 | def detection_setup(file_name, model, matrix_data, embedding,device):
109 |
110 | ds = Testdataset(file_name, matrix_data.vocab)
111 | logger.info("-" * 80)
112 | logger.info("processing dataset :{}".format(file_name))
113 |
114 | m_val = ds.val_mask
115 | m_test = ds.test_mask
116 |
117 | h = np.zeros(len(ds))
118 | h_ip = np.zeros(len(ds))
119 |
120 | predict_mask = np.full(len(ds), True)
121 | inpattern_mask = np.full(len(ds), True)
122 |
123 | true_prediction = []
124 | in_pattern_prediction = []
125 |
126 | count_w2v = 0
127 |
128 | mask_idx = 0
129 | for x,y in zip(ds.hypos, ds.hypers):
130 | if x in matrix_data.vocab and y in matrix_data.vocab:
131 |
132 | l = matrix_data.word2id[x]
133 | r = matrix_data.word2id[y]
134 | score = matrix_data.U[l].dot(matrix_data.V[r])
135 |
136 | true_prediction.append(score)
137 | in_pattern_prediction.append(score)
138 |
139 | else:
140 | # out of pattern
141 | inpattern_mask[mask_idx] = False
142 |
143 | if x in embedding and y in embedding:
144 | hypon_tensor = torch.from_numpy(embedding[x]).view(1,300).to(device)
145 | hyper_tensor = torch.from_numpy(embedding[y]).view(1,300).to(device)
146 | score = model.inference(hypon_tensor, hyper_tensor).detach().cpu().numpy()[0]
147 | true_prediction.append(score)
148 |
149 | count_w2v +=1
150 |
151 | else:
152 | predict_mask[mask_idx] = False
153 | mask_idx +=1
154 |
155 | h[predict_mask] = np.array(true_prediction, dtype=np.float32)
156 | h[~predict_mask] = h[predict_mask].min()
157 |
158 | h_ip[inpattern_mask] = np.array(in_pattern_prediction, dtype=np.float32)
159 | h_ip[~inpattern_mask] = h_ip[inpattern_mask].min()
160 |
161 | y = ds.y
162 |
163 | result= {
164 | "ap_val": average_precision_score(y[m_val],h[m_val]),
165 | "ap_test": average_precision_score(y[m_test],h[m_test]),
166 | }
167 | result['oov_rate'] = np.mean(ds.oov_mask)
168 | result['predict_num'] = int(np.sum(predict_mask))
169 | result['oov_num'] = int(np.sum(ds.oov_mask))
170 |
171 | logger.info("there are {:2d}/{:2d} pairs appeared in the trained embedding".format(count_w2v, result['oov_num']))
172 | logger.info("Word2Vec : AP for validation is :{} || for test is :{}".format(result['ap_val'], result['ap_test']))
173 | logger.info("Svdppmi : AP for validation is :{} || for test is :{}".format(average_precision_score(y[m_val],h_ip[m_val]),
174 | average_precision_score(y[m_test],h_ip[m_test]) ))
175 |
176 | return result
177 |
178 | def hyperlex_setup(file_name, model, matrix_data, embedding,device):
179 |
180 | logger.info("-" * 80)
181 | logger.info("processing dataset :{}".format(file_name))
182 |
183 | ds = Testdataset(file_name, matrix_data.vocab, ycolumn='score')
184 |
185 | h = np.zeros(len(ds))
186 |
187 | predict_mask = np.full(len(ds), True)
188 |
189 | true_prediction = []
190 |
191 | mask_idx = 0
192 | for x,y in zip(ds.hypos, ds.hypers):
193 | if x in matrix_data.vocab and y in matrix_data.vocab:
194 |
195 | l = matrix_data.word2id[x]
196 | r = matrix_data.word2id[y]
197 | score = matrix_data.U[l].dot(matrix_data.V[r])
198 |
199 | true_prediction.append(score)
200 |
201 | else:
202 | # out of pattern
203 | if x in embedding and y in embedding:
204 | hypon_tensor = torch.from_numpy(embedding[x]).view(1,300).to(device)
205 | hyper_tensor = torch.from_numpy(embedding[y]).view(1,300).to(device)
206 | score = model.inference(hypon_tensor, hyper_tensor).detach().cpu().numpy()[0]
207 | true_prediction.append(score)
208 | else:
209 | predict_mask[mask_idx] = False
210 |
211 | mask_idx +=1
212 |
213 | h[predict_mask] = np.array(true_prediction, dtype=np.float32)
214 | h[~predict_mask] = np.median(h[predict_mask])
215 |
216 | y = ds.labels
217 |
218 | m_train = ds.train_mask
219 | m_val = ds.val_mask
220 | m_test = ds.test_mask
221 |
222 | result = {
223 | "spearman_train": stats.spearmanr(y[m_train], h[m_train])[0],
224 | "spearman_val": stats.spearmanr(y[m_val], h[m_val])[0],
225 | "spearman_test": stats.spearmanr(y[m_test], h[m_test])[0],
226 | }
227 |
228 | result['oov_rate'] = np.mean(ds.oov_mask)
229 | result['predict_num'] = int(np.sum(predict_mask))
230 | result['oov_num'] = int(np.sum(ds.oov_mask))
231 |
232 | logger.info("Word2Vec: train cor: {} | test cor:{}".format(result['spearman_train'],result['spearman_test']))
233 |
234 | return result
235 |
236 | def dir_bless_setup(model, matrix_data, embedding, device):
237 |
238 | logger.info("-" * 80)
239 | logger.info("processing dataset : dir_bless")
240 | ds = Testdataset("data/bless.tsv", matrix_data.vocab)
241 |
242 | hypos = ds.hypos[ds.y]
243 | hypers = ds.hypers[ds.y]
244 |
245 | m_val = ds.val_mask[ds.y]
246 | m_test = ds.test_mask[ds.y]
247 |
248 | h = np.zeros(len(ds))
249 |
250 | pred_score_list = []
251 | svd_pred_list = []
252 | count_oop = 0
253 | count_pair = 0
254 |
255 | for hypon, hyper in zip(hypos, hypers):
256 | if hypon in matrix_data.vocab and hyper in matrix_data.vocab:
257 | l = matrix_data.word2id[hypon]
258 | r = matrix_data.word2id[hyper]
259 |
260 | forward_pred = matrix_data.U[l].dot(matrix_data.V[r])
261 | reverse_pred = matrix_data.U[r].dot(matrix_data.V[l])
262 |
263 | if forward_pred > reverse_pred:
264 | pred_score_list.append(1)
265 | svd_pred_list.append(1)
266 | else:
267 | pred_score_list.append(0)
268 | svd_pred_list.append(0)
269 | else:
270 | # out of pattern mode
271 | svd_pred_list.append(0)
272 | count_oop += 1
273 |
274 | if hypon in embedding and hyper in embedding:
275 | hypon_tensor = torch.from_numpy(embedding[hypon]).view(1,300).to(device)
276 | hyper_tensor = torch.from_numpy(embedding[hyper]).view(1,300).to(device)
277 | forward_pred = model.inference(hypon_tensor, hyper_tensor).detach().cpu().numpy()[0]
278 | reverse_pred = model.inference(hyper_tensor, hypon_tensor).detach().cpu().numpy()[0]
279 |
280 | if forward_pred > reverse_pred:
281 | pred_score_list.append(1)
282 | else:
283 | pred_score_list.append(0)
284 | else:
285 | pred_score_list.append(0)
286 |
287 | acc = np.mean(np.asarray(pred_score_list))
288 | acc_val = np.mean(np.asarray(pred_score_list)[m_val])
289 | acc_test = np.mean(np.asarray(pred_score_list)[m_test])
290 |
291 | s_acc = np.mean(np.asarray(svd_pred_list))
292 |
293 | logger.info("Val Acc : {} || Test Acc: {} ".format(acc_val, acc_test))
294 | logger.info("Sppmi Acc: {} ".format(s_acc))
295 |
296 |
297 | def dir_wbless_setup(model, data, embedding,device):
298 |
299 | logger.info("-" * 80)
300 | logger.info("processing dataset : dir_wbless")
301 | data_path = "data/wbless.tsv"
302 | ds = Testdataset(data_path, data.vocab)
303 |
304 | rng = np.random.RandomState(42)
305 | VAL_PROB = .02
306 | NUM_TRIALS = 1000
307 |
308 | # We have no way of handling oov
309 | h, h_svd, _ = predict_many(data, model, ds.hypos, ds.hypers, embedding, device)
310 | y = ds.y
311 |
312 | val_scores = []
313 | test_scores = []
314 |
315 | for _ in range(NUM_TRIALS):
316 | # Generate a new mask every time
317 | m_val = rng.rand(len(y)) < VAL_PROB
318 | # Test is everything except val
319 | m_test = ~m_val
320 | _, _, t = precision_recall_curve(y[m_val], h[m_val])
321 | # pick the highest accuracy on the validation set
322 | thr_accs = np.mean((h[m_val, np.newaxis] >= t) == y[m_val, np.newaxis], axis=0)
323 | best_t = t[thr_accs.argmax()]
324 | preds_val = h[m_val] >= best_t
325 | preds_test = h[m_test] >= best_t
326 | # Evaluate
327 | val_scores.append(np.mean(preds_val == y[m_val]))
328 | test_scores.append(np.mean(preds_test == y[m_test]))
329 | # sanity check
330 | assert np.allclose(val_scores[-1], thr_accs.max())
331 |
332 | # report average across many folds
333 | logger.info("w2v: acc_val_inv: {} acc_test_inv: {}".format(np.mean(val_scores), np.mean(test_scores)))
334 |
335 | val_scores = []
336 | test_scores = []
337 |
338 | for _ in range(NUM_TRIALS):
339 | # Generate a new mask every time
340 | m_val = rng.rand(len(y)) < VAL_PROB
341 | # Test is everything except val
342 | m_test = ~m_val
343 | _, _, t = precision_recall_curve(y[m_val], h_svd[m_val])
344 | # pick the highest accuracy on the validation set
345 | thr_accs = np.mean((h_svd[m_val, np.newaxis] >= t) == y[m_val, np.newaxis], axis=0)
346 | best_t = t[thr_accs.argmax()]
347 | preds_val = h_svd[m_val] >= best_t
348 | preds_test = h_svd[m_test] >= best_t
349 | # Evaluate
350 | val_scores.append(np.mean(preds_val == y[m_val]))
351 | test_scores.append(np.mean(preds_test == y[m_test]))
352 | # sanity check
353 | assert np.allclose(val_scores[-1], thr_accs.max())
354 |
355 | # report average across many folds
356 | logger.info("sppmi: acc_val_inv: {} acc_test_inv: {}".format(np.mean(val_scores), np.mean(test_scores)))
357 |
358 |
359 | def dir_bibless_setup(model, data, embedding, device):
360 |
361 | logger.info("-" * 80)
362 | logger.info("processing dataset : dir_bibless")
363 | data_path = "data/bibless.tsv"
364 | ds = Testdataset(data_path, data.vocab)
365 |
366 |
367 | rng = np.random.RandomState(42)
368 | VAL_PROB = .02
369 | NUM_TRIALS = 1000
370 |
371 |
372 | #y = ds.y[ds.invocab_mask]
373 | y = ds.y
374 | # hypernymy could be either direction
375 | yh = y != 0
376 |
377 | # get forward and backward predictions
378 | hf, hf_svd, oop_rate = predict_many(data, model, ds.hypos, ds.hypers, embedding, device, reverse=False)
379 | hr, hr_svd, _ = predict_many(data, model, ds.hypos, ds.hypers, embedding, device, reverse=True)
380 | logger.info('OOP Rate: {}'.format(oop_rate))
381 | h = np.max([hf, hr], axis=0)
382 | h_svd = np.max([hf_svd, hr_svd], axis=0)
383 |
384 | dir_pred = 2 * np.float32(hf >= hr) - 1
385 | dir_pred_svd = 2 * np.float32(hf_svd >= hr_svd) - 1
386 |
387 | val_scores = []
388 | test_scores = []
389 | for _ in range(NUM_TRIALS):
390 | # Generate a new mask every time
391 | m_val = rng.rand(len(y)) < VAL_PROB
392 | # Test is everything except val
393 | m_test = ~m_val
394 |
395 | # set the threshold based on the maximum score
396 | _, _, t = precision_recall_curve(yh[m_val], h[m_val])
397 | thr_accs = np.mean((h[m_val, np.newaxis] >= t) == yh[m_val, np.newaxis], axis=0)
398 | best_t = t[thr_accs.argmax()]
399 |
400 | det_preds_val = h[m_val] >= best_t
401 | det_preds_test = h[m_test] >= best_t
402 |
403 | fin_preds_val = det_preds_val * dir_pred[m_val]
404 | fin_preds_test = det_preds_test * dir_pred[m_test]
405 |
406 | val_scores.append(np.mean(fin_preds_val == y[m_val]))
407 | test_scores.append(np.mean(fin_preds_test == y[m_test]))
408 |
409 | # report average across many folds
410 | logger.info("w2v: acc_val_all: {}, acc_test_all: {}".format(np.mean(val_scores),np.mean(test_scores)))
411 |
412 | val_scores = []
413 | test_scores = []
414 | for _ in range(NUM_TRIALS):
415 | # Generate a new mask every time
416 | m_val = rng.rand(len(y)) < VAL_PROB
417 | # Test is everything except val
418 | m_test = ~m_val
419 |
420 | # set the threshold based on the maximum score
421 | _, _, t = precision_recall_curve(yh[m_val], h_svd[m_val])
422 | thr_accs = np.mean((h_svd[m_val, np.newaxis] >= t) == yh[m_val, np.newaxis], axis=0)
423 | best_t = t[thr_accs.argmax()]
424 |
425 | det_preds_val = h_svd[m_val] >= best_t
426 | det_preds_test = h_svd[m_test] >= best_t
427 |
428 | fin_preds_val = det_preds_val * dir_pred_svd[m_val]
429 | fin_preds_test = det_preds_test * dir_pred_svd[m_test]
430 |
431 | val_scores.append(np.mean(fin_preds_val == y[m_val]))
432 | test_scores.append(np.mean(fin_preds_test == y[m_test]))
433 |
434 | # report average across many folds
435 | logger.info("sppmi: acc_val_all: {}, acc_test_all: {}".format(np.mean(val_scores),np.mean(test_scores)))
436 |
437 |
438 |
439 | def evaluation_all(model_config):
440 |
441 | embedding = load_gensim_word2vec()
442 | config = configparser.RawConfigParser()
443 |
444 | config.read(model_config)
445 |
446 | gpu_device = config.get("hyperparameters", "gpu_device")
447 | device = torch.device('cuda:{}'.format(gpu_device) if torch.cuda.is_available() else 'cpu')
448 |
449 | matrix_data = Dataset(config)
450 |
451 | model = init_model(config)
452 | model.to(device)
453 |
454 | #pretrain = torch.load("/home/shared/acl-data/hype_detection/checkpoints/mlp_unisample_svd/s50_h2-300_n400_w0/best.ckpt")
455 | pretrain = torch.load("/home/cyuaq/comHyper/checkpoints/mlp_unisample_svd/s50_h2-300_n400_b128/best.ckpt")
456 | pretrain.pop("embs.weight")
457 | model.load_state_dict(pretrain)
458 | model.eval()
459 |
460 | results = {}
461 |
462 | for taskname, filename in SIEGE_EVALUATIONS:
463 | result = detection_setup(filename, model, matrix_data, embedding,device)
464 | results["detec_{}".format(taskname)] = result
465 |
466 | for taskname, filename in CORRELATION_EVAL_DATASETS:
467 | result = hyperlex_setup(filename, model, matrix_data, embedding, device)
468 | results["corr_{}".format(taskname)] = result
469 |
470 | dir_bless_setup(model, matrix_data, embedding, device)
471 | dir_wbless_setup(model, matrix_data, embedding, device)
472 | dir_bibless_setup(model, matrix_data, embedding, device)
473 |
474 | return results
475 |
476 | if __name__ == "__main__":
477 |
478 | config_file = sys.argv[1]
479 |
480 | log_path = "/home/cyuaq/comHyper/checkpoints/mlp_unisample_svd/s50_h2-300_n400_b128/word2score.log"
481 | logger = logging.getLogger()
482 | logger.setLevel(logging.INFO)
483 | handler = logging.FileHandler(log_path, 'w')
484 | handler.setLevel(logging.INFO)
485 | formatter = logging.Formatter('%(asctime)s: %(message)s', datefmt='%Y/%m/%d %H:%M:%S')
486 | handler.setFormatter(formatter)
487 | logger.addHandler(handler)
488 |
489 | results = evaluation_all(config_file)
490 | print(results)
491 |
--------------------------------------------------------------------------------
/img/dis.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HKUST-KnowComp/ComHyper/d67cdfb409a8b6ddef45d4d5457182c24057acc7/img/dis.png
--------------------------------------------------------------------------------
/img/framework.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HKUST-KnowComp/ComHyper/d67cdfb409a8b6ddef45d4d5457182c24057acc7/img/framework.png
--------------------------------------------------------------------------------
/model/models.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | import copy
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 | from transformers import BertTokenizer, BertModel
7 |
8 | def linear_block(input_dim, hidden_dim):
9 |
10 | linear = nn.Sequential(
11 | nn.Linear(input_dim, hidden_dim),
12 | nn.LeakyReLU(0.5))
13 |
14 | return linear
15 |
16 | class MLP(nn.Module):
17 | def __init__(self, input_dim, hidden_dim, num_layers):
18 | super(MLP, self).__init__()
19 | self.num_layers = num_layers
20 | self.hidden_size = hidden_dim
21 |
22 | layers = []
23 | for i in range(num_layers-1):
24 | layers.extend(
25 | linear_block(hidden_dim if i> 0 else input_dim, hidden_dim)
26 | )
27 | layers.extend([nn.Linear(hidden_dim, input_dim)])
28 |
29 | self.model = nn.Sequential(*layers)
30 |
31 | ## initilize the model
32 | for m in self.modules():
33 | if isinstance(m, nn.Linear):
34 | nn.init.kaiming_uniform_(m.weight, a=math.sqrt(5))
35 | fan_in,_ = nn.init._calculate_fan_in_and_fan_out(m.weight)
36 | bound = 1/math.sqrt(fan_in)
37 | nn.init.uniform_(m.bias, -bound, bound)
38 |
39 | def forward(self,x):
40 | out = self.model(x)
41 | return out
42 |
43 |
44 | class SDSN(nn.Module):
45 | """docstring for SDSNA"""
46 | # Replace simple dot product with SDSNA
47 | # Scoring Lexical Entailment with a supervised directional similarity network
48 | def __init__(self, arg):
49 | super(SDSNA, self).__init__()
50 |
51 | self.emb_dim = 300
52 | self.hidden_dim = hidden_dim
53 | self.num_layers = num_layers
54 | self.map_linear_left = self.mlp(self.emb_dim, self.hidden_dim, self.num_layers)
55 | self.map_linear_right = self.mlp(self.emb_dim, self.hidden_dim, self.num_layers)
56 |
57 | self.final_linear = nn.Linear(2 * self.hidden_dim + self.emb_dim, 1)
58 |
59 | def init_embs(self, w2v_weight):
60 | self.embs = nn.Embedding.from_pretrained(w2v_weight, freeze=True)
61 |
62 | def forward(self, inputs):
63 |
64 | batch_size, _ = inputs.size()
65 | left_w2v = self.embs(inputs[:,0])
66 | right_w2v = self.embs(inputs[:,1])
67 |
68 | left_trans = self.map_linear_left(left_w2v)
69 | right_trans = self.map_linear_right(right_w2v)
70 |
71 | def mlp(self, input_dim, hidden_dim, num_layers):
72 | layers = []
73 | for i in range(num_layers-1):
74 | layers.extend(
75 | linear_block(hidden_dim if i> 0 else input_dim, hidden_dim)
76 | )
77 | layers.extend([nn.Linear(hidden_dim, input_dim)])
78 |
79 | return nn.Sequential(*layers)
80 |
81 |
82 | class Word2Score(nn.Module):
83 | """docstring for Word2Score"""
84 | def __init__(self, hidden_dim, num_layers):
85 | super(Word2Score, self).__init__()
86 |
87 | self.emb_dim = 300
88 | self.hidden_dim = hidden_dim
89 | self.num_layers = num_layers
90 | self.map_linear_left = self.mlp(self.emb_dim, self.hidden_dim, self.num_layers)
91 | self.map_linear_right = self.mlp(self.emb_dim, self.hidden_dim, self.num_layers)
92 |
93 | def init_emb(self, w2v_weight):
94 | self.embs = nn.Embedding.from_pretrained(w2v_weight, freeze=True)
95 |
96 | def mlp(self, input_dim, hidden_dim, num_layers):
97 | layers = []
98 | for i in range(num_layers-1):
99 | layers.extend(
100 | linear_block(hidden_dim if i> 0 else input_dim, hidden_dim)
101 | )
102 | layers.extend([nn.Linear(hidden_dim, input_dim)])
103 |
104 | return nn.Sequential(*layers)
105 |
106 | def forward(self, inputs):
107 |
108 | # inputs: [batch_size, 2]
109 | batch_size, _ = inputs.size()
110 | left_w2v = self.embs(inputs[:,0])
111 | right_w2v = self.embs(inputs[:,1])
112 |
113 | left_trans = self.map_linear_left(left_w2v)
114 | right_trans = self.map_linear_right(right_w2v)
115 |
116 | output = torch.einsum('ij,ij->i', [left_trans, right_trans])
117 |
118 | left_norm = torch.norm(left_trans, dim=1).sum()
119 | right_norm = torch.norm(right_trans, dim=1).sum()
120 |
121 | return output, (left_norm+right_norm)
122 |
123 | def inference(self, left_w2v, right_w2v):
124 |
125 | left_trans = self.map_linear_left(left_w2v)
126 | right_trans = self.map_linear_right(right_w2v)
127 |
128 | output = torch.einsum('ij,ij->i', [left_trans, right_trans])
129 |
130 | return output
131 |
132 | class MEAN_Max(nn.Module):
133 | """docstring for MEAN"""
134 | def __init__(self, input_dim, hidden_dim):
135 | super(MEAN_Max, self).__init__()
136 | self.input_dim = input_dim
137 | self.hidden_dim = hidden_dim
138 | self.dropout_layer = nn.Dropout(0)
139 | self.output_layer = nn.Sequential(
140 | nn.Linear(input_dim, hidden_dim),
141 | nn.ReLU(),
142 | nn.Linear(hidden_dim, input_dim)
143 | )
144 |
145 | def forward(self, embed_input_left, embed_input_right):
146 | # input: [batch, context, seq, emb]
147 | batch_size, num_context, seqlen, emb_dim = embed_input_left.size()
148 |
149 | # [batch, context, seq, emb]
150 | embed_input_left = self.dropout_layer(embed_input_left)
151 | embed_input_right = self.dropout_layer(embed_input_right)
152 |
153 | oe = torch.cat((embed_input_left, embed_input_right), 2)
154 | oe = oe.mean(2)
155 | oe = self.output_layer(oe)
156 | oe = oe.max(1)[0]
157 | return oe
158 |
159 |
160 | class MEAN(nn.Module):
161 | """docstring for MEAN"""
162 | def __init__(self, input_dim, hidden_dim):
163 | super(MEAN, self).__init__()
164 | self.input_dim = input_dim
165 | self.hidden_dim = hidden_dim
166 | self.dropout_layer = nn.Dropout(0)
167 | self.output_layer = nn.Sequential(
168 | nn.Linear(input_dim, hidden_dim),
169 | nn.ReLU(),
170 | nn.Linear(hidden_dim, input_dim)
171 | )
172 |
173 | def forward(self, embed_input_left, embed_input_right):
174 | # input: [batch, context, seq, emb]
175 | batch_size, num_context, seqlen, emb_dim = embed_input_left.size()
176 |
177 | # [batch, context, seq, emb]
178 | embed_input_left = self.dropout_layer(embed_input_left)
179 | embed_input_right = self.dropout_layer(embed_input_right)
180 |
181 | oe = torch.cat((embed_input_left, embed_input_right), 2)
182 | oe = oe.mean(2)
183 | oe = self.output_layer(oe)
184 | oe = oe.mean(1)
185 | return oe
186 |
187 | class LSTM(nn.Module):
188 | """docstring for LSTM"""
189 | def __init__(self, input_dim, hidden_dim):
190 | super(LSTM, self).__init__()
191 | self.input_dim = input_dim
192 | self.hidden_dim = hidden_dim
193 | self.dropout_layer = nn.Dropout(p=0)
194 | self.left_context_encoder = nn.LSTM(input_dim, hidden_dim, 1, batch_first=True)
195 | self.right_context_encoder = nn.LSTM(input_dim, hidden_dim, 1, batch_first=True)
196 | self.output_layer = nn.Sequential(
197 | nn.Linear(hidden_dim*2, hidden_dim*2),
198 | nn.ReLU(),
199 | nn.Linear(hidden_dim*2, input_dim)
200 | )
201 |
202 | def forward(self, embed_input_left, embed_input_right):
203 | # input: [batch, context, seq, emb]
204 | batch_size, num_context, seqlen, emb_dim = embed_input_left.size()
205 |
206 | # [batch, context, seq, dim]
207 | embed_input_left = embed_input_left.view(-1, seqlen, self.input_dim)
208 | embed_input_left = self.dropout_layer(embed_input_left)
209 |
210 | embed_input_right = embed_input_right.view(-1, seqlen, self.input_dim)
211 | embed_input_right = self.dropout_layer(embed_input_right)
212 |
213 | # hidden = (torch.zeros(1, batch_size*num_context, self.hidden_dim),
214 | # torch.zeros(1, batch_size*num_context, self.hidden_dim))
215 |
216 | output_left, (final_hidden_state_left, final_cell_state_left) = self.left_context_encoder(embed_input_left) #, hidden)
217 | output_right,(final_hidden_state_right, final_cell_state_left) = self.right_context_encoder(embed_input_right) #, hidden)
218 |
219 | encode_context_left = final_hidden_state_left.view(-1, num_context, self.hidden_dim)
220 | encode_context_right = final_hidden_state_right.view(-1, num_context, self.hidden_dim)
221 |
222 | # concat + mean_pooling + fully_connect
223 | oe = torch.cat((encode_context_left, encode_context_right), 2)
224 | oe = self.output_layer(oe)
225 | oe = oe.mean(1)
226 | return oe
227 |
228 | class SelfAttention(nn.Module):
229 | """docstring for SelfAttention"""
230 | def __init__(self, input_dim, hidden_dim):
231 | super(SelfAttention, self).__init__()
232 | self.input_dim = input_dim
233 | self.hidden_dim = hidden_dim
234 | self.dropout_layer = nn.Dropout(0)
235 |
236 | self.att_w = nn.Linear(input_dim, hidden_dim)
237 | self.att_v = nn.Parameter(torch.rand(hidden_dim))
238 |
239 | self.output_layer = nn.Sequential(
240 | nn.Linear(hidden_dim, hidden_dim),
241 | nn.ReLU(),
242 | nn.Linear(hidden_dim, input_dim)
243 | )
244 |
245 | def forward(self, embed_input_left, embed_input_right):
246 |
247 | batch_size, num_context, seqlen, emb_dim = embed_input_left.size()
248 |
249 | # [batch, context, seq, dim]
250 | embed_input_left = self.dropout_layer(embed_input_left)
251 | embed_input_right = self.dropout_layer(embed_input_right)
252 |
253 | # [batch_size, context_num, seq_length, dim]
254 | left_right_context = torch.cat((embed_input_left, embed_input_right),2)
255 | #print(left_right_context.size())
256 |
257 | att_weight = torch.matmul(self.att_w(left_right_context), self.att_v)
258 | att_weight = nn.functional.softmax(att_weight, dim=2).view(batch_size, num_context, 2*seqlen, 1)
259 | #print(att_weight.size())
260 |
261 | oe = (left_right_context * att_weight).sum(2)
262 |
263 | oe = self.output_layer(oe)
264 |
265 | oe = oe.mean(1)
266 |
267 | return oe ,att_weight
268 |
269 |
270 | class HierAttention(nn.Module):
271 |
272 | def __init__(self, input_dim, hidden_dim):
273 | super(HierAttention, self).__init__()
274 | self.input_dim = input_dim
275 | self.hidden_dim = hidden_dim
276 | self.dropout_layer = nn.Dropout(0)
277 |
278 | self.att_w = nn.Linear(input_dim, hidden_dim)
279 | self.att_v = nn.Parameter(torch.rand(hidden_dim))
280 |
281 | self.att_h = nn.Linear(input_dim, hidden_dim)
282 | self.att_hv = nn.Parameter(torch.rand(hidden_dim))
283 |
284 | self.output_layer = nn.Sequential(
285 | nn.Linear(input_dim, hidden_dim),
286 | nn.ReLU(),
287 | nn.Linear(hidden_dim, input_dim)
288 | )
289 |
290 | def forward(self, embed_input_left, embed_input_right):
291 |
292 | batch_size, num_context, seqlen, emb_dim = embed_input_left.size()
293 |
294 | # [batch, context, seq, dim]
295 | embed_input_left = self.dropout_layer(embed_input_left)
296 | embed_input_right = self.dropout_layer(embed_input_right)
297 |
298 | # [batch_size, context_num, seq_length, dim]
299 | left_right_context = torch.cat((embed_input_left, embed_input_right),2)
300 | #print(left_right_context.size())
301 |
302 |
303 | att_weight = torch.matmul(self.att_w(left_right_context), self.att_v)
304 | att_weight = nn.functional.softmax(att_weight, dim=2).view(batch_size, num_context, 2*seqlen, 1)
305 |
306 | oe = (left_right_context * att_weight).sum(2)
307 |
308 | #print(oe.size())
309 |
310 | hier_att_weight = torch.matmul(self.att_h(oe), self.att_hv)
311 | #print(hier_att_weight.size())
312 |
313 | hier_att_weight = nn.functional.softmax(hier_att_weight, dim=1).view(batch_size, num_context, 1)
314 | #print(hier_att_weight.size())
315 |
316 | oe = (oe * hier_att_weight).sum(1)
317 |
318 | oe = self.output_layer(oe)
319 |
320 | return oe, att_weight, hier_att_weight
321 |
322 |
323 |
324 | class HierAttentionEnsemble(nn.Module):
325 |
326 | def __init__(self, input_dim, hidden_dim):
327 | super(HierAttention, self).__init__()
328 | self.input_dim = input_dim
329 | self.hidden_dim = hidden_dim
330 | self.dropout_layer = nn.Dropout(0)
331 |
332 | self.att_w = nn.Linear(input_dim, hidden_dim)
333 | self.att_v = nn.Parameter(torch.rand(hidden_dim))
334 |
335 | self.att_h = nn.Linear(input_dim, hidden_dim)
336 | self.att_hv = nn.Parameter(torch.rand(hidden_dim))
337 |
338 | self.output_layer = nn.Sequential(
339 | nn.Linear(input_dim, hidden_dim),
340 | nn.ReLU(),
341 | nn.Linear(hidden_dim, input_dim)
342 | )
343 |
344 | def forward(self, embed_input_left, embed_input_right):
345 |
346 | batch_size, num_context, seqlen, emb_dim = embed_input_left.size()
347 |
348 | # [batch, context, seq, dim]
349 | embed_input_left = self.dropout_layer(embed_input_left)
350 | embed_input_right = self.dropout_layer(embed_input_right)
351 |
352 | # [batch_size, context_num, seq_length, dim]
353 | left_right_context = torch.cat((embed_input_left, embed_input_right),2)
354 | #print(left_right_context.size())
355 |
356 |
357 | att_weight = torch.matmul(self.att_w(left_right_context), self.att_v)
358 | att_weight = nn.functional.softmax(att_weight, dim=2).view(batch_size, num_context, 2*seqlen, 1)
359 |
360 | oe = (left_right_context * att_weight).sum(2)
361 |
362 | #print(oe.size())
363 |
364 | hier_att_weight = torch.matmul(self.att_h(oe), self.att_hv)
365 | #print(hier_att_weight.size())
366 |
367 | hier_att_weight = nn.functional.softmax(hier_att_weight, dim=1).view(batch_size, num_context, 1)
368 | #print(hier_att_weight.size())
369 |
370 | oe = (oe * hier_att_weight).sum(1)
371 |
372 | oe = self.output_layer(oe)
373 |
374 | return oe, att_weight, hier_att_weight
375 |
376 |
377 | class ATTENTION(nn.Module):
378 | """docstring for ATTENTION"""
379 | def __init__(self, input_dim, hidden_dim):
380 | super(ATTENTION, self).__init__()
381 | self.input_dim = input_dim
382 | self.hidden_dim = hidden_dim
383 | self.dropout_layer = nn.Dropout(0)
384 | self.left_context_encoder = nn.LSTM(input_dim, hidden_dim, 1, batch_first=True)
385 | self.right_context_encoder = nn.LSTM(input_dim, hidden_dim, 1, batch_first=True)
386 | self.att_w = nn.Linear(hidden_dim*2, hidden_dim)
387 | self.att_v = nn.Parameter(torch.rand(hidden_dim))
388 | self.output_layer = nn.Sequential(
389 | nn.Linear(hidden_dim*2, hidden_dim*2),
390 | nn.ReLU(),
391 | nn.Linear(hidden_dim*2, input_dim)
392 | )
393 |
394 | def forward(self, embed_input_left, embed_input_right):
395 | # input: [batch, context, seq, emb]
396 | batch_size, num_context, seqlen, emb_dim = embed_input_left.size()
397 |
398 | # [batch, context, seq, dim] -> [batch*context, seq, dim]
399 | embed_input_left = embed_input_left.view(-1, seqlen, self.input_dim)
400 | embed_input_left = self.dropout_layer(embed_input_left)
401 | embed_input_right = embed_input_right.view(-1, seqlen, self.input_dim)
402 | embed_input_right = self.dropout_layer(embed_input_right)
403 |
404 | # hidden = (torch.zeros(1, batch_size*num_context, self.hidden_dim),
405 | # torch.zeros(1, batch_size*num_context, self.hidden_dim))
406 |
407 | output_left, (final_hidden_state_left, final_cell_state_left) = self.left_context_encoder(embed_input_left) #, hidden)
408 | output_right,(final_hidden_state_right, final_cell_state_left) = self.right_context_encoder(embed_input_right) #, hidden)
409 |
410 | encode_context_left = final_hidden_state_left.view(-1, num_context, self.hidden_dim)
411 | encode_context_right = final_hidden_state_right.view(-1, num_context, self.hidden_dim)
412 |
413 | # concat + mean_pooling + fully_connect
414 | oe = torch.cat((encode_context_left, encode_context_right), 2)
415 | print(oe.size())
416 | att_weight = torch.matmul(self.att_w(oe), self.att_v)
417 | print(att_weight.size())
418 | att_weight = nn.functional.softmax(att_weight, dim=1).view(batch_size, num_context, 1)
419 | print(att_weight.size())
420 | oe = (oe * att_weight).sum(1)
421 |
422 | print("--------")
423 |
424 | oe = self.output_layer(oe)
425 | return oe
426 |
427 | class BertEncoder(nn.Module):
428 |
429 | def __init__(self, bert_dir, model_type="base"):
430 | super(BertEncoder, self).__init__()
431 | self.model_type = model_type
432 | self.model = BertModel.from_pretrained(bert_dir)
433 | self.set_finetune("full")
434 |
435 | def set_finetune(self, finetune_type):
436 |
437 | if finetune_type == "none":
438 | for param in self.model.parameters():
439 | param.requires_grad = False
440 | elif finetune_type == "full":
441 | for param in self.model.parameters():
442 | param.requires_grad = True
443 | elif finetune_type == "last":
444 | for param in self.model.parameters():
445 | param.require_grad = False
446 | for param in self.encoder.layer[-1].parameters():
447 | param.require_grad = True
448 |
449 | def forward(self, input_ids, mask=None):
450 |
451 | # [batch_size, context_num, seq_length]
452 | batch_size, context_num, seq_length = input_ids.size()
453 | flat_input_ids = input_ids.reshape(-1, input_ids.size(-1))
454 | flat_mask = mask.reshape(-1, mask.size(-1))
455 | pooled_cls = self.model(input_ids = flat_input_ids, attention_mask=flat_mask)[1]
456 | # [batch_size * context_num, dim]
457 | #print(pooled_cls.size())
458 |
459 | reshaped_pooled_cls = pooled_cls.view(batch_size, context_num, -1)
460 | # [batch_size, context_num, dim]
461 | output = reshaped_pooled_cls.mean(1)
462 | # [batch_size, dim]
463 | return output
464 |
465 | def get_output_dim(self):
466 | if self.model_type == "large":
467 | return 1024
468 | else:
469 | return 768
470 |
471 | class Bert2Score(nn.Module):
472 |
473 | def __init__(self, encoder, bert_dir, hidden_dim, drop_prob):
474 | super(Bert2Score, self).__init__()
475 | self.hidden_dim = hidden_dim
476 | if "large" in encoder:
477 | self.encoder = BertEncoder(bert_dir, "large")
478 | else:
479 | self.encoder = BertEncoder(bert_dir)
480 |
481 | bert_dim = self.encoder.get_output_dim()
482 | self.mlp1 = nn.Linear(bert_dim, hidden_dim)
483 | self.mlp2 = nn.Linear(bert_dim, hidden_dim)
484 | self.dropout = nn.Dropout(drop_prob)
485 |
486 | def forward(self, input_ids, masks):
487 | ## input: [batch_size, 2, context, seq]
488 | left_ids = input_ids[:,0,:,:]
489 | right_ids = input_ids[:,1,:,:]
490 |
491 | left_masks = masks[:,0,:,:]
492 | right_masks = masks[:,1,:,:]
493 |
494 | left_emb = self.encoder(left_ids, left_masks)
495 | right_emb = self.encoder(right_ids, right_masks)
496 |
497 | # [batch_size, hidden_dim]
498 | tran_left = self.mlp1(self.dropout(left_emb))
499 | tran_right = self.mlp2(self.dropout(right_emb))
500 |
501 | output = torch.einsum('ij,ij->i', [tran_left, tran_right])
502 | return output
503 |
504 | class Context2Score(nn.Module):
505 | """docstring for Context2Score"""
506 | def __init__(self, encoder, input_dim, hidden_dim, device, multiple=False):
507 | super(Context2Score, self).__init__()
508 | self.input_dim = input_dim
509 | self.hidden_dim = hidden_dim
510 | self.device = device
511 | self.attention = False
512 | self.hier = False
513 | #self.name = encoder
514 | if 'lstm' in encoder:
515 | if multiple:
516 | self.encoder1 = nn.DataParallel(LSTM(input_dim, hidden_dim), device_ids=[0,1,2,3])
517 | self.encoder2 = nn.DataParallel(LSTM(input_dim, hidden_dim), device_ids=[0,1,2,3])
518 | else:
519 | self.encoder1 = LSTM(input_dim, hidden_dim).to(device)
520 | self.encoder2 = LSTM(input_dim, hidden_dim).to(device)
521 | elif 'attention' in encoder:
522 | if multiple:
523 | self.encoder1 = ATTENTION(input_dim, hidden_dim)
524 | self.encoder2 = ATTENTION(input_dim, hidden_dim)
525 | else:
526 | self.encoder1 = ATTENTION(input_dim, hidden_dim).to(device)
527 | self.encoder2 = ATTENTION(input_dim, hidden_dim).to(device)
528 | elif 'max' in encoder:
529 | self.encoder1 = MEAN_Max(input_dim, hidden_dim).to(device)
530 | self.encoder2 = MEAN_Max(input_dim, hidden_dim).to(device)
531 | elif 'self' in encoder:
532 | #self.encoder1, self.atten1 = SelfAttention(input_dim, hidden_dim).to(device)
533 | self.encoder1 = SelfAttention(input_dim, hidden_dim).to(device)
534 | self.encoder2 = SelfAttention(input_dim, hidden_dim).to(device)
535 | self.attention = True
536 |
537 | elif 'han' in encoder:
538 | self.encoder1 = HierAttention(input_dim, hidden_dim).to(device)
539 | self.encoder2 = HierAttention(input_dim, hidden_dim).to(device)
540 | self.hier = True
541 |
542 | else:
543 | if multiple:
544 | self.encoder1 = MEAN(input_dim, hidden_dim)
545 | self.encoder2 = MEAN(input_dim, hidden_dim)
546 | else:
547 | self.encoder1 = MEAN(input_dim, hidden_dim).to(device)
548 | self.encoder2 = MEAN(input_dim, hidden_dim).to(device)
549 |
550 |
551 | def init_emb(self, w2v_weight):
552 | self.word_embedding = nn.Embedding.from_pretrained(w2v_weight, freeze=True)
553 |
554 | def forward(self, input_idx):
555 | # input: [batch, 2, context, 2, seq]
556 |
557 | embed_input1_left = self.word_embedding(input_idx[:, 0, :, 0]).to(self.device)
558 | embed_input1_right = self.word_embedding(input_idx[:, 0, :, 1]).to(self.device)
559 | embed_input2_left = self.word_embedding(input_idx[:, 1, :, 0]).to(self.device)
560 | embed_input2_right = self.word_embedding(input_idx[:, 1, :, 1]).to(self.device)
561 |
562 | if self.attention:
563 | embed_hypo, atten1 = self.encoder1(embed_input1_left, embed_input1_right)
564 | embed_hype, atten2 = self.encoder2(embed_input2_left, embed_input2_right)
565 |
566 | output = torch.einsum('ij,ij->i', [embed_hypo, embed_hype])
567 | return output, atten1, atten2
568 |
569 | elif self.hier:
570 |
571 | embed_hypo, atten1, hier_atten1 = self.encoder1(embed_input1_left, embed_input1_right)
572 | embed_hype, atten2, hier_atten2 = self.encoder2(embed_input2_left, embed_input2_right)
573 |
574 | output = torch.einsum('ij,ij->i', [embed_hypo, embed_hype])
575 |
576 | atten_w = (atten1, hier_atten1, atten2, hier_atten2)
577 |
578 | return output, atten_w
579 |
580 | else:
581 | embed_hypo = self.encoder1(embed_input1_left, embed_input1_right)
582 | embed_hype = self.encoder2(embed_input2_left,embed_input2_right)
583 | output = torch.einsum('ij,ij->i', [embed_hypo, embed_hype])
584 |
585 | return output
586 |
587 |
--------------------------------------------------------------------------------
/model/visual_attention.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import os
3 | import random
4 | import logging
5 | import torch
6 | import scipy
7 | import numpy as np
8 | import pandas as pd
9 | from utils.util import oe_score
10 | from utils.util import cosine_distance
11 | from utils.util import asymmetric_distance
12 | from utils.util import load_word_vectors
13 | from model.models import *
14 | from utils.data_helper_4context import Dataset
15 | from gensim.models import Word2Vec
16 | from sklearn.metrics import average_precision_score,precision_recall_curve
17 | import configparser
18 |
19 | config = configparser.RawConfigParser()
20 | config.read(sys.argv[1])
21 |
22 | from codecs import open
23 | def createHTML(texts, weights, fileName):
24 | """
25 | Creates a html file with text heat.
26 | weights: attention weights for visualizing
27 | texts: text on which attention weights are to be visualized
28 | """
29 | fOut = open(fileName, "w", encoding="utf-8")
30 | part1 = """
31 |
32 |
33 |
34 |
39 |
40 |
41 |
42 | Heatmaps
43 |
44 |
45 |
95 | """
96 | putQuote = lambda x: "\"%s\""%x
97 | #putQuote = lambda x: "%s"%x
98 | textsString = "var any_text = [%s];\n"%(",".join(map(putQuote, texts)))
99 | weightsString = "var trigram_weights = [%s];\n"%(",".join(map(str,weights)))
100 | fOut.write(part1)
101 | fOut.write(textsString)
102 | fOut.write(weightsString)
103 | fOut.write(part2)
104 | fOut.close()
105 |
106 | return 0
107 |
108 |
109 | def make_hparam_string(config):
110 | hparam = "{}/s{}_h{}-{}_n{}_c{}-{}_b{}".format(
111 | config.get("hyperparameters", "model"),
112 | config.get("hyperparameters", "svd_dimension"),
113 | config.get("hyperparameters", "number_hidden_layers"),
114 | config.get("hyperparameters", "hidden_layer_size"),
115 | config.get("hyperparameters", "negative_num"),
116 | # config.get("hyperparameters", "weight_decay"),
117 | config.get("hyperparameters", "context_num"),
118 | config.get("hyperparameters", "context_len"),
119 | config.get("hyperparameters", "batch_size")
120 | )
121 | return hparam
122 |
123 | def init_model(config, ckpt_path, init_w2v_embedding, device):
124 |
125 | encoder_type = config.get("hyperparameters", "model")
126 | number_hidden_layers = int(config.getfloat("hyperparameters", "number_hidden_layers"))
127 | hidden_layer_size = int(config.getfloat("hyperparameters", "hidden_layer_size"))
128 |
129 | model = Context2Score(encoder_type, 300, hidden_layer_size, device)
130 |
131 | pretrain = torch.load(ckpt_path)
132 | # pretrain.pop("word_embedding.weight")
133 | model.load_state_dict(pretrain)
134 | model.init_emb(torch.FloatTensor(init_w2v_embedding))
135 | model.eval()
136 |
137 | return model
138 |
139 | ckpt_dir = config.get("data", "ckpt")
140 | hparam = make_hparam_string(config)
141 | ckpt_dir = os.path.join(ckpt_dir, hparam)
142 | log_path = os.path.join(ckpt_dir, 'eval_last_p.log')
143 | ckpt_path = os.path.join(ckpt_dir, 'best.ckpt')
144 |
145 | dataset = Dataset(config, train=False)
146 | gpu_device = config.get("hyperparameters", "gpu_device")
147 | device = torch.device('cuda:{}'.format(gpu_device) if torch.cuda.is_available() else 'cpu')
148 |
149 | model = init_model(config, ckpt_path, dataset.context_word_emb, device)
150 |
151 | ## To visusialize attention
152 |
153 | #hypon = "vicarage"
154 | #hyper = "building"
155 | #hypon = "calamus"
156 | #hyper = "specie"
157 | #hypon = "pontoon"
158 | #hyper = "boat"
159 | #hypon = "polymerase"
160 | #hyper = "enzyme"
161 | #hyper = "chemical"
162 |
163 | hypon = "kinetoscope"
164 | hyper = "device"
165 |
166 | hypon_id = dataset.context_w2i[hypon]
167 | hyper_id = dataset.context_w2i[hyper]
168 | print(hypon_id)
169 | hypon_word_context = dataset.context_dict[hypon_id]
170 |
171 | print(hypon_word_context)
172 |
173 | hypon_word = dataset.load_prediction_word_context(hypon_id)
174 | print(hypon_word)
175 |
176 |
177 | hyper_word_context = dataset.context_dict[hyper_id]
178 | hyper_word = dataset.load_prediction_word_context(hyper_id)
179 | print(hyper_word)
180 |
181 | model_name = config.get("hyperparameters", "model")
182 |
183 | inputs = torch.tensor(np.asarray([[hypon_word_context, hyper_word_context]]), dtype=torch.long)
184 | output = model(inputs)
185 | score = output[0].detach().cpu().numpy()[0]
186 |
187 | if "han" in model_name:
188 | attention1 = torch.squeeze(output[1][0]).detach().cpu().numpy()
189 | h_att1 = torch.squeeze(output[1][1]).detach().cpu().numpy()
190 | print("The attention weights of hyponymy is : ")
191 | print(h_att1)
192 |
193 | array = np.asarray(h_att1, dtype=np.float32)
194 | tmp = array.argsort()
195 |
196 | print(tmp)
197 | ranks = np.empty_like(tmp)
198 | ranks[tmp] = np.arange(len(array))
199 |
200 | print(ranks)
201 |
202 |
203 | attention2 = torch.squeeze(output[1][2]).detach().cpu().numpy()
204 | h_att2 = torch.squeeze(output[1][3]).detach().cpu().numpy()
205 | print("The attention weights of hypernymy is : ")
206 | print(h_att2)
207 |
208 | text = [hypon_word[tmp[i]][0] + hypon_word[tmp[i]][1] for i in range(len(hypon_word))]
209 |
210 | weights = [attention1[tmp[i]].tolist() for i in range(len(attention1))]
211 |
212 | text2 = [hyper_word[i][0] + hyper_word[i][1] for i in range(len(hyper_word))]
213 |
214 | for i in range(len(text)):
215 | for j in range(len(text[0])):
216 | if '"' or "'" in text[i][j]:
217 | text[i][j] = text[i][j].replace('"', "/").replace("'","/")
218 |
219 | weights2 = [attention2[i].tolist() for i in range(len(attention2))]
220 | print(text)
221 | #print(weights)
222 | file_name = "vis_" + model_name + "_" + hypon + ".html"
223 | createHTML(text, weights, file_name)
224 |
225 | else:
226 | attention1 = torch.squeeze(output[1]).detach().cpu().numpy()
227 |
228 | attention2 = torch.squeeze(output[2]).detach().cpu().numpy()
229 |
230 |
231 | text = [hypon_word[i][0] + hypon_word[i][1][::-1] for i in range(len(hypon_word))]
232 |
233 | print(text)
234 |
235 | att_weights = [attention1[i].tolist() for i in range(len(attention1))]
236 |
237 | print(att_weights)
238 |
239 | weights = [att_weights[i][:10] + att_weights[i][10:][::-1] for i in range(len(att_weights))]
240 |
241 | print(weights)
242 |
243 | text2 = [hyper_word[i][0] + hyper_word[i][1] for i in range(len(hyper_word))]
244 |
245 | for i in range(len(text)):
246 | for j in range(len(text[0])):
247 | if '"' or "'" in text[i][j]:
248 | text[i][j] = text[i][j].replace('"', "/").replace("'","/")
249 |
250 | weights2 = [attention2[i].tolist() for i in range(len(attention2))]
251 | print(text)
252 | #print(weights)
253 | file_name = "vis_" + model_name + "_" + hypon + ".html"
254 | createHTML(text, weights, file_name)
255 |
256 |
257 |
258 |
--------------------------------------------------------------------------------
/train_bert2score.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import time
3 | import os
4 | import logging
5 | import tqdm
6 | import numpy as np
7 | import torch
8 | import torch.nn as nn
9 | import torch.nn.functional as F
10 | import torch.optim as optim
11 | import torch.utils.data as Dataset
12 | import torch.distributed as dist
13 | import torch.multiprocessing as mp
14 |
15 | import configparser
16 | from model.models import *
17 | from utils.data_helper_4bert import Dataset
18 |
19 | logger = logging.getLogger()
20 |
21 | def make_hparam_string(config):
22 | hparam = "{}/s{}_h{}-{}_n{}_c{}-{}_b{}".format(
23 | config.get("hyperparameters", "model"),
24 | config.get("hyperparameters", "svd_dimension"),
25 | config.get("hyperparameters", "number_hidden_layers"),
26 | config.get("hyperparameters", "hidden_layer_size"),
27 | config.get("hyperparameters", "negative_num"),
28 | # config.get("hyperparameters", "weight_decay"),
29 | config.get("hyperparameters", "context_num"),
30 | config.get("hyperparameters", "context_len"),
31 | config.get("hyperparameters", "batch_size")
32 | )
33 | return hparam
34 |
35 | def init_model(config, device):
36 |
37 | encoder_type = config.get("hyperparameters", "model")
38 | number_hidden_layers = int(config.getfloat("hyperparameters", "number_hidden_layers"))
39 | hidden_layer_size = int(config.getfloat("hyperparameters", "hidden_layer_size"))
40 | bert_dir = config.get("data", "bert_path")
41 | model = Bert2Score(encoder_type, bert_dir, hidden_layer_size, 0.1)
42 | #torch.distributed.init_process_group(backend="nccl")
43 | #model = nn.DistributedDataParallel(model)
44 | model = nn.DataParallel(model)
45 | model.to(device)
46 | return model
47 |
48 | def evaluation(model, loss_func, dataset, device):
49 |
50 | model.eval()
51 | pred_score = []
52 | for batch_data in dataset.sample_batch_dev():
53 | batch_context, batch_mask = batch_data
54 | context_tensor = torch.tensor(batch_context, dtype=torch.long)
55 | mask_tensor = torch.tensor(batch_mask, dtype=torch.long)
56 | output = model(context_tensor, mask_tensor).detach().cpu().numpy()
57 | pred_score.extend(output)
58 |
59 | dev_input = torch.tensor(np.asarray(pred_score), dtype=torch.float).to(device)
60 | dev_label = torch.tensor(dataset.dev_label, dtype=torch.float).to(device)
61 |
62 | loss = loss_func(dev_input, dev_label)
63 |
64 | return float(loss.data)
65 |
66 |
67 | if __name__ == "__main__":
68 | config = configparser.RawConfigParser()
69 | config.read(sys.argv[1])
70 |
71 | gpu_device = config.get("hyperparameters", "gpu_device")
72 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
73 |
74 | ckpt_dir = config.get("data", "ckpt")
75 | hparam = make_hparam_string(config)
76 | ckpt_dir = os.path.join(ckpt_dir, hparam)
77 | if not os.path.exists(ckpt_dir):
78 | os.makedirs(ckpt_dir)
79 | log_path = os.path.join(ckpt_dir, 'train.log')
80 | best_ckpt_path = os.path.join(ckpt_dir, 'best.ckpt')
81 | last_ckpt_path = os.path.join(ckpt_dir, 'last.ckpt')
82 |
83 | logger.setLevel(logging.INFO)
84 | handler = logging.FileHandler(log_path, 'w')
85 | handler.setLevel(logging.INFO)
86 | formatter = logging.Formatter('%(asctime)s: %(message)s', datefmt='%Y/%m/%d %H:%M:%S')
87 | handler.setFormatter(formatter)
88 | logger.addHandler(handler)
89 |
90 | epochs = int(config.getfloat("hyperparameters", "max_epochs"))
91 | learning_rate = config.getfloat("hyperparameters", "learning_rate")
92 | weight_decay = config.getfloat("hyperparameters", "weight_decay")
93 |
94 | dataset = Dataset(config)
95 |
96 | model = init_model(config, device)
97 | logger.info(model)
98 | parameters = filter(lambda p: p.requires_grad, model.parameters())
99 | #total_parameters = sum(p.numel() for p in parameters)
100 | #logging.info("| there are totally {} trainable parameters".format(total_parameters))
101 |
102 | for name, param in model.named_parameters():
103 | print(name)
104 |
105 | loss_func = nn.MSELoss()
106 | optimizer = optim.Adam(parameters, lr=learning_rate)
107 |
108 | least_loss = 9999999
109 |
110 | # model.to(device)
111 | loss_func.to(device)
112 |
113 | for epoch in range(epochs):
114 |
115 | model.train()
116 | total_loss = 0
117 | total_mse = 0
118 | step = 0
119 | for batch_data in dataset.sample_batch():
120 |
121 | batch_x, batch_mask, batch_y = batch_data
122 |
123 | context_inputs = torch.tensor(batch_x, dtype=torch.long).to(device)
124 | mask_inputs = torch.tensor(batch_mask, dtype=torch.long).to(device)
125 | batch_y = torch.tensor(batch_y, dtype=torch.float).to(device)
126 |
127 | output = model(context_inputs, mask_inputs)
128 |
129 | labels = batch_y.to(device)
130 | mse_loss = loss_func(output, labels)
131 | loss = mse_loss #+ weight_decay * norm
132 | loss.backward()
133 | optimizer.step()
134 | optimizer.zero_grad()
135 |
136 | step +=1
137 | total_loss += float(loss.data)
138 | # total_mse += float(mse_loss.data)
139 | if step % 200 == 0:
140 | logger.info('| Epoch: {} | step: {} | mse {:5f}'.format(epoch, step, float(mse_loss.data)))
141 |
142 | logger.info('| Epoch: {} | mean mse {:.5f}'.format(epoch, total_mse /step))
143 |
144 | dev_loss = evaluation(model, loss_func, dataset, device)
145 | if dev_loss < least_loss:
146 | least_loss = dev_loss
147 | # torch.save([model, optimizer, loss_func], ckpt_path)
148 | save_model = model.module.state_dict().copy()
149 | torch.save(save_model, best_ckpt_path)
150 | # torch.save(model.state_dict(), ckpt_path)
151 | logger.info('| Epoch: {} | mean dev mse: {:.5f} | saved'.format(epoch, dev_loss))
152 | else:
153 | save_model = model.module.state_dict().copy()
154 | #save_model.pop('word_embedding.weight')
155 | torch.save(save_model, last_ckpt_path)
156 | logger.info('| Epoch: {} | mean dev mse: {:.5f} |'.format(epoch, dev_loss))
157 |
--------------------------------------------------------------------------------
/train_context2score.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import time
3 | import os
4 | import logging
5 | import tqdm
6 | import numpy as np
7 | import torch
8 | import torch.nn as nn
9 | import torch.nn.functional as F
10 | import torch.optim as optim
11 | import torch.utils.data as Dataset
12 | import torch.distributed as dist
13 | import torch.multiprocessing as mp
14 |
15 | import configparser
16 | from model.models import *
17 | from utils.data_helper_4context import Dataset
18 |
19 | logger = logging.getLogger()
20 |
21 | def make_hparam_string(config):
22 | hparam = "{}/s{}_h{}-{}_n{}_c{}-{}_b{}".format(
23 | config.get("hyperparameters", "model"),
24 | config.get("hyperparameters", "svd_dimension"),
25 | config.get("hyperparameters", "number_hidden_layers"),
26 | config.get("hyperparameters", "hidden_layer_size"),
27 | config.get("hyperparameters", "negative_num"),
28 | # config.get("hyperparameters", "weight_decay"),
29 | config.get("hyperparameters", "context_num"),
30 | config.get("hyperparameters", "context_len"),
31 | config.get("hyperparameters", "batch_size")
32 | )
33 | return hparam
34 |
35 | def init_model(config, init_w2v_embedding, device, multiple):
36 |
37 | encoder_type = config.get("hyperparameters", "model")
38 | number_hidden_layers = int(config.getfloat("hyperparameters", "number_hidden_layers"))
39 | hidden_layer_size = int(config.getfloat("hyperparameters", "hidden_layer_size"))
40 |
41 | model = Context2Score(encoder_type, 300, hidden_layer_size, device, multiple)
42 | model.init_emb(torch.FloatTensor(init_w2v_embedding))
43 | return model
44 |
45 | def evaluation(model, loss_func, dataset, device):
46 |
47 | model.eval()
48 | pred_score = []
49 | for batch_data in dataset.sample_batch_dev():
50 | batch_x = batch_data
51 | inputs = torch.tensor(batch_x, dtype=torch.long)
52 | output = model(inputs).detach().cpu().numpy()
53 | pred_score.extend(output)
54 |
55 | dev_input = torch.tensor(np.asarray(pred_score), dtype=torch.float).to(device)
56 | dev_label = torch.tensor(dataset.dev_label, dtype=torch.float).to(device)
57 |
58 | loss = loss_func(dev_input, dev_label)
59 |
60 | return float(loss.data)
61 |
62 | def evaluation_attention(model, loss_func, dataset, device):
63 |
64 | model.eval()
65 | pred_score = []
66 | for batch_data in dataset.sample_batch_dev():
67 | batch_x = batch_data
68 | inputs = torch.tensor(batch_x, dtype=torch.long)
69 | output = model(inputs)[0].detach().cpu().numpy()
70 | pred_score.extend(output)
71 |
72 | dev_input = torch.tensor(np.asarray(pred_score), dtype=torch.float).to(device)
73 | dev_label = torch.tensor(dataset.dev_label, dtype=torch.float).to(device)
74 |
75 | loss = loss_func(dev_input, dev_label)
76 |
77 | return float(loss.data)
78 |
79 | if __name__ == "__main__":
80 | config = configparser.RawConfigParser()
81 | config.read(sys.argv[1])
82 |
83 | gpu_device = config.get("hyperparameters", "gpu_device")
84 | device = torch.device('cuda:{}'.format(gpu_device) if torch.cuda.is_available() else 'cpu')
85 |
86 | ckpt_dir = config.get("data", "ckpt")
87 | hparam = make_hparam_string(config)
88 | ckpt_dir = os.path.join(ckpt_dir, hparam)
89 | if not os.path.exists(ckpt_dir):
90 | os.makedirs(ckpt_dir)
91 | log_path = os.path.join(ckpt_dir, 'train.log')
92 | best_ckpt_path = os.path.join(ckpt_dir, 'best.ckpt')
93 | last_ckpt_path = os.path.join(ckpt_dir, 'last.ckpt')
94 |
95 | logger.setLevel(logging.INFO)
96 | handler = logging.FileHandler(log_path, 'w')
97 | handler.setLevel(logging.INFO)
98 | formatter = logging.Formatter('%(asctime)s: %(message)s', datefmt='%Y/%m/%d %H:%M:%S')
99 | handler.setFormatter(formatter)
100 | logger.addHandler(handler)
101 |
102 | epochs = int(config.getfloat("hyperparameters", "max_epochs"))
103 | learning_rate = config.getfloat("hyperparameters", "learning_rate")
104 | weight_decay = config.getfloat("hyperparameters", "weight_decay")
105 |
106 | dataset = Dataset(config)
107 |
108 | model = init_model(config, dataset.context_word_emb, device, False)
109 | print(dataset.context_word_emb.shape)
110 | dataset.context_word_emb = None
111 | logger.info(model)
112 | parameters = filter(lambda p: p.requires_grad, model.parameters())
113 |
114 | for name, param in model.named_parameters():
115 | print(name)
116 |
117 |
118 | loss_func = nn.MSELoss()
119 | optimizer = optim.Adam(parameters, lr=learning_rate)
120 |
121 | least_loss = 9999999
122 |
123 | # model.to(device)
124 | loss_func.to(device)
125 |
126 | for epoch in range(epochs):
127 |
128 | model.train()
129 | total_loss = 0
130 | total_mse = 0
131 | step = 0
132 | for batch_data in dataset.sample_batch():
133 |
134 | batch_x, batch_y = batch_data
135 |
136 | inputs = torch.tensor(batch_x, dtype=torch.long)
137 | batch_y = torch.tensor(batch_y, dtype=torch.float).to(device)
138 |
139 | if "self" or "han" in config.get("hyperparameters", "model"):
140 | output = model(inputs)[0]
141 | else:
142 | output = model(inputs)
143 |
144 | print(inputs.size())
145 | print(batch_y.size())
146 | print(output.size())
147 |
148 | labels = batch_y.to(device)
149 | mse_loss = loss_func(output, labels)
150 | loss = mse_loss #+ weight_decay * norm
151 | loss.backward()
152 | optimizer.step()
153 | optimizer.zero_grad()
154 |
155 | step +=1
156 | total_loss += float(loss.data)
157 | # total_mse += float(mse_loss.data)
158 | if step % 200 == 0:
159 | logger.info('| Epoch: {} | step: {} | mse {:5f}'.format(epoch, step, float(mse_loss.data)))
160 |
161 | logger.info('| Epoch: {} | mean mse {:.5f}'.format(epoch, total_mse /step))
162 |
163 |
164 | if "self" or "han" in config.get("hyperparameters", "model"):
165 | dev_loss = evaluation_attention(model, loss_func, dataset, device)
166 | else:
167 | dev_loss = evaluation(model, loss_func, dataset, device)
168 | if dev_loss < least_loss:
169 | least_loss = dev_loss
170 | # torch.save([model, optimizer, loss_func], ckpt_path)
171 | save_model = model.state_dict().copy()
172 | save_model.pop('word_embedding.weight')
173 | torch.save(save_model, best_ckpt_path)
174 | # torch.save(model.state_dict(), ckpt_path)
175 | logger.info('| Epoch: {} | mean dev mse: {:.5f} | saved'.format(epoch, dev_loss))
176 | else:
177 | save_model = model.state_dict().copy()
178 | save_model.pop('word_embedding.weight')
179 | torch.save(save_model, last_ckpt_path)
180 | logger.info('| Epoch: {} | mean dev mse: {:.5f} |'.format(epoch, dev_loss))
181 |
--------------------------------------------------------------------------------
/train_word2score.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import time
3 | import os
4 | import logging
5 | import tqdm
6 | import torch
7 | import torch.nn as nn
8 | import torch.nn.functional as F
9 | import torch.optim as optim
10 | import torch.utils.data as Dataset
11 |
12 | import configparser
13 | from model.models import *
14 | from model.models import Word2Score
15 | from utils.data_helper import Dataset
16 |
17 | logger = logging.getLogger()
18 |
19 | def make_hparam_string(config):
20 | hparam = "{}/s{}_h{}-{}_n{}_b{}".format(
21 | config.get("hyperparameters", "model"),
22 | config.get("hyperparameters", "svd_dimension"),
23 | config.get("hyperparameters", "number_hidden_layers"),
24 | config.get("hyperparameters", "hidden_layer_size"),
25 | config.get("hyperparameters", "negative_num"),
26 | config.get("hyperparameters", "batch_size"),
27 | # config.get("hyperparameters", "context_num"),
28 | # config.get("hyperparameters", "context_len")
29 | )
30 | return hparam
31 |
32 | def init_model(config, init_w2v_embedding, device):
33 |
34 | number_hidden_layers = int(config.getfloat("hyperparameters", "number_hidden_layers"))
35 | hidden_layer_size = int(config.getfloat("hyperparameters", "hidden_layer_size"))
36 |
37 | model = Word2Score(hidden_layer_size, number_hidden_layers)
38 | model.init_emb(torch.FloatTensor(init_w2v_embedding))
39 | return model
40 |
41 | def evaluation(model, loss_func, dataset, device):
42 |
43 | model.eval()
44 | dev_input = torch.tensor(dataset.dev_data, dtype=torch.long).to(device)
45 | dev_label = torch.tensor(dataset.dev_label, dtype=torch.float).to(device)
46 |
47 | output, _ = model(dev_input)
48 | loss = loss_func(output, dev_label)
49 |
50 | return float(loss.data)
51 |
52 | if __name__ == "__main__":
53 | config = configparser.RawConfigParser()
54 | config.read(sys.argv[1])
55 |
56 | gpu_device = config.get("hyperparameters", "gpu_device")
57 | device = torch.device('cuda:{}'.format(gpu_device) if torch.cuda.is_available() else 'cpu')
58 |
59 | ckpt_dir = config.get("data", "ckpt")
60 | hparam = make_hparam_string(config)
61 | ckpt_dir = os.path.join(ckpt_dir, hparam)
62 | if not os.path.exists(ckpt_dir):
63 | os.makedirs(ckpt_dir)
64 | log_path = os.path.join(ckpt_dir, 'train.log')
65 | ckpt_path = os.path.join(ckpt_dir, 'best.ckpt')
66 |
67 | logger.setLevel(logging.INFO)
68 | handler = logging.FileHandler(log_path, 'w')
69 | handler.setLevel(logging.INFO)
70 | formatter = logging.Formatter('%(asctime)s: %(message)s', datefmt='%Y/%m/%d %H:%M:%S')
71 | handler.setFormatter(formatter)
72 | logger.addHandler(handler)
73 |
74 | epochs = int(config.getfloat("hyperparameters", "max_epochs"))
75 | learning_rate = config.getfloat("hyperparameters", "learning_rate")
76 | weight_decay = config.getfloat("hyperparameters", "weight_decay")
77 |
78 | dataset = Dataset(config)
79 |
80 | model = init_model(config, dataset.wordvec_weights, device)
81 | dataset.wordvec_weights = None
82 | logger.info(model)
83 | parameters = filter(lambda p: p.requires_grad, model.parameters())
84 |
85 | loss_func = nn.MSELoss()
86 | optimizer = optim.Adam(parameters, lr=learning_rate)
87 |
88 | least_loss = 9999999
89 |
90 | model.to(device)
91 | loss_func.to(device)
92 |
93 | for epoch in range(epochs):
94 |
95 | model.train()
96 | total_loss = 0
97 | total_mse = 0
98 | step = 0
99 | for batch_data in dataset.sample_pos_neg_batch():
100 |
101 | batch_x, batch_y = batch_data
102 |
103 | inputs = torch.tensor(batch_x, dtype=torch.long).to(device)
104 | batch_y = torch.tensor(batch_y, dtype=torch.float).to(device)
105 |
106 | output, norm = model(inputs)
107 | labels = batch_y.to(device)
108 | mse_loss = loss_func(output, labels)
109 | loss = mse_loss + weight_decay * norm
110 | loss.backward()
111 | optimizer.step()
112 | optimizer.zero_grad()
113 |
114 | step +=1
115 | total_loss += float(loss.data)
116 | total_mse += float(mse_loss.data)
117 | if step % 200 == 0:
118 | logger.info('| Epoch: {} | step: {} | mse {:5f} | loss {:5f}'.format(epoch, step, float(mse_loss.data), float(loss.data)))
119 |
120 | logger.info('| Epoch: {} | mean mse {:.5f}, loss {:.5f}'.format(epoch, total_mse /step, total_loss / step))
121 |
122 | dev_loss = evaluation(model, loss_func, dataset, device)
123 | if dev_loss < least_loss:
124 | least_loss = dev_loss
125 | # torch.save([model, optimizer, loss_func], ckpt_path)
126 | torch.save(model.state_dict(), ckpt_path)
127 | logger.info('| Epoch: {} | mean dev mse: {:.5f} | best'.format(epoch, dev_loss))
128 | else:
129 | logger.info('| Epoch: {} | mean dev mse: {:.5f} |'.format(epoch, dev_loss))
130 |
--------------------------------------------------------------------------------
/utils/data_helper.py:
--------------------------------------------------------------------------------
1 | from .loader import read_sparse_matrix
2 | from .util import load_phrase_word2vec
3 | from .util import load_gensim_word2vec
4 |
5 | import torch
6 | import numpy as np
7 | import scipy.sparse as sparse
8 |
9 | class Dataset(object):
10 | """docstring for Dataset"""
11 | def __init__(self, config, svd=False, train=True):
12 | # generate ppmi matrix for co-occurence
13 | pattern_filename = config.get("data", "pattern_filename")
14 |
15 | k = int(config.getfloat("hyperparameters", "svd_dimension"))
16 | self.batch_size = int(config.getfloat("hyperparameters", "batch_size"))
17 | self.negative_num = int(config.getfloat("hyperparameters", "negative_num"))
18 |
19 | csr_m, self.id2word, self.vocab, _ = read_sparse_matrix(
20 | pattern_filename, same_vocab=True)
21 |
22 | self.word2id = {}
23 | for i in range(len(self.id2word)):
24 | self.word2id[self.id2word[i]] = i
25 |
26 | self.matrix = csr_m.todok()
27 | self.p_w = csr_m.sum(axis=1).A[:,0]
28 | self.p_c = csr_m.sum(axis=0).A[0,:]
29 | self.N = self.p_w.sum()
30 |
31 | # for w2v
32 | if train:
33 | #self.wordvecs = load_phrase_word2vec("/home/shared/acl-data/embedding/ukwac.model", self.vocab)
34 | self.wordvecs = load_gensim_word2vec("/home/shared/acl-data/embedding/ukwac.model", self.vocab)
35 | #print(self.wordvecs["united_states"])
36 |
37 | self.wordvec_weights = self.build_emb()
38 |
39 | tr_matrix = sparse.dok_matrix(self.matrix.shape)
40 | #print(self.matrix.shape)
41 |
42 | self.left_has = {}
43 | self.right_has = {}
44 | for (l,r) in self.matrix.keys():
45 | pmi_lr = (np.log(self.N) + np.log(self.matrix[(l,r)])
46 | - np.log(self.p_w[l]) - np.log(self.p_c[r]))
47 |
48 | ppmi_lr = np.clip(pmi_lr, 0.0, 1e12)
49 | tr_matrix[(l,r)] = ppmi_lr
50 |
51 | if l not in self.left_has:
52 | self.left_has[l] = []
53 | self.left_has[l].append(r)
54 | if r not in self.right_has:
55 | self.right_has[r] = []
56 | self.right_has[r].append(l)
57 |
58 | self.ppmi_matrix = tr_matrix
59 |
60 | U, S, V = sparse.linalg.svds(self.ppmi_matrix.tocsr(), k=k)
61 | self.U = U.dot(np.diag(S))
62 | self.V = V.T
63 |
64 | if train:
65 | # self.positive_data, self.positive_label = self.generate_positive()
66 | self.get_avail_vocab()
67 |
68 | def get_avail_vocab(self):
69 | avail_vocab = []
70 | for idx in range(len(self.vocab)):
71 | if self.id2word[idx] in self.wordvecs:
72 | avail_vocab.append(idx)
73 | self.avail_vocab = np.asarray(avail_vocab)
74 | shuffle_indices_left = np.random.permutation(len(self.avail_vocab))[:20000]
75 | shuffle_indices_right = np.random.permutation(len(self.avail_vocab))[:20000]
76 | dev_data = []
77 | dev_label = []
78 | self.dev_dict = {}
79 | for id_case in range(20000):
80 | id_left = self.avail_vocab[shuffle_indices_left[id_case]]
81 | id_right = self.avail_vocab[shuffle_indices_right[id_case]]
82 | dev_data.append([self.w2embid[id_left],self.w2embid[id_right]])
83 | dev_label.append(self.U[id_left].dot(self.V[id_right]))
84 | self.dev_dict[(id_left, id_right)] = 1
85 | self.dev_data = np.asarray(dev_data)
86 | self.dev_label = np.asarray(dev_label)
87 |
88 | def build_emb(self):
89 |
90 | tensors = []
91 | ivocab = []
92 | self.w2embid = {}
93 | self.embid2w = {}
94 |
95 | for word in self.wordvecs:
96 | vec = torch.from_numpy(self.wordvecs[word])
97 | self.w2embid[self.word2id[word]] = len(ivocab)
98 | self.embid2w[len(ivocab)] = self.word2id[word]
99 |
100 | ivocab.append(word)
101 | tensors.append(vec)
102 |
103 | assert len(tensors) == len(ivocab)
104 | print(len(tensors))
105 | tensors = torch.cat(tensors).view(len(ivocab), 300)
106 |
107 | return tensors
108 |
109 | def load_vocab(self, w2v_dir, data_dir):
110 | i2w_path = os.path.join(data_dir, 'ukwac_id2word.pkl')
111 | w2i_path = os.path.join(data_dir, 'ukwac_word2id.pkl')
112 | with open(i2w_path, 'rb') as fr:
113 | self.context_i2w = pickle.load(fr)
114 | with open(w2i_path, 'rb') as fr:
115 | self.context_w2i = pickle.load(fr)
116 |
117 | self.PAD = 0
118 | self.UNK = 1
119 |
120 | # w2v_model = Word2Vec.load(w2v_path)
121 | # emb = w2v_model.wv
122 | # oi2ni = {}
123 | # new_embedding = []
124 | # new_embedding.append(np.zeros(300))
125 | # new_embedding.append(np.zeros(300))
126 | # cnt_ni = 2
127 | # for _id, word in i2w.items():
128 | # if word in emb:
129 | # oi2ni[_id] = cnt_ni
130 | # cnt_ni += 1
131 | # new_embedding.append(emb[word])
132 | # else:
133 | # oi2ni[_id] = self.UNK
134 |
135 | oi2ni_path = os.path.join(w2v_dir, 'context_word_oi2ni.pkl')
136 | w2v_path = os.path.join(w2v_dir, 'context_word_w2v.model.npy')
137 | with open(oi2ni_path, 'rb') as fr:
138 | self.context_i2embid = pickle.load(fr)
139 | self.context_word_emb = np.load(w2v_path)
140 |
141 |
142 | def generate_positive(self):
143 |
144 | positive = []
145 | label = []
146 | key_list = list(self.ppmi_matrix.keys())
147 | shuffle_indices = np.random.permutation(len(key_list))
148 |
149 | for shuffle_id in shuffle_indices:
150 | (l, r) = key_list[shuffle_id]
151 | if self.id2word[l] in self.wordvecs and self.id2word[r] in self.wordvecs:
152 | positive.append([self.w2embid[l],self.w2embid[r]])
153 | # if l in self.context_dict and r in self.context_dict:
154 | # positive.append([l, r])
155 | score = self.U[l].dot(self.V[r])
156 | label.append(score)
157 | # label.append(self.ppmi_matrix[(l,r)])
158 | # 119448 positive score
159 | positive_train = np.asarray(positive)[:-2000]
160 |
161 | self.dev_data = np.asarray(positive)[-2000:]
162 |
163 | label_train = np.asarray(label)[:-2000]
164 | self.dev_label = np.asarray(label)[-2000:]
165 |
166 | return positive_train, label_train
167 |
168 | def generate_negative(self, batch_data, negative_num):
169 |
170 | negative = []
171 | label = []
172 |
173 | batch_size = batch_data.shape[0]
174 |
175 | for i in range(batch_size):
176 | # random_idx = np.random.choice(len(self.vocab), 150 , replace=False)
177 | l = batch_data[i][0]
178 | l_w = self.embid2w[l]
179 | r = batch_data[i][1]
180 | r_w = self.embid2w[r]
181 |
182 | l_neg = l_w
183 | r_neg = r_w
184 |
185 | num = 0
186 | for j in range(negative_num):
187 | left_prob = np.random.binomial(1, 0.5)
188 | # while True:
189 | if left_prob:
190 | l_neg = np.random.choice(self.avail_vocab, 1)[0]
191 | else:
192 | r_neg = np.random.choice(self.avail_vocab, 1)[0]
193 | # if (l_neg, r_neg) not in self.matrix.keys() and self.id2word[l_neg] in self.wordvecs and self.id2word[r_neg] in self.wordvecs:
194 | # if (l_neg, r_neg) not in self.matrix.keys() and self.l_neg in self.context_dict and self.r_neg in self.context_dict:
195 | # break
196 |
197 | negative.append([self.w2embid[l_neg], self.w2embid[r_neg]])
198 | # negative.append([self.context_dict[l_neg], self.context_dict[r_neg]])
199 | score = self.U[l_neg].dot(self.V[r_neg])
200 | # score = 0
201 | label.append(score)
202 |
203 | negative = np.asarray(negative)
204 | label = np.asarray(label)
205 | return negative, label
206 |
207 |
208 | def get_batch(self):
209 |
210 |
211 | num_positive = len(self.positive_data)
212 |
213 | batch_size = self.batch_size
214 |
215 | if num_positive% batch_size == 0:
216 | batch_num = num_positive // batch_size
217 | else:
218 | batch_num = num_positive // batch_size + 1
219 |
220 | shuffle_indices = np.random.permutation(num_positive)
221 |
222 | for batch in range(batch_num):
223 |
224 | start_index = batch * batch_size
225 | end_index = min((batch+1) * batch_size, num_positive)
226 |
227 | batch_idx = shuffle_indices[start_index:end_index]
228 |
229 | batch_positive_data = self.positive_data[batch_idx]
230 | batch_positive_label = self.positive_label[batch_idx]
231 |
232 | batch_negative_data, batch_negative_label = self.generate_negative(batch_positive_data, self.negative_num)
233 |
234 | # batch_positive_data = []
235 | # for [l, r] in batch_positive_data:
236 | # batch_positive_data.append(self.context_dict[l], self.context_dict[r])
237 |
238 | # [batch, 2, doc, 2, seq]
239 | batch_input = np.concatenate((batch_positive_data, batch_negative_data), axis=0)
240 | batch_label = np.concatenate((batch_positive_label,batch_negative_label), axis=0)
241 |
242 | yield batch_input, batch_label
243 |
244 | def sample_batch(self):
245 | num_data = len(self.avail_vocab)
246 |
247 | batch_size = self.batch_size
248 |
249 | if num_data % batch_size == 0:
250 | batch_num = num_data // batch_size
251 | else:
252 | batch_num = num_data // batch_size + 1
253 |
254 | shuffle_indices = np.random.permutation(num_data)
255 |
256 | for batch in range(batch_num):
257 |
258 | start_index = batch * batch_size
259 | end_index = min((batch+1) * batch_size, num_data)
260 |
261 | batch_idx = shuffle_indices[start_index:end_index]
262 | batch_data_pair = []
263 | batch_data_score = []
264 | batch_data = self.avail_vocab[batch_idx]
265 |
266 | for idx_i in batch_data:
267 | for j in range(self.negative_num):
268 | left_prob = np.random.binomial(1, 0.5)
269 | if left_prob:
270 | while True:
271 | idx_j = np.random.choice(self.avail_vocab, 1)[0]
272 | if (idx_i, idx_j) not in self.dev_dict:
273 | break
274 | batch_data_pair.append([self.w2embid[idx_i], self.w2embid[idx_j]])
275 | score = self.U[idx_i].dot(self.V[idx_j])
276 | else:
277 | while True:
278 | idx_j = np.random.choice(self.avail_vocab, 1)[0]
279 | if (idx_j, idx_i) not in self.dev_dict:
280 | break
281 | batch_data_pair.append([self.w2embid[idx_j], self.w2embid[idx_i]])
282 | score = self.U[idx_j].dot(self.V[idx_i])
283 | batch_data_score.append(score)
284 | yield np.asarray(batch_data_pair), np.asarray(batch_data_score)
285 |
286 | def sample_pos_neg_batch(self):
287 | num_data = len(self.avail_vocab)
288 |
289 | batch_size = self.batch_size
290 |
291 | if num_data % batch_size == 0:
292 | batch_num = num_data // batch_size
293 | else:
294 | batch_num = num_data // batch_size + 1
295 |
296 | shuffle_indices = np.random.permutation(num_data)
297 |
298 | for batch in range(batch_num):
299 |
300 | start_index = batch * batch_size
301 | end_index = min((batch+1) * batch_size, num_data)
302 |
303 | batch_idx = shuffle_indices[start_index:end_index]
304 | batch_data_pair = []
305 | batch_data_score = []
306 | batch_data = self.avail_vocab[batch_idx]
307 |
308 | for idx_i in batch_data:
309 | if idx_i in self.left_has:
310 | idx_j_list = np.random.permutation(self.left_has[idx_i])
311 | for idx_j in idx_j_list:
312 | if idx_j in self.avail_vocab:
313 | batch_data_pair.append([self.w2embid[idx_i], self.w2embid[idx_j]])
314 | score = self.U[idx_i].dot(self.V[idx_j])
315 | batch_data_score.append(score)
316 | break
317 |
318 | if idx_i in self.right_has:
319 | idx_j_list = np.random.permutation(self.right_has[idx_i])
320 | for idx_j in idx_j_list:
321 | if idx_j in self.avail_vocab:
322 | batch_data_pair.append([self.w2embid[idx_j], self.w2embid[idx_i]])
323 | score = self.U[idx_j].dot(self.V[idx_i])
324 | batch_data_score.append(score)
325 | break
326 |
327 | for j in range(self.negative_num):
328 | # left_prob = np.random.binomial(1, 0.5)
329 | # if left_prob:
330 | while True:
331 | idx_j = np.random.choice(self.avail_vocab, 1)[0]
332 | if (idx_i, idx_j) not in self.dev_dict:
333 | break
334 | batch_data_pair.append([self.w2embid[idx_i], self.w2embid[idx_j]])
335 | score = self.U[idx_i].dot(self.V[idx_j])
336 | batch_data_score.append(score)
337 | # else:
338 | while True:
339 | idx_j = np.random.choice(self.avail_vocab, 1)[0]
340 | if (idx_j, idx_i) not in self.dev_dict:
341 | break
342 | batch_data_pair.append([self.w2embid[idx_j], self.w2embid[idx_i]])
343 | score = self.U[idx_j].dot(self.V[idx_i])
344 | batch_data_score.append(score)
345 | yield np.asarray(batch_data_pair), np.asarray(batch_data_score)
346 |
347 |
--------------------------------------------------------------------------------
/utils/data_helper_4bert.py:
--------------------------------------------------------------------------------
1 | from .loader import read_sparse_matrix
2 | from .util import load_gensim_word2vec
3 |
4 | import torch
5 | import numpy as np
6 | import scipy.sparse as sparse
7 | import os
8 | from os.path import isfile, join
9 | import pickle
10 | from transformers import BertTokenizer
11 | import time
12 |
13 | class Dataset(object):
14 | """docstring for Dataset"""
15 | def __init__(self, config, svd=False, train=True):
16 | # generate ppmi matrix for co-occurence
17 | pattern_filename = config.get("data", "pattern_filename")
18 | bert_dir = config.get("data", "bert_path")
19 |
20 | self.context_num = int(config.getfloat("hyperparameters", "context_num"))
21 | self.context_len = int(config.getfloat("hyperparameters", "context_len"))
22 | self.max_seq_length = int(config.getfloat("hyperparameters", "max_seq_length"))
23 |
24 | k = int(config.getfloat("hyperparameters", "svd_dimension"))
25 | self.batch_size = int(config.getfloat("hyperparameters", "batch_size"))
26 | self.negative_num = int(config.getfloat("hyperparameters", "negative_num"))
27 |
28 | self.tokenizer = BertTokenizer.from_pretrained(bert_dir)
29 | csr_m, self.id2word, self.vocab, _ = read_sparse_matrix(
30 | pattern_filename, same_vocab=True)
31 |
32 | self.word2id = {}
33 | for i in range(len(self.id2word)):
34 | self.word2id[self.id2word[i]] = i
35 |
36 | self.matrix = csr_m.todok()
37 | self.p_w = csr_m.sum(axis=1).A[:,0]
38 | self.p_c = csr_m.sum(axis=0).A[0,:]
39 | self.N = self.p_w.sum()
40 |
41 | # for w2v
42 | # self.wordvecs = load_gensim_word2vec("/home/shared/acl-data/embedding/ukwac.model",
43 | # self.vocab)
44 |
45 | # self.wordvec_weights = self.build_emb()
46 |
47 | #print(self.matrix.shape)
48 |
49 | print('SVD matrix...')
50 | tr_matrix = sparse.dok_matrix(self.matrix.shape)
51 | self.left_has = {}
52 | self.right_has = {}
53 | for (l,r) in self.matrix.keys():
54 | pmi_lr = (np.log(self.N) + np.log(self.matrix[(l,r)])
55 | - np.log(self.p_w[l]) - np.log(self.p_c[r]))
56 |
57 | ppmi_lr = np.clip(pmi_lr, 0.0, 1e12)
58 | tr_matrix[(l,r)] = ppmi_lr
59 |
60 | if l not in self.left_has:
61 | self.left_has[l] = []
62 | self.left_has[l].append(r)
63 | if r not in self.right_has:
64 | self.right_has[r] = []
65 | self.right_has[r].append(l)
66 |
67 | self.ppmi_matrix = tr_matrix
68 |
69 | U, S, V = sparse.linalg.svds(self.ppmi_matrix.tocsr(), k=k)
70 | self.U = U.dot(np.diag(S))
71 | self.V = V.T
72 |
73 | # for context
74 | w2v_dir = "/home/shared/acl-data/embedding/"
75 | vocab_path = "/home/shared/acl-data/corpus/"
76 | print('Loading vocab...')
77 | self.load_vocab(w2v_dir, vocab_path)
78 | print('Loading context...')
79 | if train:
80 | self.context_dir = config.get("data", "context")
81 | has_context_word_id_list = self.load_target_word(self.context_dir)
82 | self.context_dict = {}
83 | for matrix_id in range(len(self.vocab)):
84 | word = self.id2word[matrix_id]
85 | if word in self.context_w2i:
86 | context_id = self.context_w2i[word]
87 | if context_id in has_context_word_id_list:
88 | self.context_dict[matrix_id] = self.load_word_context_for_bert(context_id)
89 |
90 | # self.positive_data, self.positive_label = self.generate_positive()
91 | self.get_avail_vocab()
92 | else:
93 | self.context_dict = {}
94 | self.context_dir = config.get("data", "context_oov")
95 | has_context_word_id_list = self.load_target_word(self.context_dir)
96 | for context_id in has_context_word_id_list:
97 | self.context_dict[context_id] = self.load_word_context_for_bert(context_id)
98 |
99 |
100 | def load_target_word(self, data_dir):
101 | target_word_list = [int(f.split('.')[0]) for f in os.listdir(data_dir) if isfile(join(data_dir, f))]
102 | return np.asarray(target_word_list)
103 |
104 | def get_avail_vocab(self):
105 | avail_vocab = []
106 | for idx in range(len(self.vocab)):
107 | if idx in self.context_dict:
108 | avail_vocab.append(idx)
109 | self.avail_vocab = np.asarray(avail_vocab)
110 | print('Available word num: {}'.format(len(avail_vocab)))
111 | shuffle_indices_left = np.random.permutation(len(self.avail_vocab))[:2000]
112 | shuffle_indices_right = np.random.permutation(len(self.avail_vocab))[:2000]
113 | dev_data = []
114 | dev_label = []
115 | self.dev_dict = {}
116 | for id_case in range(2000):
117 | id_left = self.avail_vocab[shuffle_indices_left[id_case]]
118 | id_right = self.avail_vocab[shuffle_indices_right[id_case]]
119 | dev_data.append([id_left,id_right])
120 | dev_label.append(self.U[id_left].dot(self.V[id_right]))
121 | self.dev_dict[(id_left, id_right)] = 1
122 | self.dev_data = np.asarray(dev_data)
123 | self.dev_label = np.asarray(dev_label)
124 |
125 | def build_emb(self):
126 |
127 | self.word2id = {}
128 | for i in range(len(self.id2word)):
129 | self.word2id[self.id2word[i]] = i
130 |
131 | tensors = []
132 | ivocab = []
133 | self.w2embid = {}
134 | self.embid2w = {}
135 |
136 | for word in self.wordvecs:
137 | vec = torch.from_numpy(self.wordvecs[word])
138 | self.w2embid[self.word2id[word]] = len(ivocab)
139 | self.embid2w[len(ivocab)] = self.word2id[word]
140 |
141 | ivocab.append(word)
142 | tensors.append(vec)
143 |
144 | assert len(tensors) == len(ivocab)
145 | tensors = torch.cat(tensors).view(len(ivocab), 300)
146 |
147 | return tensors
148 |
149 | def load_vocab(self, w2v_dir, data_dir):
150 | i2w_path = os.path.join(data_dir, 'ukwac_id2word.pkl')
151 | w2i_path = os.path.join(data_dir, 'ukwac_word2id.pkl')
152 | with open(i2w_path, 'rb') as fr:
153 | self.context_i2w = pickle.load(fr)
154 | with open(w2i_path, 'rb') as fr:
155 | self.context_w2i = pickle.load(fr)
156 |
157 | self.PAD = 0
158 | self.UNK = 1
159 |
160 | # w2v_model = Word2Vec.load(w2v_path)
161 | # emb = w2v_model.wv
162 | # oi2ni = {}
163 | # new_embedding = []
164 | # new_embedding.append(np.zeros(300))
165 | # new_embedding.append(np.zeros(300))
166 | # cnt_ni = 2
167 | # for _id, word in i2w.items():
168 | # if word in emb:
169 | # oi2ni[_id] = cnt_ni
170 | # cnt_ni += 1
171 | # new_embedding.append(emb[word])
172 | # else:
173 | # oi2ni[_id] = self.UNK
174 |
175 | oi2ni_path = os.path.join(w2v_dir, 'context_word_oi2ni.pkl')
176 | w2v_path = os.path.join(w2v_dir, 'context_word_w2v.model.npy')
177 | with open(oi2ni_path, 'rb') as fr:
178 | self.context_i2embid = pickle.load(fr)
179 | self.context_word_emb = np.load(w2v_path)
180 |
181 |
182 | def load_word_context(self, word_idx):
183 | context_path = os.path.join(self.context_dir, '{}.txt'.format(word_idx))
184 |
185 | context_list = []
186 | with open(context_path, 'r') as fr:
187 | flag_right = False
188 | cnt_line = 0
189 |
190 | for line in fr:
191 | line = line.strip()
192 |
193 | if len(line) != 0:
194 | context = [int(num) for num in line.split(' ')]
195 | else:
196 | context = []
197 | context = [self.context_i2embid[num] for num in context]
198 | if not flag_right:
199 | left_context = [self.PAD] * self.context_len
200 | if len(context) >= self.context_len:
201 | left_context = context[(len(context) - self.context_len):]
202 | else:
203 | left_context[(self.context_len-len(context)):] = context
204 | flag_right = True
205 | else:
206 | right_context = [self.PAD] * self.context_len
207 | if len(context) >= self.context_len:
208 | right_context = list(reversed(context[:self.context_len]))
209 | else:
210 | right_context[self.context_len-len(context):] = list(reversed(context))
211 |
212 | context_list.append([left_context, right_context])
213 | flag_right = False
214 | cnt_line += 1
215 | if cnt_line == 2* self.context_num:
216 | break
217 |
218 | if len(context_list) <= self.context_num:
219 | for i in range(self.context_num - len(context_list)):
220 | context_list.append([[self.PAD]*self.context_len, [self.PAD]*self.context_len])
221 |
222 | return context_list
223 |
224 | def load_word_context_for_bert(self, word_idx):
225 | context_path = os.path.join(self.context_dir, '{}.txt'.format(word_idx))
226 | context_list = []
227 | with open(context_path, 'r') as fr:
228 | flag_right = False
229 | cnt_line = 0
230 | for line in fr:
231 | line = line.strip()
232 | if len(line) != 0:
233 | context = [int(num) for num in line.split(' ')]
234 | else:
235 | context = []
236 |
237 | context = [self.context_i2w[num] for num in context]
238 | context = [each for each in context if each != "@card@"]
239 | context = [each for each in context if "http" not in each]
240 | context = [each for each in context if "JavaScript" not in each]
241 | if not flag_right:
242 | if len(context) >= self.context_len:
243 | left_context = context[(len(context) - self.context_len):]
244 | else:
245 | left_context = context
246 | flag_right = True
247 | left_context.append(self.context_i2w[word_idx])
248 | else:
249 | if len(context) >= self.context_len:
250 | right_context = context[:self.context_len]
251 | else:
252 | right_context = context
253 |
254 | full_context = '[CLS] ' + " ".join(left_context + right_context) + ' [SEP]'
255 | context_list.append(full_context)
256 | flag_right = False
257 | cnt_line += 1
258 | if cnt_line == 2 * self.context_num:
259 | break
260 | if word_idx == self.context_w2i['kinetoscope'] or word_idx == self.context_w2i['device']:
261 | print(context_list)
262 | if len(context_list) <= self.context_num:
263 | for i in range(self.context_num - len(context_list)):
264 | context_list.append(" ".join(['[PAD]']*(2 *self.context_len + 1)))
265 |
266 | batched_output = self.tokenizer.batch_encode_plus(context_list, \
267 | add_special_tokens=False, padding='max_length', truncation=True, max_length= self.max_seq_length, \
268 | return_attention_mask=True)
269 | if np.asarray(batched_output['input_ids']).dtype == "object":
270 | print("something is wrong on word :" + str(word_idx))
271 | print(context_list)
272 | print(len(context_list))
273 | print("--------------------------------------------")
274 |
275 | outputs = dict({'ids': batched_output['input_ids'], "mask": batched_output['attention_mask']})
276 | return outputs
277 |
278 | def generate_positive(self):
279 |
280 | positive = []
281 | label = []
282 | key_list = list(self.ppmi_matrix.keys())
283 | shuffle_indices = np.random.permutation(len(key_list))
284 |
285 | for shuffle_id in shuffle_indices:
286 | (l, r) = key_list[shuffle_id]
287 | # if self.id2word[l] in self.wordvecs and self.id2word[r] in self.wordvecs:
288 | # positive.append([self.w2embid[l],self.w2embid[r]])
289 | if l in self.context_dict and r in self.context_dict:
290 | positive.append([l, r])
291 | score = self.U[l].dot(self.V[r])
292 | label.append(score)
293 | # label.append(self.ppmi_matrix[(l,r)])
294 | # 119448 positive score
295 | positive_train = np.asarray(positive)[:-2000]
296 |
297 | self.dev_data = np.asarray(positive)[-2000:]
298 |
299 | label_train = np.asarray(label)[:-2000]
300 | self.dev_label = np.asarray(label)[-2000:]
301 |
302 | return positive_train, label_train
303 |
304 | def generate_negative(self, batch_data, negative_num):
305 |
306 | negative = []
307 | label = []
308 |
309 | batch_size = batch_data.shape[0]
310 |
311 | for i in range(batch_size):
312 | # random_idx = np.random.choice(len(self.vocab), 150 , replace=False)
313 | l = batch_data[i][0]
314 | # l_w = self.embid2w[l]
315 | r = batch_data[i][1]
316 | # r_w = self.embid2w[r]
317 |
318 | l_neg = l
319 | r_neg = r
320 |
321 | num = 0
322 | for j in range(negative_num):
323 | left_prob = np.random.binomial(1, 0.5)
324 | while True:
325 | if left_prob:
326 | l_neg = np.random.choice(len(self.vocab), 1)[0]
327 | else:
328 | r_neg = np.random.choice(len(self.vocab), 1)[0]
329 | # if (l_neg, r_neg) not in self.matrix.keys() and self.id2word[l_neg] in self.wordvecs and self.id2word[r_neg] in self.wordvecs:
330 | if (l_neg, r_neg) not in self.matrix.keys() and l_neg in self.context_dict and r_neg in self.context_dict:
331 | break
332 |
333 | # negative.append([self.w2embid[l_neg], self.w2embid[r_neg]])
334 | negative.append([self.context_dict[l_neg], self.context_dict[r_neg]])
335 | score = self.U[l_neg].dot(self.V[r_neg])
336 | # score = 0
337 | label.append(score)
338 |
339 | negative = np.asarray(negative)
340 | label = np.asarray(label)
341 | return negative, label
342 |
343 |
344 | def get_batch(self):
345 |
346 |
347 | num_positive = len(self.positive_data)
348 |
349 | batch_size = self.batch_size
350 |
351 | if num_positive% batch_size == 0:
352 | batch_num = num_positive // batch_size
353 | else:
354 | batch_num = num_positive // batch_size + 1
355 |
356 | shuffle_indices = np.random.permutation(num_positive)
357 |
358 | for batch in range(batch_num):
359 |
360 | start_index = batch * batch_size
361 | end_index = min((batch+1) * batch_size, num_positive)
362 |
363 | batch_idx = shuffle_indices[start_index:end_index]
364 |
365 | batch_positive_data = self.positive_data[batch_idx]
366 | batch_positive_label = self.positive_label[batch_idx]
367 |
368 | batch_negative_data, batch_negative_label = self.generate_negative(batch_positive_data, self.negative_num)
369 |
370 | batch_positive_context = []
371 | for [l, r] in batch_positive_data:
372 | batch_positive_context.append([self.context_dict[l], self.context_dict[r]])
373 |
374 | # [batch, 2, doc, 2, seq]
375 | batch_input = np.concatenate((batch_positive_context, batch_negative_data), axis=0)
376 | batch_label = np.concatenate((batch_positive_label,batch_negative_label), axis=0)
377 |
378 | yield batch_input, batch_label
379 |
380 | def sample_batch(self):
381 | num_data = len(self.avail_vocab)
382 |
383 | batch_size = self.batch_size
384 |
385 | if num_data % batch_size == 0:
386 | batch_num = num_data // batch_size
387 | else:
388 | batch_num = num_data // batch_size + 1
389 |
390 | shuffle_indices = np.random.permutation(num_data)
391 |
392 | for batch in range(batch_num):
393 |
394 | start_index = batch * batch_size
395 | end_index = min((batch+1) * batch_size, num_data)
396 |
397 | batch_idx = shuffle_indices[start_index:end_index]
398 | batch_data_context = []
399 | batch_data_mask = []
400 | batch_data_score = []
401 | batch_data = self.avail_vocab[batch_idx]
402 |
403 | for idx_i in batch_data:
404 | for j in range(self.negative_num):
405 | left_prob = np.random.binomial(1, 0.5)
406 | if left_prob:
407 | while True:
408 | idx_j = np.random.choice(self.avail_vocab, 1)[0]
409 | if (idx_i, idx_j) not in self.dev_dict:
410 | break
411 | batch_data_context.append([self.context_dict[idx_i]['ids'], self.context_dict[idx_j]['ids']])
412 | batch_data_mask.append([self.context_dict[idx_i]['mask'], self.context_dict[idx_j]['mask']])
413 | score = self.U[idx_i].dot(self.V[idx_j])
414 | else:
415 | while True:
416 | idx_j = np.random.choice(self.avail_vocab, 1)[0]
417 | if (idx_j, idx_i) not in self.dev_dict:
418 | break
419 | batch_data_context.append([self.context_dict[idx_j]['ids'], self.context_dict[idx_i]['ids']])
420 | batch_data_mask.append([self.context_dict[idx_j]['mask'], self.context_dict[idx_i]['mask']])
421 | score = self.U[idx_j].dot(self.V[idx_i])
422 | batch_data_score.append(score)
423 | yield np.asarray(batch_data_context), np.asarray(batch_data_mask), np.asarray(batch_data_score)
424 |
425 | def sample_batch_dev(self):
426 | num_data = len(self.dev_data)
427 |
428 | batch_size = self.batch_size
429 |
430 | if num_data % batch_size == 0:
431 | batch_num = num_data // batch_size
432 | else:
433 | batch_num = num_data // batch_size + 1
434 |
435 | # shuffle_indices = np.random.permutation(num_data)
436 |
437 | for batch in range(batch_num):
438 | start_index = batch * batch_size
439 | end_index = min((batch+1) * batch_size, num_data)
440 |
441 | batch_data = self.dev_data[start_index:end_index]
442 | # batch_data_score = self.dev_label[start_index:end_index]
443 | batch_data_context = []
444 | batch_data_mask = []
445 | for pair in batch_data:
446 | idx_i, idx_j = pair
447 |
448 | batch_data_context.append([self.context_dict[idx_i]['ids'], self.context_dict[idx_j]['ids']])
449 | batch_data_mask.append([self.context_dict[idx_i]['mask'], self.context_dict[idx_j]['mask']])
450 | yield np.asarray(batch_data_context), np.asarray(batch_data_mask)
451 |
452 |
453 | def sample_pos_neg_batch(self):
454 | num_data = len(self.avail_vocab)
455 |
456 | batch_size = self.batch_size
457 |
458 | if num_data % batch_size == 0:
459 | batch_num = num_data // batch_size
460 | else:
461 | batch_num = num_data // batch_size + 1
462 |
463 | shuffle_indices = np.random.permutation(num_data)
464 |
465 | for batch in range(batch_num):
466 |
467 | start_index = batch * batch_size
468 | end_index = min((batch+1) * batch_size, num_data)
469 |
470 | batch_idx = shuffle_indices[start_index:end_index]
471 | batch_data_context = []
472 | batch_data_score = []
473 | batch_data = self.avail_vocab[batch_idx]
474 |
475 | for idx_i in batch_data:
476 | if idx_i in self.left_has:
477 | idx_j_list = np.random.permutation(self.left_has[idx_i])
478 | for idx_j in idx_j_list:
479 | if idx_j in self.avail_vocab:
480 | # batch_data_pair.append([self.w2embid[idx_i], self.w2embid[idx_j]])
481 | batch_data_context.append([self.context_dict[idx_i], self.context_dict[idx_j]])
482 | score = self.U[idx_i].dot(self.V[idx_j])
483 | batch_data_score.append(score)
484 | break
485 |
486 | if idx_i in self.right_has:
487 | idx_j_list = np.random.permutation(self.right_has[idx_i])
488 | for idx_j in idx_j_list:
489 | if idx_j in self.avail_vocab:
490 | # batch_data_pair.append([self.w2embid[idx_j], self.w2embid[idx_i]])
491 | batch_data_context.append([self.context_dict[idx_j], self.context_dict[idx_i]])
492 | score = self.U[idx_j].dot(self.V[idx_i])
493 | batch_data_score.append(score)
494 | break
495 |
496 | for j in range(self.negative_num):
497 | # left_prob = np.random.binomial(1, 0.5)
498 | # if left_prob:
499 | while True:
500 | idx_j = np.random.choice(self.avail_vocab, 1)[0]
501 | if (idx_i, idx_j) not in self.dev_dict:
502 | break
503 | # batch_data_pair.append([self.w2embid[idx_i], self.w2embid[idx_j]])
504 | batch_data_context.append([self.context_dict[idx_i], self.context_dict[idx_j]])
505 | score = self.U[idx_i].dot(self.V[idx_j])
506 | batch_data_score.append(score)
507 | # else:
508 | while True:
509 | idx_j = np.random.choice(self.avail_vocab, 1)[0]
510 | if (idx_j, idx_i) not in self.dev_dict:
511 | break
512 | # batch_data_pair.append([self.w2embid[idx_j], self.w2embid[idx_i]])
513 | batch_data_context.append([self.context_dict[idx_j], self.context_dict[idx_i]])
514 | score = self.U[idx_j].dot(self.V[idx_i])
515 | batch_data_score.append(score)
516 | yield np.asarray(batch_data_context), np.asarray(batch_data_score)
517 |
518 |
--------------------------------------------------------------------------------
/utils/data_helper_4context.py:
--------------------------------------------------------------------------------
1 | from .loader import read_sparse_matrix
2 | from .util import load_gensim_word2vec
3 |
4 | import torch
5 | import numpy as np
6 | import scipy.sparse as sparse
7 | import os
8 | from os.path import isfile, join
9 | import pickle
10 |
11 | class Dataset(object):
12 | """docstring for Dataset"""
13 | def __init__(self, config, svd=False, train=True):
14 | # generate ppmi matrix for co-occurence
15 | pattern_filename = config.get("data", "pattern_filename")
16 |
17 | self.context_num = int(config.getfloat("hyperparameters", "context_num"))
18 | self.context_len = int(config.getfloat("hyperparameters", "context_len"))
19 |
20 | k = int(config.getfloat("hyperparameters", "svd_dimension"))
21 | self.batch_size = int(config.getfloat("hyperparameters", "batch_size"))
22 | self.negative_num = int(config.getfloat("hyperparameters", "negative_num"))
23 |
24 | csr_m, self.id2word, self.vocab, _ = read_sparse_matrix(
25 | pattern_filename, same_vocab=True)
26 |
27 | self.word2id = {}
28 | for i in range(len(self.id2word)):
29 | self.word2id[self.id2word[i]] = i
30 |
31 | self.matrix = csr_m.todok()
32 | self.p_w = csr_m.sum(axis=1).A[:,0]
33 | self.p_c = csr_m.sum(axis=0).A[0,:]
34 | self.N = self.p_w.sum()
35 |
36 | # for w2v
37 | # self.wordvecs = load_gensim_word2vec("/home/shared/acl-data/embedding/ukwac.model",
38 | # self.vocab)
39 |
40 | # self.wordvec_weights = self.build_emb()
41 |
42 | #print(self.matrix.shape)
43 |
44 | print('SVD matrix...')
45 | tr_matrix = sparse.dok_matrix(self.matrix.shape)
46 | self.left_has = {}
47 | self.right_has = {}
48 | for (l,r) in self.matrix.keys():
49 | pmi_lr = (np.log(self.N) + np.log(self.matrix[(l,r)])
50 | - np.log(self.p_w[l]) - np.log(self.p_c[r]))
51 |
52 | ppmi_lr = np.clip(pmi_lr, 0.0, 1e12)
53 | tr_matrix[(l,r)] = ppmi_lr
54 |
55 | if l not in self.left_has:
56 | self.left_has[l] = []
57 | self.left_has[l].append(r)
58 | if r not in self.right_has:
59 | self.right_has[r] = []
60 | self.right_has[r].append(l)
61 |
62 | self.ppmi_matrix = tr_matrix
63 |
64 | U, S, V = sparse.linalg.svds(self.ppmi_matrix.tocsr(), k=k)
65 | self.U = U.dot(np.diag(S))
66 | self.V = V.T
67 |
68 | # for context
69 | w2v_dir = "/home/shared/acl-data/embedding/"
70 | vocab_path = "/home/shared/acl-data/corpus/"
71 | print('Loading vocab...')
72 | self.load_vocab(w2v_dir, vocab_path)
73 | print('Loading context...')
74 | if train:
75 | self.context_dir = config.get("data", "context")
76 | has_context_word_id_list = self.load_target_word(self.context_dir)
77 | self.context_dict = {}
78 | for matrix_id in range(len(self.vocab)):
79 | word = self.id2word[matrix_id]
80 | if word in self.context_w2i:
81 | context_id = self.context_w2i[word]
82 | if context_id in has_context_word_id_list:
83 | self.context_dict[matrix_id] = self.load_word_context(context_id)
84 |
85 | # self.positive_data, self.positive_label = self.generate_positive()
86 | self.get_avail_vocab()
87 | else:
88 | self.context_dict = {}
89 | self.context_dir = config.get("data", "context_oov")
90 | has_context_word_id_list = self.load_target_word(self.context_dir)
91 | for context_id in has_context_word_id_list:
92 | self.context_dict[context_id] = self.load_word_context(context_id)
93 |
94 |
95 | def load_target_word(self, data_dir):
96 | target_word_list = [int(f.split('.')[0]) for f in os.listdir(data_dir) if isfile(join(data_dir, f))]
97 | return np.asarray(target_word_list)
98 |
99 | def get_avail_vocab(self):
100 | avail_vocab = []
101 | for idx in range(len(self.vocab)):
102 | if idx in self.context_dict:
103 | avail_vocab.append(idx)
104 | self.avail_vocab = np.asarray(avail_vocab)
105 | print('Available word num: {}'.format(len(avail_vocab)))
106 | shuffle_indices_left = np.random.permutation(len(self.avail_vocab))[:2000]
107 | shuffle_indices_right = np.random.permutation(len(self.avail_vocab))[:2000]
108 | dev_data = []
109 | dev_label = []
110 | self.dev_dict = {}
111 | for id_case in range(2000):
112 | id_left = self.avail_vocab[shuffle_indices_left[id_case]]
113 | id_right = self.avail_vocab[shuffle_indices_right[id_case]]
114 | dev_data.append([id_left,id_right])
115 | dev_label.append(self.U[id_left].dot(self.V[id_right]))
116 | self.dev_dict[(id_left, id_right)] = 1
117 | self.dev_data = np.asarray(dev_data)
118 | self.dev_label = np.asarray(dev_label)
119 |
120 | def build_emb(self):
121 |
122 | self.word2id = {}
123 | for i in range(len(self.id2word)):
124 | self.word2id[self.id2word[i]] = i
125 |
126 | tensors = []
127 | ivocab = []
128 | self.w2embid = {}
129 | self.embid2w = {}
130 |
131 | for word in self.wordvecs:
132 | vec = torch.from_numpy(self.wordvecs[word])
133 | self.w2embid[self.word2id[word]] = len(ivocab)
134 | self.embid2w[len(ivocab)] = self.word2id[word]
135 |
136 | ivocab.append(word)
137 | tensors.append(vec)
138 |
139 | assert len(tensors) == len(ivocab)
140 | tensors = torch.cat(tensors).view(len(ivocab), 300)
141 |
142 | return tensors
143 |
144 | def load_vocab(self, w2v_dir, data_dir):
145 | i2w_path = os.path.join(data_dir, 'ukwac_id2word.pkl')
146 | w2i_path = os.path.join(data_dir, 'ukwac_word2id.pkl')
147 | with open(i2w_path, 'rb') as fr:
148 | self.context_i2w = pickle.load(fr)
149 | with open(w2i_path, 'rb') as fr:
150 | self.context_w2i = pickle.load(fr)
151 |
152 | self.PAD = 0
153 | self.UNK = 1
154 |
155 | # w2v_model = Word2Vec.load(w2v_path)
156 | # emb = w2v_model.wv
157 | # oi2ni = {}
158 | # new_embedding = []
159 | # new_embedding.append(np.zeros(300))
160 | # new_embedding.append(np.zeros(300))
161 | # cnt_ni = 2
162 | # for _id, word in i2w.items():
163 | # if word in emb:
164 | # oi2ni[_id] = cnt_ni
165 | # cnt_ni += 1
166 | # new_embedding.append(emb[word])
167 | # else:
168 | # oi2ni[_id] = self.UNK
169 |
170 | oi2ni_path = os.path.join(w2v_dir, 'context_word_oi2ni.pkl')
171 | w2v_path = os.path.join(w2v_dir, 'context_word_w2v.model.npy')
172 | with open(oi2ni_path, 'rb') as fr:
173 | self.context_i2embid = pickle.load(fr)
174 | self.context_word_emb = np.load(w2v_path)
175 |
176 |
177 | def load_word_context(self, word_idx):
178 | context_path = os.path.join(self.context_dir, '{}.txt'.format(word_idx))
179 |
180 | context_list = []
181 | with open(context_path, 'r') as fr:
182 | flag_right = False
183 | cnt_line = 0
184 |
185 | for line in fr:
186 | line = line.strip()
187 |
188 | if len(line) != 0:
189 | context = [int(num) for num in line.split(' ')]
190 | else:
191 | context = []
192 | context = [self.context_i2embid[num] for num in context]
193 | if not flag_right:
194 | left_context = [self.PAD] * self.context_len
195 | if len(context) >= self.context_len:
196 | left_context = context[(len(context) - self.context_len):]
197 | else:
198 | left_context[(self.context_len-len(context)):] = context
199 | flag_right = True
200 | else:
201 | right_context = [self.PAD] * self.context_len
202 | if len(context) >= self.context_len:
203 | right_context = list(reversed(context[:self.context_len]))
204 | else:
205 | right_context[self.context_len-len(context):] = list(reversed(context))
206 |
207 | context_list.append([left_context, right_context])
208 | flag_right = False
209 | cnt_line += 1
210 | if cnt_line == 2* self.context_num:
211 | break
212 |
213 | if len(context_list) <= self.context_num:
214 | for i in range(self.context_num - len(context_list)):
215 | context_list.append([[self.PAD]*self.context_len, [self.PAD]*self.context_len])
216 |
217 | return context_list
218 |
219 | def load_prediction_word_context(self, word_idx):
220 | context_path = os.path.join(self.context_dir, '{}.txt'.format(word_idx))
221 | context_list = []
222 | with open(context_path, 'r') as fr:
223 | flag_right = False
224 | cnt_line = 0
225 | for line in fr:
226 | line = line.strip()
227 | if len(line) != 0:
228 | context = [int(num) for num in line.split(' ')]
229 | else:
230 | context = []
231 |
232 | print(line)
233 | context = [self.context_i2w[num] for num in context]
234 | print(context)
235 | if not flag_right:
236 | left_context = [''] * self.context_len
237 | if len(context) >= self.context_len:
238 | left_context = context[(len(context) - self.context_len):]
239 | else:
240 | left_context[(self.context_len-len(context)):] = context
241 | flag_right = True
242 | else:
243 | right_context = [''] * self.context_len
244 | if len(context) >= self.context_len:
245 | right_context = list(reversed(context[:self.context_len]))
246 | else:
247 | right_context[self.context_len-len(context):] = list(reversed(context))
248 |
249 | context_list.append([left_context, right_context])
250 | flag_right = False
251 | cnt_line += 1
252 | if cnt_line == 2 * self.context_num:
253 | break
254 |
255 | if len(context_list) <= self.context_num:
256 | for i in range(self.context_num - len(context_list)):
257 | context_list.append([['']*self.context_len, ['']*self.context_len])
258 |
259 | return context_list
260 |
261 |
262 | def generate_positive(self):
263 |
264 | positive = []
265 | label = []
266 | key_list = list(self.ppmi_matrix.keys())
267 | shuffle_indices = np.random.permutation(len(key_list))
268 |
269 | for shuffle_id in shuffle_indices:
270 | (l, r) = key_list[shuffle_id]
271 | # if self.id2word[l] in self.wordvecs and self.id2word[r] in self.wordvecs:
272 | # positive.append([self.w2embid[l],self.w2embid[r]])
273 | if l in self.context_dict and r in self.context_dict:
274 | positive.append([l, r])
275 | score = self.U[l].dot(self.V[r])
276 | label.append(score)
277 | # label.append(self.ppmi_matrix[(l,r)])
278 | # 119448 positive score
279 | positive_train = np.asarray(positive)[:-2000]
280 |
281 | self.dev_data = np.asarray(positive)[-2000:]
282 |
283 | label_train = np.asarray(label)[:-2000]
284 | self.dev_label = np.asarray(label)[-2000:]
285 |
286 | return positive_train, label_train
287 |
288 | def generate_negative(self, batch_data, negative_num):
289 |
290 | negative = []
291 | label = []
292 |
293 | batch_size = batch_data.shape[0]
294 |
295 | for i in range(batch_size):
296 | # random_idx = np.random.choice(len(self.vocab), 150 , replace=False)
297 | l = batch_data[i][0]
298 | # l_w = self.embid2w[l]
299 | r = batch_data[i][1]
300 | # r_w = self.embid2w[r]
301 |
302 | l_neg = l
303 | r_neg = r
304 |
305 | num = 0
306 | for j in range(negative_num):
307 | left_prob = np.random.binomial(1, 0.5)
308 | while True:
309 | if left_prob:
310 | l_neg = np.random.choice(len(self.vocab), 1)[0]
311 | else:
312 | r_neg = np.random.choice(len(self.vocab), 1)[0]
313 | # if (l_neg, r_neg) not in self.matrix.keys() and self.id2word[l_neg] in self.wordvecs and self.id2word[r_neg] in self.wordvecs:
314 | if (l_neg, r_neg) not in self.matrix.keys() and l_neg in self.context_dict and r_neg in self.context_dict:
315 | break
316 |
317 | # negative.append([self.w2embid[l_neg], self.w2embid[r_neg]])
318 | negative.append([self.context_dict[l_neg], self.context_dict[r_neg]])
319 | score = self.U[l_neg].dot(self.V[r_neg])
320 | # score = 0
321 | label.append(score)
322 |
323 | negative = np.asarray(negative)
324 | label = np.asarray(label)
325 | return negative, label
326 |
327 |
328 | def get_batch(self):
329 |
330 |
331 | num_positive = len(self.positive_data)
332 |
333 | batch_size = self.batch_size
334 |
335 | if num_positive% batch_size == 0:
336 | batch_num = num_positive // batch_size
337 | else:
338 | batch_num = num_positive // batch_size + 1
339 |
340 | shuffle_indices = np.random.permutation(num_positive)
341 |
342 | for batch in range(batch_num):
343 |
344 | start_index = batch * batch_size
345 | end_index = min((batch+1) * batch_size, num_positive)
346 |
347 | batch_idx = shuffle_indices[start_index:end_index]
348 |
349 | batch_positive_data = self.positive_data[batch_idx]
350 | batch_positive_label = self.positive_label[batch_idx]
351 |
352 | batch_negative_data, batch_negative_label = self.generate_negative(batch_positive_data, self.negative_num)
353 |
354 | batch_positive_context = []
355 | for [l, r] in batch_positive_data:
356 | batch_positive_context.append([self.context_dict[l], self.context_dict[r]])
357 |
358 | # [batch, 2, doc, 2, seq]
359 | batch_input = np.concatenate((batch_positive_context, batch_negative_data), axis=0)
360 | batch_label = np.concatenate((batch_positive_label,batch_negative_label), axis=0)
361 |
362 | yield batch_input, batch_label
363 |
364 | def sample_batch(self):
365 | num_data = len(self.avail_vocab)
366 |
367 | batch_size = self.batch_size
368 |
369 | if num_data % batch_size == 0:
370 | batch_num = num_data // batch_size
371 | else:
372 | batch_num = num_data // batch_size + 1
373 |
374 | shuffle_indices = np.random.permutation(num_data)
375 |
376 | for batch in range(batch_num):
377 |
378 | start_index = batch * batch_size
379 | end_index = min((batch+1) * batch_size, num_data)
380 |
381 | batch_idx = shuffle_indices[start_index:end_index]
382 | batch_data_context = []
383 | batch_data_score = []
384 | batch_data = self.avail_vocab[batch_idx]
385 |
386 | for idx_i in batch_data:
387 | for j in range(self.negative_num):
388 | left_prob = np.random.binomial(1, 0.5)
389 | if left_prob:
390 | while True:
391 | idx_j = np.random.choice(self.avail_vocab, 1)[0]
392 | if (idx_i, idx_j) not in self.dev_dict:
393 | break
394 | batch_data_context.append([self.context_dict[idx_i], self.context_dict[idx_j]])
395 | score = self.U[idx_i].dot(self.V[idx_j])
396 | else:
397 | while True:
398 | idx_j = np.random.choice(self.avail_vocab, 1)[0]
399 | if (idx_j, idx_i) not in self.dev_dict:
400 | break
401 | batch_data_context.append([self.context_dict[idx_j], self.context_dict[idx_i]])
402 | score = self.U[idx_j].dot(self.V[idx_i])
403 | batch_data_score.append(score)
404 | yield np.asarray(batch_data_context), np.asarray(batch_data_score)
405 |
406 | def sample_batch_dev(self):
407 | num_data = len(self.dev_data)
408 |
409 | batch_size = self.batch_size
410 |
411 | if num_data % batch_size == 0:
412 | batch_num = num_data // batch_size
413 | else:
414 | batch_num = num_data // batch_size + 1
415 |
416 | # shuffle_indices = np.random.permutation(num_data)
417 |
418 | for batch in range(batch_num):
419 | start_index = batch * batch_size
420 | end_index = min((batch+1) * batch_size, num_data)
421 |
422 | batch_data = self.dev_data[start_index:end_index]
423 | # batch_data_score = self.dev_label[start_index:end_index]
424 | batch_data_context = []
425 | for pair in batch_data:
426 | idx_i, idx_j = pair
427 |
428 | batch_data_context.append([self.context_dict[idx_i], self.context_dict[idx_j]])
429 | yield np.asarray(batch_data_context)
430 |
431 |
432 | def sample_pos_neg_batch(self):
433 | num_data = len(self.avail_vocab)
434 |
435 | batch_size = self.batch_size
436 |
437 | if num_data % batch_size == 0:
438 | batch_num = num_data // batch_size
439 | else:
440 | batch_num = num_data // batch_size + 1
441 |
442 | shuffle_indices = np.random.permutation(num_data)
443 |
444 | for batch in range(batch_num):
445 |
446 | start_index = batch * batch_size
447 | end_index = min((batch+1) * batch_size, num_data)
448 |
449 | batch_idx = shuffle_indices[start_index:end_index]
450 | batch_data_context = []
451 | batch_data_score = []
452 | batch_data = self.avail_vocab[batch_idx]
453 |
454 | for idx_i in batch_data:
455 | if idx_i in self.left_has:
456 | idx_j_list = np.random.permutation(self.left_has[idx_i])
457 | for idx_j in idx_j_list:
458 | if idx_j in self.avail_vocab:
459 | # batch_data_pair.append([self.w2embid[idx_i], self.w2embid[idx_j]])
460 | batch_data_context.append([self.context_dict[idx_i], self.context_dict[idx_j]])
461 | score = self.U[idx_i].dot(self.V[idx_j])
462 | batch_data_score.append(score)
463 | break
464 |
465 | if idx_i in self.right_has:
466 | idx_j_list = np.random.permutation(self.right_has[idx_i])
467 | for idx_j in idx_j_list:
468 | if idx_j in self.avail_vocab:
469 | # batch_data_pair.append([self.w2embid[idx_j], self.w2embid[idx_i]])
470 | batch_data_context.append([self.context_dict[idx_j], self.context_dict[idx_i]])
471 | score = self.U[idx_j].dot(self.V[idx_i])
472 | batch_data_score.append(score)
473 | break
474 |
475 | for j in range(self.negative_num):
476 | # left_prob = np.random.binomial(1, 0.5)
477 | # if left_prob:
478 | while True:
479 | idx_j = np.random.choice(self.avail_vocab, 1)[0]
480 | if (idx_i, idx_j) not in self.dev_dict:
481 | break
482 | # batch_data_pair.append([self.w2embid[idx_i], self.w2embid[idx_j]])
483 | batch_data_context.append([self.context_dict[idx_i], self.context_dict[idx_j]])
484 | score = self.U[idx_i].dot(self.V[idx_j])
485 | batch_data_score.append(score)
486 | # else:
487 | while True:
488 | idx_j = np.random.choice(self.avail_vocab, 1)[0]
489 | if (idx_j, idx_i) not in self.dev_dict:
490 | break
491 | # batch_data_pair.append([self.w2embid[idx_j], self.w2embid[idx_i]])
492 | batch_data_context.append([self.context_dict[idx_j], self.context_dict[idx_i]])
493 | score = self.U[idx_j].dot(self.V[idx_i])
494 | batch_data_score.append(score)
495 | yield np.asarray(batch_data_context), np.asarray(batch_data_score)
496 |
497 |
--------------------------------------------------------------------------------
/utils/loader.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 |
3 | # Copyright (c) 2017-present, Facebook, Inc.
4 | # All rights reserved.
5 | #
6 | # This source code is licensed under the license found in the
7 | # LICENSE file in the root directory of this source tree.
8 |
9 | """
10 | Utility module for easily reading a sparse matrix.
11 | """
12 |
13 | from __future__ import absolute_import
14 | from __future__ import division
15 | from __future__ import print_function
16 | from __future__ import unicode_literals
17 |
18 | import numpy as np
19 | import logging
20 | import os
21 | import gzip
22 | import pandas as pd
23 | import scipy.sparse as sp
24 | from nltk.stem.wordnet import WordNetLemmatizer
25 | lemmatizer = WordNetLemmatizer()
26 |
27 |
28 | try:
29 | import cPickle as pickle
30 | except ImportError:
31 | import pickle
32 |
33 |
34 | def __try_three_columns(string):
35 | fields = string.split("\t")
36 | if len(fields) > 3:
37 | fields = fields[:3]
38 | if len(fields) == 3:
39 | return fields[0], fields[1], float(fields[2])
40 | if len(fields) == 2:
41 | return fields[0], fields[1], 1.0
42 | else:
43 | raise ValueError("Invalid number of fields {}".format(len(fields)))
44 |
45 |
46 | def __load_sparse_matrix(filename, same_vocab):
47 | """
48 | Actual workhorse for loading a sparse matrix. See docstring for
49 | read_sparse_matrix.
50 |
51 | """
52 | objects = [""]
53 | rowvocab = {"": 0}
54 | if same_vocab:
55 | colvocab = rowvocab
56 | else:
57 | colvocab = {}
58 | _is = []
59 | _js = []
60 | _vs = []
61 |
62 | # Read gzip files
63 | if filename.endswith(".gz"):
64 | f = gzip.open(filename, "r")
65 | else:
66 | f = open(filename, "rb")
67 |
68 | for line in f:
69 | line = line.decode("utf-8")
70 | target, context, weight = __try_three_columns(line)
71 | if target not in rowvocab:
72 | rowvocab[target] = len(rowvocab)
73 | objects.append(target)
74 | if context not in colvocab:
75 | colvocab[context] = len(colvocab)
76 | if same_vocab:
77 | objects.append(context)
78 |
79 | _is.append(rowvocab[target])
80 | _js.append(colvocab[context])
81 | _vs.append(weight)
82 |
83 | # clean up
84 | f.close()
85 |
86 | _shape = (len(rowvocab), len(colvocab))
87 | spmatrix = sp.csr_matrix((_vs, (_is, _js)), shape=_shape, dtype=np.float64)
88 | return spmatrix, objects, rowvocab, colvocab
89 |
90 |
91 | def read_sparse_matrix(filename, allow_binary_cache=False, same_vocab=False):
92 | """
93 | Reads in a 3 column file as a sparse matrix, where each line (x, y, v)
94 | gives the name of the row x, column y, and the value z.
95 |
96 | If filename ends with .gz, will assume the file is gzip compressed.
97 |
98 | Args:
99 | filename: str. The filename containing sparse matrix in 3-col format.
100 | allow_binary_cache: bool. If true, caches the matrix in a pkl file with
101 | the same filename for faster reads. If cache doesn't exist, will
102 | create it.
103 | same_vocab: bool. Indicates whether rows and columns have the same vocab.
104 |
105 | Returns:
106 | A tuple containing (spmatrix, id2row, row2id, col2id):
107 | spmatrix: a scipy.sparse matrix with the entries
108 | id2row: a list[str] containing the names for the rows of the matrix
109 | row2id: a dict[str,int] mapping words to row indices
110 | col2id: a dict[str,int] mapping words to col indices. If same_vocab,
111 | this is identical to row2id.
112 | """
113 | # make sure the cache is new enough
114 | cache_filename = filename + ".pkl"
115 | cache_exists = os.path.exists(cache_filename)
116 | cache_fresh = cache_exists and os.path.getmtime(filename) <= os.path.getmtime(
117 | cache_filename
118 | )
119 | if allow_binary_cache and cache_fresh:
120 | logging.debug("Using space cache {}".format(cache_filename))
121 | with open(cache_filename + ".pkl", "rb") as pklf:
122 | return pickle.load(pklf)
123 | else:
124 | # binary cache is not allowed, or it's stale
125 | result = __load_sparse_matrix(filename, same_vocab=same_vocab)
126 | if allow_binary_cache:
127 | logging.warning("Dumping the binary cache {}.pkl".format(filename))
128 | with open(filename + ".pkl", "wb") as pklf:
129 | pickle.dump(result, pklf)
130 | return result
131 |
132 |
133 | class Testdataset(object):
134 | """
135 | Represents a hypernymy dataset, which contains a left hand side (LHS) of hyponyms,
136 | and right hand side (RHS) of hypernyms.
137 |
138 | Params:
139 | filename: str. Filename on disk corresponding to the TSV file
140 | vocabdict: dict[str,*]. Dictionary whose keys are the vocabulary of the
141 | model to test
142 | ycolumn: str. Optional name of the label column.
143 | """
144 |
145 | def __init__(self, filename, vocabdict, ycolumn="label"):
146 | #if "" not in vocabdict:
147 | # raise ValueError("Reserved word must appear in vocabulary.")
148 |
149 | table = pd.read_table(filename)
150 |
151 | # some things require the part of speech, which may not be explicitly
152 | # given in the dataset.
153 | if "pos" not in table.columns:
154 | table["pos"] = "N"
155 | table = table[table.pos.str.lower() == "n"]
156 |
157 | # Handle MWEs by replacing the space
158 | table["word1"] = table.word1.apply(lambda x: x.replace(" ", "_").lower())
159 | table["word2"] = table.word2.apply(lambda x: x.replace(" ", "_").lower())
160 |
161 | if vocabdict:
162 | self.word1_inv = table.word1.apply(vocabdict.__contains__)
163 | self.word2_inv = table.word2.apply(vocabdict.__contains__)
164 | else:
165 | self.word1_inv = table.word1.apply(lambda x: True)
166 | self.word2_inv = table.word2.apply(lambda x: True)
167 |
168 | # Always evaluate on lemmas
169 | #table["word1"] = table.word1
170 | #table["word2"] = table.word2
171 | table["word1"] = table.word1.apply(lemmatizer.lemmatize)
172 | table["word2"] = table.word2.apply(lemmatizer.lemmatize)
173 |
174 | self.table = table
175 | self.labels = np.array(table[ycolumn])
176 | if "fold" in table:
177 | self.folds = table["fold"]
178 | else:
179 | self.folds = np.array(["test"] * len(self.table))
180 |
181 | self.table["is_oov"] = self.oov_mask
182 |
183 | def __len__(self):
184 | return len(self.table)
185 |
186 | @property
187 | def hypos(self):
188 | return np.array(self.table.word1)
189 |
190 | @property
191 | def hypers(self):
192 | return np.array(self.table.word2)
193 |
194 | @property
195 | def invocab_mask(self):
196 | return self.word1_inv & self.word2_inv
197 |
198 | @property
199 | def oov_mask(self):
200 | return ~self.invocab_mask
201 |
202 | @property
203 | def val_mask(self):
204 | return np.array(self.folds == "val")
205 |
206 | @property
207 | def test_mask(self):
208 | return np.array(self.folds == "test")
209 |
210 |
211 | @property
212 | def train_mask(self):
213 | return np.array(self.folds == "train")
214 |
215 | @property
216 | def train_inv_mask(self):
217 | return self.invocab_mask & self.train_mask
218 |
219 | @property
220 | def val_inv_mask(self):
221 | return self.invocab_mask & self.val_mask
222 |
223 | @property
224 | def test_inv_mask(self):
225 | return self.invocab_mask & self.test_mask
226 |
227 | @property
228 | def y(self):
229 | return self.labels
230 |
231 |
232 |
233 |
234 |
235 |
236 |
--------------------------------------------------------------------------------
/utils/util.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import gensim
3 | import time
4 | from gensim.models import Word2Vec
5 | from scipy.linalg import norm
6 |
7 |
8 | class MySentences(object):
9 | def __init__(self, fileName):
10 | self.fileName = fileName
11 |
12 | def __iter__(self):
13 | for line in open(self.fileName, "r"):
14 | yield line.split()
15 |
16 |
17 | def cosine_distance(v1, v2):
18 |
19 | return np.dot(v1,v2) / (norm(v1) * norm(v2))
20 |
21 |
22 | def oe_score(hypo, hyper):
23 |
24 | sub = np.subtract(hypo, hyper)
25 |
26 | mid = np.maximum(sub,0)
27 | pw = norm(mid, 2)
28 |
29 | norm1 = norm(hypo, 2)
30 | norm2 = norm(hyper, 2)
31 | norm_sum = norm1 + norm2
32 |
33 | return float(pw)/norm_sum
34 |
35 |
36 | def asymmetric_distance(v1, v2, distance_metric):
37 | """
38 | Directly copy from LEAR code
39 | """
40 | #return distance(v1, v2) + norm(v1) - norm(v2)
41 |
42 | cosine_similarity = cosine_distance(v1,v2)
43 |
44 | norm1 = norm(v1, ord=2)
45 | norm2 = norm(v2, ord=2)
46 |
47 | if distance_metric == "metric_1":
48 | # |x| - |y|
49 | return cosine_similarity+ (norm2 - norm1)
50 |
51 | elif distance_metric == "metric_2":
52 | # (|x| - |y|) / (|x| + |y|)
53 |
54 | norm_difference = norm2 - norm1
55 | norm_sum = norm1 + norm2
56 |
57 | return cosine_similarity + (norm_difference / norm_sum)
58 |
59 | elif distance_metric == "metric_3":
60 |
61 | max_norm = np.maximum(norm1, norm2)
62 | norm_difference = norm2 - norm1
63 |
64 | return cosine_similarity + (norm_difference / max_norm)
65 |
66 |
67 |
68 | def load_word_vectors(file_path):
69 |
70 | print("loading vectors from ", file_path)
71 | input_dict = {}
72 |
73 | with open(file_path, "r") as in_file:
74 | lines = in_file.readlines()
75 |
76 | in_file.close()
77 |
78 | for line in lines:
79 | item = line.strip().split()
80 | dkey = item.pop(0)
81 | if len(item)!=300:
82 | continue
83 | vectors = np.array(item, dtype='float32')
84 | input_dict[dkey] = vectors
85 |
86 | print(len(input_dict), "vectors load from", file_path)
87 |
88 | return input_dict
89 |
90 | def train_word2vec(file_path, saved_path):
91 |
92 | st = time.time()
93 | sentences = MySentences(file_path) # a memory-friendly iterator
94 | model = gensim.models.Word2Vec(sentences, size=300, min_count=1, workers=30, iter=5)
95 | model.save(saved_path + "/unwak.model")
96 | print('Finished in {:.2f}'.format(time.time()-st))
97 |
98 |
99 | def load_gensim_word2vec(model_name, word_list, saved_path=None):
100 |
101 | print("Start to load word embedding ...")
102 |
103 | model = Word2Vec.load(model_name)
104 | emb = model.wv
105 | #emb = load_word_vectors("/home/cyuaq/embeddings/glove.840B.300d.txt")
106 |
107 | input_dict = {}
108 |
109 | #out = open(saved_path, "w")
110 | num = 0
111 | for word in word_list:
112 | if word in emb:
113 | num +=1
114 | input_dict[word] = emb[word]
115 | #vec = " ".join([str(each) for each in emb[word]])
116 | #out.write(word + " " + vec + "\n")
117 | #out.close()
118 | print("There are total word in word2vec: ",num)
119 | return input_dict
120 |
121 |
122 | def load_phrase_word2vec(model_name, word_list, saved_path=None):
123 |
124 | print("Start to load word embeddings ... ")
125 | model = Word2Vec.load(model_name)
126 |
127 | emb = model.wv
128 | input_dict = {}
129 |
130 | num = 0
131 | for word in word_list:
132 | if word in emb:
133 | num+=1
134 | input_dict[word] = emb[word]
135 | else:
136 | if '_' in word:
137 | tmp = word.split("_")
138 | tmp_vec = np.zeros(300,dtype=np.float32)
139 | flag = True
140 | for w in tmp:
141 | if w not in emb:
142 | flag = False
143 | break
144 | else:
145 | tmp_vec += emb[w]
146 | if flag:
147 | num +=1
148 | input_dict[word] = tmp_vec/len(tmp)
149 |
150 | print("There are total word in word2vec: ",num)
151 | assert num == len(input_dict)
152 | return input_dict
--------------------------------------------------------------------------------