├── src
├── .gitkeep
├── __init__.py
├── eval
│ ├── .gitkeep
│ ├── __init__.py
│ ├── EEtesting.py
│ ├── Groundingtesting.py
│ ├── SRtesting.py
│ └── EEvisualizing.py
├── util
│ ├── .gitkeep
│ ├── __init__.py
│ ├── consts.py
│ ├── helper.py
│ ├── vocab.py
│ ├── util_model.py
│ ├── constant.py
│ └── util_img.py
├── dataflow
│ ├── .gitkeep
│ ├── __init__.py
│ ├── numpy
│ │ ├── __init__.py
│ │ ├── .DS_Store
│ │ ├── dataset_image_download.py
│ │ ├── anno_mapping.py
│ │ └── prepare_vocab.py
│ └── torch
│ │ ├── __init__.py
│ │ ├── Corpus.py
│ │ ├── Data.py
│ │ └── Sentence.py
├── engine
│ ├── .gitkeep
│ ├── __init__.py
│ ├── TestRunnerEE.py
│ └── Groundingrunner.py
└── models
│ ├── .gitkeep
│ ├── __init__.py
│ ├── modules
│ ├── .gitkeep
│ ├── __init__.py
│ ├── HighWay.py
│ ├── FMapLayerImage.py
│ ├── EmbeddingLayerImage.py
│ ├── model.py
│ ├── SelfAttention.py
│ ├── DynamicLSTM.py
│ ├── GCN.py
│ └── EmbeddingLayer.py
│ └── ace_classifier.py
├── scripts
├── .gitkeep
├── eval
│ ├── .gitkeep
│ ├── test_ee_m2e2.sh
│ ├── test_sr_m2e2_obj.sh
│ ├── test_sr_m2e2.sh
│ ├── test_joint_object.sh
│ └── test_joint.sh
└── train
│ ├── .gitkeep
│ ├── train_ee.sh
│ ├── train_grounding.sh
│ ├── train_sr_object.sh
│ ├── train_sr.sh
│ ├── train_joint_obj.sh
│ └── train_joint_att.sh
├── training-testing.jpg
├── data
├── m2e2_annotations.zip
├── ace
│ └── ace_sr_mapping.txt
└── object
│ └── class-descriptions-boxable.csv
├── .idea
├── dictionaries
│ └── lynn.xml
├── vcs.xml
├── misc.xml
├── modules.xml
└── m2e2-multimedia-event-extraction.iml
├── requirements.txt
└── README.md
/src/.gitkeep:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/scripts/.gitkeep:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/src/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/src/eval/.gitkeep:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/src/util/.gitkeep:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/scripts/eval/.gitkeep:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/scripts/train/.gitkeep:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/src/dataflow/.gitkeep:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/src/engine/.gitkeep:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/src/engine/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/src/eval/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/src/models/.gitkeep:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/src/models/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/src/util/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/src/dataflow/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/src/models/modules/.gitkeep:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/src/dataflow/numpy/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/src/dataflow/torch/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/src/models/modules/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/training-testing.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/limanling/m2e2/HEAD/training-testing.jpg
--------------------------------------------------------------------------------
/data/m2e2_annotations.zip:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/limanling/m2e2/HEAD/data/m2e2_annotations.zip
--------------------------------------------------------------------------------
/src/dataflow/numpy/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/limanling/m2e2/HEAD/src/dataflow/numpy/.DS_Store
--------------------------------------------------------------------------------
/.idea/dictionaries/lynn.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
--------------------------------------------------------------------------------
/.idea/vcs.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/.idea/misc.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
--------------------------------------------------------------------------------
/src/util/consts.py:
--------------------------------------------------------------------------------
1 | UNK_LABEL = ""
2 | UNK_IDX = 0
3 | PADDING_LABEL = ""
4 | PADDING_IDX = 1
5 |
6 | VOCAB_START_IDX = 2
7 |
8 | CUTOFF = 50
9 | TRIGGER_GOLDEN_COMPENSATION = 12
10 | ARGUMENT_GOLDEN_COMPENSATION = 27
11 |
12 | O_LABEL = None
13 | ROLE_O_LABEL = None
14 |
15 | O_LABEL_NAME = "O"
16 | ROLE_O_LABEL_NAME = "OTHER"
--------------------------------------------------------------------------------
/.idea/modules.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | torchtext==0.4.0
2 | ujson==1.35
3 | seqeval==0.0.12
4 | six==1.12.0
5 | jieba==0.39
6 | nltk==3.4.5
7 | scipy==1.3.0
8 | torchvision==0.3.0
9 | Pillow==8.1.1
10 | common==0.1.2
11 | constants==0.6.0
12 | graphstate==1.0.6
13 | ipdb==0.12.2
14 | preprocessing==0.1.13
15 | stanfordcorenlp==3.9.1.1
16 | tensorboardX==1.8
17 | utils==0.9.0
18 |
--------------------------------------------------------------------------------
/.idea/m2e2-multimedia-event-extraction.iml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
--------------------------------------------------------------------------------
/src/dataflow/torch/Corpus.py:
--------------------------------------------------------------------------------
1 | from torchtext.data import Dataset
2 |
3 |
4 | class Corpus(Dataset):
5 | def __init__(self, path, fields, amr, **kwargs):
6 | '''
7 | Create a corpus given a path, field list, and a filter function.
8 |
9 | :param path: str, Path to the data file
10 | :param fields: dict[str: tuple(str, Field)],
11 | If using a dict, the keys should be a subset of the JSON keys or CSV/TSV
12 | columns, and the values should be tuples of (name, field).
13 | Keys not present in the input dictionary are ignored.
14 | This allows the user to rename columns from their JSON/CSV/TSV key names
15 | and also enables selecting a subset of columns to load.
16 | '''
17 | self.path = path
18 | self._size = None
19 |
20 | examples = self.parse_example(path, fields, amr, **kwargs)
21 | fields = list(fields.values())
22 | super(Corpus, self).__init__(examples, fields, **kwargs)
23 |
24 | def parse_example(self, path, fields, amr, **kwargs):
25 | raise NotImplementedError
--------------------------------------------------------------------------------
/src/models/ace_classifier.py:
--------------------------------------------------------------------------------
1 | from src.models.modules.model import Model
2 | from src.util.util_model import BottledXavierLinear
3 | from torch import nn
4 |
5 | class ACEClassifier(Model):
6 | def __init__(self, common_dim, type_num, role_num, device):
7 | # self.common_dim, hyps["oc"], hyps["ae_oc"]
8 | super(ACEClassifier, self).__init__()
9 |
10 | self.device = device
11 |
12 | # Output Linear
13 | self.ol = BottledXavierLinear(in_features=common_dim, out_features=type_num).to(device=device)
14 |
15 | # AE Output Linear
16 | self.ae_ol = BottledXavierLinear(in_features=2 * common_dim, out_features=role_num).to(device=device)
17 | # self.ae_l1 = nn.Linear(in_features=2 * common_dim, out_features=common_dim)
18 | # self.ae_bn1 = nn.BatchNorm1d(num_features=common_dim)
19 | # self.ae_l2 = nn.Linear(in_features=common_dim, out_features=role_num)
20 |
21 | # Move to right device
22 | self.to(self.device)
23 |
24 | def forward_type(self, feature_in):
25 | ed_logits = self.ol(feature_in)
26 | return ed_logits
27 |
28 | def forward_role(self, entity_feature_in):
29 | ae_logits = self.ae_ol(entity_feature_in)
30 | return ae_logits
31 |
32 |
33 |
--------------------------------------------------------------------------------
/scripts/train/train_ee.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | export CUDA_VISIBLE_DEVICES=$2
4 |
5 | logdir='../../log/ee/'$1
6 | mkdir -p $logdir
7 |
8 | datadir="/scratch/manling2/data/mm-event-graph"
9 | glovedir="/scratch/manling2/data/glove"
10 |
11 | python ../../src/engine/EErunner.py \
12 | --train_ee $datadir"/ace/JMEE_train_filter_no_timevalue.json" \
13 | --test_ee $datadir"/ace/JMEE_test_filter_no_timevalue.json" \
14 | --dev_ee $datadir"/ace/JMEE_dev_filter_no_timevalue.json" \
15 | --earlystop 10 --optimizer "adam" --lr 1e-4 \
16 | --webd $glovedir"/glove.840B.300d.txt" \
17 | --batch 32 --epochs 100 --device "cuda" --out $logdir \
18 | --shuffle \
19 | --hps "{
20 | 'wemb_dim': 300,
21 | 'wemb_ft': True,
22 | 'wemb_dp': 0.5,
23 | 'pemb_dim': 50,
24 | 'pemb_dp': 0.5,
25 | 'eemb_dim': 50,
26 | 'eemb_dp': 0.5,
27 | 'psemb_dim': 50,
28 | 'psemb_dp': 0.5,
29 | 'lstm_dim': 220,
30 | 'lstm_layers': 1,
31 | 'lstm_dp': 0,
32 | 'gcn_et': 3,
33 | 'gcn_use_bn': True,
34 | 'gcn_layers': 2,
35 | 'gcn_dp': 0.5,
36 | 'sa_dim': 300,
37 | 'use_highway': True,
38 | 'loss_alpha': 5
39 | }" \
40 | >& $logdir/stdout.log &
41 |
--------------------------------------------------------------------------------
/src/dataflow/numpy/dataset_image_download.py:
--------------------------------------------------------------------------------
1 | import os
2 | import argparse
3 | import urllib.request
4 | import json
5 |
6 | def download_image(url_image, path_save):
7 | urllib.request.urlretrieve(url_image, path_save)
8 |
9 | def download_image_list(meta_json, dir_save):
10 | if not os.path.exists(meta_json):
11 | print('[ERROR] input_metadata_json does not exist.')
12 | metadata = json.load(open(meta_json))
13 | for doc_id in metadata:
14 | for img_id in metadata[doc_id]:
15 | url_image = metadata[doc_id][img_id]['url']
16 | suffix_image = metadata[doc_id][img_id]['url'].split('.')[-1]
17 | image_path_save = os.path.join(dir_save, '%s_%s.%s' % (doc_id, img_id, suffix_image))
18 | # if not url_image.endswith('.jpg'):
19 | # print(image_path_save, url_image)
20 | download_image(url_image, image_path_save)
21 |
22 |
23 | if __name__ == '__main__':
24 | parser = argparse.ArgumentParser()
25 | parser.add_argument('input_metadata_json', type=str,
26 | help='input_metadata_json')
27 | parser.add_argument('output_image_dir', type=str,
28 | help='output_image_dir')
29 | args = parser.parse_args()
30 |
31 | input_metadata_json = args.input_metadata_json
32 | output_image_dir = args.output_image_dir
33 |
34 | download_image_list(input_metadata_json, output_image_dir)
--------------------------------------------------------------------------------
/src/models/modules/HighWay.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | from torch.nn import functional as F
3 |
4 | import sys
5 | #sys.path.append('/dvmm-filer2/users/manling/mm-event-graph2')
6 | from src.util.util_model import BottledXavierLinear
7 |
8 |
9 | class HighWay(nn.Module):
10 | def __init__(self, size, num_layers=1, dropout_ratio=0.5):
11 | super(HighWay, self).__init__()
12 | self.size = size
13 | self.num_layers = num_layers
14 | self.trans = nn.ModuleList()
15 | self.gate = nn.ModuleList()
16 | self.dropout = dropout_ratio
17 |
18 | for i in range(num_layers):
19 | tmptrans = BottledXavierLinear(size, size)
20 | tmpgate = BottledXavierLinear(size, size)
21 | self.trans.append(tmptrans)
22 | self.gate.append(tmpgate)
23 |
24 | def forward(self, x):
25 | '''
26 | forward this module
27 | :param x: torch.FloatTensor, (N, D) or (N1, N2, D)
28 | :return: torch.FloatTensor, (N, D) or (N1, N2, D)
29 | '''
30 |
31 | g = F.sigmoid(self.gate[0](x))
32 | h = F.relu(self.trans[0](x))
33 | x = g * h + (1 - g) * x
34 |
35 | for i in range(1, self.num_layers):
36 | x = F.dropout(x, p=self.dropout, training=self.training)
37 | g = F.sigmoid(self.gate[i](x))
38 | h = F.relu(self.trans[i](x))
39 | x = g * h + (1 - g) * x
40 |
41 | return x
42 |
--------------------------------------------------------------------------------
/src/models/modules/FMapLayerImage.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torchvision.models as models
4 | from torch.nn import functional as F
5 |
6 | class Flatten(torch.nn.Module):
7 | def forward(self, x):
8 | batch_size = x.shape[0]
9 | return x.view(batch_size, -1)
10 |
11 | class FMapLayerImage(nn.Module):
12 | def __init__(self, fine_tune=False, dropout=0.5,
13 | device=torch.device("cpu"), backbone='resnet152'):
14 |
15 | super(FMapLayerImage, self).__init__()
16 |
17 | net = getattr(models, backbone)(pretrained=True)
18 | #print(net)
19 | if backbone=='vgg16':
20 | b1, pool, b2 = list(net.children())
21 | modules_1 = list(b1.children())
22 | modules_2 = [pool, Flatten()] + list(b2.children())[:-1]
23 |
24 | else:
25 | b1 = list(net.children())
26 | modules_1 = b1[:-2]
27 | modules_2 = [b1[-2]]
28 |
29 |
30 | self.backbone = nn.Sequential(*modules_1)
31 | self.pooler = nn.Sequential(*modules_2)
32 | #print(self.backbone, self.pooler)
33 |
34 |
35 | for p in self.backbone.parameters():
36 | p.requires_grad = fine_tune
37 | for p in self.pooler.parameters():
38 | p.requires_grad = fine_tune
39 |
40 | self.device = device
41 | self.to(device)
42 |
43 | def forward(self, images):
44 | fmap = self.backbone(images)
45 | emb = self.pooler(fmap)
46 | return fmap, emb
--------------------------------------------------------------------------------
/scripts/eval/test_ee_m2e2.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | export CUDA_VISIBLE_DEVICES=$2
4 |
5 | logdir='../../log/ee_test/'$1
6 | mkdir -p $logdir
7 |
8 | datadir="/scratch/manling2/data/mm-event-graph"
9 | glovedir="/scratch/manling2/data/glove"
10 | checkpoint="/scratch/manling2/mm-event-graph/log/joint/obj_gcn2/model/model_ee_17.pt"
11 | checkpoint_params="/scratch/manling2/mm-event-graph/log/joint/obj_gcn2/ee_hyps.json"
12 |
13 | python ../../src/engine/TestRunnerEE_m2e2.py \
14 | --train_grounding $datadir"/grounding/grounding_train_20000.json" \
15 | --test_grounding $datadir"/grounding/grounding_test_20000.json" \
16 | --dev_grounding $datadir"/grounding/grounding_valid_20000.json" \
17 | --img_dir_grounding $datadir"/voa/rawdata/VOA_image_en/" \
18 | --object_detection_pkl_file_g $datadir"/voa/object_detection/det_results_voa_oi_1.pkl" \
19 | --object_class_map_file $datadir"/object/class-descriptions-boxable.csv" \
20 | --object_detection_threshold 0.2 \
21 | --train_ee $datadir"/ace/JMEE_train_filter_no_timevalue.json" \
22 | --test_ee $datadir"/ace/JMEE_test_filter_no_timevalue.json" \
23 | --dev_ee $datadir"/ace/JMEE_dev_filter_no_timevalue.json" \
24 | --webd $glovedir"/glove.840B.300d.txt" \
25 | --batch 32 --device "cuda" --out $logdir \
26 | --finetune ${checkpoint} \
27 | --hps_path ${checkpoint_params} \
28 | --gt_voa_text $datadir"/voa_anno_m2e2/article_event.json" \
29 | --keep_events 1 \
30 | --load_grounding \
31 | >& $logdir/testout.log &
32 |
33 |
--------------------------------------------------------------------------------
/src/models/modules/EmbeddingLayerImage.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torchvision.models as models
4 | from torch.nn import functional as F
5 |
6 | class Flatten(torch.nn.Module):
7 | def forward(self, x):
8 | batch_size = x.shape[0]
9 | return x.view(batch_size, -1)
10 |
11 | class EmbeddingLayerImage(nn.Module):
12 | def __init__(self, fine_tune=False, dropout=0.5,
13 | device=torch.device("cpu"), backbone='resnet152'):
14 |
15 | super(EmbeddingLayerImage, self).__init__()
16 |
17 | resnet = getattr(models, backbone)(pretrained=True)
18 | #print(resnet)
19 | if backbone=='vgg16':
20 | b1, pool, b2 = list(resnet.children())
21 | modules = list(b1.children()) + [pool, Flatten()] + list(b2.children())[:-1]
22 | self.dropout = None
23 | else:
24 | modules = list(resnet.children())[:-1]
25 | self.dropout = dropout if type(dropout) == float and -1e-7 < dropout < 1 + 1e-7 else None
26 |
27 | self.resnet = nn.Sequential(*modules)
28 | #print(self.resnet)
29 |
30 |
31 | for p in self.resnet.parameters():
32 | p.requires_grad = fine_tune
33 |
34 | self.device = device
35 | self.to(device)
36 |
37 | def forward(self, images):
38 | if self.dropout is not None:
39 | return F.dropout(self.resnet(images), p=self.dropout, training=self.training)
40 | else:
41 | return self.resnet(images) # batchsize * 2048 * imageLength * imageLength (imageLength=1)
--------------------------------------------------------------------------------
/src/util/helper.py:
--------------------------------------------------------------------------------
1 | """
2 | Helper functions.
3 | """
4 |
5 | import os
6 | import subprocess
7 | import json
8 | import argparse
9 |
10 | ### IO
11 | def check_dir(d):
12 | if not os.path.exists(d):
13 | print("Directory {} does not exist. Exit.".format(d))
14 | exit(1)
15 |
16 | def check_files(files):
17 | for f in files:
18 | if f is not None and not os.path.exists(f):
19 | print("File {} does not exist. Exit.".format(f))
20 | exit(1)
21 |
22 | def ensure_dir(d, verbose=True):
23 | if not os.path.exists(d):
24 | if verbose:
25 | print("Directory {} do not exist; creating...".format(d))
26 | os.makedirs(d)
27 |
28 | def save_config(config, path, verbose=True):
29 | with open(path, 'w') as outfile:
30 | json.dump(config, outfile, indent=2)
31 | if verbose:
32 | print("Config saved to file {}".format(path))
33 | return config
34 |
35 | def load_config(path, verbose=True):
36 | with open(path) as f:
37 | config = json.load(f)
38 | if verbose:
39 | print("Config loaded from file {}".format(path))
40 | return config
41 |
42 | def print_config(config):
43 | info = "Running with the following configs:\n"
44 | for k,v in config.items():
45 | info += "\t{} : {}\n".format(k, str(v))
46 | print("\n" + info + "\n")
47 | return
48 |
49 | class FileLogger(object):
50 | """
51 | A file logger that opens the file periodically and write to it.
52 | """
53 | def __init__(self, filename, header=None):
54 | self.filename = filename
55 | if os.path.exists(filename):
56 | # remove the old file
57 | os.remove(filename)
58 | if header is not None:
59 | with open(filename, 'w') as out:
60 | print(header, file=out)
61 |
62 | def log(self, message):
63 | with open(self.filename, 'a') as out:
64 | print(message, file=out)
65 |
66 |
--------------------------------------------------------------------------------
/scripts/train/train_grounding.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | export CUDA_VISIBLE_DEVICES=$2
4 |
5 | logdir='../../log/grounding/'$1
6 | mkdir -p $logdir
7 |
8 | datadir="/scratch/manling2/data/mm-event-graph"
9 | glovedir="/scratch/manling2/data/glove"
10 |
11 | python ../../src/engine/Groundingrunner.py \
12 | --train $datadir"/grounding/grounding_train.json" \
13 | --test $datadir"/grounding/grounding_test.json" \
14 | --dev $datadir"/grounding/grounding_valid.json" \
15 | --earlystop 10 --restart 999999 --optimizer "adam" --lr 0.0001 \
16 | --wnebd $datadir"/vocab/embedding_situation_noun.npy" \
17 | --wvebd $datadir"/vocab/embedding_situation_verb.npy" \
18 | --wrebd $datadir"/vocab/embedding_situation_role.npy" \
19 | --vocab $datadir"/vocab/" \
20 | --webd $glovedir"/glove.840B.300d.txt" \
21 | --img_dir $datadir"/voa/rawdata/VOA_image_en/" \
22 | --shuffle \
23 | --batch 32 --epochs 100 --device "cuda" --out $logdir \
24 | --ee_hps "{
25 | 'wemb_dim': 300,
26 | 'wemb_ft': True,
27 | 'wemb_dp': 0.5,
28 | 'pemb_dim': 50,
29 | 'pemb_dp': 0.5,
30 | 'eemb_dim': 50,
31 | 'eemb_dp': 0.5,
32 | 'psemb_dim': 50,
33 | 'psemb_dp': 0.5,
34 | 'lstm_dim': 150,
35 | 'lstm_layers': 1,
36 | 'lstm_dp': 0,
37 | 'gcn_et': 3,
38 | 'gcn_use_bn': True,
39 | 'gcn_layers': 3,
40 | 'gcn_dp': 0.5,
41 | 'sa_dim': 300,
42 | 'use_highway': True,
43 | 'loss_alpha': 5
44 | }" \
45 | --sr_hps "{
46 | 'wemb_dim': 300,
47 | 'wemb_ft': True,
48 | 'wemb_dp': 0.0,
49 | 'iemb_backbone': 'vgg16',
50 | 'iemb_dim':4096,
51 | 'iemb_ft': False,
52 | 'iemb_dp': 0.0,
53 | 'posemb_dim': 512,
54 | 'fmap_dim': 512,
55 | 'fmap_size': 7,
56 | 'att_dim': 1024,
57 | 'loss_weight_verb': 1.0,
58 | 'loss_weight_noun': 0.1,
59 | 'loss_weight_role': 0.0
60 | }" \
61 | >& $logdir/stdout.log &
62 |
--------------------------------------------------------------------------------
/scripts/train/train_sr_object.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | export CUDA_VISIBLE_DEVICES=$2
4 |
5 | logdir='../../log/sr/'$1
6 | mkdir -p $logdir
7 |
8 | datadir="/scratch/manling2/data/mm-event-graph"
9 | glovedir="/scratch/manling2/data/glove"
10 |
11 | python ../../src/engine/SRrunner.py \
12 | --train_sr $datadir"/imSitu/train.json" \
13 | --test_sr $datadir"/imSitu/test.json" \
14 | --dev_sr $datadir"/imSitu/dev.json" \
15 | --train_ee $datadir"/ace/JMEE_train_filter_no_timevalue.json" \
16 | --webd $glovedir"/glove.840B.300d.txt" \
17 | --earlystop 10 --restart 999999 --optimizer "adam" --lr 0.0001 \
18 | --wnebd $datadir"/vocab/embedding_situation_noun.npy" \
19 | --wvebd $datadir"/vocab/embedding_situation_verb.npy" \
20 | --wrebd $datadir"/vocab/embedding_situation_role.npy" \
21 | --vocab $datadir"/vocab/" \
22 | --image_dir $datadir"/imSitu/of500_images_resized" \
23 | --imsitu_ontology_file $datadir"/imSitu/imsitu_space.json" \
24 | --verb_mapping_file $datadir"/ace/ace_sr_mapping.txt" \
25 | --object_class_map_file $datadir"/object/class-descriptions-boxable.csv" \
26 | --object_detection_pkl_file $datadir"/imSitu/object_detection/det_results_imsitu_oi_1.pkl" \
27 | --object_detection_threshold 0.2 \
28 | --shuffle \
29 | --add_object \
30 | --filter_place \
31 | --batch 12 --epochs 100 --device "cuda" --out $logdir \
32 | --hps "{
33 | 'wemb_dim': 300,
34 | 'wemb_ft': False,
35 | 'wemb_dp': 0.0,
36 | 'iemb_backbone': 'vgg16',
37 | 'iemb_dim':4096,
38 | 'iemb_ft': False,
39 | 'iemb_dp': 0.0,
40 | 'posemb_dim': 512,
41 | 'fmap_dim': 512,
42 | 'fmap_size': 7,
43 | 'att_dim': 1024,
44 | 'loss_weight_verb': 1.0,
45 | 'loss_weight_noun': 0.1,
46 | 'loss_weight_role': 0.0,
47 | 'gcn_layers': 1,
48 | 'gcn_dp': False,
49 | 'gcn_use_bn': False,
50 | 'use_highway': False,
51 | }" \
52 | >& $logdir/stdout.log &
53 |
54 | #--filter_irrelevant_verbs
55 | # --train_ace \
--------------------------------------------------------------------------------
/scripts/train/train_sr.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | export CUDA_VISIBLE_DEVICES=$2
4 |
5 | logdir='../../log/sr/'$1
6 | mkdir -p $logdir
7 |
8 | datadir="/scratch/manling2/data/mm-event-graph"
9 | glovedir="/scratch/manling2/data/glove"
10 |
11 | python ../../src/engine/SRrunner.py \
12 | --train_sr $datadir"/imSitu/train.json" \
13 | --test_sr $datadir"/imSitu/test.json" \
14 | --dev_sr $datadir"/imSitu/dev.json" \
15 | --train_ee $datadir"/ace/JMEE_train_filter_no_timevalue.json" \
16 | --webd $glovedir"/glove.840B.300d.txt" \
17 | --earlystop 10 --restart 999999 \
18 | --wnebd $datadir"/vocab/embedding_situation_noun.npy" \
19 | --wvebd $datadir"/vocab/embedding_situation_verb.npy" \
20 | --wrebd $datadir"/vocab/embedding_situation_role.npy" \
21 | --vocab $datadir"/vocab/" \
22 | --image_dir $datadir"/imSitu/of500_images_resized" \
23 | --imsitu_ontology_file $datadir"/imSitu/imsitu_space.json" \
24 | --verb_mapping_file $datadir"/ace/ace_sr_mapping.txt" \
25 | --object_class_map_file $datadir"/object/class-descriptions-boxable.csv" \
26 | --object_detection_pkl_file $datadir"/imSitu/object_detection/det_results_imsitu_oi_1.pkl" \
27 | --object_detection_threshold 0.2 \
28 | --shuffle \
29 | --filter_place \
30 | --batch 8 --epochs 40 --device "cuda" --out $logdir \
31 | --optimizer "adam" --lr 0.0001 \
32 | --hps "{
33 | 'wemb_dim': 300,
34 | 'wemb_ft': False,
35 | 'wemb_dp': 0.0,
36 | 'iemb_backbone': 'vgg16',
37 | 'iemb_dim':4096,
38 | 'iemb_ft': False,
39 | 'iemb_dp': 0.0,
40 | 'posemb_dim': 512,
41 | 'fmap_dim': 512,
42 | 'fmap_size': 7,
43 | 'att_dim': 1024,
44 | 'loss_weight_verb': 1.0,
45 | 'loss_weight_noun': 0.1,
46 | 'loss_weight_role': 0.0,
47 | 'gcn_layers': 1,
48 | 'gcn_dp': False,
49 | 'gcn_use_bn': False,
50 | 'use_highway': False,
51 | }" \
52 | >& $logdir/stdout.log &
53 |
54 | # --train_ace
55 | # --filter_irrelevant_verbs
56 | # --optimizer "adam" --lr 0.0001 \
--------------------------------------------------------------------------------
/scripts/eval/test_sr_m2e2_obj.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | export CUDA_VISIBLE_DEVICES=$2
4 |
5 | logdir='../../log/sr_test/'$1
6 | mkdir -p $logdir
7 | mkdir -p $logdir/'image_result'
8 |
9 | datadir="/scratch/manling2/data/mm-event-graph"
10 | glovedir="/scratch/manling2/data/glove"
11 | checkpoint="/scratch/manling2/mm-event-graph/log/joint/obj_gcn2/model/model_sr_37.pt"
12 | checkpoint_params="/scratch/manling2/mm-event-graph/log/joint/obj_gcn2/sr_hyps.json"
13 |
14 | python ../../src/engine/TestRunnerSR_m2e2.py \
15 | --train_sr $datadir"/imSitu/train.json" \
16 | --test_sr $datadir"/imSitu/test.json" \
17 | --dev_sr $datadir"/imSitu/dev.json" \
18 | --wnebd $datadir"/vocab/embedding_situation_noun.npy" \
19 | --wvebd $datadir"/vocab/embedding_situation_verb.npy" \
20 | --wrebd $datadir"/vocab/embedding_situation_role.npy" \
21 | --vocab $datadir"/vocab/" \
22 | --image_dir $datadir"/imSitu/of500_images_resized" \
23 | --imsitu_ontology_file $datadir"/imSitu/imsitu_space.json" \
24 | --verb_mapping_file $datadir"/ace/ace_sr_mapping.txt" \
25 | --object_class_map_file $datadir"/object/class-descriptions-boxable.csv" \
26 | --filter_irrelevant_verbs --filter_place \
27 | --train_ee $datadir"/ace/JMEE_train_filter_no_timevalue.json" \
28 | --test_ee $datadir"/ace/JMEE_test_filter_no_timevalue.json" \
29 | --dev_ee $datadir"/ace/JMEE_dev_filter_no_timevalue.json" \
30 | --webd $glovedir"/glove.840B.300d.txt" \
31 | --train_grounding $datadir"/grounding/grounding_train_20000.json" \
32 | --test_grounding $datadir"/grounding/grounding_test_20000.json" \
33 | --dev_grounding $datadir"/grounding/grounding_valid_20000.json" \
34 | --batch 1 --device "cuda" --out $logdir \
35 | --test_voa_image $datadir"/voa/rawdata/VOA_image_en/" \
36 | --gt_voa_image $datadir"/voa_anno_m2e2/image_event.json" \
37 | --gt_voa_text $datadir"/voa_anno_m2e2/article_event.json" \
38 | --finetune_sr ${checkpoint} \
39 | --sr_hps_path ${checkpoint_params} \
40 | --object_detection_pkl_file $datadir"/voa/object_detection/det_results_voa_oi_1.pkl" \
41 | --ignore_place_sr_test \
42 | --ignore_time_test \
43 | --object_detection_threshold 0.1 \
44 | --keep_events 1 \
45 | --keep_events_sr 0 \
46 | --add_object \
47 | >& $logdir/testout.log &
48 |
49 |
50 |
51 |
--------------------------------------------------------------------------------
/scripts/eval/test_sr_m2e2.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | export CUDA_VISIBLE_DEVICES=$2
4 |
5 | logdir='../../log/sr_test/'$1
6 | mkdir -p $logdir
7 | mkdir -p $logdir/'image_result'
8 |
9 | datadir="/scratch/manling2/data/mm-event-graph"
10 | glovedir="/scratch/manling2/data/glove"
11 | checkpoint="/scratch/manling2/mm-event-graph/log/joint/att_gcn3/model/model_sr_4.pt"
12 | checkpoint_params="/scratch/manling2/mm-event-graph/log/joint/att_gcn3/sr_hyps.json"
13 |
14 | python ../../src/engine/TestRunnerSR_m2e2.py \
15 | --train_sr $datadir"/imSitu/train.json" \
16 | --test_sr $datadir"/imSitu/test.json" \
17 | --dev_sr $datadir"/imSitu/dev.json" \
18 | --wnebd $datadir"/vocab/embedding_situation_noun.npy" \
19 | --wvebd $datadir"/vocab/embedding_situation_verb.npy" \
20 | --wrebd $datadir"/vocab/embedding_situation_role.npy" \
21 | --vocab $datadir"/vocab/" \
22 | --image_dir $datadir"/imSitu/of500_images_resized" \
23 | --imsitu_ontology_file $datadir"/imSitu/imsitu_space.json" \
24 | --verb_mapping_file $datadir"/ace/ace_sr_mapping.txt" \
25 | --object_class_map_file $datadir"/object/class-descriptions-boxable.csv" \
26 | --filter_irrelevant_verbs --filter_place \
27 | --train_ee $datadir"/ace/JMEE_train_filter_no_timevalue.json" \
28 | --test_ee $datadir"/ace/JMEE_test_filter_no_timevalue.json" \
29 | --dev_ee $datadir"/ace/JMEE_dev_filter_no_timevalue.json" \
30 | --webd $glovedir"/glove.840B.300d.txt" \
31 | --train_grounding $datadir"/grounding/grounding_train_20000.json" \
32 | --test_grounding $datadir"/grounding/grounding_test_20000.json" \
33 | --dev_grounding $datadir"/grounding/grounding_valid_20000.json" \
34 | --batch 1 --device "cuda" --out $logdir \
35 | --test_voa_image $datadir"/voa/rawdata/VOA_image_en/" \
36 | --gt_voa_image $datadir"/voa_anno_m2e2/image_event.json" \
37 | --gt_voa_text $datadir"/voa_anno_m2e2/article_event.json" \
38 | --finetune_sr ${checkpoint} \
39 | --sr_hps_path ${checkpoint_params} \
40 | --object_detection_pkl_file $datadir"/voa/object_detection/det_results_voa_oi_1.pkl" \
41 | --ignore_place_sr_test \
42 | --ignore_time_test \
43 | --object_detection_threshold 0.1 \
44 | --keep_events 1 \
45 | --keep_events_sr 0 \
46 | --visual_voa_sr_path $logdir/'image_result' \
47 | >& $logdir/testout.log &
48 |
49 |
50 |
51 |
--------------------------------------------------------------------------------
/src/models/modules/model.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | import sys
5 | #sys.path.append('/dvmm-filer2/users/manling/mm-event-graph2')
6 | from src.util.util_model import log
7 |
8 |
9 | class Model(nn.Module):
10 | def __init__(self, *args, **kwargs):
11 | super(Model, self).__init__()
12 | self.hyperparams = None
13 | self.device = torch.device("cpu")
14 |
15 | def __getnewargs__(self):
16 | # for pickle
17 | return self.hyperparams
18 |
19 | def __new__(cls, *args, **kwargs):
20 | log('created %s with params %s' % (str(cls), str(args)))
21 |
22 | instance = super(Model, cls).__new__(cls)
23 | instance.__init__(*args, **kwargs)
24 | return instance
25 |
26 | def test_mode_on(self):
27 | self.test_mode = True
28 | self.eval()
29 |
30 | def test_mode_off(self):
31 | self.test_mode = False
32 | self.train()
33 |
34 | def parameters_requires_grads(self):
35 | return list(filter(lambda p: p.requires_grad, self.parameters()))
36 |
37 | def parameters_requires_grad_clipping(self):
38 | return self.parameters_requires_grads()
39 |
40 | def save_model(self, path):
41 | state_dict = self.state_dict()
42 | for key, value in state_dict.items():
43 | # print('save key', key)
44 | state_dict[key] = value.cpu()
45 | torch.save(state_dict, path)
46 |
47 | def load_model(self, path, load_partial=False):
48 | pretrained_dict = torch.load(path)
49 | try:
50 | self.load_state_dict(pretrained_dict)
51 | except Exception as e:
52 | if load_partial:
53 | # load matched part
54 | model_dict = self.state_dict()
55 | ignore_keys = set(['ace_classifier.ol.linear.weight',
56 | 'ace_classifier.ol.linear.bias',
57 | 'ace_classifier.ae_ol.linear.weight',
58 | 'ace_classifier.ae_ol.linear.bias'])
59 | # 1. filter out unnecessary keys
60 | pretrained_dict = {k: v for k, v in pretrained_dict.items() if k not in ignore_keys} #in ignore_keys}
61 | # 2. overwrite entries in the existing state dict
62 | model_dict.update(pretrained_dict)
63 | # 3. load the new state dict
64 | self.load_state_dict(model_dict)
65 | else:
66 | print(e)
67 | exit(-1)
68 |
--------------------------------------------------------------------------------
/scripts/train/train_joint_obj.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | export CUDA_VISIBLE_DEVICES=$2
4 |
5 | logdir='../../log/joint/'$1
6 | mkdir -p $logdir
7 |
8 | datadir="/scratch/manling2/data/mm-event-graph"
9 | glovedir="/scratch/manling2/data/glove"
10 |
11 | python ../../src/engine/JOINTRunner.py \
12 | --train_sr $datadir"/imSitu/train.json" \
13 | --test_sr $datadir"/imSitu/test.json" \
14 | --dev_sr $datadir"/imSitu/dev.json" \
15 | --wnebd $datadir"/vocab/embedding_situation_noun.npy" \
16 | --wvebd $datadir"/vocab/embedding_situation_verb.npy" \
17 | --wrebd $datadir"/vocab/embedding_situation_role.npy" \
18 | --vocab $datadir"/vocab/" \
19 | --image_dir $datadir"/imSitu/of500_images_resized" \
20 | --imsitu_ontology_file $datadir"/imSitu/imsitu_space.json" \
21 | --verb_mapping_file $datadir"/ace/ace_sr_mapping.txt" \
22 | --object_class_map_file $datadir"/object/class-descriptions-boxable.csv" \
23 | --object_detection_pkl_file $datadir"/imSitu/object_detection/det_results_imsitu_oi_1.pkl" \
24 | --train_ee $datadir"/ace/JMEE_train_filter_no_timevalue.json" \
25 | --test_ee $datadir"/ace/JMEE_test_filter_no_timevalue.json" \
26 | --dev_ee $datadir"/ace/JMEE_dev_filter_no_timevalue.json" \
27 | --webd $glovedir"/glove.840B.300d.txt" \
28 | --ee_hps "{
29 | 'wemb_dim': 300,
30 | 'wemb_ft': True,
31 | 'wemb_dp': 0.5,
32 | 'pemb_dim': 50,
33 | 'pemb_dp': 0.5,
34 | 'eemb_dim': 50,
35 | 'eemb_dp': 0.5,
36 | 'psemb_dim': 50,
37 | 'psemb_dp': 0.5,
38 | 'lstm_dim': 150,
39 | 'lstm_layers': 1,
40 | 'lstm_dp': 0,
41 | 'gcn_et': 3,
42 | 'gcn_use_bn': True,
43 | 'gcn_layers': 3,
44 | 'gcn_dp': 0.5,
45 | 'sa_dim': 300,
46 | 'use_highway': True,
47 | 'loss_alpha': 5
48 | }" \
49 | --train_grounding $datadir"/grounding/grounding_train_20000.json" \
50 | --test_grounding $datadir"/grounding/grounding_test_20000.json" \
51 | --dev_grounding $datadir"/grounding/grounding_valid_20000.json" \
52 | --img_dir_grounding $datadir"/voa/rawdata/VOA_image_en/" \
53 | --object_detection_pkl_file_g $datadir"/voa/object_detection/det_results_voa_oi_1.pkl" \
54 | --earlystop 10 --restart 999999 --optimizer "adadelta" --lr 1 \
55 | --finetune_sr '/scratch/manling2/mm-event-graph/log/sr/retest/model.pt' \
56 | --sr_hps_path '/scratch/manling2/mm-event-graph/log/sr/retest/sr_hyps.json' \
57 | --object_detection_threshold 0.0 \
58 | --shuffle \
59 | --filter_place \
60 | --add_object \
61 | --batch 4 --epochs 100 --device "cuda" --out $logdir \
62 | >& $logdir/stdout.log &
63 |
--------------------------------------------------------------------------------
/src/models/modules/SelfAttention.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 |
6 | class AttentionLayer(nn.Module):
7 | def __init__(self, D, H=128, return_sequences=False):
8 | '''
9 | A single convolutional unit
10 | :param D: int, input feature dim
11 | :param H: int, hidden feature dim
12 | :param return_sequences: boolean, whether return sequence
13 | '''
14 | super(AttentionLayer, self).__init__()
15 |
16 | # Config copying
17 | self.H = H
18 | self.return_sequences = return_sequences
19 | self.D = D
20 | self.linear1 = nn.Linear(D, H)
21 | self.linear2 = nn.Linear(H, 1)
22 |
23 | def softmax_mask(self, x, mask):
24 | '''
25 | Softmax with mask
26 | :param x: torch.FloatTensor, logits, [batch_size, seq_len]
27 | :param mask: torch.ByteTensor, masks for sentences, [batch_size, seq_len]
28 | :return: torch.FloatTensor, probabilities, [batch_size, seq_len]
29 | '''
30 | x_exp = torch.exp(x)
31 | if mask is not None:
32 | x_exp = x_exp * mask.float()
33 | x_sum = torch.sum(x_exp, dim=-1, keepdim=True) + 1e-6
34 | x_exp /= x_sum
35 | return x_exp
36 |
37 | def forward(self, x_text, mask, x_attention=None):
38 | '''
39 | Forward this module
40 | :param x_text: torch.FloatTensor, input features, [batch_size, seq_len, D]
41 | :param mask: torch.ByteTensor, masks for features, [batch_size, seq_len]
42 | :param x_attention: torch.FloatTensor, input features No. 2 to attent with x_text, [batch_size, seq_len, D]
43 | :return: torch.FloatTensor, output features, if return sequences, output shape is [batch, SEQ_LEN, D];
44 | otherwise output shape is [batch, D]
45 | '''
46 | if x_attention is None:
47 | x_attention = x_text
48 | SEQ_LEN = x_text.size()[-2]
49 | x_attention = x_attention.contiguous().view(-1, self.D) # [batch_size * seq_len, D]
50 | attention = F.tanh(self.linear1(x_attention)) # [batch_size * seq_len, H]
51 | attention = self.linear2(attention) # [batch_size * seq_len, 1]
52 | attention = attention.view(-1, SEQ_LEN) # [batch_size, seq_len]
53 | attention = self.softmax_mask(attention, mask) # [batch_size, seq_len]
54 | output = x_text * attention.unsqueeze(-1).expand_as(x_text) # [batch_size, seq_len, D]
55 |
56 | if not self.return_sequences:
57 | output = torch.sum(output, -2)
58 | output = output.squeeze(1)
59 | return output
60 |
61 |
62 | if __name__ == "__main__":
63 | al = AttentionLayer(2, 3, return_sequences=True)
64 | x = torch.randn(5, 3, 2)
65 | print(x.size())
66 | mask = torch.ByteTensor([[1, 1, 1], [1, 1, 1], [1, 1, 1], [1, 1, 0], [1, 0, 0]])
67 | print(mask.size())
68 | y = al(x, mask)
69 | print(y.size())
70 |
--------------------------------------------------------------------------------
/scripts/eval/test_joint_object.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | export CUDA_VISIBLE_DEVICES=$2
4 |
5 | logdir='../../log/joint_test/'$1
6 | mkdir -p $logdir
7 | mkdir -p $logdir/grounding_result
8 | mkdir -p $logdir/image_result
9 | mkdir -p $logdir/text_result
10 | rm -rf $logdir/image_result/*
11 | rm -rf $logdir/grounding_result/*
12 | rm -rf $logdir/text_result/*
13 | ln -s "/scratch/manling2/data/mm-event-graph/voa/rawdata/VOA_image_en" $logdir/grounding_result/VOA_image_en
14 |
15 | datadir="/scratch/manling2/data/mm-event-graph"
16 | glovedir="/scratch/manling2/data/glove"
17 | checkpoint_sr="/scratch/manling2/mm-event-graph/log/joint/obj_gcn3/model/model_sr_52.pt"
18 | checkpoint_sr_params="/scratch/manling2/mm-event-graph/log/joint/obj_gcn3/sr_hyps.json"
19 | checkpoint_ee="/scratch/manling2/mm-event-graph/log/joint/obj_gcn3/model/model_ee_52.pt"
20 | checkpoint_ee_params="/scratch/manling2/mm-event-graph/log/joint/obj_gcn3/ee_hyps.json"
21 |
22 | python ../../src/engine/TestRunnerJOINT.py \
23 | --train_sr $datadir"/imSitu/train.json" \
24 | --test_sr $datadir"/imSitu/test.json" \
25 | --dev_sr $datadir"/imSitu/dev.json" \
26 | --wnebd $datadir"/vocab/embedding_situation_noun.npy" \
27 | --wvebd $datadir"/vocab/embedding_situation_verb.npy" \
28 | --wrebd $datadir"/vocab/embedding_situation_role.npy" \
29 | --vocab $datadir"/vocab/" \
30 | --image_dir $datadir"/imSitu/of500_images_resized" \
31 | --imsitu_ontology_file $datadir"/imSitu/imsitu_space.json" \
32 | --verb_mapping_file $datadir"/ace/ace_sr_mapping.txt" \
33 | --object_class_map_file $datadir"/object/class-descriptions-boxable.csv" \
34 | --filter_irrelevant_verbs --filter_place \
35 | --train_ee $datadir"/ace/JMEE_train_filter_no_timevalue.json" \
36 | --test_ee $datadir"/ace/JMEE_test_filter_no_timevalue.json" \
37 | --dev_ee $datadir"/ace/JMEE_dev_filter_no_timevalue.json" \
38 | --webd $glovedir"/glove.840B.300d.txt" \
39 | --train_grounding $datadir"/grounding/grounding_train_20000.json" \
40 | --test_grounding $datadir"/grounding/grounding_test_20000.json" \
41 | --dev_grounding $datadir"/grounding/grounding_valid_20000.json" \
42 | --batch 1 --device "cuda" --out $logdir \
43 | --test_voa_image $datadir"/voa/rawdata/VOA_image_en/" \
44 | --gt_voa_image $datadir"/voa_anno_m2e2/image_event.json" \
45 | --gt_voa_text $datadir"/voa_anno_m2e2/article_event.json" \
46 | --gt_voa_align $datadir"/voa_anno_m2e2/article_event.json" \
47 | --object_detection_pkl_file $datadir"/voa/object_detection/det_results_voa_oi_1.pkl" \
48 | --finetune_sr ${checkpoint_sr} \
49 | --sr_hps_path ${checkpoint_sr_params} \
50 | --finetune_ee ${checkpoint_ee} \
51 | --ee_hps_path ${checkpoint_ee_params} \
52 | --ignore_place_sr_test \
53 | --ignore_time_test \
54 | --object_detection_threshold 0.1 \
55 | --add_object \
56 | --keep_events 1 \
57 | --keep_events_sr 0 \
58 | --with_sentid \
59 | --apply_ee_role_mask \
60 | >& $logdir/stdout.log &
61 |
--------------------------------------------------------------------------------
/scripts/eval/test_joint.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | export CUDA_VISIBLE_DEVICES=$2
4 |
5 | logdir='../../log/joint_test/'$1
6 | mkdir -p $logdir
7 | mkdir -p $logdir/grounding_result
8 | mkdir -p $logdir/image_result
9 | mkdir -p $logdir/text_result
10 | rm -rf $logdir/image_result/*
11 | rm -rf $logdir/grounding_result/*
12 | rm -rf $logdir/text_result/*
13 | ln -s "/scratch/manling2/data/mm-event-graph/voa/rawdata/VOA_image_en" $logdir/grounding_result/VOA_image_en
14 |
15 | datadir="/scratch/manling2/data/mm-event-graph"
16 | glovedir="/scratch/manling2/data/glove"
17 | checkpoint_sr="/scratch/manling2/mm-event-graph/log/joint/att_gcn3/model/model_sr_4.pt"
18 | checkpoint_sr_params="/scratch/manling2/mm-event-graph/log/joint/att_gcn3/sr_hyps.json"
19 | checkpoint_ee="/scratch/manling2/mm-event-graph/log/joint/att_gcn3/model/model_ee_4.pt"
20 | checkpoint_ee_params="/scratch/manling2/mm-event-graph/log/joint/att_gcn3/ee_hyps.json"
21 |
22 | python ../../src/engine/TestRunnerJOINT.py \
23 | --train_sr $datadir"/imSitu/train.json" \
24 | --test_sr $datadir"/imSitu/test.json" \
25 | --dev_sr $datadir"/imSitu/dev.json" \
26 | --wnebd $datadir"/vocab/embedding_situation_noun.npy" \
27 | --wvebd $datadir"/vocab/embedding_situation_verb.npy" \
28 | --wrebd $datadir"/vocab/embedding_situation_role.npy" \
29 | --vocab $datadir"/vocab/" \
30 | --image_dir $datadir"/imSitu/of500_images_resized" \
31 | --imsitu_ontology_file $datadir"/imSitu/imsitu_space.json" \
32 | --verb_mapping_file $datadir"/ace/ace_sr_mapping.txt" \
33 | --object_class_map_file $datadir"/object/class-descriptions-boxable.csv" \
34 | --filter_irrelevant_verbs --filter_place \
35 | --train_ee $datadir"/ace/JMEE_train_filter_no_timevalue.json" \
36 | --test_ee $datadir"/ace/JMEE_test_filter_no_timevalue.json" \
37 | --dev_ee $datadir"/ace/JMEE_dev_filter_no_timevalue.json" \
38 | --webd $glovedir"/glove.840B.300d.txt" \
39 | --train_grounding $datadir"/grounding/grounding_train_20000.json" \
40 | --test_grounding $datadir"/grounding/grounding_test_20000.json" \
41 | --dev_grounding $datadir"/grounding/grounding_valid_20000.json" \
42 | --object_detection_pkl_file_g $datadir"/voa/object_detection/det_results_voa_oi_1.pkl" \
43 | --batch 1 --device "cuda" --out $logdir \
44 | --test_voa_image $datadir"/voa/rawdata/VOA_image_en/" \
45 | --gt_voa_image $datadir"/voa_anno_m2e2/image_event.json" \
46 | --gt_voa_text $datadir"/voa_anno_m2e2/article_event.json" \
47 | --gt_voa_align $datadir"/voa_anno_m2e2/article_event.json" \
48 | --object_detection_pkl_file $datadir"/voa/object_detection/det_results_voa_oi_1.pkl" \
49 | --finetune_sr ${checkpoint_sr} \
50 | --sr_hps_path ${checkpoint_sr_params} \
51 | --finetune_ee ${checkpoint_ee} \
52 | --ee_hps_path ${checkpoint_ee_params} \
53 | --ignore_place_sr_test \
54 | --ignore_time_test \
55 | --object_detection_threshold 0.1 \
56 | --keep_events 1 \
57 | --keep_events_sr 0 \
58 | --with_sentid \
59 | --apply_ee_role_mask \
60 | >& $logdir/stdout.log &
61 |
--------------------------------------------------------------------------------
/scripts/train/train_joint_att.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | export CUDA_VISIBLE_DEVICES=$2
4 |
5 | logdir='../../log/joint/'$1
6 | mkdir -p $logdir
7 |
8 | datadir="/scratch/manling2/data/mm-event-graph"
9 | glovedir="/scratch/manling2/data/glove"
10 |
11 | python ../../src/engine/JOINTRunner.py \
12 | --train_sr $datadir"/imSitu/train.json" \
13 | --test_sr $datadir"/imSitu/test.json" \
14 | --dev_sr $datadir"/imSitu/dev.json" \
15 | --wnebd $datadir"/vocab/embedding_situation_noun.npy" \
16 | --wvebd $datadir"/vocab/embedding_situation_verb.npy" \
17 | --wrebd $datadir"/vocab/embedding_situation_role.npy" \
18 | --vocab $datadir"/vocab/" \
19 | --image_dir $datadir"/imSitu/of500_images_resized" \
20 | --imsitu_ontology_file $datadir"/imSitu/imsitu_space.json" \
21 | --verb_mapping_file $datadir"/ace/ace_sr_mapping.txt" \
22 | --object_class_map_file $datadir"/object/class-descriptions-boxable.csv" \
23 | --object_detection_pkl_file $datadir"/imSitu/object_detection/det_results_imsitu_oi_1.pkl" \
24 | --object_detection_threshold 0.2 \
25 | --sr_hps "{
26 | 'wemb_dim': 300,
27 | 'wemb_ft': False,
28 | 'wemb_dp': 0.0,
29 | 'iemb_backbone': 'vgg16',
30 | 'iemb_dim':4096,
31 | 'iemb_ft': False,
32 | 'iemb_dp': 0.0,
33 | 'posemb_dim': 512,
34 | 'fmap_dim': 512,
35 | 'fmap_size': 7,
36 | 'att_dim': 1024,
37 | 'loss_weight_verb': 1.0,
38 | 'loss_weight_noun': 0.1,
39 | 'loss_weight_role': 0.0,
40 | 'gcn_layers': 3,
41 | 'gcn_dp': 0.5,
42 | 'gcn_use_bn': True,
43 | 'use_highway': False,
44 | }" \
45 | --train_ee $datadir"/ace/JMEE_train_filter_no_timevalue.json" \
46 | --test_ee $datadir"/ace/JMEE_test_filter_no_timevalue.json" \
47 | --dev_ee $datadir"/ace/JMEE_dev_filter_no_timevalue.json" \
48 | --webd $glovedir"/glove.840B.300d.txt" \
49 | --ee_hps "{
50 | 'wemb_dim': 300,
51 | 'wemb_ft': True,
52 | 'wemb_dp': 0.5,
53 | 'pemb_dim': 50,
54 | 'pemb_dp': 0.5,
55 | 'eemb_dim': 50,
56 | 'eemb_dp': 0.5,
57 | 'psemb_dim': 50,
58 | 'psemb_dp': 0.5,
59 | 'lstm_dim': 150,
60 | 'lstm_layers': 1,
61 | 'lstm_dp': 0,
62 | 'gcn_et': 3,
63 | 'gcn_use_bn': True,
64 | 'gcn_layers': 3,
65 | 'gcn_dp': 0.5,
66 | 'sa_dim': 300,
67 | 'use_highway': True,
68 | 'loss_alpha': 5
69 | }" \
70 | --train_grounding $datadir"/grounding/grounding_train_20000.json" \
71 | --test_grounding $datadir"/grounding/grounding_test_20000.json" \
72 | --dev_grounding $datadir"/grounding/grounding_valid_20000.json" \
73 | --img_dir_grounding $datadir"/voa/rawdata/VOA_image_en/" \
74 | --object_detection_pkl_file_g $datadir"/voa/object_detection/det_results_voa_oi_1.pkl" \
75 | --earlystop 10 --restart 999999 --optimizer "adadelta" --lr 1 \
76 | --shuffle \
77 | --filter_place \
78 | --finetune_sr /scratch/manling2/mm-event-graph/log/sr/retest/model.pt \
79 | --batch 8 --epochs 60 --device "cuda" --out $logdir \
80 | >& $logdir/stdout.log &
81 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Cross-media Structured Common Space for Multimedia Event Extraction
2 |
3 | Table of Contents
4 | =================
5 | * [Overview](#overview)
6 | * [Requirements](#requirements)
7 | * [Data](#data)
8 | * [Quickstart](#quickstart)
9 | * [Citation](#citation)
10 |
11 | ## Overview
12 | The code for paper [Cross-media Structured Common Space for Multimedia Event Extraction](http://blender.cs.illinois.edu/software/m2e2/).
13 |
14 |
15 |
16 |
17 |
18 | ## Requirements
19 |
20 | You can install the environment using `requirements.txt` for each component.
21 |
22 | ```pip
23 | pip install -r requirements.txt
24 | ```
25 |
26 | ## Data
27 |
28 | ### Situation Recognition (Visual Event Extraction Data)
29 | We download situation recognition data from [imSitu](http://imsitu.org/). Please find the preprocessed data in [PreprcessedSR](https://drive.google.com/drive/folders/1h0qwYWeGEoCx8m-zwH-XcoPSyffmrC-c?usp=sharing).
30 |
31 | ### ACE (Text Event Extraction Data)
32 | We preprcoessed ACE following [JMEE](https://github.com/lx865712528/EMNLP2018-JMEE/tree/master). The sample data format is in [sample.json](https://github.com/lx865712528/EMNLP2018-JMEE/blob/master/ace-05-splits/sample.json). Due to license reason, the ACE 2005 dataset is only accessible to those with LDC2006T06 license, please drop me an email `manling2@illinois.edu` showing your possession of the license for the processed data.
33 |
34 | ### Voice of America Image-Caption Pairs
35 | We crawled VOA image-captions to train the common space, the [image-caption pairs](https://uofi.box.com/s/xtn9p6m8z5qtjbbi5tqrl45tn6apew4x) and images can be downloaded using the URLs (We share image URLs instead of downloaded images due to license issue) using script in [dataset_image_download.py](https://github.com/limanling/m2e2/blob/master/src/dataflow/numpy/dataset_image_download.py). We preprocess the data including object detection, and parse text sentences. The preprocessed data is in [PreprocessedVOA](https://drive.google.com/drive/folders/1I9vMGIhWZpKqxQYip91eLoDRnrkqRxnt?usp=sharing).
36 |
37 | ### M2E2 (Multimedia Event Extraction Benchmark)
38 |
39 | The images and text articles are in [m2e2_rawdata](https://drive.google.com/file/d/1xtFMjt_eYgeBts5rBomOWbPo7wV_mnhy/view?usp=sharing), and annotations are in [m2e2_annotation](http://blender.cs.illinois.edu/software/m2e2/m2e2_v0.1/m2e2_annotations.zip) under `data` directory.
40 |
41 | ### Vocabulary
42 | Preprocessed vocabulary is in [PreprocessedVocab](https://drive.google.com/drive/folders/1MWklNkpedJp-P80wpKF-WOQAvYaYKzP2?usp=sharing).
43 |
44 |
45 | ## Quickstart
46 |
47 | ### Training
48 |
49 | We have two variants to parse images into situation graph, one is parsing images to role-driven attention graph, and another is parsing images to object graphs.
50 |
51 | (1) attention-graph based version
52 | ```bash
53 | sh scripts/train/train_joint_att.sh
54 | ```
55 | (2) object-graph based version:
56 | ```bash
57 | sh scripts/train/train_joint_obj.sh
58 | ```
59 | Please specify the data paths `datadir`, `glovedir` in scripts.
60 |
61 |
62 | ### Testing
63 |
64 | (1) attention-graph based version
65 | ```
66 | sh test_joint.sh
67 | ```
68 | (2) object-graph based version:
69 | ```bash
70 | sh test_joint_object.sh
71 | ```
72 |
73 | Please specify the data paths `datadir`, `glovedir`, and model paths `checkpoint_sr`, `checkpoint_sr_params`, `checkpoint_ee`, `checkpoint_ee_params` in scripts.
74 |
75 |
76 | ## Citation
77 |
78 | Manling Li, Alireza Zareian, Qi Zeng, Spencer Whitehead, Di Lu, Heng Ji, Shih-Fu Chang. 2020. Cross-media Structured Common Space for Multimedia Event Extraction. Proceedings of The 58th Annual Meeting of the Association for Computational Linguistics.
79 | ```
80 | @inproceedings{li2020multimediaevent,
81 | title={Cross-media Structured Common Space for Multimedia Event Extraction},
82 | author={Manling Li and Alireza Zareian and Qi Zeng and Spencer Whitehead and Di Lu and Heng Ji and Shih-Fu Chang},
83 | booktitle={Proceedings of The 58th Annual Meeting of the Association for Computational Linguistics},
84 | year={2020}
85 | ```
86 |
--------------------------------------------------------------------------------
/src/util/vocab.py:
--------------------------------------------------------------------------------
1 | """
2 | A class for basic vocab operations.
3 | """
4 |
5 | from __future__ import print_function
6 | import os
7 | import random
8 | import numpy as np
9 | import pickle
10 |
11 | from . import constant
12 |
13 | random.seed(1234)
14 | np.random.seed(1234)
15 |
16 | def build_embedding(wv_file, vocab, wv_dim):
17 | vocab_size = len(vocab)
18 | emb = np.random.uniform(-1, 1, (vocab_size, wv_dim))
19 | emb[constant.PAD_ID] = 0 # should be all 0
20 |
21 | w2id = {w: i for i, w in enumerate(vocab)}
22 | with open(wv_file, encoding="utf8") as f:
23 | for line in f:
24 | elems = line.split()
25 | token = ''.join(elems[0:-wv_dim])
26 | if token in w2id:
27 | emb[w2id[token]] = [float(v) for v in elems[-wv_dim:]]
28 | return emb
29 |
30 | def load_glove_vocab(file, wv_dim):
31 | """
32 | Load all words from glove.
33 | """
34 | vocab = set()
35 | with open(file, encoding='utf8') as f:
36 | for line in f:
37 | elems = line.split()
38 | token = ''.join(elems[0:-wv_dim])
39 | vocab.add(token)
40 | return vocab
41 |
42 | class Vocab(object):
43 | def __init__(self, filename, load=False, word_counter=None, threshold=0):
44 | if load:
45 | assert os.path.exists(filename), "Vocab file does not exist at " + filename
46 | # load from file and ignore all other params
47 | self.id2word, self.word2id = self.load(filename)
48 | self.size = len(self.id2word)
49 | print("Vocab size {} loaded from file".format(self.size))
50 | else:
51 | print("Creating vocab from scratch...")
52 | assert word_counter is not None, "word_counter is not provided for vocab creation."
53 | self.word_counter = word_counter
54 | if threshold > 1:
55 | # remove words that occur less than thres
56 | self.word_counter = dict([(k,v) for k,v in self.word_counter.items() if v >= threshold])
57 | self.id2word = sorted(self.word_counter, key=lambda k:self.word_counter[k], reverse=True)
58 | # add special tokens to the beginning
59 | self.id2word = [constant.PAD_TOKEN, constant.UNK_TOKEN] + self.id2word
60 | self.word2id = dict([(self.id2word[idx],idx) for idx in range(len(self.id2word))])
61 | self.size = len(self.id2word)
62 | self.save(filename)
63 | print("Vocab size {} saved to file {}".format(self.size, filename))
64 |
65 | def load(self, filename):
66 | with open(filename, 'rb') as infile:
67 | id2word = pickle.load(infile)
68 | word2id = dict([(id2word[idx], idx) for idx in range(len(id2word))])
69 | return id2word, word2id
70 |
71 | def save(self, filename):
72 | if os.path.exists(filename):
73 | print("Overwriting old vocab file at " + filename)
74 | os.remove(filename)
75 | with open(filename, 'wb') as outfile:
76 | pickle.dump(self.id2word, outfile)
77 | return
78 |
79 | def map(self, token_list):
80 | """
81 | Map a list of tokens to their ids.
82 | """
83 | return [self.word2id[w] if w in self.word2id else constant.VOCAB_UNK_ID for w in token_list]
84 |
85 | def unmap(self, idx_list):
86 | """
87 | Unmap ids back to tokens.
88 | """
89 | return [self.id2word[idx] for idx in idx_list]
90 |
91 | def get_embeddings(self, word_vectors=None, dim=100):
92 | self.embeddings = 2 * constant.EMB_INIT_RANGE * np.random.rand(self.size, dim) - constant.EMB_INIT_RANGE
93 | if word_vectors is not None:
94 | assert len(list(word_vectors.values())[0]) == dim, \
95 | "Word vectors does not have required dimension {}.".format(dim)
96 | for w, idx in self.word2id.items():
97 | if w in word_vectors:
98 | self.embeddings[idx] = np.asarray(word_vectors[w])
99 | return self.embeddings
100 |
101 |
--------------------------------------------------------------------------------
/src/models/modules/DynamicLSTM.py:
--------------------------------------------------------------------------------
1 | import numpy
2 | import numpy as np
3 | import torch
4 | import torch.nn as nn
5 |
6 |
7 | class DynamicLSTM(nn.Module):
8 | def __init__(self, input_size, hidden_size, num_layers=1, bias=True, batch_first=True, dropout=0,
9 | bidirectional=False, device=torch.device("cpu")):
10 | """
11 | Dynamic LSTM which can hold variable length sequence, use like TensorFlow's RNN(input, length...).
12 |
13 | :param input_size: The number of expected features in the input x
14 | :param hidden_size: The number of features in the hidden state h
15 | :param num_layers: Number of recurrent layers.
16 | :param bias: If False, then the layer does not use bias weights b_ih and b_hh. Default: True
17 | :param batch_first: If True, then the input and output tensors are provided as (batch, seq, feature)
18 | :param dropout: If non-zero, introduces a dropout layer on the outputs of each RNN layer except the last layer
19 | :param bidirectional: If True, becomes a bidirectional RNN. Default: False
20 | """
21 | super(DynamicLSTM, self).__init__()
22 | self.input_size = input_size
23 | self.hidden_size = hidden_size
24 | self.num_layers = num_layers
25 | self.bias = bias
26 | self.batch_first = batch_first
27 | self.dropout = dropout
28 | self.bidirectional = bidirectional
29 | self.LSTM = nn.LSTM(
30 | input_size=input_size,
31 | hidden_size=hidden_size,
32 | num_layers=num_layers,
33 | bias=bias,
34 | batch_first=batch_first,
35 | dropout=dropout,
36 | bidirectional=bidirectional
37 | )
38 |
39 | self.device = device
40 | self.to(device)
41 |
42 | def forward(self, x, x_len, only_use_last_hidden_state=False):
43 | """
44 | sequence -> sort -> pad and pack -> process using RNN -> unpack -> unsort
45 |
46 | :param x: FloatTensor, pre-padded input sequence (batch_size, seq_len, feature_dim)
47 | :param x_len: numpy list, indicating corresponding actual sequence length
48 | :return: output, (h_n, c_n)
49 | - **output**: FloatTensor, packed output sequence (batch_size, seq_len, feature_dim * num_directions)
50 | containing the output features `(h_t)` from the last layer of the LSTM, for each t.
51 | - **h_n**: FloatTensor, (num_layers * num_directions, batch, hidden_size)
52 | containing the hidden state for `t = seq_len`
53 | - **c_n**: FloatTensor, (num_layers * num_directions, batch, hidden_size)
54 | containing the cell state for `t = seq_len`
55 | """
56 | # 1. sort
57 | x_sort_idx = np.argsort(-x_len)
58 | x_unsort_idx = torch.LongTensor(np.argsort(x_sort_idx)).to(self.device)
59 | x_len = x_len[x_sort_idx]
60 | x = x[torch.LongTensor(x_sort_idx).to(self.device)]
61 | # 2. pack
62 | x_p = torch.nn.utils.rnn.pack_padded_sequence(x, x_len, batch_first=self.batch_first)
63 | # 3. process using RNN
64 | out_pack, (ht, ct) = self.LSTM(x_p, None)
65 | # 4. unsort h
66 | ht = torch.transpose(ht, 0, 1)[x_unsort_idx]
67 | ht = torch.transpose(ht, 0, 1)
68 |
69 | if only_use_last_hidden_state:
70 | return ht
71 | else:
72 | # 5. unpack output
73 | out = torch.nn.utils.rnn.pad_packed_sequence(out_pack, batch_first=self.batch_first) # (sequence, lengths)
74 | out = out[0] #
75 | # 6. unsort out c
76 | out = out[x_unsort_idx]
77 | ct = torch.transpose(ct, 0, 1)[x_unsort_idx]
78 | ct = torch.transpose(ct, 0, 1)
79 | return out, (ht, ct)
80 |
81 |
82 | if __name__ == "__main__":
83 | BATCH_SIZE = 5
84 | SEQ_LEN = 3
85 | D = 2
86 | aa = DynamicLSTM(input_size=2, hidden_size=2, batch_first=True)
87 | binp = torch.rand(BATCH_SIZE, SEQ_LEN, D)
88 | binlen = numpy.array([3, 3, 3, 2, 2], dtype=numpy.int8)
89 | boup, _ = aa(binp, binlen)
90 | print(boup)
91 |
92 | for i in range(BATCH_SIZE):
93 | sinp = binp[i].unsqueeze(0)
94 | sinlen = numpy.array([binlen[i]], dtype=numpy.int8)
95 | soup, _ = aa(sinp, sinlen)
96 | print(soup)
97 |
--------------------------------------------------------------------------------
/src/eval/EEtesting.py:
--------------------------------------------------------------------------------
1 | from seqeval.metrics import f1_score, precision_score, recall_score
2 |
3 |
4 | class EDTester():
5 | def __init__(self, type_i2s, role_i2s, ignore_time):
6 | self.voc_i2s = type_i2s
7 | self.role_i2s = role_i2s
8 | self.ignore_time = ignore_time
9 |
10 | def calculate_report(self, y, y_, transform=True):
11 | '''
12 | calculating F1, P, R
13 |
14 | :param y: golden label, list
15 | :param y_: model output, list
16 | :return:
17 | '''
18 | if transform:
19 | for i in range(len(y)):
20 | for j in range(len(y[i])):
21 | y[i][j] = self.voc_i2s[y[i][j]]
22 | for i in range(len(y_)):
23 | for j in range(len(y_[i])):
24 | y_[i][j] = self.voc_i2s[y_[i][j]]
25 | return precision_score(y, y_), recall_score(y, y_), f1_score(y, y_)
26 |
27 | @staticmethod
28 | def merge_segments(y, send_id=None):
29 | segs = {}
30 | tt = ""
31 | st, ed = -1, -1
32 | for i, x in enumerate(y):
33 | if x.startswith("B-"):
34 | if tt == "":
35 | tt = x[2:]
36 | if send_id is None:
37 | st = i
38 | else:
39 | st = '%s__%d' % (send_id, i)
40 | else:
41 | ed = i
42 | segs[st] = (ed, tt)
43 | tt = x[2:]
44 | if send_id is None:
45 | st = i
46 | else:
47 | st = '%s__%d' % (send_id, i)
48 | elif x.startswith("I-"):
49 | if tt == "":
50 | y[i] = "B" + y[i][1:]
51 | tt = x[2:]
52 | if send_id is None:
53 | st = i
54 | else:
55 | st = '%s__%d' % (send_id, i)
56 | else:
57 | if tt != x[2:]:
58 | ed = i
59 | segs[st] = (ed, tt)
60 | y[i] = "B" + y[i][1:]
61 | tt = x[2:]
62 | if send_id is None:
63 | st = i
64 | else:
65 | st = '%s__%d' % (send_id, i)
66 | else:
67 | ed = i
68 | if tt != "":
69 | segs[st] = (ed, tt)
70 | tt = ""
71 |
72 | if tt != "":
73 | segs[st] = (len(y), tt)
74 | return segs
75 |
76 | def calculate_sets(self, y, y_):
77 | ct, p1, p2 = 0, 0, 0
78 | for sent, sent_ in zip(y, y_):
79 | # trigger start
80 | for key, value in sent.items():
81 | # key = trigger end, event type
82 | # value = args
83 | p1 += len(value)
84 | if key not in sent_:
85 | continue
86 | # matched sentences
87 | arguments = value
88 | arguments_ = sent_[key]
89 | for item, item_ in zip(arguments, arguments_):
90 | # print('item', self.role_i2s[item[2]], self.role_i2s[item_[2]])
91 | if self.ignore_time and self.role_i2s[item[2]].upper().startswith('TIME'):
92 | continue
93 | if item[2] == item_[2]:
94 | ct += 1
95 |
96 | for key, value in sent_.items():
97 | # p2_key += 1
98 | # p2 += len(value)
99 | # print('key', key)
100 | for item in sent_[key]:
101 | if self.ignore_time and self.role_i2s[item[2]].upper().startswith('TIME'):
102 | continue
103 | p2 += 1
104 |
105 |
106 | if ct == 0 or p1 == 0 or p2 == 0:
107 | return 0.0, 0.0, 0.0
108 | else:
109 | p = 1.0 * ct / p2
110 | r = 1.0 * ct / p1
111 | f1 = 2.0 * p * r / (p + r)
112 | print('ct', ct)
113 | print('p1', p1)
114 | print('p2', p2)
115 | return p, r, f1
116 |
--------------------------------------------------------------------------------
/src/eval/Groundingtesting.py:
--------------------------------------------------------------------------------
1 |
2 | class GroundingTester():
3 | def __init__(self):
4 | pass
5 |
6 | def calculate_lists(self, y, y_):
7 | '''
8 | for a sequence, whether the prediction is correct
9 | note that len(y) == len(y_)
10 | :param y:
11 | :param y_:
12 | :return:
13 | '''
14 | ct = 0
15 | p2 = len(y_)
16 | p1 = len(y)
17 | for i in range(p2):
18 | if y[i] == y_[i]:
19 | ct = ct + 1
20 | if ct == 0 or p1 == 0 or p2 == 0:
21 | return 0.0, 0.0, 0.0
22 | else:
23 | p = 1.0 * ct / p2
24 | r = 1.0 * ct / p1
25 | f1 = 2.0 * p * r / (p + r)
26 | return p, r, f1
27 |
28 | def calculate_sets_no_order(self, y, y_):
29 | '''
30 | for each predicted item, whether it is in the gt
31 | :param y: [batch, items]
32 | :param y_: [batch, items]
33 | :return:
34 | '''
35 | ct, p1, p2 = 0, 0, 0
36 | for batch, batch_ in zip(y, y_):
37 | value_set = set(batch)
38 | value_set_ = set(batch_)
39 | p1 += len(value_set)
40 | p2 += len(value_set_)
41 |
42 | for value_ in value_set_:
43 | # if value_ == '(0,0,0)':
44 | if value_ in value_set:
45 | ct += 1
46 |
47 | if ct == 0 or p1 == 0 or p2 == 0:
48 | return 0.0, 0.0, 0.0
49 | else:
50 | p = 1.0 * ct / p2
51 | r = 1.0 * ct / p1
52 | f1 = 2.0 * p * r / (p + r)
53 | return p, r, f1
54 |
55 | def calculate_sets_noun(self, y, y_):
56 | '''
57 | for each ground truth entity, whether it is in the predicted entities
58 | :param y: [batch, role_num, multiple_args]
59 | :param y_: [batch, role_num, multiple_entities]
60 | :return:
61 | '''
62 | # print('y', y)
63 | # print('y_', y_)
64 | ct, p1, p2 = 0, 0, 0
65 | # for batch_idx, batch_idx_ in zip(y, y_):
66 | # batch = y[batch_idx]
67 | # batch_ = y_[batch_idx_]
68 | for batch, batch_ in zip(y, y_):
69 | # print('batch', batch)
70 | # print('batch_', batch_)
71 | p1 += len(batch)
72 | p2 += len(batch_)
73 | for role in batch:
74 | found = False
75 | entities = batch[role]
76 | for entity in entities:
77 | for role_ in batch_:
78 | entities_ = batch_[role_]
79 | if entity in entities_:
80 | ct += 1
81 | found = True
82 | break
83 | if found:
84 | break
85 |
86 | if ct == 0 or p1 == 0 or p2 == 0:
87 | return 0.0, 0.0, 0.0
88 | else:
89 | p = 1.0 * ct / p2
90 | r = 1.0 * ct / p1
91 | f1 = 2.0 * p * r / (p + r)
92 | return p, r, f1
93 |
94 | def calculate_sets_triple(self, y, y_):
95 | '''
96 | for each role, whether the predicted entities have overlap with the gt entities
97 | :param y: dict, role -> entities
98 | :param y_: dict, role -> entities
99 | :return:
100 | '''
101 | ct, p1, p2 = 0, 0, 0
102 | # for batch_idx, batch_idx_ in zip(y, y_):
103 | # batch = y[batch_idx]
104 | # batch_ = y_[batch_idx_]
105 | for batch, batch_ in zip(y, y_):
106 | p1 += len(batch)
107 | p2 += len(batch_)
108 | for role in batch:
109 | entities = batch[role]
110 | if role in batch_:
111 | entities_ = batch_[role]
112 | for entity_ in entities_:
113 | if entity_ in entities:
114 | ct += 1
115 | break
116 |
117 | if ct == 0 or p1 == 0 or p2 == 0:
118 | return 0.0, 0.0, 0.0
119 | else:
120 | p = 1.0 * ct / p2
121 | r = 1.0 * ct / p1
122 | f1 = 2.0 * p * r / (p + r)
123 | return p, r, f1
124 |
125 |
--------------------------------------------------------------------------------
/src/models/modules/GCN.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from torch.nn import functional as F
4 | from torch.optim import Adadelta
5 |
6 | import sys
7 | #sys.path.append('/dvmm-filer2/users/manling/mm-event-graph2')
8 | from src.util.util_model import BottledOrthogonalLinear, log
9 |
10 |
11 | class GraphConvolution(nn.Module):
12 | def __init__(self, in_features, out_features, edge_types, dropout=0.5, bias=True, use_bn=False,
13 | device=torch.device("cpu")):
14 | """
15 | Single Layer GraphConvolution
16 |
17 | :param in_features: The number of incoming features
18 | :param out_features: The number of output features
19 | :param edge_types: The number of edge types in the whole graph
20 | :param dropout: Dropout keep rate, if not bigger than 0, 0 or None, default 0.5
21 | :param bias: If False, then the layer does not use bias weights b_ih and b_hh. Default: True
22 | """
23 | super(GraphConvolution, self).__init__()
24 | self.in_features = in_features
25 | self.out_features = out_features
26 | self.edge_types = edge_types
27 | self.dropout = dropout if type(dropout) == float and -1e-7 < dropout < 1 + 1e-7 else None
28 | # parameters for gates
29 | self.Gates = nn.ModuleList()
30 | # parameters for graph convolutions
31 | self.GraphConv = nn.ModuleList()
32 | # batch norm
33 | self.use_bn = use_bn
34 | if self.use_bn:
35 | self.bn = nn.BatchNorm1d(self.out_features)
36 |
37 | for _ in range(edge_types):
38 | self.Gates.append(BottledOrthogonalLinear(in_features=in_features,
39 | out_features=1,
40 | bias=bias))
41 | self.GraphConv.append(BottledOrthogonalLinear(in_features=in_features,
42 | out_features=out_features,
43 | bias=bias))
44 | self.device = device
45 | self.to(device)
46 |
47 | def forward(self, input, adj):
48 | """
49 |
50 | :param input: FloatTensor, input feature tensor, (batch_size, seq_len, hidden_size)
51 | :param adj: FloatTensor (sparse.FloatTensor.to_dense()), adjacent matrix for provided graph of padded sequences, (batch_size, edge_types, seq_len, seq_len)
52 | :return: output
53 | - **output**: FloatTensor, output feature tensor with the same size of input, (batch_size, seq_len, hidden_size)
54 | """
55 |
56 | adj_ = adj.transpose(0, 1) # (edge_types, batch_size, seq_len, seq_len)
57 | ts = []
58 | for i in range(self.edge_types):
59 | gate_status = F.sigmoid(self.Gates[i](input)) # (batch_size, seq_len, 1)
60 | adj_hat_i = adj_[i] * gate_status # (batch_size, seq_len, seq_len)
61 | ts.append(torch.bmm(adj_hat_i, self.GraphConv[i](input)))
62 | ts = torch.stack(ts).sum(dim=0, keepdim=False).to(self.device)
63 | if self.use_bn:
64 | ts = ts.transpose(1, 2).contiguous()
65 | ts = self.bn(ts)
66 | ts = ts.transpose(1, 2).contiguous()
67 | ts = F.relu(ts)
68 | if self.dropout is not None:
69 | ts = F.dropout(ts, p=self.dropout, training=self.training)
70 | return ts
71 |
72 | def __repr__(self):
73 | return self.__class__.__name__ + ' (' + str(self.in_features) + ' -> ' + str(self.out_features) + ')'
74 |
75 |
76 | if __name__ == "__main__":
77 | device = torch.device("cuda")
78 |
79 | BATCH_SIZE = 1
80 | SEQ_LEN = 8
81 | D = 6
82 | ET = 1
83 | CLASSN = 2
84 | adj = torch.sparse.FloatTensor(
85 | torch.LongTensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
86 | [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
87 | [0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7, 0, 1, 2, 3, 4, 5, 6, 7],
88 | [1, 3, 0, 2, 1, 5, 0, 4, 3, 7, 2, 6, 5, 7, 4, 6, 0, 1, 2, 3, 4, 5, 6, 7]]),
89 | torch.FloatTensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]),
90 | torch.Size([BATCH_SIZE, ET, SEQ_LEN, SEQ_LEN])).to_dense().to(device)
91 | input = torch.randn(BATCH_SIZE, SEQ_LEN, D).to(device)
92 | label = torch.LongTensor([0, 1, 0, 1, 0, 1, 0, 1]).to(device)
93 |
94 | cc = GraphConvolution(in_features=D, out_features=D, edge_types=ET, device=device, use_bn=True)
95 | oo = BottledOrthogonalLinear(in_features=D, out_features=CLASSN).to(device)
96 |
97 | optimizer = Adadelta(list(cc.parameters()) + list(oo.parameters()))
98 |
99 | aloss = 1e9
100 | df = 1e9
101 | while df > 1e-7:
102 | output = oo(cc(input, adj)).view(BATCH_SIZE * SEQ_LEN, CLASSN)
103 | loss = F.cross_entropy(output, label)
104 | df = abs(aloss - loss.item())
105 | aloss = loss.item()
106 | loss.backward()
107 | optimizer.step()
108 | log(aloss)
109 |
110 | log(F.softmax(output), dim=2)
111 |
--------------------------------------------------------------------------------
/src/dataflow/numpy/anno_mapping.py:
--------------------------------------------------------------------------------
1 | from collections import defaultdict
2 |
3 | def event_type_norm(type_str):
4 | return type_str.replace('.', '||').replace(':', '||').replace('-', '|').upper()
5 |
6 |
7 | def role_name_norm(type_str):
8 | return type_str.upper()
9 |
10 | entity_type_mapping_brat = {
11 | 'PER': 'PER',
12 | 'ORG': 'ORG',
13 | 'GPE': 'GPE',
14 | 'LOC': 'LOC',
15 | 'FAC': 'FAC',
16 | 'VEH': 'VEH',
17 | 'WEA': 'WEA',
18 | # 'TIM': 'TIME',
19 | # 'NUM': 'VALUE',
20 | # 'TIT',
21 | 'MON': 'VALUE',
22 | # 'URL',
23 | # 'RES',
24 | # 'BALLOT'
25 | }
26 |
27 | event_type_mapping_brat2ace = {
28 | 'Die': 'Life:Die',
29 | # 'Injure': 'Life:Injure',
30 | 'TransferMoney': 'Transaction:Transfer-Money',
31 | 'Attack': 'Conflict:Attack',
32 | 'Demonstrate': 'Conflict:Demonstrate',
33 | 'Correspondence': 'Contact:Phone-Write',
34 | 'Meet': 'Contact:Meet',
35 | 'ArrestJail': 'Justice:Arrest-Jail',
36 | # 'ReleaseParole': 'Justice:Release-Parole',
37 | 'TransportPerson': 'Movement:Transport',
38 | 'TransportArtifact': 'Movement:Transport',
39 | }
40 |
41 | event_type_mapping_ace2brat = {event_type_norm(v): k for k, v in event_type_mapping_brat2ace.items()}
42 | # print(event_type_mapping_ace2brat)
43 |
44 | event_type_mapping_aida2brat = {
45 | 'Life.Die': 'Die',
46 | 'Life.Injure': 'Injure',
47 | 'Transaction.TransferMoney': 'TransferMoney',
48 | 'Conflict.Attack': 'Attack',
49 | 'Conflict.Demonstrate': 'Demonstrate',
50 | 'Contact.Correspondence': 'Correspondence',
51 | 'Contact.Meet': 'Meet',
52 | 'Justice.ArrestJail': 'ArrestJail',
53 | 'Justice.ReleaseParole': 'ReleaseParole',
54 | 'Movement.TransportPerson': 'TransportPerson',
55 | 'Movement.TransportArtifact': 'TransportArtifact',
56 | }
57 |
58 | event_role_mapping_brat2ace = { # delete time-related ones when testing
59 | 'Die': {'Victim': 'Victim', 'Agent': 'Agent', 'Instrument': 'Instrument', 'Place': 'Place'}, #, 'Time': 'Time'},
60 | # 'Injure': {'Victim': 'Victim', 'Agent': 'Agent', 'Instrument': 'Instrument', 'Place': 'Place'}, #, 'Time': 'Time'},
61 | 'TransferMoney': {'Giver': 'Giver', 'Recipient': 'Recipient', 'Beneficiary': 'Beneficiary', 'Money': 'Money', 'Place': 'Place'}, #, 'Time': 'Time'},
62 | 'Attack': {'Attacker': 'Attacker', 'Instrument': 'Instrument', 'Place': 'Place', 'Target': 'Target'}, #, 'Time': 'Time'},
63 | 'Demonstrate': {'Demonstrator': 'Entity', 'Place': 'Place'}, #, 'Time': 'Time'},
64 | 'Correspondence': {'Participant': 'Entity', 'Place': 'Place'}, #, 'Time': 'Time'},
65 | 'Meet': {'Participant': 'Entity', 'Place': 'Place'}, #, 'Time': 'Time'},
66 | 'ArrestJail': {'Agent': 'Agent', 'Person': 'Person', 'Place': 'Place'}, #, 'Time': 'Time'},
67 | # 'ReleaseParole': {'Agent': 'Entity', 'Person': 'Person', 'Place': 'Place'}, #, 'Time': 'Time'},
68 | 'TransportPerson': {'Agent': 'Agent', 'Person': 'Artifact', 'Instrument': 'Vehicle', 'Destination': 'Destination', 'Origin': 'Origin'}, #, 'Time': 'Time'},
69 | 'TransportArtifact': {'Agent': 'Agent', 'Artifact': 'Artifact', 'Instrument': 'Vehicle', 'Destination': 'Destination', 'Origin': 'Origin'}, #, 'Time': 'Time'},
70 | }
71 |
72 | # event_role_mapping_ace2brat = {role_name_norm(v): k for t in event_role_mapping_brat2ace.items() for k, v in event_role_mapping_brat2ace[t]}
73 | event_role_mapping_ace2brat = defaultdict(lambda : defaultdict())
74 | for t in event_role_mapping_brat2ace:
75 | for k, v in event_role_mapping_brat2ace[t].items():
76 | event_type_ace = event_type_norm(event_type_mapping_brat2ace[t])
77 | event_role_mapping_ace2brat[event_type_ace][role_name_norm(v)] = k
78 | # print(event_role_mapping_ace2brat)
79 |
80 | event_type_mapping_image2ace = {
81 | 'Movement.TransportPerson': 'Movement:Transport',
82 | 'Movement.TransportArtifact': 'Movement:Transport',
83 | 'Life.Die': 'Life:Die',
84 | 'Conflict.Attack': 'Conflict:Attack',
85 | 'Conflict.Demonstrate': 'Conflict:Demonstrate',
86 | 'Contact.Phone-Write': 'Contact:Phone-Write',
87 | 'Contact.Meet': 'Contact:Meet',
88 | 'Transaction.TransferMoney': 'Transaction:Transfer-Money',
89 | 'Justice.ArrestJail': 'Justice:Arrest-Jail',
90 | }
91 |
92 | event_role_mapping_image2ace = {
93 | 'Life.Die': {'victim': 'Victim', 'agent': 'Agent', 'instrument': 'Instrument', 'place': 'Place'}, #, 'Time': 'Time'},
94 | 'Transaction.TransferMoney': {'instrument': 'Instrument', 'giver': 'Giver', 'recipient': 'Recipient', 'beneficiary': 'Beneficiary', 'money': 'Money', 'place': 'Place'}, #, 'Time': 'Time'},
95 | 'Conflict.Attack': {'attacker': 'Attacker', 'instrument': 'Instrument', 'place': 'Place', 'target': 'Target', 'victim': 'Target'}, #, 'Time': 'Time'},
96 | 'Conflict.Demonstrate': {'police':'Police', 'instrument': 'Instrument', 'demonstrator': 'Entity', 'place': 'Place', 'participant': 'Entity'}, #, 'Time': 'Time'},
97 | 'Contact.Phone-Write': {'instrument':'Instrument', 'participant': 'Entity', 'place': 'Place'}, #, 'Time': 'Time'},
98 | 'Contact.Meet': {'participant': 'Entity', 'place': 'Place'}, #, 'Time': 'Time'},
99 | 'Justice.ArrestJail': {'instrument': 'Instrument', 'agent': 'Agent', 'person': 'Person', 'place': 'Place'}, #, 'Time': 'Time'},
100 | 'Movement.TransportPerson': {'agent': 'Agent', 'person': 'Artifact', 'instrument': 'Instrument', 'destination': 'Destination', 'origin': 'Origin'}, #, 'Time': 'Time'},
101 | 'Movement.TransportArtifact': {'person': 'Artifact', 'agent': 'Agent', 'artifact': 'Artifact', 'instrument': 'Instrument', 'destination': 'Destination', 'origin': 'Origin'}, #, 'Time': 'Time'},
102 | }
103 | # 'TransportArtifact' has 'person' in image annotation
--------------------------------------------------------------------------------
/src/dataflow/numpy/prepare_vocab.py:
--------------------------------------------------------------------------------
1 | """
2 | Prepare vocabulary and initial word vectors.
3 | """
4 | import json
5 | import pickle
6 | import argparse
7 | import numpy as np
8 | from collections import Counter
9 |
10 | import sys
11 | #sys.path.append("../")
12 | from utils import vocab, constant, helper
13 |
14 | def parse_args():
15 | parser = argparse.ArgumentParser(description='Prepare vocab for relation extraction.')
16 | parser.add_argument('data_dir', help='TACRED directory.')
17 | # /data/m1/lim22/multimedia-common-space/Multimedia-Common-Space/ace/JMEE_data/head
18 | parser.add_argument('vocab_dir', help='Output vocab directory.')
19 | # /data/m1/lim22/multimedia-common-space/Multimedia-Common-Space/ace/vocab
20 | parser.add_argument('--glove_dir', default='/data/m1/lim22/env/glove', help='GloVe directory.')
21 | parser.add_argument('--wv_file', default='glove.840B.300d.txt', help='GloVe vector file.')
22 | parser.add_argument('--wv_dim', type=int, default=300, help='GloVe vector dimension.')
23 | parser.add_argument('--min_freq', type=int, default=0, help='If > 0, use min_freq as the cutoff.')
24 | parser.add_argument('--lower', action='store_true', help='If specified, lowercase all words.')
25 |
26 | args = parser.parse_args()
27 | return args
28 |
29 | def main():
30 | args = parse_args()
31 |
32 | # input files
33 | train_file = args.data_dir + '/JMEE_train.json'
34 | dev_file = args.data_dir + '/JMEE_dev.json'
35 | test_file = args.data_dir + '/JMEE_test.json'
36 | wv_file = args.glove_dir + '/' + args.wv_file
37 | wv_dim = args.wv_dim
38 |
39 | # output files
40 | helper.ensure_dir(args.vocab_dir)
41 | vocab_file = args.vocab_dir + '/vocab.pkl'
42 | emb_file = args.vocab_dir + '/embedding.npy'
43 |
44 | # load files
45 | print("loading files...")
46 | train_tokens = load_tokens(train_file)
47 | dev_tokens = load_tokens(dev_file)
48 | test_tokens = load_tokens(test_file)
49 | if args.lower:
50 | train_tokens, dev_tokens, test_tokens = [[t.lower() for t in tokens] for tokens in\
51 | (train_tokens, dev_tokens, test_tokens)]
52 |
53 | # load glove
54 | print("loading glove...")
55 | glove_vocab = vocab.load_glove_vocab(wv_file, wv_dim)
56 | print("{} words loaded from glove.".format(len(glove_vocab)))
57 |
58 | print("building vocab...")
59 | v = build_vocab(train_tokens, glove_vocab, args.min_freq)
60 |
61 | print("calculating oov...")
62 | datasets = {'train': train_tokens, 'dev': dev_tokens, 'test': test_tokens}
63 | for dname, d in datasets.items():
64 | total, oov = count_oov(d, v)
65 | print("{} oov: {}/{} ({:.2f}%)".format(dname, oov, total, oov*100.0/total))
66 |
67 | print("building embeddings...")
68 | embedding = vocab.build_embedding(wv_file, v, wv_dim)
69 | print("embedding size: {} x {}".format(*embedding.shape))
70 |
71 | print("dumping to files...")
72 | with open(vocab_file, 'wb') as outfile:
73 | pickle.dump(v, outfile)
74 | np.save(emb_file, embedding)
75 | print("all done.")
76 |
77 | # def load_tokens(filename):
78 | # with open(filename) as infile:
79 | # data = json.load(infile)
80 | # tokens = []
81 | # for d in data:
82 | # ts = d['token']
83 | # ss, se, os, oe = d['subj_start'], d['subj_end'], d['obj_start'], d['obj_end']
84 | # # do not create vocab for entity words
85 | # ts[ss:se+1] = ['']*(se-ss+1)
86 | # ts[os:oe+1] = ['']*(oe-os+1)
87 | # tokens += list(filter(lambda t: t!='', ts))
88 | # print("{} tokens from {} examples loaded from {}.".format(len(tokens), len(data), filename))
89 | # return tokens
90 | def load_tokens(filename):
91 | parsed_data = json.load(open(filename))
92 |
93 | tokens = []
94 | for sent in parsed_data:
95 | ts = sent['words']
96 | # # do not create vocab for entity words
97 | # events = sent['golden-event-mentions']
98 | # for event in events:
99 | # trigger = event['trigger']
100 | # args = event['arguments']
101 | # for arg in args:
102 | # arg_start = arg['start']
103 | # arg_end = arg['end']
104 | tokens += list(filter(lambda t: t != '', ts))
105 | print("{} tokens from {} examples loaded from {}.".format(len(tokens), len(parsed_data), filename))
106 | return tokens
107 |
108 | def build_vocab(tokens, glove_vocab, min_freq):
109 | """ build vocab from tokens and glove words. """
110 | counter = Counter(t for t in tokens)
111 | # if min_freq > 0, use min_freq, otherwise keep all glove words
112 | if min_freq > 0:
113 | v = sorted([t for t in counter if counter.get(t) >= min_freq], key=counter.get, reverse=True)
114 | else:
115 | v = sorted([t for t in counter if t in glove_vocab], key=counter.get, reverse=True)
116 | # add special tokens and entity mask tokens
117 | v = constant.VOCAB_PREFIX + entity_masks() + v
118 | print("vocab built with {}/{} words.".format(len(v), len(counter)))
119 | return v
120 |
121 | def count_oov(tokens, vocab):
122 | c = Counter(t for t in tokens)
123 | total = sum(c.values())
124 | matched = sum(c[t] for t in vocab)
125 | return total, total - matched
126 |
127 | def entity_masks():
128 | """ Get all entity mask tokens as a list. Accoeding to constant.py"""
129 | masks = []
130 | # subj_entities = list(constant.SUBJ_NER_TO_ID.keys())[2:]
131 | obj_entities = list(constant.OBJ_NER_TO_ID.keys())[2:]
132 | # masks += ["SUBJ-" + e for e in subj_entities]
133 | # masks += ["OBJ-" + e for e in obj_entities]
134 | masks += [e for e in obj_entities]
135 | print(masks) #['SUBJ-ORGANIZATION', 'SUBJ-PERSON', 'OBJ-PERSON', 'OBJ-ORGANIZATION', 'OBJ-DATE', 'OBJ-NUMBER', 'OBJ-TITLE', 'OBJ-COUNTRY', 'OBJ-LOCATION', 'OBJ-CITY', 'OBJ-MISC', 'OBJ-STATE_OR_PROVINCE', 'OBJ-DURATION', 'OBJ-NATIONALITY', 'OBJ-CAUSE_OF_DEATH', 'OBJ-CRIMINAL_CHARGE', 'OBJ-RELIGION', 'OBJ-URL', 'OBJ-IDEOLOGY']
136 | return masks
137 |
138 | if __name__ == '__main__':
139 | main()
140 |
141 |
142 |
--------------------------------------------------------------------------------
/src/util/util_model.py:
--------------------------------------------------------------------------------
1 | import json
2 | import math
3 | import sys
4 | import ujson as json
5 |
6 | import torch
7 | import torch.nn as nn
8 | from torch.nn import init
9 | import numpy as np
10 |
11 | import sys
12 | #sys.path.append('/dvmm-filer2/users/manling/mm-event-graph2')
13 |
14 |
15 |
16 | class SparseMM(torch.autograd.Function):
17 | """
18 | Sparse x dense matrix multiplication with autograd support.
19 |
20 | Implementation by Soumith Chintala:
21 | https://discuss.pytorch.org/t/
22 | does-pytorch-support-autograd-on-sparse-matrix/6156/7
23 | """
24 |
25 | def __init__(self, sparse):
26 | super(SparseMM, self).__init__()
27 | self.sparse = sparse
28 |
29 | def forward(self, dense):
30 | return torch.mm(self.sparse, dense)
31 |
32 | def backward(self, grad_output):
33 | grad_input = None
34 | if self.needs_input_grad[0]:
35 | grad_input = torch.mm(self.sparse.t(), grad_output)
36 | return grad_input
37 |
38 |
39 | class Bottle(nn.Module):
40 | ''' Perform the reshape routine before and after an operation '''
41 |
42 | def forward(self, input):
43 | size = input.size()
44 | out = super(Bottle, self).forward(input.contiguous().view(np.prod(size[:-1]), size[-1]))
45 | return out.view(*(size[:-1] + (-1,)))
46 |
47 |
48 |
49 | class XavierLinear(nn.Module):
50 | '''
51 | Simple Linear layer with Xavier init
52 |
53 | Paper by Xavier Glorot and Yoshua Bengio (2010):
54 | Understanding the difficulty of training deep feedforward neural networks
55 | http://proceedings.mlr.press/v9/glorot10a/glorot10a.pdf
56 | '''
57 |
58 | def __init__(self, in_features, out_features, bias=True):
59 | super(XavierLinear, self).__init__()
60 | self.linear = nn.Linear(in_features, out_features, bias=bias)
61 | init.xavier_normal_(self.linear.weight)
62 |
63 | def forward(self, x):
64 | return self.linear(x)
65 |
66 | class OrthogonalLinear(nn.Module):
67 | def __init__(self, in_features, out_features, bias=True):
68 | super(OrthogonalLinear, self).__init__()
69 | self.linear = nn.Linear(in_features, out_features, bias=bias)
70 | init.orthogonal_(self.linear.weight)
71 |
72 | def forward(self, x):
73 | return self.linear(x)
74 |
75 |
76 | class BottledLinear(Bottle, nn.Linear):
77 | pass
78 |
79 |
80 | class BottledXavierLinear(Bottle, XavierLinear):
81 | pass
82 |
83 |
84 | class BottledOrthogonalLinear(Bottle, OrthogonalLinear):
85 | pass
86 |
87 |
88 | class MLP(nn.Module):
89 | def __init__(self, dim_in_hid_out, act_fn='ReLU', last_act=False):
90 | super(MLP, self).__init__()
91 | layers = []
92 | for i in range(len(dim_in_hid_out) - 1):
93 | layers.append(XavierLinear(dim_in_hid_out[i], dim_in_hid_out[i + 1]))
94 | if i < len(dim_in_hid_out) - 2 or last_act:
95 | layers.append(getattr(torch.nn, act_fn)())
96 | self.model = torch.nn.Sequential(*layers)
97 |
98 | def forward(self, x):
99 | return self.model(x)
100 |
101 | class BottledMLP(Bottle, MLP):
102 | pass
103 |
104 |
105 |
106 | def log(*args, **kwargs):
107 | print(file=sys.stdout, flush=True, *args, **kwargs)
108 |
109 |
110 | def logerr(*args, **kwargs):
111 | print(file=sys.stderr, flush=True, *args, **kwargs)
112 |
113 |
114 | def logonfile(fp, *args, **kwargs):
115 | fp.write(*args, **kwargs)
116 |
117 |
118 | def progressbar(cur, total, other_information):
119 | percent = '{:.2%}'.format(cur / total)
120 | if type(other_information) is str:
121 | log("\r[%-50s] %s %s" % ('=' * int(math.floor(cur * 50 / total)), percent, other_information))
122 | else:
123 | log("\r[%-50s] %s" % ('=' * int(math.floor(cur * 50 / total)), percent))
124 |
125 |
126 | def save_hyps(hyps, fp):
127 | json.dump(hyps, fp)
128 |
129 |
130 | def load_hyps(fp):
131 | hyps = json.load(fp)
132 | return hyps
133 |
134 |
135 | def masked_log_softmax(vector, mask, dim=-1):
136 | """
137 | mask: [1,1,1,0,0]: the padded one is 0
138 |
139 | ``torch.nn.functional.log_softmax(vector)`` does not work if some elements of ``vector`` should be
140 | masked. This performs a log_softmax on just the non-masked portions of ``vector``. Passing
141 | ``None`` in for the mask is also acceptable; you'll just get a regular log_softmax.
142 | ``vector`` can have an arbitrary number of dimensions; the only requirement is that ``mask`` is
143 | broadcastable to ``vector's`` shape. If ``mask`` has fewer dimensions than ``vector``, we will
144 | unsqueeze on dimension 1 until they match. If you need a different unsqueezing of your mask,
145 | do it yourself before passing the mask into this function.
146 | In the case that the input vector is completely masked, the return value of this function is
147 | arbitrary, but not ``nan``. You should be masking the result of whatever computation comes out
148 | of this in that case, anyway, so the specific values returned shouldn't matter. Also, the way
149 | that we deal with this case relies on having single-precision floats; mixing half-precision
150 | floats with fully-masked vectors will likely give you ``nans``.
151 | If your logits are all extremely negative (i.e., the max value in your logit vector is -50 or
152 | lower), the way we handle masking here could mess you up. But if you've got logit values that
153 | extreme, you've got bigger problems than this.
154 | """
155 | if mask is not None:
156 | mask = mask.float()
157 | while mask.dim() < vector.dim():
158 | mask = mask.unsqueeze(1)
159 | # vector + mask.log() is an easy way to zero out masked elements in logspace, but it
160 | # results in nans when the whole vector is masked. We need a very small value instead of a
161 | # zero in the mask for these cases. log(1 + 1e-45) is still basically 0, so we can safely
162 | # just add 1e-45 before calling mask.log(). We use 1e-45 because 1e-46 is so small it
163 | # becomes 0 - this is just the smallest value we can actually use.
164 | vector = vector + (mask + 1e-45).log() # float('-inf')
165 | return torch.nn.functional.log_softmax(vector, dim=dim)
--------------------------------------------------------------------------------
/src/models/modules/EmbeddingLayer.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from torch.nn import functional as F
4 |
5 |
6 | class EmbeddingLayer(nn.Module):
7 | def __init__(self, embedding_size=None, embedding_matrix=None,
8 | fine_tune=True, dropout=0.5,
9 | padding_idx=None,
10 | max_norm=None, norm_type=2, scale_grad_by_freq=False,
11 | sparse=False,
12 | device=torch.device("cpu")):
13 | '''
14 | Embedding Layer need at least one of `embedding_size` and `embedding_matrix`
15 | :param embedding_size: tuple, contains 2 integers indicating the shape of embedding matrix, eg: (20000, 300)
16 | :param embedding_matrix: torch.Tensor, the pre-trained value of embedding matrix
17 | :param fine_tune: boolean, whether fine tune embedding matrix
18 | :param dropout: float, dropout rate
19 | :param padding_idx: int, if given, pads the output with zeros whenever it encounters the index
20 | :param max_norm: float, if given, will renormalize the embeddings to always have a norm lesser than this
21 | :param norm_type: float, the p of the p-norm to compute for the max_norm option
22 | :param scale_grad_by_freq: boolean, if given, this will scale gradients by the frequency of the words in the mini-batch
23 | :param sparse: boolean, *unclear option copied from original module*
24 | '''
25 | super(EmbeddingLayer, self).__init__()
26 |
27 | if embedding_matrix is not None:
28 | embedding_size = embedding_matrix.size()
29 | else:
30 | embedding_matrix = torch.nn.init.uniform_(torch.FloatTensor(embedding_size[0], embedding_size[1]),
31 | a=-0.15,
32 | b=0.15)
33 | assert (embedding_size is not None)
34 | assert (embedding_matrix is not None)
35 | # Config copying
36 | self.matrix = nn.Embedding(num_embeddings=embedding_size[0],
37 | embedding_dim=embedding_size[1],
38 | padding_idx=padding_idx,
39 | max_norm=max_norm,
40 | norm_type=norm_type,
41 | scale_grad_by_freq=scale_grad_by_freq,
42 | sparse=sparse)
43 | self.matrix.weight.data.copy_(embedding_matrix)
44 | self.matrix.weight.requires_grad = fine_tune
45 | self.dropout = dropout if type(dropout) == float and -1e-7 < dropout < 1 + 1e-7 else None
46 |
47 | self.device = device
48 | self.to(device)
49 |
50 | # def init_embeddings(self):
51 | # if self.emb_matrix is None:
52 | # self.emb.weight.data[1:,:].uniform_(-1.0, 1.0)
53 | # else:
54 | # self.emb_matrix = torch.from_numpy(self.emb_matrix)
55 | # self.emb.weight.data.copy_(self.emb_matrix)
56 | # # decide finetuning
57 | # if self.opt['topn'] <= 0:
58 | # print("Do not finetune word embedding layer.")
59 | # self.emb.weight.requires_grad = False
60 | # elif self.opt['topn'] < self.opt['vocab_size']:
61 | # print("Finetune top {} word embeddings.".format(self.opt['topn']))
62 | # self.emb.weight.register_hook(lambda x: \
63 | # torch_utils.keep_partial_grad(x, self.opt['topn']))
64 | # else:
65 | # print("Finetune all embeddings.")
66 |
67 | def forward(self, x):
68 | '''
69 | Forward this module
70 | :param x: torch.LongTensor, token sequence or sentence, shape is [batch, sentence_len]
71 | :return: torch.FloatTensor, output data, shape is [batch, sentence_len, embedding_size]
72 | '''
73 | if self.dropout is not None:
74 | return F.dropout(self.matrix(x), p=self.dropout, training=self.training)
75 | else:
76 | return self.matrix(x)
77 |
78 |
79 | class MultiLabelEmbeddingLayer(nn.Module):
80 | def __init__(self, embedding_size=None, embedding_matrix=None,
81 | fine_tune=True, dropout=0.5,
82 | padding_idx=None,
83 | max_norm=None, norm_type=2, scale_grad_by_freq=False,
84 | sparse=False,
85 | device=torch.device("cpu")):
86 | '''
87 | MultiLabelEmbeddingLayer Layer need at least one of `embedding_size` and `embedding_matrix`
88 | :param embedding_size: tuple, contains 2 integers indicating the shape of embedding matrix, eg: (20000, 300)
89 | :param embedding_matrix: torch.Tensor, the pre-trained value of embedding matrix
90 | :param fine_tune: boolean, whether fine tune embedding matrix
91 | :param dropout: float, dropout rate
92 | :param padding_idx: int, if given, pads the output with zeros whenever it encounters the index
93 | :param max_norm: float, if given, will renormalize the embeddings to always have a norm lesser than this
94 | :param norm_type: float, the p of the p-norm to compute for the max_norm option
95 | :param scale_grad_by_freq: boolean, if given, this will scale gradients by the frequency of the words in the mini-batch
96 | :param sparse: boolean, *unclear option copied from original module*
97 | '''
98 | super(MultiLabelEmbeddingLayer, self).__init__()
99 |
100 | if embedding_matrix is not None:
101 | embedding_size = embedding_matrix.size()
102 | else:
103 | embedding_matrix = torch.randn(embedding_size[0], embedding_size[1])
104 | assert (embedding_size is not None)
105 | assert (embedding_matrix is not None)
106 | # Config copying
107 | self.matrix = nn.Embedding(num_embeddings=embedding_size[0],
108 | embedding_dim=embedding_size[1],
109 | padding_idx=padding_idx,
110 | max_norm=max_norm,
111 | norm_type=norm_type,
112 | scale_grad_by_freq=scale_grad_by_freq,
113 | sparse=sparse)
114 | self.matrix.weight.data.copy_(embedding_matrix)
115 | self.matrix.weight.requires_grad = fine_tune
116 | self.dropout = dropout if type(dropout) == float and -1e-7 < dropout < 1 + 1e-7 else None
117 |
118 | self.device = device
119 | self.to(device)
120 |
121 | def forward(self, x):
122 | '''
123 | Forward this module
124 | :param x: list, token sequence or sentence, shape is [batch, sentence_len, variable_size(>=1)]
125 | :return: torch.FloatTensor, output data, shape is [batch, sentence_len, embedding_size]
126 | '''
127 | BATCH = len(x)
128 | SEQ_LEN = len(x[0])
129 | x = [self.matrix(torch.LongTensor(x[i][j]).to(self.device)).sum(0)
130 | for i in range(BATCH)
131 | for j in range(SEQ_LEN)]
132 | x = torch.stack(x).view(BATCH, SEQ_LEN, -1)
133 | if self.dropout is not None:
134 | return F.dropout(x, p=self.dropout, training=self.training)
135 | else:
136 | return x
137 |
--------------------------------------------------------------------------------
/src/eval/SRtesting.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | class SRTester():
4 | def __init__(self):
5 | pass
6 |
7 | def calculate_lists(self, y, y_):
8 | '''
9 | for a sequence, whether the prediction is correct
10 | note that len(y) == len(y_)
11 | :param y:
12 | :param y_:
13 | :return:
14 | '''
15 | ct = 0
16 | p2 = len(y_)
17 | p1 = len(y)
18 | for i in range(p2):
19 | if y[i] == y_[i]:
20 | ct = ct + 1
21 | if ct == 0 or p1 == 0 or p2 == 0:
22 | return 0.0, 0.0, 0.0
23 | else:
24 | p = 1.0 * ct / p2
25 | r = 1.0 * ct / p1
26 | f1 = 2.0 * p * r / (p + r)
27 | return p, r, f1
28 |
29 | def calculate_sets_no_order(self, y, y_):
30 | '''
31 | for each predicted item, whether it is in the gt
32 | :param y: [batch, items]
33 | :param y_: [batch, items]
34 | :return:
35 | '''
36 | ct, p1, p2 = 0, 0, 0
37 | for batch, batch_ in zip(y, y_):
38 | value_set = set(batch)
39 | value_set_ = set(batch_)
40 | p1 += len(value_set)
41 | p2 += len(value_set_)
42 |
43 | for value_ in value_set_:
44 | # if value_ == '(0,0,0)':
45 | if value_ in value_set:
46 | ct += 1
47 |
48 | if ct == 0 or p1 == 0 or p2 == 0:
49 | return 0.0, 0.0, 0.0
50 | else:
51 | p = 1.0 * ct / p2
52 | r = 1.0 * ct / p1
53 | f1 = 2.0 * p * r / (p + r)
54 | return p, r, f1
55 |
56 | def calculate_sets_noun(self, y, y_):
57 | '''
58 | for each ground truth entity, whether it is in the predicted entities
59 | :param y: [batch, role_num, multiple_args]
60 | :param y_: [batch, role_num, multiple_entities]
61 | :return:
62 | '''
63 | # print('y', y)
64 | # print('y_', y_)
65 | ct, p1, p2 = 0, 0, 0
66 | # for batch_idx, batch_idx_ in zip(y, y_):
67 | # batch = y[batch_idx]
68 | # batch_ = y_[batch_idx_]
69 | for batch, batch_ in zip(y, y_):
70 | # print('batch', batch)
71 | # print('batch_', batch_)
72 | p1 += len(batch)
73 | p2 += len(batch_)
74 | for role in batch:
75 | found = False
76 | entities = batch[role]
77 | for entity in entities:
78 | for role_ in batch_:
79 | entities_ = batch_[role_]
80 | if entity in entities_:
81 | ct += 1
82 | found = True
83 | break
84 | if found:
85 | break
86 |
87 | if ct == 0 or p1 == 0 or p2 == 0:
88 | return 0.0, 0.0, 0.0
89 | else:
90 | p = 1.0 * ct / p2
91 | r = 1.0 * ct / p1
92 | f1 = 2.0 * p * r / (p + r)
93 | return p, r, f1
94 |
95 | def calculate_sets_triple(self, y, y_):
96 | '''
97 | for each role, whether the predicted entities have overlap with the gt entities
98 | :param y: dict, role -> entities
99 | :param y_: dict, role -> entities
100 | :return:
101 | '''
102 | ct, p1, p2 = 0, 0, 0
103 | # for batch_idx, batch_idx_ in zip(y, y_):
104 | # batch = y[batch_idx]
105 | # batch_ = y_[batch_idx_]
106 | for batch, batch_ in zip(y, y_):
107 | p1 += len(batch)
108 | p2 += len(batch_)
109 | for role in batch:
110 | # is_correct = False
111 | entities = batch[role]
112 | if role in batch_:
113 | entities_ = batch_[role]
114 | for entity_ in entities_:
115 | if entity_ in entities:
116 | ct += 1
117 | # is_correct = True
118 | break
119 | # if not is_correct:
120 | # print('Wrong one:', role, batch[role])
121 |
122 | if ct == 0 or p1 == 0 or p2 == 0:
123 | return 0.0, 0.0, 0.0
124 | else:
125 | p = 1.0 * ct / p2
126 | r = 1.0 * ct / p1
127 | f1 = 2.0 * p * r / (p + r)
128 | return p, r, f1
129 |
130 |
131 | def visualize_sets_triple(self, image_id_batch, y_verb, y_verb_, y, y_, verb_id2s, role_id2s, noun_id2s, sr_visualpath, image_path=None):
132 | '''
133 | for each role, whether the predicted entities have overlap with the gt entities
134 | :param y: dict, role -> entities
135 | :param y_: dict, role -> entities
136 | :return:
137 | '''
138 |
139 | # sr_visualpath = '/scratch/manling2/html/m2e2/sr_errors'
140 | image_path = '../of500_images_resized/'
141 |
142 | ct, p1, p2 = 0, 0, 0
143 | # for batch_idx, batch_idx_ in zip(y, y_):
144 | # batch = y[batch_idx]
145 | # batch_ = y_[batch_idx_]
146 | batch_idx = 0
147 | print('visualize image:', image_id_batch)
148 |
149 | if y_verb_[0] == y_verb[0]:
150 | f_html = open(os.path.join(sr_visualpath, 'verb_correct', '%s.html' % (image_id_batch[0])), 'w')
151 | else:
152 | f_html = open(os.path.join(sr_visualpath, 'verb_wrong', '%s.html' % (image_id_batch[0])), 'w')
153 | f_html.write("
\n
")
154 |
155 | f_html.write('[verb_prediction] %s \n
' % verb_id2s[y_verb_[0]])
156 | f_html.write('[verb_ground_truth] %s \n
\n
' % verb_id2s[y_verb[0]])
157 |
158 | for batch, batch_ in zip(y, y_):
159 | # print('image_id', image_id_batch[batch_idx])
160 | p1 += len(batch)
161 | p2 += len(batch_)
162 | for role in batch:
163 | is_correct = False
164 | entities = batch[role]
165 | if role in batch_:
166 | entities_ = batch_[role]
167 | for entity_ in entities_:
168 | if entity_ in entities:
169 | ct += 1
170 | is_correct = True
171 | break
172 | if not is_correct:
173 | f_html.write('[prediction]\n
')
174 | f_html.write('%s.%s = [' % (verb_id2s[y_verb_[batch_idx]], role_id2s[role]))
175 | for entity_ in entities_:
176 | f_html.write('%s, ' % noun_id2s[entity_])
177 | f_html.write(']\n
')
178 | f_html.write('[ground truth]\n
')
179 | f_html.write('%s.%s = [' % (verb_id2s[y_verb[batch_idx]], role_id2s[role]))
180 | for entity in entities:
181 | f_html.write('%s, ' % noun_id2s[entity])
182 | f_html.write(']\n\n
')
183 | batch_idx = batch_idx + 1
184 | f_html.flush()
185 | f_html.close()
186 |
187 | if ct == 0 or p1 == 0 or p2 == 0:
188 | return 0.0, 0.0, 0.0
189 | else:
190 | p = 1.0 * ct / p2
191 | r = 1.0 * ct / p1
192 | f1 = 2.0 * p * r / (p + r)
193 | return p, r, f1
--------------------------------------------------------------------------------
/src/util/constant.py:
--------------------------------------------------------------------------------
1 | """
2 | Define constants.
3 | """
4 | EMB_INIT_RANGE = 1.0
5 |
6 | # vocab
7 | PAD_TOKEN = ''
8 | PAD_ID = 0
9 | UNK_TOKEN = ''
10 | UNK_ID = 1
11 |
12 | VOCAB_PREFIX = [PAD_TOKEN, UNK_TOKEN]
13 |
14 | # hard-coded mappings from fields to ids
15 | # SUBJ is trigger, no type
16 | # SUBJ_NER_TO_ID = {PAD_TOKEN: 0, UNK_TOKEN: 1, 'ORGANIZATION': 2, 'PERSON': 3}
17 |
18 | OBJ_NER_TO_ID = {PAD_TOKEN: 0, UNK_TOKEN: 1, 'Weapon': 2, 'Organization': 3, 'Geopolitical_Entity': 4, 'Location': 5, 'Time': 6, 'Vehicle': 7, 'Value': 8, 'Facility': 9, 'Person': 10}
19 |
20 | NER_NORM = {'WEA': 'Weapon', 'ORG': 'Organization', 'GPE': 'Geopolitical_Entity', 'LOC': 'Location', 'TIME': 'Time', 'VEH': 'Vehicle', 'VALUE': 'Value', 'FAC': 'Facility', 'PER': 'Person'}
21 |
22 | NER_NIL_LABEL = 'O'
23 | NUM_OTHER_NER = 3
24 | NER_TO_ID = {PAD_TOKEN: 0, UNK_TOKEN: 1, NER_NIL_LABEL: 2, 'Weapon': 3, 'Organization': 4, 'Geopolitical_Entity': 5, 'Location': 6, 'Time': 7, 'Vehicle': 8, 'Value': 9, 'Facility': 10, 'Person': 11}
25 | # {'ORG:Sports', 'VALUE:other', 'GPE:GPE-Cluster', 'VEH:Subarea-Vehicle', 'WEA:Nuclear', 'ORG:Religious', 'ORG:Government', 'FAC:Building-Grounds', 'LOC:Address', 'PER:Indeterminate', 'LOC:Boundary', 'ORG:Entertainment', 'WEA:Exploding', 'PER:Group', 'WEA:Blunt', 'VEH:Underspecified', 'WEA:Projectile', 'GPE:Nation', 'VEH:Water', 'GPE:Special', 'LOC:Region-General', 'LOC:Celestial', 'GPE:County-or-District', 'VEH:Air', 'WEA:Sharp', 'ORG:Educational', 'PER:Individual', 'FAC:Airport', 'WEA:Chemical', 'GPE:Population-Center', 'LOC:Region-International', 'TIME:other', 'ORG:Media', 'LOC:Water-Body', 'WEA:Biological', 'FAC:Subarea-Facility', 'WEA:Shooting', 'ORG:Commercial', 'ORG:Non-Governmental', 'GPE:Continent', 'ORG:Medical-Science', 'WEA:Underspecified', 'FAC:Path', 'GPE:State-or-Province', 'VEH:Land', 'LOC:Land-Region-Natural', 'FAC:Plant'
26 |
27 | POS_TO_ID = {PAD_TOKEN: 0, UNK_TOKEN: 1, 'NNP': 2, 'NN': 3, 'IN': 4, 'DT': 5, ',': 6, 'JJ': 7, 'NNS': 8, 'VBD': 9, 'CD': 10, 'CC': 11, '.': 12, 'RB': 13, 'VBN': 14, 'PRP': 15, 'TO': 16, 'VB': 17, 'VBG': 18, 'VBZ': 19, 'PRP$': 20, ':': 21, 'POS': 22, '\'\'': 23, '``': 24, '-RRB-': 25, '-LRB-': 26, 'VBP': 27, 'MD': 28, 'NNPS': 29, 'WP': 30, 'WDT': 31, 'WRB': 32, 'RP': 33, 'JJR': 34, 'JJS': 35, '$': 36, 'FW': 37, 'RBR': 38, 'SYM': 39, 'EX': 40, 'RBS': 41, 'WP$': 42, 'PDT': 43, 'LS': 44, 'UH': 45, '#': 46}
28 |
29 | DEPREL_TO_ID = {PAD_TOKEN: 0, UNK_TOKEN: 1, 'punct': 2, 'compound': 3, 'case': 4, 'nmod': 5, 'det': 6, 'nsubj': 7, 'amod': 8, 'conj': 9, 'dobj': 10, 'ROOT': 11, 'cc': 12, 'nmod:poss': 13, 'mark': 14, 'advmod': 15, 'appos': 16, 'nummod': 17, 'dep': 18, 'ccomp': 19, 'aux': 20, 'advcl': 21, 'acl:relcl': 22, 'xcomp': 23, 'cop': 24, 'acl': 25, 'auxpass': 26, 'nsubjpass': 27, 'nmod:tmod': 28, 'neg': 29, 'compound:prt': 30, 'mwe': 31, 'parataxis': 32, 'root': 33, 'nmod:npmod': 34, 'expl': 35, 'csubj': 36, 'cc:preconj': 37, 'iobj': 38, 'det:predet': 39, 'discourse': 40, 'csubjpass': 41}
30 |
31 | NEGATIVE_LABEL_TRIGGER = 'other_event'
32 | NUM_OTHER_TRIGGER = 2
33 | LABEL_TO_ID_TRIGGER = {PAD_TOKEN: 0, NEGATIVE_LABEL_TRIGGER: 1, 'Conflict:Attack': 2, 'Conflict:Demonstrate': 3,
34 | 'Contact:Meet': 4, 'Contact:Phone-Write': 5, 'Life:Die': 6, 'Movement:Transport': 7,
35 | 'Justice:Arrest-Jail': 8, 'Transaction:Transfer-Money': 9} # 'Life:Injure', 'Justice:Release-Parole'
36 | LABEL_TO_ID_TRIGGER_UNSEEN = {'Personnel:Elect': 2,
37 | 'Life:Marry': 3, 'Life:Injure': 4, 'Justice:Execute': 5, 'Justice:Trial-Hearing': 6,
38 | 'Life:Be-Born': 7, 'Justice:Convict': 8, 'Justice:Release-Parole': 9, 'Justice:Fine': 10}
39 | LABEL_TO_ID_TRIGGER_UNVISUAL = {'Personnel:End-Position': 3, 'Personnel:Start-Position': 4, 'Personnel:Nominate': 5,
40 | 'Justice:Sue': 7, 'Business:End-Org': 9, 'Business:Start-Org': 11,
41 | 'Transaction:Transfer-Ownership': 10, 'Justice:Sentence': 13,
42 | 'Justice:Charge-Indict': 16, 'Business:Declare-Bankruptcy': 18,
43 | 'Justice:Pardon': 21, 'Justice:Appeal': 22, 'Justice:Extradite': 23,
44 | 'Life:Divorce': 24, 'Business:Merge-Org': 25, 'Justice:Acquit': 26}
45 | LABEL_TO_ID_TRIGGER_ALL = {PAD_TOKEN: 0, UNK_TOKEN: 1, NEGATIVE_LABEL_TRIGGER: 2, 'Personnel:Elect': 3,
46 | 'Personnel:End-Position': 4, 'Personnel:Start-Position': 5,
47 | 'Movement:Transport': 6, 'Conflict:Attack': 7, 'Personnel:Nominate': 8, 'Contact:Meet': 9,
48 | 'Life:Marry': 10,
49 | 'Justice:Sue': 11, 'Contact:Phone-Write': 12, 'Transaction:Transfer-Money': 13,
50 | 'Conflict:Demonstrate': 14, 'Life:Injure': 15, 'Business:End-Org': 16, 'Life:Die': 17,
51 | 'Justice:Arrest-Jail': 18, 'Transaction:Transfer-Ownership': 19, 'Business:Start-Org': 20,
52 | 'Justice:Execute': 21, 'Justice:Sentence': 22, 'Justice:Trial-Hearing': 23, 'Life:Be-Born': 24,
53 | 'Justice:Charge-Indict': 25, 'Justice:Convict': 26, 'Business:Declare-Bankruptcy': 27,
54 | 'Justice:Release-Parole': 28, 'Justice:Fine': 29, 'Justice:Pardon': 30, 'Justice:Appeal': 31,
55 | 'Justice:Extradite': 32, 'Life:Divorce': 33, 'Business:Merge-Org': 34, 'Justice:Acquit': 35}
56 |
57 | NEGATIVE_LABEL_ROLE = 'other_role'
58 | NUM_OTHER_ROLE = 2
59 | LABEL_TO_ID_ROLE = {PAD_TOKEN: 0, NEGATIVE_LABEL_ROLE: 1, 'Buyer': 2, 'Target': 3, 'Agent': 4, 'Vehicle': 5,
60 | 'Instrument': 6, 'Person': 7, 'Victim': 8, 'Attacker': 9, 'Artifact': 10, 'Seller': 11,
61 | 'Recipient': 12, 'Money': 13, 'Giver': 14, 'Entity': 15, 'Place': 16, 'Defendant': 17,
62 | 'Destination': 18, 'Origin': 19}
63 | LABEL_TO_ID_ROLE_UNSEEN = {'Beneficiary': 2, 'Prosecutor': 3,
64 | 'Plaintiff': 4, 'Adjudicator': 5, 'Agent': 6, 'Vehicle': 7}
65 | LABEL_TO_ID_ROLE_UNVISUAL = {'Time-Holds': 8, 'Time-Starting': 10, 'Price': 11, 'Time-Before': 12,
66 | 'Time-After': 14, 'Time-Ending': 19, 'Time-At-End': 22,
67 | 'Time-At-Beginning': 28, 'Time-Within': 32, 'Org': 18,
68 | 'Sentence': 15, 'Crime': 16, 'Position': 24 } #
69 | LABEL_TO_ID_ROLE_ALL = {PAD_TOKEN: 0, UNK_TOKEN: 1, NEGATIVE_LABEL_ROLE: 2, 'Buyer': 3, 'Target': 4,
70 | 'Instrument': 5, 'Person': 6, 'Victim': 7,
71 | 'Time-Holds': 8, 'Attacker': 9, 'Time-Starting': 10, 'Price': 11, 'Time-Before': 12,
72 | 'Artifact': 13, 'Time-After': 14, 'Sentence': 15, 'Crime': 16, 'Destination': 17,
73 | 'Org': 18, 'Time-Ending': 19, 'Beneficiary': 20, 'Seller': 21, 'Time-At-End': 22,
74 | 'Recipient': 23, 'Position': 24, 'Money': 25, 'Giver': 26, 'Prosecutor': 27,
75 | 'Time-At-Beginning': 28, 'Entity': 29, 'Place': 30, 'Defendant': 31,
76 | 'Time-Within': 32, 'Plaintiff': 33, 'Adjudicator': 34, 'Agent': 35,
77 | 'Origin': 36, 'Vehicle': 37}
78 |
79 | # NEGATIVE_LABEL = 'no_relation'
80 | # LABEL_TO_ID = {'no_relation': 0, 'per:title': 1, 'org:top_members/employees': 2, 'per:employee_of': 3, 'org:alternate_names': 4, 'org:country_of_headquarters': 5, 'per:countries_of_residence': 6, 'org:city_of_headquarters': 7, 'per:cities_of_residence': 8, 'per:age': 9, 'per:stateorprovinces_of_residence': 10, 'per:origin': 11, 'org:subsidiaries': 12, 'org:parents': 13, 'per:spouse': 14, 'org:stateorprovince_of_headquarters': 15, 'per:children': 16, 'per:other_family': 17, 'per:alternate_names': 18, 'org:members': 19, 'per:siblings': 20, 'per:schools_attended': 21, 'per:parents': 22, 'per:date_of_death': 23, 'org:member_of': 24, 'org:founded_by': 25, 'org:website': 26, 'per:cause_of_death': 27, 'org:political/religious_affiliation': 28, 'org:founded': 29, 'per:city_of_death': 30, 'org:shareholders': 31, 'org:number_of_employees/members': 32, 'per:date_of_birth': 33, 'per:city_of_birth': 34, 'per:charges': 35, 'per:stateorprovince_of_death': 36, 'per:religion': 37, 'per:stateorprovince_of_birth': 38, 'per:country_of_birth': 39, 'org:dissolved': 40, 'per:country_of_death': 41}
81 |
82 | INFINITY_NUMBER = 1e12
83 |
84 | # TYPE_ROLE_MAP : in encoder_ont
85 | # not mask, zero-shot, so only list the ones that belongs to the ontology
86 |
--------------------------------------------------------------------------------
/data/ace/ace_sr_mapping.txt:
--------------------------------------------------------------------------------
1 | apprehending agent Justice||Arrest|Jail Agent
2 | arresting agent Justice||Arrest|Jail Agent
3 | catching agent Justice||Arrest|Jail Agent
4 | chasing agent Justice||Arrest|Jail Agent
5 | detaining agent Justice||Arrest|Jail Agent
6 | dragging agent Justice||Arrest|Jail Agent
7 | frisking agent Justice||Arrest|Jail Agent
8 | handcuffing agent Justice||Arrest|Jail Agent
9 | burying agent Life||Die Agent
10 | colliding agent Life||Die Agent
11 | crashing agent Life||Die Agent
12 | carrying agent Movement||Transport Agent
13 | carrying agentpart Movement||Transport Agent
14 | carrying item Movement||Transport Agent
15 | fetching agent Movement||Transport Agent
16 | hauling carrier Movement||Transport Agent
17 | lifting agent Movement||Transport Agent
18 | loading agent Movement||Transport Agent
19 | towing agent Movement||Transport Agent
20 | unloading agent Movement||Transport Agent
21 | unpacking agent Movement||Transport Agent
22 | disembarking agent Movement||Transport Agent
23 | landing agent Movement||Transport Agent
24 | marching agent Movement||Transport Agent
25 | piloting agent Movement||Transport Agent
26 | rafting agent Movement||Transport Agent
27 | rowing agent Movement||Transport Agent
28 | skidding agent Movement||Transport Agent
29 | taxiing agent Movement||Transport Agent
30 | wheeling agent Movement||Transport Agent
31 | boarding agent Movement||Transport Artifact
32 | boating boaters Movement||Transport Artifact
33 | fetching item Movement||Transport Artifact
34 | hauling item Movement||Transport Artifact
35 | lifting item Movement||Transport Artifact
36 | loading item Movement||Transport Artifact
37 | towing item Movement||Transport Artifact
38 | unloading item Movement||Transport Artifact
39 | unpacking item Movement||Transport Artifact
40 | wheeling item Movement||Transport Artifact
41 | aiming agent Conflict||Attack Attacker
42 | attacking agent Conflict||Attack Attacker
43 | burning agent Conflict||Attack Attacker
44 | butting agent Conflict||Attack Attacker
45 | deflecting agent Conflict||Attack Attacker
46 | destroying agent Conflict||Attack Attacker
47 | ejecting agent Conflict||Attack Attacker
48 | erupting agent Conflict||Attack Attacker
49 | flaming agent Conflict||Attack Attacker
50 | hitting agent Conflict||Attack Attacker
51 | launching agent Conflict||Attack Attacker
52 | punching agent Conflict||Attack Attacker
53 | ramming agent Conflict||Attack Attacker
54 | shooting agent Conflict||Attack Attacker
55 | slapping agent Conflict||Attack Attacker
56 | striking agent Conflict||Attack Attacker
57 | striking coagent Conflict||Attack Attacker
58 | subduing agent Conflict||Attack Attacker
59 | tackling agent Conflict||Attack Attacker
60 | launching source Conflict||Attack Attacker
61 | fetching destination Movement||Transport Destination
62 | lifting end Movement||Transport Destination
63 | loading destination Movement||Transport Destination
64 | landing destination Movement||Transport Destination
65 | piloting end Movement||Transport Destination
66 | confronting agent Conflict||Demonstrate Entity
67 | confronting confronted Conflict||Demonstrate Entity
68 | congregating individuals Conflict||Demonstrate Entity
69 | gathering gatherers Conflict||Demonstrate Entity
70 | parading agent Conflict||Demonstrate Entity
71 | protesting agent Conflict||Demonstrate Entity
72 | communicating adressee Contact||Meet Entity
73 | communicating agent Contact||Meet Entity
74 | discussing agents Contact||Meet Entity
75 | saluting agent Contact||Meet Entity
76 | saluting target Contact||Meet Entity
77 | scolding agent Contact||Meet Entity
78 | scolding victim Contact||Meet Entity
79 | shaking agent Contact||Meet Entity
80 | socializing agent Contact||Meet Entity
81 | socializing coagent Contact||Meet Entity
82 | talking agent Contact||Meet Entity
83 | talking listener Contact||Meet Entity
84 | calling agent Contact||Phone|Write Entity
85 | phoning agent Contact||Phone|Write Entity
86 | telephoning agent Contact||Phone|Write Entity
87 | writing agent Contact||Phone|Write Entity
88 | writing target Contact||Phone|Write Entity
89 | buying agent Transaction||Transfer|Money Giver
90 | paying agent Transaction||Transfer|Money Giver
91 | aiming item Conflict||Attack Instrument
92 | attacking weapon Conflict||Attack Instrument
93 | deflecting deflecteditem Conflict||Attack Instrument
94 | destroying tool Conflict||Attack Instrument
95 | ejecting item Conflict||Attack Instrument
96 | hitting tool Conflict||Attack Instrument
97 | launching item Conflict||Attack Instrument
98 | ramming rammingitem Conflict||Attack Instrument
99 | shooting firearm Conflict||Attack Instrument
100 | shooting projectile Conflict||Attack Instrument
101 | slapping tool Conflict||Attack Instrument
102 | striking tool Conflict||Attack Instrument
103 | calling tool Contact||Phone|Write Instrument
104 | phoning tool Contact||Phone|Write Instrument
105 | writing tool Contact||Phone|Write Instrument
106 | catching tool Justice||Arrest|Jail Instrument
107 | dragging tool Justice||Arrest|Jail Instrument
108 | burying tool Life||Die Instrument
109 | crashing item Life||Die Instrument
110 | buying payment Transaction||Transfer|MONEY Instrument
111 | fetching source Movement||Transport Origin
112 | lifting start Movement||Transport Origin
113 | unloading source Movement||Transport Origin
114 | piloting start Movement||Transport Origin
115 | dragging contact Justice||Arrest|Jail Person
116 | apprehending victim Justice||Arrest|Jail Person
117 | arresting suspect Justice||Arrest|Jail Person
118 | catching caughtitem Justice||Arrest|Jail Person
119 | chasing chasee Justice||Arrest|Jail Person
120 | detaining victim Justice||Arrest|Jail Person
121 | dragging item Justice||Arrest|Jail Person
122 | frisking victim Justice||Arrest|Jail Person
123 | handcuffing victim Justice||Arrest|Jail Person
124 | aiming place Conflict||Attack Place
125 | attacking place Conflict||Attack Place
126 | burning place Conflict||Attack Place
127 | butting place Conflict||Attack Place
128 | deflecting place Conflict||Attack Place
129 | destroying place Conflict||Attack Place
130 | ejecting place Conflict||Attack Place
131 | ejecting source Conflict||Attack Place
132 | erupting place Conflict||Attack Place
133 | flaming place Conflict||Attack Place
134 | hitting place Conflict||Attack Place
135 | launching place Conflict||Attack Place
136 | punching place Conflict||Attack Place
137 | ramming place Conflict||Attack Place
138 | slapping place Conflict||Attack Place
139 | striking place Conflict||Attack Place
140 | subduing place Conflict||Attack Place
141 | tackling place Conflict||Attack Place
142 | confronting place Conflict||Demonstrate Place
143 | congregating place Conflict||Demonstrate Place
144 | gathering place Conflict||Demonstrate Place
145 | parading place Conflict||Demonstrate Place
146 | protesting place Conflict||Demonstrate Place
147 | communicating place Contact||Meet Place
148 | discussing place Contact||Meet Place
149 | saluting place Contact||Meet Place
150 | scolding place Contact||Meet Place
151 | shaking place Contact||Meet Place
152 | socializing place Contact||Meet Place
153 | talking place Contact||Meet Place
154 | calling place Contact||Phone|Write Place
155 | phoning place Contact||Phone|Write Place
156 | telephoning place Contact||Phone|Write Place
157 | writing place Contact||Phone|Write Place
158 | apprehending place Justice||Arrest|Jail Place
159 | arresting place Justice||Arrest|Jail Place
160 | catching place Justice||Arrest|Jail Place
161 | chasing place Justice||Arrest|Jail Place
162 | detaining place Justice||Arrest|Jail Place
163 | dragging place Justice||Arrest|Jail Place
164 | frisking place Justice||Arrest|Jail Place
165 | handcuffing place Justice||Arrest|Jail Place
166 | burying destination Life||Die Place
167 | burying place Life||Die Place
168 | colliding place Life||Die Place
169 | crashing place Life||Die Place
170 | steering place Movement||Transport Place
171 | carrying place Movement||Transport Place
172 | fetching place Movement||Transport Place
173 | hauling place Movement||Transport Place
174 | lifting place Movement||Transport Place
175 | loading place Movement||Transport Place
176 | towing place Movement||Transport Place
177 | unloading place Movement||Transport Place
178 | unpacking place Movement||Transport Place
179 | boarding place Movement||Transport Place
180 | boating place Movement||Transport Place
181 | disembarking place Movement||Transport Place
182 | landing place Movement||Transport Place
183 | marching place Movement||Transport Place
184 | piloting place Movement||Transport Place
185 | rafting place Movement||Transport Place
186 | rowing place Movement||Transport Place
187 | skidding place Movement||Transport Place
188 | taxiing place Movement||Transport Place
189 | wheeling place Movement||Transport Place
190 | buying place Transaction||Transfer|Money Place
191 | paying place Transaction||Transfer|Money Place
192 | buying seller Transaction||Transfer|Money Recipient
193 | paying seller Transaction||Transfer|Money Recipient
194 | aiming target Conflict||Attack Target
195 | attacking victim Conflict||Attack Target
196 | burning target Conflict||Attack Target
197 | butting target Conflict||Attack Target
198 | destroying destroyeditem Conflict||Attack Target
199 | hitting victim Conflict||Attack Target
200 | hitting victimpart Conflict||Attack Target
201 | punching bodypart Conflict||Attack Target
202 | punching victim Conflict||Attack Target
203 | ramming victim Conflict||Attack Target
204 | shooting target Conflict||Attack Target
205 | slapping victim Conflict||Attack Target
206 | slapping victimpart Conflict||Attack Target
207 | striking agentpart Conflict||Attack Target
208 | subduing target Conflict||Attack Target
209 | tackling victim Conflict||Attack Target
210 | ejecting destination Conflict||Attack Target
211 | launching destination Conflict||Attack Target
212 | boarding vehicle Movement||Transport Vehicle
213 | boating vehicle Movement||Transport Vehicle
214 | steering agent Movement||Transport Vehicle
215 | steering tool Movement||Transport Vehicle
216 | steering vehicle Movement||Transport Vehicle
217 | hauling tool Movement||Transport Vehicle
218 | loading tool Movement||Transport Vehicle
219 | unloading tool Movement||Transport Vehicle
220 | unpacking container Movement||Transport Vehicle
221 | disembarking vehicle Movement||Transport Vehicle
222 | piloting vehicle Movement||Transport Vehicle
223 | rowing vehicle Movement||Transport Vehicle
224 | wheeling carrier Movement||Transport Vehicle
225 | burying item Life||Die Victim
226 | crashing against Life||Die Victim
227 | colliding item Life||Die Victim
--------------------------------------------------------------------------------
/src/dataflow/torch/Data.py:
--------------------------------------------------------------------------------
1 | import json
2 | # import ujson as json
3 | from collections import Counter, OrderedDict
4 |
5 | import codecs
6 | import six
7 | import torch
8 | from torchtext.data import Field, Example, Pipeline, Dataset
9 |
10 | # import sys
11 | # sys.path.append('/dvmm-filer2/users/manling/mm-event-graph2')
12 | from src.dataflow.torch.Corpus import Corpus
13 | from src.dataflow.torch.Sentence import Sentence_ace
14 |
15 |
16 | class SparseField(Field):
17 | def process(self, batch, device=None, train=True):
18 | return batch
19 |
20 |
21 | class EntityField(Field):
22 | '''
23 | Processing data each sentence has only one
24 |
25 | [(2, 3, "entity_type")]
26 | '''
27 |
28 | def preprocess(self, x):
29 | return x
30 |
31 | def pad(self, minibatch):
32 | return minibatch
33 |
34 | def numericalize(self, arr, device=None, train=True):
35 | return arr
36 |
37 |
38 | class EventField(Field):
39 | '''
40 | Processing data each sentence has only one
41 |
42 | {
43 | (2, 3, "event_type_str") --> [(1, 2, "role_type_str"), ...]
44 | ...
45 | }
46 | '''
47 |
48 | def preprocess(self, x):
49 | return x
50 |
51 | def build_vocab(self, *args, **kwargs):
52 | counter = Counter()
53 | sources = []
54 | for arg in args:
55 | if isinstance(arg, Dataset):
56 | sources += [getattr(arg, name) for name, field in
57 | arg.fields.items() if field is self]
58 | else:
59 | sources.append(arg)
60 | for data in sources:
61 | for x in data:
62 | for key, value in x.items():
63 | for v in value:
64 | counter.update([v[2]])
65 | self.vocab = self.vocab_cls(counter, specials=["OTHER"], **kwargs)
66 |
67 | def pad(self, minibatch):
68 | return minibatch
69 |
70 | def numericalize(self, arr, device=None, train=True):
71 | if self.use_vocab:
72 | # arr = [{key: [(v[0], v[1], self.vocab.stoi[v[2]]) for v in value] for key, value in dd.items()} for dd in
73 | # arr]
74 | arr_num = list()
75 | for dd in arr:
76 | dd_dict = dict()
77 | for key, value in dd.items():
78 | value_list = []
79 | for v in value:
80 | if v[2] in self.vocab.stoi:
81 | value_list.append( (v[0], v[1], self.vocab.stoi[v[2]]) )
82 | else:
83 | value_list.append( (v[0], v[1], 0) )
84 | dd_dict[key] = value_list
85 | arr_num.append(dd_dict)
86 | return arr_num
87 |
88 |
89 | class MultiTokenField(Field):
90 | '''
91 | Processing data like "[ ["A", "A", "A"], ["A", "A"], ["A", "A"], ["A"] ]"
92 | '''
93 |
94 | def preprocess(self, x):
95 | """Load a single example using this field, tokenizing if necessary.
96 |
97 | If the input is a Python 2 `str`, it will be converted to Unicode
98 | first. If `sequential=True`, it will be tokenized. Then the input
99 | will be optionally lowercased and passed to the user-provided
100 | `preprocessing` Pipeline."""
101 | if (six.PY2 and isinstance(x, six.string_types) and
102 | not isinstance(x, six.text_type)): # never
103 | x = Pipeline(lambda s: six.text_type(s, encoding='utf-8'))(x)
104 | if self.sequential and isinstance(x, six.text_type): # never
105 | x = self.tokenize(x.rstrip('\n'))
106 | if self.lower:
107 | x = [Pipeline(six.text_type.lower)(xx) for xx in x]
108 | if self.preprocessing is not None:
109 | return self.preprocessing(x)
110 | else:
111 | return x
112 |
113 | def build_vocab(self, *args, **kwargs):
114 | counter = Counter()
115 | sources = []
116 | for arg in args:
117 | if isinstance(arg, Dataset):
118 | sources += [getattr(arg, name) for name, field in
119 | arg.fields.items() if field is self]
120 | else:
121 | sources.append(arg)
122 | for data in sources:
123 | for x in data:
124 | if not self.sequential:
125 | x = [x]
126 | for xx in x:
127 | counter.update(xx)
128 | specials = list(OrderedDict.fromkeys(
129 | tok for tok in [self.unk_token, self.pad_token, self.init_token,
130 | self.eos_token]
131 | if tok is not None))
132 | self.vocab = self.vocab_cls(counter, specials=specials, **kwargs)
133 |
134 | def pad(self, minibatch):
135 | minibatch = list(minibatch)
136 | if not self.sequential:
137 | return minibatch
138 | if self.fix_length is None:
139 | max_len = max(len(x) for x in minibatch)
140 | else:
141 | max_len = self.fix_length + (
142 | self.init_token, self.eos_token).count(None) - 2
143 | padded, lengths = [], []
144 | for x in minibatch:
145 | if self.pad_first:
146 | padded.append(
147 | [[self.pad_token]] * max(0, max_len - len(x)) +
148 | ([] if self.init_token is None else [[self.init_token]]) +
149 | list(x[-max_len:] if self.truncate_first else x[:max_len]) +
150 | ([] if self.eos_token is None else [[self.eos_token]]))
151 | else:
152 | padded.append(
153 | ([] if self.init_token is None else [[self.init_token]]) +
154 | list(x[-max_len:] if self.truncate_first else x[:max_len]) +
155 | ([] if self.eos_token is None else [[self.eos_token]]) +
156 | [[self.pad_token]] * max(0, max_len - len(x)))
157 | lengths.append(len(padded[-1]) - max(0, max_len - len(x)))
158 |
159 | if self.include_lengths:
160 | return (padded, lengths)
161 | return padded
162 |
163 | def numericalize(self, arr, device=None, train=True):
164 | if self.include_lengths and not isinstance(arr, tuple):
165 | raise ValueError("Field has include_lengths set to True, but "
166 | "input data is not a tuple of "
167 | "(data batch, batch lengths).")
168 | if isinstance(arr, tuple):
169 | arr, lengths = arr
170 | lengths = torch.LongTensor(lengths)
171 |
172 | if self.use_vocab:
173 | if self.sequential:
174 | arr = [[[self.vocab.stoi[xx] for xx in x] for x in ex] for ex in arr]
175 |
176 | if self.postprocessing is not None:
177 | arr = self.postprocessing(arr, self.vocab, train)
178 |
179 | if self.include_lengths:
180 | return arr, lengths
181 | return arr
182 |
183 |
184 | class ACE2005Dataset(Corpus):
185 | """
186 | Defines a dataset composed of Examples along with its Fields.
187 | """
188 |
189 | sort_key = None
190 |
191 | def __init__(self, path, fields, amr=False, keep_events=None, only_keep=False, **kwargs):
192 | '''
193 | Create a corpus given a path, field list, and a filter function.
194 |
195 | :param path: str, Path to the data file
196 | :param fields: dict[str: tuple(str, Field)],
197 | If using a dict, the keys should be a subset of the JSON keys or CSV/TSV
198 | columns, and the values should be tuples of (name, field).
199 | Keys not present in the input dictionary are ignored.
200 | This allows the user to rename columns from their JSON/CSV/TSV key names
201 | and also enables selecting a subset of columns to load.
202 | :param keep_events: int, minimum sentence events. Default keep all.
203 | '''
204 | self.keep_events = keep_events
205 | self.only_keep = only_keep
206 | super(ACE2005Dataset, self).__init__(path, fields, amr, **kwargs)
207 |
208 | def parse_example(self, path, fields, amr, **kwargs):
209 | examples = []
210 |
211 | _file = codecs.open(path, 'r', 'utf-8')
212 | jl = json.load(_file)
213 | print(path, len(jl))
214 | for js in jl:
215 | ex = self.parse_sentence(js, fields, amr)
216 | if ex is not None:
217 | examples.append(ex)
218 | # for line in f:
219 | # line = line.strip()
220 | # if len(line) == 0:
221 | # continue
222 | # print(line)
223 | # jl = json.loads(line, encoding="utf-8")
224 | # for js in jl:
225 | # ex = self.parse_sentence(js, fields)
226 | # if ex is not None:
227 | # examples.append(ex)
228 |
229 | return examples
230 |
231 | def parse_sentence(self, js, fields, amr):
232 | SENTID = fields["sentence_id"]
233 | WORDS = fields["words"]
234 | POSTAGS = fields["pos-tags"]
235 | # LEMMAS = fields["lemma"]
236 | ENTITYLABELS = fields["golden-entity-mentions"]
237 | if amr:
238 | colcc = "simple-parsing"
239 | else:
240 | colcc = "combined-parsing"
241 | # print(colcc)
242 | ADJMATRIX = fields[colcc]
243 | LABELS = fields["golden-event-mentions"]
244 | EVENTS = fields["all-events"]
245 | ENTITIES = fields["all-entities"]
246 |
247 | sentence = Sentence_ace(json_content=js, graph_field_name=colcc)
248 | ex = Example()
249 | # print('sentence.wordList', WORDS[1].preprocess(sentence.wordList))
250 | setattr(ex, SENTID[0], SENTID[1].preprocess(sentence.sentence_id))
251 | setattr(ex, WORDS[0], WORDS[1].preprocess(sentence.wordList))
252 | setattr(ex, POSTAGS[0], POSTAGS[1].preprocess(sentence.posLabelList))
253 | # setattr(ex, LEMMAS[0], LEMMAS[1].preprocess(sentence.lemmaList))
254 | setattr(ex, ENTITYLABELS[0], ENTITYLABELS[1].preprocess(sentence.entityLabelList))
255 | setattr(ex, ADJMATRIX[0], (sentence.adjpos, sentence.adjv))
256 | setattr(ex, LABELS[0], LABELS[1].preprocess(sentence.triggerLabelList))
257 | setattr(ex, EVENTS[0], EVENTS[1].preprocess(sentence.events))
258 | setattr(ex, ENTITIES[0], ENTITIES[1].preprocess(sentence.entities))
259 |
260 | if self.keep_events is not None:
261 | if self.only_keep and sentence.containsEvents != self.keep_events:
262 | return None
263 | elif not self.only_keep and sentence.containsEvents < self.keep_events:
264 | return None
265 | else:
266 | return ex
267 | else:
268 | return ex
269 |
270 | def longest(self):
271 | return max([len(x.POSTAGS) for x in self.examples])
272 |
--------------------------------------------------------------------------------
/src/dataflow/torch/Sentence.py:
--------------------------------------------------------------------------------
1 | # import sys
2 | # sys.path.append('/dvmm-filer2/users/manling/mm-event-graph2')
3 |
4 | from src.util.consts import CUTOFF
5 | # from PIL import Image
6 | # import os
7 |
8 |
9 | def pretty_str(a):
10 | a = a.upper()
11 | if a == 'O':
12 | return a
13 | elif a[1] == '-':
14 | return a[:2] + "|".join(a[2:].split("-")).replace(":", "||")
15 | else:
16 | return "|".join(a.split("-")).replace(":", "||")
17 |
18 |
19 | class Sentence:
20 | def __init__(self, json_content, with_sentid=False):
21 | # self.wordList = json_content["words"][:CUTOFF]
22 | # self.posLabelList = json_content["pos-tags"][:CUTOFF]
23 | # # self.lemmaList = json_content["lemma"][:CUTOFF]
24 | # self.length = len(self.wordList)
25 | #
26 | # self.entityLabelList = self.generateEntityLabelList(json_content["golden-entity-mentions"])
27 | # # self.triggerLabelList = self.generateTriggerLabelList(json_content["golden-event-mentions"])
28 | # self.adjpos, self.adjv = self.generateAdjMatrix(json_content[graph_field_name])
29 | #
30 | # self.entities = self.generateGoldenEntities(json_content["golden-entity-mentions"])
31 | # # self.events = self.generateGoldenEvents(json_content["golden-event-mentions"])
32 | #
33 | # # self.containsEvents = len(json_content["golden-event-mentions"])
34 | # self.tokenList = self.makeTokenList()
35 | self.json_content = json_content
36 | self.with_sentid = with_sentid
37 |
38 | def generateEntityLabelList(self, entitiesJsonList):
39 | '''
40 | Keep the overlapping entity labels
41 | :param entitiesJsonList:
42 | :return:
43 | '''
44 |
45 | entityLabel = [["O"] for _ in range(self.length)]
46 |
47 | def assignEntityLabel(index, label):
48 | if index >= CUTOFF:
49 | return
50 | if len(entityLabel[index]) == 1 and entityLabel[index][0] == "O":
51 | entityLabel[index][0] = pretty_str(label)
52 | else:
53 | entityLabel[index].append(pretty_str(label))
54 |
55 | for entityJson in entitiesJsonList:
56 | start = entityJson["start"]
57 | end = entityJson["end"]
58 | etype = entityJson["entity-type"].split(":")[0]
59 | assignEntityLabel(start, "B-" + etype)
60 | for i in range(start + 1, end):
61 | assignEntityLabel(i, "I-" + etype)
62 |
63 | return entityLabel
64 |
65 | def generateGoldenEntities(self, entitiesJson):
66 | '''
67 | [(2, 3, "entity_type")]
68 | '''
69 | golden_list = []
70 | for entityJson in entitiesJson:
71 | start = entityJson["start"]
72 | if start >= CUTOFF:
73 | continue
74 | end = min(entityJson["end"], CUTOFF)
75 | etype = entityJson["entity-type"].split(":")[0]
76 | golden_list.append((start, end, etype))
77 | return golden_list
78 |
79 | def generateGoldenEvents(self, eventsJson, with_sentid=False, sent_id=None):
80 | '''
81 |
82 | {
83 | (2, 3, "event_type_str") --> [(1, 2, "role_type_str"), ...]
84 | ...
85 | }
86 |
87 | '''
88 | golden_dict = {}
89 | for eventJson in eventsJson:
90 | triggerJson = eventJson["trigger"]
91 | if triggerJson["start"] >= CUTOFF:
92 | continue
93 | if with_sentid:
94 | key = ('%s__%d' % (sent_id, triggerJson["start"]),
95 | min(triggerJson["end"], CUTOFF), pretty_str(eventJson["event_type"]))
96 | else:
97 | key = (triggerJson["start"],
98 | min(triggerJson["end"], CUTOFF), pretty_str(eventJson["event_type"]))
99 | values = []
100 | for argumentJson in eventJson["arguments"]:
101 | if argumentJson["start"] >= CUTOFF:
102 | continue
103 | value = (argumentJson["start"], min(argumentJson["end"], CUTOFF), pretty_str(argumentJson["role"]))
104 | values.append(value)
105 | golden_dict[key] = list(sorted(values))
106 | return golden_dict
107 |
108 | def generateTriggerLabelList(self, triggerJsonList):
109 | triggerLabel = ["O" for _ in range(self.length)]
110 |
111 | def assignTriggerLabel(index, label):
112 | if index >= CUTOFF:
113 | return
114 | triggerLabel[index] = pretty_str(label)
115 |
116 | for eventJson in triggerJsonList:
117 | triggerJson = eventJson["trigger"]
118 | start = triggerJson["start"]
119 | end = triggerJson["end"]
120 | etype = eventJson["event_type"]
121 | assignTriggerLabel(start, "B-" + etype)
122 | for i in range(start + 1, end):
123 | assignTriggerLabel(i, "I-" + etype)
124 | return triggerLabel
125 |
126 | def generateAdjMatrix(self, edgeJsonList):
127 | sparseAdjMatrixPos = [[], [], []]
128 | sparseAdjMatrixValues = []
129 |
130 | def addedge(type_, from_, to_, value_):
131 | sparseAdjMatrixPos[0].append(type_)
132 | sparseAdjMatrixPos[1].append(from_)
133 | sparseAdjMatrixPos[2].append(to_)
134 | sparseAdjMatrixValues.append(value_)
135 |
136 | for edgeJson in edgeJsonList:
137 | ss = edgeJson.split("/")
138 | fromIndex = int(ss[-1].split("=")[-1])
139 | toIndex = int(ss[-2].split("=")[-1])
140 | etype = ss[0]
141 | if etype.lower() == "root" or fromIndex == -1 or toIndex == -1 or fromIndex >= CUTOFF or toIndex >= CUTOFF:
142 | continue
143 | addedge(0, fromIndex, toIndex, 1.0)
144 | addedge(1, toIndex, fromIndex, 1.0)
145 |
146 | for i in range(self.length):
147 | addedge(2, i, i, 1.0)
148 |
149 | return sparseAdjMatrixPos, sparseAdjMatrixValues
150 |
151 | # def makeTokenList(self):
152 | # # return [Token(self.wordList[i], self.posLabelList[i], self.lemmaList[i], self.entityLabelList[i],
153 | # # self.triggerLabelList[i])
154 | # # for i in range(self.length)]
155 | # return [Token(self.wordList[i], self.posLabelList[i], self.entityLabelList[i],
156 | # self.triggerLabelList[i])
157 | # for i in range(self.length)]
158 |
159 | def __len__(self):
160 | return self.length
161 |
162 | def __iter__(self):
163 | for x in self.tokenList:
164 | yield x
165 |
166 | def __getitem__(self, index):
167 | return self.tokenList[index]
168 |
169 |
170 | class Sentence_ace(Sentence):
171 | def __init__(self, json_content, graph_field_name):
172 | Sentence.__init__(self, json_content)
173 | if "sentence_id" in json_content:
174 | self.sentence_id = json_content["sentence_id"]
175 | else:
176 | self.sentence_id = "none"
177 | self.wordList = json_content["words"][:CUTOFF]
178 | self.posLabelList = json_content["pos-tags"][:CUTOFF]
179 | # self.lemmaList = json_content["lemma"][:CUTOFF]
180 | self.length = len(self.wordList)
181 |
182 | if "golden-entity-mentions" in json_content:
183 | self.entityLabelList = self.generateEntityLabelList(json_content["golden-entity-mentions"])
184 | self.entities = self.generateGoldenEntities(json_content["golden-entity-mentions"])
185 | else:
186 | self.entityLabelList = self.generateEntityLabelList(list())
187 | self.entities = self.generateGoldenEntities(list())
188 | if "golden-event-mentions" in json_content:
189 | self.triggerLabelList = self.generateTriggerLabelList(json_content["golden-event-mentions"])
190 | self.events = self.generateGoldenEvents(json_content["golden-event-mentions"])
191 | self.containsEvents = len(json_content["golden-event-mentions"])
192 | else:
193 | self.triggerLabelList = self.generateTriggerLabelList(list())
194 | self.events = self.generateGoldenEvents(list())
195 | self.containsEvents = 0
196 |
197 | self.adjpos, self.adjv = self.generateAdjMatrix(json_content[graph_field_name])
198 | self.tokenList = self.makeTokenList()
199 |
200 | def makeTokenList(self):
201 | # return [Token(self.wordList[i], self.posLabelList[i], self.lemmaList[i], self.entityLabelList[i],
202 | # self.triggerLabelList[i])
203 | # for i in range(self.length)]
204 | return [Token(self.wordList[i], self.posLabelList[i], self.entityLabelList[i],
205 | self.triggerLabelList[i])
206 | for i in range(self.length)]
207 |
208 |
209 | class Sentence_grounding(Sentence):
210 | def __init__(self, json_content, graph_field_name, img_dir, transform=None):
211 | Sentence.__init__(self, json_content)
212 | self.image_id = json_content["image"]
213 | self.sentence_id = json_content["sentence_id"]
214 | self.wordList = json_content["words"][:CUTOFF]
215 | self.posLabelList = json_content["pos-tags"][:CUTOFF]
216 | # self.lemmaList = json_content["lemma"][:CUTOFF]
217 | self.length = len(self.wordList)
218 |
219 | self.entityLabelList = self.generateEntityLabelList(json_content["golden-entity-mentions"])
220 | self.adjpos, self.adjv = self.generateAdjMatrix(json_content[graph_field_name])
221 |
222 | self.entities = self.generateGoldenEntities(json_content["golden-entity-mentions"])
223 |
224 | self.tokenList = self.makeTokenList()
225 |
226 | # get the image vectors
227 | # img_path = os.path.join(img_dir, self.image_id)
228 | # try:
229 | # self.image_vec = Image.open(img_path).convert('RGB')
230 | # if transform is not None:
231 | # self.image_vec = transform(self.image_vec)
232 | # except:
233 | # self.image_vec = None
234 |
235 | def makeTokenList(self):
236 | return [Token(self.wordList[i], self.posLabelList[i], self.entityLabelList[i])
237 | for i in range(self.length)]
238 |
239 | class Sentence_m2e2(Sentence):
240 | def __init__(self, json_content, graph_field_name, with_sentid=False):
241 | Sentence.__init__(self, json_content, with_sentid=with_sentid)
242 | self.image_id = json_content["image"]
243 | self.sentence_id = json_content["sentence_id"]
244 | self.wordList = json_content["words"][:CUTOFF]
245 | self.posLabelList = json_content["pos-tags"][:CUTOFF]
246 | self.length = len(self.wordList)
247 |
248 | self.entityLabelList = self.generateEntityLabelList(json_content["golden-entity-mentions"])
249 | self.triggerLabelList = self.generateTriggerLabelList(json_content["golden-event-mentions"])
250 | self.adjpos, self.adjv = self.generateAdjMatrix(json_content[graph_field_name])
251 |
252 | self.entities = self.generateGoldenEntities(json_content["golden-entity-mentions"])
253 | self.events = self.generateGoldenEvents(json_content["golden-event-mentions"], self.with_sentid, self.sentence_id)
254 |
255 | self.containsEvents = len(json_content["golden-event-mentions"])
256 | self.tokenList = self.makeTokenList()
257 |
258 | def makeTokenList(self):
259 | return [Token(self.wordList[i], self.posLabelList[i], self.entityLabelList[i],
260 | self.triggerLabelList[i])
261 | for i in range(self.length)]
262 |
263 |
264 | class Token:
265 | # def __init__(self, word, posLabel, lemmaLabel, entityLabel, triggerLabel):
266 | def __init__(self, word, posLabel, entityLabel, triggerLabel=None):
267 | self.word = word
268 | self.posLabel = posLabel
269 | # self.lemmaLabel = lemmaLabel
270 | self.entityLabel = entityLabel
271 | self.triggerLabel = triggerLabel
272 | self.predictedLabel = None
273 |
274 | def addPredictedLabel(self, label):
275 | self.predictedLabel = label
--------------------------------------------------------------------------------
/src/engine/TestRunnerEE.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import pickle
4 | import json
5 | import sys
6 | from functools import partial
7 |
8 | import numpy as np
9 | import torch
10 | from tensorboardX import SummaryWriter
11 | from torchtext.data import Field
12 | from torchtext.vocab import Vectors
13 | from torchtext.data import BucketIterator
14 | from math import ceil
15 |
16 | import sys
17 | #export PATH=/dvmm-filer2/users/manling/mm-event-graph2:$PATH
18 | sys.path.append('../..')
19 |
20 | from src.util import consts
21 | from src.dataflow.torch.Data import ACE2005Dataset, MultiTokenField, SparseField, EventField, EntityField
22 | from src.models.ee import EDModel
23 | from src.eval.EEtesting import EDTester
24 | from src.eval.EEvisualizing import EDVisualizer
25 | from src.engine.EEtraining import ee_train
26 | from src.util.util_model import log
27 | from src.engine.EEtraining import run_over_data
28 | from src.engine.EErunner import load_ee_model
29 | from src.util.util_model import progressbar
30 | from src.dataflow.torch.Sentence import Token
31 | from src.engine.EErunner import event_role_mask
32 |
33 |
34 | class EERunnerTest(object):
35 | def __init__(self):
36 | parser = argparse.ArgumentParser(description="neural networks trainer")
37 | parser.add_argument("--test_ee", help="event extraction validation set")
38 | parser.add_argument("--train_ee", help="event extraction training set", required=False)
39 | parser.add_argument("--dev_ee", help="event extraction development set", required=False)
40 | parser.add_argument("--webd", help="word embedding", required=False)
41 | parser.add_argument("--ignore_time_test", help="testing ignore place in sr model", action='store_true')
42 |
43 | parser.add_argument("--batch", help="batch size", default=128, type=int)
44 | # parser.add_argument("--epochs", help="n of epochs", default=sys.maxsize, type=int)
45 |
46 | parser.add_argument("--seed", help="RNG seed", default=1111, type=int)
47 | # parser.add_argument("--optimizer", default="adam")
48 | # parser.add_argument("--lr", default=1, type=float)
49 | # parser.add_argument("--l2decay", default=0, type=float)
50 | parser.add_argument("--maxnorm", default=3, type=float)
51 |
52 | parser.add_argument("--out", help="output model path", default="out")
53 | parser.add_argument("--finetune", help="pretrained model path")
54 | # parser.add_argument("--earlystop", default=999999, type=int)
55 | # parser.add_argument("--restart", default=999999, type=int)
56 | # parser.add_argument("--shuffle", help="shuffle", action='store_true')
57 | parser.add_argument("--amr", help="use amr", action='store_true')
58 |
59 | parser.add_argument("--device", default="cpu")
60 | parser.add_argument("--hps_path", help="model hyperparams", required=False)
61 | parser.add_argument("--hps", help="model hyperparams", required=False)
62 |
63 | self.a = parser.parse_args()
64 | if self.a.hps_path:
65 | self.a.hps = json.load(open(self.a.hps_path))
66 | print(self.a.hps)
67 |
68 | def set_device(self, device="cpu"):
69 | # self.device = torch.device(device)
70 | self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
71 |
72 | def get_device(self):
73 | return self.device
74 |
75 | def get_tester(self, voc_i2s, voc_role_i2s):
76 | return EDTester(voc_i2s, voc_role_i2s, self.a.ignore_time_test)
77 |
78 | def run(self):
79 | print("Running on", self.a.device)
80 | self.set_device(self.a.device)
81 |
82 | np.random.seed(self.a.seed)
83 | torch.manual_seed(self.a.seed)
84 | torch.backends.cudnn.benchmark = True
85 |
86 | # create training set
87 | if self.a.test_ee:
88 | log('loading event extraction corpus from %s' % self.a.test_ee)
89 |
90 | WordsField = Field(lower=True, include_lengths=True, batch_first=True)
91 | PosTagsField = Field(lower=True, batch_first=True)
92 | EntityLabelsField = MultiTokenField(lower=False, batch_first=True)
93 | AdjMatrixField = SparseField(sequential=False, use_vocab=False, batch_first=True)
94 | LabelField = Field(lower=False, batch_first=True, pad_token='0', unk_token=None)
95 | EventsField = EventField(lower=False, batch_first=True)
96 | EntitiesField = EntityField(lower=False, batch_first=True, use_vocab=False)
97 | SENTIDField = SparseField(sequential=False, use_vocab=False, batch_first=True)
98 | if self.a.amr:
99 | colcc = 'simple-parsing'
100 | else:
101 | colcc = 'combined-parsing'
102 | print(colcc)
103 |
104 | train_ee_set = ACE2005Dataset(path=self.a.train_ee,
105 | fields={"sentence_id": ("SENTID", SENTIDField), "words": ("WORDS", WordsField),
106 | "pos-tags": ("POSTAGS", PosTagsField),
107 | "golden-entity-mentions": ("ENTITYLABELS", EntityLabelsField),
108 | colcc: ("ADJM", AdjMatrixField),
109 | "golden-event-mentions": ("LABEL", LabelField),
110 | "all-events": ("EVENT", EventsField),
111 | "all-entities": ("ENTITIES", EntitiesField)},
112 | amr=self.a.amr, keep_events=1)
113 |
114 | dev_ee_set = ACE2005Dataset(path=self.a.dev_ee,
115 | fields={"sentence_id": ("SENTID", SENTIDField), "words": ("WORDS", WordsField),
116 | "pos-tags": ("POSTAGS", PosTagsField),
117 | "golden-entity-mentions": ("ENTITYLABELS", EntityLabelsField),
118 | colcc: ("ADJM", AdjMatrixField),
119 | "golden-event-mentions": ("LABEL", LabelField),
120 | "all-events": ("EVENT", EventsField),
121 | "all-entities": ("ENTITIES", EntitiesField)},
122 | amr=self.a.amr, keep_events=0)
123 |
124 | test_ee_set = ACE2005Dataset(path=self.a.test_ee,
125 | fields={"sentence_id": ("SENTID", SENTIDField), "words": ("WORDS", WordsField),
126 | "pos-tags": ("POSTAGS", PosTagsField),
127 | "golden-entity-mentions": ("ENTITYLABELS", EntityLabelsField),
128 | colcc: ("ADJM", AdjMatrixField),
129 | "golden-event-mentions": ("LABEL", LabelField),
130 | "all-events": ("EVENT", EventsField),
131 | "all-entities": ("ENTITIES", EntitiesField)},
132 | amr=self.a.amr, keep_events=0)
133 |
134 | if self.a.webd:
135 | pretrained_embedding = Vectors(self.a.webd, ".", unk_init=partial(torch.nn.init.uniform_, a=-0.15, b=0.15))
136 | WordsField.build_vocab(train_ee_set.WORDS, dev_ee_set.WORDS, vectors=pretrained_embedding)
137 | else:
138 | WordsField.build_vocab(train_ee_set.WORDS, dev_ee_set.WORDS)
139 | PosTagsField.build_vocab(train_ee_set.POSTAGS, dev_ee_set.POSTAGS)
140 | EntityLabelsField.build_vocab(train_ee_set.ENTITYLABELS, dev_ee_set.ENTITYLABELS)
141 | LabelField.build_vocab(train_ee_set.LABEL, dev_ee_set.LABEL)
142 | EventsField.build_vocab(train_ee_set.EVENT, dev_ee_set.EVENT)
143 | consts.O_LABEL = LabelField.vocab.stoi[consts.O_LABEL_NAME]
144 | # print("O label is", consts.O_LABEL)
145 | consts.ROLE_O_LABEL = EventsField.vocab.stoi[consts.ROLE_O_LABEL_NAME]
146 | # print("O label for AE is", consts.ROLE_O_LABEL)
147 |
148 | self.a.label_weight = torch.ones([len(LabelField.vocab.itos)]) * 5
149 | self.a.label_weight[consts.O_LABEL] = 1.0
150 | self.a.arg_weight = torch.ones([len(EventsField.vocab.itos)]) * 5
151 | # add role mask
152 | self.a.role_mask = event_role_mask(self.a.test_ee, self.a.train_ee, self.a.dev_ee, LabelField.vocab.stoi,
153 | EventsField.vocab.stoi, self.device)
154 | # print('self.a.hps', self.a.hps)
155 | if not self.a.hps_path:
156 | self.a.hps = eval(self.a.hps)
157 | if "wemb_size" not in self.a.hps:
158 | self.a.hps["wemb_size"] = len(WordsField.vocab.itos)
159 | if "pemb_size" not in self.a.hps:
160 | self.a.hps["pemb_size"] = len(PosTagsField.vocab.itos)
161 | if "psemb_size" not in self.a.hps:
162 | self.a.hps["psemb_size"] = max([train_ee_set.longest(), dev_ee_set.longest(), test_ee_set.longest()]) + 2
163 | if "eemb_size" not in self.a.hps:
164 | self.a.hps["eemb_size"] = len(EntityLabelsField.vocab.itos)
165 | if "oc" not in self.a.hps:
166 | self.a.hps["oc"] = len(LabelField.vocab.itos)
167 | if "ae_oc" not in self.a.hps:
168 | self.a.hps["ae_oc"] = len(EventsField.vocab.itos)
169 |
170 | tester = self.get_tester(LabelField.vocab.itos, EventsField.vocab.itos)
171 | visualizer = EDVisualizer(self.a.test_ee)
172 |
173 | if self.a.finetune:
174 | log('init model from ' + self.a.finetune)
175 | model = load_ee_model(self.a.hps, self.a.finetune, WordsField.vocab.vectors, self.device)
176 | log('model loaded, there are %i sets of params' % len(model.parameters_requires_grads()))
177 | else:
178 | model = load_ee_model(self.a.hps, None, WordsField.vocab.vectors, self.device)
179 | log('model created from scratch, there are %i sets of params' % len(model.parameters_requires_grads()))
180 |
181 | self.a.word_i2s = WordsField.vocab.itos
182 | self.a.label_i2s = LabelField.vocab.itos
183 | self.a.role_i2s = EventsField.vocab.itos
184 | writer = SummaryWriter(os.path.join(self.a.out, "exp"))
185 | self.a.writer = writer
186 |
187 | # train_iter = BucketIterator(train_ee_set, batch_size=self.a.batch,
188 | # train=True, shuffle=False, device=-1,
189 | # sort_key=lambda x: len(x.POSTAGS))
190 | # dev_iter = BucketIterator(dev_ee_set, batch_size=self.a.batch, train=False,
191 | # shuffle=False, device=-1,
192 | # sort_key=lambda x: len(x.POSTAGS))
193 | test_iter = BucketIterator(test_ee_set, batch_size=self.a.batch, train=False,
194 | shuffle=False, device=-1,
195 | sort_key=lambda x: len(x.POSTAGS))
196 |
197 | print("\nStarting testing ...\n")
198 |
199 | # Testing Phrase
200 | test_loss, test_ed_p, test_ed_r, test_ed_f1, \
201 | test_ae_p, test_ae_r, test_ae_f1 = run_over_data(data_iter=test_iter,
202 | optimizer=None,
203 | model=model,
204 | need_backward=False,
205 | MAX_STEP=len(test_iter),
206 | tester=tester,
207 | visualizer=visualizer,
208 | hyps=model.hyperparams,
209 | device=model.device,
210 | maxnorm=self.a.maxnorm,
211 | word_i2s=self.a.word_i2s,
212 | label_i2s=self.a.label_i2s,
213 | role_i2s=self.a.role_i2s,
214 | weight=self.a.label_weight,
215 | arg_weight=self.a.arg_weight,
216 | save_output=os.path.join(
217 | self.a.out,
218 | "test_final.txt"),
219 | role_mask=self.a.role_mask)
220 |
221 | print("\nFinally test loss: ", test_loss,
222 | "\ntest ed p: ", test_ed_p,
223 | " test ed r: ", test_ed_r,
224 | " test ed f1: ", test_ed_f1,
225 | "\ntest ae p: ", test_ae_p,
226 | " test ae r: ", test_ae_r,
227 | " test ae f1: ", test_ae_f1)
228 |
229 | if __name__ == "__main__":
230 | EERunnerTest().run()
--------------------------------------------------------------------------------
/src/eval/EEvisualizing.py:
--------------------------------------------------------------------------------
1 | import json
2 | from collections import defaultdict
3 | import os
4 | import sys
5 | sys.path.append('../../')
6 | from src.dataflow.numpy.anno_mapping import entity_type_mapping_brat, event_type_mapping_ace2brat, event_role_mapping_ace2brat, event_role_mapping_brat2ace
7 | import shutil
8 |
9 | class EDVisualizer():
10 | def __init__(self, JMEE_data_json):
11 | self.sent_info = defaultdict(lambda : defaultdict())
12 | self._load_sent_info(JMEE_data_json)
13 |
14 | def _load_sent_info(self, JMEE_data_json):
15 | raw_data = json.load(open(JMEE_data_json))
16 | for data_instance in raw_data:
17 | sent_id = data_instance['sentence_id']
18 | words = data_instance['words']
19 | # index_all = data_instance['index_all']
20 | # index = data_instance['index']
21 | # entity = data_instance['golden-entity-mentions']
22 | sentence = data_instance['sentence']
23 |
24 | self.sent_info[sent_id]['words'] = words
25 | # self.sent_info[sent_id]['index_all'] = index_all
26 | # self.sent_info[sent_id]['index'] = index
27 | # self.sent_info[sent_id]['entity'] = entity
28 | self.sent_info[sent_id]['sentence'] = sentence
29 |
30 |
31 | def save_json(self, predicted_event_triggers, predicted_events, sent_id, ee_role_i2s, text_result_json):
32 | sent_id = sent_id[0]
33 | for result_dict in predicted_events:
34 | for key in result_dict:
35 | # print('self.sent_info[sent_id]', self.sent_info)
36 | # print('sent_id', sent_id)
37 | eventstart_sentid, event_end, event_type_ace = key
38 | trigger_word = ' '.join(self.sent_info[sent_id]['words'][int(eventstart_sentid):int(event_end)])
39 | text_result_json[eventstart_sentid]['sentence_id'] = sent_id
40 | text_result_json[eventstart_sentid]['sentence'] = self.sent_info[sent_id]['sentence']
41 | text_result_json[eventstart_sentid]['sentence_tokens'] = self.sent_info[sent_id]['words']
42 | text_result_json[eventstart_sentid]['pred_event_type'] = event_type_ace
43 | text_result_json[eventstart_sentid]['pred_trigger'] = {'index_start': eventstart_sentid, 'index_end':event_end, 'text':trigger_word}
44 | text_result_json[eventstart_sentid]['pred_roles'] = defaultdict(list)
45 |
46 | args = result_dict[key]
47 | for arg_start, arg_end, arg_type in args:
48 | arg_word = ' '.join(self.sent_info[sent_id]['words'][int(arg_start):int(arg_end)])
49 | # arg_start_offset, arg_end_offset = self.get_offset_by_idx(sent_id, arg_start, arg_end)
50 | role_name_ace = ee_role_i2s[arg_type]
51 | text_result_json[eventstart_sentid]['pred_roles'][role_name_ace].append( {'index_start':arg_start, 'index_end':arg_end, 'text':arg_word} )
52 |
53 |
54 |
55 | def visualize_html(self, predicted_event_triggers, predicted_events, sent_id, ee_role_i2s, visual_writer):
56 | for eventstart_sentid in predicted_event_triggers:
57 | # print('eventstart_sentid', eventstart_sentid)
58 | event_end, event_type_ace = predicted_event_triggers[eventstart_sentid]
59 | # if '__' in eventstart_sentid:
60 | event_start = int(eventstart_sentid[eventstart_sentid.find('__')+2:])
61 | # print('predicted_event_triggers', predicted_event_triggers)
62 | # event_start_offset, event_end_offset = self.get_offset_by_idx(sent_id, event_start, event_end)
63 | visual_writer.write("sentence: %s
\n" % self.sent_info[sent_id]['sentence'])
64 | trigger_word = ' '.join(self.sent_info[sent_id]['words'][event_start:event_end]) # multiple words]
65 | visual_writer.write("event type: %s
\n" % event_type_ace)
66 | visual_writer.write("trigger word: %s
\n" % trigger_word)
67 | if (eventstart_sentid, event_end, event_type_ace) in predicted_events:
68 | args = predicted_events[(eventstart_sentid, event_end, event_type_ace)]
69 | for arg_start, arg_end, arg_type in args:
70 | arg_word = ' '.join(self.sent_info[sent_id]['words'][arg_start:arg_end])
71 | # arg_start_offset, arg_end_offset = self.get_offset_by_idx(sent_id, arg_start, arg_end)
72 | role_name_ace = ee_role_i2s[arg_type]
73 | visual_writer.write(" [%s]: %s
\n" % (role_name_ace, arg_word))
74 | visual_writer.write('
\n
\n')
75 |
76 | visual_writer.flush()
77 | visual_writer.close()
78 |
79 |
80 | def visualize_brat(self, predicted_event_triggers, predicted_events, sent_id, ee_role_i2s, visual_writer, save_entity=False):
81 | print('predicted_event_triggers', predicted_event_triggers)
82 | print('predicted_events', predicted_events)
83 | for eventstart_sentid in predicted_event_triggers:
84 | # print('eventstart_sentid', eventstart_sentid)
85 | event_end, event_type_ace = predicted_event_triggers[eventstart_sentid]
86 | # if '__' in eventstart_sentid:
87 | event_start = int(eventstart_sentid[eventstart_sentid.find('__')+2:])
88 | # print('predicted_event_triggers', predicted_event_triggers)
89 | event_start_offset, event_end_offset = self.get_offset_by_idx(sent_id, event_start, event_end)
90 | trigger_word = ' '.join(self.sent_info[sent_id]['words'][event_start:event_end]) # multiple words
91 | # T44 Attack 6595 6604 onslaught
92 | # event_id = 'T%s_%d_%d' % (sent_id[sent_id.find('_'):], event_start_offset, event_end_offset)
93 | token_id = 'T%d_%d' % (event_start_offset, event_end_offset)
94 | if event_type_ace not in event_type_mapping_ace2brat:
95 | print('ignored type', event_type_ace)
96 | continue
97 | event_type_brat = event_type_mapping_ace2brat[event_type_ace]
98 | visual_writer.write(token_id)
99 | visual_writer.write('\t%s' % event_type_brat)
100 | visual_writer.write(' %d %d' % (event_start_offset, event_end_offset))
101 | visual_writer.write('\t%s\n' % trigger_word)
102 |
103 | # E11 Attack:T44 Attacker:T26493 Target:T45
104 | # save triggers
105 | event_id = token_id.replace('T', 'E')
106 | visual_writer.write('%s' % event_id)
107 | visual_writer.write('\t%s:%s' % (event_type_brat, token_id))
108 | # save args
109 | if save_entity:
110 | entity_lines = list()
111 | # try:
112 | if (eventstart_sentid, event_end, event_type_ace) in predicted_events:
113 | args = predicted_events[(eventstart_sentid, event_end, event_type_ace)]
114 | for arg_start, arg_end, arg_type in args:
115 | # index -> offset index
116 | arg_word = ' '.join(self.sent_info[sent_id]['words'][arg_start:arg_end])
117 | arg_start_offset, arg_end_offset = self.get_offset_by_idx(sent_id, arg_start, arg_end)
118 | role_name_ace = ee_role_i2s[arg_type]
119 | # print(event_type_ace, role_name_ace)
120 | role_name_brat = event_role_mapping_ace2brat[event_type_ace][role_name_ace]
121 | visual_writer.write(' %s:T%d_%d' % (role_name_brat, arg_start_offset, arg_end_offset))
122 | if save_entity:
123 | # T44 Attack 6595 6604 onslaught
124 | arg_type_brat = 'VAL' # do not save the entity types?
125 | entity_lines.append('T%d_%d\t%s %d %d\t%s\n' % (
126 | arg_start_offset, arg_end_offset, arg_type_brat,
127 | arg_start_offset, arg_end_offset, arg_word
128 | ))
129 | # except:
130 | # print('no arguments ', (eventstart_sentid, event_end, event_type_ace))
131 | visual_writer.write('\n')
132 | if save_entity:
133 | visual_writer.write(''.join(entity_lines))
134 |
135 | visual_writer.flush()
136 | visual_writer.close()
137 |
138 | def get_offset_by_idx(self, sent_id, idx_start, idx_end):
139 | # print('idx_start', idx_start, 'idx_end', idx_end)
140 | # print(self.sent_info[sent_id]['index_all'])
141 | start_offset, _ = self.sent_info[sent_id]['index_all'][idx_start]
142 | _, end_offset = self.sent_info[sent_id]['index_all'][idx_end - 1]
143 | end_offset = end_offset + 1 # brat is [start, end] not [start, end)
144 | return start_offset, end_offset
145 |
146 | def rewrite_brat(self, event_tmp_dir, ann_dir, save_entity=False):
147 | '''
148 | The previous format is not brat format, need postprocessing
149 | :return:
150 | '''
151 | # if os.path.isdir(event_tmp_dir):
152 | for event_tmp_file in os.listdir(event_tmp_dir):
153 | if event_tmp_file.endswith('.ann_tmp'):
154 | event_tmp_path = os.path.join(event_tmp_dir, event_tmp_file)
155 | ann_path = os.path.join(ann_dir, event_tmp_file.replace('.ann_tmp', '.ann'))
156 | if not save_entity:
157 | self._rewrite_brat(event_tmp_path, ann_path)
158 | else:
159 | # remove repeated lines?
160 | # simpler version: copy as final ann file, ignore the repeated lines
161 | shutil.copyfile(event_tmp_path, event_tmp_path.replace('_tmp', ''))
162 |
163 | # copy rsd txt
164 | rsd_path = os.path.join(ann_dir, event_tmp_file.replace('.ann_tmp', '.txt'))
165 | rsd_path_new = os.path.join(event_tmp_dir, event_tmp_file.replace('.ann_tmp', '.txt'))
166 | shutil.copyfile(rsd_path, rsd_path_new)
167 |
168 | def _rewrite_brat(self, event_tmp_file, anno_file):
169 | # get all entity offset and entity id mapping
170 | token_offsetid2realid = dict()
171 |
172 | entity_lines = list()
173 | if os.path.exists(anno_file):
174 | lines = open(anno_file).readlines()
175 | for line in lines:
176 | if line.startswith('T'):
177 | tabs = line.split('\t')
178 | id = tabs[0]
179 | subs = tabs[1].split(' ')
180 | type = subs[0]
181 | if type in entity_type_mapping_brat:
182 | start = int(subs[1])
183 | end = int(subs[2])
184 | # mention = tabs[2]
185 | offset_id = 'T%d_%d' % (start, end)
186 | token_offsetid2realid[offset_id] = id
187 | entity_lines.append(line)
188 | else:
189 | print('NoANN', anno_file)
190 |
191 | # write entities
192 | writer = open(event_tmp_file.replace('_tmp', ''), 'w')
193 | print(event_tmp_file.replace('_tmp', ''))
194 | writer.write(''.join(entity_lines))
195 |
196 | # write events:
197 | for line in open(event_tmp_file):
198 | line = line.rstrip('\n')
199 | if line.startswith('T'):
200 | writer.write('%s\n' % line)
201 | else:
202 | # update the arg token_id
203 | # E11 Attack:T44 Attacker:T26493 Target:T45
204 | role_values = line.split('\t')[-1].split(' ')
205 | # event_id = role_values[0].split(':')[1]
206 | # event_type = role_values[0].split(':')[0]
207 | writer.write(line.split('\t')[0])
208 | writer.write('\t')
209 | writer.write(role_values[0])
210 | entity_added_lines=list()
211 | for role_value in role_values[1:]:
212 | # role_name_raw = role_value.split(':')[0].replace('2', '').replace('3', '').replace('4', '').replace(
213 | # '5', '')
214 | # if len(role_name_raw) == 0:
215 | # continue
216 | # if role_name_raw not in event_role_mapping_brat2ace[event_type]:
217 | # print('ignored role: ', event_id, role_value)
218 | # continue
219 | # role_name = event_role_mapping_brat2ace[event_type][role_name_raw]
220 | entity = role_value.split(':')[1]
221 | if entity not in token_offsetid2realid:
222 | print('[ERROR] entity id can not find in *.ann', entity, anno_file)
223 | entity_type = 'VAL'
224 | entity_str = 'VAL'
225 | entity_added_lines.append('%s\t%s %s %s\t%s\n' % (entity, entity_type,
226 | entity[1:].split('_')[0],
227 | entity[1:].split('_')[1],
228 | entity_str) )
229 | entity_realid = entity
230 | else:
231 | entity_realid = token_offsetid2realid[entity]
232 | # print('entity', entity, 'entity_realid', entity_realid)
233 | writer.write(' ')
234 | writer.write(role_value.replace(entity, entity_realid))
235 | writer.write('\n')
236 | writer.write(''.join(entity_added_lines))
237 |
238 | writer.flush()
239 | writer.close()
240 |
241 |
--------------------------------------------------------------------------------
/src/util/util_img.py:
--------------------------------------------------------------------------------
1 | '''
2 | Revised based on https://github.com/hassanhub/MultiGrounding/blob/master/code/utils.py
3 | '''
4 |
5 | from skimage.feature import peak_local_max
6 | import cv2
7 | import numpy as np
8 | import matplotlib.pyplot as plt
9 | from scipy import ndimage as ndi
10 |
11 |
12 | # bbox generation config
13 | rel_peak_thr = .3
14 | rel_rel_thr = .3
15 | ioa_thr = .6
16 | topk_boxes = 3
17 |
18 |
19 | def heat2bbox(heat_map, original_image_shape):
20 | h, w = heat_map.shape
21 |
22 | bounding_boxes = []
23 |
24 | heat_map = heat_map - np.min(heat_map)
25 | heat_map = heat_map / np.max(heat_map)
26 |
27 | bboxes = []
28 | box_scores = []
29 |
30 | peak_coords = peak_local_max(heat_map, exclude_border=False,
31 | threshold_rel=rel_peak_thr) # find local peaks of heat map
32 |
33 | heat_resized = cv2.resize(heat_map, (
34 | original_image_shape[1], original_image_shape[0])) ## resize heat map to original image shape
35 | peak_coords_resized = ((peak_coords + 0.5) *
36 | np.asarray([original_image_shape]) /
37 | np.asarray([[h, w]])
38 | ).astype('int32')
39 |
40 | for pk_coord in peak_coords_resized:
41 | pk_value = heat_resized[tuple(pk_coord)]
42 | mask = heat_resized > pk_value * rel_rel_thr
43 | labeled, n = ndi.label(mask)
44 | l = labeled[tuple(pk_coord)]
45 | yy, xx = np.where(labeled == l)
46 | min_x = np.min(xx)
47 | min_y = np.min(yy)
48 | max_x = np.max(xx)
49 | max_y = np.max(yy)
50 | bboxes.append((min_x, min_y, max_x, max_y))
51 | box_scores.append(pk_value) # you can change to pk_value * probability of sentence matching image or etc.
52 |
53 | ## Merging boxes that overlap too much
54 | box_idx = np.argsort(-np.asarray(box_scores))
55 | box_idx = box_idx[:min(topk_boxes, len(box_scores))]
56 | bboxes = [bboxes[i] for i in box_idx]
57 | box_scores = [box_scores[i] for i in box_idx]
58 |
59 | to_remove = []
60 | for iii in range(len(bboxes)):
61 | for iiii in range(iii):
62 | if iiii in to_remove:
63 | continue
64 | b1 = bboxes[iii]
65 | b2 = bboxes[iiii]
66 | isec = max(min(b1[2], b2[2]) - max(b1[0], b2[0]), 0) * max(min(b1[3], b2[3]) - max(b1[1], b2[1]), 0)
67 | ioa1 = isec / ((b1[2] - b1[0]) * (b1[3] - b1[1]))
68 | ioa2 = isec / ((b2[2] - b2[0]) * (b2[3] - b2[1]))
69 | if ioa1 > ioa_thr and ioa1 == ioa2:
70 | to_remove.append(iii)
71 | elif ioa1 > ioa_thr and ioa1 >= ioa2:
72 | to_remove.append(iii)
73 | elif ioa2 > ioa_thr and ioa2 >= ioa1:
74 | to_remove.append(iiii)
75 |
76 | for i in range(len(bboxes)):
77 | if i not in to_remove:
78 | bounding_boxes.append({
79 | 'score': box_scores[i],
80 | 'bbox': bboxes[i],
81 | 'bbox_normalized': np.asarray([
82 | bboxes[i][0] / heat_resized.shape[1],
83 | bboxes[i][1] / heat_resized.shape[0],
84 | bboxes[i][2] / heat_resized.shape[1],
85 | bboxes[i][3] / heat_resized.shape[0],
86 | ]),
87 | })
88 |
89 | return bounding_boxes
90 |
91 |
92 | def img_heat_bbox_disp(image, heat_map, title='', en_name='', alpha=0.6, cmap='viridis', cbar='False', dot_max=False,
93 | bboxes=[], order=None, show=True):
94 | thr_hit = 1 # a bbox is acceptable if hit point is in middle 85% of bbox area
95 | thr_fit = .60 # the biggest acceptable bbox should not exceed 60% of the image
96 | H, W = image.shape[0:2]
97 | # resize heat map
98 | heat_map_resized = cv2.resize(heat_map, (H, W))
99 |
100 | # display
101 | fig = plt.figure(figsize=(15, 5))
102 | fig.suptitle(title, size=15)
103 | ax = plt.subplot(1, 3, 1)
104 | plt.imshow(image)
105 | if dot_max:
106 | max_loc = np.unravel_index(np.argmax(heat_map_resized, axis=None), heat_map_resized.shape)
107 | plt.scatter(x=max_loc[1], y=max_loc[0], edgecolor='w', linewidth=3)
108 |
109 | if len(bboxes) > 0: # it gets normalized bbox
110 | if order == None:
111 | order = 'xxyy'
112 |
113 | for i in range(len(bboxes)):
114 | bbox_norm = bboxes[i]
115 | if order == 'xxyy':
116 | x_min, x_max, y_min, y_max = int(bbox_norm[0] * W), int(bbox_norm[1] * W), int(bbox_norm[2] * H), int(
117 | bbox_norm[3] * H)
118 | elif order == 'xyxy':
119 | x_min, x_max, y_min, y_max = int(bbox_norm[0] * W), int(bbox_norm[2] * W), int(bbox_norm[1] * H), int(
120 | bbox_norm[3] * H)
121 | x_length, y_length = x_max - x_min, y_max - y_min
122 | box = plt.Rectangle((x_min, y_min), x_length, y_length, edgecolor='w', linewidth=3, fill=False)
123 | plt.gca().add_patch(box)
124 | if en_name != '':
125 | ax.text(x_min + .5 * x_length, y_min + 10, en_name,
126 | verticalalignment='center', horizontalalignment='center',
127 | # transform=ax.transAxes,
128 | color='white', fontsize=15)
129 | # an = ax.annotate(en_name, xy=(x_min,y_min), xycoords="data", va="center", ha="center", bbox=dict(boxstyle="round", fc="w"))
130 | # plt.gca().add_patch(an)
131 |
132 | plt.imshow(heat_map_resized, alpha=alpha, cmap=cmap)
133 |
134 | # plt.figure(2, figsize=(6, 6))
135 | plt.subplot(1, 3, 2)
136 | plt.imshow(image)
137 | # plt.figure(3, figsize=(6, 6))
138 | plt.subplot(1, 3, 3)
139 | plt.imshow(heat_map_resized)
140 | fig.tight_layout()
141 | fig.subplots_adjust(top=.85)
142 |
143 | if show:
144 | plt.show()
145 | else:
146 | plt.close()
147 |
148 | return fig
149 |
150 |
151 | def filter_bbox(bbox_dict, order=None):
152 | thr_fit = .99 # the biggest acceptable bbox should not exceed 80% of the image
153 | if order == None:
154 | order = 'xxyy'
155 |
156 | filtered_bbox = []
157 | filtered_bbox_norm = []
158 | filtered_score = []
159 | if len(bbox_dict) > 0: # it gets normalized bbox
160 | for i in range(len(bbox_dict)):
161 | bbox = bbox_dict[i]['bbox']
162 | bbox_norm = bbox_dict[i]['bbox_normalized']
163 | bbox_score = bbox_dict[i]['score']
164 | if order == 'xxyy':
165 | x_min, x_max, y_min, y_max = bbox_norm[0], bbox_norm[1], bbox_norm[2], bbox_norm[3]
166 | elif order == 'xyxy':
167 | x_min, x_max, y_min, y_max = bbox_norm[0], bbox_norm[2], bbox_norm[1], bbox_norm[3]
168 | if bbox_score > 0:
169 | x_length, y_length = x_max - x_min, y_max - y_min
170 | if x_length * y_length < thr_fit:
171 | filtered_score.append(bbox_score)
172 | filtered_bbox.append(bbox)
173 | filtered_bbox_norm.append(bbox_norm)
174 | return filtered_bbox, filtered_bbox_norm, filtered_score
175 |
176 |
177 | def crop_resize_im(image, bbox, size, order='xxyy'):
178 | H, W, _ = image.shape
179 | if order == 'xxyy':
180 | roi = image[int(bbox[2] * H):int(bbox[3] * H), int(bbox[0] * W):int(bbox[1] * W), :]
181 | elif order == 'xyxy':
182 | roi = image[int(bbox[1] * H):int(bbox[3] * H), int(bbox[0] * W):int(bbox[2] * W), :]
183 | roi = cv2.resize(roi, size)
184 | return roi
185 |
186 |
187 | def im2double(im):
188 | return cv2.normalize(im.astype('float'), None, 0.0, 1.0, cv2.NORM_MINMAX)
189 |
190 |
191 | def IoU(boxA, boxB):
192 | # order = xyxy
193 | xA = max(boxA[0], boxB[0])
194 | yA = max(boxA[1], boxB[1])
195 | xB = min(boxA[2], boxB[2])
196 | yB = min(boxA[3], boxB[3])
197 |
198 | # compute the area of intersection rectangle
199 | interArea = max(0, xB - xA) * max(0, yB - yA)
200 |
201 | # compute the area of both the prediction and ground-truth
202 | # rectangles
203 | boxAArea = (boxA[2] - boxA[0]) * (boxA[3] - boxA[1])
204 | boxBArea = (boxB[2] - boxB[0]) * (boxB[3] - boxB[1])
205 |
206 | # compute the intersection over union by taking the intersection
207 | # area and dividing it by the sum of prediction + ground-truth
208 | # areas - the interesection area
209 | iou = interArea / float(boxAArea + boxBArea - interArea)
210 |
211 | # return the intersection over union value
212 | return iou
213 |
214 |
215 | def isCorrect(bbox_annot, bbox_pred, iou_thr=.4):
216 | iou_value_max = 0.0
217 | for bbox_p in bbox_pred:
218 | for bbox_a in bbox_annot:
219 | iou_value = IoU(bbox_p, bbox_a)
220 | iou_value_max = max(iou_value, iou_value_max)
221 | if iou_value >= iou_thr:
222 | return 1, iou_value
223 | return 0, iou_value_max
224 |
225 |
226 | def isCorrectHit(bbox_annot, heatmap, orig_img_shape):
227 | H, W = orig_img_shape
228 | heatmap_resized = cv2.resize(heatmap, (W, H))
229 | max_loc = np.unravel_index(np.argmax(heatmap_resized, axis=None), heatmap_resized.shape)
230 | print('max_loc', max_loc)
231 | for bbox in bbox_annot:
232 | if bbox[0] <= max_loc[1] <= bbox[2] and bbox[1] <= max_loc[0] <= bbox[3]:
233 | return 1
234 | return 0
235 |
236 |
237 | def check_percent(bboxes):
238 | for bbox in bboxes:
239 | x_length = bbox[2] - bbox[0]
240 | y_length = bbox[3] - bbox[1]
241 | if x_length * y_length < .05:
242 | return False
243 | return True
244 |
245 |
246 | def union(bbox):
247 | if len(bbox) == 0:
248 | return []
249 | if type(bbox[0]) == type(0.0) or type(bbox[0]) == type(0):
250 | bbox = [bbox]
251 | maxes = np.max(bbox, axis=0)
252 | mins = np.min(bbox, axis=0)
253 | return [[mins[0], mins[1], maxes[2], maxes[3]]]
254 |
255 |
256 | def attCorrectness(bbox_annot, heatmap, orig_img_shape):
257 | H, W = orig_img_shape
258 | heatmap_resized = cv2.resize(heatmap, (W, H))
259 | h_s = np.sum(heatmap_resized)
260 | if h_s == 0:
261 | return 0
262 | else:
263 | heatmap_resized /= h_s
264 | att_correctness = 0
265 | for bbox in bbox_annot:
266 | x0, y0, x1, y1 = bbox
267 | att_correctness += np.sum(heatmap_resized[y0:y1, x0:x1])
268 | return att_correctness
269 |
270 |
271 | def calc_correctness(annot, heatmap, orig_img_shape, iou_thr=.5):
272 | bbox_dict = heat2bbox(heatmap, orig_img_shape)
273 | bbox, bbox_norm, bbox_score = filter_bbox(bbox_dict=bbox_dict, order='xyxy')
274 | bbox_norm_annot = union(annot['bbox_norm'])
275 | bbox_annot = annot['bbox']
276 | bbox_norm_pred = union(bbox_norm)
277 | # print('bbox_norm_annot', bbox_norm_annot)
278 | # print('bbox_norm_pred', bbox_norm_pred)
279 | # print('bbox_annot', bbox_annot)
280 | # print('bbox_norm', bbox_norm)
281 | bbox_correctness, bbox_iou = isCorrect(bbox_norm_annot, bbox_norm_pred, iou_thr=iou_thr)
282 | hit_correctness = isCorrectHit(bbox_annot, heatmap, orig_img_shape)
283 | att_correctness = attCorrectness(bbox_annot, heatmap, orig_img_shape)
284 | return bbox_correctness, hit_correctness, att_correctness, bbox_iou
285 |
286 | def precision_bbox(boxA, boxB):
287 | '''
288 |
289 | :param boxA: predicted
290 | :param boxB: GT
291 | :return:
292 | '''
293 | # order = xyxy
294 | xA = max(boxA[0], boxB[0])
295 | yA = max(boxA[1], boxB[1])
296 | xB = min(boxA[2], boxB[2])
297 | yB = min(boxA[3], boxB[3])
298 |
299 | # compute the area of intersection rectangle
300 | interArea = max(0, xB - xA) * max(0, yB - yA)
301 |
302 | # compute the area of both the prediction and ground-truth
303 | # rectangles
304 | boxAArea = (boxA[2] - boxA[0]) * (boxA[3] - boxA[1])
305 | # boxBArea = (boxB[2] - boxB[0]) * (boxB[3] - boxB[1])
306 |
307 | # compute the intersection over union by taking the intersection
308 | # area and dividing it by the sum of prediction + ground-truth
309 | # areas - the interesection area
310 | precision = interArea / float(boxAArea)
311 | # recall = interArea / float(boxB)
312 |
313 | # return the intersection over union value
314 | return precision
315 |
316 | def calc_correctness_box(annot_all_norm, bbox_one_norm, iou_thr=.5, count_inside=False):
317 | '''
318 |
319 | :param bbox_annot: list of boxes
320 | :param bbox_pred_one: one bbox
321 | :return:
322 | '''
323 | # bbox_norm_annot = annot['bbox_norm']
324 | # bbox_annot = annot['bbox']
325 |
326 | # bbox_norm = bboxes_role['bbox_norm']
327 | # bbox = bboxes_role['bbox']
328 |
329 | # for bbox_p in bbox_norm:
330 | iou_value_max = 0.0
331 | for bbox_a in annot_all_norm:
332 | iou_value = IoU(bbox_one_norm, bbox_a)
333 | iou_value_max = max(iou_value, iou_value_max)
334 | if iou_value >= iou_thr:
335 | return 1, iou_value
336 | if count_inside:
337 | precision = precision_bbox(bbox_one_norm, bbox_a)
338 | if precision >= 0.9:
339 | return 1, precision
340 | return 0, iou_value_max
341 |
342 | def calc_correctness_box_uion(annot_all_norm, bbox_all_norm, iou_thr=.5, count_inside=False):
343 | '''
344 |
345 | :param bbox_annot: list of boxes
346 | :param bbox_pred_one: one bbox
347 | :return:
348 | '''
349 | # bbox_norm_annot = annot['bbox_norm']
350 | # bbox_annot = annot['bbox']
351 |
352 | # bbox_norm = bboxes_role['bbox_norm']
353 | # bbox = bboxes_role['bbox']
354 |
355 | # for bbox_p in bbox_norm:
356 | # iou_value_max = 0.0
357 | iou_value = 0.0
358 | annot_all_norm_union = union(annot_all_norm)
359 | bbox_all_norm_union = union(bbox_all_norm)
360 |
361 | for bbox_a in annot_all_norm_union:
362 | for bbox_b in bbox_all_norm_union:
363 | iou_value = IoU(bbox_a, bbox_b)
364 | if iou_value >= iou_thr:
365 | return 1, iou_value
366 | if count_inside:
367 | precision = precision_bbox(bbox_b, bbox_a)
368 | if precision >= 0.9:
369 | return 1, precision
370 | return 0, iou_value
371 |
--------------------------------------------------------------------------------
/data/object/class-descriptions-boxable.csv:
--------------------------------------------------------------------------------
1 | /m/011k07,Tortoise,0
2 | /m/011q46kg,Container,1
3 | /m/012074,Magpie,0
4 | /m/0120dh,Sea turtle,0
5 | /m/01226z,Football,0
6 | /m/012n7d,Ambulance,1
7 | /m/012w5l,Ladder,0
8 | /m/012xff,Toothbrush,0
9 | /m/012ysf,Syringe,0
10 | /m/0130jx,Sink,0
11 | /m/0138tl,Toy,0
12 | /m/013y1f,Organ,0
13 | /m/01432t,Cassette deck,0
14 | /m/014j1m,Apple,0
15 | /m/014sv8,Human eye,1
16 | /m/014trl,Cosmetics,0
17 | /m/014y4n,Paddle,0
18 | /m/0152hh,Snowman,0
19 | /m/01599,Beer,0
20 | /m/01_5g,Chopsticks,0
21 | /m/015h_t,Human beard,1
22 | /m/015p6,Bird,0
23 | /m/015qbp,Parking meter,0
24 | /m/015qff,Traffic light,0
25 | /m/015wgc,Croissant,0
26 | /m/015x4r,Cucumber,0
27 | /m/015x5n,Radish,0
28 | /m/0162_1,Towel,0
29 | /m/0167gd,Doll,0
30 | /m/016m2d,Skull,0
31 | /m/0174k2,Washing machine,0
32 | /m/0174n1,Glove,0
33 | /m/0175cv,Tick,0
34 | /m/0176mf,Belt,0
35 | /m/017ftj,Sunglasses,0
36 | /m/018j2,Banjo,0
37 | /m/018p4k,Cart,0
38 | /m/018xm,Ball,0
39 | /m/01940j,Backpack,0
40 | /m/0199g,Bicycle,1
41 | /m/019dx1,Home appliance,0
42 | /m/019h78,Centipede,0
43 | /m/019jd,Boat,1
44 | /m/019w40,Surfboard,0
45 | /m/01b638,Boot,0
46 | /m/01b7fy,Headphones,1
47 | /m/01b9xk,Hot dog,0
48 | /m/01bfm9,Shorts,0
49 | /m/01_bhs,Fast food,0
50 | /m/01bjv,Bus,1
51 | /m/01bl7v,Boy,1
52 | /m/01bms0,Screwdriver,0
53 | /m/01bqk0,Bicycle wheel,0
54 | /m/01btn,Barge,0
55 | /m/01c648,Laptop,0
56 | /m/01cmb2,Miniskirt,0
57 | /m/01d380,Drill,0
58 | /m/01d40f,Dress,0
59 | /m/01dws,Bear,0
60 | /m/01dwsz,Waffle,0
61 | /m/01dwwc,Pancake,0
62 | /m/01dxs,Brown bear,0
63 | /m/01dy8n,Woodpecker,0
64 | /m/01f8m5,Blue jay,0
65 | /m/01f91_,Pretzel,0
66 | /m/01fb_0,Bagel,0
67 | /m/01fdzj,Tower,0
68 | /m/01fh4r,Teapot,0
69 | /m/01g317,Person,1
70 | /m/01g3x7,Bow and arrow,0
71 | /m/01gkx_,Swimwear,0
72 | /m/01gllr,Beehive,0
73 | /m/01gmv2,Brassiere,0
74 | /m/01h3n,Bee,0
75 | /m/01h44,Bat,0
76 | /m/01h8tj,Starfish,0
77 | /m/01hrv5,Popcorn,0
78 | /m/01j3zr,Burrito,0
79 | /m/01j4z9,Chainsaw,0
80 | /m/01j51,Balloon,0
81 | /m/01j5ks,Wrench,0
82 | /m/01j61q,Tent,0
83 | /m/01jfm_,Vehicle registration plate,0
84 | /m/01jfsr,Lantern,0
85 | /m/01k6s3,Toaster,0
86 | /m/01kb5b,Flashlight,0
87 | /m/01knjb,Billboard,1
88 | /m/01krhy,Tiara,0
89 | /m/01lcw4,Limousine,0
90 | /m/01llwg,Necklace,0
91 | /m/01lrl,Carnivore,0
92 | /m/01lsmm,Scissors,0
93 | /m/01lynh,Stairs,0
94 | /m/01m2v,Computer keyboard,0
95 | /m/01m4t,Printer,0
96 | /m/01mqdt,Traffic sign,0
97 | /m/01mzpv,Chair,0
98 | /m/01n4qj,Shirt,0
99 | /m/01n5jq,Poster,1
100 | /m/01nkt,Cheese,0
101 | /m/01nq26,Sock,0
102 | /m/01pns0,Fire hydrant,0
103 | /m/01prls,Land vehicle,0
104 | /m/01r546,Earrings,0
105 | /m/01rkbr,Tie,0
106 | /m/01rzcn,Watercraft,1
107 | /m/01s105,Cabinetry,0
108 | /m/01s55n,Suitcase,0
109 | /m/01tcjp,Muffin,0
110 | /m/01vbnl,Bidet,0
111 | /m/01ww8y,Snack,0
112 | /m/01x3jk,Snowmobile,0
113 | /m/01x3z,Clock,0
114 | /m/01xgg_,Medical equipment,1
115 | /m/01xq0k1,Cattle,0
116 | /m/01xqw,Cello,1
117 | /m/01xs3r,Jet ski,1
118 | /m/01x_v,Camel,0
119 | /m/01xygc,Coat,0
120 | /m/01xyhv,Suit,0
121 | /m/01y9k5,Desk,0
122 | /m/01yrx,Cat,0
123 | /m/01yx86,Bronze sculpture,0
124 | /m/01z1kdw,Juice,0
125 | /m/02068x,Gondola,1
126 | /m/020jm,Beetle,0
127 | /m/020kz,Cannon,0
128 | /m/020lf,Computer mouse,0
129 | /m/021mn,Cookie,0
130 | /m/021sj1,Office building,1
131 | /m/0220r2,Fountain,0
132 | /m/0242l,Coin,0
133 | /m/024d2,Calculator,0
134 | /m/024g6,Cocktail,0
135 | /m/02522,Computer monitor,0
136 | /m/025dyy,Box,0
137 | /m/025fsf,Stapler,0
138 | /m/025nd,Christmas tree,0
139 | /m/025rp__,Cowboy hat,0
140 | /m/0268lbt,Hiking equipment,0
141 | /m/026qbn5,Studio couch,0
142 | /m/026t6,Drum,0
143 | /m/0270h,Dessert,0
144 | /m/0271qf7,Wine rack,0
145 | /m/0271t,Drink,0
146 | /m/027pcv,Zucchini,0
147 | /m/027rl48,Ladle,0
148 | /m/0283dt1,Human mouth,0
149 | /m/0284d,Dairy,0
150 | /m/029b3,Dice,0
151 | /m/029bxz,Oven,0
152 | /m/029tx,Dinosaur,0
153 | /m/02bm9n,Ratchet,0
154 | /m/02crq1,Couch,0
155 | /m/02ctlc,Cricket ball,0
156 | /m/02cvgx,Winter melon,0
157 | /m/02d1br,Spatula,0
158 | /m/02d9qx,Whiteboard,0
159 | /m/02ddwp,Pencil sharpener,0
160 | /m/02dgv,Door,0
161 | /m/02dl1y,Hat,0
162 | /m/02f9f_,Shower,0
163 | /m/02fh7f,Eraser,0
164 | /m/02fq_6,Fedora,0
165 | /m/02g30s,Guacamole,0
166 | /m/02gzp,Dagger,0
167 | /m/02h19r,Scarf,0
168 | /m/02hj4,Dolphin,0
169 | /m/02jfl0,Sombrero,0
170 | /m/02jnhm,Tin can,0
171 | /m/02jvh9,Mug,0
172 | /m/02jz0l,Tap,0
173 | /m/02l8p9,Harbor seal,0
174 | /m/02lbcq,Stretcher,1
175 | /m/02mqfb,Can opener,0
176 | /m/02_n6y,Goggles,0
177 | /m/02p0tk3,Human body,1
178 | /m/02p3w7d,Roller skates,0
179 | /m/02p5f1q,Coffee cup,0
180 | /m/02pdsw,Cutting board,0
181 | /m/02pjr4,Blender,0
182 | /m/02pkr5,Plumbing fixture,0
183 | /m/02pv19,Stop sign,0
184 | /m/02rdsp,Office supplies,0
185 | /m/02rgn06,Volleyball,0
186 | /m/02s195,Vase,0
187 | /m/02tsc9,Slow cooker,0
188 | /m/02vkqh8,Wardrobe,0
189 | /m/02vqfm,Coffee,0
190 | /m/02vwcm,Whisk,0
191 | /m/02w3r3,Paper towel,0
192 | /m/02w3_ws,Personal care,0
193 | /m/02wbm,Food,0
194 | /m/02wbtzl,Sun hat,0
195 | /m/02wg_p,Tree house,0
196 | /m/02wmf,Flying disc,0
197 | /m/02wv6h6,Skirt,0
198 | /m/02wv84t,Gas stove,0
199 | /m/02x8cch,Salt and pepper shakers,0
200 | /m/02x984l,Mechanical fan,0
201 | /m/02xb7qb,Face powder,0
202 | /m/02xqq,Fax,0
203 | /m/02xwb,Fruit,0
204 | /m/02y6n,French fries,0
205 | /m/02z51p,Nightstand,0
206 | /m/02zn6n,Barrel,0
207 | /m/02zt3,Kite,0
208 | /m/02zvsm,Tart,0
209 | /m/030610,Treadmill,0
210 | /m/0306r,Fox,0
211 | /m/03120,Flag,1
212 | /m/0319l,Horn,0
213 | /m/031b6r,Window blind,0
214 | /m/031n1,Human foot,0
215 | /m/0323sq,Golf cart,0
216 | /m/032b3c,Jacket,0
217 | /m/033cnk,Egg,0
218 | /m/033rq4,Street light,0
219 | /m/0342h,Guitar,0
220 | /m/034c16,Pillow,0
221 | /m/035r7c,Human leg,1
222 | /m/035vxb,Isopod,0
223 | /m/0388q,Grape,0
224 | /m/039xj_,Human ear,0
225 | /m/03bbps,Power plugs and sockets,0
226 | /m/03bj1,Panda,0
227 | /m/03bk1,Giraffe,0
228 | /m/03bt1vf,Woman,1
229 | /m/03c7gz,Door handle,0
230 | /m/03d443,Rhinoceros,0
231 | /m/03dnzn,Bathtub,0
232 | /m/03fj2,Goldfish,0
233 | /m/03fp41,Houseplant,0
234 | /m/03fwl,Goat,0
235 | /m/03g8mr,Baseball bat,0
236 | /m/03grzl,Baseball glove,0
237 | /m/03hj559,Mixing bowl,0
238 | /m/03hl4l9,Marine invertebrates,0
239 | /m/03hlz0c,Kitchen utensil,0
240 | /m/03jbxj,Light switch,0
241 | /m/03jm5,House,1
242 | /m/03k3r,Horse,0
243 | /m/03kt2w,Stationary bicycle,0
244 | /m/03l9g,Hammer,0
245 | /m/03ldnb,Ceiling fan,0
246 | /m/03m3pdh,Sofa bed,0
247 | /m/03m3vtv,Adhesive tape,0
248 | /m/03m5k,Harp,0
249 | /m/03nfch,Sandal,0
250 | /m/03p3bw,Bicycle helmet,0
251 | /m/03q5c7,Saucer,0
252 | /m/03q5t,Harpsichord,0
253 | /m/03q69,Human hair,0
254 | /m/03qhv5,Heater,0
255 | /m/03qjg,Harmonica,0
256 | /m/03qrc,Hamster,0
257 | /m/03rszm,Curtain,0
258 | /m/03ssj5,Bed,0
259 | /m/03s_tn,Kettle,0
260 | /m/03tw93,Fireplace,0
261 | /m/03txqz,Scale,0
262 | /m/03v5tg,Drinking straw,0
263 | /m/03vt0,Insect,0
264 | /m/03wvsk,Hair dryer,0
265 | /m/03_wxk,Kitchenware,0
266 | /m/03wym,Indoor rower,0
267 | /m/03xxp,Invertebrate,0
268 | /m/03y6mg,Food processor,0
269 | /m/03__z0,Bookcase,0
270 | /m/040b_t,Refrigerator,0
271 | /m/04169hn,Wood-burning stove,0
272 | /m/0420v5,Punching bag,0
273 | /m/043nyj,Common fig,0
274 | /m/0440zs,Cocktail shaker,0
275 | /m/0449p,Jaguar,0
276 | /m/044r5d,Golf ball,0
277 | /m/0463sg,Fashion accessory,0
278 | /m/046dlr,Alarm clock,0
279 | /m/047j0r,Filing cabinet,0
280 | /m/047v4b,Artichoke,0
281 | /m/04bcr3,Table,0
282 | /m/04brg2,Tableware,0
283 | /m/04c0y,Kangaroo,0
284 | /m/04cp_,Koala,0
285 | /m/04ctx,Knife,0
286 | /m/04dr76w,Bottle,0
287 | /m/04f5ws,Bottle opener,0
288 | /m/04g2r,Lynx,0
289 | /m/04gth,Lavender,0
290 | /m/04h7h,Lighthouse,0
291 | /m/04h8sr,Dumbbell,0
292 | /m/04hgtk,Human head,1
293 | /m/04kkgm,Bowl,0
294 | /m/04lvq_,Humidifier,0
295 | /m/04m6gz,Porch,0
296 | /m/04m9y,Lizard,0
297 | /m/04p0qw,Billiard table,0
298 | /m/04rky,Mammal,0
299 | /m/04rmv,Mouse,0
300 | /m/04_sv,Motorcycle,1
301 | /m/04szw,Musical instrument,0
302 | /m/04tn4x,Swim cap,0
303 | /m/04v6l4,Frying pan,0
304 | /m/04vv5k,Snowplow,0
305 | /m/04y4h8h,Bathroom cabinet,0
306 | /m/04ylt,Missile,1
307 | /m/04yqq2,Bust,0
308 | /m/04yx4,Man,1
309 | /m/04z4wx,Waffle iron,0
310 | /m/04zpv,Milk,0
311 | /m/04zwwv,Ring binder,0
312 | /m/050gv4,Plate,0
313 | /m/050k8,Mobile phone,1
314 | /m/052lwg6,Baked goods,0
315 | /m/052sf,Mushroom,0
316 | /m/05441v,Crutch,0
317 | /m/054fyh,Pitcher,0
318 | /m/054_l,Mirror,0
319 | /m/054xkw,Lifejacket,0
320 | /m/05_5p_0,Table tennis racket,0
321 | /m/05676x,Pencil case,0
322 | /m/057cc,Musical keyboard,0
323 | /m/057p5t,Scoreboard,0
324 | /m/0584n8,Briefcase,0
325 | /m/058qzx,Kitchen knife,0
326 | /m/05bm6,Nail,0
327 | /m/05ctyq,Tennis ball,0
328 | /m/05gqfk,Plastic bag,0
329 | /m/05kms,Oboe,0
330 | /m/05kyg_,Chest of drawers,0
331 | /m/05n4y,Ostrich,0
332 | /m/05r5c,Piano,0
333 | /m/05r655,Girl,1
334 | /m/05s2s,Plant,0
335 | /m/05vtc,Potato,0
336 | /m/05w9t9,Hair spray,0
337 | /m/05y5lj,Sports equipment,0
338 | /m/05z55,Pasta,0
339 | /m/05z6w,Penguin,0
340 | /m/05zsy,Pumpkin,0
341 | /m/061_f,Pear,0
342 | /m/061hd_,Infant bed,0
343 | /m/0633h,Polar bear,0
344 | /m/063rgb,Mixer,0
345 | /m/0642b4,Cupboard,0
346 | /m/065h6l,Jacuzzi,0
347 | /m/0663v,Pizza,0
348 | /m/06_72j,Digital clock,0
349 | /m/068zj,Pig,0
350 | /m/06bt6,Reptile,0
351 | /m/06c54,Rifle,1
352 | /m/06c7f7,Lipstick,0
353 | /m/06_fw,Skateboard,0
354 | /m/06j2d,Raven,0
355 | /m/06k2mb,High heels,0
356 | /m/06l9r,Red panda,0
357 | /m/06m11,Rose,0
358 | /m/06mf6,Rabbit,0
359 | /m/06msq,Sculpture,0
360 | /m/06ncr,Saxophone,0
361 | /m/06nrc,Shotgun,1
362 | /m/06nwz,Seafood,0
363 | /m/06pcq,Submarine sandwich,0
364 | /m/06__v,Snowboard,0
365 | /m/06y5r,Sword,0
366 | /m/06z37_,Picture frame,0
367 | /m/07030,Sushi,0
368 | /m/0703r8,Loveseat,0
369 | /m/071p9,Ski,0
370 | /m/071qp,Squirrel,0
371 | /m/073bxn,Tripod,0
372 | /m/073g6,Stethoscope,0
373 | /m/074d1,Submarine,0
374 | /m/0755b,Scorpion,0
375 | /m/076bq,Segway,0
376 | /m/076lb9,Training bench,0
377 | /m/078jl,Snake,0
378 | /m/078n6m,Coffee table,0
379 | /m/079cl,Skyscraper,1
380 | /m/07bgp,Sheep,0
381 | /m/07c52,Television,0
382 | /m/07c6l,Trombone,0
383 | /m/07clx,Tea,0
384 | /m/07cmd,Tank,1
385 | /m/07crc,Taco,0
386 | /m/07cx4,Telephone,0
387 | /m/07dd4,Torch,0
388 | /m/07dm6,Tiger,0
389 | /m/07fbm7,Strawberry,0
390 | /m/07gql,Trumpet,0
391 | /m/07j7r,Tree,0
392 | /m/07j87,Tomato,0
393 | /m/07jdr,Train,1
394 | /m/07k1x,Tool,1
395 | /m/07kng9,Picnic basket,0
396 | /m/07mcwg,Cooking spray,0
397 | /m/07mhn,Trousers,0
398 | /m/07pj7bq,Bowling equipment,0
399 | /m/07qxg_,Football helmet,0
400 | /m/07r04,Truck,1
401 | /m/07v9_z,Measuring cup,0
402 | /m/07xyvk,Coffeemaker,0
403 | /m/07y_7,Violin,0
404 | /m/07yv9,Vehicle,1
405 | /m/080hkjn,Handbag,0
406 | /m/080n7g,Paper cutter,0
407 | /m/081qc,Wine,0
408 | /m/083kb,Weapon,0
409 | /m/083wq,Wheel,0
410 | /m/084hf,Worm,0
411 | /m/084rd,Wok,0
412 | /m/084zz,Whale,0
413 | /m/0898b,Zebra,0
414 | /m/08dz3q,Auto part,1
415 | /m/08hvt4,Jug,0
416 | /m/08ks85,Pizza cutter,0
417 | /m/08p92x,Cream,0
418 | /m/08pbxl,Monkey,0
419 | /m/096mb,Lion,0
420 | /m/09728,Bread,0
421 | /m/099ssp,Platter,0
422 | /m/09b5t,Chicken,0
423 | /m/09csl,Eagle,0
424 | /m/09ct_,Helicopter,1
425 | /m/09d5_,Owl,0
426 | /m/09ddx,Duck,0
427 | /m/09dzg,Turtle,0
428 | /m/09f20,Hippopotamus,0
429 | /m/09f_2,Crocodile,0
430 | /m/09g1w,Toilet,0
431 | /m/09gtd,Toilet paper,0
432 | /m/09gys,Squid,0
433 | /m/09j2d,Clothing,0
434 | /m/09j5n,Footwear,0
435 | /m/09k_b,Lemon,0
436 | /m/09kmb,Spider,0
437 | /m/09kx5,Deer,0
438 | /m/09ld4,Frog,0
439 | /m/09qck,Banana,0
440 | /m/09rvcxw,Rocket,1
441 | /m/09tvcd,Wine glass,0
442 | /m/0b3fp9,Countertop,0
443 | /m/0bh9flk,Tablet computer,0
444 | /m/0bjyj5,Waste container,0
445 | /m/0b_rs,Swimming pool,0
446 | /m/0bt9lr,Dog,0
447 | /m/0bt_c3,Book,0
448 | /m/0bwd_0j,Elephant,0
449 | /m/0by6g,Shark,0
450 | /m/0c06p,Candle,0
451 | /m/0c29q,Leopard,0
452 | /m/0c2jj,Axe,0
453 | /m/0c3m8g,Hand dryer,0
454 | /m/0c3mkw,Soap dispenser,0
455 | /m/0c568,Porcupine,0
456 | /m/0c9ph5,Flower,0
457 | /m/0ccs93,Canary,0
458 | /m/0cd4d,Cheetah,0
459 | /m/0cdl1,Palm tree,0
460 | /m/0cdn1,Hamburger,0
461 | /m/0cffdh,Maple,0
462 | /m/0cgh4,Building,1
463 | /m/0ch_cf,Fish,0
464 | /m/0cjq5,Lobster,0
465 | /m/0cjs7,Asparagus,0
466 | /m/0c_jw,Furniture,0
467 | /m/0cl4p,Hedgehog,0
468 | /m/0cmf2,Airplane,1
469 | /m/0cmx8,Spoon,0
470 | /m/0cn6p,Otter,0
471 | /m/0cnyhnx,Bull,0
472 | /m/0_cp5,Oyster,0
473 | /m/0cqn2,Horizontal bar,0
474 | /m/0crjs,Convenience store,0
475 | /m/0ct4f,Bomb,1
476 | /m/0cvnqh,Bench,0
477 | /m/0cxn2,Ice cream,0
478 | /m/0cydv,Caterpillar,0
479 | /m/0cyf8,Butterfly,0
480 | /m/0cyfs,Parachute,0
481 | /m/0cyhj_,Orange,0
482 | /m/0czz2,Antelope,0
483 | /m/0d20w4,Beaker,0
484 | /m/0d_2m,Moths and butterflies,0
485 | /m/0d4v4,Window,0
486 | /m/0d4w1,Closet,0
487 | /m/0d5gx,Castle,0
488 | /m/0d8zb,Jellyfish,0
489 | /m/0dbvp,Goose,0
490 | /m/0dbzx,Mule,0
491 | /m/0dftk,Swan,0
492 | /m/0dj6p,Peach,0
493 | /m/0djtd,Coconut,0
494 | /m/0dkzw,Seat belt,0
495 | /m/0dq75,Raccoon,0
496 | /m/0_dqb,Chisel,0
497 | /m/0dt3t,Fork,0
498 | /m/0dtln,Lamp,0
499 | /m/0dv5r,Camera,0
500 | /m/0dv77,Squash,0
501 | /m/0dv9c,Racket,0
502 | /m/0dzct,Human face,1
503 | /m/0dzf4,Human arm,1
504 | /m/0f4s2w,Vegetable,0
505 | /m/0f571,Diaper,0
506 | /m/0f6nr,Unicycle,0
507 | /m/0f6wt,Falcon,0
508 | /m/0f8s22,Chime,0
509 | /m/0f9_l,Snail,0
510 | /m/0fbdv,Shellfish,0
511 | /m/0fbw6,Cabbage,0
512 | /m/0fj52s,Carrot,0
513 | /m/0fldg,Mango,0
514 | /m/0fly7,Jeans,0
515 | /m/0fm3zh,Flowerpot,0
516 | /m/0fp6w,Pineapple,0
517 | /m/0fqfqc,Drawer,0
518 | /m/0fqt361,Stool,0
519 | /m/0frqm,Envelope,0
520 | /m/0fszt,Cake,0
521 | /m/0ft9s,Dragonfly,0
522 | /m/0ftb8,Sunflower,0
523 | /m/0fx9l,Microwave oven,0
524 | /m/0fz0h,Honeycomb,0
525 | /m/0gd2v,Marine mammal,0
526 | /m/0gd36,Sea lion,0
527 | /m/0gj37,Ladybug,0
528 | /m/0gjbg72,Shelf,0
529 | /m/0gjkl,Watch,0
530 | /m/0gm28,Candy,0
531 | /m/0grw1,Salad,0
532 | /m/0gv1x,Parrot,0
533 | /m/0gxl3,Handgun,1
534 | /m/0h23m,Sparrow,0
535 | /m/0h2r6,Van,1
536 | /m/0h8jyh6,Grinder,0
537 | /m/0h8kx63,Spice rack,0
538 | /m/0h8l4fh,Light bulb,0
539 | /m/0h8lkj8,Corded phone,1
540 | /m/0h8mhzd,Sports uniform,0
541 | /m/0h8my_4,Tennis racket,0
542 | /m/0h8mzrc,Wall clock,0
543 | /m/0h8n27j,Serving tray,0
544 | /m/0h8n5zk,Kitchen & dining room table,0
545 | /m/0h8n6f9,Dog bed,0
546 | /m/0h8n6ft,Cake stand,0
547 | /m/0h8nm9j,Cat furniture,0
548 | /m/0h8nr_l,Bathroom accessory,0
549 | /m/0h8nsvg,Facial tissue holder,0
550 | /m/0h8ntjv,Pressure cooker,0
551 | /m/0h99cwc,Kitchen appliance,0
552 | /m/0h9mv,Tire,0
553 | /m/0hdln,Ruler,0
554 | /m/0hf58v5,Luggage and bags,0
555 | /m/0hg7b,Microphone,0
556 | /m/0hkxq,Broccoli,0
557 | /m/0hnnb,Umbrella,0
558 | /m/0hnyx,Pastry,0
559 | /m/0hqkz,Grapefruit,0
560 | /m/0j496,Band-aid,0
561 | /m/0jbk,Animal,0
562 | /m/0jg57,Bell pepper,0
563 | /m/0jly1,Turkey,0
564 | /m/0jqgx,Lily,0
565 | /m/0jwn_,Pomegranate,0
566 | /m/0jy4k,Doughnut,0
567 | /m/0jyfg,Glasses,0
568 | /m/0k0pj,Human nose,0
569 | /m/0k1tl,Pen,0
570 | /m/0_k2,Ant,0
571 | /m/0k4j,Car,1
572 | /m/0k5j,Aircraft,1
573 | /m/0k65p,Human hand,1
574 | /m/0km7z,Skunk,0
575 | /m/0kmg4,Teddy bear,0
576 | /m/0kpqd,Watermelon,0
577 | /m/0kpt_,Cantaloupe,0
578 | /m/0ky7b,Dishwasher,0
579 | /m/0l14j_,Flute,0
580 | /m/0l3ms,Balance beam,0
581 | /m/0l515,Sandwich,0
582 | /m/0ll1f78,Shrimp,0
583 | /m/0llzx,Sewing machine,0
584 | /m/0lt4_,Binoculars,0
585 | /m/0m53l,Rays and skates,0
586 | /m/0mcx2,Ipod,0
587 | /m/0mkg,Accordion,0
588 | /m/0mw_6,Willow,0
589 | /m/0n28_,Crab,0
590 | /m/0nl46,Crown,0
591 | /m/0nybt,Seahorse,0
592 | /m/0p833,Perfume,0
593 | /m/0pcr,Alpaca,0
594 | /m/0pg52,Taxi,0
595 | /m/0ph39,Canoe,0
596 | /m/0qjjc,Remote control,0
597 | /m/0qmmr,Wheelchair,0
598 | /m/0wdt60w,Rugby ball,0
599 | /m/0xfy,Armadillo,0
600 | /m/0xzly,Maracas,0
601 | /m/0zvk5,Helmet,0
--------------------------------------------------------------------------------
/src/engine/Groundingrunner.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import pickle
4 | import json
5 | import sys
6 | from functools import partial
7 |
8 | import numpy as np
9 | import torch
10 | from tensorboardX import SummaryWriter
11 | from torchvision import transforms
12 |
13 | import sys
14 | sys.path.append('../..')
15 | from src.util import consts
16 | from src.dataflow.torch.Data import MultiTokenField, SparseField, EntityField
17 | from torchtext.data import Field
18 | from src.util.vocab import Vocab
19 | from torchtext.vocab import Vectors
20 | from src.dataflow.numpy.data_loader_grounding import GroundingDataset
21 | from src.models.grounding import GroundingModel
22 | from src.eval.Groundingtesting import GroundingTester
23 | from src.engine.Groundingtraining import grounding_train
24 | from src.engine.SRrunner import load_sr_model
25 | from src.engine.EErunner import load_ee_model
26 | from src.util.util_model import log
27 |
28 |
29 | class GroundingRunner(object):
30 | def __init__(self):
31 | parser = argparse.ArgumentParser(description="neural networks trainer")
32 | parser.add_argument("--test", help="validation set")
33 | parser.add_argument("--train", help="training set", required=False)
34 | parser.add_argument("--dev", help="development set", required=False)
35 | parser.add_argument("--webd", help="word embedding", required=False)
36 | parser.add_argument("--img_dir", help="Grounding images directory", required=False)
37 | parser.add_argument("--amr", help="use amr", action='store_true')
38 |
39 | # sr model parameter
40 | parser.add_argument("--wnebd", help="noun word embedding", required=False)
41 | parser.add_argument("--wvebd", help="verb word embedding", required=False)
42 | parser.add_argument("--wrebd", help="role word embedding", required=False)
43 | parser.add_argument("--add_object", help="add_object", action='store_true')
44 | parser.add_argument("--object_class_map_file", help="object_class_map_file", required=False)
45 | parser.add_argument("--object_detection_pkl_file", help="object_detection_pkl_file", required=False)
46 | parser.add_argument("--object_detection_threshold", default=0.2, type=float, help="object_detection_threshold",
47 | required=False)
48 |
49 | parser.add_argument("--vocab", help="vocab_dir", required=False)
50 | parser.add_argument("--sr_hps", help="sr model hyperparams", required=False)
51 |
52 | # ee model parameter
53 | parser.add_argument("--ee_hps", help="ee model hyperparams", required=False)
54 |
55 | parser.add_argument("--batch", help="batch size", default=128, type=int)
56 | parser.add_argument("--epochs", help="n of epochs", default=sys.maxsize, type=int)
57 |
58 | parser.add_argument("--seed", help="RNG seed", default=42, type=int)
59 | parser.add_argument("--optimizer", default="adam")
60 | parser.add_argument("--lr", default=1e-3, type=float)
61 | parser.add_argument("--l2decay", default=0, type=float)
62 | parser.add_argument("--maxnorm", default=3, type=float)
63 |
64 | parser.add_argument("--out", help="output model path", default="out")
65 | parser.add_argument("--finetune_sr", help="pretrained sr model path")
66 | parser.add_argument("--finetune_ee", help="pretrained ee model path")
67 | parser.add_argument("--earlystop", default=999999, type=int)
68 | parser.add_argument("--restart", default=999999, type=int)
69 | parser.add_argument("--shuffle", help="shuffle", action='store_true')
70 |
71 | parser.add_argument("--device", default="cpu")
72 |
73 | self.a = parser.parse_args()
74 |
75 | def set_device(self, device="cpu"):
76 | # self.device = torch.device(device)
77 | self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
78 |
79 | def get_device(self):
80 | return self.device
81 |
82 | # def load_model(self, ed_model, sr_model, fine_tune):
83 | # if fine_tune is None:
84 | # train_model = GroundingModel(ed_model, sr_model, self.get_device())
85 | # return train_model
86 | # else:
87 | # mymodel = GroundingModel(ed_model, sr_model, self.get_device())
88 | # mymodel.load_model(fine_tune)
89 | # mymodel.to(self.get_device())
90 | # return mymodel
91 |
92 | def get_tester(self):
93 | return GroundingTester()
94 |
95 | def run(self):
96 | print("Running on", self.a.device)
97 | self.set_device(self.a.device)
98 |
99 | np.random.seed(self.a.seed)
100 | torch.manual_seed(self.a.seed)
101 | torch.backends.cudnn.benchmark = True
102 |
103 | # create training set
104 | if self.a.train:
105 | log('loading corpus from %s' % self.a.train)
106 |
107 | transform = transforms.Compose([
108 | transforms.Resize(256),
109 | transforms.RandomHorizontalFlip(),
110 | transforms.RandomCrop(224),
111 | transforms.ToTensor(),
112 | transforms.Normalize((0.485, 0.456, 0.406),
113 | (0.229, 0.224, 0.225))])
114 |
115 | IMAGEIDField = SparseField(sequential=False, use_vocab=False, batch_first=True)
116 | SENTIDField = SparseField(sequential=False, use_vocab=False, batch_first=True)
117 | # IMAGEField = SparseField(sequential=False, use_vocab=False, batch_first=True)
118 | WordsField = Field(lower=True, include_lengths=True, batch_first=True)
119 | PosTagsField = Field(lower=True, batch_first=True)
120 | EntityLabelsField = MultiTokenField(lower=False, batch_first=True)
121 | AdjMatrixField = SparseField(sequential=False, use_vocab=False, batch_first=True)
122 | EntitiesField = EntityField(lower=False, batch_first=True, use_vocab=False)
123 |
124 | if self.a.amr:
125 | colcc = 'simple-parsing'
126 | else:
127 | colcc = 'combined-parsing'
128 | print(colcc)
129 |
130 | train_set = GroundingDataset(path=self.a.train,
131 | img_dir=self.a.img_dir,
132 | fields={"id": ("IMAGEID", IMAGEIDField),
133 | "sentence_id": ("SENTID", SENTIDField),
134 | "words": ("WORDS", WordsField),
135 | "pos-tags": ("POSTAGS", PosTagsField),
136 | "golden-entity-mentions": ("ENTITYLABELS", EntityLabelsField),
137 | colcc: ("ADJM", AdjMatrixField),
138 | "all-entities": ("ENTITIES", EntitiesField),
139 | # "image": ("IMAGE", IMAGEField),
140 | },
141 | transform=transform,
142 | amr=self.a.amr,
143 | load_object=self.a.add_object,
144 | object_ontology_file=self.a.object_class_map_file,
145 | object_detection_pkl_file=self.a.object_detection_pkl_file,
146 | object_detection_threshold=self.a.object_detection_threshold,
147 | )
148 |
149 | dev_set = GroundingDataset(path=self.a.dev,
150 | img_dir=self.a.img_dir,
151 | fields={"id": ("IMAGEID", IMAGEIDField),
152 | "sentence_id": ("SENTID", SENTIDField),
153 | "words": ("WORDS", WordsField),
154 | "pos-tags": ("POSTAGS", PosTagsField),
155 | "golden-entity-mentions": ("ENTITYLABELS", EntityLabelsField),
156 | colcc: ("ADJM", AdjMatrixField),
157 | "all-entities": ("ENTITIES", EntitiesField),
158 | # "image": ("IMAGE", IMAGEField),
159 | },
160 | transform=transform,
161 | amr=self.a.amr,
162 | load_object=self.a.add_object,
163 | object_ontology_file=self.a.object_class_map_file,
164 | object_detection_pkl_file=self.a.object_detection_pkl_file,
165 | object_detection_threshold=self.a.object_detection_threshold,
166 | )
167 |
168 | test_set = GroundingDataset(path=self.a.test,
169 | img_dir=self.a.img_dir,
170 | fields={"id": ("IMAGEID", IMAGEIDField),
171 | "sentence_id": ("SENTID", SENTIDField),
172 | "words": ("WORDS", WordsField),
173 | "pos-tags": ("POSTAGS", PosTagsField),
174 | "golden-entity-mentions": ("ENTITYLABELS", EntityLabelsField),
175 | colcc: ("ADJM", AdjMatrixField),
176 | "all-entities": ("ENTITIES", EntitiesField),
177 | # "image": ("IMAGE", IMAGEField),
178 | },
179 | transform=transform,
180 | amr=self.a.amr,
181 | load_object=self.a.add_object,
182 | object_ontology_file=self.a.object_class_map_file,
183 | object_detection_pkl_file=self.a.object_detection_pkl_file,
184 | object_detection_threshold=self.a.object_detection_threshold,
185 | )
186 |
187 | if self.a.webd:
188 | pretrained_embedding = Vectors(self.a.webd, ".", unk_init=partial(torch.nn.init.uniform_, a=-0.15, b=0.15))
189 | WordsField.build_vocab(train_set.WORDS, dev_set.WORDS, vectors=pretrained_embedding)
190 | else:
191 | WordsField.build_vocab(train_set.WORDS, dev_set.WORDS)
192 | # WordsField.build_vocab(train_set.WORDS, dev_set.WORDS)
193 | PosTagsField.build_vocab(train_set.POSTAGS, dev_set.POSTAGS)
194 | EntityLabelsField.build_vocab(train_set.ENTITYLABELS, dev_set.ENTITYLABELS)
195 |
196 | # sr model initialization
197 | self.a.sr_hps = eval(self.a.sr_hps)
198 | vocab_noun = Vocab(os.path.join(self.a.vocab, 'vocab_situation_noun.pkl'), load=True)
199 | vocab_role = Vocab(os.path.join(self.a.vocab, 'vocab_situation_role.pkl'), load=True)
200 | vocab_verb = Vocab(os.path.join(self.a.vocab, 'vocab_situation_verb.pkl'), load=True)
201 | embeddingMatrix_noun = torch.FloatTensor(np.load(self.a.wnebd)).to(self.device)
202 | embeddingMatrix_verb = torch.FloatTensor(np.load(self.a.wvebd)).to(self.device)
203 | embeddingMatrix_role = torch.FloatTensor(np.load(self.a.wrebd)).to(self.device)
204 | if "wvemb_size" not in self.a.sr_hps:
205 | self.a.sr_hps["wvemb_size"] = len(vocab_verb.id2word)
206 | if "wremb_size" not in self.a.sr_hps:
207 | self.a.sr_hps["wremb_size"] = len(vocab_role.id2word)
208 | if "wnemb_size" not in self.a.sr_hps:
209 | self.a.sr_hps["wnemb_size"] = len(vocab_noun.id2word)
210 | if "ae_oc" not in self.a.sr_hps:
211 | self.a.sr_hps["ae_oc"] = len(vocab_role.id2word)
212 |
213 | self.a.ee_hps = eval(self.a.ee_hps)
214 | if "wemb_size" not in self.a.ee_hps:
215 | self.a.ee_hps["wemb_size"] = len(WordsField.vocab.itos)
216 | if "pemb_size" not in self.a.ee_hps:
217 | self.a.ee_hps["pemb_size"] = len(PosTagsField.vocab.itos)
218 | if "psemb_size" not in self.a.ee_hps:
219 | self.a.ee_hps["psemb_size"] = max([train_set.longest(), dev_set.longest(), test_set.longest()]) + 2
220 | if "eemb_size" not in self.a.ee_hps:
221 | self.a.ee_hps["eemb_size"] = len(EntityLabelsField.vocab.itos)
222 | if "oc" not in self.a.ee_hps:
223 | self.a.ee_hps["oc"] = 36 #???
224 | if "ae_oc" not in self.a.ee_hps:
225 | self.a.ee_hps["ae_oc"] = 20 #???
226 |
227 | tester = self.get_tester()
228 |
229 | if self.a.finetune_sr:
230 | log('init sr model from ' + self.a.finetune_sr)
231 | sr_model = load_sr_model(self.a.sr_hps, embeddingMatrix_noun, embeddingMatrix_verb, embeddingMatrix_role, self.a.finetune_sr, self.device)
232 | log('sr model loaded, there are %i sets of params' % len(sr_model.parameters_requires_grads()))
233 | else:
234 | sr_model = load_sr_model(self.a.sr_hps, embeddingMatrix_noun, embeddingMatrix_verb, embeddingMatrix_role, None, self.device)
235 | log('sr model created from scratch, there are %i sets of params' % len(sr_model.parameters_requires_grads()))
236 |
237 | if self.a.finetune_ee:
238 | log('init model from ' + self.a.finetune_ee)
239 | ee_model = load_ee_model(self.a.ee_hps, self.a.finetune_ee, WordsField.vocab.vectors, self.device)
240 | log('model loaded, there are %i sets of params' % len(ee_model.parameters_requires_grads()))
241 | else:
242 | ee_model = load_ee_model(self.a.ee_hps, None, WordsField.vocab.vectors, self.device)
243 | log('model created from scratch, there are %i sets of params' % len(ee_model.parameters_requires_grads()))
244 |
245 | model = GroundingModel(ee_model, sr_model, self.get_device())
246 |
247 | if self.a.optimizer == "adadelta":
248 | optimizer_constructor = partial(torch.optim.Adadelta, params=model.parameters_requires_grads(),
249 | weight_decay=self.a.l2decay)
250 | elif self.a.optimizer == "adam":
251 | optimizer_constructor = partial(torch.optim.Adam, params=model.parameters_requires_grads(),
252 | weight_decay=self.a.l2decay)
253 | else:
254 | optimizer_constructor = partial(torch.optim.SGD, params=model.parameters_requires_grads(),
255 | weight_decay=self.a.l2decay,
256 | momentum=0.9)
257 |
258 | log('optimizer in use: %s' % str(self.a.optimizer))
259 |
260 | if not os.path.exists(self.a.out):
261 | os.mkdir(self.a.out)
262 | with open(os.path.join(self.a.out, "word.vec"), "wb") as f:
263 | pickle.dump(WordsField.vocab, f)
264 | with open(os.path.join(self.a.out, "pos.vec"), "wb") as f:
265 | pickle.dump(PosTagsField.vocab.stoi, f)
266 | with open(os.path.join(self.a.out, "entity.vec"), "wb") as f:
267 | pickle.dump(EntityLabelsField.vocab.stoi, f)
268 | with open(os.path.join(self.a.out, "ee_hyps.json"), "w") as f:
269 | json.dump(self.a.ee_hps, f)
270 | with open(os.path.join(self.a.out, "sr_hyps.json"), "w") as f:
271 | json.dump(self.a.sr_hps, f)
272 |
273 | log('init complete\n')
274 |
275 | self.a.word_i2s = vocab_noun.id2word
276 | self.a.label_i2s = vocab_verb.id2word # LabelField.vocab.itos
277 | self.a.role_i2s = vocab_role.id2word
278 | self.a.word_i2s = WordsField.vocab.itos
279 | # self.a.label_i2s = LabelField.vocab.itos
280 | # self.a.role_i2s = EventsField.vocab.itos
281 | writer = SummaryWriter(os.path.join(self.a.out, "exp"))
282 | self.a.writer = writer
283 |
284 | grounding_train(
285 | model=model,
286 | train_set=train_set,
287 | dev_set=dev_set,
288 | test_set=test_set,
289 | optimizer_constructor=optimizer_constructor,
290 | epochs=self.a.epochs,
291 | tester=tester,
292 | parser=self.a,
293 | other_testsets={
294 | # "dev 1/1": dev_set1,
295 | # "test 1/1": test_set1,
296 | },
297 | transform=transform,
298 | vocab_objlabel=vocab_noun.word2id
299 | )
300 | log('Done!')
301 |
302 |
303 | if __name__ == "__main__":
304 | GroundingRunner().run()
305 |
--------------------------------------------------------------------------------