├── LICENSE
├── README.md
├── bertSP
├── README.md
├── data
│ └── .gitignore
├── logs
│ └── .gitignore
├── models
│ └── .gitignore
├── requirements.txt
└── src
│ ├── distributed.py
│ ├── models
│ ├── __init__.py
│ ├── adam.py
│ ├── data_loader.py
│ ├── decoder.py
│ ├── encoder.py
│ ├── loss.py
│ ├── model_builder.py
│ ├── neural.py
│ ├── optimizers.py
│ ├── predictor.py
│ ├── reporter.py
│ └── trainer.py
│ ├── others
│ ├── __init__.py
│ ├── constants.py
│ ├── logging.py
│ ├── tokenization.py
│ └── utils.py
│ ├── prepro
│ ├── __init__.py
│ ├── data_builder.py
│ └── utils.py
│ ├── preprocess.py
│ ├── train.py
│ ├── train_baseline.py
│ └── translate
│ ├── __init__.py
│ ├── beam.py
│ └── penalties.py
├── dataset
├── README.md
├── SparqlResults.py
├── SparqlServer.py
├── ner
│ ├── allennlp_ner
│ │ ├── README.md
│ │ ├── createlist.py
│ │ ├── nel.py
│ │ └── ner_stats.py
│ └── strner
│ │ ├── README.md
│ │ ├── create_entity_list.py
│ │ ├── createlist.py
│ │ ├── redump_ascii_disamb_list.py
│ │ ├── str_nel.py
│ │ ├── str_tag.sh
│ │ └── unnormalized_entity_counts.py
├── precompute_local_subgraphs.py
└── precompute_local_types.py
├── evaluation
├── README.md
├── actions.py
├── constants.py
├── executor.py
├── meters.py
├── run_subtype_lf.py
└── summarise_results.py
├── lasagneSP
├── LICENCE
├── README.md
├── args.py
├── dataset.py
├── execute_all.sh
├── graph.py
├── inference.py
├── knowledge_graph
│ └── knowledge_graph.py
├── model.py
├── myconstants.py
├── prreprocess_command.sh
├── requirements.txt
├── scripts
│ ├── __init__.py
│ ├── bert_embeddings.py
│ └── csqa_elasticse.py
├── train.py
└── utils.py
└── sparql-server
├── README.md
├── RWStore.properties
├── json_to_triples.py
├── load_ttl.sh
└── wd_prefix.ttl
/LICENSE:
--------------------------------------------------------------------------------
1 | BSD 3-Clause License
2 |
3 | Copyright (c) 2022, Edinburgh NLP
4 | All rights reserved.
5 |
6 | Redistribution and use in source and binary forms, with or without
7 | modification, are permitted provided that the following conditions are met:
8 |
9 | 1. Redistributions of source code must retain the above copyright notice, this
10 | list of conditions and the following disclaimer.
11 |
12 | 2. Redistributions in binary form must reproduce the above copyright notice,
13 | this list of conditions and the following disclaimer in the documentation
14 | and/or other materials provided with the distribution.
15 |
16 | 3. Neither the name of the copyright holder nor the names of its
17 | contributors may be used to endorse or promote products derived from
18 | this software without specific prior written permission.
19 |
20 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
23 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
24 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
26 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
27 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
28 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Semantic Parsing for Conversational Question Answering over Knowledge Graphs
2 |
3 | This repository contains data and code for our semantic parsing dataset SPICE. Please contact Laura Perez-Beltrachini (lperez@ed.ac.uk) or Parag Jain (parag.jain@ed.ac.uk) if you have any questions.
4 |
5 |
6 | ## SPICE Dataset
7 |
8 | The dataset is available for download from [here](https://drive.google.com/file/d/1v7sZrRrdNmnjBrvHStCMX4TQIzDN6z4K/view?usp=sharing).
9 |
10 | The scripts for the different annotations (e.g., NER/NEL) on the dataset can be found [here](./dataset).
11 |
12 | ## Sparql Server for SPICE Knowledge Graph
13 |
14 | Instructions on how to set up the SPICE knowledge graph can be found [here](./sparql-server).
15 |
16 | ## Baselines
17 |
18 | ### BertSP
19 |
20 | You will find code and instructions related to our BertSP baseline [here](./bertSP).
21 |
22 | ### LasagneSP
23 |
24 | You will find code and instructions related to our LasagneSP baseline [here](./lasagneSP).
25 |
26 |
27 | ## Evaluation
28 |
29 | For instructions on the evaluation scripts see [here](./evaluation).
30 |
31 |
--------------------------------------------------------------------------------
/bertSP/README.md:
--------------------------------------------------------------------------------
1 | # BertSP Baseline
2 |
3 | Here you will find the commands that are needed to prepare the data, train, validate and do inference with BertSP model.
4 |
5 | ## Preprocess Data
6 | You need to run the following two steps. Note that our data preprocessing can use either Bert tokenizer or the tokenizations present in the "context" json field generated by the original Lasagne NER annotations [here](https://github.com/endrikacupaj/LASAGNE/tree/master/annotate_csqa/ner_annotators). We used the "context" field in our experiments.
7 |
8 | ### Prepare data to be used by the BertSP baseline
9 |
10 | Below an example command to prepare the SPICE dataset with the data (and format) as required by the BertSP model.
11 | This will produce a new set of .json files.
12 |
13 | ```commandline
14 | python src/preprocess.py \
15 | -mode format_to_lines \
16 | -save_path ${JSON_PATH_DATA} \
17 | -n_cpus 20 \
18 | -use_bert_basic_tokenizer false \
19 | -data_path ${DATA_PATH} \
20 | -tgt_dict ${TGT_DICT} \
21 | -shard_size 5000 \
22 | -dataset ${DATASET} \
23 | -log_file ${HOME}/${LOGFILE} \
24 | -kb_graph ${HOME}'/knowledge_graph' \
25 | -nentities 'strnel' \
26 | -types 'linked'
27 | ```
28 |
29 | For the ```-nentities``` flag there are four possible modes: *gold* (it will use all available gold entity identifier
30 | annotations), *lgnel* (will use gold entity identifier whose names can be found in the user utterances), *allennel*
31 | (will use entity identifiers that were linked through AllenNLP NER + ElasticSearch, this requires the input data,
32 | ```-data_path```, to be annotated accordingly, see script for this annotation [here](../dataset)), and *strnel*
33 | (will use entity identifiers that were linked through String Match with KG symbols, this requires the input data,
34 | ```-data_path```, to be annotated accordingly, see script for this annotation [here](../dataset))
35 |
36 | For the ```-types``` flag possible values are: *gold* (will use gold type annotations) and *linked* will use
37 | types that were linked through String Match with KG symbols).
38 |
39 | Note that the target fixed dictionary (required in flag ```tgt_dict```) can be found with the files that make up the dataset, the name of the file is *global2.dict*.
40 |
41 | #### Prepare data for generalisation splits
42 |
43 | To prepare data for the generalisation spplits instead of the original train/valid/test, you need to add to the
44 | *preprocess.py* command above the flag ```-mapsplits``` and the file that contains the conversation IDs that
45 | goes into each of the generalisation splits, i.e., flag ```-mapfile```.
46 |
47 | You can find more about the files for the generalisation splits in the dataset description folder [here](../dataset).
48 |
49 | ### Generate binary files
50 |
51 | Takes the data prepared in the previous step and generates binary .pt files.
52 |
53 | ```commandline
54 | python src/preprocess.py \
55 | -mode format_to_bert \
56 | -raw_path ${JSON_PATH} \
57 | -save_path ${BERT_DATA_PATH} \
58 | -data_path ${DATA_PATH} \
59 | -tgt_dict ${TGT_DICT} \
60 | -lower \
61 | -n_cpus 20 \
62 | -dataset ${DATASET} \
63 | -log_file ${HOME}/${LOGFILE}
64 | ```
65 |
66 | ## Train BertSP
67 |
68 | Run the following command to train BertSP. This will train for 100k steps and save checkpoints.
69 |
70 | ```commandline
71 | python src/train.py \
72 | -mode train \
73 | -tgt_dict ${TGT_DICT} \
74 | -bert_data_path ${BERT_DATA_PATH} \
75 | -dec_dropout 0.2 \
76 | -sep_optim true \
77 | -lr_bert 0.00002 \
78 | -lr_dec 0.001 \
79 | -save_checkpoint_steps 2000 \
80 | -batch_size 1 \
81 | -train_steps 100000 \
82 | -report_every 50 \
83 | -accum_count 5 \
84 | -use_bert_emb true \
85 | -use_interval true \
86 | -warmup_steps_bert 20000 \
87 | -warmup_steps_dec 10000 \
88 | -max_pos 512 \
89 | -max_length 512 \
90 | -min_length 10 \
91 | -beam_size 1 \
92 | -alpha 0.95 \
93 | -visible_gpus 0,1,2,3 \
94 | -label_smoothing 0 \
95 | -model_path ${MODEL_PATH} \
96 | -log_file ${LOG_PATH}/${LOGFILE}
97 | ```
98 |
99 | ### Validate Checkpoints
100 |
101 | Run the following command for checkpoint selection. The command runs one checkpoint at a time.
102 |
103 | ```commandline
104 | python src/train.py \
105 | -mode validate \
106 | -valid_from ${MODEL_PATH}/${CHECKPOINT} \
107 | -batch_size 10 \
108 | -test_batch_size 10 \
109 | -tgt_dict ${TGT_DICT} \
110 | -bert_data_path ${BERT_DATA_PATH} \
111 | -log_file logs/base_bert_sparql_csqa_val \
112 | -model_path ${MODEL_PATH} \
113 | -sep_optim true \
114 | -use_interval true \
115 | -visible_gpus 1 \
116 | -max_pos 512 \
117 | -max_length 512 \
118 | -min_length 20 \
119 | -test_split ${SPLIT} \
120 | -log_file $logf"/stats.${STEP}.log"
121 | ```
122 |
123 | ## Run Inference with BertSP
124 |
125 | The following script will generate a .json file with Sparql predictions (parse) for each user utterance.
126 | The format of this file is according to the format required by the evaluation scripts [here](../evaluation).
127 |
128 | ```commandline
129 | python src/train.py \
130 | -mode test \
131 | -test_from ${MODEL_PATH}/${CHECKPOINT} \
132 | -batch_size 1 \
133 | -test_batch_size 1 \
134 | -tgt_dict ${TGT_DICT} \
135 | -bert_data_path ${BERT_DATA_PATH} \
136 | -log_file logs/base_bert_sparql_csqa_val \
137 | -model_path ${MODEL_PATH} \
138 | -test_split test \
139 | -sep_optim true \
140 | -use_interval true \
141 | -visible_gpus 1 \
142 | -max_pos 512 \
143 | -max_length 512 \
144 | -min_length 10 \
145 | -alpha 0.95 \
146 | -beam_size 5 \
147 | -dosubset ${TEST_VALID_SUBSET} \
148 | -result_path results/${MODEL_NAME}/baseline.${TEST_VALID_SUBSET_STR}
149 | ```
150 |
151 | Note that the flag ```-dosubset``` is used run inference on a subset of files from the test split.
152 | The pre-processing (discussed at the beginning of this README) will shard conversations from each split into .json/.pt
153 | files (e.g., *json_data.valid.41.json*). The ```-dosubset``` flag allows to give a regular expression to specify
154 | the a range of shard IDs to run inference on, e.g., for our configuration of shards (0 to 49) we can use the following
155 | to do inference on shards with IDs starting with *4*.
156 | ```
157 | TEST_VALID_SUBSET='4[0-9]+'
158 | TEST_VALID_SUBSET_STR='40-49'
159 | ```
160 | Predictions will be saved in .json with the shard IDs, e.g. *baseline.40-49.96000.test_Logical Reasoning (All).json*
161 | If the option ```-dosubset``` is not used only a single file containing predictions for all shards will be created, e.g.,
162 | *baseline.96000.test_Logical Reasoning (All).json*.
163 |
164 | ## Evaluation
165 |
166 | For evaluation of generated outputs see the evaluation scripts [here](../evaluation).
167 |
--------------------------------------------------------------------------------
/bertSP/data/.gitignore:
--------------------------------------------------------------------------------
1 | *
2 | !.gitignore
3 |
--------------------------------------------------------------------------------
/bertSP/logs/.gitignore:
--------------------------------------------------------------------------------
1 | *
2 | !.gitignore
3 |
--------------------------------------------------------------------------------
/bertSP/models/.gitignore:
--------------------------------------------------------------------------------
1 | *
2 | !.gitignore
3 |
--------------------------------------------------------------------------------
/bertSP/requirements.txt:
--------------------------------------------------------------------------------
1 | transformers==4.7.0
2 | Unidecode==1.2.0
3 | torchtext==0.4.0
4 | ujson==4.0.2
5 | elasticsearch==7.8.1
6 | numpy==1.17.4
7 | torch==1.6.0
8 | tensorboardX==1.9
--------------------------------------------------------------------------------
/bertSP/src/distributed.py:
--------------------------------------------------------------------------------
1 | """ Pytorch Distributed utils
2 | This piece of code was heavily inspired by the equivalent of Fairseq-py
3 | https://github.com/pytorch/fairseq
4 | """
5 |
6 |
7 | from __future__ import print_function
8 |
9 | import math
10 | import pickle
11 |
12 | import torch.distributed
13 |
14 | from others.logging import logger
15 |
16 |
17 | def is_master(gpu_ranks, device_id):
18 | return gpu_ranks[device_id] == 0
19 |
20 |
21 | def multi_init(device_id, world_size,gpu_ranks):
22 | print(gpu_ranks)
23 | dist_init_method = 'tcp://localhost:10000'
24 | dist_world_size = world_size
25 | torch.distributed.init_process_group(
26 | backend='nccl', init_method=dist_init_method,
27 | world_size=dist_world_size, rank=gpu_ranks[device_id])
28 | gpu_rank = torch.distributed.get_rank()
29 | if not is_master(gpu_ranks, device_id):
30 | # print('not master')
31 | logger.disabled = True
32 |
33 | return gpu_rank
34 |
35 |
36 |
37 | def all_reduce_and_rescale_tensors(tensors, rescale_denom,
38 | buffer_size=10485760):
39 | """All-reduce and rescale tensors in chunks of the specified size.
40 |
41 | Args:
42 | tensors: list of Tensors to all-reduce
43 | rescale_denom: denominator for rescaling summed Tensors
44 | buffer_size: all-reduce chunk size in bytes
45 | """
46 | # buffer size in bytes, determine equiv. # of elements based on data type
47 | buffer_t = tensors[0].new(
48 | math.ceil(buffer_size / tensors[0].element_size())).zero_()
49 | buffer = []
50 |
51 | def all_reduce_buffer():
52 | # copy tensors into buffer_t
53 | offset = 0
54 | for t in buffer:
55 | numel = t.numel()
56 | buffer_t[offset:offset+numel].copy_(t.view(-1))
57 | offset += numel
58 |
59 | # all-reduce and rescale
60 | torch.distributed.all_reduce(buffer_t[:offset])
61 | buffer_t.div_(rescale_denom)
62 |
63 | # copy all-reduced buffer back into tensors
64 | offset = 0
65 | for t in buffer:
66 | numel = t.numel()
67 | t.view(-1).copy_(buffer_t[offset:offset+numel])
68 | offset += numel
69 |
70 | filled = 0
71 | for t in tensors:
72 | sz = t.numel() * t.element_size()
73 | if sz > buffer_size:
74 | # tensor is bigger than buffer, all-reduce and rescale directly
75 | torch.distributed.all_reduce(t)
76 | t.div_(rescale_denom)
77 | elif filled + sz > buffer_size:
78 | # buffer is full, all-reduce and replace buffer with grad
79 | all_reduce_buffer()
80 | buffer = [t]
81 | filled = sz
82 | else:
83 | # add tensor to buffer
84 | buffer.append(t)
85 | filled += sz
86 |
87 | if len(buffer) > 0:
88 | all_reduce_buffer()
89 |
90 |
91 | def all_gather_list(data, max_size=4096):
92 | """Gathers arbitrary data from all nodes into a list."""
93 | world_size = torch.distributed.get_world_size()
94 | if not hasattr(all_gather_list, '_in_buffer') or \
95 | max_size != all_gather_list._in_buffer.size():
96 | all_gather_list._in_buffer = torch.cuda.ByteTensor(max_size)
97 | all_gather_list._out_buffers = [
98 | torch.cuda.ByteTensor(max_size)
99 | for i in range(world_size)
100 | ]
101 | in_buffer = all_gather_list._in_buffer
102 | out_buffers = all_gather_list._out_buffers
103 |
104 | enc = pickle.dumps(data)
105 | enc_size = len(enc)
106 | if enc_size + 2 > max_size:
107 | raise ValueError(
108 | 'encoded data exceeds max_size: {}'.format(enc_size + 2))
109 | assert max_size < 255*256
110 | in_buffer[0] = enc_size // 255 # this encoding works for max_size < 65k
111 | in_buffer[1] = enc_size % 255
112 | in_buffer[2:enc_size+2] = torch.ByteTensor(list(enc))
113 |
114 | torch.distributed.all_gather(out_buffers, in_buffer.cuda())
115 |
116 | results = []
117 | for i in range(world_size):
118 | out_buffer = out_buffers[i]
119 | size = (255 * out_buffer[0].item()) + out_buffer[1].item()
120 |
121 | bytes_list = bytes(out_buffer[2:size+2].tolist())
122 | result = pickle.loads(bytes_list)
123 | results.append(result)
124 | return results
125 |
--------------------------------------------------------------------------------
/bertSP/src/models/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/EdinburghNLP/SPICE/4afa4404b02f59d175976b7e02583fdf41c23c3a/bertSP/src/models/__init__.py
--------------------------------------------------------------------------------
/bertSP/src/models/adam.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | from torch.optim.optimizer import Optimizer
4 |
5 |
6 | class Adam(Optimizer):
7 | r"""Implements Adam algorithm.
8 |
9 | It has been proposed in `Adam: A Method for Stochastic Optimization`_.
10 |
11 | Arguments:
12 | params (iterable): iterable of parameters to optimize or dicts defining
13 | parameter groups
14 | lr (float, optional): learning rate (default: 1e-3)
15 | betas (Tuple[float, float], optional): coefficients used for computing
16 | running averages of gradient and its square (default: (0.9, 0.999))
17 | eps (float, optional): term added to the denominator to improve
18 | numerical stability (default: 1e-8)
19 | weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
20 | amsgrad (boolean, optional): whether to use the AMSGrad variant of this
21 | algorithm from the paper `On the Convergence of Adam and Beyond`_
22 | (default: False)
23 |
24 | .. _Adam\: A Method for Stochastic Optimization:
25 | https://arxiv.org/abs/1412.6980
26 | .. _On the Convergence of Adam and Beyond:
27 | https://openreview.net/forum?id=ryQu7f-RZ
28 | """
29 |
30 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8,
31 | weight_decay=0, amsgrad=False):
32 | if not 0.0 <= lr:
33 | raise ValueError("Invalid learning rate: {}".format(lr))
34 | if not 0.0 <= eps:
35 | raise ValueError("Invalid epsilon value: {}".format(eps))
36 | if not 0.0 <= betas[0] < 1.0:
37 | raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
38 | if not 0.0 <= betas[1] < 1.0:
39 | raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
40 | defaults = dict(lr=lr, betas=betas, eps=eps,
41 | weight_decay=weight_decay, amsgrad=amsgrad)
42 | super(Adam, self).__init__(params, defaults)
43 |
44 | def __setstate__(self, state):
45 | super(Adam, self).__setstate__(state)
46 | for group in self.param_groups:
47 | group.setdefault('amsgrad', False)
48 |
49 | def step(self, closure=None):
50 | """Performs a single optimization step.
51 | Arguments:
52 | closure (callable, optional): A closure that reevaluates the model
53 | and returns the loss.
54 | """
55 | loss = None
56 | if closure is not None:
57 | loss = closure()
58 |
59 |
60 | for group in self.param_groups:
61 | for p in group['params']:
62 | if p.grad is None:
63 | continue
64 | grad = p.grad.data
65 | if grad.is_sparse:
66 | raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead')
67 |
68 | state = self.state[p]
69 |
70 | # State initialization
71 | if len(state) == 0:
72 | state['step'] = 0
73 | # Exponential moving average of gradient values
74 | state['next_m'] = torch.zeros_like(p.data)
75 | # Exponential moving average of squared gradient values
76 | state['next_v'] = torch.zeros_like(p.data)
77 |
78 | next_m, next_v = state['next_m'], state['next_v']
79 | beta1, beta2 = group['betas']
80 |
81 | # Decay the first and second moment running average coefficient
82 | # In-place operations to update the averages at the same time
83 | next_m.mul_(beta1).add_(1 - beta1, grad)
84 | next_v.mul_(beta2).addcmul_(1 - beta2, grad, grad)
85 | update = next_m / (next_v.sqrt() + group['eps'])
86 |
87 | # Just adding the square of the weights to the loss function is *not*
88 | # the correct way of using L2 regularization/weight decay with Adam,
89 | # since that will interact with the m and v parameters in strange ways.
90 | #
91 | # Instead we want to decay the weights in a manner that doesn't interact
92 | # with the m/v parameters. This is equivalent to adding the square
93 | # of the weights to the loss with plain (non-momentum) SGD.
94 | if group['weight_decay'] > 0.0:
95 | update += group['weight_decay'] * p.data
96 |
97 | lr_scheduled = group['lr']
98 |
99 | update_with_lr = lr_scheduled * update
100 | p.data.add_(-update_with_lr)
101 |
102 | state['step'] += 1
103 |
104 | # step_size = lr_scheduled * math.sqrt(bias_correction2) / bias_correction1
105 | # No bias correction
106 | # bias_correction1 = 1 - beta1 ** state['step']
107 | # bias_correction2 = 1 - beta2 ** state['step']
108 |
109 | return loss
--------------------------------------------------------------------------------
/bertSP/src/models/decoder.py:
--------------------------------------------------------------------------------
1 | """
2 | Implementation of "Attention is All You Need"
3 | """
4 |
5 | import torch
6 | import torch.nn as nn
7 | import numpy as np
8 |
9 | from models.encoder import PositionalEncoding
10 | from models.neural import MultiHeadedAttention, PositionwiseFeedForward, DecoderState
11 |
12 | MAX_SIZE = 5000
13 |
14 |
15 | class TransformerDecoderLayer(nn.Module):
16 | """
17 | Args:
18 | d_model (int): the dimension of keys/values/queries in
19 | MultiHeadedAttention, also the input size of
20 | the first-layer of the PositionwiseFeedForward.
21 | heads (int): the number of heads for MultiHeadedAttention.
22 | d_ff (int): the second-layer of the PositionwiseFeedForward.
23 | dropout (float): dropout probability(0-1.0).
24 | self_attn_type (string): type of self-attention scaled-dot, average
25 | """
26 |
27 | def __init__(self, d_model, heads, d_ff, dropout):
28 | super(TransformerDecoderLayer, self).__init__()
29 |
30 |
31 | self.self_attn = MultiHeadedAttention(
32 | heads, d_model, dropout=dropout)
33 |
34 | self.context_attn = MultiHeadedAttention(
35 | heads, d_model, dropout=dropout)
36 | self.feed_forward = PositionwiseFeedForward(d_model, d_ff, dropout)
37 | self.layer_norm_1 = nn.LayerNorm(d_model, eps=1e-6)
38 | self.layer_norm_2 = nn.LayerNorm(d_model, eps=1e-6)
39 | self.drop = nn.Dropout(dropout)
40 | mask = self._get_attn_subsequent_mask(MAX_SIZE)
41 | # Register self.mask as a buffer in TransformerDecoderLayer, so
42 | # it gets TransformerDecoderLayer's cuda behavior automatically.
43 | self.register_buffer('mask', mask)
44 |
45 | def forward(self, inputs, memory_bank, src_pad_mask, tgt_pad_mask,
46 | previous_input=None, layer_cache=None, step=None):
47 | """
48 | Args:
49 | inputs (`FloatTensor`): `[batch_size x 1 x model_dim]`
50 | memory_bank (`FloatTensor`): `[batch_size x src_len x model_dim]`
51 | src_pad_mask (`LongTensor`): `[batch_size x 1 x src_len]`
52 | tgt_pad_mask (`LongTensor`): `[batch_size x 1 x 1]`
53 |
54 | Returns:
55 | (`FloatTensor`, `FloatTensor`, `FloatTensor`):
56 |
57 | * output `[batch_size x 1 x model_dim]`
58 | * attn `[batch_size x 1 x src_len]`
59 | * all_input `[batch_size x current_step x model_dim]`
60 |
61 | """
62 | dec_mask = torch.gt(tgt_pad_mask +
63 | self.mask[:, :tgt_pad_mask.size(1),
64 | :tgt_pad_mask.size(1)], 0)
65 | input_norm = self.layer_norm_1(inputs)
66 | all_input = input_norm
67 | if previous_input is not None:
68 | all_input = torch.cat((previous_input, input_norm), dim=1)
69 | dec_mask = None
70 |
71 | query = self.self_attn(all_input, all_input, input_norm,
72 | mask=dec_mask,
73 | layer_cache=layer_cache,
74 | type="self")
75 |
76 | query = self.drop(query) + inputs
77 |
78 | query_norm = self.layer_norm_2(query)
79 | mid = self.context_attn(memory_bank, memory_bank, query_norm,
80 | mask=src_pad_mask,
81 | layer_cache=layer_cache,
82 | type="context")
83 | output = self.feed_forward(self.drop(mid) + query)
84 |
85 | return output, all_input
86 | # return output
87 |
88 | def _get_attn_subsequent_mask(self, size):
89 | """
90 | Get an attention mask to avoid using the subsequent info.
91 |
92 | Args:
93 | size: int
94 |
95 | Returns:
96 | (`LongTensor`):
97 |
98 | * subsequent_mask `[1 x size x size]`
99 | """
100 | attn_shape = (1, size, size)
101 | subsequent_mask = np.triu(np.ones(attn_shape), k=1).astype('uint8')
102 | subsequent_mask = torch.from_numpy(subsequent_mask)
103 | return subsequent_mask
104 |
105 | class TransformerDecoder(nn.Module):
106 | """
107 | The Transformer decoder from "Attention is All You Need".
108 |
109 |
110 | .. mermaid::
111 |
112 | graph BT
113 | A[input]
114 | B[multi-head self-attn]
115 | BB[multi-head src-attn]
116 | C[feed forward]
117 | O[output]
118 | A --> B
119 | B --> BB
120 | BB --> C
121 | C --> O
122 |
123 |
124 | Args:
125 | num_layers (int): number of encoder layers.
126 | d_model (int): size of the model
127 | heads (int): number of heads
128 | d_ff (int): size of the inner FF layer
129 | dropout (float): dropout parameters
130 | embeddings (:obj:`onmt.modules.Embeddings`):
131 | embeddings to use, should have positional encodings
132 | attn_type (str): if using a seperate copy attention
133 | """
134 |
135 | def __init__(self, num_layers, d_model, heads, d_ff, dropout, embeddings):
136 | super(TransformerDecoder, self).__init__()
137 |
138 | # Basic attributes.
139 | self.decoder_type = 'transformer'
140 | self.num_layers = num_layers
141 | self.embeddings = embeddings
142 | self.pos_emb = PositionalEncoding(dropout,self.embeddings.embedding_dim)
143 |
144 |
145 | # Build TransformerDecoder.
146 | self.transformer_layers = nn.ModuleList(
147 | [TransformerDecoderLayer(d_model, heads, d_ff, dropout)
148 | for _ in range(num_layers)])
149 |
150 | self.layer_norm = nn.LayerNorm(d_model, eps=1e-6)
151 |
152 | def forward(self, tgt, memory_bank, state, memory_lengths=None,
153 | step=None, cache=None,memory_masks=None):
154 | """
155 | See :obj:`onmt.modules.RNNDecoderBase.forward()`
156 | """
157 |
158 | src_words = state.src
159 | tgt_words = tgt
160 | src_batch, src_len = src_words.size()
161 | tgt_batch, tgt_len = tgt_words.size()
162 |
163 | # Run the forward pass of the TransformerDecoder.
164 | # emb = self.embeddings(tgt, step=step)
165 | emb = self.embeddings(tgt)
166 | assert emb.dim() == 3 # len x batch x embedding_dim
167 |
168 | output = self.pos_emb(emb, step)
169 |
170 | src_memory_bank = memory_bank
171 | padding_idx = self.embeddings.padding_idx
172 | tgt_pad_mask = tgt_words.data.eq(padding_idx).unsqueeze(1) \
173 | .expand(tgt_batch, tgt_len, tgt_len)
174 |
175 | if (not memory_masks is None):
176 | src_len = memory_masks.size(-1)
177 | src_pad_mask = memory_masks.expand(src_batch, tgt_len, src_len)
178 |
179 | else:
180 | src_pad_mask = src_words.data.eq(padding_idx).unsqueeze(1) \
181 | .expand(src_batch, tgt_len, src_len)
182 |
183 | if state.cache is None:
184 | saved_inputs = []
185 |
186 | for i in range(self.num_layers):
187 | prev_layer_input = None
188 | if state.cache is None:
189 | if state.previous_input is not None:
190 | prev_layer_input = state.previous_layer_inputs[i]
191 | output, all_input \
192 | = self.transformer_layers[i](
193 | output, src_memory_bank,
194 | src_pad_mask, tgt_pad_mask,
195 | previous_input=prev_layer_input,
196 | layer_cache=state.cache["layer_{}".format(i)]
197 | if state.cache is not None else None,
198 | step=step)
199 | if state.cache is None:
200 | saved_inputs.append(all_input)
201 |
202 | if state.cache is None:
203 | saved_inputs = torch.stack(saved_inputs)
204 |
205 | output = self.layer_norm(output)
206 |
207 | # Process the result and update the attentions.
208 |
209 | if state.cache is None:
210 | state = state.update_state(tgt, saved_inputs)
211 |
212 | return output, state
213 |
214 | def init_decoder_state(self, src, memory_bank,
215 | with_cache=False):
216 | """ Init decoder state """
217 | state = TransformerDecoderState(src)
218 | if with_cache:
219 | state._init_cache(memory_bank, self.num_layers)
220 | return state
221 |
222 | class TransformerDecoderState(DecoderState):
223 | """ Transformer Decoder state base class """
224 |
225 | def __init__(self, src):
226 | """
227 | Args:
228 | src (FloatTensor): a sequence of source words tensors
229 | with optional feature tensors, of size (len x batch).
230 | """
231 | self.src = src
232 | self.previous_input = None
233 | self.previous_layer_inputs = None
234 | self.cache = None
235 |
236 | @property
237 | def _all(self):
238 | """
239 | Contains attributes that need to be updated in self.beam_update().
240 | """
241 | if (self.previous_input is not None
242 | and self.previous_layer_inputs is not None):
243 | return (self.previous_input,
244 | self.previous_layer_inputs,
245 | self.src)
246 | else:
247 | return (self.src,)
248 |
249 | def detach(self):
250 | if self.previous_input is not None:
251 | self.previous_input = self.previous_input.detach()
252 | if self.previous_layer_inputs is not None:
253 | self.previous_layer_inputs = self.previous_layer_inputs.detach()
254 | self.src = self.src.detach()
255 |
256 | def update_state(self, new_input, previous_layer_inputs):
257 | state = TransformerDecoderState(self.src)
258 | state.previous_input = new_input
259 | state.previous_layer_inputs = previous_layer_inputs
260 | return state
261 |
262 | def _init_cache(self, memory_bank, num_layers):
263 | self.cache = {}
264 |
265 | for l in range(num_layers):
266 | layer_cache = {
267 | "memory_keys": None,
268 | "memory_values": None
269 | }
270 | layer_cache["self_keys"] = None
271 | layer_cache["self_values"] = None
272 | self.cache["layer_{}".format(l)] = layer_cache
273 |
274 | def repeat_beam_size_times(self, beam_size):
275 | """ Repeat beam_size times along batch dimension. """
276 | self.src = self.src.data.repeat(1, beam_size, 1)
277 |
278 | def map_batch_fn(self, fn):
279 | def _recursive_map(struct, batch_dim=0):
280 | for k, v in struct.items():
281 | if v is not None:
282 | if isinstance(v, dict):
283 | _recursive_map(v)
284 | else:
285 | struct[k] = fn(v, batch_dim)
286 |
287 | self.src = fn(self.src, 0)
288 | if self.cache is not None:
289 | _recursive_map(self.cache)
290 |
291 |
292 |
293 |
--------------------------------------------------------------------------------
/bertSP/src/models/encoder.py:
--------------------------------------------------------------------------------
1 | import math
2 |
3 | import torch
4 | import torch.nn as nn
5 |
6 | from models.neural import MultiHeadedAttention, PositionwiseFeedForward
7 |
8 |
9 | class Classifier(nn.Module):
10 | def __init__(self, hidden_size):
11 | super(Classifier, self).__init__()
12 | self.linear1 = nn.Linear(hidden_size, 1)
13 | self.sigmoid = nn.Sigmoid()
14 |
15 | def forward(self, x, mask_cls):
16 | h = self.linear1(x).squeeze(-1)
17 | sent_scores = self.sigmoid(h) * mask_cls.float()
18 | return sent_scores
19 |
20 |
21 | class PositionalEncoding(nn.Module):
22 |
23 | def __init__(self, dropout, dim, max_len=5000):
24 | pe = torch.zeros(max_len, dim)
25 | position = torch.arange(0, max_len).unsqueeze(1)
26 | div_term = torch.exp((torch.arange(0, dim, 2, dtype=torch.float) *
27 | -(math.log(10000.0) / dim)))
28 | pe[:, 0::2] = torch.sin(position.float() * div_term)
29 | pe[:, 1::2] = torch.cos(position.float() * div_term)
30 | pe = pe.unsqueeze(0)
31 | super(PositionalEncoding, self).__init__()
32 | self.register_buffer('pe', pe)
33 | self.dropout = nn.Dropout(p=dropout)
34 | self.dim = dim
35 |
36 | def forward(self, emb, step=None):
37 | emb = emb * math.sqrt(self.dim)
38 | if (step):
39 | emb = emb + self.pe[:, step][:, None, :]
40 |
41 | else:
42 | emb = emb + self.pe[:, :emb.size(1)]
43 | emb = self.dropout(emb)
44 | return emb
45 |
46 | def get_emb(self, emb):
47 | return self.pe[:, :emb.size(1)]
48 |
49 |
50 | class TransformerEncoderLayer(nn.Module):
51 | def __init__(self, d_model, heads, d_ff, dropout):
52 | super(TransformerEncoderLayer, self).__init__()
53 |
54 | self.self_attn = MultiHeadedAttention(
55 | heads, d_model, dropout=dropout)
56 | self.feed_forward = PositionwiseFeedForward(d_model, d_ff, dropout)
57 | self.layer_norm = nn.LayerNorm(d_model, eps=1e-6)
58 | self.dropout = nn.Dropout(dropout)
59 |
60 | def forward(self, iter, query, inputs, mask):
61 | if (iter != 0):
62 | input_norm = self.layer_norm(inputs)
63 | else:
64 | input_norm = inputs
65 |
66 | mask = mask.unsqueeze(1)
67 | context = self.self_attn(input_norm, input_norm, input_norm,
68 | mask=mask)
69 | out = self.dropout(context) + inputs
70 | return self.feed_forward(out)
71 |
72 |
73 | class ExtTransformerEncoder(nn.Module):
74 | def __init__(self, d_model, d_ff, heads, dropout, num_inter_layers=0):
75 | super(ExtTransformerEncoder, self).__init__()
76 | self.d_model = d_model
77 | self.num_inter_layers = num_inter_layers
78 | self.pos_emb = PositionalEncoding(dropout, d_model)
79 | self.transformer_inter = nn.ModuleList(
80 | [TransformerEncoderLayer(d_model, heads, d_ff, dropout)
81 | for _ in range(num_inter_layers)])
82 | self.dropout = nn.Dropout(dropout)
83 | self.layer_norm = nn.LayerNorm(d_model, eps=1e-6)
84 | self.wo = nn.Linear(d_model, 1, bias=True)
85 | self.sigmoid = nn.Sigmoid()
86 |
87 | def forward(self, top_vecs, mask):
88 | """ See :obj:`EncoderBase.forward()`"""
89 |
90 | batch_size, n_sents = top_vecs.size(0), top_vecs.size(1)
91 | pos_emb = self.pos_emb.pe[:, :n_sents]
92 | x = top_vecs * mask[:, :, None].float()
93 | x = x + pos_emb
94 |
95 | for i in range(self.num_inter_layers):
96 | x = self.transformer_inter[i](i, x, x, 1 - mask) # all_sents * max_tokens * dim
97 |
98 | x = self.layer_norm(x)
99 | sent_scores = self.sigmoid(self.wo(x))
100 | sent_scores = sent_scores.squeeze(-1) * mask.float()
101 |
102 | return sent_scores
103 |
104 |
--------------------------------------------------------------------------------
/bertSP/src/models/optimizers.py:
--------------------------------------------------------------------------------
1 | """ Optimizers class """
2 | import torch
3 | import torch.optim as optim
4 | from torch.nn.utils import clip_grad_norm_
5 |
6 |
7 | # from onmt.utils import use_gpu
8 | # from models.adam import Adam
9 |
10 |
11 | def use_gpu(opt):
12 | """
13 | Creates a boolean if gpu used
14 | """
15 | return (hasattr(opt, 'gpu_ranks') and len(opt.gpu_ranks) > 0) or \
16 | (hasattr(opt, 'gpu') and opt.gpu > -1)
17 |
18 | def build_optim(model, opt, checkpoint):
19 | """ Build optimizer """
20 | saved_optimizer_state_dict = None
21 |
22 | if opt.train_from:
23 | optim = checkpoint['optim']
24 | # We need to save a copy of optim.optimizer.state_dict() for setting
25 | # the, optimizer state later on in Stage 2 in this method, since
26 | # the method optim.set_parameters(model.parameters()) will overwrite
27 | # optim.optimizer, and with ith the values stored in
28 | # optim.optimizer.state_dict()
29 | saved_optimizer_state_dict = optim.optimizer.state_dict()
30 | else:
31 | optim = Optimizer(
32 | opt.optim, opt.learning_rate, opt.max_grad_norm,
33 | lr_decay=opt.learning_rate_decay,
34 | start_decay_steps=opt.start_decay_steps,
35 | decay_steps=opt.decay_steps,
36 | beta1=opt.adam_beta1,
37 | beta2=opt.adam_beta2,
38 | adagrad_accum=opt.adagrad_accumulator_init,
39 | decay_method=opt.decay_method,
40 | warmup_steps=opt.warmup_steps)
41 |
42 | optim.set_parameters(model.named_parameters())
43 |
44 | if opt.train_from:
45 | optim.optimizer.load_state_dict(saved_optimizer_state_dict)
46 | if use_gpu(opt):
47 | for state in optim.optimizer.state.values():
48 | for k, v in state.items():
49 | if torch.is_tensor(v):
50 | state[k] = v.cuda()
51 |
52 | if (optim.method == 'adam') and (len(optim.optimizer.state) < 1):
53 | raise RuntimeError(
54 | "Error: loaded Adam optimizer from existing model" +
55 | " but optimizer state is empty")
56 |
57 | return optim
58 |
59 |
60 | class MultipleOptimizer(object):
61 | """ Implement multiple optimizers needed for sparse adam """
62 |
63 | def __init__(self, op):
64 | """ ? """
65 | self.optimizers = op
66 |
67 | def zero_grad(self):
68 | """ ? """
69 | for op in self.optimizers:
70 | op.zero_grad()
71 |
72 | def step(self):
73 | """ ? """
74 | for op in self.optimizers:
75 | op.step()
76 |
77 | @property
78 | def state(self):
79 | """ ? """
80 | return {k: v for op in self.optimizers for k, v in op.state.items()}
81 |
82 | def state_dict(self):
83 | """ ? """
84 | return [op.state_dict() for op in self.optimizers]
85 |
86 | def load_state_dict(self, state_dicts):
87 | """ ? """
88 | assert len(state_dicts) == len(self.optimizers)
89 | for i in range(len(state_dicts)):
90 | self.optimizers[i].load_state_dict(state_dicts[i])
91 |
92 |
93 | class Optimizer(object):
94 | """
95 | Controller class for optimization. Mostly a thin
96 | wrapper for `optim`, but also useful for implementing
97 | rate scheduling beyond what is currently available.
98 | Also implements necessary methods for training RNNs such
99 | as grad manipulations.
100 |
101 | Args:
102 | method (:obj:`str`): one of [sgd, adagrad, adadelta, adam]
103 | lr (float): learning rate
104 | lr_decay (float, optional): learning rate decay multiplier
105 | start_decay_steps (int, optional): step to start learning rate decay
106 | beta1, beta2 (float, optional): parameters for adam
107 | adagrad_accum (float, optional): initialization parameter for adagrad
108 | decay_method (str, option): custom decay options
109 | warmup_steps (int, option): parameter for `noam` decay
110 | model_size (int, option): parameter for `noam` decay
111 |
112 | We use the default parameters for Adam that are suggested by
113 | the original paper https://arxiv.org/pdf/1412.6980.pdf
114 | These values are also used by other established implementations,
115 | e.g. https://www.tensorflow.org/api_docs/python/tf/train/AdamOptimizer
116 | https://keras.io/optimizers/
117 | Recently there are slightly different values used in the paper
118 | "Attention is all you need"
119 | https://arxiv.org/pdf/1706.03762.pdf, particularly the value beta2=0.98
120 | was used there however, beta2=0.999 is still arguably the more
121 | established value, so we use that here as well
122 | """
123 |
124 | def __init__(self, method, learning_rate, max_grad_norm,
125 | lr_decay=1, start_decay_steps=None, decay_steps=None,
126 | beta1=0.9, beta2=0.999,
127 | adagrad_accum=0.0,
128 | decay_method=None,
129 | warmup_steps=4000, weight_decay=0):
130 | self.last_ppl = None
131 | self.learning_rate = learning_rate
132 | self.original_lr = learning_rate
133 | self.max_grad_norm = max_grad_norm
134 | self.method = method
135 | self.lr_decay = lr_decay
136 | self.start_decay_steps = start_decay_steps
137 | self.decay_steps = decay_steps
138 | self.start_decay = False
139 | self._step = 0
140 | self.betas = [beta1, beta2]
141 | self.adagrad_accum = adagrad_accum
142 | self.decay_method = decay_method
143 | self.warmup_steps = warmup_steps
144 | self.weight_decay = weight_decay
145 |
146 | def set_parameters(self, params):
147 | """ ? """
148 | self.params = []
149 | self.sparse_params = []
150 | for k, p in params:
151 | if p.requires_grad:
152 | if self.method != 'sparseadam' or "embed" not in k:
153 | self.params.append(p)
154 | else:
155 | self.sparse_params.append(p)
156 | if self.method == 'sgd':
157 | self.optimizer = optim.SGD(self.params, lr=self.learning_rate)
158 | elif self.method == 'adagrad':
159 | self.optimizer = optim.Adagrad(self.params, lr=self.learning_rate)
160 | for group in self.optimizer.param_groups:
161 | for p in group['params']:
162 | self.optimizer.state[p]['sum'] = self.optimizer\
163 | .state[p]['sum'].fill_(self.adagrad_accum)
164 | elif self.method == 'adadelta':
165 | self.optimizer = optim.Adadelta(self.params, lr=self.learning_rate)
166 | elif self.method == 'adam':
167 | self.optimizer = optim.Adam(self.params, lr=self.learning_rate,
168 | betas=self.betas, eps=1e-9)
169 | else:
170 | raise RuntimeError("Invalid optim method: " + self.method)
171 |
172 | def _set_rate(self, learning_rate):
173 | self.learning_rate = learning_rate
174 | if self.method != 'sparseadam':
175 | self.optimizer.param_groups[0]['lr'] = self.learning_rate
176 | else:
177 | for op in self.optimizer.optimizers:
178 | op.param_groups[0]['lr'] = self.learning_rate
179 |
180 | def step(self):
181 | """Update the model parameters based on current gradients.
182 |
183 | Optionally, will employ gradient modification or update learning
184 | rate.
185 | """
186 | self._step += 1
187 |
188 | # Decay method used in tensor2tensor.
189 | if self.decay_method == "noam":
190 | self._set_rate(
191 | self.original_lr *
192 | min(self._step ** (-0.5),
193 | self._step * self.warmup_steps**(-1.5)))
194 |
195 | else:
196 | if ((self.start_decay_steps is not None) and (
197 | self._step >= self.start_decay_steps)):
198 | self.start_decay = True
199 | if self.start_decay:
200 | if ((self._step - self.start_decay_steps)
201 | % self.decay_steps == 0):
202 | self.learning_rate = self.learning_rate * self.lr_decay
203 |
204 | if self.method != 'sparseadam':
205 | self.optimizer.param_groups[0]['lr'] = self.learning_rate
206 |
207 | if self.max_grad_norm:
208 | clip_grad_norm_(self.params, self.max_grad_norm)
209 | self.optimizer.step()
210 |
211 |
212 |
--------------------------------------------------------------------------------
/bertSP/src/models/reporter.py:
--------------------------------------------------------------------------------
1 | """ Report manager utility """
2 | from __future__ import print_function
3 | from datetime import datetime
4 |
5 | import time
6 | import math
7 | import sys
8 |
9 | from distributed import all_gather_list
10 | from others.logging import logger
11 |
12 |
13 | def build_report_manager(opt):
14 | if opt.tensorboard:
15 | from tensorboardX import SummaryWriter
16 | writer = SummaryWriter(opt.tensorboard_log_dir
17 | + datetime.now().strftime("/%b-%d_%H-%M-%S"),
18 | comment="Unmt")
19 | else:
20 | writer = None
21 |
22 | report_mgr = ReportMgr(opt.report_every, start_time=-1,
23 | tensorboard_writer=writer)
24 | return report_mgr
25 |
26 |
27 | class ReportMgrBase(object):
28 | """
29 | Report Manager Base class
30 | Inherited classes should override:
31 | * `_report_training`
32 | * `_report_step`
33 | """
34 |
35 | def __init__(self, report_every, start_time=-1.):
36 | """
37 | Args:
38 | report_every(int): Report status every this many sentences
39 | start_time(float): manually set report start time. Negative values
40 | means that you will need to set it later or use `start()`
41 | """
42 | self.report_every = report_every
43 | self.progress_step = 0
44 | self.start_time = start_time
45 |
46 | def start(self):
47 | self.start_time = time.time()
48 |
49 | def log(self, *args, **kwargs):
50 | logger.info(*args, **kwargs)
51 |
52 | def report_training(self, step, num_steps, learning_rate,
53 | report_stats, multigpu=False):
54 | """
55 | This is the user-defined batch-level traing progress
56 | report function.
57 |
58 | Args:
59 | step(int): current step count.
60 | num_steps(int): total number of batches.
61 | learning_rate(float): current learning rate.
62 | report_stats(Statistics): old Statistics instance.
63 | Returns:
64 | report_stats(Statistics): updated Statistics instance.
65 | """
66 | if self.start_time < 0:
67 | raise ValueError("""ReportMgr needs to be started
68 | (set 'start_time' or use 'start()'""")
69 |
70 | if multigpu:
71 | report_stats = Statistics.all_gather_stats(report_stats)
72 |
73 | if step % self.report_every == 0:
74 | self._report_training(
75 | step, num_steps, learning_rate, report_stats)
76 | self.progress_step += 1
77 | return Statistics()
78 |
79 | def _report_training(self, *args, **kwargs):
80 | """ To be overridden """
81 | raise NotImplementedError()
82 |
83 | def report_step(self, lr, step, train_stats=None, valid_stats=None):
84 | """
85 | Report stats of a step
86 |
87 | Args:
88 | train_stats(Statistics): training stats
89 | valid_stats(Statistics): validation stats
90 | lr(float): current learning rate
91 | """
92 | self._report_step(
93 | lr, step, train_stats=train_stats, valid_stats=valid_stats)
94 |
95 | def _report_step(self, *args, **kwargs):
96 | raise NotImplementedError()
97 |
98 |
99 | class ReportMgr(ReportMgrBase):
100 | def __init__(self, report_every, start_time=-1., tensorboard_writer=None):
101 | """
102 | A report manager that writes statistics on standard output as well as
103 | (optionally) TensorBoard
104 |
105 | Args:
106 | report_every(int): Report status every this many sentences
107 | tensorboard_writer(:obj:`tensorboard.SummaryWriter`):
108 | The TensorBoard Summary writer to use or None
109 | """
110 | super(ReportMgr, self).__init__(report_every, start_time)
111 | self.tensorboard_writer = tensorboard_writer
112 |
113 | def maybe_log_tensorboard(self, stats, prefix, learning_rate, step):
114 | if self.tensorboard_writer is not None:
115 | stats.log_tensorboard(
116 | prefix, self.tensorboard_writer, learning_rate, step)
117 |
118 | def _report_training(self, step, num_steps, learning_rate,
119 | report_stats):
120 | """
121 | See base class method `ReportMgrBase.report_training`.
122 | """
123 | report_stats.output(step, num_steps,
124 | learning_rate, self.start_time)
125 |
126 | # Log the progress using the number of batches on the x-axis.
127 | self.maybe_log_tensorboard(report_stats,
128 | "progress",
129 | learning_rate,
130 | step)
131 | report_stats = Statistics()
132 |
133 | return report_stats
134 |
135 | def _report_step(self, lr, step, train_stats=None, valid_stats=None):
136 | """
137 | See base class method `ReportMgrBase.report_step`.
138 | """
139 | if train_stats is not None:
140 | self.log('Train perplexity: %g' % train_stats.ppl())
141 | self.log('Train accuracy: %g' % train_stats.accuracy())
142 |
143 | self.maybe_log_tensorboard(train_stats,
144 | "train",
145 | lr,
146 | step)
147 |
148 | if valid_stats is not None:
149 | self.log('Validation perplexity: %g' % valid_stats.ppl())
150 | self.log('Validation accuracy: %g' % valid_stats.accuracy())
151 |
152 | self.maybe_log_tensorboard(valid_stats,
153 | "valid",
154 | lr,
155 | step)
156 |
157 |
158 | class Statistics(object):
159 | """
160 | Accumulator for loss statistics.
161 | Currently calculates:
162 |
163 | * accuracy
164 | * perplexity
165 | * elapsed time
166 | """
167 |
168 | def __init__(self, loss=0, n_words=0, n_correct=0):
169 | self.loss = loss
170 | self.n_words = n_words
171 | self.n_docs = 0
172 | self.n_correct = n_correct
173 | self.n_src_words = 0
174 | self.start_time = time.time()
175 |
176 | @staticmethod
177 | def all_gather_stats(stat, max_size=4096):
178 | """
179 | Gather a `Statistics` object accross multiple process/nodes
180 |
181 | Args:
182 | stat(:obj:Statistics): the statistics object to gather
183 | accross all processes/nodes
184 | max_size(int): max buffer size to use
185 |
186 | Returns:
187 | `Statistics`, the update stats object
188 | """
189 | stats = Statistics.all_gather_stats_list([stat], max_size=max_size)
190 | return stats[0]
191 |
192 | @staticmethod
193 | def all_gather_stats_list(stat_list, max_size=4096):
194 | from torch.distributed import get_rank
195 |
196 | """
197 | Gather a `Statistics` list accross all processes/nodes
198 |
199 | Args:
200 | stat_list(list([`Statistics`])): list of statistics objects to
201 | gather accross all processes/nodes
202 | max_size(int): max buffer size to use
203 |
204 | Returns:
205 | our_stats(list([`Statistics`])): list of updated stats
206 | """
207 | # Get a list of world_size lists with len(stat_list) Statistics objects
208 | all_stats = all_gather_list(stat_list, max_size=max_size)
209 |
210 | our_rank = get_rank()
211 | our_stats = all_stats[our_rank]
212 | for other_rank, stats in enumerate(all_stats):
213 | if other_rank == our_rank:
214 | continue
215 | for i, stat in enumerate(stats):
216 | our_stats[i].update(stat, update_n_src_words=True)
217 | return our_stats
218 |
219 | def update(self, stat, update_n_src_words=False):
220 | """
221 | Update statistics by suming values with another `Statistics` object
222 |
223 | Args:
224 | stat: another statistic object
225 | update_n_src_words(bool): whether to update (sum) `n_src_words`
226 | or not
227 |
228 | """
229 | self.loss += stat.loss
230 | self.n_words += stat.n_words
231 | self.n_correct += stat.n_correct
232 | self.n_docs += stat.n_docs
233 |
234 | if update_n_src_words:
235 | self.n_src_words += stat.n_src_words
236 |
237 | def accuracy(self):
238 | """ compute accuracy """
239 | return 100 * (self.n_correct / self.n_words)
240 |
241 | def xent(self):
242 | """ compute cross entropy """
243 | return self.loss / self.n_words
244 |
245 | def ppl(self):
246 | """ compute perplexity """
247 | return math.exp(min(self.loss / self.n_words, 100))
248 |
249 | def elapsed_time(self):
250 | """ compute elapsed time """
251 | return time.time() - self.start_time
252 |
253 | def output(self, step, num_steps, learning_rate, start):
254 | """Write out statistics to stdout.
255 |
256 | Args:
257 | step (int): current step
258 | n_batch (int): total batches
259 | start (int): start time of step.
260 | """
261 | t = self.elapsed_time()
262 | logger.info(
263 | ("Step %2d/%5d; acc: %6.2f; ppl: %5.2f; xent: %4.2f; " +
264 | "lr: %7.8f; %3.0f/%3.0f tok/s; %6.0f sec")
265 | % (step, num_steps,
266 | self.accuracy(),
267 | self.ppl(),
268 | self.xent(),
269 | learning_rate,
270 | self.n_src_words / (t + 1e-5),
271 | self.n_words / (t + 1e-5),
272 | time.time() - start))
273 | sys.stdout.flush()
274 |
275 | def log_tensorboard(self, prefix, writer, learning_rate, step):
276 | """ display statistics to tensorboard """
277 | t = self.elapsed_time()
278 | writer.add_scalar(prefix + "/xent", self.xent(), step)
279 | writer.add_scalar(prefix + "/ppl", self.ppl(), step)
280 | writer.add_scalar(prefix + "/accuracy", self.accuracy(), step)
281 | writer.add_scalar(prefix + "/tgtper", self.n_words / t, step)
282 | writer.add_scalar(prefix + "/lr", learning_rate, step)
283 |
--------------------------------------------------------------------------------
/bertSP/src/others/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/EdinburghNLP/SPICE/4afa4404b02f59d175976b7e02583fdf41c23c3a/bertSP/src/others/__init__.py
--------------------------------------------------------------------------------
/bertSP/src/others/constants.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import re
4 | from pathlib import Path
5 |
6 | # set root path
7 | ROOT_PATH = Path(os.path.dirname(__file__))
8 |
9 | # model name
10 | MODEL_NAME = 'BertSP'
11 |
12 | # fields
13 | INPUT = 'input'
14 | STR_LOGICAL_FORM = 'str_logical_form'
15 | LOGICAL_FORM = 'logical_form'
16 | GRAPH = 'graph'
17 | NEL = 'nel'
18 | SEGMENT = 'segment'
19 | START = 'start'
20 | END = 'end'
21 | DYNBIN = 'dynbin'
22 | ENTITY_MAP = 'entity_map'
23 | GLOBAL_SYNTAX = 'global_syntax'
24 |
25 | # json annotations fields, used in precompute_local_subgraphs.py
26 | ES_LINKS = 'es_links'
27 | SPANS = 'tagged_words'
28 | ALLEN_SPANS = 'allen_tagged_words'
29 | ALLEN_TAGS = 'allennlp_tags'
30 | ALLEN_ES_LINKS = 'allen_es_links'
31 | STR_ES_LINKS = 'str_es_links'
32 | STR_SPANS = 'str_tagged_words'
33 |
34 | # helper tokens
35 | BOS_TOKEN = '[BOS]'
36 | BOS_TOKEN_BERT = '[unused6]'
37 | EOS_TOKEN = '[EOS]'
38 | EOS_TOKEN_BERT = '[unused7]'
39 | CTX_TOKEN = '[CTX]'
40 | CTX_TOKEN_BERT = '[unused2]'
41 | PAD_TOKEN = '[PAD]'
42 | UNK_TOKEN = '[UNK]'
43 | SEP_TOKEN = '[SEP]'
44 | CLS_TOKEN = '[CLS]'
45 | NA_TOKEN = 'NA'
46 | NA_TOKEN_BERT = '[unused3]'
47 | OBJ_TOKEN = '[OBJ]'
48 | SUBJ_TOKEN = '[SUBJ]'
49 | PRED_TOKEN = '[PRED]'
50 | TYPE_TOKEN = '[TYPE]'
51 | KGELEM_DELIMITER = ';'
52 | KGELEM_DELIMITER_BERT = '[unused4]'
53 | TMP_DELIMITER = '||||'
54 | RELARGS_DELIMITER = '[->]'
55 | RELARGS_DELIMITER_BERT = '[unused5]'
56 |
57 | # KB graph
58 | SUBJECT = 'subject'
59 | OBJECT = 'object'
60 | TYPE = 'type'
61 | TYPE_SUBJOBJ = 'typesubjobj'
62 |
63 | # ner tag
64 | B = 'B'
65 | I = 'I'
66 | O = 'O'
67 |
68 | # question types
69 | TOTAL = 'total'
70 | OVERALL = 'Overall'
71 | CLARIFICATION = 'Clarification'
72 | COMPARATIVE = 'Comparative Reasoning (All)'
73 | LOGICAL = 'Logical Reasoning (All)'
74 | QUANTITATIVE = 'Quantitative Reasoning (All)'
75 | SIMPLE_COREFERENCED = 'Simple Question (Coreferenced)'
76 | SIMPLE_DIRECT = 'Simple Question (Direct)'
77 | SIMPLE_ELLIPSIS = 'Simple Question (Ellipsis)'
78 | VERIFICATION = 'Verification (Boolean) (All)'
79 | QUANTITATIVE_COUNT = 'Quantitative Reasoning (Count) (All)'
80 | COMPARATIVE_COUNT = 'Comparative Reasoning (Count) (All)'
81 |
82 | # action related
83 | ENTITY = 'entity'
84 | RELATION = 'relation'
85 | TYPE = 'type'
86 | VALUE = 'value'
87 | ACTION = 'action'
88 |
89 | # other
90 | UTTERANCE = 'utterance'
91 | QUESTION_TYPE = 'question_type'
92 | DESCRIPTION = 'description'
93 | IS_CORRECT = 'is_correct'
94 | QUESTION = 'question'
95 | ANSWER = 'answer'
96 | ACTIONS = 'actions'
97 | GOLD_ACTIONS = 'sparql_delex'
98 | RESULTS = 'results'
99 | PREV_RESULTS = 'prev_results'
100 | CONTEXT_QUESTION = 'context_question'
101 | CONTEXT_ENTITIES = 'context_entities'
102 | BERT_BASE_UNCASED = 'bert-base-uncased'
103 | TURN_ID = 'turnID'
104 | USER = 'USER'
105 | SYSTEM = 'SYSTEM'
106 |
107 | # ENTITY and TYPE annotations options, defined in preprocess.py
108 | TGOLD = 'gold'
109 | TLINKED = 'linked'
110 | TNONE = 'none'
111 | NEGOLD = 'gold'
112 | NELGNEL = 'lgnel'
113 | NEALLENNEL = 'allennel'
114 | NESTRNEL = 'strnel'
115 |
116 | # max limits, truncations in inputs used in data_builder.py
117 | MAX_TYPE_RESTRICTIONS = 5
118 | MAX_LINKED_TYPES = 3 # graph from type linking, we know in average there are 2.3 gold types
119 | MAX_INPUTSEQ_LEN = 508
120 |
121 | QTYPE_DICT = {
122 | 'Comparative Reasoning (All)': 0,
123 | 'Logical Reasoning (All)': 1,
124 | 'Quantitative Reasoning (All)': 2,
125 | 'Simple Question (Coreferenced)': 3,
126 | 'Simple Question (Direct)': 4,
127 | 'Simple Question (Ellipsis)': 5,
128 | 'Verification (Boolean) (All)': 6,
129 | 'Quantitative Reasoning (Count) (All)': 7,
130 | 'Comparative Reasoning (Count) (All)': 8,
131 | 'Clarification': 9
132 | }
133 |
134 | INV_QTYPE_DICT = {}
135 | for k, v in QTYPE_DICT.items():
136 | INV_QTYPE_DICT[v] = k
137 |
138 |
139 | def get_value(question):
140 | if 'min' in question.split():
141 | value = '0'
142 | elif 'max' in question.split():
143 | value = '0'
144 | elif 'exactly' in question.split():
145 | value = re.search(r'\d+', question.split('exactly')[1])
146 | if value:
147 | value = value.group()
148 | elif 'approximately' in question.split():
149 | value = re.search(r'\d+', question.split('approximately')[1])
150 | if value:
151 | value = value.group()
152 | elif 'around' in question.split():
153 | value = re.search(r'\d+', question.split('around')[1])
154 | if value:
155 | value = value.group()
156 | elif 'atmost' in question.split():
157 | value = re.search(r'\d+', question.split('atmost')[1])
158 | if value:
159 | value = value.group()
160 | elif 'atleast' in question.split():
161 | value = re.search(r'\d+', question.split('atleast')[1])
162 | if value:
163 | value = value.group()
164 | else:
165 | print(f'Could not extract value from question: {question}')
166 | value = '0'
167 |
168 | return value
--------------------------------------------------------------------------------
/bertSP/src/others/logging.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | from __future__ import absolute_import
3 |
4 | import logging
5 |
6 | logger = logging.getLogger()
7 |
8 |
9 | def init_logger(log_file=None, log_file_level=logging.NOTSET):
10 | log_format = logging.Formatter("[%(asctime)s %(levelname)s] %(message)s")
11 | logger = logging.getLogger()
12 | logger.setLevel(logging.INFO)
13 |
14 | console_handler = logging.StreamHandler()
15 | console_handler.setFormatter(log_format)
16 | logger.handlers = [console_handler]
17 |
18 | if log_file and log_file != '':
19 | file_handler = logging.FileHandler(log_file)
20 | file_handler.setLevel(log_file_level)
21 | file_handler.setFormatter(log_format)
22 | logger.addHandler(file_handler)
23 |
24 | return logger
25 |
--------------------------------------------------------------------------------
/bertSP/src/others/utils.py:
--------------------------------------------------------------------------------
1 |
2 | def tile(x, count, dim=0):
3 | """
4 | Tiles x on dimension dim count times.
5 | """
6 | perm = list(range(len(x.size())))
7 | if dim != 0:
8 | perm[0], perm[dim] = perm[dim], perm[0]
9 | x = x.permute(perm).contiguous()
10 | out_size = list(x.size())
11 | out_size[0] *= count
12 | batch = x.size(0)
13 | x = x.view(batch, -1) \
14 | .transpose(0, 1) \
15 | .repeat(count, 1) \
16 | .transpose(0, 1) \
17 | .contiguous() \
18 | .view(*out_size)
19 | if dim != 0:
20 | x = x.permute(perm).contiguous()
21 | return x
22 |
23 | def ids_to_tokens_dynamic(w, vocab, entmap):
24 | if w < len(vocab):
25 | return vocab.itos[w]
26 | else:
27 | return list(entmap.keys())[w - len(vocab)]
28 |
--------------------------------------------------------------------------------
/bertSP/src/prepro/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/EdinburghNLP/SPICE/4afa4404b02f59d175976b7e02583fdf41c23c3a/bertSP/src/prepro/__init__.py
--------------------------------------------------------------------------------
/bertSP/src/prepro/utils.py:
--------------------------------------------------------------------------------
1 |
2 | def _get_ngrams(n, text):
3 | """Calcualtes n-grams.
4 |
5 | Args:
6 | n: which n-grams to calculate
7 | text: An array of tokens
8 |
9 | Returns:
10 | A set of n-grams
11 | """
12 | ngram_set = set()
13 | text_length = len(text)
14 | max_index_ngram_start = text_length - n
15 | for i in range(max_index_ngram_start + 1):
16 | ngram_set.add(tuple(text[i:i + n]))
17 | return ngram_set
18 |
19 |
20 | def _get_word_ngrams(n, sentences):
21 | """Calculates word n-grams for multiple sentences.
22 | """
23 | assert len(sentences) > 0
24 | assert n > 0
25 |
26 | words = sum(sentences, [])
27 | return _get_ngrams(n, words)
28 |
29 |
--------------------------------------------------------------------------------
/bertSP/src/preprocess.py:
--------------------------------------------------------------------------------
1 | #encoding=utf-8
2 |
3 |
4 | import argparse
5 | import time
6 |
7 | from others.logging import init_logger
8 | from prepro import data_builder
9 |
10 |
11 | def do_format_to_lines(args):
12 | print(time.clock())
13 | data_builder.format_to_lines(args)
14 | print(time.clock())
15 |
16 | def do_format_to_bert(args):
17 | print(time.clock())
18 | data_builder.format_to_bert(args)
19 | print(time.clock())
20 |
21 |
22 | def str2bool(v):
23 | if v.lower() in ('yes', 'true', 't', 'y', '1'):
24 | return True
25 | elif v.lower() in ('no', 'false', 'f', 'n', '0'):
26 | return False
27 | else:
28 | raise argparse.ArgumentTypeError('Boolean value expected.')
29 |
30 |
31 | if __name__ == '__main__':
32 | parser = argparse.ArgumentParser()
33 | parser.add_argument("-pretrained_model", default='bert', type=str)
34 |
35 | parser.add_argument("-mode", default='', type=str)
36 | parser.add_argument("-select_mode", default='greedy', type=str)
37 | parser.add_argument("-data_path", default='../../data/')
38 | parser.add_argument("-raw_path", default='../../line_data')
39 | parser.add_argument("-save_path", default='../../data/')
40 | parser.add_argument("-tgt_dict", default='')
41 |
42 | parser.add_argument("-shard_size", default=2000, type=int)
43 | parser.add_argument('-min_src_nsents', default=3, type=int)
44 | parser.add_argument('-max_src_nsents', default=100, type=int)
45 | parser.add_argument('-min_src_ntokens_per_sent', default=5, type=int)
46 | parser.add_argument('-max_src_ntokens_per_sent', default=200, type=int)
47 | parser.add_argument('-min_tgt_ntokens', default=5, type=int)
48 | parser.add_argument('-max_tgt_ntokens', default=500, type=int)
49 |
50 | parser.add_argument("-lower", type=str2bool, nargs='?',const=True,default=True)
51 | parser.add_argument("-use_bert_basic_tokenizer", type=str2bool, nargs='?',const=True,default=False)
52 |
53 | parser.add_argument('-log_file', default='logs/preprocess.log')
54 |
55 | parser.add_argument('-dataset', default='')
56 | parser.add_argument('-mapsplits', type=str2bool, nargs='?',const=True,default=False)
57 | parser.add_argument('-mapfile', default='')
58 | parser.add_argument('-types', default='gold', choices=['gold', 'linked', 'none'])
59 | parser.add_argument('-nentities', default='gold', choices=['gold', 'lgnel', 'allennel', 'strnel'])
60 | parser.add_argument('-kb_graph', default='knowledge_graph')
61 |
62 | parser.add_argument('-n_cpus', default=2, type=int)
63 |
64 | args = parser.parse_args()
65 | init_logger(args.log_file)
66 | eval('data_builder.'+args.mode + '(args)')
67 |
--------------------------------------------------------------------------------
/bertSP/src/train.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | """
3 | Main training workflow
4 | """
5 | from __future__ import division
6 |
7 | import argparse
8 | import os
9 | from others.logging import init_logger
10 | from train_baseline import validate_abs, train_abs, test_abs #, test_text_abs
11 |
12 | model_flags = ['hidden_size', 'ff_size', 'heads', 'emb_size', 'enc_layers', 'enc_hidden_size', 'enc_ff_size',
13 | 'dec_layers', 'dec_hidden_size', 'dec_ff_size', 'encoder', 'ff_actv', 'use_interval']
14 |
15 |
16 | def str2bool(v):
17 | if v.lower() in ('yes', 'true', 't', 'y', '1'):
18 | return True
19 | elif v.lower() in ('no', 'false', 'f', 'n', '0'):
20 | return False
21 | else:
22 | raise argparse.ArgumentTypeError('Boolean value expected.')
23 |
24 |
25 |
26 |
27 | if __name__ == '__main__':
28 | parser = argparse.ArgumentParser()
29 | parser.add_argument("-mode", default='train', type=str, choices=['train', 'validate', 'test'])
30 | parser.add_argument("-test_split", default='test', help='Used for test (i.e., inference) mode. Expect valid or test.')
31 | parser.add_argument("-data_path", default='dataset/spice')
32 | parser.add_argument("-bert_data_path", default='dataset/spice')
33 | parser.add_argument("-model_path", default='models/baseline')
34 | parser.add_argument("-result_path", default='results/baseline')
35 | parser.add_argument("-tgt_dict", default='')
36 | parser.add_argument("-temp_dir", default='temp')
37 |
38 | parser.add_argument("-batch_size", default=140, type=int)
39 | parser.add_argument("-test_batch_size", default=200, type=int)
40 |
41 | parser.add_argument("-max_pos", default=512, type=int)
42 | parser.add_argument("-use_interval", type=str2bool, nargs='?',const=True,default=True)
43 | parser.add_argument("-large", type=str2bool, nargs='?',const=True,default=False)
44 |
45 | # These two are not in use
46 | parser.add_argument("-input_syntax_voc", type=str2bool, nargs='?', const=True, default=False)
47 | parser.add_argument("-predict_qtype", type=str2bool, nargs='?', const=True, default=False)
48 |
49 | parser.add_argument("-sep_optim", type=str2bool, nargs='?',const=True,default=False)
50 | parser.add_argument("-lr_bert", default=2e-3, type=float)
51 | parser.add_argument("-lr_dec", default=2e-3, type=float)
52 | parser.add_argument("-use_bert_emb", type=str2bool, nargs='?',const=True,default=False)
53 |
54 | #parser.add_argument("-share_emb", type=str2bool, nargs='?', const=True, default=False)
55 | parser.add_argument("-finetune_bert", type=str2bool, nargs='?', const=True, default=True)
56 | parser.add_argument("-dec_dropout", default=0.2, type=float)
57 | parser.add_argument("-dec_layers", default=6, type=int)
58 | parser.add_argument("-dec_hidden_size", default=768, type=int)
59 | parser.add_argument("-dec_heads", default=8, type=int)
60 | parser.add_argument("-dec_ff_size", default=2048, type=int)
61 | parser.add_argument("-enc_hidden_size", default=512, type=int)
62 | parser.add_argument("-enc_ff_size", default=512, type=int)
63 | parser.add_argument("-enc_dropout", default=0.2, type=float)
64 | parser.add_argument("-enc_layers", default=6, type=int)
65 |
66 | parser.add_argument("-label_smoothing", default=0.0, type=float)
67 | parser.add_argument("-generator_shard_size", default=32, type=int)
68 | parser.add_argument("-alpha", default=1, type=float)
69 | parser.add_argument("-beam_size", default=1, type=int)
70 | parser.add_argument("-min_length", default=10, type=int)
71 | parser.add_argument("-max_length", default=150, type=int)
72 | parser.add_argument("-max_tgt_len", default=400, type=int)
73 | parser.add_argument("-dosubset", default='', type=str, help='used to select a subset of test/valid '
74 | 'files to run on. Should be a regular expression to '
75 | 'cover the desired shard numbers., e.g. 2[0-9] for'
76 | 'all 20, 21, ..., 29 shard files.')
77 |
78 | parser.add_argument("-param_init", default=0, type=float)
79 | parser.add_argument("-param_init_glorot", type=str2bool, nargs='?',const=True,default=True)
80 | parser.add_argument("-optim", default='adam', type=str)
81 | parser.add_argument("-lr", default=1, type=float)
82 | parser.add_argument("-beta1", default= 0.9, type=float)
83 | parser.add_argument("-beta2", default=0.999, type=float)
84 | parser.add_argument("-warmup_steps", default=8000, type=int)
85 | parser.add_argument("-warmup_steps_bert", default=8000, type=int)
86 | parser.add_argument("-warmup_steps_dec", default=8000, type=int)
87 | parser.add_argument("-max_grad_norm", default=0, type=float)
88 |
89 | parser.add_argument("-save_checkpoint_steps", default=5, type=int)
90 | parser.add_argument("-accum_count", default=1, type=int)
91 | parser.add_argument("-report_every", default=1, type=int)
92 | parser.add_argument("-train_steps", default=1000, type=int)
93 | parser.add_argument("-recall_eval", type=str2bool, nargs='?',const=True,default=False)
94 |
95 |
96 | parser.add_argument('-visible_gpus', default='-1', type=str)
97 | parser.add_argument('-gpu_ranks', default='0', type=str)
98 | parser.add_argument('-log_file', default='logs/log.log')
99 | parser.add_argument('-seed', default=222, type=int)
100 |
101 | parser.add_argument("-test_all", type=str2bool, nargs='?',const=True,default=False)
102 | parser.add_argument("-test_from", default='', type=str)
103 | parser.add_argument("-test_start_from", default=-1, type=int)
104 | parser.add_argument("-valid_from", default='', type=str)
105 |
106 | parser.add_argument("-train_from", default='')
107 | parser.add_argument("-ban_unk_token", type=str2bool, nargs='?', const=True, default=False)
108 |
109 |
110 | args = parser.parse_args()
111 | args.gpu_ranks = [int(i) for i in range(len(args.visible_gpus.split(',')))]
112 | args.world_size = len(args.gpu_ranks)
113 | os.environ["CUDA_VISIBLE_DEVICES"] = args.visible_gpus
114 |
115 | init_logger(args.log_file)
116 | device = "cpu" if args.visible_gpus == '-1' else "cuda"
117 | device_id = 0 if device == "cuda" else -1
118 |
119 | if (args.mode == 'train'):
120 | train_abs(args, device_id)
121 | elif (args.mode == 'validate'):
122 | validate_abs(args, device_id)
123 | elif (args.mode == 'test'):
124 | cp = args.test_from
125 | try:
126 | step = int(cp.split('.')[-2].split('_')[-1])
127 | except:
128 | step = 0
129 | test_abs(args, device_id, cp, step)
--------------------------------------------------------------------------------
/bertSP/src/translate/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/EdinburghNLP/SPICE/4afa4404b02f59d175976b7e02583fdf41c23c3a/bertSP/src/translate/__init__.py
--------------------------------------------------------------------------------
/bertSP/src/translate/beam.py:
--------------------------------------------------------------------------------
1 | from __future__ import division
2 | import torch
3 | from translate import penalties
4 |
5 |
6 | class Beam(object):
7 | """
8 | Class for managing the internals of the beam search process.
9 |
10 | Takes care of beams, back pointers, and scores.
11 |
12 | Args:
13 | size (int): beam size
14 | pad, bos, eos (int): indices of padding, beginning, and ending.
15 | n_best (int): nbest size to use
16 | cuda (bool): use gpu
17 | global_scorer (:obj:`GlobalScorer`)
18 | """
19 |
20 | def __init__(self, size, pad, bos, eos,
21 | n_best=1, cuda=False,
22 | global_scorer=None,
23 | min_length=0,
24 | stepwise_penalty=False,
25 | block_ngram_repeat=0,
26 | exclusion_tokens=set()):
27 |
28 | self.size = size
29 | self.tt = torch.cuda if cuda else torch
30 |
31 | # The score for each translation on the beam.
32 | self.scores = self.tt.FloatTensor(size).zero_()
33 | self.all_scores = []
34 |
35 | # The backpointers at each time-step.
36 | self.prev_ks = []
37 |
38 | # The outputs at each time-step.
39 | self.next_ys = [self.tt.LongTensor(size)
40 | .fill_(pad)]
41 | self.next_ys[0][0] = bos
42 |
43 | # Has EOS topped the beam yet.
44 | self._eos = eos
45 | self.eos_top = False
46 |
47 | # The attentions (matrix) for each time.
48 | self.attn = []
49 |
50 | # Time and k pair for finished.
51 | self.finished = []
52 | self.n_best = n_best
53 |
54 | # Information for global scoring.
55 | self.global_scorer = global_scorer
56 | self.global_state = {}
57 |
58 | # Minimum prediction length
59 | self.min_length = min_length
60 |
61 | # Apply Penalty at every step
62 | self.stepwise_penalty = stepwise_penalty
63 | self.block_ngram_repeat = block_ngram_repeat
64 | self.exclusion_tokens = exclusion_tokens
65 |
66 | def get_current_state(self):
67 | "Get the outputs for the current timestep."
68 | return self.next_ys[-1]
69 |
70 | def get_current_origin(self):
71 | "Get the backpointers for the current timestep."
72 | return self.prev_ks[-1]
73 |
74 | def advance(self, word_probs, attn_out):
75 | """
76 | Given prob over words for every last beam `wordLk` and attention
77 | `attn_out`: Compute and update the beam search.
78 |
79 | Parameters:
80 |
81 | * `word_probs`- probs of advancing from the last step (K x words)
82 | * `attn_out`- attention at the last step
83 |
84 | Returns: True if beam search is complete.
85 | """
86 | num_words = word_probs.size(1)
87 | if self.stepwise_penalty:
88 | self.global_scorer.update_score(self, attn_out)
89 | # force the output to be longer than self.min_length
90 | cur_len = len(self.next_ys)
91 | if cur_len < self.min_length:
92 | for k in range(len(word_probs)):
93 | word_probs[k][self._eos] = -1e20
94 | # Sum the previous scores.
95 | if len(self.prev_ks) > 0:
96 | beam_scores = word_probs + \
97 | self.scores.unsqueeze(1).expand_as(word_probs)
98 | # Don't let EOS have children.
99 | for i in range(self.next_ys[-1].size(0)):
100 | if self.next_ys[-1][i] == self._eos:
101 | beam_scores[i] = -1e20
102 |
103 | # Block ngram repeats
104 | if self.block_ngram_repeat > 0:
105 | ngrams = []
106 | le = len(self.next_ys)
107 | for j in range(self.next_ys[-1].size(0)):
108 | hyp, _ = self.get_hyp(le - 1, j)
109 | ngrams = set()
110 | fail = False
111 | gram = []
112 | for i in range(le - 1):
113 | # Last n tokens, n = block_ngram_repeat
114 | gram = (gram +
115 | [hyp[i].item()])[-self.block_ngram_repeat:]
116 | # Skip the blocking if it is in the exclusion list
117 | if set(gram) & self.exclusion_tokens:
118 | continue
119 | if tuple(gram) in ngrams:
120 | fail = True
121 | ngrams.add(tuple(gram))
122 | if fail:
123 | beam_scores[j] = -10e20
124 | else:
125 | beam_scores = word_probs[0]
126 | flat_beam_scores = beam_scores.view(-1)
127 | best_scores, best_scores_id = flat_beam_scores.topk(self.size, 0,
128 | True, True)
129 |
130 | self.all_scores.append(self.scores)
131 | self.scores = best_scores
132 |
133 | # best_scores_id is flattened beam x word array, so calculate which
134 | # word and beam each score came from
135 | prev_k = best_scores_id / num_words
136 | self.prev_ks.append(prev_k)
137 | self.next_ys.append((best_scores_id - prev_k * num_words))
138 | self.attn.append(attn_out.index_select(0, prev_k))
139 | self.global_scorer.update_global_state(self)
140 |
141 | for i in range(self.next_ys[-1].size(0)):
142 | if self.next_ys[-1][i] == self._eos:
143 | global_scores = self.global_scorer.score(self, self.scores)
144 | s = global_scores[i]
145 | self.finished.append((s, len(self.next_ys) - 1, i))
146 |
147 | # End condition is when top-of-beam is EOS and no global score.
148 | if self.next_ys[-1][0] == self._eos:
149 | self.all_scores.append(self.scores)
150 | self.eos_top = True
151 |
152 | def done(self):
153 | return self.eos_top and len(self.finished) >= self.n_best
154 |
155 | def sort_finished(self, minimum=None):
156 | if minimum is not None:
157 | i = 0
158 | # Add from beam until we have minimum outputs.
159 | while len(self.finished) < minimum:
160 | global_scores = self.global_scorer.score(self, self.scores)
161 | s = global_scores[i]
162 | self.finished.append((s, len(self.next_ys) - 1, i))
163 | i += 1
164 |
165 | self.finished.sort(key=lambda a: -a[0])
166 | scores = [sc for sc, _, _ in self.finished]
167 | ks = [(t, k) for _, t, k in self.finished]
168 | return scores, ks
169 |
170 | def get_hyp(self, timestep, k):
171 | """
172 | Walk back to construct the full hypothesis.
173 | """
174 | hyp, attn = [], []
175 | for j in range(len(self.prev_ks[:timestep]) - 1, -1, -1):
176 | hyp.append(self.next_ys[j + 1][k])
177 | attn.append(self.attn[j][k])
178 | k = self.prev_ks[j][k]
179 | return hyp[::-1], torch.stack(attn[::-1])
180 |
181 |
182 | class GNMTGlobalScorer(object):
183 | """
184 | NMT re-ranking score from
185 | "Google's Neural Machine Translation System" :cite:`wu2016google`
186 |
187 | Args:
188 | alpha (float): length parameter
189 | beta (float): coverage parameter
190 | """
191 |
192 | def __init__(self, alpha, length_penalty):
193 | self.alpha = alpha
194 | penalty_builder = penalties.PenaltyBuilder(length_penalty)
195 | # Term will be subtracted from probability
196 | # Probability will be divided by this
197 | self.length_penalty = penalty_builder.length_penalty()
198 |
199 | def score(self, beam, logprobs):
200 | """
201 | Rescores a prediction based on penalty functions
202 | """
203 | normalized_probs = self.length_penalty(beam,
204 | logprobs,
205 | self.alpha)
206 |
207 | return normalized_probs
208 |
209 |
--------------------------------------------------------------------------------
/bertSP/src/translate/penalties.py:
--------------------------------------------------------------------------------
1 | from __future__ import division
2 | import torch
3 |
4 |
5 | class PenaltyBuilder(object):
6 | """
7 | Returns the Length and Coverage Penalty function for Beam Search.
8 |
9 | Args:
10 | length_pen (str): option name of length pen
11 | cov_pen (str): option name of cov pen
12 | """
13 |
14 | def __init__(self, length_pen):
15 | self.length_pen = length_pen
16 |
17 | def length_penalty(self):
18 | if self.length_pen == "wu":
19 | return self.length_wu
20 | elif self.length_pen == "avg":
21 | return self.length_average
22 | else:
23 | return self.length_none
24 |
25 | """
26 | Below are all the different penalty terms implemented so far
27 | """
28 |
29 |
30 | def length_wu(self, beam, logprobs, alpha=0.):
31 | """
32 | NMT length re-ranking score from
33 | "Google's Neural Machine Translation System" :cite:`wu2016google`.
34 | """
35 |
36 | modifier = (((5 + len(beam.next_ys)) ** alpha) /
37 | ((5 + 1) ** alpha))
38 | return (logprobs / modifier)
39 |
40 | def length_average(self, beam, logprobs, alpha=0.):
41 | """
42 | Returns the average probability of tokens in a sequence.
43 | """
44 | return logprobs / len(beam.next_ys)
45 |
46 | def length_none(self, beam, logprobs, alpha=0., beta=0.):
47 | """
48 | Returns unmodified scores.
49 | """
50 | return logprobs
--------------------------------------------------------------------------------
/dataset/README.md:
--------------------------------------------------------------------------------
1 | # SPICE dataset
2 |
3 | Dataset description
4 |
5 |
6 | # Annotations on the SPICE dataset
7 |
8 | ## Entity Neighborhood Sub-Graphs
9 |
10 | Input is a SPICE dataset, this script will extract entity neighborhood sub-graphs for gold entities. Will output a SPICE dataset copy annotated with entity neighborhood sub-graphs (added json field ```'local_subgraph'```, each local_subgraph for each turn is constructed based on the entities in the previous question, previous answer and current question).
11 |
12 | ``` bash
13 | python precompute_local_subgraphs.py \
14 | --partition train \
15 | --read_folder ${SPICE_CONVERSATIONS} \
16 | --write_folder ${ANNOTATED_SPICE_CONVERSATIONS} \
17 | --json_kg_folder ${PATH_JSON_KG}
18 | ```
19 |
20 | Once annotations are done for gold entities, it's possible to add entity neighborhood sub-graphs for NER/NEL entities (e.g., AllenNLP). For this you need to specify the ```--nel_entities``` flag and ```--allennlpNER_folder``` that contains the conversations annotated with AllenNLP NER/NEL (see instructions for this script below).
21 |
22 | This script also generates the global vocabulary file, it will generate a file named ```expansion_vocab.json``` in folder ```ANNOTATED_SPICE_CONVERSATIONS```.
23 | ``` bash
24 | python precompute_local_subgraphs.py --write_folder ${SPICE_CONVERSATIONS} --task vocab
25 | ```
26 |
27 |
28 |
29 | ## Type Sub-Graphs
30 |
31 | Input is a SPICE dataset, will find KG type candidates mentioned in utterances, link to types in the KG and extract a set of relations for each of them. Will output a SPICE dataset copy annotated with type sub-graphs (the added json field is ```'type_subgraph'```).
32 |
33 | ``` bash
34 | python precompute_local_types.py \
35 | --partition train \
36 | --read_folder ${SPICE_CONVERSATIONS} \
37 | --write_folder ${ANNOTATED_SPICE_CONVERSATIONS} \
38 | --json_kg_folder ${PATH_JSON_KG}
39 | ```
40 |
41 |
42 | ## AllenNLP -based NER
43 |
44 | Allennlp based NER-NEL scripts are present [here](./ner/allennlp_ner)
45 |
46 | ## String Match -based NER
47 |
48 | String based NER-NEL script are present [here](./ner/strner)
49 |
50 |
51 |
--------------------------------------------------------------------------------
/dataset/SparqlResults.py:
--------------------------------------------------------------------------------
1 |
2 |
3 | class SparqlResults:
4 |
5 | @staticmethod
6 | def getEntitySetFromBindings(results):
7 | """
8 | :param results: {'head': {'vars': ['x']}, 'results':
9 | {'bindings': [{'x': {'type': 'uri', 'value': 'http://www.wikidata.org/entity/Q155'}},
10 | {'x': {'type': 'uri', 'value': 'http://www.wikidata.org/entity/Q159'}},
11 | {'x': {'type': 'uri', 'value': 'http://www.wikidata.org/entity/Q183'}}]}}
12 | :return: {'x': [Q155, Q159, Q183]}
13 |
14 | :param results: {'head': {}, 'boolean': False}
15 | :return: {'boolean': False}
16 | """
17 | if 'boolean' in results.keys():
18 | return {'boolean': results['boolean']}
19 | else:
20 | varBindings = {}
21 | for var in results['head']['vars']: # we expect to find one variable...!?
22 | varBindings[var] = []
23 | for bin in results['results']['bindings']:
24 | if var in bin.keys():
25 | varBindings[var].append(bin[var]['value'].split('/')[-1])
26 |
27 | return varBindings
28 |
29 |
--------------------------------------------------------------------------------
/dataset/SparqlServer.py:
--------------------------------------------------------------------------------
1 | from pymantic import sparql
2 |
3 | class SparqlServer(object):
4 | _instance = None
5 |
6 | def __init__(self):
7 | raise RuntimeError('Call instance() instead')
8 |
9 | @classmethod
10 | def instance(cls, renew=False):
11 | if cls._instance is None or renew:
12 | if renew:
13 | cls._instance.s.close()
14 | print('Creating new instance')
15 | cls._instance = sparql.SPARQLServer('http://localhost:9999/blazegraph/namespace/wd/sparql')
16 | return cls._instance
17 |
--------------------------------------------------------------------------------
/dataset/ner/allennlp_ner/README.md:
--------------------------------------------------------------------------------
1 | ### Tag based on allennlp NER and elastic search
2 |
3 | - Start elastic_search server in localhost and port 9200
4 |
5 | ```
6 | python createlist.py
7 | python nel.py \
8 | -data_path "data_path/" \
9 | -save_path "tag_data_ner" \
10 | -file_path "trainlist.txt" \
11 | -dataset 'train' \
12 | -start 0 \
13 | -end -1
14 | ```
--------------------------------------------------------------------------------
/dataset/ner/allennlp_ner/createlist.py:
--------------------------------------------------------------------------------
1 | from glob import glob
2 | from os.path import exists
3 | p='data_path/train/*'
4 | outfile='trainlist.txt'
5 | assert not exists(outfile)
6 |
7 |
8 | fp=open(outfile, 'w')
9 |
10 | files = glob(p + '/*.json')
11 |
12 | for f in files:
13 | fp.write(f + '\n')
14 |
15 | fp.close()
16 |
--------------------------------------------------------------------------------
/dataset/ner/allennlp_ner/nel.py:
--------------------------------------------------------------------------------
1 | import json
2 | import traceback
3 | import os
4 | from glob import glob
5 | from multiprocess import Pool
6 | from tqdm import tqdm
7 | import argparse
8 | import pathlib
9 | from os.path import exists
10 | from unidecode import unidecode
11 | from elasticsearch import Elasticsearch
12 | from allennlp.predictors.predictor import Predictor
13 |
14 | '''
15 | B - 'beginning'
16 | I - 'inside'
17 | L - 'last'
18 | O - 'outside'
19 | U - 'unit'
20 | '''
21 |
22 | elastic_search = Elasticsearch([{'host': 'localhost', 'port': 9200}]) # connect to elastic search server
23 | #allennlp_predictor = Predictor.from_path("https://storage.googleapis.com/allennlp-public-models/ner-model-2020.02.10.tar.gz")
24 | # do not lower case
25 |
26 | def allen_simple_token_ner(utterance, predictor_elmo):
27 | simple_tokens = utterance.lower().split()
28 | result = predictor_elmo.predict(sentence=utterance)
29 | new_ner = []
30 | s_idx = 0
31 | cache_w = ''
32 | combining = False
33 | allennlp_words = result['words']
34 | allennlp_tags = result['tags']
35 | tagged_words = []
36 |
37 | for t_idx, t in enumerate(result['tags']):
38 | if t == 'O':
39 | continue
40 | if t.startswith('B-'):
41 | cache_w += result['words'][t_idx]
42 | if t.startswith('I-'):
43 | cache_w += ' '+result['words'][t_idx]
44 | if t.startswith('L-'):
45 | cache_w += ' ' +result['words'][t_idx]
46 | tagged_words.append(cache_w)
47 | cache_w = ''
48 | if t.startswith('U-'):
49 | tagged_words.append(result['words'][t_idx])
50 |
51 |
52 | return new_ner, allennlp_words, allennlp_tags, tagged_words
53 |
54 |
55 | def elasticsearch_query(query, res_size=1):
56 | res = elastic_search.search(index='csqa_wikidata', doc_type='entities', body={'size': res_size, 'query': {'match': {'label': {'query': unidecode(query), 'fuzziness': '1'}}}})
57 | results = []
58 | for hit in res['hits']['hits']: results.append(hit['_source']['id'])
59 | return results
60 |
61 |
62 | def add_nel_file(data_file,outfile, predictor_elmo):
63 | if exists(outfile):
64 | print('File {f} already exist ... '.format(f=outfile))
65 | return
66 | print('File {f} will be processed ... '.format(f=outfile))
67 | try:
68 | return _add_nel_file(data_file,outfile, predictor_elmo)
69 | except Exception as e:
70 | print(traceback.format_exc())
71 | print('Failed ', data_file)
72 | return 0
73 |
74 |
75 | def _add_nel_file(data_file, outfile, predictor_elmo):
76 |
77 | conversation_triple_length = []
78 | conversation_triples = set()
79 | all_kg_element_set = set()
80 | input_data = []
81 |
82 | try:
83 | data = json.load(open(data_file, 'r'))
84 | except json.decoder.JSONDecodeError as e:
85 | print('Failed loading json file: ', data_file)
86 | raise e
87 | for conversation in [data]:
88 | is_clarification = False
89 | prev_user_conv = None
90 | prev_system_conv = None
91 | turns = len(conversation) // 2
92 |
93 | for i in range(turns):
94 | input = []
95 | logical_form = []
96 | # If the previous was a clarification question we basically took next
97 | # logical form so need to skip
98 | if is_clarification:
99 | is_clarification = False
100 | continue
101 | user = conversation[2*i]
102 | system = conversation[2*i + 1]
103 | utterance = user['utterance']
104 | _, _, allennlp_tags, tagged_words = allen_simple_token_ner(utterance, predictor_elmo)
105 | es_links = []
106 | for w in tagged_words:
107 | r=elasticsearch_query(query=w, res_size=5)
108 | es_links.append(r)
109 |
110 | user['tagged_words'] = tagged_words
111 | user['allennlp_tags'] = allennlp_tags
112 | user['es_links'] = es_links
113 | if 'utterance' in system.keys():
114 | utterance = system['utterance']
115 | _, _, allennlp_tags, tagged_words = allen_simple_token_ner(utterance, predictor_elmo)
116 | es_links = []
117 | for w in tagged_words:
118 | r=elasticsearch_query(query=w, res_size=5)
119 | es_links.append(r)
120 |
121 | system['tagged_words'] = tagged_words
122 | system['allennlp_tags'] = allennlp_tags
123 | system['es_links'] = es_links
124 |
125 | if exists(outfile):
126 | print('File {f} already exist ... '.format(f=outfile))
127 | #raise Exception('File exist check input')
128 | base_folder = os.path.dirname(outfile)
129 | pathlib.Path(base_folder).mkdir(exist_ok=True, parents=True)
130 |
131 | json.dump(data, open(outfile, 'w', encoding='utf8'), indent=2, ensure_ascii=False)
132 |
133 |
134 | def process_files(args_list, colour=None):
135 | print('Process files ...')
136 | predictor_elmo = Predictor.from_path("https://storage.googleapis.com/allennlp-public-models/ner-elmo.2021-02-12.tar.gz")
137 | for ifile, ofile in tqdm(args_list, total=len(args_list), colour=colour):
138 | add_nel_file(ifile, ofile, predictor_elmo)
139 |
140 | def tag_data(args):
141 | allfiles = []
142 | for f in open(args.file_path, 'r').readlines():
143 | allfiles.append(f.rstrip())
144 | if args.end > 0:
145 | allfiles = allfiles[args.start:args.end]
146 | print('Processing from {s} and {e}, will save at {p}'.format(s=args.start,e=args.end,p=args.save_path))
147 | #args_list = [(inp_file, os.path.join(args.save_path, corpus_type, os.path.basename(inp_file))) for inp_file in allfiles]
148 | def get_inp_out_file(filename, local_path, out_dir):
149 | lsys='/disk/scratch/parag/tagged/' # set this path according to the list
150 | filename = filename.replace('.tagged', '')
151 | inp_file = filename.replace(lsys, local_path)
152 | outfile = filename.replace(lsys, out_dir) + '.tagged'
153 | return (inp_file, outfile)
154 | args_list = [get_inp_out_file(inp_file, args.data_path, args.save_path) for inp_file in allfiles]
155 | process_files(args_list)
156 |
157 |
158 | def str2bool(v):
159 | if v.lower() in ('yes', 'true', 't', 'y', '1'):
160 | return True
161 | elif v.lower() in ('no', 'false', 'f', 'n', '0'):
162 | return False
163 | else:
164 | raise argparse.ArgumentTypeError('Boolean value expected.')
165 |
166 |
167 | if __name__ == '__main__':
168 | # python preprocess.py -data_path /home/s1959796/csqparsing/dataset/data/version_splits -save_path /home/s1959796/csqparsing/nel_data -dataset train -debug 1
169 | parser = argparse.ArgumentParser()
170 | parser.add_argument("-data_path", required=False)
171 | parser.add_argument("-save_path", required=True)
172 | parser.add_argument("-file_path", required=False)
173 |
174 | parser.add_argument('-dataset', default='')
175 | parser.add_argument('-n_cpus', default=10, type=int)
176 | parser.add_argument('-start', default=10, type=int)
177 | parser.add_argument('-end', default=10, type=int)
178 | parser.add_argument('-debug', nargs='?',const=False,default=False)
179 | args = parser.parse_args()
180 | tag_data(args)
--------------------------------------------------------------------------------
/dataset/ner/allennlp_ner/ner_stats.py:
--------------------------------------------------------------------------------
1 | from concurrent.futures import process
2 | import math
3 | import sys
4 | from glob import glob
5 | from multiprocess import Pool
6 | import json
7 | from multiprocessing.dummy import Pool as ThreadPool
8 | #import matplotlib.pyplot as plt
9 | import numpy as np
10 | import traceback
11 | from collections import OrderedDict
12 | import pickle
13 | import argparse
14 | import gc
15 | from tqdm import tqdm
16 | import math
17 |
18 | RELARGS_DELIMITER_BERT = '[SEP]'
19 |
20 |
21 | class F1scoreMeter(object):
22 | def __init__(self):
23 | self.reset()
24 |
25 | def reset(self):
26 | self.tp = 0
27 | self.fp = 0
28 | self.fn = 0
29 | self.precision = 0
30 | self.recall = 0
31 | self.f1_score = 0
32 | self.exact_match_acc = 0
33 | self.correct_exact_match = 0.0
34 | self.number_of_instance = 0.0
35 |
36 | def update(self, gold, result):
37 | self.number_of_instance += 1
38 |
39 | self.tp += len(result.intersection(gold))
40 | self.fp += len(result.difference(gold))
41 | self.fn += len(gold.difference(result))
42 | if self.tp > 0 or self.fp > 0:
43 | self.precision = self.tp / (self.tp + self.fp)
44 | if self.tp > 0 or self.fn > 0:
45 | self.recall = self.tp / (self.tp + self.fn)
46 | if self.precision > 0 or self.recall > 0:
47 | self.f1_score = 2 * self.precision * self.recall / (self.precision + self.recall)
48 |
49 | self.exact_match_acc = self.correct_exact_match / self.number_of_instance
50 |
51 |
52 |
53 | def lineariseAsTriples(subgraph):
54 | label_triples = []
55 | wikidata_entity_triples = []
56 | entity_label_dict = {}
57 | entity_type_dict = {}
58 | all_entity_set = set()
59 | all_entity_relation_set = set()
60 | all_rel_set = set()
61 | all_type_set = set()
62 | edge_set = set()
63 | all_ent_rel_types_text = []
64 |
65 | def preprocess_file(data_file):
66 | try:
67 | return _preprocess_file(data_file)
68 | except Exception as e:
69 | print(traceback.format_exc())
70 | print('Failed ', data_file)
71 | return 0
72 |
73 |
74 | def _preprocess_file(data_file):
75 |
76 | try:
77 | data = json.load(open(data_file, 'r'))
78 | except json.decoder.JSONDecodeError as e:
79 | print('Failed loading json file: ', data_file)
80 | raise e
81 |
82 | user_tp = []
83 | user_fp = []
84 | user_fn = []
85 | sys_tp = []
86 | sys_fp = []
87 | sys_fn = []
88 | for conversation in [data]:
89 |
90 | turns = len(conversation) // 2
91 |
92 | for i in range(turns):
93 |
94 | user = conversation[2*i]
95 | system = conversation[2*i + 1]
96 | if 'entities_in_utterance' in user.keys() or 'entities' in user.keys():
97 | user_tags = set(user['es_links'])
98 | if 'entities_in_utterance' in user.keys():
99 | user_gold = set(user['entities_in_utterance'])
100 | else:
101 | user_gold = set(user['entities'])
102 | utp = len(user_tags.intersection(user_gold))
103 | ufp = len(user_tags.difference(user_gold))
104 | ufn = len(user_gold.difference(user_tags))
105 | user_tp.append(utp)
106 | user_fp.append(ufp)
107 | user_fn.append(ufn)
108 | #else:
109 | # print(user['utterance'], data_file)
110 |
111 |
112 | if 'entities_in_utterance' in system.keys():
113 | system_tags = set(system['es_links'])
114 | system_gold = set(system['entities_in_utterance'])
115 | utp = len(system_tags.intersection(system_gold))
116 | ufp = len(system_tags.difference(system_gold))
117 | ufn = len(system_gold.difference(system_tags))
118 | sys_tp.append(utp)
119 | sys_fp.append(ufp)
120 | sys_fn.append(ufn)
121 | else:
122 | print('entities_in_utterance not present', data_file)
123 |
124 | return user_tp, user_fp, user_fn, sys_tp, sys_fp, sys_fn
125 |
126 |
127 | def main(args):
128 | if (args.dataset != ''):
129 | split_path = args.data_path + f'/{args.dataset}/*'
130 | split_files = glob(split_path + '/*.tagged')
131 | if args.debug:
132 | split_files = split_files[:500]
133 | print('Loading files from ', split_path, len(split_files))
134 |
135 | corpora = {f'{args.dataset}': split_files}
136 | else:
137 | # do all
138 | train_path = args.data_path + '/train/*'
139 | val_path = args.data_path + '/valid/*'
140 | test_path = args.data_path + '/test/*'
141 | train_files = glob(train_path + '/*.tagged')
142 | print('Train files ', train_path, len(train_files))
143 | valid_files = glob(val_path + '/*.tagged')
144 | print('Valid files ', val_path, len(valid_files))
145 | test_files = glob(test_path + '/*.tagged')
146 | print('Test files ', test_path, len(test_files))
147 | corpora = {'train': train_files, 'valid': valid_files, 'test': test_files}
148 |
149 | for corpus_type in corpora.keys():
150 | #a_lst = [(f, args, csqa, corpus_type) for f in corpora[corpus_type]]
151 | filelist = corpora[corpus_type]
152 | pool = Pool(args.n_cpus)
153 | user_tp = []
154 | user_fp = []
155 | user_fn = []
156 | sys_tp = []
157 | sys_fp = []
158 | sys_fn = []
159 | for processed_input in tqdm(pool.imap_unordered(preprocess_file, filelist), total=len(filelist)):
160 | p = processed_input
161 | user_tp.extend(p[0])
162 | user_fp.extend(p[1])
163 | user_fn.extend(p[2])
164 | sys_tp.extend(p[3])
165 | sys_fp.extend(p[4])
166 | sys_fn.extend(p[5])
167 |
168 | pool.close()
169 | pool.join()
170 | utp = sum(user_tp)
171 | ufp = sum(user_fp)
172 | ufn = sum(user_fn)
173 | stp = sum(sys_tp)
174 | sfp = sum(sys_fp)
175 | sfn = sum(sys_fn)
176 | uprec = utp / (utp + ufp)
177 | urecall = utp / (utp + ufn)
178 | sprec = stp / (stp + sfp)
179 | srecall = stp / (stp + sfn)
180 | uf1 = (2 * uprec * urecall) / (uprec + urecall)
181 | sf1 = (2 * sprec * srecall) / (sprec + srecall)
182 | print('User F1: {f1} precision {p} recall {r} '.format(f1=uf1,p=uprec,r=urecall))
183 | print('Sys F1: {f1} precision {p} recall {r} '.format(f1=sf1,p=sprec,r=srecall))
184 |
185 |
186 |
187 |
188 |
189 | def str2bool(v):
190 | if v.lower() in ('yes', 'true', 't', 'y', '1'):
191 | return True
192 | elif v.lower() in ('no', 'false', 'f', 'n', '0'):
193 | return False
194 | else:
195 | raise argparse.ArgumentTypeError('Boolean value expected.')
196 |
197 |
198 | if __name__ == '__main__':
199 | # python preprocess.py -data_path /home/s1959796/csqparsing/dataset/data/version_splits -save_path /home/s1959796/csqparsing/processed_data -dataset train -debug 0
200 | parser = argparse.ArgumentParser()
201 | parser.add_argument("-data_path", required=True)
202 | parser.add_argument('-dataset', default='')
203 | parser.add_argument('-n_cpus', default=10, type=int)
204 | parser.add_argument('-debug', nargs='?',const=False,default=False)
205 | args = parser.parse_args()
206 | main(args)
207 |
208 |
--------------------------------------------------------------------------------
/dataset/ner/strner/README.md:
--------------------------------------------------------------------------------
1 | ## Scripts for string match based ner and nel
2 | Tested for python version 3.8.13 and pyahocorasick 1.4.4
3 |
4 |
5 | To avoid encoding issues.
6 | ```
7 | Redump json files:
8 | python redump_ascii_disamb_list.py
9 | ```
10 |
11 | ### Generate entity count file for disambibuation
12 | ```
13 | #This will create a json with entity count
14 | python unnormalized_entity_counts.py -data_path path
15 | ```
16 |
17 | ### Run automation creation and tagging
18 | ```
19 | # create list of files to annotate
20 | python createlist.py
21 | bash str_tag.sh
22 | ```
23 |
--------------------------------------------------------------------------------
/dataset/ner/strner/create_entity_list.py:
--------------------------------------------------------------------------------
1 | import ahocorasick
2 | import json
3 | import os, re
4 | from tqdm import tqdm
5 |
6 | data_path='wikidata_proc_json/wikidata_proc_json_2/'
7 | file_list = ['filtered_property_wikidata4.json' , 'items_wikidata_n.json']
8 |
9 | def preprocess(text):
10 | return '_{}_'.format(re.sub('[^a-z]', '_', text.lower()))
11 |
12 | def create_index(data_path, file_list):
13 | index = ahocorasick.Automaton()
14 | for filename in file_list:
15 | fpath = os.path.join(data_path, filename)
16 | print('Loading json file from ', fpath)
17 | id_val_dict = json.load(open(fpath, 'r'))
18 | count = 0
19 | for id, val in tqdm(id_val_dict.items(), total=len(id_val_dict)):
20 | index.add_word(preprocess(val), (id, val))
21 | count += 1
22 | print(f'Added {count} items.')
23 |
24 | index.make_automaton()
25 | return index
26 |
27 |
28 | def find_indices_and_position(text, searcher):
29 | result = dict()
30 | preptext = preprocess(text)
31 | print(preptext)
32 | for end_index, found_value in searcher.iter_long(preptext):
33 | print(found_value)
34 | text_value = found_value[1]
35 | id = found_value[1]
36 | print(end_index, text_value)
37 | end = end_index - 1
38 | start = end - len(text_value)
39 | occurrence_text = text[start:end]
40 | print('occurrence_text ', occurrence_text)
41 | result[(start, end)] = text_value
42 | return result
43 |
44 |
45 | index = create_index(data_path=data_path, file_list=file_list)
46 | utterances = open('example_utterances.txt' , 'r').readlines()
47 |
48 | for u in utterances:
49 | print(find_indices_and_position(u, index))
50 |
--------------------------------------------------------------------------------
/dataset/ner/strner/createlist.py:
--------------------------------------------------------------------------------
1 | from glob import glob
2 | from os.path import exists
3 | p='/home/s1959796/csqparsing/dataversion_aug27_2022/CSQA_v9_skg.v6_compar_spqres9_subkg2_tyTop_nelctx_cleaned/valid/*'
4 | outfile='validlist.txt'
5 | assert not exists(outfile)
6 |
7 |
8 | fp=open(outfile, 'w')
9 |
10 | files = glob(p + '/*.json')
11 |
12 | for f in files:
13 | fp.write(f + '\n')
14 |
15 | fp.close()
16 |
17 | p='/home/s1959796/csqparsing/dataversion_aug27_2022/CSQA_v9_skg.v6_compar_spqres9_subkg2_tyTop_nelctx_cleaned/test/*'
18 | outfile='testlist.txt'
19 | assert not exists(outfile)
20 |
21 |
22 | fp=open(outfile, 'w')
23 |
24 | files = glob(p + '/*.json')
25 |
26 | for f in files:
27 | fp.write(f + '\n')
28 |
29 | fp.close()
30 |
31 | p='/home/s1959796/csqparsing/dataversion_aug27_2022/CSQA_v9_skg.v6_compar_spqres9_subkg2_tyTop_nelctx_cleaned/train/*'
32 | outfile='trainlist.txt'
33 | assert not exists(outfile)
34 |
35 |
36 | fp=open(outfile, 'w')
37 |
38 | files = glob(p + '/*.json')
39 |
40 | for f in files:
41 | fp.write(f + '\n')
42 |
43 | fp.close()
44 |
--------------------------------------------------------------------------------
/dataset/ner/strner/redump_ascii_disamb_list.py:
--------------------------------------------------------------------------------
1 | import json, os
2 |
3 |
4 | data_path='/home/s1959796/spice_dataset_project/spice_project/nel_scripts/wikidata_proc_json/wikidata_proc_json_2/'
5 | file_list = ['filtered_property_wikidata4.json' , 'items_wikidata_n.json']
6 |
7 | for filename in file_list:
8 | fpath = os.path.join(data_path, filename)
9 | print('Loading json file from ', fpath)
10 | id_val_dict = json.load(open(fpath, 'r'))
11 | ofpath = fpath+'.redump'
12 | new_entity_id = {}
13 | for id, val in id_val_dict.items():
14 | if val in new_entity_id.keys():
15 | new_entity_id[val].append(id)
16 | else:
17 | new_entity_id[val] = [id]
18 | json.dump(new_entity_id, open(ofpath, 'w', encoding='utf8'), indent=2, ensure_ascii=False)
19 |
20 |
--------------------------------------------------------------------------------
/dataset/ner/strner/str_nel.py:
--------------------------------------------------------------------------------
1 | import json
2 | import traceback
3 | import os, re
4 | from glob import glob
5 | from multiprocessing import Pool
6 | from multiprocessing.pool import ThreadPool
7 | from tqdm import tqdm
8 | import argparse
9 | import pathlib
10 | import ahocorasick
11 | from os.path import exists
12 | from unidecode import unidecode
13 | import pickle
14 |
15 | basic_stops = ['where', 'did', 'how', 'many', 'where', 'when', 'which']
16 |
17 | entity_count = json.load(open('entity_count.json', 'r'))
18 | def disambiguate(listofids):
19 | counts = []
20 | for id in listofids:
21 | c=entity_count.get(id, -1)
22 | #if c > -1:
23 | # print(id)
24 | counts.append(c)
25 | max_id = counts.index(max(counts))
26 | return listofids[max_id]
27 |
28 | def preprocess(text):
29 | text = text.lower()
30 | text = text.translate(str.maketrans('', '', ",.?"))
31 | text = ' '.join([t for t in text.split() if t not in basic_stops])
32 | return text.lower()
33 |
34 | def create_index(data_path, file_list):
35 | index = ahocorasick.Automaton()
36 | for filename in file_list:
37 | fpath = os.path.join(data_path, filename)
38 | print('Loading json file from ', fpath)
39 | id_val_dict = json.load(open(fpath, 'r'))
40 | count = 0
41 | for val, idlist in tqdm(id_val_dict.items(), total=len(id_val_dict)):
42 | ## you could disambiguate later based on the counts... but if we are taking the top count then we might just add the top one in our index. This way the whole process is faster.
43 | disambiguated_id = disambiguate(idlist)
44 | index.add_word(preprocess(val), (disambiguated_id, val))
45 | count += 1
46 | print(f'Added {count} items.')
47 |
48 | index.make_automaton()
49 | return index
50 |
51 |
52 | def str_nel_long(utterance, automaton):
53 | tagged_words = []
54 | elinks = []
55 | start_end_pos = []
56 | preptext = preprocess(utterance)
57 | for end_index, found_value in automaton.iter_long(preptext):
58 | text_value = found_value[1]
59 | id = found_value[0]
60 | end = end_index - 1
61 | start = end - len(text_value)
62 | start_end_pos.append((start, end))
63 | tagged_words.append(text_value)
64 | elinks.append(id)
65 | return tagged_words, elinks, start_end_pos
66 |
67 |
68 | def str_nel(utterance, automaton):
69 | tagged_words = []
70 | elinks = []
71 | start_end_pos = []
72 | preptext = preprocess(utterance)
73 | matched_items = [] # start, end, len, value
74 | for end_index, found_value in automaton.iter(preptext):
75 | text_value = found_value[1]
76 | id = found_value[0]
77 | start_index = end_index - len(text_value) + 1
78 | if (start_index - 1 < 0 or preptext[start_index - 1] == ' ') and (end_index == len(preptext) - 1 or preptext[end_index+1] == ' '):
79 | matched_items.append((start_index, end_index, end_index - start_index, found_value))
80 |
81 | keep_items = []
82 | matched_items = sorted(matched_items, key=lambda x: x[2], reverse=True)
83 | for m in matched_items:
84 | start_index, end_index, vlen, v = m
85 | flag=True
86 | for k in keep_items:
87 | if (start_index >= k[0] and start_index <= k[1]) or (end_index >= k[0] and end_index <= k[1]): # other condition where curr string encapculated m is not needed as we have already sorted and kept bigger string first
88 | flag=False
89 | break
90 | if flag:
91 | keep_items.append(m)
92 | tagged_words.append(m[3][1])
93 | elinks.append(m[3][0])
94 | #start_end_pos.append((m[0], m[1]))
95 |
96 | return tagged_words, elinks
97 |
98 |
99 | def add_nel_file(params):
100 | data_file,outfile, automaton = params
101 | if exists(outfile):
102 | print('File {f} already exist ... '.format(f=outfile))
103 | return
104 | try:
105 | return _add_nel_file(data_file,outfile, automaton)
106 | except Exception as e:
107 | print(traceback.format_exc())
108 | print('Failed ', data_file)
109 | return 0
110 |
111 |
112 | def _add_nel_file(data_file, outfile, automaton):
113 |
114 | try:
115 | data = json.load(open(data_file, 'r'))
116 | except json.decoder.JSONDecodeError as e:
117 | print('Failed loading json file: ', data_file)
118 | raise e
119 | for conversation in [data]:
120 | is_clarification = False
121 | prev_user_conv = None
122 | prev_system_conv = None
123 | turns = len(conversation) // 2
124 |
125 | for i in range(turns):
126 | input = []
127 | logical_form = []
128 | # If the previous was a clarification question we basically took next
129 | # logical form so need to skip
130 | if is_clarification:
131 | is_clarification = False
132 | continue
133 | user = conversation[2*i]
134 | system = conversation[2*i + 1]
135 | utterance = user['utterance']
136 | tagged_words, es_links = str_nel(utterance, automaton)
137 | user['tagged_words'] = tagged_words
138 | user['es_links'] = es_links
139 |
140 | if 'utterance' in system.keys():
141 | utterance = system['utterance']
142 | tagged_words, es_links = str_nel(utterance, automaton)
143 | system['tagged_words'] = tagged_words
144 | system['es_links'] = es_links
145 |
146 |
147 |
148 | if exists(outfile):
149 | print('File {f} already exist ... '.format(f=outfile))
150 | #raise Exception('File exist check input')
151 | base_folder = os.path.dirname(outfile)
152 | pathlib.Path(base_folder).mkdir(exist_ok=True, parents=True)
153 |
154 | json.dump(data, open(outfile, 'w', encoding='utf8'), indent=2, ensure_ascii=False)
155 | return outfile
156 |
157 |
158 | def process_files_parallel(args_list, automaton, colour=None):
159 | print('Process files ...')
160 | a_lst = [(ifile, ofile, automaton) for ifile, ofile in args_list]
161 | pool = ThreadPool(args.n_cpus)
162 | outfiles = set()
163 | for processed_filename in tqdm(pool.imap_unordered(add_nel_file, a_lst), total=len(a_lst)):
164 | if processed_filename in outfiles:
165 | print('processed_filename', processed_filename)
166 | outfiles.add(processed_filename)
167 |
168 |
169 | def tag_str_all_files(args):
170 | '''
171 | after some issues in data were fixed some more files could be tagged and added
172 | '''
173 | automaton_filename = 'automaton_noproperty.pkl'
174 | data_path='wikidata_proc_json/wikidata_proc_json_2/'
175 | #file_list = ['filtered_property_wikidata4.json.redump' , 'items_wikidata_n.json.redump']
176 | file_list = ['items_wikidata_n.json.redump']
177 | if not os.path.exists(automaton_filename):
178 | automaton = create_index(data_path=data_path, file_list=file_list)
179 | pickle.dump(automaton, open(automaton_filename, 'wb'))
180 | else:
181 | automaton = pickle.load(open(automaton_filename, 'rb'))
182 | allfiles = []
183 | for f in open(args.file_path, 'r').readlines():
184 | allfiles.append(f.rstrip())
185 | if args.end > 0:
186 | allfiles = allfiles[args.start:args.end]
187 | print('Processing from {s} and {e}, will save at {p}'.format(s=args.start,e=args.end,p=args.save_path))
188 | #args_list = [(inp_file, os.path.join(args.save_path, corpus_type, os.path.basename(inp_file))) for inp_file in allfiles]
189 | def get_inp_out_file(filename, local_path, out_dir):
190 | lsys=args.data_path
191 | inp_file = filename
192 | outfile = filename.replace(lsys, out_dir) + '.strtaggedwithoutproperty'
193 | return (inp_file, outfile)
194 | args_list = [get_inp_out_file(inp_file, args.data_path, args.save_path) for inp_file in allfiles]
195 | #print(args_list)
196 | process_files_parallel(args_list, automaton)
197 |
198 |
199 | def str2bool(v):
200 | if v.lower() in ('yes', 'true', 't', 'y', '1'):
201 | return True
202 | elif v.lower() in ('no', 'false', 'f', 'n', '0'):
203 | return False
204 | else:
205 | raise argparse.ArgumentTypeError('Boolean value expected.')
206 |
207 |
208 | if __name__ == '__main__':
209 |
210 | parser = argparse.ArgumentParser()
211 | parser.add_argument("-data_path", required=False)
212 | parser.add_argument("-save_path", required=True)
213 | parser.add_argument("-file_path", required=False)
214 |
215 | parser.add_argument('-dataset', default='')
216 | parser.add_argument('-n_cpus', default=10, type=int)
217 | parser.add_argument('-start', default=10, type=int)
218 | parser.add_argument('-end', default=10, type=int)
219 | parser.add_argument('-debug', nargs='?',const=False,default=False)
220 | args = parser.parse_args()
221 | tag_str_all_files(args)
222 |
--------------------------------------------------------------------------------
/dataset/ner/strner/str_tag.sh:
--------------------------------------------------------------------------------
1 | start=0
2 | end=-1
3 | python str_nel.py \
4 | -data_path "datapath/" \
5 | -save_path "strnel_data/" \
6 | -file_path "validlist.txt" \
7 | -dataset 'valid' \
8 | -start $start \
9 | -end $end
10 |
11 | python str_nel.py \
12 | -data_path "datapath/" \
13 | -save_path "strnel_data/" \
14 | -file_path "testlist.txt" \
15 | -dataset 'test' \
16 | -start $start \
17 | -end $end \
18 | -n_cpus 5
19 |
20 | python str_nel.py \
21 | -data_path "datapath" \
22 | -save_path "strnel_data/" \
23 | -file_path "trainlist.txt" \
24 | -dataset 'train' \
25 | -start $start \
26 | -end $end \
27 | -n_cpus 5
28 |
--------------------------------------------------------------------------------
/dataset/ner/strner/unnormalized_entity_counts.py:
--------------------------------------------------------------------------------
1 | from concurrent.futures import process
2 | import math
3 | import sys
4 | from glob import glob
5 | from multiprocess import Pool
6 | import json
7 | from multiprocessing.dummy import Pool as ThreadPool
8 | #import matplotlib.pyplot as plt
9 | import numpy as np
10 | import traceback
11 | from collections import OrderedDict
12 | import pickle
13 | import argparse
14 | import gc
15 | from tqdm import tqdm
16 | import math
17 |
18 | RELARGS_DELIMITER_BERT = '[SEP]'
19 |
20 |
21 | class F1scoreMeter(object):
22 | def __init__(self):
23 | self.reset()
24 |
25 | def reset(self):
26 | self.tp = 0
27 | self.fp = 0
28 | self.fn = 0
29 | self.precision = 0
30 | self.recall = 0
31 | self.f1_score = 0
32 | self.exact_match_acc = 0
33 | self.correct_exact_match = 0.0
34 | self.number_of_instance = 0.0
35 |
36 | def update(self, gold, result):
37 | self.number_of_instance += 1
38 |
39 | self.tp += len(result.intersection(gold))
40 | self.fp += len(result.difference(gold))
41 | self.fn += len(gold.difference(result))
42 | if self.tp > 0 or self.fp > 0:
43 | self.precision = self.tp / (self.tp + self.fp)
44 | if self.tp > 0 or self.fn > 0:
45 | self.recall = self.tp / (self.tp + self.fn)
46 | if self.precision > 0 or self.recall > 0:
47 | self.f1_score = 2 * self.precision * self.recall / (self.precision + self.recall)
48 |
49 | self.exact_match_acc = self.correct_exact_match / self.number_of_instance
50 |
51 |
52 |
53 |
54 | def count_file(data_file):
55 | try:
56 | return _getcounts(data_file)
57 | except Exception as e:
58 | print(traceback.format_exc())
59 | print('Failed ', data_file)
60 | return 0
61 |
62 | def _getcounts(data_file):
63 |
64 | try:
65 | data = json.load(open(data_file, 'r'))
66 | except json.decoder.JSONDecodeError as e:
67 | print('Failed loading json file: ', data_file)
68 | raise e
69 |
70 | entity_counts = {}
71 | def add_to_dict(e):
72 | if e in entity_counts.keys():
73 | entity_counts[e] += 1
74 | else:
75 | entity_counts[e] = 1
76 |
77 | for conversation in [data]:
78 |
79 | turns = len(conversation) // 2
80 |
81 | for i in range(turns):
82 |
83 | user = conversation[2*i]
84 | system = conversation[2*i + 1]
85 | if 'entities_in_utterance' in user.keys() or 'entities' in user.keys():
86 | if 'entities_in_utterance' in user.keys():
87 | user_gold = user['entities_in_utterance']
88 | for e in user_gold:
89 | add_to_dict(e)
90 | else:
91 | user_gold = set(user['entities'])
92 | for e in user_gold:
93 | add_to_dict(e)
94 |
95 | # print(user['utterance'], data_file)
96 |
97 |
98 | if 'entities_in_utterance' in system.keys():
99 | system_gold = set(system['entities_in_utterance'])
100 | for e in system_gold:
101 | add_to_dict(e)
102 | else:
103 | print('entities_in_utterance not present', data_file)
104 |
105 | return entity_counts
106 |
107 |
108 | def main(args):
109 | global_entity_counts= {}
110 | def add_to_global_dict(e, c):
111 | if e in global_entity_counts.keys():
112 | global_entity_counts[e] += c
113 | else:
114 | global_entity_counts[e] = c
115 |
116 | if (args.dataset != ''):
117 | split_path = args.data_path + f'/{args.dataset}/*'
118 | split_files = glob(split_path + '/*.json')
119 | if args.debug:
120 | split_files = split_files[:500]
121 | print('Loading files from ', split_path, len(split_files))
122 |
123 | corpora = {f'{args.dataset}': split_files}
124 | else:
125 | # do all
126 | train_path = args.data_path + '/train/*'
127 | val_path = args.data_path + '/valid/*'
128 | test_path = args.data_path + '/test/*'
129 | train_files = glob(train_path + '/*.json')
130 | print('Train files ', train_path, len(train_files))
131 | valid_files = glob(val_path + '/*.json')
132 | print('Valid files ', val_path, len(valid_files))
133 | test_files = glob(test_path + '/*.json')
134 | print('Test files ', test_path, len(test_files))
135 | corpora = {'train': train_files, 'valid': valid_files, 'test': test_files}
136 |
137 | for corpus_type in corpora.keys():
138 | #a_lst = [(f, args, csqa, corpus_type) for f in corpora[corpus_type]]
139 | filelist = corpora[corpus_type]
140 | pool = Pool(args.n_cpus)
141 | for entity_count in tqdm(pool.imap_unordered(count_file, filelist), total=len(filelist)):
142 | for e, c in entity_count.items():
143 | add_to_global_dict(e, c)
144 |
145 |
146 | pool.close()
147 | pool.join()
148 |
149 | json.dump(global_entity_counts, open('entity_count.json', 'w', encoding='utf8'), indent=2, ensure_ascii=False)
150 |
151 | def str2bool(v):
152 | if v.lower() in ('yes', 'true', 't', 'y', '1'):
153 | return True
154 | elif v.lower() in ('no', 'false', 'f', 'n', '0'):
155 | return False
156 | else:
157 | raise argparse.ArgumentTypeError('Boolean value expected.')
158 |
159 |
160 | if __name__ == '__main__':
161 | # python unnormalized_entity_counts.py -data_path /home/s1959796/csqparsing/dataversion_aug27_2022/CSQA_v9_skg.v6_compar_spqres9_subkg2_tyTop_nelctx_cleaned/
162 | parser = argparse.ArgumentParser()
163 | parser.add_argument("-data_path", required=True)
164 | parser.add_argument('-dataset', default='')
165 | parser.add_argument('-n_cpus', default=10, type=int)
166 | parser.add_argument('-debug', nargs='?',const=False,default=False)
167 | args = parser.parse_args()
168 | main(args)
169 |
170 |
171 |
--------------------------------------------------------------------------------
/dataset/precompute_local_types.py:
--------------------------------------------------------------------------------
1 | import json
2 | from glob import glob
3 | import os
4 | import sys
5 | import requests
6 | import numpy as np
7 | import argparse
8 | from tqdm import tqdm
9 | from datetime import datetime
10 | from multiprocessing import Pool
11 |
12 | from transformers import BertTokenizer
13 |
14 | import nltk
15 | from nltk.stem import WordNetLemmatizer
16 | from nltk.corpus import stopwords
17 | from nltk.corpus import wordnet
18 |
19 |
20 | from SparqlServer import SparqlServer
21 | from SparqlResults import SparqlResults
22 |
23 | # KB graph
24 | SUBJECT = 'subject'
25 | OBJECT = 'object'
26 | TYPE = 'type'
27 |
28 | ROOT_PATH_JSON_KG = ''
29 | ROOT_PATH = ''
30 | DST_ROOT_PATH = ''
31 |
32 | # add arguments to parser
33 | parser = argparse.ArgumentParser(description='Pre-compute types sub-graphs')
34 | parser.add_argument('--partition', default='train', choices=['train', 'valid', 'test'], type=str, help='Partition to preprocess.')
35 | parser.add_argument('--read_folder', default=ROOT_PATH, help='Folder to read conversations.')
36 | parser.add_argument('--write_folder', default=DST_ROOT_PATH, help='Folder to write annotated conversations.')
37 | parser.add_argument('--refine', default=False, action='store_true', help='Refine existing type_subgraph field. DEPRECATED')
38 | parser.add_argument('--json_kg_folder', default=ROOT_PATH_JSON_KG, help='Folder that contains KG in .json format. used for faster annotation')
39 |
40 | args = parser.parse_args()
41 |
42 | # set tokenizer
43 | bert_tokenizer = BertTokenizer.from_pretrained('bert-base-uncased').tokenize
44 | lemmatizer = WordNetLemmatizer()
45 | stops = set(stopwords.words('english'))
46 |
47 | # we'll use this for efficiency as contains pre-computed type-relations.
48 | print('Loading KG .json files')
49 | id_relation = json.loads(open(f'{args.json_kg_folder}/knowledge_graph/filtered_property_wikidata4.json').read())
50 | id_entity = json.loads(open(f'{args.json_kg_folder}/knowledge_graph/items_wikidata_n.json').read())
51 | TYPE_TRIPLES = json.loads(open(f'{args.json_kg_folder}/knowledge_graph/wikidata_type_dict.json').read())
52 | REV_TYPE_TRIPLES = json.loads(open(f'{args.json_kg_folder}/knowledge_graph/wikidata_rev_type_dict.json').read())
53 | print('DONE')
54 |
55 | def loadTypeIDLabelDict():
56 | typeIDLabelDict = {}
57 | for k in TYPE_TRIPLES.keys():
58 | typeIDLabelDict[k] = id_entity[k]
59 | for k in REV_TYPE_TRIPLES.keys():
60 | typeIDLabelDict[k] = id_entity[k]
61 | return typeIDLabelDict
62 |
63 | TYPE_ID_LABEL = loadTypeIDLabelDict()
64 |
65 | def getTypesGraph(gold_types):
66 | """ Retrieve relations where the types in gold_types are either domain are range of."""
67 | tgraph = {}
68 | for t in gold_types:
69 | # just take the type and associated relations
70 | try:
71 | tg = [rel for rel in TYPE_TRIPLES[t].keys()]
72 | except KeyError:
73 | tg = []
74 |
75 | try:
76 | tg.extend([rel for rel in REV_TYPE_TRIPLES[t].keys()])
77 | except KeyError:
78 | pass
79 |
80 | tgraph[t] = list(set(tg))
81 | return tgraph
82 |
83 | def nltk_pos_tagger(nltk_tag):
84 | if nltk_tag.startswith('J'):
85 | return wordnet.ADJ
86 | elif nltk_tag.startswith('V'):
87 | return wordnet.VERB
88 | elif nltk_tag.startswith('N'):
89 | return wordnet.NOUN
90 | elif nltk_tag.startswith('R'):
91 | return wordnet.ADV
92 | else:
93 | return None
94 |
95 | def lemmatize_sentence(sentence):
96 | nltk_tagged = nltk.pos_tag(nltk.word_tokenize(sentence))
97 | wordnet_tagged = map(lambda x: (x[0], nltk_pos_tagger(x[1])), nltk_tagged)
98 | lemmatized_sentence = []
99 |
100 | for word, tag in wordnet_tagged:
101 | if tag is None:
102 | lemmatized_sentence.append(word)
103 | else:
104 | lemmatized_sentence.append(lemmatizer.lemmatize(word, tag))
105 | return " ".join(lemmatized_sentence)
106 |
107 | def getLinkTypesSubgraph(utterance, existing_type_set=None):
108 |
109 | types = []
110 | ori_utt = utterance
111 | utterance = lemmatize_sentence(utterance).split()
112 | if existing_type_set:
113 | id_label_set = {}
114 | for t in existing_type_set:
115 | id_label_set[t] = TYPE_ID_LABEL[t]
116 | else:
117 | id_label_set = TYPE_ID_LABEL
118 |
119 | for k, v in id_label_set.items():
120 | type = lemmatize_sentence(v).split()
121 | type_non_stop = [w for w in type if w not in stops.union({'number'})] #number is a wd type but freq =count operator in utterances
122 | inter = (set(utterance) & set(type_non_stop))
123 | if type_non_stop and len(inter) == len(set(type_non_stop)):
124 | types.append((k, len(set(type_non_stop))))
125 |
126 | # we know that type 'people'/'Q2472587' is never used but 'common name'/'Q502895'
127 | foo = [x for x in types if x!= ('Q2472587', 1)]
128 | if len(foo) == len(types)-1:
129 | foo.append(('Q502895', 2))
130 | types = foo
131 |
132 | # we know that type 'occupation'/'Q528892' is never used but 'occupation'/'Q12737077'
133 | types = [x for x in types if x!= ('Q528892', 1)]
134 |
135 | # remove types that are covered by longer labels ['work of art', 'art', 'work'] ==> 'art' and 'work'
136 | types = sorted(types, key=lambda x: x[1], reverse=True)
137 | if types:
138 | keep = [(TYPE_ID_LABEL[types[0][0]], types[0][0])]
139 | for t, _ in types[1:]:
140 | present = False
141 | for l, _ in keep:
142 | incl = set(TYPE_ID_LABEL[t].split()).intersection(set(l.split()))
143 | if len(incl) == len(set(TYPE_ID_LABEL[t].split())):
144 | present = True
145 | break
146 | if not present:
147 | keep.append((TYPE_ID_LABEL[t], t))
148 | types = [t for t,_ in types if t in [k for _, k in keep]]
149 | types_graph = getTypesGraph(types)
150 |
151 | return types_graph
152 |
153 | def getLabelJson(r):
154 | return id_relation[r]
155 |
156 | splits=[args.partition]
157 |
158 | # new splits directory
159 | if not os.path.isdir(DST_ROOT_PATH):
160 | os.mkdir(DST_ROOT_PATH)
161 | for sp in splits:
162 | os.mkdir(os.path.join(DST_ROOT_PATH, sp))
163 | print(f'Directory "{DST_ROOT_PATH}" created')
164 |
165 | def annotate_conversation(f):
166 | print(f)
167 | with open(f) as json_file:
168 | try:
169 | fileName = f.split('/')[-1]
170 | dirName = f.split('/')[-2]
171 | # load conversation
172 | conversation = json.load(json_file)
173 | new_conversation = []
174 | is_clarification = False
175 | turns = len(conversation) // 2
176 | for i in range(turns):
177 |
178 | if is_clarification:
179 | is_clarification = False
180 | continue
181 |
182 | user = conversation[2 * i]
183 | system = conversation[2 * i + 1]
184 |
185 | if user['question-type'] == 'Clarification':
186 | new_conversation.append(user)
187 | new_conversation.append(system)
188 |
189 | # get next context
190 | is_clarification = True
191 | next_user = conversation[2 * (i + 1)]
192 | next_system = conversation[2 * (i + 1) + 1]
193 |
194 | if args.refine and 'type_subgraph' in next_system.keys():
195 | next_system['type_subgraph'] = getLinkTypesSubgraph(user['utterance'],
196 | next_system['type_subgraph'])
197 | else:
198 | next_system['type_subgraph'] = getLinkTypesSubgraph(user['utterance'])
199 |
200 | new_conversation.append(next_user)
201 | new_conversation.append(next_system)
202 | else:
203 | if args.refine and 'type_subgraph' in system.keys():
204 | system['type_subgraph'] = getLinkTypesSubgraph(user['utterance'],
205 | system['type_subgraph'])
206 | else:
207 | system['type_subgraph'] = getLinkTypesSubgraph(user['utterance'])
208 |
209 | new_conversation.append(user)
210 | new_conversation.append(system)
211 |
212 | # write conversation
213 | assert len(conversation) == len(new_conversation)
214 |
215 | if not os.path.isdir(os.path.join(DST_ROOT_PATH, sp, dirName)):
216 | os.mkdir(os.path.join(DST_ROOT_PATH, sp, dirName))
217 | with open(f'{DST_ROOT_PATH}/{sp}/{dirName}/{fileName}', 'w') as formatted_json_file:
218 | json.dump(new_conversation, formatted_json_file, ensure_ascii=False, indent=4)
219 |
220 | except json.decoder.JSONDecodeError:
221 | print('Fail', f)
222 |
223 | for sp in splits:
224 | # read data
225 | files = glob(f'{ROOT_PATH}/{sp}/*' + '/*.json')
226 |
227 | print(f'Remain to do {len(files)}')
228 | with Pool(20) as pool:
229 | res = pool.map(annotate_conversation, files)
230 |
--------------------------------------------------------------------------------
/evaluation/README.md:
--------------------------------------------------------------------------------
1 |
2 | # Evaluation
3 |
4 | ## Run evaluation script on predicted Sparqls (i.e., models' outputs)
5 |
6 | The following script runs per question type. It will print a table with averaged results per question type and sub-type.
7 | ```--file_path``` indicates the json file with models' outputs. If the file name contains a "*" combined with
8 | ```--read_all```, all matching files will be read and combined for evaluation (e.g., we can read models' outputs saved
9 | in several files such as *baseline.0-9.100000.test_Simple Question (Direct).json*,
10 | *baseline.10-19.100000.test_Simple Question (Direct).json*, etc.). See format of models' predictions files below.
11 |
12 | ```commandline
13 | python run_subtype_lf.py \
14 | --file_path "results/${MODEL_NAME}/baseline.*.${CHECKPOINT}.${SPLIT}_Simple Question (Direct).json" \
15 | --read_all \
16 | --question_type "Simple Question (Direct)" \
17 | --em_only \
18 | --out_eval_file ${PATH_AND_NAME}
19 | ```
20 |
21 | ```--out_eval_file``` gives the name of the file where evaluation results details will be saved. This is a dictionary
22 | where entries are the evaluated aspects (i.e., question type, question sub-types, and linguistic phenomena -- Ctx=-1,
23 | Ctx<-1, ellipsis, and multiple entities--). For each entry, the different computed metrics will be stored (e.g., exact
24 | match, f1-score, etc.).
25 |
26 | When using the flag ```--context_dist_file``` it will also aggregate metrics per 'context distance'.
27 | This can be used for question types involving Coreference. If specified, it will aggregate scores for all coreference
28 | questions where the antecedent is at distance 1, distance 2, and so on.
29 | For each conversation turn, we have precomputed a file that contains the distance information used for this flag,
30 | namelly, *CSQA_v9_skg.v7_context_distance_test.log*. Each line of the file is of the form:
31 | ```
32 | test#QA_81#QA_16#2 2 Which french administrative division was that person born at ?
33 | ```
34 | The first column indicates the turn and conversation identifier (see next section). The second, the distance in the
35 | conversation where the referent of the question is mentioned (here 2 means that the referent is in 2 turns back, that
36 | is introduced in turn 0). Finally, the user question.
37 |
38 | #### Expected input json file containing models' predictions
39 |
40 | The format is the following. See below descriptions of the fields.
41 | ```
42 | [
43 | {
44 | "question_type": "Simple Question (Direct)",
45 | "description": "Simple Question|Single Entity",
46 | "question": "What is the official language of Paraguay ?",
47 | "answer": "Spanish",
48 | "actions": "SELECT ?x WHERE { wd: Q34770 wdt: P37 ?x . ?x wdt: P31 wd: Q34770 . }",
49 | "results": [
50 | "Q1321"
51 | ],
52 | "sparql_delex": "SELECT ?x WHERE { wd: Q733 wdt: P37 ?x . ?x wdt: P31 wd: Q34770 . }",
53 | "turnID": "test#QA_282#QA_72#3"
54 | },
55 |
56 | ...
57 |
58 | ]
59 | ```
60 |
61 | * question_type: the question type
62 | * description: the question sub-type
63 | * question: user utterance
64 | * answer: system utterance
65 | * actions: predicted Sparql
66 | * results: GOLD results/answer
67 | * sparql_delex: GOLD Sparql
68 | * turnID: Identifier of the conversation (e.g., ```test#QA_282#QA_72#3``` means turn position 3 in conversation
69 | of file test/QA_282/QA_72.json). Note that turn positions start from 0.
70 |
71 | #### Summarising evaluation results
72 |
73 | The following script reads the evaluation details generated for each question type generated as described above (i.e.,
74 | run_subtype_lf.py) and generates a final summary aggregation overall aspects across all questions. (Currently summarises linguistic phenomena).
75 | ```--file_path``` is the path to the folder that contains all .json files generated for each question with
76 | *run_subtype_lf.py*.
77 |
78 | ```commandline
79 | python summarise_results.py --file_path ${PATH_AND_NAME}
80 | ```
81 |
--------------------------------------------------------------------------------
/evaluation/actions.py:
--------------------------------------------------------------------------------
1 | class ActionOperator:
2 | def __init__(self, kg):
3 | self.kg = kg
4 |
5 | def find(self, e, p):
6 | if isinstance(e, list):
7 | return self.find_set(e, p)
8 |
9 | if e is None or p is None:
10 | return None
11 |
12 | if e not in self.kg.triples['subject'] or p not in self.kg.triples['subject'][e]:
13 | return set()
14 |
15 | return set(self.kg.triples['subject'][e][p])
16 |
17 | def find_reverse(self, e, p):
18 | if isinstance(e, list):
19 | return self.find_reverse_set(e, p)
20 |
21 | if e is None or p is None:
22 | return None
23 |
24 | if e not in self.kg.triples['object'] or p not in self.kg.triples['object'][e]:
25 | return set()
26 |
27 | return set(self.kg.triples['object'][e][p])
28 |
29 | def find_set(self, e_set, p):
30 | result_set = set()
31 | for e in e_set:
32 | result_set.update(self.find(e, p))
33 |
34 | return result_set
35 |
36 | def find_reverse_set(self, e_set, p):
37 | result_set = set()
38 | for e in e_set:
39 | result_set.update(self.find_reverse(e, p))
40 |
41 | return result_set
42 |
43 | def filter_type(self, ent_set, typ):
44 | if type(ent_set) is not set or typ is None:
45 | return None
46 |
47 | result = set()
48 |
49 | for o in ent_set:
50 | if (o in self.kg.entity_type and typ in self.kg.entity_type[o]):
51 | result.add(o)
52 |
53 | return result
54 |
55 | def filter_multi_types(self, ent_set, t1, t2):
56 | typ_set = {t1, t2}
57 | if type(ent_set) is not set or type(typ_set) is not set:
58 | return None
59 |
60 | result = set()
61 |
62 | for o in ent_set:
63 | if (o in self.kg.entity_type and len(typ_set.intersection(set(self.kg.entity_type[o]))) > 0):
64 | result.add(o)
65 |
66 | return result
67 |
68 | def find_tuple_counts(self, r, t1, t2):
69 | if r is None or t1 is None or t2 is None:
70 | return None
71 |
72 | tuple_count = dict()
73 |
74 | for s in self.kg.triples['relation']['subject'][r]:
75 | if (s in self.kg.entity_type and t1 in self.kg.entity_type[s]):
76 | count = 0
77 | for o in self.kg.triples['relation']['subject'][r][s]:
78 | if (o in self.kg.entity_type and t2 in self.kg.entity_type[o]):
79 | count += 1
80 |
81 | tuple_count[s] = count
82 |
83 | return tuple_count
84 |
85 | def find_reverse_tuple_counts(self, r, t1, t2):
86 | if r is None or t1 is None or t2 is None:
87 | return None
88 |
89 | tuple_count = dict()
90 |
91 | for o in self.kg.triples['relation']['object'][r]:
92 | if (o in self.kg.entity_type and t1 in self.kg.entity_type[o]):
93 | count = 0
94 | for s in self.kg.triples['relation']['object'][r][o]:
95 | if (s in self.kg.entity_type and t2 in self.kg.entity_type[s]):
96 | count += 1
97 |
98 | tuple_count[o] = count
99 |
100 | return tuple_count
101 |
102 | def greater(self, type_dict, value):
103 | return set([k for k, v in type_dict.items() if v > value and v >= 0])
104 |
105 | def less(self, type_dict, value):
106 | return set([k for k, v in type_dict.items() if v < value and v >= 0])
107 |
108 | def equal(self, type_dict, value):
109 | return set([k for k, v in type_dict.items() if v == value and v >= 0])
110 |
111 | def approx(self, type_dict, value, interval=15):
112 | # ambiguous action
113 | # simply check for more than 0
114 | return set([k for k, v in type_dict.items() if v > 0])
115 |
116 | def atmost(self, type_dict, max_value):
117 | return set([k for k, v in type_dict.items() if v <= max_value and v >= 0])
118 |
119 | def atleast(self, type_dict, min_value):
120 | return set([k for k, v in type_dict.items() if v >= min_value and v >= 0])
121 |
122 | def argmin(self, type_dict, value=0):
123 | min_value = min(type_dict.values())
124 | return set([k for k, v in type_dict.items() if v == min_value])
125 |
126 | def argmax(self, type_dict, value=0):
127 | max_value = max(type_dict.values())
128 | return set([k for k, v in type_dict.items() if v == max_value])
129 |
130 | def is_in(self, ent, set_ent):
131 | return set(ent).issubset(set_ent)
132 |
133 | def count(self, in_set):
134 | return len(in_set)
135 |
136 | def union(self, *args):
137 | if all(isinstance(x, set) for x in args):
138 | return args[0].union(*args)
139 | else:
140 | return {k: args[0].get(k, 0) + args[1].get(k, 0) for k in set(args[0]) | set(args[1])}
141 |
142 | def intersection(self, s1, s2):
143 | return s1.intersection(s2)
144 |
145 | def difference(self, s1, s2):
146 | return s1.difference(s2)
147 |
--------------------------------------------------------------------------------
/evaluation/constants.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import re
4 | from pathlib import Path
5 |
6 | # set root path
7 | ROOT_PATH = Path(os.path.dirname(__file__))
8 |
9 | # model name
10 | MODEL_NAME = 'BertSP'
11 |
12 | # fields
13 | INPUT = 'input'
14 | STR_LOGICAL_FORM = 'str_logical_form'
15 | LOGICAL_FORM = 'logical_form'
16 | GRAPH = 'graph'
17 | NEL = 'nel'
18 | SEGMENT = 'segment'
19 | START = 'start'
20 | END = 'end'
21 | DYNBIN = 'dynbin'
22 | ENTITY_MAP = 'entity_map'
23 | GLOBAL_SYNTAX = 'global_syntax'
24 |
25 | # json annotations fields, used in precompute_local_subgraphs.py
26 | ES_LINKS = 'es_links'
27 | SPANS = 'tagged_words'
28 | ALLEN_SPANS = 'allen_tagged_words'
29 | ALLEN_TAGS = 'allennlp_tags'
30 | ALLEN_ES_LINKS = 'allen_es_links'
31 | STR_ES_LINKS = 'str_es_links'
32 | STR_SPANS = 'str_tagged_words'
33 |
34 | # helper tokens
35 | BOS_TOKEN = '[BOS]'
36 | BOS_TOKEN_BERT = '[unused6]'
37 | EOS_TOKEN = '[EOS]'
38 | EOS_TOKEN_BERT = '[unused7]'
39 | CTX_TOKEN = '[CTX]'
40 | CTX_TOKEN_BERT = '[unused2]'
41 | PAD_TOKEN = '[PAD]'
42 | UNK_TOKEN = '[UNK]'
43 | SEP_TOKEN = '[SEP]'
44 | CLS_TOKEN = '[CLS]'
45 | NA_TOKEN = 'NA'
46 | NA_TOKEN_BERT = '[unused3]'
47 | OBJ_TOKEN = '[OBJ]'
48 | SUBJ_TOKEN = '[SUBJ]'
49 | PRED_TOKEN = '[PRED]'
50 | TYPE_TOKEN = '[TYPE]'
51 | KGELEM_DELIMITER = ';'
52 | KGELEM_DELIMITER_BERT = '[unused4]'
53 | TMP_DELIMITER = '||||'
54 | RELARGS_DELIMITER = '[->]'
55 | RELARGS_DELIMITER_BERT = '[unused5]'
56 |
57 | # KB graph
58 | SUBJECT = 'subject'
59 | OBJECT = 'object'
60 | TYPE = 'type'
61 | TYPE_SUBJOBJ = 'typesubjobj'
62 |
63 | # ner tag
64 | B = 'B'
65 | I = 'I'
66 | O = 'O'
67 |
68 | # question types
69 | TOTAL = 'total'
70 | OVERALL = 'Overall'
71 | CLARIFICATION = 'Clarification'
72 | COMPARATIVE = 'Comparative Reasoning (All)'
73 | LOGICAL = 'Logical Reasoning (All)'
74 | QUANTITATIVE = 'Quantitative Reasoning (All)'
75 | SIMPLE_COREFERENCED = 'Simple Question (Coreferenced)'
76 | SIMPLE_DIRECT = 'Simple Question (Direct)'
77 | SIMPLE_ELLIPSIS = 'Simple Question (Ellipsis)'
78 | VERIFICATION = 'Verification (Boolean) (All)'
79 | QUANTITATIVE_COUNT = 'Quantitative Reasoning (Count) (All)'
80 | COMPARATIVE_COUNT = 'Comparative Reasoning (Count) (All)'
81 |
82 | # action related
83 | ENTITY = 'entity'
84 | RELATION = 'relation'
85 | TYPE = 'type'
86 | VALUE = 'value'
87 | ACTION = 'action'
88 |
89 | # other
90 | UTTERANCE = 'utterance'
91 | QUESTION_TYPE = 'question_type'
92 | DESCRIPTION = 'description'
93 | IS_CORRECT = 'is_correct'
94 | QUESTION = 'question'
95 | ANSWER = 'answer'
96 | ACTIONS = 'actions'
97 | GOLD_ACTIONS = 'sparql_delex'
98 | RESULTS = 'results'
99 | PREV_RESULTS = 'prev_results'
100 | CONTEXT_QUESTION = 'context_question'
101 | CONTEXT_ENTITIES = 'context_entities'
102 | BERT_BASE_UNCASED = 'bert-base-uncased'
103 | TURN_ID = 'turnID'
104 | USER = 'USER'
105 | SYSTEM = 'SYSTEM'
106 |
107 | # ENTITY and TYPE annotations options, defined in preprocess.py
108 | TGOLD = 'gold'
109 | TLINKED = 'linked'
110 | TNONE = 'none'
111 | NEGOLD = 'gold'
112 | NELGNEL = 'lgnel'
113 | NEALLENNEL = 'allennel'
114 | NESTRNEL = 'strnel'
115 |
116 | # max limits, truncations in inputs used in data_builder.py
117 | MAX_TYPE_RESTRICTIONS = 5
118 | MAX_LINKED_TYPES = 3 # graph from type linking, we know in average there are 2.3 gold types
119 | MAX_INPUTSEQ_LEN = 508
120 |
121 | # Eval script output json keys
122 | INSTANCES = 'instances'
123 | ACCURACY = 'accuracy'
124 | PRECISION = 'precision'
125 | RECALL = 'recall'
126 | F1SCORE = 'f1score'
127 | MACRO_F1SCORE = 'macro-f1score'
128 | EM = 'em'
129 | INME_CTX = 'Ctx=-1'
130 | LARGE_CTX = 'Ctx<-1'
131 | ELLIPSIS = 'ellipsis'
132 | MULTI_ENTITY = 'multi_entity'
133 |
134 | QTYPE_DICT = {
135 | 'Comparative Reasoning (All)': 0,
136 | 'Logical Reasoning (All)': 1,
137 | 'Quantitative Reasoning (All)': 2,
138 | 'Simple Question (Coreferenced)': 3,
139 | 'Simple Question (Direct)': 4,
140 | 'Simple Question (Ellipsis)': 5,
141 | 'Verification (Boolean) (All)': 6,
142 | 'Quantitative Reasoning (Count) (All)': 7,
143 | 'Comparative Reasoning (Count) (All)': 8,
144 | 'Clarification': 9
145 | }
146 |
147 | INV_QTYPE_DICT = {}
148 | for k, v in QTYPE_DICT.items():
149 | INV_QTYPE_DICT[v] = k
150 |
151 |
152 | def get_value(question):
153 | if 'min' in question.split():
154 | value = '0'
155 | elif 'max' in question.split():
156 | value = '0'
157 | elif 'exactly' in question.split():
158 | value = re.search(r'\d+', question.split('exactly')[1])
159 | if value:
160 | value = value.group()
161 | elif 'approximately' in question.split():
162 | value = re.search(r'\d+', question.split('approximately')[1])
163 | if value:
164 | value = value.group()
165 | elif 'around' in question.split():
166 | value = re.search(r'\d+', question.split('around')[1])
167 | if value:
168 | value = value.group()
169 | elif 'atmost' in question.split():
170 | value = re.search(r'\d+', question.split('atmost')[1])
171 | if value:
172 | value = value.group()
173 | elif 'atleast' in question.split():
174 | value = re.search(r'\d+', question.split('atleast')[1])
175 | if value:
176 | value = value.group()
177 | else:
178 | print(f'Could not extract value from question: {question}')
179 | value = '0'
180 |
181 | return value
--------------------------------------------------------------------------------
/evaluation/executor.py:
--------------------------------------------------------------------------------
1 | from collections import OrderedDict
2 | from actions import ActionOperator
3 | import requests
4 | import json
5 | import sys
6 | import random
7 |
8 | action_input_size_constrain = {'find': 2, 'find_reverse': 2, 'filter_type': 2,
9 | 'filter_multi_types': 3, 'find_tuple_counts': 3, 'find_reverse_tuple_counts': 3,
10 | 'greater': 2, 'less': 2, 'equal': 2, 'approx': 2, 'atmost': 2, 'atleast': 2, 'argmin': 1,
11 | 'argmax': 1, 'is_in': 2, 'count': 1, 'union': 10, 'intersection': 2, 'difference': 2}
12 |
13 | class ActionExecutor:
14 | def __init__(self, server_link, kg=None):
15 | self.server_link = server_link
16 | self.operator = kg
17 |
18 | def _parse_actions(self, actions):
19 | actions_to_execute = OrderedDict()
20 | for i, action in enumerate(actions):
21 | if i == 0:
22 | actions_to_execute[str(i)] = [action[1]]
23 | continue
24 |
25 | if action[0] == 'action':
26 | actions_to_execute[str(i)] = [action[1]]
27 | if actions[i-1][0] == 'action':
28 | # child of previous action
29 | actions_to_execute[str(i-1)].append(str(i))
30 | else:
31 | for k in reversed(list(actions_to_execute.keys())[:-1]):
32 | if (len(actions_to_execute[k])-1) < action_input_size_constrain[actions_to_execute[k][0]]:
33 | actions_to_execute[str(k)].append(str(i))
34 | break
35 | else:
36 | for j in reversed(list(actions_to_execute.keys())):
37 | j = int(j)
38 | if actions[j][0] == 'action' and len(actions_to_execute[str(j)]) < action_input_size_constrain[actions[j][1]] + 1:
39 | # child of previous action
40 | if action[1].isnumeric():
41 | actions_to_execute[str(j)].append(int(action[1]))
42 | else:
43 | actions_to_execute[str(j)].append(action[1])
44 | break
45 | elif actions[j][0] != 'action' and len(actions_to_execute[str(j)]) < action_input_size_constrain[actions[j][1]] + 1:
46 | actions_to_execute[str(j)].append(action[1])
47 |
48 | return actions_to_execute
49 |
50 | def _parse_actions_sparql(self, actions):
51 | all_actions = actions.split(' ')
52 | query = ""
53 | for idx, a in enumerate(all_actions):
54 | if idx > 1 and all_actions[idx - 1].lower() in ['wd:', 'wdt:']:
55 | query = query + a.upper()
56 | else:
57 | query = query + ' ' + a
58 | query = query.strip()
59 | return query
60 |
61 | def _execute_actions(self, actions_to_execute):
62 | # execute actions on kg
63 | partial_results = OrderedDict()
64 | for key, value in reversed(actions_to_execute.items()):
65 | if key == list(actions_to_execute.keys())[0] and value[0] == 'count':
66 | continue
67 | # create new values in case getting children results
68 | new_value = []
69 | for v in value:
70 | if isinstance(v, str) and v.isnumeric():
71 | new_value.append(partial_results[v])
72 | continue
73 | new_value.append(v)
74 |
75 | value = new_value.copy()
76 |
77 | # execute action
78 | action = value[0]
79 | if action == 'union' and len(value) >= 2:
80 | partial_results[key] = getattr(self.operator, action)(*value[1:])
81 | elif len(value) == 2:
82 | arg = value[1]
83 | partial_results[key] = getattr(self.operator, action)(arg)
84 | elif len(value) == 3:
85 | arg_1 = value[1]
86 | arg_2 = value[2]
87 | partial_results[key] = getattr(self.operator, action)(arg_1, arg_2)
88 | elif len(value) == 4:
89 | arg_1 = value[1]
90 | arg_2 = value[2]
91 | arg_3 = value[3]
92 | partial_results[key] = getattr(self.operator, action)(arg_1, arg_2, arg_3)
93 | else:
94 | raise NotImplementedError('Not implemented for more than 3 inputs!')
95 |
96 | return next(reversed(partial_results.values()))
97 |
98 | def _execute_actions_sparql(self, query):
99 | def run_q(query,link):
100 | acceptable_format = 'application/sparql-results+json'
101 | headers = {'Accept': acceptable_format}
102 | response = requests.post(link ,data={'query': query}, headers=headers)
103 | t = response.content
104 | j = json.loads(t)
105 | return j
106 |
107 | def get_results(results):
108 | if 'boolean' in results.keys():
109 | print(results['boolean'])
110 | return results['boolean'], 'boolean'
111 | else:
112 | print(results)
113 | varBindings = {}
114 | assert len(results['head']['vars']) == 1
115 | for var in results['head']['vars']:
116 | varBindings[var] = []
117 | for bin in results['results']['bindings']:
118 | print(bin)
119 | if var in bin.keys():
120 | print(var)
121 | varBindings[var].append(bin[var]['value'].split('/')[-1])
122 | assert len(varBindings.keys()) == 1
123 | for key in varBindings.keys():
124 | return varBindings[key], key
125 |
126 | #link= servers[self.server_link]
127 | j = run_q(query,self.server_link)
128 | results = get_results(j)
129 | return results
130 |
131 |
132 | def __call__(self, actions, prev_results, question_type, sparql=False):
133 | if sparql:
134 | sparql = self._parse_actions_sparql(actions)
135 | return self._execute_actions_sparql(sparql)
136 |
137 | if question_type in ['Logical Reasoning (All)', 'Quantitative Reasoning (All)', 'Comparative Reasoning (All)', 'Clarification', 'Quantitative Reasoning (Count) (All)', 'Comparative Reasoning (Count) (All)']:
138 | action_input_size_constrain['union'] = 2
139 | # parse actions
140 | actions_to_execute = self._parse_actions(actions)
141 | for key, value in actions_to_execute.items():
142 | if actions_to_execute[key][1] == 'prev_answer':
143 | actions_to_execute[key][1] = prev_results
144 | elif actions_to_execute[key][0] == 'is_in' and actions_to_execute[key][1].startswith('Q'):
145 | actions_to_execute[key][1] = [actions_to_execute[key][1]]
146 | # execute actions and return results
147 | return self._execute_actions(actions_to_execute)
148 |
--------------------------------------------------------------------------------
/evaluation/meters.py:
--------------------------------------------------------------------------------
1 | # meter class for storing results
2 | class AccuracyMeter(object):
3 | def __init__(self):
4 | self.reset()
5 |
6 | def reset(self):
7 | self.correct = 0
8 | self.wrong = 0
9 | self.accuracy = 0
10 | self.exact_match_acc = 0.0
11 | self.number_of_instance = 0
12 | self.correct_exact_match = 0.0
13 |
14 | def update(self, gold, result, gold_sparql, pred_sparql):
15 | self.number_of_instance += 1
16 | if gold_sparql is not None and pred_sparql is not None and gold_sparql.lower() == pred_sparql.lower():
17 | self.correct_exact_match += 1
18 | if gold == result:
19 | self.correct += 1
20 | else:
21 | self.wrong += 1
22 |
23 | self.accuracy = self.correct / (self.correct + self.wrong)
24 | self.exact_match_acc = self.correct_exact_match / self.number_of_instance
25 |
26 | class F1scoreMeter(object):
27 | def __init__(self):
28 | self.reset()
29 |
30 | def reset(self):
31 | self.tp = 0
32 | self.fp = 0
33 | self.fn = 0
34 | self.precision = 0
35 | self.recall = 0
36 | self.f1_score = 0
37 | self.exact_match_acc = 0
38 | self.correct_exact_match = 0.0
39 | self.number_of_instance = 0.0
40 | self.missmatch = 0.0
41 | ## debug
42 | self.acc_prec_macro = 0.0
43 | self.acc_rec_macro = 0.0
44 | self.acc_f1_macro = 0.0
45 |
46 | def update(self, gold, result, gold_sparql, pred_sparql):
47 | self.number_of_instance += 1
48 | if gold_sparql is not None and pred_sparql is not None and gold_sparql.lower() == pred_sparql.lower():
49 | self.correct_exact_match += 1
50 | if result != gold:
51 | self.missmatch += 1
52 | # debug
53 | print(gold_sparql)
54 | print('result', result)
55 | print('gold', gold)
56 | print('****** EM but <> results ******')
57 |
58 | self.tp += len(result.intersection(gold))
59 | self.fp += len(result.difference(gold))
60 | self.fn += len(gold.difference(result))
61 | if self.tp > 0 or self.fp > 0:
62 | self.precision = self.tp / (self.tp + self.fp)
63 | if self.tp > 0 or self.fn > 0:
64 | self.recall = self.tp / (self.tp + self.fn)
65 | if self.precision > 0 or self.recall > 0:
66 | self.f1_score = 2 * self.precision * self.recall / (self.precision + self.recall)
67 |
68 | prec, rec = 0,0
69 | if len(result) > 0:
70 | #print(f'Instance precision: {len(result.intersection(gold)) / len(result)}')
71 | prec = len(result.intersection(gold)) / len(result)
72 | self.acc_prec_macro += prec
73 | if len(gold) > 0:
74 | rec = len(result.intersection(gold)) / len(gold)
75 | self.acc_rec_macro += rec
76 | if prec > 0 or rec > 0:
77 | self.acc_f1_macro += 2 * prec * rec / (prec + rec)
78 |
79 | ###
80 | #if (len(result.intersection(gold))!= len(gold) or len(result.intersection(gold))!= len(result)) and \
81 | # (gold_sparql is not None and pred_sparql is not None and gold_sparql.lower() == pred_sparql.lower()):
82 | # print('gold', gold)
83 | # print('result', result)
84 | # print('prec/rec', self.precision, self.recall)
85 | # print('rec fla', self.tp / (self.tp + self.fn), self.tp, self.fn)
86 | # print('tp/inter', result.intersection(gold), len(result.intersection(gold)))
87 | # print('fn/diff', gold.difference(result), len(gold.difference(result)))
88 | # exit()
89 |
90 | self.exact_match_acc = self.correct_exact_match / self.number_of_instance
91 |
92 |
93 | ## Unused???
94 | class ExactMatchMeter(object):
95 | def __init__(self):
96 | self.total = 0
97 | self.correct = 0
98 |
99 | def update(self, gold, pred):
100 | pass
101 |
--------------------------------------------------------------------------------
/evaluation/summarise_results.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import json
3 | from glob import glob
4 | import os
5 | import argparse
6 |
7 | # add arguments to parser
8 | parser = argparse.ArgumentParser(description='Summarise evaluation results')
9 | parser.add_argument('--file_path', default='', help='path that contains the .json files to summarise')
10 |
11 | args = parser.parse_args()
12 |
13 | # load constants module
14 | from constants import *
15 |
16 |
17 | files = glob(os.path.join(args.file_path, '*.json'))
18 |
19 | metric = {INME_CTX: {EM: 0, 'cnt': 0},
20 | LARGE_CTX: {EM: 0, 'cnt': 0},
21 | ELLIPSIS: {EM: 0, 'cnt': 0},
22 | MULTI_ENTITY: {EM: 0, 'cnt': 0}}
23 |
24 | for f in files:
25 | if 'eval_summary.json' in f:
26 | continue #skip the summary file in case it was already generated
27 | with open(f) as json_file:
28 | print(f)
29 | data = json.load(json_file)
30 | for key in data.keys():
31 | if MULTI_ENTITY == key and EM in data[key].keys():
32 | metric[MULTI_ENTITY][EM] += data[key][EM]
33 | metric[MULTI_ENTITY]['cnt'] += 1
34 | if ELLIPSIS == key and EM in data[key].keys():
35 | metric[ELLIPSIS][EM] += data[key][EM]
36 | metric[ELLIPSIS]['cnt'] += 1
37 | if INME_CTX == key and EM in data[key].keys():
38 | metric[INME_CTX][EM] += data[key][EM]
39 | metric[INME_CTX]['cnt'] += 1
40 | if LARGE_CTX == key and EM in data[key].keys():
41 | metric[LARGE_CTX][EM] += data[key][EM]
42 | metric[LARGE_CTX]['cnt'] += 1
43 |
44 | res = {}
45 | res[MULTI_ENTITY] = {}
46 | res[MULTI_ENTITY][EM] = metric[MULTI_ENTITY][EM] / metric[MULTI_ENTITY]['cnt'] if metric[MULTI_ENTITY]['cnt'] > 0 else 0
47 | res[ELLIPSIS] = {}
48 | res[ELLIPSIS][EM] = metric[ELLIPSIS][EM] / metric[ELLIPSIS]['cnt'] if metric[ELLIPSIS]['cnt'] > 0 else 0
49 | res[INME_CTX] = {}
50 | res[INME_CTX][EM] = metric[INME_CTX][EM] / metric[INME_CTX]['cnt'] if metric[INME_CTX]['cnt'] > 0 else 0
51 | res[LARGE_CTX] = {}
52 | res[LARGE_CTX] [EM] = metric[LARGE_CTX][EM] / metric[LARGE_CTX]['cnt'] if metric[LARGE_CTX]['cnt'] > 0 else 0
53 |
54 | # write .json file with details about the results
55 | with open(os.path.join(args.file_path, 'eval_summary.json'), 'w') as fp:
56 | json.dump(res, fp)
57 |
58 |
--------------------------------------------------------------------------------
/lasagneSP/LICENCE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2020 Endri Kacupaj
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/lasagneSP/README.md:
--------------------------------------------------------------------------------
1 | This director contains lasagneSP model adapted from https://github.com/endrikacupaj/LASAGNE.
2 |
3 |
4 | ### Inverted index on Wikidata entities
5 | For building an inverted index on wikidata entities we use [elastic](https://www.elastic.co/) search. Consider the script file [csqa_elasticse.py](scripts/csqa_elasticse.py) for doing so.
6 |
7 | ## BERT embeddings
8 | Before training the framework, we need to create BERT embeddings for the knowledge graph (entity) types and relations. You can do that by running.
9 | ```
10 | python scripts/bert_embeddings.py
11 | ```
12 |
13 | ## Train lasagneSP
14 | ```
15 | python train.py --data_path /preprocessed_data
16 | ```
17 |
18 | ## Inference
19 | Inference is performed per question-type.
20 | ```
21 | python inference.py --question_type QTYPE --model_path experiments/snapshots/model_path.pth.tar --data_path /preprocessed_data
22 | ```
23 | Where QTYPE is in ("Clarification" "Comparative Reasoning (All)" "Comparative Reasoning (Count) (All)" "Logical Reasoning (All)" "Quantitative Reasoning (All)" "Quantitative Reasoning (Count) (All)" "Simple Question (Coreferenced)" "Simple Question (Direct)" "Simple Question (Ellipsis)" "Verification (Boolean) (All)" "Simple Question (Coreferenced)" "Verification (Boolean) (All)")
24 |
25 | ## Evaluation
26 | To eexcute and evalute the inferred files, run the following script in evaluation folder.
27 | ```
28 | bash execute_all.sh
29 | python summarise_results.py --file_path out_dir
30 | ```
31 |
32 |
33 | ## License
34 | The repository is under [MIT License](LICENCE).
35 |
36 |
--------------------------------------------------------------------------------
/lasagneSP/args.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 |
4 | def str2bool(v):
5 | if v.lower() in ('yes', 'true', 't', 'y', '1'):
6 | return True
7 | elif v.lower() in ('no', 'false', 'f', 'n', '0'):
8 | return False
9 | else:
10 | raise argparse.ArgumentTypeError('Boolean value expected.')
11 |
12 | def get_parser():
13 | parser = argparse.ArgumentParser(description='LASAGNE')
14 |
15 | # general
16 | parser.add_argument('--seed', default=1234, type=int)
17 | parser.add_argument('--no-cuda', action='store_true')
18 | parser.add_argument('--cuda_device', default=0, type=int)
19 |
20 | # data
21 | parser.add_argument('--data_path', default='/data/final/csqa')
22 |
23 | # experiments
24 | parser.add_argument('--snapshots', default='experiments/snapshots', type=str)
25 | parser.add_argument('--path_results', default='experiments/results', type=str)
26 | parser.add_argument('--path_error_analysis', default='experiments/error_analysis', type=str)
27 | parser.add_argument('--path_inference', default='experiments/inference', type=str)
28 |
29 | # task
30 | parser.add_argument('--task', default='multitask', choices=['multitask',
31 | 'logical_form',
32 | 'ner',
33 | 'coref',
34 | 'graph'], type=str)
35 |
36 | # model
37 | parser.add_argument('--emb_dim', default=300, type=int)
38 | parser.add_argument('--dropout', default=0.1, type=int)
39 | parser.add_argument('--heads', default=6, type=int)
40 | parser.add_argument('--layers', default=2, type=int)
41 | parser.add_argument('--max_positions', default=1000, type=int)
42 | parser.add_argument('--pf_dim', default=300, type=int)
43 | parser.add_argument('--graph_heads', default=2, type=int)
44 | parser.add_argument('--bert_dim', default=3072, type=int)
45 |
46 | # training
47 | parser.add_argument('--lr', default=0.0001, type=float)
48 | parser.add_argument('--momentum', default=0.9, type=float)
49 | parser.add_argument('--warmup', default=4000, type=float)
50 | parser.add_argument('--factor', default=1, type=float)
51 | parser.add_argument('--weight_decay', default=0, type=float)
52 | parser.add_argument('--epochs', default=20, type=int)
53 | parser.add_argument('--start_epoch', default=0, type=int)
54 | parser.add_argument('--valfreq', default=1, type=int)
55 | parser.add_argument('--resume', default='', type=str)
56 | parser.add_argument('--clip', default=5, type=int)
57 | parser.add_argument('--batch_size', default=50, type=int)
58 | parser.add_argument('--mapsplits', type=str2bool, nargs='?',const=True,default=False)
59 | parser.add_argument('--mapfile', default='')
60 |
61 | # test and inference
62 | parser.add_argument('--model_path', default='experiments/snapshots/', type=str)
63 | parser.add_argument('--inference_partition', default='test', choices=['val', 'test'], type=str)
64 | parser.add_argument('--question_type', default='Clarification',
65 | choices=['Clarification',
66 | 'Comparative Reasoning (All)',
67 | 'Logical Reasoning (All)',
68 | 'Quantitative Reasoning (All)',
69 | 'Simple Question (Coreferenced)',
70 | 'Simple Question (Direct)',
71 | 'Simple Question (Ellipsis)',
72 | 'Verification (Boolean) (All)',
73 | 'Quantitative Reasoning (Count) (All)',
74 | 'Comparative Reasoning (Count) (All)'], type=str)
75 |
76 | return parser
77 |
--------------------------------------------------------------------------------
/lasagneSP/execute_all.sh:
--------------------------------------------------------------------------------
1 | set -e
2 | QTYPE_ARRAY_WITHOUT_CONTEXT=("Clarification" "Comparative Reasoning (All)" "Comparative Reasoning (Count) (All)" "Logical Reasoning (All)" "Quantitative Reasoning (All)" "Quantitative Reasoning (Count) (All)" "Simple Question (Direct)" "Simple Question (Ellipsis)")
3 | QTYPE_ARRAY_WITH_CONTEXT=("Simple Question (Coreferenced)" "Verification (Boolean) (All)")
4 | base_p="prefix_path_test_"
5 | split=0
6 | out_dir="outdir"
7 | mkdir ${out_dir}
8 |
9 | for val in "${QTYPE_ARRAY_WITHOUT_CONTEXT[@]}"; do
10 | echo $val
11 | python run_subtype_lf.py --file_path "${base_p}${val}"".json" --question_type "${val}" --server_link "http://127.0.0.1:9999/blazegraph/namespace/wd/sparq" --out_eval_file split1_"${val}"_intermediate.json > ${out_dir}/split${split}_"${val}"_intermediate.out 2> ${out_dir}/split${split}_"${val}"_intermediate.out
12 | done
13 |
14 | for val in "${QTYPE_ARRAY_WITH_CONTEXT[@]}"; do
15 | echo $val
16 | python run_subtype_lf.py --file_path "${base_p}${val}"".json" --question_type "${val}" --server_link "http://127.0.0.1:9999/blazegraph/namespace/wd/sparq" --context_dist_file CSQA_v9_skg.v6_compar_spqres9_subkg2_tyTop_nelctx_cleaned_context_distance_test.log --out_eval_file split1_"${val}"_intermediate.json > ${out_dir}/split${split}_"${val}"_intermediate.out 2> ${out_dir}/split${split}_"${val}"_intermediate.out
17 | done
18 |
--------------------------------------------------------------------------------
/lasagneSP/graph.py:
--------------------------------------------------------------------------------
1 | import json
2 | import torch
3 | import torch.nn as nn
4 | from torch.autograd import Variable
5 | from torch_geometric.data import Data
6 |
7 | # import constants
8 | from myconstants import *
9 |
10 | class TypeRelationGraph:
11 | def __init__(self, vocab, type_path=f'{ROOT_PATH}'):
12 | #type_path = "/home/s1959796/debug_lasange/LASAGNE-master"
13 | type_path = "/home/s1959796/spice_dataset_project/debug_lasange/LASAGNE-master"
14 | self.vocab = vocab
15 | self.existing_nodes = list(vocab.stoi.keys())
16 | self.type_triples = json.loads(open(f'{type_path}/knowledge_graph/wikidata_type_dict.json').read())
17 | self.bert_embeddings = json.loads(open(f'{type_path}/knowledge_graph/node_embeddings.json').read())
18 | self.nodes = torch.tensor([self.bert_embeddings[node] for node in self.existing_nodes], requires_grad=True)
19 | self.start = []
20 | self.end = []
21 | self.existing_edges = []
22 |
23 | # create edges
24 | self._create_edges()
25 |
26 | # create PyG graph
27 | self.data = Data(x=self.nodes, edge_index=torch.LongTensor([self.start, self.end])).to(DEVICE)
28 |
29 | def _create_edges(self):
30 | # extract graph data from KG
31 | for head in self.type_triples:
32 | if head in self.vocab.stoi: # only types that are in vocab
33 | for relation in self.type_triples[head]:
34 | if relation in self.vocab.stoi: # only predicates that are in vocab
35 | self._add_edge(head, relation) # add head -> relation edge
36 | for tail in self.type_triples[head][relation]:
37 | if tail in self.vocab.stoi:
38 | self._add_edge(relation, tail) # add relation -> tail edge
39 |
40 | def _add_edge(self, start, end):
41 | if f'{start}->{end}' not in self.existing_edges:
42 | self.start.append(self.existing_nodes.index(start))
43 | self.end.append(self.existing_nodes.index(end))
44 | self.existing_edges.append(f'{start}->{end}')
45 |
--------------------------------------------------------------------------------
/lasagneSP/inference.py:
--------------------------------------------------------------------------------
1 | import random
2 | import logging
3 | import torch
4 | import numpy as np
5 | from model import LASAGNE
6 | from dataset import CSQADataset
7 | from utils import Predictor, Inference
8 |
9 | # import constants
10 | from myconstants import *
11 |
12 | # set logger
13 | print(args.question_type)
14 | logging.basicConfig(format='%(asctime)s %(name)-12s %(levelname)-8s %(message)s',
15 | datefmt='%d/%m/%Y %I:%M:%S %p',
16 | level=logging.INFO,
17 | handlers=[
18 | logging.FileHandler(f'{args.path_results}/inference_{args.question_type}.log', 'w'),
19 | logging.StreamHandler()
20 | ])
21 | logger = logging.getLogger(__name__)
22 |
23 | # set a seed value
24 | random.seed(args.seed)
25 | np.random.seed(args.seed)
26 | if torch.cuda.is_available():
27 | torch.manual_seed(args.seed)
28 | torch.cuda.manual_seed(args.seed)
29 | torch.cuda.manual_seed_all(args.seed)
30 |
31 | def main():
32 | # load data
33 | dataset = CSQADataset()
34 | vocabs = dataset.get_vocabs()
35 | inference_data = dataset.get_inference_data(args.inference_partition)
36 |
37 | logger.info(f'Inference partition: {args.inference_partition}')
38 | logger.info(f'Inference question type: {args.question_type}')
39 | logger.info('Inference data prepared')
40 | logger.info(f"Num of inference data: {len(inference_data)}")
41 |
42 | # load model
43 | model = LASAGNE(vocabs).to(DEVICE)
44 |
45 | logger.info(f"=> loading checkpoint '{args.model_path}'")
46 | if DEVICE.type=='cpu':
47 | checkpoint = torch.load(f'{ROOT_PATH}/{args.model_path}', encoding='latin1', map_location='cpu')
48 | else:
49 | checkpoint = torch.load(f'{ROOT_PATH}/{args.model_path}', encoding='latin1')
50 | args.start_epoch = checkpoint['epoch']
51 | model.load_state_dict(checkpoint['state_dict'])
52 | logger.info(f"=> loaded checkpoint '{args.model_path}' (epoch {checkpoint['epoch']})")
53 |
54 | # construct actions
55 | predictor = Predictor(model, vocabs, DEVICE)
56 | Inference().construct_actions(inference_data, predictor, logger)
57 |
58 | if __name__ == '__main__':
59 | main()
60 |
--------------------------------------------------------------------------------
/lasagneSP/knowledge_graph/knowledge_graph.py:
--------------------------------------------------------------------------------
1 | import os
2 | import ujson
3 | import time
4 | from pathlib import Path
5 | ROOT_PATH = Path(os.path.dirname(__file__))
6 |
7 | class KnowledgeGraph:
8 | def __init__(self, wikidata_path=f'{ROOT_PATH}'):
9 | tic = time.perf_counter()
10 |
11 | # id -> entity label
12 | self.id_entity = ujson.loads(open(f'{wikidata_path}/items_wikidata_n.json').read())
13 | print(f'Loaded id_entity {time.perf_counter()-tic:0.2f}s')
14 |
15 | # id -> relation label
16 | self.id_relation = ujson.loads(open(f'{wikidata_path}/filtered_property_wikidata4.json').read())
17 | print(f'Loaded id_relation {time.perf_counter()-tic:0.2f}s')
18 |
19 | # entity -> type
20 | self.entity_type = ujson.loads(open(f'{wikidata_path}/entity_type.json').read()) # dict[e] -> type
21 | print(f'Loaded entity_type {time.perf_counter()-tic:0.2f}s')
22 |
23 | # type -> relation -> type
24 | self.type_triples = ujson.loads(open(f'{wikidata_path}/wikidata_type_dict.json').read())
25 | print(f'Loaded type_triples {time.perf_counter()-tic:0.2f}s')
26 |
27 | # subject -> relation -> object
28 | self.subject_triples_1 = ujson.loads(open(f'{wikidata_path}/wikidata_short_1.json').read())
29 | self.subject_triples_2 = ujson.loads(open(f'{wikidata_path}/wikidata_short_2.json').read())
30 | self.subject_triples = {**self.subject_triples_1, **self.subject_triples_2}
31 | print(f'Loaded subject_triples {time.perf_counter()-tic:0.2f}s')
32 |
33 | # object -> relation -> subject
34 | self.object_triples = ujson.loads(open(f'{wikidata_path}/comp_wikidata_rev.json').read())
35 | print(f'Loaded object_triples {time.perf_counter()-tic:0.2f}s')
36 |
37 | # relation -> subject -> object | relation -> object -> subject
38 | self.relation_subject_object = ujson.loads(open(f'{wikidata_path}/relation_subject_object.json').read())
39 | self.relation_object_subject = ujson.loads(open(f'{wikidata_path}/relation_object_subject.json').read())
40 | print(f'Loaded relation_triples {time.perf_counter()-tic:0.2f}s')
41 |
42 | # labels
43 | self.labels = {
44 | 'entity': self.id_entity, # dict[e] -> label
45 | 'relation': self.id_relation # dict[r] -> label
46 | }
47 |
48 | # triples
49 | self.triples = {
50 | 'subject': self.subject_triples, # dict[s][r] -> [o1, o2, o3]
51 | 'object': self.object_triples, # dict[o][r] -> [s1, s2, s3]
52 | 'relation': {
53 | 'subject': self.relation_subject_object, # dict[r][s] -> [o1, o2, o3]
54 | 'object': self.relation_object_subject # dict[r][o] -> [s1, s2, s3]
55 | },
56 | 'type': self.type_triples # dict[t][r] -> [t1, t2, t3]
57 | }
58 |
59 |
--------------------------------------------------------------------------------
/lasagneSP/model.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 | from torch.autograd import Variable
6 | from graph import TypeRelationGraph
7 | from torch_geometric.nn import GATConv
8 |
9 | # import constants
10 | from myconstants import *
11 |
12 | class LASAGNE(nn.Module):
13 | def __init__(self, vocabs):
14 | super(LASAGNE, self).__init__()
15 | self.vocabs = vocabs
16 | self.encoder = Encoder(vocabs[INPUT], DEVICE)
17 | self.decoder = Decoder(vocabs[LOGICAL_FORM], DEVICE)
18 | self.ner = NerNet(len(vocabs[NER]))
19 | self.coref = CorefNet(len(vocabs[COREF]))
20 | self.graph = TypeRelationGraph(vocabs[GRAPH]).data
21 | self.graph_net = GraphNet(len(vocabs[GRAPH]))
22 |
23 | def forward(self, src_tokens, trg_tokens):
24 | encoder_out = self.encoder(src_tokens)
25 | ner_out, ner_h = self.ner(encoder_out)
26 | coref_out = self.coref(torch.cat([encoder_out, ner_h], dim=-1))
27 | decoder_out, decoder_h = self.decoder(src_tokens, trg_tokens, encoder_out)
28 | encoder_ctx = encoder_out[:, -1:, :]
29 | graph_out = self.graph_net(encoder_ctx, decoder_h, self.graph)
30 |
31 | return {
32 | LOGICAL_FORM: decoder_out,
33 | NER: ner_out,
34 | COREF: coref_out,
35 | GRAPH: graph_out
36 | }
37 |
38 | def _predict_encoder(self, src_tensor):
39 | with torch.no_grad():
40 | encoder_out = self.encoder(src_tensor)
41 | ner_out, ner_h = self.ner(encoder_out)
42 | coref_out = self.coref(torch.cat([encoder_out, ner_h], dim=-1))
43 |
44 | return {
45 | ENCODER_OUT: encoder_out,
46 | NER: ner_out,
47 | COREF: coref_out
48 | }
49 |
50 | def _predict_decoder(self, src_tokens, trg_tokens, encoder_out):
51 | with torch.no_grad():
52 | decoder_out, decoder_h = self.decoder(src_tokens, trg_tokens, encoder_out)
53 | encoder_ctx = encoder_out[:, -1:, :]
54 | graph_out = self.graph_net(encoder_ctx, decoder_h, self.graph)
55 |
56 | return {
57 | DECODER_OUT: decoder_out,
58 | GRAPH: graph_out
59 | }
60 |
61 | class LstmFlatten(nn.Module):
62 | def forward(self, x):
63 | return x[0].squeeze(1)
64 |
65 | class Flatten(nn.Module):
66 | def forward(self, x):
67 | return x.contiguous().view(-1, x.shape[-1])
68 |
69 | class NerNet(nn.Module):
70 | def __init__(self, tags, dropout=args.dropout):
71 | super(NerNet, self).__init__()
72 | self.ner_lstm = nn.Sequential(
73 | nn.LSTM(input_size=args.emb_dim, hidden_size=args.emb_dim, batch_first=True),
74 | LstmFlatten(),
75 | nn.LeakyReLU()
76 | )
77 |
78 | self.ner_linear = nn.Sequential(
79 | Flatten(),
80 | nn.Dropout(dropout),
81 | nn.Linear(args.emb_dim, tags)
82 | )
83 |
84 | def forward(self, x):
85 | h = self.ner_lstm(x)
86 | return self.ner_linear(h), h
87 |
88 | class CorefNet(nn.Module):
89 | def __init__(self, tags, dropout=args.dropout):
90 | super(CorefNet, self).__init__()
91 | self.seq_net = nn.Sequential(
92 | nn.Linear(args.emb_dim*2, args.emb_dim),
93 | nn.LeakyReLU(),
94 | Flatten(),
95 | nn.Dropout(dropout),
96 | nn.Linear(args.emb_dim, tags)
97 | )
98 |
99 | def forward(self, x):
100 | return self.seq_net(x)
101 |
102 | class GraphNet(nn.Module):
103 | def __init__(self, num_nodes):
104 | super(GraphNet, self).__init__()
105 | self.gat = GATConv(args.bert_dim, args.emb_dim, heads=args.graph_heads, dropout=args.dropout)
106 | self.dropout = nn.Dropout(args.dropout)
107 | self.linear_out = nn.Linear((args.emb_dim*args.graph_heads)+args.emb_dim, args.emb_dim)
108 | self.score = nn.Linear(args.emb_dim, 1)
109 | self.context_net = nn.Sequential(
110 | nn.Linear(args.emb_dim*2, args.emb_dim),
111 | nn.LeakyReLU(),
112 | Flatten(),
113 | nn.Dropout(args.dropout),
114 | nn.Linear(args.emb_dim, num_nodes)
115 | )
116 |
117 | def forward(self, encoder_ctx, decoder_h, graph):
118 | g = self.gat(graph.x, graph.edge_index)
119 | g = self.dropout(g)
120 | g = self.linear_out(torch.cat([encoder_ctx.repeat(1, graph.x.shape[0], 1), g.unsqueeze(0).repeat(encoder_ctx.shape[0], 1, 1)], dim=-1))
121 | g = Flatten()(self.score(g).squeeze(-1).unsqueeze(1).repeat(1, decoder_h.shape[1], 1))
122 | x = self.context_net(torch.cat([encoder_ctx.expand(decoder_h.shape), decoder_h], dim=-1))
123 | return x * g
124 |
125 | class Encoder(nn.Module):
126 | def __init__(self, vocabulary, device, embed_dim=args.emb_dim, layers=args.layers,
127 | heads=args.heads, pf_dim=args.pf_dim, dropout=args.dropout, max_positions=args.max_positions):
128 | super().__init__()
129 | input_dim = len(vocabulary)
130 | self.padding_idx = vocabulary.stoi[PAD_TOKEN]
131 | self.dropout = dropout
132 | self.device = device
133 |
134 | input_dim, embed_dim = vocabulary.vectors.size()
135 | self.scale = math.sqrt(embed_dim)
136 | self.embed_tokens = nn.Embedding(input_dim, embed_dim)
137 | self.embed_tokens.weight.data.copy_(vocabulary.vectors)
138 | self.embed_positions = PositionalEmbedding(embed_dim, dropout, max_positions)
139 |
140 | self.layers = nn.ModuleList([EncoderLayer(embed_dim, heads, pf_dim, dropout, device) for _ in range(layers)])
141 |
142 | def forward(self, src_tokens):
143 | src_mask = (src_tokens != self.padding_idx).unsqueeze(1).unsqueeze(2)
144 |
145 | x = self.embed_tokens(src_tokens) * self.scale
146 | x += self.embed_positions(src_tokens)
147 | x = F.dropout(x, p=self.dropout, training=self.training)
148 |
149 | for layer in self.layers:
150 | x = layer(x, src_mask)
151 |
152 | return x
153 |
154 | class EncoderLayer(nn.Module):
155 | def __init__(self, embed_dim, heads, pf_dim, dropout, device):
156 | super().__init__()
157 |
158 | self.layer_norm = nn.LayerNorm(embed_dim)
159 | self.self_attn = MultiHeadedAttention(embed_dim, heads, dropout, device)
160 | self.pos_ff = PositionwiseFeedforward(embed_dim, pf_dim, dropout)
161 | self.dropout = nn.Dropout(dropout)
162 |
163 | def forward(self, src_tokens, src_mask):
164 | x = self.layer_norm(src_tokens + self.dropout(self.self_attn(src_tokens, src_tokens, src_tokens, src_mask)))
165 | x = self.layer_norm(x + self.dropout(self.pos_ff(x)))
166 |
167 | return x
168 |
169 | class Decoder(nn.Module):
170 | def __init__(self, vocabulary, device, embed_dim=args.emb_dim, layers=args.layers,
171 | heads=args.heads, pf_dim=args.pf_dim, dropout=args.dropout, max_positions=args.max_positions):
172 | super().__init__()
173 |
174 | output_dim = len(vocabulary)
175 | self.pad_id = vocabulary.stoi[PAD_TOKEN]
176 | self.pf_dim = pf_dim
177 | self.dropout = dropout
178 | self.device = device
179 | self.max_positions = max_positions
180 |
181 | self.scale = math.sqrt(embed_dim)
182 | self.embed_tokens = nn.Embedding(output_dim, embed_dim)
183 | self.embed_positions = PositionalEmbedding(embed_dim, dropout, max_positions)
184 |
185 | self.layers = nn.ModuleList([DecoderLayer(embed_dim, heads, pf_dim, dropout, device) for _ in range(layers)])
186 |
187 | self.linear_out = nn.Linear(embed_dim, output_dim)
188 |
189 | def make_masks(self, src_tokens, trg_tokens):
190 | src_mask = (src_tokens != self.pad_id).unsqueeze(1).unsqueeze(2)
191 | trg_pad_mask = (trg_tokens != self.pad_id).unsqueeze(1).unsqueeze(3)
192 | trg_len = trg_tokens.shape[1]
193 | trg_sub_mask = torch.tril(torch.ones((trg_len, trg_len), device=self.device)).bool()
194 | trg_mask = trg_pad_mask & trg_sub_mask
195 | return src_mask, trg_mask
196 |
197 | def forward(self, src_tokens, trg_tokens, encoder_out):
198 | src_mask, trg_mask = self.make_masks(src_tokens, trg_tokens)
199 |
200 | x = self.embed_tokens(trg_tokens) * self.scale
201 | x += self.embed_positions(trg_tokens)
202 | h = F.dropout(x, p=self.dropout, training=self.training)
203 |
204 | for layer in self.layers:
205 | h = layer(h, encoder_out, trg_mask, src_mask)
206 |
207 | x = h.contiguous().view(-1, h.shape[-1])
208 | x = self.linear_out(x)
209 |
210 | return x, h
211 |
212 | class DecoderLayer(nn.Module):
213 | def __init__(self, embed_dim, heads, pf_dim, dropout, device):
214 | super().__init__()
215 | self.layer_norm = nn.LayerNorm(embed_dim)
216 | self.self_attn = MultiHeadedAttention(embed_dim, heads, dropout, device)
217 | self.src_attn = MultiHeadedAttention(embed_dim, heads, dropout, device)
218 | self.pos_ff = PositionwiseFeedforward(embed_dim, pf_dim, dropout)
219 | self.dropout = nn.Dropout(dropout)
220 |
221 | def forward(self, embed_trg, embed_src, trg_mask, src_mask):
222 | x = self.layer_norm(embed_trg + self.dropout(self.self_attn(embed_trg, embed_trg, embed_trg, trg_mask)))
223 | x = self.layer_norm(x + self.dropout(self.src_attn(x, embed_src, embed_src, src_mask)))
224 | x = self.layer_norm(x + self.dropout(self.pos_ff(x)))
225 |
226 | return x
227 |
228 | class MultiHeadedAttention(nn.Module):
229 | def __init__(self, embed_dim, heads, dropout, device):
230 | super().__init__()
231 | assert embed_dim % heads == 0
232 | self.attn_dim = embed_dim // heads
233 | self.heads = heads
234 | self.dropout = dropout
235 |
236 | self.linear_q = nn.Linear(embed_dim, embed_dim)
237 | self.linear_k = nn.Linear(embed_dim, embed_dim)
238 | self.linear_v = nn.Linear(embed_dim, embed_dim)
239 |
240 | self.scale = torch.sqrt(torch.FloatTensor([self.attn_dim])).to(device)
241 |
242 | self.linear_out = nn.Linear(embed_dim, embed_dim)
243 |
244 | def forward(self, query, key, value, mask=None):
245 | batch_size = query.shape[0]
246 |
247 | Q = self.linear_q(query)
248 | K = self.linear_k(key)
249 | V = self.linear_v(value)
250 |
251 | Q = Q.view(batch_size, -1, self.heads, self.attn_dim).permute(0, 2, 1, 3) # (batch, heads, sent_len, attn_dim)
252 | K = K.view(batch_size, -1, self.heads, self.attn_dim).permute(0, 2, 1, 3) # (batch, heads, sent_len, attn_dim)
253 | V = V.view(batch_size, -1, self.heads, self.attn_dim).permute(0, 2, 1, 3) # (batch, heads, sent_len, attn_dim)
254 |
255 | energy = torch.matmul(Q, K.permute(0, 1, 3, 2)) / self.scale # (batch, heads, sent_len, sent_len)
256 |
257 | if mask is not None:
258 | energy = energy.masked_fill(mask == 0, -1e10)
259 |
260 | attention = F.softmax(energy, dim=-1) # (batch, heads, sent_len, sent_len)
261 | attention = F.dropout(attention, p=self.dropout, training=self.training)
262 |
263 | x = torch.matmul(attention, V) # (batch, heads, sent_len, attn_dim)
264 | x = x.permute(0, 2, 1, 3).contiguous() # (batch, sent_len, heads, attn_dim)
265 | x = x.view(batch_size, -1, self.heads * (self.attn_dim)) # (batch, sent_len, embed_dim)
266 | x = self.linear_out(x)
267 |
268 | return x
269 |
270 | class PositionwiseFeedforward(nn.Module):
271 | def __init__(self, embed_dim, pf_dim, dropout):
272 | super().__init__()
273 | self.linear_1 = nn.Linear(embed_dim, pf_dim)
274 | self.linear_2 = nn.Linear(pf_dim, embed_dim)
275 | self.dropout = dropout
276 |
277 | def forward(self, x):
278 | x = torch.relu(self.linear_1(x))
279 | x = F.dropout(x, p=self.dropout, training=self.training)
280 |
281 | return self.linear_2(x)
282 |
283 | class PositionalEmbedding(nn.Module):
284 | def __init__(self, d_model, dropout, max_len=5000):
285 | super().__init__()
286 | pos_embed = torch.zeros(max_len, d_model)
287 | position = torch.arange(0., max_len).unsqueeze(1)
288 | div_term = torch.exp(torch.arange(0., d_model, 2) * -(math.log(10000.0) / d_model))
289 | pos_embed[:, 0::2] = torch.sin(position * div_term)
290 | pos_embed[:, 1::2] = torch.cos(position * div_term)
291 | pos_embed = pos_embed.unsqueeze(0)
292 | self.register_buffer('pos_embed', pos_embed)
293 |
294 | def forward(self, x):
295 | return Variable(self.pos_embed[:, :x.size(1)], requires_grad=False)
296 |
--------------------------------------------------------------------------------
/lasagneSP/myconstants.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | from pathlib import Path
4 | from args import get_parser
5 |
6 | # set root path
7 | ROOT_PATH = Path(os.path.dirname(__file__))
8 | #ROOT_PATH = Path("/home/s1959796/debug_lasange/LASAGNE-master/scripts")
9 |
10 | # read parser
11 | parser = get_parser()
12 | args = parser.parse_args()
13 |
14 | # model name
15 | MODEL_NAME = 'LASAGNE'
16 |
17 | # define device
18 | CUDA = 'cuda'
19 | CPU = 'cpu'
20 | DEVICE = torch.device(CUDA if torch.cuda.is_available() else CPU)
21 |
22 | # fields
23 | INPUT = 'input'
24 | LOGICAL_FORM = 'logical_form'
25 | NER = 'ner'
26 | COREF = 'coref'
27 | GRAPH = 'graph'
28 | MULTITASK = 'multitask'
29 |
30 | # helper tokens
31 | START_TOKEN = '[START]'
32 | END_TOKEN = '[END]'
33 | CTX_TOKEN = '[CTX]'
34 | PAD_TOKEN = '[PAD]'
35 | UNK_TOKEN = '[UNK]'
36 | SEP_TOKEN = '[SEP]'
37 | NA_TOKEN = 'NA'
38 |
39 | # ner tag
40 | B = 'B'
41 | I = 'I'
42 | O = 'O'
43 |
44 | # model
45 | ENCODER_OUT = 'encoder_out'
46 | DECODER_OUT = 'decoder_out'
47 |
48 | # training
49 | EPOCH = 'epoch'
50 | STATE_DICT = 'state_dict'
51 | BEST_VAL = 'best_val'
52 | OPTIMIZER = 'optimizer'
53 | CURR_VAL = 'curr_val'
54 |
55 | # question types
56 | TOTAL = 'total'
57 | OVERALL = 'Overall'
58 | CLARIFICATION = 'Clarification'
59 | COMPARATIVE = 'Comparative Reasoning (All)'
60 | LOGICAL = 'Logical Reasoning (All)'
61 | QUANTITATIVE = 'Quantitative Reasoning (All)'
62 | SIMPLE_COREFERENCED = 'Simple Question (Coreferenced)'
63 | SIMPLE_DIRECT = 'Simple Question (Direct)'
64 | SIMPLE_ELLIPSIS = 'Simple Question (Ellipsis)'
65 | VERIFICATION = 'Verification (Boolean) (All)'
66 | QUANTITATIVE_COUNT = 'Quantitative Reasoning (Count) (All)'
67 | COMPARATIVE_COUNT = 'Comparative Reasoning (Count) (All)'
68 |
69 | # action related
70 | ENTITY = 'entity'
71 | RELATION = 'relation'
72 | TYPE = 'type'
73 | VALUE = 'value'
74 | PREV_ANSWER = 'prev_answer'
75 | ACTION = 'action'
76 |
77 | # other
78 | DESCRIPTION = 'description'
79 | QUESTION_TYPE = 'question_type'
80 | IS_CORRECT = 'is_correct'
81 | QUESTION = 'question'
82 | ANSWER = 'answer'
83 | ACTIONS = 'actions'
84 | #GOLD_ACTIONS = 'gold_actions'
85 | GOLD_ACTIONS = 'sparql_delex'
86 | RESULTS = 'results'
87 | PREV_RESULTS = 'prev_results'
88 | CONTEXT_QUESTION = 'context_question'
89 | CONTEXT_ENTITIES = 'context_entities'
90 | BERT_BASE_UNCASED = 'bert-base-uncased'
91 | TURN_ID = 'turnID'
92 |
--------------------------------------------------------------------------------
/lasagneSP/prreprocess_command.sh:
--------------------------------------------------------------------------------
1 | python annotate_csqa/preprocess.py --partition $1 --annotation_task all --read_folder /knowledge_graph/CSQA_v9/ --write_folder /preprocessed_data
2 |
3 |
--------------------------------------------------------------------------------
/lasagneSP/requirements.txt:
--------------------------------------------------------------------------------
1 | dgl==0.4.3.post2
2 | transformers==2.8.0
3 | Unidecode==1.1.1
4 | torch_geometric==1.6.1
5 | torchtext==0.4.0
6 | networkx==2.2
7 | #flair==0.4.4
8 | flair
9 | ujson==2.0.3
10 | #torch_nightly==1.1.0.dev20190502
11 | elasticsearch==7.8.1
12 | numpy==1.17.4
13 | #torch==1.6.0
14 | torch
15 |
--------------------------------------------------------------------------------
/lasagneSP/scripts/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/EdinburghNLP/SPICE/4afa4404b02f59d175976b7e02583fdf41c23c3a/lasagneSP/scripts/__init__.py
--------------------------------------------------------------------------------
/lasagneSP/scripts/bert_embeddings.py:
--------------------------------------------------------------------------------
1 | import re
2 | import json
3 | import torch
4 | import flair
5 | import sys
6 | sys.path.append("/home/s1959796/debug_lasange/lasagne_sparql/lasagne-baseline")
7 | from dataset import CSQADataset
8 | from flair.data import Sentence
9 | from flair.embeddings import FlairEmbeddings, BertEmbeddings, DocumentPoolEmbeddings
10 |
11 | # import constants
12 | from myconstants import *
13 |
14 | # set device
15 | torch.cuda.set_device(0)
16 | DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
17 | flair.device = DEVICE
18 |
19 | # load bert model
20 | bert = DocumentPoolEmbeddings([BertEmbeddings('bert-base-uncased')])
21 |
22 | # read nodes from dataset
23 | graph_nodes = list(CSQADataset().get_vocabs()[GRAPH].stoi.keys())
24 |
25 | print(len(graph_nodes))
26 |
27 | # read entity and relation labels
28 | #ROOT_PATH = '/home/s1959796/debug_lasange/LASAGNE-master'
29 | id_entity = json.loads(open(f'{ROOT_PATH}/knowledge_graph/items_wikidata_n.json').read())
30 | id_relation = json.loads(open(f'{ROOT_PATH}/knowledge_graph/filtered_property_wikidata4.json').read())
31 |
32 | # create embeddings
33 | na_node = Sentence(graph_nodes[0])
34 | pad_node = Sentence(graph_nodes[1])
35 | bert.embed(na_node)
36 | bert.embed(pad_node)
37 | node_embeddings = {
38 | graph_nodes[0]: na_node.embedding.detach().cpu().tolist(),
39 | graph_nodes[1]: pad_node.embedding.detach().cpu().tolist()
40 | }
41 | for node in graph_nodes[2:]:
42 | node_label = Sentence(id_entity[node] if node.startswith('Q') else id_relation[node])
43 | bert.embed(node_label)
44 | node_embeddings[node] = node_label.embedding.detach().cpu().tolist()
45 |
46 | with open(f'{ROOT_PATH}/knowledge_graph/node_embeddings.json', 'w') as outfile:
47 | json.dump(node_embeddings, outfile, indent=4)
48 |
--------------------------------------------------------------------------------
/lasagneSP/scripts/csqa_elasticse.py:
--------------------------------------------------------------------------------
1 | import time
2 | from unidecode import unidecode
3 | from elasticsearch import Elasticsearch
4 | import sys
5 | sys.path.append("/home/s1959796/debug_lasange/LASAGNE-master")
6 |
7 | from knowledge_graph.knowledge_graph import KnowledgeGraph
8 |
9 | kg = KnowledgeGraph()
10 | kg_entities = list(kg.id_entity.items())
11 | kg_types = kg.entity_type
12 | print(f'Num of wikidata entities: {len(kg_entities)}')
13 |
14 | #es = Elasticsearch([{'host': 'localhost', 'port': 9200}])
15 | es = Elasticsearch([{'host': 'localhost','port': 9200}], timeout=300, max_retries=100, retry_on_timeout=True)
16 | #es = Elasticsearch([{'host': '127.0.0.1','port': 9200}])
17 | # es.indices.delete(index='csqa_wikidata', ignore=[400, 404])
18 |
19 | tic = time.perf_counter()
20 | for i, (id, label) in enumerate(kg_entities):
21 | es.index(index='csqa_wikidata', doc_type='entities', id=i+1, body={'id': id, 'label': unidecode(label), 'type': kg_types[id] if id in kg_types else []})
22 | if (i+1) % 10000 == 0:
23 | print(f'==> Finished {((i+1)/len(kg_entities))*100:.4f}% -- {time.perf_counter() - tic:0.2f}s')
24 | print(i, flush=True)
25 | time.sleep(3)
26 |
27 | query = unidecode('Albania')
28 | res = es.search(index='csqa_wikidata', doc_type='entities', body={
29 | 'size': 50,
30 | 'query': {
31 | 'match': {
32 | 'label': {
33 | 'query': query,
34 | 'fuzziness': 'AUTO',
35 | }
36 | }
37 | }
38 | })
39 |
40 | for hit in res['hits']['hits']:
41 | print(f'{hit["_source"]["id"]} - {hit["_source"]["label"]} - {hit["_score"]}')
42 | print('**********************')
43 |
--------------------------------------------------------------------------------
/lasagneSP/train.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import time
4 | import random
5 | import logging
6 | import torch
7 | import numpy as np
8 | import torch.optim
9 | import torch.nn as nn
10 | from pathlib import Path
11 | from args import get_parser
12 | from model import LASAGNE
13 | from dataset import CSQADataset
14 | from torchtext.data import BucketIterator
15 | #from torchtext.legacy.data import BucketIterator
16 | from utils import (NoamOpt, AverageMeter,
17 | SingleTaskLoss, MultiTaskLoss,
18 | save_checkpoint, init_weights)
19 |
20 | # import constants
21 | from myconstants import *
22 |
23 | # set logger
24 | logging.basicConfig(format='%(asctime)s %(name)-12s %(levelname)-8s %(message)s',
25 | datefmt='%d/%m/%Y %I:%M:%S %p',
26 | level=logging.INFO,
27 | handlers=[
28 | logging.FileHandler(f'{args.path_results}/train_{args.task}.log', 'w'),
29 | logging.StreamHandler()
30 | ])
31 | logger = logging.getLogger(__name__)
32 |
33 | # set a seed value
34 | random.seed(args.seed)
35 | np.random.seed(args.seed)
36 | if torch.cuda.is_available():
37 | torch.manual_seed(args.seed)
38 | torch.cuda.manual_seed(args.seed)
39 | torch.cuda.manual_seed_all(args.seed)
40 |
41 | def main():
42 | # load data
43 | dataset = CSQADataset()
44 | print('Getting vocab...')
45 | vocabs = dataset.get_vocabs()
46 | print('Getting data...')
47 | train_data, val_data, _ = dataset.get_data()
48 |
49 | # load model
50 | model = LASAGNE(vocabs).to(DEVICE)
51 |
52 | # initialize model weights
53 | init_weights(model)
54 |
55 | logger.info(f'The model has {sum(p.numel() for p in model.parameters() if p.requires_grad):,} trainable parameters')
56 |
57 | # define loss function (criterion)
58 | criterion = {
59 | LOGICAL_FORM: SingleTaskLoss,
60 | NER: SingleTaskLoss,
61 | COREF: SingleTaskLoss,
62 | GRAPH: SingleTaskLoss,
63 | MULTITASK: MultiTaskLoss
64 | }[args.task](ignore_index=vocabs[LOGICAL_FORM].stoi[PAD_TOKEN])
65 |
66 | # define optimizer
67 | optimizer = NoamOpt(torch.optim.Adam(model.parameters(), lr=0, betas=(0.9, 0.98), eps=1e-9))
68 |
69 | if args.resume:
70 | if os.path.isfile(args.resume):
71 | logger.info(f"=> loading checkpoint '{args.resume}''")
72 | checkpoint = torch.load(args.resume)
73 | args.start_epoch = checkpoint[EPOCH]
74 | best_val = checkpoint[BEST_VAL]
75 | model.load_state_dict(checkpoint[STATE_DICT])
76 | optimizer.optimizer.load_state_dict(checkpoint[OPTIMIZER])
77 | logger.info(f"=> loaded checkpoint '{args.resume}' (epoch {checkpoint[EPOCH]})")
78 | else:
79 | logger.info(f"=> no checkpoint found at '{args.resume}'")
80 | best_val = float('inf')
81 | else:
82 | best_val = float('inf')
83 |
84 | # prepare training and validation loader
85 | train_loader, val_loader = BucketIterator.splits((train_data, val_data),
86 | batch_size=args.batch_size,
87 | sort_within_batch=False,
88 | sort_key=lambda x: len(x.input),
89 | device=DEVICE)
90 |
91 | logger.info('Loaders prepared.')
92 | logger.info(f"Training data: {len(train_data.examples)}")
93 | logger.info(f"Validation data: {len(val_data.examples)}")
94 | logger.info(f'Question example: {train_data.examples[0].input}')
95 | logger.info(f'Logical form example: {train_data.examples[0].logical_form}')
96 | logger.info(f"Unique tokens in input vocabulary: {len(vocabs[INPUT])}")
97 | logger.info(f"Unique tokens in logical form vocabulary: {len(vocabs[LOGICAL_FORM])}")
98 | logger.info(f"Unique tokens in ner vocabulary: {len(vocabs[NER])}")
99 | logger.info(f"Unique tokens in coref vocabulary: {len(vocabs[COREF])}")
100 | logger.info(f"Number of nodes in the graph: {len(vocabs[GRAPH])}")
101 | logger.info(f'Batch: {args.batch_size}')
102 | logger.info(f'Epochs: {args.epochs}')
103 |
104 | # run epochs
105 | for epoch in range(args.start_epoch, args.epochs):
106 | # train for one epoch
107 | train(train_loader, model, vocabs, criterion, optimizer, epoch)
108 |
109 | logger.info(f'Start Validating')
110 | # evaluate on validation set
111 | if (epoch+1) % args.valfreq == 0:
112 | val_loss = validate(val_loader, model, vocabs, criterion)
113 | # if val_loss < best_val:
114 | best_val = min(val_loss, best_val) # log every validation step
115 | logger.info(f'Saving checkpoint ...')
116 | save_checkpoint({
117 | EPOCH: epoch + 1,
118 | STATE_DICT: model.state_dict(),
119 | BEST_VAL: best_val,
120 | OPTIMIZER: optimizer.optimizer.state_dict(),
121 | CURR_VAL: val_loss})
122 | logger.info(f'* Val loss: {val_loss:.4f}')
123 |
124 | def train(train_loader, model, vocabs, criterion, optimizer, epoch):
125 | batch_time = AverageMeter()
126 | losses = AverageMeter()
127 |
128 | # switch to train mode
129 | model.train()
130 |
131 | end = time.time()
132 | for i, batch in enumerate(train_loader):
133 | # get inputs
134 | input = batch.input
135 | logical_form = batch.logical_form
136 | ner = batch.ner
137 | coref = batch.coref
138 | graph = batch.graph
139 |
140 | # compute output
141 | output = model(input, logical_form[:, :-1])
142 |
143 | # prepare targets
144 | target = {
145 | LOGICAL_FORM: logical_form[:, 1:].contiguous().view(-1), # (batch_size * trg_len)
146 | NER: ner.contiguous().view(-1),
147 | COREF: coref.contiguous().view(-1),
148 | GRAPH: graph[:, 1:].contiguous().view(-1)
149 | }
150 |
151 | # compute loss
152 | loss = criterion(output, target) if args.task == MULTITASK else criterion(output[args.task], target[args.task])
153 |
154 | # record loss
155 | losses.update(loss.data, input.size(0))
156 |
157 | # compute gradient and do Adam step
158 | optimizer.zero_grad()
159 | loss.backward()
160 | torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip)
161 | optimizer.step()
162 |
163 | # measure elapsed time
164 | batch_time.update(time.time() - end)
165 | end = time.time()
166 |
167 | logger.info(f'Epoch: {epoch+1} - Train loss: {losses.val:.4f} ({losses.avg:.4f}) - Batch: {((i+1)/len(train_loader))*100:.2f}% - Time: {batch_time.sum:0.2f}s')
168 |
169 | def validate(val_loader, model, vocabs, criterion):
170 | losses = AverageMeter()
171 |
172 | # switch to evaluate mode
173 | model.eval()
174 |
175 | with torch.no_grad():
176 | for _, batch in enumerate(val_loader):
177 | # get inputs
178 | input = batch.input
179 | logical_form = batch.logical_form
180 | ner = batch.ner
181 | coref = batch.coref
182 | graph = batch.graph
183 |
184 | # compute output
185 | output = model(input, logical_form[:, :-1])
186 |
187 | # prepare targets
188 | target = {
189 | LOGICAL_FORM: logical_form[:, 1:].contiguous().view(-1), # (batch_size * trg_len)
190 | NER: ner.contiguous().view(-1),
191 | COREF: coref.contiguous().view(-1),
192 | GRAPH: graph[:, 1:].contiguous().view(-1)
193 | }
194 |
195 | # compute loss
196 | loss = criterion(output, target) if args.task == MULTITASK else criterion(output[args.task], target[args.task])
197 |
198 | # record loss
199 | losses.update(loss.data, input.size(0))
200 |
201 | return losses.avg
202 |
203 | if __name__ == '__main__':
204 | main()
205 |
--------------------------------------------------------------------------------
/sparql-server/README.md:
--------------------------------------------------------------------------------
1 | # Minimal setup
2 | You can access a minimal setup from the following [link](https://uoe-my.sharepoint.com/:f:/g/personal/s1959796_ed_ac_uk/ErNL2lgTuI1Lu5vw8b6tEzEBk2YJ3QalLnYalG89e4Ge0g?e=KwkYC7) and run the following command.
3 |
4 | ```
5 | bash start.sh
6 | ```
7 |
8 | Alternatively, you can re-create the files by following the instructions below.
9 |
10 |
11 | # Instructions to create triples file and load in blazegraph format
12 |
13 | Server files including triples can be created by following the instructions below. We use blazegraph to host the server, however any other triple store should work too.
14 |
15 | ## Create triples file
16 | Get wikidata_proc_json_2 from [CSQA Dataset](https://amritasaha1812.github.io/CSQA/)
17 | ```
18 | python json_to_triples.py
19 | ```
20 |
21 | ## Load triples to blazegraph
22 | ```
23 | bash load_ttl.sh
24 | ```
25 | This will produce wikidata.jnl
26 |
27 | ## Start server
28 | - copy properties filem wikidata.jnl and wd_prefix.ttl into server_files
29 | - Get the blazegraph.jar via [Blazegraph](https://blazegraph.com/)
30 | ```
31 | export PATH=$PATH:$JAVA_HOME
32 | cd path_to/server_files
33 | /usr/java/java-11.0.5/bin/java -server -Xmx150g -XX:+UseG1GC -Dcom.bigdata.rdf.sail.sparql.PrefixDeclProcessor.additionalDeclsFile=wd_prefix.ttl -jar blazegraph.jar
34 | ```
--------------------------------------------------------------------------------
/sparql-server/RWStore.properties:
--------------------------------------------------------------------------------
1 | # Dump data in target.
2 | com.bigdata.journal.AbstractJournal.file=wikidata.jnl
3 | com.bigdata.journal.AbstractJournal.bufferMode=DiskRW
4 | com.bigdata.service.AbstractTransactionService.minReleaseAge=1
5 | # Disable raw records - see https://phabricator.wikimedia.org/T213375
6 | com.bigdata.rdf.store.AbstractTripleStore.enableRawRecordsSupport=false
7 |
8 | com.bigdata.rdf.store.AbstractTripleStore.quads=false
9 | com.bigdata.rdf.store.AbstractTripleStore.statementIdentifiers=false
10 |
11 | # Don't use truth maintenance right yet.
12 | com.bigdata.rdf.sail.truthMaintenance=false
13 | com.bigdata.rdf.store.AbstractTripleStore.textIndex=false
14 | com.bigdata.rdf.store.AbstractTripleStore.axiomsClass=com.bigdata.rdf.axioms.NoAxioms
15 |
16 | # Use our private vocabularies
17 | #com.bigdata.rdf.store.AbstractTripleStore.vocabularyClass=org.wikidata.query.rdf.blazegraph.WikibaseVocabulary$V005
18 | # Enable values inlining - see https://phabricator.wikimedia.org/T213375
19 | #com.bigdata.rdf.store.AbstractTripleStore.inlineURIFactory=org.wikidata.query.rdf.blazegraph.WikibaseInlineUriFactory$V002
20 | #com.bigdata.rdf.store.AbstractTripleStore.extensionFactoryClass=org.wikidata.query.rdf.blazegraph.WikibaseExtensionFactory
21 |
22 | # Suggested settings from https://phabricator.wikimedia.org/T92308
23 | com.bigdata.btree.writeRetentionQueue.capacity=4000
24 | com.bigdata.btree.BTree.branchingFactor=128
25 | # 200M initial extent.
26 | com.bigdata.journal.AbstractJournal.initialExtent=209715200
27 | com.bigdata.journal.AbstractJournal.maximumExtent=209715200
28 | # Bump up the branching factor for the lexicon indices on the default kb.
29 | com.bigdata.namespace.wdq.lex.com.bigdata.btree.BTree.branchingFactor=400
30 | com.bigdata.namespace.wdq.lex.ID2TERM.com.bigdata.btree.BTree.branchingFactor=600
31 | com.bigdata.namespace.wdq.lex.TERM2ID.com.bigdata.btree.BTree.branchingFactor=330
32 | # Bump up the branching factor for the statement indices on the default kb.
33 | com.bigdata.namespace.wdq.spo.com.bigdata.btree.BTree.branchingFactor=1024
34 | com.bigdata.namespace.wdq.spo.OSP.com.bigdata.btree.BTree.branchingFactor=900
35 | com.bigdata.namespace.wdq.spo.SPO.com.bigdata.btree.BTree.branchingFactor=900
36 | # larger statement buffer capacity for bulk loading.
37 | com.bigdata.rdf.sail.bufferCapacity=100000
38 | # Override the #of write cache buffers to improve bulk load performance. Requires enough native heap!
39 | com.bigdata.journal.AbstractJournal.writeCacheBufferCount=1000
40 | # Enable small slot optimization!
41 | com.bigdata.rwstore.RWStore.smallSlotType=1024
42 | # See https://jira.blazegraph.com/browse/BLZG-1385 - reduce LRU cache timeout
43 | com.bigdata.journal.AbstractJournal.historicalIndexCacheCapacity=20
44 | com.bigdata.journal.AbstractJournal.historicalIndexCacheTimeout=5
45 | # default prefix
46 | com.bigdata.rdf.sail.sparql.PrefixDeclProcessor.additionalDeclsFile=wd_predix.ttl
47 |
48 | # Geospatial ON
49 | #com.bigdata.rdf.store.AbstractTripleStore.geoSpatial=true
50 | #com.bigdata.rdf.store.AbstractTripleStore.geoSpatialDefaultDatatype=http\://www.opengis.net/ont/geosparql#wktLiteral
51 | #com.bigdata.rdf.store.AbstractTripleStore.geoSpatialIncludeBuiltinDatatypes=false
52 | #com.bigdata.rdf.store.AbstractTripleStore.geoSpatialDatatypeConfig.0={"config": \
53 | #{"uri":"http://www.opengis.net/ont/geosparql#wktLiteral",\
54 | # "literalSerializer":"org.wikidata.query.rdf.blazegraph.inline.literal.WKTSerializer",\
55 | # "fields":[\
56 | # {"valueType":"DOUBLE","multiplier":"1000000000","serviceMapping":"LONGITUDE"},\
57 | # {"valueType":"DOUBLE","multiplier":"1000000000","serviceMapping":"LATITUDE"},\
58 | # {"valueType":"LONG","multiplier":"1","minValue":"0","serviceMapping":"COORD_SYSTEM"}\
59 | # ]}}
60 |
--------------------------------------------------------------------------------
/sparql-server/json_to_triples.py:
--------------------------------------------------------------------------------
1 | import os, json
2 | from tqdm import tqdm
3 | from rdflib import Literal
4 |
5 | directory = '../wikidata_proc_json_2'
6 | out_dir = 'ttl_files'
7 | #WD="
2 | PREFIX wds:
3 | PREFIX wdv:
4 | PREFIX wdt:
5 | PREFIX wikibase:
6 | PREFIX p:
7 | PREFIX ps:
8 | PREFIX pq:
9 | PREFIX rdfs:
10 | PREFIX bd:
--------------------------------------------------------------------------------