├── petci.pdf ├── models ├── requirements.txt ├── lstm │ ├── auto_train.sh │ ├── auto_test.sh │ ├── lstm.py │ ├── test.py │ ├── dataloader.py │ └── train.py ├── tree_lstm │ ├── auto_train.sh │ ├── auto_test.sh │ ├── tree_lstm.py │ ├── test.py │ ├── train.py │ └── dataloader.py ├── bert │ ├── dataloader.py │ ├── train.py │ └── test.py └── nts │ ├── vocab.yaml │ └── nts.yaml ├── figs └── plot.py ├── README.md ├── LICENSE └── data └── dataset.py /petci.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kenantang/petci/HEAD/petci.pdf -------------------------------------------------------------------------------- /models/requirements.txt: -------------------------------------------------------------------------------- 1 | torch==1.6.0 2 | opennmt-py==2.2.0 3 | transformers==4.12.5 4 | dgl==0.6.1 5 | stanza==1.3.0 -------------------------------------------------------------------------------- /models/lstm/auto_train.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # use 5 seeds 4 | for SEED in 41 42 43 44 45 5 | do 6 | # use different training sets 7 | for HM in gh gm ghm 8 | do 9 | # use different training set sizes 10 | for PART in 1 2 3 4 5 11 | do 12 | python -u train.py --seed $SEED --train-set train-$HM-$PART --dev-set dev-$HM > log/$SEED-$HM-$PART.txt 13 | done 14 | done 15 | done 16 | -------------------------------------------------------------------------------- /models/tree_lstm/auto_train.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # use 5 seeds 4 | for SEED in 41 42 43 44 45 5 | do 6 | # use different training sets 7 | for HM in gh gm ghm 8 | do 9 | # use different training set sizes 10 | for PART in 1 2 3 4 5 11 | do 12 | python -u train.py --seed $SEED --train-set train-$HM-$PART --dev-set dev-$HM > log/$SEED-$HM-$PART.txt 13 | done 14 | done 15 | done 16 | -------------------------------------------------------------------------------- /models/lstm/auto_test.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | DIRECTORY=batch-25 4 | echo $DIRECTORY 5 | 6 | # WARNING: remove the test results from previous runs 7 | rm test_summary.jsonl 8 | 9 | # test all 75 models 10 | for SEED in 41 42 43 44 45 11 | do 12 | for HM in gh gm ghm 13 | do 14 | for PART in 1 2 3 4 5 15 | do 16 | MODEL=best_$SEED\_train-$HM-$PART\_dev-$HM.pkl 17 | 18 | # test on all four test sets 19 | for SET in g h m all 20 | do 21 | python test.py --directory $DIRECTORY --model $MODEL --test-set test-$SET 22 | done 23 | done 24 | done 25 | done -------------------------------------------------------------------------------- /models/tree_lstm/auto_test.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | DIRECTORY=batch-25 4 | echo $DIRECTORY 5 | 6 | # WARNING: remove the test results from previous runs 7 | rm test_summary.jsonl 8 | 9 | # test all 75 models 10 | for SEED in 41 42 43 44 45 11 | do 12 | for HM in gh gm ghm 13 | do 14 | for PART in 1 2 3 4 5 15 | do 16 | MODEL=best_$SEED\_train-$HM-$PART\_dev-$HM.pkl 17 | 18 | # test on all four test sets 19 | for SET in g h m all 20 | do 21 | python test.py --directory $DIRECTORY --model $MODEL --test-set test-$SET 22 | done 23 | done 24 | done 25 | done -------------------------------------------------------------------------------- /models/bert/dataloader.py: -------------------------------------------------------------------------------- 1 | import torch as th 2 | from transformers import BertTokenizer 3 | 4 | class Dataset(th.utils.data.Dataset): 5 | def __init__(self, mode): 6 | tokenizer = BertTokenizer.from_pretrained("bert-base-uncased") 7 | 8 | sents = [] 9 | self.labels = [] 10 | with open("../../data/"+mode+".txt") as f: 11 | for line in f: 12 | sents.append(line[:-3]) 13 | self.labels.append(int(line[-2])) 14 | 15 | self.encodings = tokenizer(sents, padding=True, truncation=True, max_length=512) 16 | 17 | def __getitem__(self, idx): 18 | item = {key: th.tensor(val[idx]) for key, val in self.encodings.items()} 19 | if self.labels: 20 | item["labels"] = th.tensor(self.labels[idx]) 21 | return item 22 | 23 | def __len__(self): 24 | return len(self.encodings["input_ids"]) -------------------------------------------------------------------------------- /models/nts/vocab.yaml: -------------------------------------------------------------------------------- 1 | # vocab.yaml 2 | 3 | # Create a shared Vocabulary 4 | share_vocab: True 5 | save_data: ../../data/simplify/all.vocab 6 | 7 | # WARNING: overwrite existing files 8 | overwrite: True 9 | 10 | # Use pretrained embedding 11 | both_embeddings: ../../data/embedding/glove.840B.300d.txt 12 | embeddings_type: "GloVe" 13 | word_vec_size: 300 14 | 15 | # Corpus opts: 16 | data: 17 | train: 18 | path_src: ../../data/simplify/train-src.txt 19 | path_tgt: ../../data/simplify/train-tgt.txt 20 | valid: 21 | path_src: ../../data/simplify/dev-src.txt 22 | path_tgt: ../../data/simplify/dev-tgt.txt 23 | test: 24 | path_src: ../../data/simplify/test-src.txt 25 | path_tgt: ../../data/simplify/test-tgt.txt 26 | 27 | # Vocabulary files that were just created 28 | src_vocab: ../../data/simplify/all.vocab 29 | tgt_vocab: ../../data/simplify/all.vocab -------------------------------------------------------------------------------- /models/lstm/lstm.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from torch.nn.utils.rnn import pack_padded_sequence 3 | 4 | class LSTM(nn.Module): 5 | def __init__(self, 6 | num_vocabs, 7 | x_size, 8 | h_size, 9 | num_classes, 10 | dropout, 11 | pretrained_emb=None): 12 | super(LSTM, self).__init__() 13 | self.x_size = x_size 14 | self.embedding = nn.Embedding(num_vocabs, x_size) 15 | if pretrained_emb is not None: 16 | print('Using glove') 17 | self.embedding.weight.data.copy_(pretrained_emb) 18 | self.embedding.weight.requires_grad = True 19 | self.lstm = nn.LSTM(x_size, h_size, batch_first=True) 20 | self.dropout = nn.Dropout(dropout) 21 | self.linear = nn.Linear(h_size, num_classes) 22 | 23 | def forward(self, batch): 24 | # no initial h and c, set to zero by default 25 | embeds = self.embedding(batch.wordid) 26 | embeds = self.dropout(embeds) 27 | packed = pack_padded_sequence(embeds, batch.lengths, batch_first=True) 28 | _, (h, _) = self.lstm(packed) 29 | h = self.dropout(h) 30 | logits = self.linear(h[-1]) 31 | return logits -------------------------------------------------------------------------------- /models/nts/nts.yaml: -------------------------------------------------------------------------------- 1 | # nts.yaml 2 | 3 | # Create a shared Vocabulary 4 | share_vocab: True 5 | save_data: ../../data/simplify/all.vocab 6 | 7 | # WARNING: overwrite existing files 8 | overwrite: True 9 | 10 | # Use pretrained embedding 11 | both_embeddings: ../../data/embedding/glove.840B.300d.txt 12 | embeddings_type: "GloVe" 13 | word_vec_size: 300 14 | 15 | # Corpus opts: 16 | data: 17 | train: 18 | path_src: ../../data/simplify/train-src.txt 19 | path_tgt: ../../data/simplify/train-tgt.txt 20 | valid: 21 | path_src: ../../data/simplify/dev-src.txt 22 | path_tgt: ../../data/simplify/dev-tgt.txt 23 | 24 | # Vocabulary files that were just created 25 | src_vocab: ../../data/simplify/all.vocab 26 | tgt_vocab: ../../data/simplify/all.vocab 27 | 28 | # Train on a single GPU 29 | world_size: 1 30 | gpu_ranks: [0] 31 | 32 | # Set seed! 33 | seed: 41 34 | 35 | # Where to save the checkpoints 36 | save_model: checkpoints/checkpoint 37 | 38 | # Calculation based on training size 39 | # size = 19421 40 | # batch = 64 41 | # step per epoch = size / batch = 303 \approx 300 42 | # total epochs = 15 43 | # total steps = total epochs * step per epoch 44 | # start decay epoch = 8 45 | 46 | # If using the SGD optimizer: 47 | # start_decay_steps: 2400 48 | 49 | batch_size: 64 50 | save_checkpoint_steps: 300 51 | train_steps: 9000 52 | valid_steps: 300 53 | 54 | optim: adam 55 | learning_rate: 0.001 56 | 57 | early_stopping: 5 58 | early_stopping_criteria: ppl -------------------------------------------------------------------------------- /models/bert/train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from transformers import BertForSequenceClassification, Trainer, TrainingArguments 3 | from transformers.trainer_utils import set_seed 4 | from dataloader import Dataset 5 | 6 | def main(args): 7 | set_seed(args.seed) 8 | 9 | model = BertForSequenceClassification.from_pretrained("bert-base-uncased", num_labels=args.num_classes) 10 | 11 | train_dataset = Dataset(mode=args.train_set) 12 | dev_dataset = Dataset(mode=args.dev_set) 13 | 14 | training_args = TrainingArguments( 15 | output_dir='./checkpoints', 16 | num_train_epochs=args.epochs, 17 | per_device_train_batch_size=args.batch_size, 18 | per_device_eval_batch_size=64, 19 | warmup_steps=500, 20 | weight_decay=0.01, 21 | seed=args.seed, 22 | save_total_limit=5, 23 | save_steps=args.save_steps, 24 | evaluation_strategy="no", 25 | logging_dir='./logs', 26 | ) 27 | 28 | trainer = Trainer( 29 | model=model, 30 | args=training_args, 31 | train_dataset=train_dataset, 32 | eval_dataset=dev_dataset 33 | ) 34 | 35 | trainer.train() 36 | 37 | if __name__ == '__main__': 38 | parser = argparse.ArgumentParser() 39 | parser.add_argument('--seed', type=int, default=41) 40 | parser.add_argument('--batch-size', type=int, default=16) 41 | parser.add_argument('--num-classes', type=int, default=2) 42 | parser.add_argument('--epochs', type=int, default=3) 43 | parser.add_argument('--save-steps', type=int, default=500) 44 | parser.add_argument('--train-set', type=str, default="sst-train-tiny") 45 | parser.add_argument('--dev-set', type=str, default="sst-dev-tiny") 46 | args = parser.parse_args() 47 | print(args) 48 | main(args) -------------------------------------------------------------------------------- /models/bert/test.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from transformers import BertForSequenceClassification, Trainer 3 | from transformers.trainer_utils import set_seed 4 | import numpy as np 5 | from dataloader import Dataset 6 | import json 7 | 8 | def main(args): 9 | set_seed(args.seed) 10 | 11 | test_dataset = Dataset(args.test_set) 12 | 13 | model_path = "./checkpoints/" + args.model 14 | model = BertForSequenceClassification.from_pretrained(model_path, num_labels=args.num_classes) 15 | 16 | # Define test trainer 17 | test_trainer = Trainer(model) 18 | 19 | # Make prediction 20 | raw_pred, _, _ = test_trainer.predict(test_dataset) 21 | # Preprocess raw predictions 22 | y_pred = np.argmax(raw_pred, axis=1) 23 | y = test_dataset.labels 24 | correct = np.sum(np.equal(y_pred, y)) 25 | total = len(y) 26 | test_acc = 1.0*correct/total 27 | 28 | # print result as json 29 | result = {} 30 | 31 | result["model"] = args.model 32 | result["test-set"] = args.test_set 33 | result["correct"] = int(correct) 34 | result["total"] = int(total) 35 | result["accuracy"] = round(test_acc, 4) 36 | 37 | print(result) 38 | 39 | if args.save: 40 | result["model"] = "best_{}_train-{}-{}_dev-{}.pkl".format( 41 | str(args.seed), 42 | args.hm, 43 | str(args.part), 44 | args.hm 45 | ) 46 | with open(args.summary, "a") as f: 47 | f.write(json.dumps(result)+'\n') 48 | 49 | else: 50 | # print to a temporary file 51 | with open("check_points_acc.jsonl", "a") as f: 52 | f.write(json.dumps(result)+'\n') 53 | 54 | if __name__ == '__main__': 55 | parser = argparse.ArgumentParser() 56 | parser.add_argument('--seed', type=int, default=41) 57 | parser.add_argument('--hm', type=str, default="gh") 58 | parser.add_argument('--part', type=int, default=1) 59 | parser.add_argument('--num-classes', type=int, default=2) 60 | parser.add_argument('--test-set', type=str, default="test-all") 61 | parser.add_argument('--model', type=str, default=None) 62 | parser.add_argument('--save', action="store_true") 63 | parser.add_argument('--summary', type=str, default='test_summary.jsonl') 64 | args = parser.parse_args() 65 | 66 | assert args.model is not None, "No model provided!" 67 | 68 | main(args) -------------------------------------------------------------------------------- /models/tree_lstm/tree_lstm.py: -------------------------------------------------------------------------------- 1 | import torch as th 2 | import torch.nn as nn 3 | import dgl 4 | 5 | class TreeLSTMCell(nn.Module): 6 | def __init__(self, x_size, h_size): 7 | super(TreeLSTMCell, self).__init__() 8 | self.W_iou = nn.Linear(x_size, 3 * h_size, bias=False) 9 | self.U_iou = nn.Linear(2 * h_size, 3 * h_size, bias=False) 10 | self.b_iou = nn.Parameter(th.zeros(1, 3 * h_size)) 11 | self.U_f = nn.Linear(2 * h_size, 2 * h_size) 12 | 13 | def message_func(self, edges): 14 | return {'h': edges.src['h'], 'c': edges.src['c']} 15 | 16 | def reduce_func(self, nodes): 17 | h_cat = nodes.mailbox['h'].view(nodes.mailbox['h'].size(0), -1) 18 | f = th.sigmoid(self.U_f(h_cat)).view(*nodes.mailbox['h'].size()) 19 | c = th.sum(f * nodes.mailbox['c'], 1) 20 | return {'iou': self.U_iou(h_cat), 'c': c} 21 | 22 | def apply_node_func(self, nodes): 23 | iou = nodes.data['iou'] + self.b_iou 24 | i, o, u = th.chunk(iou, 3, 1) 25 | i, o, u = th.sigmoid(i), th.sigmoid(o), th.tanh(u) 26 | c = i * u + nodes.data['c'] 27 | h = o * th.tanh(c) 28 | return {'h' : h, 'c' : c} 29 | 30 | class TreeLSTM(nn.Module): 31 | def __init__(self, 32 | num_vocabs, 33 | x_size, 34 | h_size, 35 | num_classes, 36 | dropout, 37 | cell_type='nary', 38 | pretrained_emb=None): 39 | super(TreeLSTM, self).__init__() 40 | self.x_size = x_size 41 | self.embedding = nn.Embedding(num_vocabs, x_size) 42 | if pretrained_emb is not None: 43 | print('Using glove') 44 | self.embedding.weight.data.copy_(pretrained_emb) 45 | self.embedding.weight.requires_grad = True 46 | self.dropout = nn.Dropout(dropout) 47 | self.linear = nn.Linear(h_size, num_classes) 48 | cell = TreeLSTMCell 49 | self.cell = cell(x_size, h_size) 50 | 51 | def forward(self, batch, g, h, c): 52 | # feed embedding 53 | embeds = self.embedding(batch.wordid * batch.mask) 54 | g.ndata['iou'] = self.cell.W_iou(self.dropout(embeds)) * batch.mask.float().unsqueeze(-1) 55 | g.ndata['h'] = h 56 | g.ndata['c'] = c 57 | # propagate 58 | dgl.prop_nodes_topo(g, self.cell.message_func, self.cell.reduce_func, apply_node_func=self.cell.apply_node_func) 59 | # compute logits 60 | h = self.dropout(g.ndata.pop('h')) 61 | h_root = h[batch.rootid][:] 62 | logits = self.linear(h_root) 63 | 64 | return logits -------------------------------------------------------------------------------- /models/lstm/test.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch as th 3 | from torch.utils.data import DataLoader 4 | import numpy as np 5 | 6 | from lstm import LSTM 7 | from dataloader import Dataset 8 | from dataloader import batcher 9 | 10 | import json 11 | from os.path import exists 12 | 13 | def main(args): 14 | device = th.device('cpu') 15 | 16 | testset = Dataset(mode=args.test_set) 17 | test_loader = DataLoader(dataset=testset, 18 | batch_size=100, 19 | collate_fn=batcher, 20 | shuffle=True, 21 | num_workers=0) 22 | 23 | model = LSTM(testset.vocab_size, 24 | args.x_size, 25 | args.h_size, 26 | args.num_classes, 27 | args.dropout, 28 | pretrained_emb = testset.pretrained_emb).to(device) 29 | print(model) 30 | 31 | # test 32 | model.load_state_dict(th.load(args.directory + '/' + args.model)) 33 | accs = [] 34 | model.eval() 35 | for batch in test_loader: 36 | with th.no_grad(): 37 | logits = model(batch) 38 | 39 | pred = th.argmax(logits, 1) 40 | acc = th.sum(th.eq(batch.label, pred)).item() 41 | accs.append([acc, len(batch.label)]) 42 | 43 | correct = np.sum([x[0] for x in accs]) 44 | total = np.sum([x[1] for x in accs]) 45 | test_acc = 1.0*correct/total 46 | 47 | # print result as json 48 | result = {} 49 | 50 | result["model"] = args.model 51 | result["test-set"] = args.test_set 52 | result["correct"] = int(correct) 53 | result["total"] = int(total) 54 | result["accuracy"] = round(test_acc, 4) 55 | 56 | print(result) 57 | 58 | # save to file 59 | with open(args.summary, "a") as f: 60 | f.write(json.dumps(result)+'\n') 61 | 62 | 63 | if __name__ == '__main__': 64 | parser = argparse.ArgumentParser() 65 | parser.add_argument('--x-size', type=int, default=300) 66 | parser.add_argument('--h-size', type=int, default=168) 67 | parser.add_argument('--num-classes', type=int, default=2) 68 | parser.add_argument('--dropout', type=float, default=0.5) 69 | parser.add_argument('--test-set', type=str, default="test-g") 70 | parser.add_argument('--model', type=str, default=None) 71 | parser.add_argument('--directory', type=str, default=None) 72 | parser.add_argument('--summary', type=str, default='test_summary.jsonl') 73 | args = parser.parse_args() 74 | print(args) 75 | assert args.model is not None, "No model provided!" 76 | assert args.directory is not None, "Directory not specified!" 77 | assert exists(args.directory + '/' + args.model), "Model does not exist!" 78 | main(args) -------------------------------------------------------------------------------- /models/tree_lstm/test.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import numpy as np 3 | import torch as th 4 | from torch.utils.data import DataLoader 5 | 6 | from dataloader import Dataset 7 | 8 | from tree_lstm import TreeLSTM 9 | from train import batcher 10 | 11 | import json 12 | from os.path import exists 13 | 14 | def main(args): 15 | # only use cpu 16 | device = th.device('cpu') 17 | 18 | testset = Dataset(mode=args.test_set) 19 | test_loader = DataLoader(dataset=testset, 20 | batch_size=100, 21 | collate_fn=batcher(device), 22 | shuffle=True, 23 | num_workers=0) 24 | 25 | model = TreeLSTM(testset.vocab_size, 26 | args.x_size, 27 | args.h_size, 28 | testset.num_classes, 29 | args.dropout, 30 | cell_type='nary', 31 | pretrained_emb = testset.pretrained_emb).to(device) 32 | print(model) 33 | 34 | # test 35 | model.load_state_dict(th.load(args.directory + '/' + args.model)) 36 | accs = [] 37 | model.eval() 38 | for batch in test_loader: 39 | 40 | g = batch.graph.to(device) 41 | n = g.number_of_nodes() 42 | 43 | with th.no_grad(): 44 | h = th.zeros((n, args.h_size)).to(device) 45 | c = th.zeros((n, args.h_size)).to(device) 46 | logits = model(batch, g, h, c) 47 | 48 | pred = th.argmax(logits, 1) 49 | acc = th.sum(th.eq(batch.label[batch.rootid], pred)).item() 50 | accs.append([acc, len(batch.label[batch.rootid])]) 51 | 52 | correct = np.sum([x[0] for x in accs]) 53 | total = np.sum([x[1] for x in accs]) 54 | test_acc = 1.0*correct/total 55 | 56 | # print result as json 57 | result = {} 58 | 59 | result["model"] = args.model 60 | result["test-set"] = args.test_set 61 | result["correct"] = int(correct) 62 | result["total"] = int(total) 63 | result["accuracy"] = round(test_acc, 4) 64 | 65 | print(result) 66 | 67 | # save to file 68 | with open(args.summary, "a") as f: 69 | f.write(json.dumps(result)+'\n') 70 | 71 | if __name__ == '__main__': 72 | parser = argparse.ArgumentParser() 73 | parser.add_argument('--x-size', type=int, default=300) 74 | parser.add_argument('--h-size', type=int, default=150) 75 | parser.add_argument('--num-classes', type=int, default=2) 76 | parser.add_argument('--dropout', type=float, default=0.5) 77 | parser.add_argument('--test-set', type=str, default="test-g") 78 | parser.add_argument('--model', type=str, default=None) 79 | parser.add_argument('--directory', type=str, default=None) 80 | parser.add_argument('--summary', type=str, default='test_summary.jsonl') 81 | args = parser.parse_args() 82 | print(args) 83 | assert args.model is not None, "No model provided!" 84 | assert args.directory is not None, "Directory not specified!" 85 | assert exists(args.directory + '/' + args.model), "Model does not exist!" 86 | main(args) -------------------------------------------------------------------------------- /figs/plot.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from os.path import exists 3 | import json 4 | import numpy as np 5 | import matplotlib.pyplot as plt 6 | 7 | labels = ["gh", "gm", "ghm"] 8 | 9 | def main(args): 10 | with open(args.size_file, "r") as f: 11 | size = json.load(f) 12 | 13 | accs = {} 14 | tests = [] 15 | 16 | with open(args.summary, "r") as f: 17 | for line in f: 18 | result = json.loads(line) 19 | info = result['model'].split('_') 20 | train = info[2] 21 | test = result['test-set'] 22 | if test not in tests: 23 | tests.append(test) 24 | acc = result['accuracy'] 25 | 26 | if train not in accs: 27 | accs[train] = {} 28 | if test in accs[train]: 29 | accs[train][test].append(acc) 30 | else: 31 | accs[train][test] = [acc] 32 | 33 | 34 | # 4 plots 35 | # g acc, h acc, m acc, total acc 36 | 37 | figs = {} 38 | axs = {} 39 | for t in tests: 40 | figs[t], axs[t] = plt.subplots() 41 | axs[t].set_ylim(0, 1) 42 | 43 | accs_mean = {} 44 | accs_sd = {} 45 | accs_sizes = {} 46 | 47 | for l in labels: 48 | accs_mean[l] = [] 49 | accs_sd[l] = [] 50 | accs_sizes[l] = [] 51 | 52 | for k in size: 53 | for l in labels: 54 | # plot 3 lines, gh-, gm-, ghm- 55 | if l+'-' in k: 56 | accs_sizes[l].append(size[k]) 57 | accs_mean[l].append(np.mean(accs[k][t])) 58 | accs_sd[l].append(np.std(accs[k][t])) 59 | 60 | print(t) 61 | 62 | for l in labels: 63 | with open("accs/{}-{}-{}.dat".format(args.model, t, l), "w") as fd: 64 | for idx in range(len(accs_sizes[l])): 65 | fd.write("{}\t{:.2f}\t{:.2f}\n".format( 66 | accs_sizes[l][idx], 67 | accs_mean[l][idx]*100, 68 | accs_sd[l][idx]*100 69 | )) 70 | 71 | # print the result in latex table format 72 | print(" & {:.2f} ({:.2f})".format(accs_mean[l][idx]*100, accs_sd[l][idx]*100), end="") 73 | print() 74 | 75 | for l in labels: 76 | axs[t].errorbar(accs_sizes[l], accs_mean[l], yerr=accs_sd[l], label=l) 77 | 78 | 79 | for t in tests: 80 | axs[t].legend() 81 | figs[t].savefig('{}-{}.png'.format(args.model, t)) 82 | 83 | 84 | if __name__ == '__main__': 85 | parser = argparse.ArgumentParser() 86 | parser.add_argument('--size-file', type=str, default="../data/json/size.json") 87 | parser.add_argument('--model', type=str, default=None) 88 | parser.add_argument('--summary', type=str, default=None) 89 | args = parser.parse_args() 90 | 91 | assert args.model is not None, "Model not specified!" 92 | args.summary = "../models/"+args.model+"/test_summary.jsonl" 93 | assert exists(args.summary), "Summary file does not exist!" 94 | main(args) -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # PETCI: A Parallel English Translation Dataset of Chinese Idioms 2 | 3 | PETCI is a **P**arallel **E**nglish **T**ranslation dataset of **C**hinese **I**dioms, collected from an idiom dictionary and Google and DeepL translation. PETCI contains **4,310** Chinese idioms with **29,936** English translations. These translations capture diverse translation errors and paraphrase strategies. 4 | 5 | We provide several baseline models to facilitate future research on this dataset. 6 | 7 | ## Data 8 | 9 | The Chinese idioms and their translations are in the `./data/json/raw.json` file. Here is one example: 10 | 11 | ```python 12 | { 13 | "id": 0, 14 | "chinese": "一波未平,一波又起", 15 | "book": [ 16 | "suffer a string of reverses", 17 | "hardly has one wave subsided when another rises", 18 | "one trouble follows another" 19 | ], 20 | "google": [ 21 | "One wave is not flat, another wave is rising" 22 | ], 23 | "deepl": [ 24 | "before the first wave subsides, a new wave rises" 25 | ] 26 | } 27 | ``` 28 | 29 | - `id` is the index of the idiom in the dictionary 30 | - `chinese` is the Chinese idiom 31 | - `book` is the translations from the dictionary 32 | - `google` is the translation from Google 33 | - `deepl` is the translation from DeepL 34 | 35 | In `./data/json/filtered.json`, the `machine` translations that are the same as dictionary translations are removed, and the dictionary translations are split into `gold` and `human` translations. 36 | 37 | ## Training and Testing 38 | 39 | ### Prerequisites 40 | 41 | Run `pip install -r ./models/requirements.txt` to install required packages. Download and put `glove.840B.300d.txt` in `./data/embedding`. Download [CoreNLP](https://stanfordnlp.github.io/CoreNLP/index.html). 42 | 43 | ### Create Datasets 44 | Before training, run `java -Xmx4g -cp "*" edu.stanford.nlp.pipeline.StanfordCoreNLPServer -parse.binaryTrees` to start the CoreNLP server, and run the following commands in the `./data` folder to create the necessary datasets. 45 | 46 | ```shell 47 | mkdir label simplify tree 48 | python dataset.py 49 | ``` 50 | 51 | ### LSTM 52 | 53 | In the enclosing folder, run 54 | ```shell 55 | ./auto_train.sh 56 | ./auto_test.sh 57 | ``` 58 | 59 | ### Tree-LSTM 60 | 61 | In the enclosing folder, run 62 | ```shell 63 | ./auto_train.sh 64 | ./auto_test.sh 65 | ``` 66 | 67 | ### BERT 68 | 69 | In the enclosing folder, run 70 | ```shell 71 | SEED=45 72 | HM=ghm 73 | PART=5 74 | python train.py --seed $SEED --train-set train-$HM-$PART --dev-set dev-$HM 75 | 76 | MODEL=checkpoint-5000 77 | python test.py --model $MODEL --test-set dev-$HM --seed $SEED --hm $HM --part $PART 78 | ``` 79 | 80 | ### NTS 81 | 82 | In the enclosing folder, run 83 | ```shell 84 | onmt_build_vocab -config vocab.yaml -n_sample -1 85 | 86 | onmt_train -config nts.yaml 87 | 88 | BEST=checkpoints/checkpoint_step_300.pt 89 | SRC=../../data/simplify/test-src.txt 90 | OUTPUT=../test-output.txt 91 | onmt_translate -model $BEST -src $SRC -output $OUTPUT -verbose -beam_size 5 92 | ``` 93 | 94 | ### Figures 95 | In the `figs` folder, run `python plot.py --model lstm`, where the model name can be replaced by `tree_lstm` or `bert`. 96 | -------------------------------------------------------------------------------- /models/lstm/dataloader.py: -------------------------------------------------------------------------------- 1 | import torch as th 2 | import torch.nn.functional as F 3 | import numpy as np 4 | import stanza 5 | from collections import OrderedDict 6 | from tqdm import tqdm 7 | import pickle 8 | from os.path import exists 9 | import collections 10 | 11 | def tokenize(text, nlp): 12 | doc = nlp(text) 13 | result = [] 14 | for sentence in doc.sentences: 15 | result.append([token.text for token in sentence.tokens]) 16 | return result 17 | 18 | Batch = collections.namedtuple('Batch', ['wordid', 'lengths','label']) 19 | 20 | def batcher(data): 21 | sents, labels = zip(*data) 22 | max_len = max(map(len, sents)) 23 | seq_tensor = [] 24 | sl = [] 25 | for sent in sents: 26 | seq_tensor.append(th.cat([th.Tensor(sent), th.zeros(max_len-len(sent))])) 27 | sl.append(len(sent)) 28 | seq_tensor = th.Tensor(np.stack(seq_tensor, 0)) 29 | seq_lengths = th.Tensor(sl) 30 | seq_lengths, perm_idx = seq_lengths.sort(0, descending=True) 31 | seq_tensor = seq_tensor[perm_idx] 32 | labels = th.Tensor(labels)[perm_idx] 33 | 34 | return Batch(seq_tensor.long(), seq_lengths, labels.long()) 35 | 36 | class Dataset(th.utils.data.Dataset): 37 | 38 | def __init__(self, mode): 39 | self.PAD_WORD = -1 # special pad word id 40 | self.UNK_WORD = 0 # out-of-vocabulary word id 41 | # load vocab file 42 | self._vocab = OrderedDict() 43 | with open("../../data/label/vocab.txt", encoding='utf-8') as vf: 44 | for line in vf.readlines(): 45 | line = line.strip() 46 | self._vocab[line] = len(self._vocab) 47 | 48 | self.vocab_size = len(self._vocab) 49 | 50 | if exists('../../data/label/emb.pkl'): 51 | with open('../../data/label/emb.pkl', 'rb') as handle: 52 | self.pretrained_emb = pickle.load(handle) 53 | else: 54 | glove_emb = {} 55 | with open("../../data/embedding/glove.840B.300d.txt", 'r', encoding='utf-8') as pf: 56 | for line in tqdm(pf.readlines()): 57 | sp = line.split(' ') 58 | if sp[0].lower() in self._vocab: 59 | glove_emb[sp[0].lower()] = np.asarray([float(x) for x in sp[1:]]) 60 | 61 | # initialize with glove 62 | pretrained_emb = [] 63 | for line in self._vocab.keys(): 64 | pretrained_emb.append(glove_emb.get(line.lower(), np.random.uniform(-0.05, 0.05, 300))) 65 | 66 | self.pretrained_emb = F.Tensor(np.stack(pretrained_emb, 0)) 67 | 68 | with open('../../data/label/emb.pkl', 'wb') as handle: 69 | pickle.dump(self.pretrained_emb, handle, protocol=pickle.HIGHEST_PROTOCOL) 70 | 71 | stanza.download('en') 72 | nlp = stanza.Pipeline(lang='en', processors='tokenize', tokenize_no_ssplit=True) 73 | 74 | self.sents = [] 75 | self.labels = [] 76 | with open("../../data/label/"+mode+".txt") as f: 77 | 78 | # parse all sentences at once to speed up 79 | all_sentences = "" 80 | for line in f: 81 | all_sentences += line[:-3] + '\n\n' 82 | self.labels.append(int(line[-2])) 83 | 84 | for sent in tokenize(all_sentences, nlp): 85 | self.sents.append([self._vocab.get(word, self.UNK_WORD) for word in sent]) 86 | 87 | print("Finished tokenization for {}!".format(mode)) 88 | 89 | def __len__(self): 90 | return len(self.labels) 91 | 92 | def __getitem__(self, index): 93 | return self.sents[index], self.labels[index] -------------------------------------------------------------------------------- /models/lstm/train.py: -------------------------------------------------------------------------------- 1 | from ast import Pass 2 | import time 3 | import torch as th 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | import torch.nn.init as INIT 7 | import torch.optim as optim 8 | from torch.utils.data import DataLoader 9 | import numpy as np 10 | import argparse 11 | 12 | from lstm import LSTM 13 | from dataloader import Dataset 14 | from dataloader import batcher 15 | 16 | def main(args): 17 | np.random.seed(args.seed) 18 | th.manual_seed(args.seed) 19 | th.cuda.manual_seed(args.seed) 20 | 21 | best_epoch = -1 22 | best_dev_acc = 0 23 | 24 | # always use cpu 25 | device = th.device('cpu') 26 | 27 | trainset = Dataset(mode=args.train_set) 28 | train_loader = DataLoader(dataset=trainset, 29 | batch_size=args.batch_size, 30 | collate_fn=batcher, 31 | shuffle=True, 32 | num_workers=0) 33 | 34 | devset = Dataset(mode=args.dev_set) 35 | dev_loader = DataLoader(dataset=devset, 36 | batch_size=100, 37 | collate_fn=batcher, 38 | shuffle=False, 39 | num_workers=0) 40 | 41 | model = LSTM(trainset.vocab_size, 42 | args.x_size, 43 | args.h_size, 44 | args.num_classes, 45 | args.dropout, 46 | pretrained_emb = trainset.pretrained_emb).to(device) 47 | print(model) 48 | 49 | # parameters that are not embedding 50 | params_ex_emb =[x for x in list(model.parameters()) if x.requires_grad and x.size(0)!=trainset.vocab_size] 51 | params_emb = list(model.embedding.parameters()) 52 | 53 | for p in params_ex_emb: 54 | if p.dim() > 1: 55 | INIT.xavier_uniform_(p) 56 | 57 | optimizer = optim.Adagrad([ 58 | {'params':params_ex_emb, 'lr':args.lr, 'weight_decay':args.weight_decay}, 59 | {'params':params_emb, 'lr':0.1*args.lr}]) 60 | 61 | dur = [] 62 | for epoch in range(args.epochs): 63 | t_epoch = time.time() 64 | model.train() 65 | for step, batch in enumerate(train_loader): 66 | if step >= 3: 67 | t0 = time.time() 68 | 69 | logits = model(batch) 70 | logp = F.log_softmax(logits, 1) 71 | loss = F.nll_loss(logp, batch.label, reduction='sum') 72 | 73 | optimizer.zero_grad() 74 | loss.backward() 75 | optimizer.step() 76 | 77 | if step >= 3: 78 | dur.append(time.time() - t0) 79 | 80 | if step > 0 and step % args.log_every == 0: 81 | pred = th.argmax(logits, 1) 82 | acc = th.sum(th.eq(batch.label, pred)) 83 | 84 | print("Epoch {:05d} | Step {:05d} | Loss {:.4f} | Acc {:.4f} | Time(s) {:.4f}".format( 85 | epoch, step, loss.item(), 1.0*acc/len(batch.label), np.mean(dur))) 86 | print('Epoch {:05d} training time {:.4f}s'.format(epoch, time.time() - t_epoch)) 87 | 88 | # eval on dev set 89 | accs = [] 90 | model.eval() 91 | for step, batch in enumerate(dev_loader): 92 | with th.no_grad(): 93 | logits = model(batch) 94 | 95 | pred = th.argmax(logits, 1) 96 | acc = th.sum(th.eq(batch.label, pred)).item() 97 | accs.append([acc, len(batch.label)]) 98 | 99 | dev_acc = 1.0*np.sum([x[0] for x in accs])/np.sum([x[1] for x in accs]) 100 | print("Epoch {:05d} | Dev Acc {:.4f}".format(epoch, dev_acc)) 101 | 102 | if dev_acc > best_dev_acc: 103 | best_dev_acc = dev_acc 104 | best_epoch = epoch 105 | th.save(model.state_dict(), 106 | 'checkpoints/best_{}_{}_{}.pkl'.format(args.seed, args.train_set,args.dev_set)) 107 | else: 108 | # early stopping 109 | if best_epoch <= epoch - 10: 110 | break 111 | 112 | # lr decay 113 | for param_group in optimizer.param_groups: 114 | param_group['lr'] = max(1e-5, param_group['lr']*0.99) 115 | print(param_group['lr']) 116 | 117 | if __name__ == '__main__': 118 | parser = argparse.ArgumentParser() 119 | parser.add_argument('--seed', type=int, default=41) 120 | parser.add_argument('--batch-size', type=int, default=20) 121 | parser.add_argument('--x-size', type=int, default=300) 122 | parser.add_argument('--h-size', type=int, default=168) 123 | parser.add_argument('--num-classes', type=int, default=2) 124 | parser.add_argument('--epochs', type=int, default=100) 125 | parser.add_argument('--log-every', type=int, default=10) 126 | parser.add_argument('--lr', type=float, default=0.05) 127 | parser.add_argument('--weight-decay', type=float, default=1e-4) 128 | parser.add_argument('--dropout', type=float, default=0.5) 129 | parser.add_argument('--train-set', type=str, default="train-gh-1") 130 | parser.add_argument('--dev-set', type=str, default="dev-gh") 131 | args = parser.parse_args() 132 | print(args) 133 | main(args) -------------------------------------------------------------------------------- /models/tree_lstm/train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import collections 3 | import time 4 | import numpy as np 5 | import torch as th 6 | import torch.nn.functional as F 7 | import torch.nn.init as INIT 8 | import torch.optim as optim 9 | from torch.utils.data import DataLoader 10 | 11 | import dgl 12 | from dataloader import Dataset 13 | 14 | from tree_lstm import TreeLSTM 15 | 16 | # skip warnings, may have side effects 17 | import os 18 | os.environ['KMP_DUPLICATE_LIB_OK']='True' 19 | 20 | Batch = collections.namedtuple('Batch', ['graph', 'mask', 'wordid', 'label', 'rootid']) 21 | 22 | eps = 1e-30 23 | 24 | def batcher(device): 25 | def batcher_dev(batch): 26 | batch_trees = dgl.batch(batch) 27 | 28 | rootid = [0] 29 | for b in batch: 30 | rootid.append(len(b)+rootid[-1]) 31 | rootid.pop() 32 | 33 | return Batch(graph=batch_trees, 34 | mask=batch_trees.ndata['mask'].to(device), 35 | wordid=batch_trees.ndata['x'].to(device), 36 | label=batch_trees.ndata['y'].to(device), 37 | rootid=th.tensor(rootid).to(device)) 38 | return batcher_dev 39 | 40 | def main(args): 41 | np.random.seed(args.seed) 42 | th.manual_seed(args.seed) 43 | th.cuda.manual_seed(args.seed) 44 | 45 | best_epoch = -1 46 | best_dev_acc = 0 47 | 48 | # only use cpu 49 | device = th.device('cpu') 50 | 51 | trainset = Dataset(mode=args.train_set) 52 | train_loader = DataLoader(dataset=trainset, 53 | batch_size=args.batch_size, 54 | collate_fn=batcher(device), 55 | shuffle=True, 56 | num_workers=0) 57 | devset = Dataset(mode=args.dev_set) 58 | dev_loader = DataLoader(dataset=devset, 59 | batch_size=100, 60 | collate_fn=batcher(device), 61 | shuffle=False, 62 | num_workers=0) 63 | 64 | model = TreeLSTM(trainset.vocab_size, 65 | args.x_size, 66 | args.h_size, 67 | trainset.num_classes, 68 | args.dropout, 69 | cell_type='nary', 70 | pretrained_emb = trainset.pretrained_emb).to(device) 71 | print(model) 72 | 73 | params_ex_emb =[x for x in list(model.parameters()) if x.requires_grad and x.size(0)!=trainset.vocab_size] 74 | params_emb = list(model.embedding.parameters()) 75 | 76 | for p in params_ex_emb: 77 | if p.dim() > 1: 78 | INIT.xavier_uniform_(p) 79 | 80 | optimizer = optim.Adagrad([ 81 | {'params':params_ex_emb, 'lr':args.lr, 'weight_decay':args.weight_decay}, 82 | {'params':params_emb, 'lr':0.1*args.lr}]) 83 | 84 | dur = [] 85 | for epoch in range(args.epochs): 86 | training_loss = 0 87 | 88 | t_epoch = time.time() 89 | model.train() 90 | for step, batch in enumerate(train_loader): 91 | g = batch.graph.to(device) 92 | n = g.number_of_nodes() 93 | h = th.zeros((n, args.h_size)).to(device) 94 | c = th.zeros((n, args.h_size)).to(device) 95 | if step >= 3: 96 | t0 = time.time() 97 | 98 | logits = model(batch, g, h, c) 99 | logp = F.log_softmax(logits, 1) 100 | loss = F.nll_loss(logp, batch.label[batch.rootid], reduction='sum') 101 | 102 | training_loss += loss.data.item() 103 | 104 | optimizer.zero_grad() 105 | loss.backward() 106 | optimizer.step() 107 | 108 | if step >= 3: 109 | dur.append(time.time() - t0) 110 | 111 | if step > 0 and step % args.log_every == 0: 112 | pred = th.argmax(logits, 1) 113 | 114 | acc = th.sum(th.eq(batch.label[batch.rootid], pred)) 115 | print("Epoch {:05d} | Step {:05d} | Loss {:.4f} | Root Acc {:.4f} | Time(s) {:.4f}".format( 116 | epoch, step, loss.item(), 1.0*acc/len(batch.rootid), np.mean(dur))) 117 | print('Epoch {:05d} training time {:.4f}s'.format(epoch, time.time() - t_epoch)) 118 | print('Training loss:', training_loss) 119 | 120 | # eval on dev set 121 | accs = [] 122 | model.eval() 123 | for step, batch in enumerate(dev_loader): 124 | g = batch.graph.to(device) 125 | n = g.number_of_nodes() 126 | with th.no_grad(): 127 | h = th.zeros((n, args.h_size)).to(device) 128 | c = th.zeros((n, args.h_size)).to(device) 129 | logits = model(batch, g, h, c) 130 | 131 | pred = th.argmax(logits, 1) 132 | acc = th.sum(th.eq(batch.label[batch.rootid], pred)) 133 | accs.append([acc, len(batch.label[batch.rootid])]) 134 | 135 | dev_acc = 1.0*np.sum([x[0] for x in accs])/np.sum([x[1] for x in accs]) 136 | print("Epoch {:05d} | Dev Acc {:.4f}".format(epoch, dev_acc)) 137 | 138 | if dev_acc > best_dev_acc: 139 | best_dev_acc = dev_acc 140 | best_epoch = epoch 141 | th.save(model.state_dict(), 142 | 'checkpoints/best_{}_{}_{}.pkl'.format(args.seed, args.train_set,args.dev_set)) 143 | else: 144 | if best_epoch <= epoch - 10: 145 | break 146 | 147 | # lr decay 148 | for param_group in optimizer.param_groups: 149 | param_group['lr'] = max(1e-5, param_group['lr']*0.99) #10 150 | print(param_group['lr']) 151 | 152 | if __name__ == '__main__': 153 | parser = argparse.ArgumentParser() 154 | parser.add_argument('--seed', type=int, default=41) 155 | parser.add_argument('--batch-size', type=int, default=20) 156 | parser.add_argument('--x-size', type=int, default=300) 157 | parser.add_argument('--h-size', type=int, default=150) 158 | parser.add_argument('--epochs', type=int, default=100) 159 | parser.add_argument('--log-every', type=int, default=5) 160 | parser.add_argument('--lr', type=float, default=0.05) 161 | parser.add_argument('--weight-decay', type=float, default=1e-4) 162 | parser.add_argument('--dropout', type=float, default=0.5) 163 | parser.add_argument('--train-set', type=str, default="train-gh-1") 164 | parser.add_argument('--dev-set', type=str, default="dev-gh") 165 | args = parser.parse_args() 166 | print(args) 167 | main(args) -------------------------------------------------------------------------------- /models/tree_lstm/dataloader.py: -------------------------------------------------------------------------------- 1 | from dgl.data import DGLDataset 2 | from collections import OrderedDict 3 | import networkx as nx 4 | import numpy as np 5 | import os 6 | from tqdm import tqdm 7 | 8 | import torch as F 9 | from dgl.data.utils import save_graphs, save_info, load_graphs, \ 10 | load_info, deprecate_property 11 | from dgl.convert import from_networkx 12 | 13 | class Dataset(DGLDataset): 14 | 15 | PAD_WORD = 0 # special pad word id 16 | UNK_WORD = 0 # out-of-vocabulary word id 17 | 18 | def __init__(self, 19 | mode='train', 20 | glove_embed_file="../../data/embedding/glove.840B.300d.txt", 21 | vocab_file=None, 22 | raw_dir="../../data/", 23 | force_reload=False, 24 | verbose=False): 25 | 26 | name = "tree" 27 | self._glove_embed_file = glove_embed_file 28 | if not os.path.exists(raw_dir+name+"/emb.pkl"): 29 | self._glove_embed_file = glove_embed_file 30 | else: 31 | self._glove_embed_file = None 32 | self.mode = mode 33 | self._vocab_file = vocab_file 34 | super().__init__(name=name, 35 | url=None, 36 | raw_dir=raw_dir, 37 | force_reload=force_reload, 38 | verbose=verbose) 39 | 40 | def process(self): 41 | from nltk.corpus.reader import BracketParseCorpusReader 42 | # load vocab file 43 | self._vocab = OrderedDict() 44 | vocab_file = self._vocab_file if self._vocab_file is not None else os.path.join(self.raw_path, 'vocab.txt') 45 | with open(vocab_file, encoding='utf-8') as vf: 46 | for line in vf.readlines(): 47 | line = line.strip() 48 | self._vocab[line] = len(self._vocab) 49 | 50 | # filter glove 51 | if self._glove_embed_file is not None and os.path.exists(self._glove_embed_file): 52 | glove_emb = {} 53 | with open(self._glove_embed_file, 'r', encoding='utf-8') as pf: 54 | for line in tqdm(pf.readlines()): 55 | sp = line.split(' ') 56 | if sp[0].lower() in self._vocab: 57 | glove_emb[sp[0].lower()] = np.asarray([float(x) for x in sp[1:]]) 58 | files = ['{}.txt'.format(self.mode)] 59 | corpus = BracketParseCorpusReader(self.raw_path, files) 60 | sents = corpus.parsed_sents(files[0]) 61 | 62 | # initialize with glove 63 | pretrained_emb = [] 64 | fail_cnt = 0 65 | for line in self._vocab.keys(): 66 | if self._glove_embed_file is not None and os.path.exists(self._glove_embed_file): 67 | if not line.lower() in glove_emb: 68 | fail_cnt += 1 69 | pretrained_emb.append(glove_emb.get(line.lower(), np.random.uniform(-0.05, 0.05, 300))) 70 | 71 | self._pretrained_emb = None 72 | if self._glove_embed_file is not None and os.path.exists(self._glove_embed_file): 73 | self._pretrained_emb = F.tensor(np.stack(pretrained_emb, 0)) 74 | print('Miss word in GloVe {0:.4f}'.format(1.0 * fail_cnt / len(self._pretrained_emb))) 75 | # build trees 76 | self._trees = [] 77 | for sent in sents: 78 | self._trees.append(self._build_tree(sent)) 79 | 80 | def _build_tree(self, root): 81 | g = nx.DiGraph() 82 | 83 | def _rec_build(nid, node): 84 | for child in node: 85 | cid = g.number_of_nodes() 86 | 87 | # account for trees with a single node 88 | if isinstance(child, str): 89 | return 90 | 91 | if isinstance(child[0], str) or isinstance(child[0], bytes): 92 | # leaf node 93 | word = self.vocab.get(child[0].lower(), self.UNK_WORD) 94 | g.add_node(cid, x=word, y=int(child.label()), mask=1) 95 | else: 96 | g.add_node(cid, x=Dataset.PAD_WORD, y=int(child.label()), mask=0) 97 | _rec_build(cid, child) 98 | g.add_edge(cid, nid) 99 | 100 | # add root 101 | g.add_node(0, x=Dataset.PAD_WORD, y=int(root.label()), mask=0) 102 | _rec_build(0, root) 103 | ret = from_networkx(g, node_attrs=['x', 'y', 'mask']) 104 | return ret 105 | 106 | def has_cache(self): 107 | graph_path = os.path.join(self.save_path, self.mode + '_dgl_graph.bin') 108 | vocab_path = os.path.join(self.save_path, 'vocab.pkl') 109 | return os.path.exists(graph_path) and os.path.exists(vocab_path) 110 | 111 | def save(self): 112 | graph_path = os.path.join(self.save_path, self.mode + '_dgl_graph.bin') 113 | save_graphs(graph_path, self._trees) 114 | vocab_path = os.path.join(self.save_path, 'vocab.pkl') 115 | save_info(vocab_path, {'vocab': self.vocab}) 116 | if self.pretrained_emb is not None: 117 | emb_path = os.path.join(self.save_path, 'emb.pkl') 118 | save_info(emb_path, {'embed': self.pretrained_emb}) 119 | 120 | def load(self): 121 | graph_path = os.path.join(self.save_path, self.mode + '_dgl_graph.bin') 122 | vocab_path = os.path.join(self.save_path, 'vocab.pkl') 123 | emb_path = os.path.join(self.save_path, 'emb.pkl') 124 | 125 | self._trees = load_graphs(graph_path)[0] 126 | self._vocab = load_info(vocab_path)['vocab'] 127 | self._pretrained_emb = None 128 | if os.path.exists(emb_path): 129 | self._pretrained_emb = load_info(emb_path)['embed'] 130 | 131 | @property 132 | def trees(self): 133 | deprecate_property('dataset.trees', '[dataset[i] for i in len(dataset)]') 134 | return self._trees 135 | 136 | @property 137 | def vocab(self): 138 | return self._vocab 139 | 140 | @property 141 | def pretrained_emb(self): 142 | return self._pretrained_emb 143 | 144 | def __getitem__(self, idx): 145 | return self._trees[idx] 146 | 147 | 148 | def __len__(self): 149 | return len(self._trees) 150 | 151 | 152 | @property 153 | def num_vocabs(self): 154 | deprecate_property('dataset.num_vocabs', 'dataset.vocab_size') 155 | return self.vocab_size 156 | 157 | @property 158 | def vocab_size(self): 159 | return len(self._vocab) 160 | 161 | @property 162 | def num_classes(self): 163 | return 2 -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /data/dataset.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import stanza 3 | import json 4 | import numpy as np 5 | from os.path import exists 6 | import requests 7 | import json 8 | import re 9 | from tqdm import tqdm 10 | 11 | # clean half nodes from a binary parse 12 | def clean_binary_parse(bp): 13 | bp = bp.replace('\n', '') 14 | bp = re.sub(' +', ' ', bp) 15 | bp = re.sub('\((.*?) ', '(0 ', bp) 16 | 17 | # construct a tree from the string 18 | 19 | words = [] 20 | words_lc = [] 21 | words_rc = [] 22 | words_p = [] 23 | word = "" 24 | root_p = -1 25 | cur_p = -1 26 | for c in bp: 27 | if c != '(' and c != ')': 28 | word += c 29 | else: 30 | if not word.isspace() and word!='': 31 | words.append(word) 32 | words_lc.append(None) 33 | words_rc.append(None) 34 | words_p.append(cur_p) 35 | cur_c = len(words_p)-1 36 | if cur_p != root_p: 37 | if not words_lc[cur_p]: 38 | words_lc[cur_p] = cur_c 39 | else: 40 | words_rc[cur_p] = cur_c 41 | cur_p = cur_c 42 | if c == ')': 43 | cur_p = words_p[cur_p] 44 | word = "" 45 | 46 | # remove half-nodes 47 | removed = [False for _ in words] 48 | for i in range(len(words)): 49 | 50 | # by our construction, any half-node always has left child 51 | if words_lc[i] and not words_rc[i]: 52 | g = words_p[i] 53 | c = words_lc[i] 54 | words_p[c] = g 55 | 56 | # we do not want returns the last element for index==-1 57 | if g != root_p: 58 | if words_lc[g] == i: 59 | words_lc[g] = c 60 | else: 61 | words_rc[g] = c 62 | 63 | removed[i] = True 64 | if words_p[i] == root_p and not removed[i]: 65 | root = i 66 | 67 | def inorder(root): 68 | if not root: 69 | return '' 70 | s = '(' + words[root] 71 | s += inorder(words_lc[root]) 72 | if words_lc[root]: 73 | s += ' ' 74 | s += inorder(words_rc[root]) + ')' 75 | return s 76 | 77 | # in the end, there should be 2n-1 labels if there are n tokens 78 | return inorder(root) 79 | 80 | def main(args): 81 | np.random.seed(args.seed) 82 | 83 | # open the json 84 | with open("json/filtered.json", "r") as f: 85 | filtered = json.load(f) 86 | 87 | all_sentences = "" 88 | for i in filtered: 89 | all_sentences += i["gold"].lower() + '\n\n' 90 | for h in i["human"]: 91 | all_sentences += h.lower() + '\n\n' 92 | for m in i["machine"]: 93 | all_sentences += m.lower() + '\n\n' 94 | 95 | # create vocabulary for the labelled sentences 96 | # required by LSTM but not by BERT 97 | if exists("label/vocab.txt"): 98 | print("Vocabulary for labelled sentences exists!") 99 | else: 100 | stanza.download('en') 101 | 102 | vocab = set() 103 | nlp = stanza.Pipeline(lang='en', processors='tokenize', tokenize_no_ssplit=True) 104 | 105 | print("Creating vocabulary for labelled sentences...") 106 | 107 | doc = nlp(all_sentences) 108 | for sentence in doc.sentences: 109 | for token in sentence.tokens: 110 | vocab.add(token.text) 111 | 112 | with open("label/vocab.txt", "w") as fo: 113 | fo.write("\n") 114 | for v in list(vocab): 115 | fo.write(v+"\n") 116 | 117 | 118 | # get parse results of all sentences 119 | parsed_sentences = [] 120 | 121 | if exists("tree/parse.txt"): 122 | # need to manually check if this file is complete 123 | with open("tree/parse.txt", "r") as f: 124 | for line in f: 125 | parsed_sentences.append(json.loads(line)) 126 | else: 127 | # parse sentences from scratch, may take around 10 minutes! 128 | r = 'http://[::]:9000/?properties={"annotators":"tokenize,ssplit,pos,parse","outputFormat":"json"}' 129 | separate_sentences = all_sentences.split('\n\n') # last one is empty 130 | 131 | # file to save the parsed sentence 132 | fp = open("tree/parse.txt", "w") 133 | 134 | for s in tqdm(separate_sentences): 135 | j = json.loads(requests.post(r, data = s).text) 136 | 137 | # save only relevant fields 138 | j_rel = {} 139 | 140 | # clean the binary parse right after results are returned 141 | j_rel['binaryParse'] = clean_binary_parse(j['sentences'][0]['binaryParse']) 142 | 143 | # save tokens for building vocabulary 144 | j_rel['tokens'] = j['sentences'][0]['tokens'] 145 | parsed_sentences.append(j_rel) 146 | json.dump(parsed_sentences[-1], fp) 147 | fp.write('\n') 148 | 149 | # read the parsed sentence for each idiom 150 | parsed_idiom = [] 151 | idx = 0 152 | for i in filtered: 153 | parsed_idiom.append({"gold": parsed_sentences[idx]}) 154 | 155 | # change to gold label on root node! 156 | parsed_idiom[-1]["gold"]["binaryParse"] = '(1' + parsed_idiom[-1]["gold"]["binaryParse"][2:] 157 | idx += 1 158 | parsed_idiom[-1]["human"] = [] 159 | parsed_idiom[-1]["machine"] = [] 160 | for h in i["human"]: 161 | parsed_idiom[-1]["human"].append(parsed_sentences[idx]) 162 | idx += 1 163 | for m in i["machine"]: 164 | parsed_idiom[-1]["machine"].append(parsed_sentences[idx]) 165 | idx += 1 166 | 167 | # create vocabulary for the constituency parse trees 168 | # start the server by: 169 | # java -Xmx4g -cp "*" edu.stanford.nlp.pipeline.StanfordCoreNLPServer -parse.binaryTrees 170 | if exists("tree/vocab.txt"): 171 | print("Vocabulary for parsed sentences exists!") 172 | else: 173 | vocab = set() 174 | 175 | print("Creating vocabulary for parsed sentences...") 176 | 177 | for sentence in parsed_sentences: 178 | for token in sentence['tokens']: 179 | vocab.add(token['word']) 180 | 181 | with open("tree/vocab.txt", "w") as fo: 182 | fo.write("\n") 183 | for v in list(vocab): 184 | fo.write(v+"\n") 185 | 186 | assert args.points >= 2, "Too few data points!" 187 | 188 | idiom_total = len(filtered) 189 | split = {"train": [], "dev": [], "test": []} 190 | idx_train = int(idiom_total * 0.8) 191 | idx_dev = int(idiom_total * 0.9) 192 | all_idx =list(range(idiom_total)) 193 | 194 | # need to shuffle, because idioms of similar meaning are put together in the dictionary 195 | np.random.shuffle(all_idx) 196 | split['train'] = all_idx[0:idx_train] 197 | split['dev'] = all_idx[idx_train:idx_dev] 198 | split['test'] = all_idx[idx_dev:idiom_total] 199 | print("train/dev/test: {}, {}, {}".format(idx_train, idx_dev-idx_train, idiom_total-idx_dev)) 200 | 201 | with open("json/split.json", "w") as fo: 202 | json.dump(split, fo) 203 | 204 | # all lower case 205 | for w in filtered: 206 | w["gold"] = w["gold"].lower() 207 | w["human"] = [h.lower() for h in w["human"]] 208 | w["machine"] = [m.lower() for m in w["machine"]] 209 | 210 | # train set 211 | 212 | hum_head = [] 213 | mac_head = [] 214 | hum_tail = [] 215 | mac_tail = [] 216 | 217 | filtered_train = [filtered[i] for i in split["train"]] 218 | parsed_idiom_train = [parsed_idiom[i] for i in split["train"]] 219 | 220 | for idx, i in enumerate(filtered_train): 221 | lh = len(i["human"]) 222 | if lh > 0: 223 | hum_head.append((idx, 0)) 224 | if lh > 1: 225 | hum_tail.extend([(idx, j) for j in range(1, lh)]) 226 | mh = len(i["machine"]) 227 | if mh > 0: 228 | mac_head.append((idx, 0)) 229 | if mh > 1: 230 | mac_tail.extend([(idx, j) for j in range(1, mh)]) 231 | 232 | np.random.shuffle(hum_tail) 233 | np.random.shuffle(mac_tail) 234 | hum_all = hum_head + hum_tail 235 | mac_all = mac_head + mac_tail 236 | 237 | hum_indices = np.linspace(len(hum_head), len(hum_all), args.points, dtype=int) 238 | mac_indices = np.linspace(len(mac_head), len(mac_all), args.points, dtype=int) 239 | 240 | print("Number of human translations for training:", hum_indices) 241 | print("Number of machine translations for training:", mac_indices) 242 | 243 | sizes = {} 244 | for i in range(args.points): 245 | name = str(i+1) 246 | sizes["train-gh-"+name] = int(hum_indices[i]) 247 | sizes["train-gm-"+name] = int(mac_indices[i]) 248 | sizes["train-ghm-"+name] = int(hum_indices[i]+mac_indices[i]) 249 | 250 | # save training set size for plotting 251 | with open("json/size.json", "w") as f: 252 | json.dump(sizes, f, indent=4) 253 | 254 | for i, (hr, mr) in enumerate(zip(hum_indices, mac_indices)): 255 | hl = 0 256 | ml = 0 257 | 258 | file_idx = str(i+1) 259 | 260 | # save labelled sentences 261 | foh_train = open("label/train-gh-"+file_idx+".txt", "w") 262 | fom_train = open("label/train-gm-"+file_idx+".txt", "w") 263 | fohm_train = open("label/train-ghm-"+file_idx+".txt", "w") 264 | 265 | # save parsed sentences 266 | th_train = open("tree/train-gh-"+file_idx+".txt", "w") 267 | tm_train = open("tree/train-gm-"+file_idx+".txt", "w") 268 | thm_train = open("tree/train-ghm-"+file_idx+".txt", "w") 269 | 270 | # save simplification 271 | if i == args.points - 1: 272 | ss_train = open("simplify/train-src.txt", "w") 273 | st_train = open("simplify/train-tgt.txt", "w") 274 | 275 | # balance training data 276 | 277 | for j in range(hl, hr): 278 | i_idx, j_idx = hum_all[j] 279 | 280 | # labelled 281 | foh_train.write(filtered_train[i_idx]["gold"] + " 1\n") 282 | foh_train.write(filtered_train[i_idx]["human"][j_idx] + " 0\n") 283 | fohm_train.write(filtered_train[i_idx]["gold"] + " 1\n") 284 | fohm_train.write(filtered_train[i_idx]["human"][j_idx] + " 0\n") 285 | 286 | # tree 287 | th_train.write(parsed_idiom_train[i_idx]["gold"]['binaryParse']+"\n") 288 | th_train.write(parsed_idiom_train[i_idx]["human"][j_idx]['binaryParse']+"\n") 289 | thm_train.write(parsed_idiom_train[i_idx]["gold"]['binaryParse']+"\n") 290 | thm_train.write(parsed_idiom_train[i_idx]["human"][j_idx]['binaryParse']+"\n") 291 | 292 | # simplify, do not save parts, but all 293 | if i == args.points - 1: 294 | ss_train.write(filtered_train[i_idx]["human"][j_idx]+'\n') 295 | st_train.write(filtered_train[i_idx]["gold"]+'\n') 296 | 297 | for j in range(ml, mr): 298 | i_idx, j_idx = mac_all[j] 299 | 300 | # labelled 301 | fom_train.write(filtered_train[i_idx]["gold"] + " 1\n") 302 | fom_train.write(filtered_train[i_idx]["machine"][j_idx] + " 0\n") 303 | fohm_train.write(filtered_train[i_idx]["gold"] + " 1\n") 304 | fohm_train.write(filtered_train[i_idx]["machine"][j_idx] + " 0\n") 305 | 306 | # tree 307 | tm_train.write(parsed_idiom_train[i_idx]["gold"]['binaryParse']+"\n") 308 | tm_train.write(parsed_idiom_train[i_idx]["machine"][j_idx]['binaryParse']+"\n") 309 | thm_train.write(parsed_idiom_train[i_idx]["gold"]['binaryParse']+"\n") 310 | thm_train.write(parsed_idiom_train[i_idx]["machine"][j_idx]['binaryParse']+"\n") 311 | 312 | # simplify 313 | if i == args.points - 1: 314 | ss_train.write(filtered_train[i_idx]["machine"][j_idx]+'\n') 315 | st_train.write(filtered_train[i_idx]["gold"]+'\n') 316 | 317 | # dev set 318 | 319 | foh_dev = open("label/dev-gh.txt", "w") 320 | fom_dev = open("label/dev-gm.txt", "w") 321 | fohm_dev = open("label/dev-ghm.txt", "w") 322 | 323 | th_dev = open("tree/dev-gh.txt", "w") 324 | tm_dev = open("tree/dev-gm.txt", "w") 325 | thm_dev = open("tree/dev-ghm.txt", "w") 326 | 327 | ss_dev = open("simplify/dev-src.txt", "w") 328 | st_dev = open("simplify/dev-tgt.txt", "w") 329 | 330 | for i in split["dev"]: 331 | w = filtered[i] 332 | wp = parsed_idiom[i] 333 | 334 | foh_dev.write(w["gold"] + " 1\n") 335 | fom_dev.write(w["gold"] + " 1\n") 336 | 337 | th_dev.write(wp["gold"]['binaryParse']+'\n') 338 | tm_dev.write(wp["gold"]['binaryParse']+'\n') 339 | 340 | # labelled and simplify 341 | for h in w["human"]: 342 | foh_dev.write(h + " 0\n") 343 | fohm_dev.write(h + " 0\n") 344 | 345 | ss_dev.write(h+'\n') 346 | st_dev.write(w["gold"]+'\n') 347 | 348 | for m in w["machine"]: 349 | fom_dev.write(m + " 0\n") 350 | fohm_dev.write(m + " 0\n") 351 | 352 | ss_dev.write(m+'\n') 353 | st_dev.write(w["gold"]+'\n') 354 | 355 | # tree 356 | for h in wp["human"]: 357 | th_dev.write(h['binaryParse']+'\n') 358 | thm_dev.write(h['binaryParse']+'\n') 359 | for m in wp["machine"]: 360 | tm_dev.write(m['binaryParse']+'\n') 361 | thm_dev.write(m['binaryParse']+'\n') 362 | 363 | 364 | # test set 365 | 366 | fog_test = open("label/test-g.txt", "w") 367 | foh_test = open("label/test-h.txt", "w") 368 | fom_test = open("label/test-m.txt", "w") 369 | foa_test = open("label/test-all.txt", "w") 370 | 371 | tg_test = open("tree/test-g.txt", "w") 372 | th_test = open("tree/test-h.txt", "w") 373 | tm_test = open("tree/test-m.txt", "w") 374 | ta_test = open("tree/test-all.txt", "w") 375 | 376 | ss_test = open("simplify/test-src.txt", "w") 377 | st_test = open("simplify/test-tgt.txt", "w") 378 | 379 | 380 | for i in split["test"]: 381 | # label 382 | w = filtered[i] 383 | fog_test.write(w["gold"] + " 1\n") 384 | foa_test.write(w["gold"] + " 1\n") 385 | for h in w["human"]: 386 | foh_test.write(h + " 0\n") 387 | foa_test.write(h + " 0\n") 388 | 389 | ss_test.write(h+'\n') 390 | st_test.write(w["gold"]+'\n') 391 | 392 | for m in w["machine"]: 393 | fom_test.write(m + " 0\n") 394 | foa_test.write(m + " 0\n") 395 | 396 | ss_test.write(m+'\n') 397 | st_test.write(w["gold"]+'\n') 398 | 399 | # tree 400 | wp = parsed_idiom[i] 401 | tg_test.write(wp["gold"]["binaryParse"]+'\n') 402 | ta_test.write(wp["gold"]["binaryParse"]+'\n') 403 | for h in wp["human"]: 404 | th_test.write(h["binaryParse"]+'\n') 405 | ta_test.write(h["binaryParse"]+'\n') 406 | for m in wp["machine"]: 407 | tm_test.write(m["binaryParse"]+'\n') 408 | ta_test.write(m["binaryParse"]+'\n') 409 | 410 | 411 | 412 | if __name__ == '__main__': 413 | parser = argparse.ArgumentParser() 414 | parser.add_argument('--seed', type=int, default=41) 415 | parser.add_argument('-p', '--points', type=int, default=5) 416 | args = parser.parse_args() 417 | main(args) --------------------------------------------------------------------------------