├── src ├── .gitkeep ├── __init__.py ├── eval │ ├── .gitkeep │ ├── __init__.py │ ├── EEtesting.py │ ├── Groundingtesting.py │ ├── SRtesting.py │ └── EEvisualizing.py ├── util │ ├── .gitkeep │ ├── __init__.py │ ├── consts.py │ ├── helper.py │ ├── vocab.py │ ├── util_model.py │ ├── constant.py │ └── util_img.py ├── dataflow │ ├── .gitkeep │ ├── __init__.py │ ├── numpy │ │ ├── __init__.py │ │ ├── .DS_Store │ │ ├── dataset_image_download.py │ │ ├── anno_mapping.py │ │ └── prepare_vocab.py │ └── torch │ │ ├── __init__.py │ │ ├── Corpus.py │ │ ├── Data.py │ │ └── Sentence.py ├── engine │ ├── .gitkeep │ ├── __init__.py │ ├── TestRunnerEE.py │ └── Groundingrunner.py └── models │ ├── .gitkeep │ ├── __init__.py │ ├── modules │ ├── .gitkeep │ ├── __init__.py │ ├── HighWay.py │ ├── FMapLayerImage.py │ ├── EmbeddingLayerImage.py │ ├── model.py │ ├── SelfAttention.py │ ├── DynamicLSTM.py │ ├── GCN.py │ └── EmbeddingLayer.py │ └── ace_classifier.py ├── scripts ├── .gitkeep ├── eval │ ├── .gitkeep │ ├── test_ee_m2e2.sh │ ├── test_sr_m2e2_obj.sh │ ├── test_sr_m2e2.sh │ ├── test_joint_object.sh │ └── test_joint.sh └── train │ ├── .gitkeep │ ├── train_ee.sh │ ├── train_grounding.sh │ ├── train_sr_object.sh │ ├── train_sr.sh │ ├── train_joint_obj.sh │ └── train_joint_att.sh ├── training-testing.jpg ├── data ├── m2e2_annotations.zip ├── ace │ └── ace_sr_mapping.txt └── object │ └── class-descriptions-boxable.csv ├── .idea ├── dictionaries │ └── lynn.xml ├── vcs.xml ├── misc.xml ├── modules.xml └── m2e2-multimedia-event-extraction.iml ├── requirements.txt └── README.md /src/.gitkeep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /scripts/.gitkeep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/eval/.gitkeep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/util/.gitkeep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /scripts/eval/.gitkeep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /scripts/train/.gitkeep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/dataflow/.gitkeep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/engine/.gitkeep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/engine/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/eval/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/models/.gitkeep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/models/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/util/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/dataflow/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/models/modules/.gitkeep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/dataflow/numpy/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/dataflow/torch/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/models/modules/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /training-testing.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/limanling/m2e2/HEAD/training-testing.jpg -------------------------------------------------------------------------------- /data/m2e2_annotations.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/limanling/m2e2/HEAD/data/m2e2_annotations.zip -------------------------------------------------------------------------------- /src/dataflow/numpy/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/limanling/m2e2/HEAD/src/dataflow/numpy/.DS_Store -------------------------------------------------------------------------------- /.idea/dictionaries/lynn.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | -------------------------------------------------------------------------------- /.idea/vcs.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | -------------------------------------------------------------------------------- /.idea/misc.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 6 | 7 | -------------------------------------------------------------------------------- /src/util/consts.py: -------------------------------------------------------------------------------- 1 | UNK_LABEL = "" 2 | UNK_IDX = 0 3 | PADDING_LABEL = "" 4 | PADDING_IDX = 1 5 | 6 | VOCAB_START_IDX = 2 7 | 8 | CUTOFF = 50 9 | TRIGGER_GOLDEN_COMPENSATION = 12 10 | ARGUMENT_GOLDEN_COMPENSATION = 27 11 | 12 | O_LABEL = None 13 | ROLE_O_LABEL = None 14 | 15 | O_LABEL_NAME = "O" 16 | ROLE_O_LABEL_NAME = "OTHER" -------------------------------------------------------------------------------- /.idea/modules.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torchtext==0.4.0 2 | ujson==1.35 3 | seqeval==0.0.12 4 | six==1.12.0 5 | jieba==0.39 6 | nltk==3.4.5 7 | scipy==1.3.0 8 | torchvision==0.3.0 9 | Pillow==8.1.1 10 | common==0.1.2 11 | constants==0.6.0 12 | graphstate==1.0.6 13 | ipdb==0.12.2 14 | preprocessing==0.1.13 15 | stanfordcorenlp==3.9.1.1 16 | tensorboardX==1.8 17 | utils==0.9.0 18 | -------------------------------------------------------------------------------- /.idea/m2e2-multimedia-event-extraction.iml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 12 | -------------------------------------------------------------------------------- /src/dataflow/torch/Corpus.py: -------------------------------------------------------------------------------- 1 | from torchtext.data import Dataset 2 | 3 | 4 | class Corpus(Dataset): 5 | def __init__(self, path, fields, amr, **kwargs): 6 | ''' 7 | Create a corpus given a path, field list, and a filter function. 8 | 9 | :param path: str, Path to the data file 10 | :param fields: dict[str: tuple(str, Field)], 11 | If using a dict, the keys should be a subset of the JSON keys or CSV/TSV 12 | columns, and the values should be tuples of (name, field). 13 | Keys not present in the input dictionary are ignored. 14 | This allows the user to rename columns from their JSON/CSV/TSV key names 15 | and also enables selecting a subset of columns to load. 16 | ''' 17 | self.path = path 18 | self._size = None 19 | 20 | examples = self.parse_example(path, fields, amr, **kwargs) 21 | fields = list(fields.values()) 22 | super(Corpus, self).__init__(examples, fields, **kwargs) 23 | 24 | def parse_example(self, path, fields, amr, **kwargs): 25 | raise NotImplementedError -------------------------------------------------------------------------------- /src/models/ace_classifier.py: -------------------------------------------------------------------------------- 1 | from src.models.modules.model import Model 2 | from src.util.util_model import BottledXavierLinear 3 | from torch import nn 4 | 5 | class ACEClassifier(Model): 6 | def __init__(self, common_dim, type_num, role_num, device): 7 | # self.common_dim, hyps["oc"], hyps["ae_oc"] 8 | super(ACEClassifier, self).__init__() 9 | 10 | self.device = device 11 | 12 | # Output Linear 13 | self.ol = BottledXavierLinear(in_features=common_dim, out_features=type_num).to(device=device) 14 | 15 | # AE Output Linear 16 | self.ae_ol = BottledXavierLinear(in_features=2 * common_dim, out_features=role_num).to(device=device) 17 | # self.ae_l1 = nn.Linear(in_features=2 * common_dim, out_features=common_dim) 18 | # self.ae_bn1 = nn.BatchNorm1d(num_features=common_dim) 19 | # self.ae_l2 = nn.Linear(in_features=common_dim, out_features=role_num) 20 | 21 | # Move to right device 22 | self.to(self.device) 23 | 24 | def forward_type(self, feature_in): 25 | ed_logits = self.ol(feature_in) 26 | return ed_logits 27 | 28 | def forward_role(self, entity_feature_in): 29 | ae_logits = self.ae_ol(entity_feature_in) 30 | return ae_logits 31 | 32 | 33 | -------------------------------------------------------------------------------- /scripts/train/train_ee.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | export CUDA_VISIBLE_DEVICES=$2 4 | 5 | logdir='../../log/ee/'$1 6 | mkdir -p $logdir 7 | 8 | datadir="/scratch/manling2/data/mm-event-graph" 9 | glovedir="/scratch/manling2/data/glove" 10 | 11 | python ../../src/engine/EErunner.py \ 12 | --train_ee $datadir"/ace/JMEE_train_filter_no_timevalue.json" \ 13 | --test_ee $datadir"/ace/JMEE_test_filter_no_timevalue.json" \ 14 | --dev_ee $datadir"/ace/JMEE_dev_filter_no_timevalue.json" \ 15 | --earlystop 10 --optimizer "adam" --lr 1e-4 \ 16 | --webd $glovedir"/glove.840B.300d.txt" \ 17 | --batch 32 --epochs 100 --device "cuda" --out $logdir \ 18 | --shuffle \ 19 | --hps "{ 20 | 'wemb_dim': 300, 21 | 'wemb_ft': True, 22 | 'wemb_dp': 0.5, 23 | 'pemb_dim': 50, 24 | 'pemb_dp': 0.5, 25 | 'eemb_dim': 50, 26 | 'eemb_dp': 0.5, 27 | 'psemb_dim': 50, 28 | 'psemb_dp': 0.5, 29 | 'lstm_dim': 220, 30 | 'lstm_layers': 1, 31 | 'lstm_dp': 0, 32 | 'gcn_et': 3, 33 | 'gcn_use_bn': True, 34 | 'gcn_layers': 2, 35 | 'gcn_dp': 0.5, 36 | 'sa_dim': 300, 37 | 'use_highway': True, 38 | 'loss_alpha': 5 39 | }" \ 40 | >& $logdir/stdout.log & 41 | -------------------------------------------------------------------------------- /src/dataflow/numpy/dataset_image_download.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import urllib.request 4 | import json 5 | 6 | def download_image(url_image, path_save): 7 | urllib.request.urlretrieve(url_image, path_save) 8 | 9 | def download_image_list(meta_json, dir_save): 10 | if not os.path.exists(meta_json): 11 | print('[ERROR] input_metadata_json does not exist.') 12 | metadata = json.load(open(meta_json)) 13 | for doc_id in metadata: 14 | for img_id in metadata[doc_id]: 15 | url_image = metadata[doc_id][img_id]['url'] 16 | suffix_image = metadata[doc_id][img_id]['url'].split('.')[-1] 17 | image_path_save = os.path.join(dir_save, '%s_%s.%s' % (doc_id, img_id, suffix_image)) 18 | # if not url_image.endswith('.jpg'): 19 | # print(image_path_save, url_image) 20 | download_image(url_image, image_path_save) 21 | 22 | 23 | if __name__ == '__main__': 24 | parser = argparse.ArgumentParser() 25 | parser.add_argument('input_metadata_json', type=str, 26 | help='input_metadata_json') 27 | parser.add_argument('output_image_dir', type=str, 28 | help='output_image_dir') 29 | args = parser.parse_args() 30 | 31 | input_metadata_json = args.input_metadata_json 32 | output_image_dir = args.output_image_dir 33 | 34 | download_image_list(input_metadata_json, output_image_dir) -------------------------------------------------------------------------------- /src/models/modules/HighWay.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from torch.nn import functional as F 3 | 4 | import sys 5 | #sys.path.append('/dvmm-filer2/users/manling/mm-event-graph2') 6 | from src.util.util_model import BottledXavierLinear 7 | 8 | 9 | class HighWay(nn.Module): 10 | def __init__(self, size, num_layers=1, dropout_ratio=0.5): 11 | super(HighWay, self).__init__() 12 | self.size = size 13 | self.num_layers = num_layers 14 | self.trans = nn.ModuleList() 15 | self.gate = nn.ModuleList() 16 | self.dropout = dropout_ratio 17 | 18 | for i in range(num_layers): 19 | tmptrans = BottledXavierLinear(size, size) 20 | tmpgate = BottledXavierLinear(size, size) 21 | self.trans.append(tmptrans) 22 | self.gate.append(tmpgate) 23 | 24 | def forward(self, x): 25 | ''' 26 | forward this module 27 | :param x: torch.FloatTensor, (N, D) or (N1, N2, D) 28 | :return: torch.FloatTensor, (N, D) or (N1, N2, D) 29 | ''' 30 | 31 | g = F.sigmoid(self.gate[0](x)) 32 | h = F.relu(self.trans[0](x)) 33 | x = g * h + (1 - g) * x 34 | 35 | for i in range(1, self.num_layers): 36 | x = F.dropout(x, p=self.dropout, training=self.training) 37 | g = F.sigmoid(self.gate[i](x)) 38 | h = F.relu(self.trans[i](x)) 39 | x = g * h + (1 - g) * x 40 | 41 | return x 42 | -------------------------------------------------------------------------------- /src/models/modules/FMapLayerImage.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torchvision.models as models 4 | from torch.nn import functional as F 5 | 6 | class Flatten(torch.nn.Module): 7 | def forward(self, x): 8 | batch_size = x.shape[0] 9 | return x.view(batch_size, -1) 10 | 11 | class FMapLayerImage(nn.Module): 12 | def __init__(self, fine_tune=False, dropout=0.5, 13 | device=torch.device("cpu"), backbone='resnet152'): 14 | 15 | super(FMapLayerImage, self).__init__() 16 | 17 | net = getattr(models, backbone)(pretrained=True) 18 | #print(net) 19 | if backbone=='vgg16': 20 | b1, pool, b2 = list(net.children()) 21 | modules_1 = list(b1.children()) 22 | modules_2 = [pool, Flatten()] + list(b2.children())[:-1] 23 | 24 | else: 25 | b1 = list(net.children()) 26 | modules_1 = b1[:-2] 27 | modules_2 = [b1[-2]] 28 | 29 | 30 | self.backbone = nn.Sequential(*modules_1) 31 | self.pooler = nn.Sequential(*modules_2) 32 | #print(self.backbone, self.pooler) 33 | 34 | 35 | for p in self.backbone.parameters(): 36 | p.requires_grad = fine_tune 37 | for p in self.pooler.parameters(): 38 | p.requires_grad = fine_tune 39 | 40 | self.device = device 41 | self.to(device) 42 | 43 | def forward(self, images): 44 | fmap = self.backbone(images) 45 | emb = self.pooler(fmap) 46 | return fmap, emb -------------------------------------------------------------------------------- /scripts/eval/test_ee_m2e2.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | export CUDA_VISIBLE_DEVICES=$2 4 | 5 | logdir='../../log/ee_test/'$1 6 | mkdir -p $logdir 7 | 8 | datadir="/scratch/manling2/data/mm-event-graph" 9 | glovedir="/scratch/manling2/data/glove" 10 | checkpoint="/scratch/manling2/mm-event-graph/log/joint/obj_gcn2/model/model_ee_17.pt" 11 | checkpoint_params="/scratch/manling2/mm-event-graph/log/joint/obj_gcn2/ee_hyps.json" 12 | 13 | python ../../src/engine/TestRunnerEE_m2e2.py \ 14 | --train_grounding $datadir"/grounding/grounding_train_20000.json" \ 15 | --test_grounding $datadir"/grounding/grounding_test_20000.json" \ 16 | --dev_grounding $datadir"/grounding/grounding_valid_20000.json" \ 17 | --img_dir_grounding $datadir"/voa/rawdata/VOA_image_en/" \ 18 | --object_detection_pkl_file_g $datadir"/voa/object_detection/det_results_voa_oi_1.pkl" \ 19 | --object_class_map_file $datadir"/object/class-descriptions-boxable.csv" \ 20 | --object_detection_threshold 0.2 \ 21 | --train_ee $datadir"/ace/JMEE_train_filter_no_timevalue.json" \ 22 | --test_ee $datadir"/ace/JMEE_test_filter_no_timevalue.json" \ 23 | --dev_ee $datadir"/ace/JMEE_dev_filter_no_timevalue.json" \ 24 | --webd $glovedir"/glove.840B.300d.txt" \ 25 | --batch 32 --device "cuda" --out $logdir \ 26 | --finetune ${checkpoint} \ 27 | --hps_path ${checkpoint_params} \ 28 | --gt_voa_text $datadir"/voa_anno_m2e2/article_event.json" \ 29 | --keep_events 1 \ 30 | --load_grounding \ 31 | >& $logdir/testout.log & 32 | 33 | -------------------------------------------------------------------------------- /src/models/modules/EmbeddingLayerImage.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torchvision.models as models 4 | from torch.nn import functional as F 5 | 6 | class Flatten(torch.nn.Module): 7 | def forward(self, x): 8 | batch_size = x.shape[0] 9 | return x.view(batch_size, -1) 10 | 11 | class EmbeddingLayerImage(nn.Module): 12 | def __init__(self, fine_tune=False, dropout=0.5, 13 | device=torch.device("cpu"), backbone='resnet152'): 14 | 15 | super(EmbeddingLayerImage, self).__init__() 16 | 17 | resnet = getattr(models, backbone)(pretrained=True) 18 | #print(resnet) 19 | if backbone=='vgg16': 20 | b1, pool, b2 = list(resnet.children()) 21 | modules = list(b1.children()) + [pool, Flatten()] + list(b2.children())[:-1] 22 | self.dropout = None 23 | else: 24 | modules = list(resnet.children())[:-1] 25 | self.dropout = dropout if type(dropout) == float and -1e-7 < dropout < 1 + 1e-7 else None 26 | 27 | self.resnet = nn.Sequential(*modules) 28 | #print(self.resnet) 29 | 30 | 31 | for p in self.resnet.parameters(): 32 | p.requires_grad = fine_tune 33 | 34 | self.device = device 35 | self.to(device) 36 | 37 | def forward(self, images): 38 | if self.dropout is not None: 39 | return F.dropout(self.resnet(images), p=self.dropout, training=self.training) 40 | else: 41 | return self.resnet(images) # batchsize * 2048 * imageLength * imageLength (imageLength=1) -------------------------------------------------------------------------------- /src/util/helper.py: -------------------------------------------------------------------------------- 1 | """ 2 | Helper functions. 3 | """ 4 | 5 | import os 6 | import subprocess 7 | import json 8 | import argparse 9 | 10 | ### IO 11 | def check_dir(d): 12 | if not os.path.exists(d): 13 | print("Directory {} does not exist. Exit.".format(d)) 14 | exit(1) 15 | 16 | def check_files(files): 17 | for f in files: 18 | if f is not None and not os.path.exists(f): 19 | print("File {} does not exist. Exit.".format(f)) 20 | exit(1) 21 | 22 | def ensure_dir(d, verbose=True): 23 | if not os.path.exists(d): 24 | if verbose: 25 | print("Directory {} do not exist; creating...".format(d)) 26 | os.makedirs(d) 27 | 28 | def save_config(config, path, verbose=True): 29 | with open(path, 'w') as outfile: 30 | json.dump(config, outfile, indent=2) 31 | if verbose: 32 | print("Config saved to file {}".format(path)) 33 | return config 34 | 35 | def load_config(path, verbose=True): 36 | with open(path) as f: 37 | config = json.load(f) 38 | if verbose: 39 | print("Config loaded from file {}".format(path)) 40 | return config 41 | 42 | def print_config(config): 43 | info = "Running with the following configs:\n" 44 | for k,v in config.items(): 45 | info += "\t{} : {}\n".format(k, str(v)) 46 | print("\n" + info + "\n") 47 | return 48 | 49 | class FileLogger(object): 50 | """ 51 | A file logger that opens the file periodically and write to it. 52 | """ 53 | def __init__(self, filename, header=None): 54 | self.filename = filename 55 | if os.path.exists(filename): 56 | # remove the old file 57 | os.remove(filename) 58 | if header is not None: 59 | with open(filename, 'w') as out: 60 | print(header, file=out) 61 | 62 | def log(self, message): 63 | with open(self.filename, 'a') as out: 64 | print(message, file=out) 65 | 66 | -------------------------------------------------------------------------------- /scripts/train/train_grounding.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | export CUDA_VISIBLE_DEVICES=$2 4 | 5 | logdir='../../log/grounding/'$1 6 | mkdir -p $logdir 7 | 8 | datadir="/scratch/manling2/data/mm-event-graph" 9 | glovedir="/scratch/manling2/data/glove" 10 | 11 | python ../../src/engine/Groundingrunner.py \ 12 | --train $datadir"/grounding/grounding_train.json" \ 13 | --test $datadir"/grounding/grounding_test.json" \ 14 | --dev $datadir"/grounding/grounding_valid.json" \ 15 | --earlystop 10 --restart 999999 --optimizer "adam" --lr 0.0001 \ 16 | --wnebd $datadir"/vocab/embedding_situation_noun.npy" \ 17 | --wvebd $datadir"/vocab/embedding_situation_verb.npy" \ 18 | --wrebd $datadir"/vocab/embedding_situation_role.npy" \ 19 | --vocab $datadir"/vocab/" \ 20 | --webd $glovedir"/glove.840B.300d.txt" \ 21 | --img_dir $datadir"/voa/rawdata/VOA_image_en/" \ 22 | --shuffle \ 23 | --batch 32 --epochs 100 --device "cuda" --out $logdir \ 24 | --ee_hps "{ 25 | 'wemb_dim': 300, 26 | 'wemb_ft': True, 27 | 'wemb_dp': 0.5, 28 | 'pemb_dim': 50, 29 | 'pemb_dp': 0.5, 30 | 'eemb_dim': 50, 31 | 'eemb_dp': 0.5, 32 | 'psemb_dim': 50, 33 | 'psemb_dp': 0.5, 34 | 'lstm_dim': 150, 35 | 'lstm_layers': 1, 36 | 'lstm_dp': 0, 37 | 'gcn_et': 3, 38 | 'gcn_use_bn': True, 39 | 'gcn_layers': 3, 40 | 'gcn_dp': 0.5, 41 | 'sa_dim': 300, 42 | 'use_highway': True, 43 | 'loss_alpha': 5 44 | }" \ 45 | --sr_hps "{ 46 | 'wemb_dim': 300, 47 | 'wemb_ft': True, 48 | 'wemb_dp': 0.0, 49 | 'iemb_backbone': 'vgg16', 50 | 'iemb_dim':4096, 51 | 'iemb_ft': False, 52 | 'iemb_dp': 0.0, 53 | 'posemb_dim': 512, 54 | 'fmap_dim': 512, 55 | 'fmap_size': 7, 56 | 'att_dim': 1024, 57 | 'loss_weight_verb': 1.0, 58 | 'loss_weight_noun': 0.1, 59 | 'loss_weight_role': 0.0 60 | }" \ 61 | >& $logdir/stdout.log & 62 | -------------------------------------------------------------------------------- /scripts/train/train_sr_object.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | export CUDA_VISIBLE_DEVICES=$2 4 | 5 | logdir='../../log/sr/'$1 6 | mkdir -p $logdir 7 | 8 | datadir="/scratch/manling2/data/mm-event-graph" 9 | glovedir="/scratch/manling2/data/glove" 10 | 11 | python ../../src/engine/SRrunner.py \ 12 | --train_sr $datadir"/imSitu/train.json" \ 13 | --test_sr $datadir"/imSitu/test.json" \ 14 | --dev_sr $datadir"/imSitu/dev.json" \ 15 | --train_ee $datadir"/ace/JMEE_train_filter_no_timevalue.json" \ 16 | --webd $glovedir"/glove.840B.300d.txt" \ 17 | --earlystop 10 --restart 999999 --optimizer "adam" --lr 0.0001 \ 18 | --wnebd $datadir"/vocab/embedding_situation_noun.npy" \ 19 | --wvebd $datadir"/vocab/embedding_situation_verb.npy" \ 20 | --wrebd $datadir"/vocab/embedding_situation_role.npy" \ 21 | --vocab $datadir"/vocab/" \ 22 | --image_dir $datadir"/imSitu/of500_images_resized" \ 23 | --imsitu_ontology_file $datadir"/imSitu/imsitu_space.json" \ 24 | --verb_mapping_file $datadir"/ace/ace_sr_mapping.txt" \ 25 | --object_class_map_file $datadir"/object/class-descriptions-boxable.csv" \ 26 | --object_detection_pkl_file $datadir"/imSitu/object_detection/det_results_imsitu_oi_1.pkl" \ 27 | --object_detection_threshold 0.2 \ 28 | --shuffle \ 29 | --add_object \ 30 | --filter_place \ 31 | --batch 12 --epochs 100 --device "cuda" --out $logdir \ 32 | --hps "{ 33 | 'wemb_dim': 300, 34 | 'wemb_ft': False, 35 | 'wemb_dp': 0.0, 36 | 'iemb_backbone': 'vgg16', 37 | 'iemb_dim':4096, 38 | 'iemb_ft': False, 39 | 'iemb_dp': 0.0, 40 | 'posemb_dim': 512, 41 | 'fmap_dim': 512, 42 | 'fmap_size': 7, 43 | 'att_dim': 1024, 44 | 'loss_weight_verb': 1.0, 45 | 'loss_weight_noun': 0.1, 46 | 'loss_weight_role': 0.0, 47 | 'gcn_layers': 1, 48 | 'gcn_dp': False, 49 | 'gcn_use_bn': False, 50 | 'use_highway': False, 51 | }" \ 52 | >& $logdir/stdout.log & 53 | 54 | #--filter_irrelevant_verbs 55 | # --train_ace \ -------------------------------------------------------------------------------- /scripts/train/train_sr.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | export CUDA_VISIBLE_DEVICES=$2 4 | 5 | logdir='../../log/sr/'$1 6 | mkdir -p $logdir 7 | 8 | datadir="/scratch/manling2/data/mm-event-graph" 9 | glovedir="/scratch/manling2/data/glove" 10 | 11 | python ../../src/engine/SRrunner.py \ 12 | --train_sr $datadir"/imSitu/train.json" \ 13 | --test_sr $datadir"/imSitu/test.json" \ 14 | --dev_sr $datadir"/imSitu/dev.json" \ 15 | --train_ee $datadir"/ace/JMEE_train_filter_no_timevalue.json" \ 16 | --webd $glovedir"/glove.840B.300d.txt" \ 17 | --earlystop 10 --restart 999999 \ 18 | --wnebd $datadir"/vocab/embedding_situation_noun.npy" \ 19 | --wvebd $datadir"/vocab/embedding_situation_verb.npy" \ 20 | --wrebd $datadir"/vocab/embedding_situation_role.npy" \ 21 | --vocab $datadir"/vocab/" \ 22 | --image_dir $datadir"/imSitu/of500_images_resized" \ 23 | --imsitu_ontology_file $datadir"/imSitu/imsitu_space.json" \ 24 | --verb_mapping_file $datadir"/ace/ace_sr_mapping.txt" \ 25 | --object_class_map_file $datadir"/object/class-descriptions-boxable.csv" \ 26 | --object_detection_pkl_file $datadir"/imSitu/object_detection/det_results_imsitu_oi_1.pkl" \ 27 | --object_detection_threshold 0.2 \ 28 | --shuffle \ 29 | --filter_place \ 30 | --batch 8 --epochs 40 --device "cuda" --out $logdir \ 31 | --optimizer "adam" --lr 0.0001 \ 32 | --hps "{ 33 | 'wemb_dim': 300, 34 | 'wemb_ft': False, 35 | 'wemb_dp': 0.0, 36 | 'iemb_backbone': 'vgg16', 37 | 'iemb_dim':4096, 38 | 'iemb_ft': False, 39 | 'iemb_dp': 0.0, 40 | 'posemb_dim': 512, 41 | 'fmap_dim': 512, 42 | 'fmap_size': 7, 43 | 'att_dim': 1024, 44 | 'loss_weight_verb': 1.0, 45 | 'loss_weight_noun': 0.1, 46 | 'loss_weight_role': 0.0, 47 | 'gcn_layers': 1, 48 | 'gcn_dp': False, 49 | 'gcn_use_bn': False, 50 | 'use_highway': False, 51 | }" \ 52 | >& $logdir/stdout.log & 53 | 54 | # --train_ace 55 | # --filter_irrelevant_verbs 56 | # --optimizer "adam" --lr 0.0001 \ -------------------------------------------------------------------------------- /scripts/eval/test_sr_m2e2_obj.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | export CUDA_VISIBLE_DEVICES=$2 4 | 5 | logdir='../../log/sr_test/'$1 6 | mkdir -p $logdir 7 | mkdir -p $logdir/'image_result' 8 | 9 | datadir="/scratch/manling2/data/mm-event-graph" 10 | glovedir="/scratch/manling2/data/glove" 11 | checkpoint="/scratch/manling2/mm-event-graph/log/joint/obj_gcn2/model/model_sr_37.pt" 12 | checkpoint_params="/scratch/manling2/mm-event-graph/log/joint/obj_gcn2/sr_hyps.json" 13 | 14 | python ../../src/engine/TestRunnerSR_m2e2.py \ 15 | --train_sr $datadir"/imSitu/train.json" \ 16 | --test_sr $datadir"/imSitu/test.json" \ 17 | --dev_sr $datadir"/imSitu/dev.json" \ 18 | --wnebd $datadir"/vocab/embedding_situation_noun.npy" \ 19 | --wvebd $datadir"/vocab/embedding_situation_verb.npy" \ 20 | --wrebd $datadir"/vocab/embedding_situation_role.npy" \ 21 | --vocab $datadir"/vocab/" \ 22 | --image_dir $datadir"/imSitu/of500_images_resized" \ 23 | --imsitu_ontology_file $datadir"/imSitu/imsitu_space.json" \ 24 | --verb_mapping_file $datadir"/ace/ace_sr_mapping.txt" \ 25 | --object_class_map_file $datadir"/object/class-descriptions-boxable.csv" \ 26 | --filter_irrelevant_verbs --filter_place \ 27 | --train_ee $datadir"/ace/JMEE_train_filter_no_timevalue.json" \ 28 | --test_ee $datadir"/ace/JMEE_test_filter_no_timevalue.json" \ 29 | --dev_ee $datadir"/ace/JMEE_dev_filter_no_timevalue.json" \ 30 | --webd $glovedir"/glove.840B.300d.txt" \ 31 | --train_grounding $datadir"/grounding/grounding_train_20000.json" \ 32 | --test_grounding $datadir"/grounding/grounding_test_20000.json" \ 33 | --dev_grounding $datadir"/grounding/grounding_valid_20000.json" \ 34 | --batch 1 --device "cuda" --out $logdir \ 35 | --test_voa_image $datadir"/voa/rawdata/VOA_image_en/" \ 36 | --gt_voa_image $datadir"/voa_anno_m2e2/image_event.json" \ 37 | --gt_voa_text $datadir"/voa_anno_m2e2/article_event.json" \ 38 | --finetune_sr ${checkpoint} \ 39 | --sr_hps_path ${checkpoint_params} \ 40 | --object_detection_pkl_file $datadir"/voa/object_detection/det_results_voa_oi_1.pkl" \ 41 | --ignore_place_sr_test \ 42 | --ignore_time_test \ 43 | --object_detection_threshold 0.1 \ 44 | --keep_events 1 \ 45 | --keep_events_sr 0 \ 46 | --add_object \ 47 | >& $logdir/testout.log & 48 | 49 | 50 | 51 | -------------------------------------------------------------------------------- /scripts/eval/test_sr_m2e2.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | export CUDA_VISIBLE_DEVICES=$2 4 | 5 | logdir='../../log/sr_test/'$1 6 | mkdir -p $logdir 7 | mkdir -p $logdir/'image_result' 8 | 9 | datadir="/scratch/manling2/data/mm-event-graph" 10 | glovedir="/scratch/manling2/data/glove" 11 | checkpoint="/scratch/manling2/mm-event-graph/log/joint/att_gcn3/model/model_sr_4.pt" 12 | checkpoint_params="/scratch/manling2/mm-event-graph/log/joint/att_gcn3/sr_hyps.json" 13 | 14 | python ../../src/engine/TestRunnerSR_m2e2.py \ 15 | --train_sr $datadir"/imSitu/train.json" \ 16 | --test_sr $datadir"/imSitu/test.json" \ 17 | --dev_sr $datadir"/imSitu/dev.json" \ 18 | --wnebd $datadir"/vocab/embedding_situation_noun.npy" \ 19 | --wvebd $datadir"/vocab/embedding_situation_verb.npy" \ 20 | --wrebd $datadir"/vocab/embedding_situation_role.npy" \ 21 | --vocab $datadir"/vocab/" \ 22 | --image_dir $datadir"/imSitu/of500_images_resized" \ 23 | --imsitu_ontology_file $datadir"/imSitu/imsitu_space.json" \ 24 | --verb_mapping_file $datadir"/ace/ace_sr_mapping.txt" \ 25 | --object_class_map_file $datadir"/object/class-descriptions-boxable.csv" \ 26 | --filter_irrelevant_verbs --filter_place \ 27 | --train_ee $datadir"/ace/JMEE_train_filter_no_timevalue.json" \ 28 | --test_ee $datadir"/ace/JMEE_test_filter_no_timevalue.json" \ 29 | --dev_ee $datadir"/ace/JMEE_dev_filter_no_timevalue.json" \ 30 | --webd $glovedir"/glove.840B.300d.txt" \ 31 | --train_grounding $datadir"/grounding/grounding_train_20000.json" \ 32 | --test_grounding $datadir"/grounding/grounding_test_20000.json" \ 33 | --dev_grounding $datadir"/grounding/grounding_valid_20000.json" \ 34 | --batch 1 --device "cuda" --out $logdir \ 35 | --test_voa_image $datadir"/voa/rawdata/VOA_image_en/" \ 36 | --gt_voa_image $datadir"/voa_anno_m2e2/image_event.json" \ 37 | --gt_voa_text $datadir"/voa_anno_m2e2/article_event.json" \ 38 | --finetune_sr ${checkpoint} \ 39 | --sr_hps_path ${checkpoint_params} \ 40 | --object_detection_pkl_file $datadir"/voa/object_detection/det_results_voa_oi_1.pkl" \ 41 | --ignore_place_sr_test \ 42 | --ignore_time_test \ 43 | --object_detection_threshold 0.1 \ 44 | --keep_events 1 \ 45 | --keep_events_sr 0 \ 46 | --visual_voa_sr_path $logdir/'image_result' \ 47 | >& $logdir/testout.log & 48 | 49 | 50 | 51 | -------------------------------------------------------------------------------- /src/models/modules/model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | import sys 5 | #sys.path.append('/dvmm-filer2/users/manling/mm-event-graph2') 6 | from src.util.util_model import log 7 | 8 | 9 | class Model(nn.Module): 10 | def __init__(self, *args, **kwargs): 11 | super(Model, self).__init__() 12 | self.hyperparams = None 13 | self.device = torch.device("cpu") 14 | 15 | def __getnewargs__(self): 16 | # for pickle 17 | return self.hyperparams 18 | 19 | def __new__(cls, *args, **kwargs): 20 | log('created %s with params %s' % (str(cls), str(args))) 21 | 22 | instance = super(Model, cls).__new__(cls) 23 | instance.__init__(*args, **kwargs) 24 | return instance 25 | 26 | def test_mode_on(self): 27 | self.test_mode = True 28 | self.eval() 29 | 30 | def test_mode_off(self): 31 | self.test_mode = False 32 | self.train() 33 | 34 | def parameters_requires_grads(self): 35 | return list(filter(lambda p: p.requires_grad, self.parameters())) 36 | 37 | def parameters_requires_grad_clipping(self): 38 | return self.parameters_requires_grads() 39 | 40 | def save_model(self, path): 41 | state_dict = self.state_dict() 42 | for key, value in state_dict.items(): 43 | # print('save key', key) 44 | state_dict[key] = value.cpu() 45 | torch.save(state_dict, path) 46 | 47 | def load_model(self, path, load_partial=False): 48 | pretrained_dict = torch.load(path) 49 | try: 50 | self.load_state_dict(pretrained_dict) 51 | except Exception as e: 52 | if load_partial: 53 | # load matched part 54 | model_dict = self.state_dict() 55 | ignore_keys = set(['ace_classifier.ol.linear.weight', 56 | 'ace_classifier.ol.linear.bias', 57 | 'ace_classifier.ae_ol.linear.weight', 58 | 'ace_classifier.ae_ol.linear.bias']) 59 | # 1. filter out unnecessary keys 60 | pretrained_dict = {k: v for k, v in pretrained_dict.items() if k not in ignore_keys} #in ignore_keys} 61 | # 2. overwrite entries in the existing state dict 62 | model_dict.update(pretrained_dict) 63 | # 3. load the new state dict 64 | self.load_state_dict(model_dict) 65 | else: 66 | print(e) 67 | exit(-1) 68 | -------------------------------------------------------------------------------- /scripts/train/train_joint_obj.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | export CUDA_VISIBLE_DEVICES=$2 4 | 5 | logdir='../../log/joint/'$1 6 | mkdir -p $logdir 7 | 8 | datadir="/scratch/manling2/data/mm-event-graph" 9 | glovedir="/scratch/manling2/data/glove" 10 | 11 | python ../../src/engine/JOINTRunner.py \ 12 | --train_sr $datadir"/imSitu/train.json" \ 13 | --test_sr $datadir"/imSitu/test.json" \ 14 | --dev_sr $datadir"/imSitu/dev.json" \ 15 | --wnebd $datadir"/vocab/embedding_situation_noun.npy" \ 16 | --wvebd $datadir"/vocab/embedding_situation_verb.npy" \ 17 | --wrebd $datadir"/vocab/embedding_situation_role.npy" \ 18 | --vocab $datadir"/vocab/" \ 19 | --image_dir $datadir"/imSitu/of500_images_resized" \ 20 | --imsitu_ontology_file $datadir"/imSitu/imsitu_space.json" \ 21 | --verb_mapping_file $datadir"/ace/ace_sr_mapping.txt" \ 22 | --object_class_map_file $datadir"/object/class-descriptions-boxable.csv" \ 23 | --object_detection_pkl_file $datadir"/imSitu/object_detection/det_results_imsitu_oi_1.pkl" \ 24 | --train_ee $datadir"/ace/JMEE_train_filter_no_timevalue.json" \ 25 | --test_ee $datadir"/ace/JMEE_test_filter_no_timevalue.json" \ 26 | --dev_ee $datadir"/ace/JMEE_dev_filter_no_timevalue.json" \ 27 | --webd $glovedir"/glove.840B.300d.txt" \ 28 | --ee_hps "{ 29 | 'wemb_dim': 300, 30 | 'wemb_ft': True, 31 | 'wemb_dp': 0.5, 32 | 'pemb_dim': 50, 33 | 'pemb_dp': 0.5, 34 | 'eemb_dim': 50, 35 | 'eemb_dp': 0.5, 36 | 'psemb_dim': 50, 37 | 'psemb_dp': 0.5, 38 | 'lstm_dim': 150, 39 | 'lstm_layers': 1, 40 | 'lstm_dp': 0, 41 | 'gcn_et': 3, 42 | 'gcn_use_bn': True, 43 | 'gcn_layers': 3, 44 | 'gcn_dp': 0.5, 45 | 'sa_dim': 300, 46 | 'use_highway': True, 47 | 'loss_alpha': 5 48 | }" \ 49 | --train_grounding $datadir"/grounding/grounding_train_20000.json" \ 50 | --test_grounding $datadir"/grounding/grounding_test_20000.json" \ 51 | --dev_grounding $datadir"/grounding/grounding_valid_20000.json" \ 52 | --img_dir_grounding $datadir"/voa/rawdata/VOA_image_en/" \ 53 | --object_detection_pkl_file_g $datadir"/voa/object_detection/det_results_voa_oi_1.pkl" \ 54 | --earlystop 10 --restart 999999 --optimizer "adadelta" --lr 1 \ 55 | --finetune_sr '/scratch/manling2/mm-event-graph/log/sr/retest/model.pt' \ 56 | --sr_hps_path '/scratch/manling2/mm-event-graph/log/sr/retest/sr_hyps.json' \ 57 | --object_detection_threshold 0.0 \ 58 | --shuffle \ 59 | --filter_place \ 60 | --add_object \ 61 | --batch 4 --epochs 100 --device "cuda" --out $logdir \ 62 | >& $logdir/stdout.log & 63 | -------------------------------------------------------------------------------- /src/models/modules/SelfAttention.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class AttentionLayer(nn.Module): 7 | def __init__(self, D, H=128, return_sequences=False): 8 | ''' 9 | A single convolutional unit 10 | :param D: int, input feature dim 11 | :param H: int, hidden feature dim 12 | :param return_sequences: boolean, whether return sequence 13 | ''' 14 | super(AttentionLayer, self).__init__() 15 | 16 | # Config copying 17 | self.H = H 18 | self.return_sequences = return_sequences 19 | self.D = D 20 | self.linear1 = nn.Linear(D, H) 21 | self.linear2 = nn.Linear(H, 1) 22 | 23 | def softmax_mask(self, x, mask): 24 | ''' 25 | Softmax with mask 26 | :param x: torch.FloatTensor, logits, [batch_size, seq_len] 27 | :param mask: torch.ByteTensor, masks for sentences, [batch_size, seq_len] 28 | :return: torch.FloatTensor, probabilities, [batch_size, seq_len] 29 | ''' 30 | x_exp = torch.exp(x) 31 | if mask is not None: 32 | x_exp = x_exp * mask.float() 33 | x_sum = torch.sum(x_exp, dim=-1, keepdim=True) + 1e-6 34 | x_exp /= x_sum 35 | return x_exp 36 | 37 | def forward(self, x_text, mask, x_attention=None): 38 | ''' 39 | Forward this module 40 | :param x_text: torch.FloatTensor, input features, [batch_size, seq_len, D] 41 | :param mask: torch.ByteTensor, masks for features, [batch_size, seq_len] 42 | :param x_attention: torch.FloatTensor, input features No. 2 to attent with x_text, [batch_size, seq_len, D] 43 | :return: torch.FloatTensor, output features, if return sequences, output shape is [batch, SEQ_LEN, D]; 44 | otherwise output shape is [batch, D] 45 | ''' 46 | if x_attention is None: 47 | x_attention = x_text 48 | SEQ_LEN = x_text.size()[-2] 49 | x_attention = x_attention.contiguous().view(-1, self.D) # [batch_size * seq_len, D] 50 | attention = F.tanh(self.linear1(x_attention)) # [batch_size * seq_len, H] 51 | attention = self.linear2(attention) # [batch_size * seq_len, 1] 52 | attention = attention.view(-1, SEQ_LEN) # [batch_size, seq_len] 53 | attention = self.softmax_mask(attention, mask) # [batch_size, seq_len] 54 | output = x_text * attention.unsqueeze(-1).expand_as(x_text) # [batch_size, seq_len, D] 55 | 56 | if not self.return_sequences: 57 | output = torch.sum(output, -2) 58 | output = output.squeeze(1) 59 | return output 60 | 61 | 62 | if __name__ == "__main__": 63 | al = AttentionLayer(2, 3, return_sequences=True) 64 | x = torch.randn(5, 3, 2) 65 | print(x.size()) 66 | mask = torch.ByteTensor([[1, 1, 1], [1, 1, 1], [1, 1, 1], [1, 1, 0], [1, 0, 0]]) 67 | print(mask.size()) 68 | y = al(x, mask) 69 | print(y.size()) 70 | -------------------------------------------------------------------------------- /scripts/eval/test_joint_object.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | export CUDA_VISIBLE_DEVICES=$2 4 | 5 | logdir='../../log/joint_test/'$1 6 | mkdir -p $logdir 7 | mkdir -p $logdir/grounding_result 8 | mkdir -p $logdir/image_result 9 | mkdir -p $logdir/text_result 10 | rm -rf $logdir/image_result/* 11 | rm -rf $logdir/grounding_result/* 12 | rm -rf $logdir/text_result/* 13 | ln -s "/scratch/manling2/data/mm-event-graph/voa/rawdata/VOA_image_en" $logdir/grounding_result/VOA_image_en 14 | 15 | datadir="/scratch/manling2/data/mm-event-graph" 16 | glovedir="/scratch/manling2/data/glove" 17 | checkpoint_sr="/scratch/manling2/mm-event-graph/log/joint/obj_gcn3/model/model_sr_52.pt" 18 | checkpoint_sr_params="/scratch/manling2/mm-event-graph/log/joint/obj_gcn3/sr_hyps.json" 19 | checkpoint_ee="/scratch/manling2/mm-event-graph/log/joint/obj_gcn3/model/model_ee_52.pt" 20 | checkpoint_ee_params="/scratch/manling2/mm-event-graph/log/joint/obj_gcn3/ee_hyps.json" 21 | 22 | python ../../src/engine/TestRunnerJOINT.py \ 23 | --train_sr $datadir"/imSitu/train.json" \ 24 | --test_sr $datadir"/imSitu/test.json" \ 25 | --dev_sr $datadir"/imSitu/dev.json" \ 26 | --wnebd $datadir"/vocab/embedding_situation_noun.npy" \ 27 | --wvebd $datadir"/vocab/embedding_situation_verb.npy" \ 28 | --wrebd $datadir"/vocab/embedding_situation_role.npy" \ 29 | --vocab $datadir"/vocab/" \ 30 | --image_dir $datadir"/imSitu/of500_images_resized" \ 31 | --imsitu_ontology_file $datadir"/imSitu/imsitu_space.json" \ 32 | --verb_mapping_file $datadir"/ace/ace_sr_mapping.txt" \ 33 | --object_class_map_file $datadir"/object/class-descriptions-boxable.csv" \ 34 | --filter_irrelevant_verbs --filter_place \ 35 | --train_ee $datadir"/ace/JMEE_train_filter_no_timevalue.json" \ 36 | --test_ee $datadir"/ace/JMEE_test_filter_no_timevalue.json" \ 37 | --dev_ee $datadir"/ace/JMEE_dev_filter_no_timevalue.json" \ 38 | --webd $glovedir"/glove.840B.300d.txt" \ 39 | --train_grounding $datadir"/grounding/grounding_train_20000.json" \ 40 | --test_grounding $datadir"/grounding/grounding_test_20000.json" \ 41 | --dev_grounding $datadir"/grounding/grounding_valid_20000.json" \ 42 | --batch 1 --device "cuda" --out $logdir \ 43 | --test_voa_image $datadir"/voa/rawdata/VOA_image_en/" \ 44 | --gt_voa_image $datadir"/voa_anno_m2e2/image_event.json" \ 45 | --gt_voa_text $datadir"/voa_anno_m2e2/article_event.json" \ 46 | --gt_voa_align $datadir"/voa_anno_m2e2/article_event.json" \ 47 | --object_detection_pkl_file $datadir"/voa/object_detection/det_results_voa_oi_1.pkl" \ 48 | --finetune_sr ${checkpoint_sr} \ 49 | --sr_hps_path ${checkpoint_sr_params} \ 50 | --finetune_ee ${checkpoint_ee} \ 51 | --ee_hps_path ${checkpoint_ee_params} \ 52 | --ignore_place_sr_test \ 53 | --ignore_time_test \ 54 | --object_detection_threshold 0.1 \ 55 | --add_object \ 56 | --keep_events 1 \ 57 | --keep_events_sr 0 \ 58 | --with_sentid \ 59 | --apply_ee_role_mask \ 60 | >& $logdir/stdout.log & 61 | -------------------------------------------------------------------------------- /scripts/eval/test_joint.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | export CUDA_VISIBLE_DEVICES=$2 4 | 5 | logdir='../../log/joint_test/'$1 6 | mkdir -p $logdir 7 | mkdir -p $logdir/grounding_result 8 | mkdir -p $logdir/image_result 9 | mkdir -p $logdir/text_result 10 | rm -rf $logdir/image_result/* 11 | rm -rf $logdir/grounding_result/* 12 | rm -rf $logdir/text_result/* 13 | ln -s "/scratch/manling2/data/mm-event-graph/voa/rawdata/VOA_image_en" $logdir/grounding_result/VOA_image_en 14 | 15 | datadir="/scratch/manling2/data/mm-event-graph" 16 | glovedir="/scratch/manling2/data/glove" 17 | checkpoint_sr="/scratch/manling2/mm-event-graph/log/joint/att_gcn3/model/model_sr_4.pt" 18 | checkpoint_sr_params="/scratch/manling2/mm-event-graph/log/joint/att_gcn3/sr_hyps.json" 19 | checkpoint_ee="/scratch/manling2/mm-event-graph/log/joint/att_gcn3/model/model_ee_4.pt" 20 | checkpoint_ee_params="/scratch/manling2/mm-event-graph/log/joint/att_gcn3/ee_hyps.json" 21 | 22 | python ../../src/engine/TestRunnerJOINT.py \ 23 | --train_sr $datadir"/imSitu/train.json" \ 24 | --test_sr $datadir"/imSitu/test.json" \ 25 | --dev_sr $datadir"/imSitu/dev.json" \ 26 | --wnebd $datadir"/vocab/embedding_situation_noun.npy" \ 27 | --wvebd $datadir"/vocab/embedding_situation_verb.npy" \ 28 | --wrebd $datadir"/vocab/embedding_situation_role.npy" \ 29 | --vocab $datadir"/vocab/" \ 30 | --image_dir $datadir"/imSitu/of500_images_resized" \ 31 | --imsitu_ontology_file $datadir"/imSitu/imsitu_space.json" \ 32 | --verb_mapping_file $datadir"/ace/ace_sr_mapping.txt" \ 33 | --object_class_map_file $datadir"/object/class-descriptions-boxable.csv" \ 34 | --filter_irrelevant_verbs --filter_place \ 35 | --train_ee $datadir"/ace/JMEE_train_filter_no_timevalue.json" \ 36 | --test_ee $datadir"/ace/JMEE_test_filter_no_timevalue.json" \ 37 | --dev_ee $datadir"/ace/JMEE_dev_filter_no_timevalue.json" \ 38 | --webd $glovedir"/glove.840B.300d.txt" \ 39 | --train_grounding $datadir"/grounding/grounding_train_20000.json" \ 40 | --test_grounding $datadir"/grounding/grounding_test_20000.json" \ 41 | --dev_grounding $datadir"/grounding/grounding_valid_20000.json" \ 42 | --object_detection_pkl_file_g $datadir"/voa/object_detection/det_results_voa_oi_1.pkl" \ 43 | --batch 1 --device "cuda" --out $logdir \ 44 | --test_voa_image $datadir"/voa/rawdata/VOA_image_en/" \ 45 | --gt_voa_image $datadir"/voa_anno_m2e2/image_event.json" \ 46 | --gt_voa_text $datadir"/voa_anno_m2e2/article_event.json" \ 47 | --gt_voa_align $datadir"/voa_anno_m2e2/article_event.json" \ 48 | --object_detection_pkl_file $datadir"/voa/object_detection/det_results_voa_oi_1.pkl" \ 49 | --finetune_sr ${checkpoint_sr} \ 50 | --sr_hps_path ${checkpoint_sr_params} \ 51 | --finetune_ee ${checkpoint_ee} \ 52 | --ee_hps_path ${checkpoint_ee_params} \ 53 | --ignore_place_sr_test \ 54 | --ignore_time_test \ 55 | --object_detection_threshold 0.1 \ 56 | --keep_events 1 \ 57 | --keep_events_sr 0 \ 58 | --with_sentid \ 59 | --apply_ee_role_mask \ 60 | >& $logdir/stdout.log & 61 | -------------------------------------------------------------------------------- /scripts/train/train_joint_att.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | export CUDA_VISIBLE_DEVICES=$2 4 | 5 | logdir='../../log/joint/'$1 6 | mkdir -p $logdir 7 | 8 | datadir="/scratch/manling2/data/mm-event-graph" 9 | glovedir="/scratch/manling2/data/glove" 10 | 11 | python ../../src/engine/JOINTRunner.py \ 12 | --train_sr $datadir"/imSitu/train.json" \ 13 | --test_sr $datadir"/imSitu/test.json" \ 14 | --dev_sr $datadir"/imSitu/dev.json" \ 15 | --wnebd $datadir"/vocab/embedding_situation_noun.npy" \ 16 | --wvebd $datadir"/vocab/embedding_situation_verb.npy" \ 17 | --wrebd $datadir"/vocab/embedding_situation_role.npy" \ 18 | --vocab $datadir"/vocab/" \ 19 | --image_dir $datadir"/imSitu/of500_images_resized" \ 20 | --imsitu_ontology_file $datadir"/imSitu/imsitu_space.json" \ 21 | --verb_mapping_file $datadir"/ace/ace_sr_mapping.txt" \ 22 | --object_class_map_file $datadir"/object/class-descriptions-boxable.csv" \ 23 | --object_detection_pkl_file $datadir"/imSitu/object_detection/det_results_imsitu_oi_1.pkl" \ 24 | --object_detection_threshold 0.2 \ 25 | --sr_hps "{ 26 | 'wemb_dim': 300, 27 | 'wemb_ft': False, 28 | 'wemb_dp': 0.0, 29 | 'iemb_backbone': 'vgg16', 30 | 'iemb_dim':4096, 31 | 'iemb_ft': False, 32 | 'iemb_dp': 0.0, 33 | 'posemb_dim': 512, 34 | 'fmap_dim': 512, 35 | 'fmap_size': 7, 36 | 'att_dim': 1024, 37 | 'loss_weight_verb': 1.0, 38 | 'loss_weight_noun': 0.1, 39 | 'loss_weight_role': 0.0, 40 | 'gcn_layers': 3, 41 | 'gcn_dp': 0.5, 42 | 'gcn_use_bn': True, 43 | 'use_highway': False, 44 | }" \ 45 | --train_ee $datadir"/ace/JMEE_train_filter_no_timevalue.json" \ 46 | --test_ee $datadir"/ace/JMEE_test_filter_no_timevalue.json" \ 47 | --dev_ee $datadir"/ace/JMEE_dev_filter_no_timevalue.json" \ 48 | --webd $glovedir"/glove.840B.300d.txt" \ 49 | --ee_hps "{ 50 | 'wemb_dim': 300, 51 | 'wemb_ft': True, 52 | 'wemb_dp': 0.5, 53 | 'pemb_dim': 50, 54 | 'pemb_dp': 0.5, 55 | 'eemb_dim': 50, 56 | 'eemb_dp': 0.5, 57 | 'psemb_dim': 50, 58 | 'psemb_dp': 0.5, 59 | 'lstm_dim': 150, 60 | 'lstm_layers': 1, 61 | 'lstm_dp': 0, 62 | 'gcn_et': 3, 63 | 'gcn_use_bn': True, 64 | 'gcn_layers': 3, 65 | 'gcn_dp': 0.5, 66 | 'sa_dim': 300, 67 | 'use_highway': True, 68 | 'loss_alpha': 5 69 | }" \ 70 | --train_grounding $datadir"/grounding/grounding_train_20000.json" \ 71 | --test_grounding $datadir"/grounding/grounding_test_20000.json" \ 72 | --dev_grounding $datadir"/grounding/grounding_valid_20000.json" \ 73 | --img_dir_grounding $datadir"/voa/rawdata/VOA_image_en/" \ 74 | --object_detection_pkl_file_g $datadir"/voa/object_detection/det_results_voa_oi_1.pkl" \ 75 | --earlystop 10 --restart 999999 --optimizer "adadelta" --lr 1 \ 76 | --shuffle \ 77 | --filter_place \ 78 | --finetune_sr /scratch/manling2/mm-event-graph/log/sr/retest/model.pt \ 79 | --batch 8 --epochs 60 --device "cuda" --out $logdir \ 80 | >& $logdir/stdout.log & 81 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Cross-media Structured Common Space for Multimedia Event Extraction 2 | 3 | Table of Contents 4 | ================= 5 | * [Overview](#overview) 6 | * [Requirements](#requirements) 7 | * [Data](#data) 8 | * [Quickstart](#quickstart) 9 | * [Citation](#citation) 10 | 11 | ## Overview 12 | The code for paper [Cross-media Structured Common Space for Multimedia Event Extraction](http://blender.cs.illinois.edu/software/m2e2/). 13 | 14 |

15 | Photo 16 |

17 | 18 | ## Requirements 19 | 20 | You can install the environment using `requirements.txt` for each component. 21 | 22 | ```pip 23 | pip install -r requirements.txt 24 | ``` 25 | 26 | ## Data 27 | 28 | ### Situation Recognition (Visual Event Extraction Data) 29 | We download situation recognition data from [imSitu](http://imsitu.org/). Please find the preprocessed data in [PreprcessedSR](https://drive.google.com/drive/folders/1h0qwYWeGEoCx8m-zwH-XcoPSyffmrC-c?usp=sharing). 30 | 31 | ### ACE (Text Event Extraction Data) 32 | We preprcoessed ACE following [JMEE](https://github.com/lx865712528/EMNLP2018-JMEE/tree/master). The sample data format is in [sample.json](https://github.com/lx865712528/EMNLP2018-JMEE/blob/master/ace-05-splits/sample.json). Due to license reason, the ACE 2005 dataset is only accessible to those with LDC2006T06 license, please drop me an email `manling2@illinois.edu` showing your possession of the license for the processed data. 33 | 34 | ### Voice of America Image-Caption Pairs 35 | We crawled VOA image-captions to train the common space, the [image-caption pairs](https://uofi.box.com/s/xtn9p6m8z5qtjbbi5tqrl45tn6apew4x) and images can be downloaded using the URLs (We share image URLs instead of downloaded images due to license issue) using script in [dataset_image_download.py](https://github.com/limanling/m2e2/blob/master/src/dataflow/numpy/dataset_image_download.py). We preprocess the data including object detection, and parse text sentences. The preprocessed data is in [PreprocessedVOA](https://drive.google.com/drive/folders/1I9vMGIhWZpKqxQYip91eLoDRnrkqRxnt?usp=sharing). 36 | 37 | ### M2E2 (Multimedia Event Extraction Benchmark) 38 | 39 | The images and text articles are in [m2e2_rawdata](https://drive.google.com/file/d/1xtFMjt_eYgeBts5rBomOWbPo7wV_mnhy/view?usp=sharing), and annotations are in [m2e2_annotation](http://blender.cs.illinois.edu/software/m2e2/m2e2_v0.1/m2e2_annotations.zip) under `data` directory. 40 | 41 | ### Vocabulary 42 | Preprocessed vocabulary is in [PreprocessedVocab](https://drive.google.com/drive/folders/1MWklNkpedJp-P80wpKF-WOQAvYaYKzP2?usp=sharing). 43 | 44 | 45 | ## Quickstart 46 | 47 | ### Training 48 | 49 | We have two variants to parse images into situation graph, one is parsing images to role-driven attention graph, and another is parsing images to object graphs. 50 | 51 | (1) attention-graph based version 52 | ```bash 53 | sh scripts/train/train_joint_att.sh 54 | ``` 55 | (2) object-graph based version: 56 | ```bash 57 | sh scripts/train/train_joint_obj.sh 58 | ``` 59 | Please specify the data paths `datadir`, `glovedir` in scripts. 60 | 61 | 62 | ### Testing 63 | 64 | (1) attention-graph based version 65 | ``` 66 | sh test_joint.sh 67 | ``` 68 | (2) object-graph based version: 69 | ```bash 70 | sh test_joint_object.sh 71 | ``` 72 | 73 | Please specify the data paths `datadir`, `glovedir`, and model paths `checkpoint_sr`, `checkpoint_sr_params`, `checkpoint_ee`, `checkpoint_ee_params` in scripts. 74 | 75 | 76 | ## Citation 77 | 78 | Manling Li, Alireza Zareian, Qi Zeng, Spencer Whitehead, Di Lu, Heng Ji, Shih-Fu Chang. 2020. Cross-media Structured Common Space for Multimedia Event Extraction. Proceedings of The 58th Annual Meeting of the Association for Computational Linguistics. 79 | ``` 80 | @inproceedings{li2020multimediaevent, 81 | title={Cross-media Structured Common Space for Multimedia Event Extraction}, 82 | author={Manling Li and Alireza Zareian and Qi Zeng and Spencer Whitehead and Di Lu and Heng Ji and Shih-Fu Chang}, 83 | booktitle={Proceedings of The 58th Annual Meeting of the Association for Computational Linguistics}, 84 | year={2020} 85 | ``` 86 | -------------------------------------------------------------------------------- /src/util/vocab.py: -------------------------------------------------------------------------------- 1 | """ 2 | A class for basic vocab operations. 3 | """ 4 | 5 | from __future__ import print_function 6 | import os 7 | import random 8 | import numpy as np 9 | import pickle 10 | 11 | from . import constant 12 | 13 | random.seed(1234) 14 | np.random.seed(1234) 15 | 16 | def build_embedding(wv_file, vocab, wv_dim): 17 | vocab_size = len(vocab) 18 | emb = np.random.uniform(-1, 1, (vocab_size, wv_dim)) 19 | emb[constant.PAD_ID] = 0 # should be all 0 20 | 21 | w2id = {w: i for i, w in enumerate(vocab)} 22 | with open(wv_file, encoding="utf8") as f: 23 | for line in f: 24 | elems = line.split() 25 | token = ''.join(elems[0:-wv_dim]) 26 | if token in w2id: 27 | emb[w2id[token]] = [float(v) for v in elems[-wv_dim:]] 28 | return emb 29 | 30 | def load_glove_vocab(file, wv_dim): 31 | """ 32 | Load all words from glove. 33 | """ 34 | vocab = set() 35 | with open(file, encoding='utf8') as f: 36 | for line in f: 37 | elems = line.split() 38 | token = ''.join(elems[0:-wv_dim]) 39 | vocab.add(token) 40 | return vocab 41 | 42 | class Vocab(object): 43 | def __init__(self, filename, load=False, word_counter=None, threshold=0): 44 | if load: 45 | assert os.path.exists(filename), "Vocab file does not exist at " + filename 46 | # load from file and ignore all other params 47 | self.id2word, self.word2id = self.load(filename) 48 | self.size = len(self.id2word) 49 | print("Vocab size {} loaded from file".format(self.size)) 50 | else: 51 | print("Creating vocab from scratch...") 52 | assert word_counter is not None, "word_counter is not provided for vocab creation." 53 | self.word_counter = word_counter 54 | if threshold > 1: 55 | # remove words that occur less than thres 56 | self.word_counter = dict([(k,v) for k,v in self.word_counter.items() if v >= threshold]) 57 | self.id2word = sorted(self.word_counter, key=lambda k:self.word_counter[k], reverse=True) 58 | # add special tokens to the beginning 59 | self.id2word = [constant.PAD_TOKEN, constant.UNK_TOKEN] + self.id2word 60 | self.word2id = dict([(self.id2word[idx],idx) for idx in range(len(self.id2word))]) 61 | self.size = len(self.id2word) 62 | self.save(filename) 63 | print("Vocab size {} saved to file {}".format(self.size, filename)) 64 | 65 | def load(self, filename): 66 | with open(filename, 'rb') as infile: 67 | id2word = pickle.load(infile) 68 | word2id = dict([(id2word[idx], idx) for idx in range(len(id2word))]) 69 | return id2word, word2id 70 | 71 | def save(self, filename): 72 | if os.path.exists(filename): 73 | print("Overwriting old vocab file at " + filename) 74 | os.remove(filename) 75 | with open(filename, 'wb') as outfile: 76 | pickle.dump(self.id2word, outfile) 77 | return 78 | 79 | def map(self, token_list): 80 | """ 81 | Map a list of tokens to their ids. 82 | """ 83 | return [self.word2id[w] if w in self.word2id else constant.VOCAB_UNK_ID for w in token_list] 84 | 85 | def unmap(self, idx_list): 86 | """ 87 | Unmap ids back to tokens. 88 | """ 89 | return [self.id2word[idx] for idx in idx_list] 90 | 91 | def get_embeddings(self, word_vectors=None, dim=100): 92 | self.embeddings = 2 * constant.EMB_INIT_RANGE * np.random.rand(self.size, dim) - constant.EMB_INIT_RANGE 93 | if word_vectors is not None: 94 | assert len(list(word_vectors.values())[0]) == dim, \ 95 | "Word vectors does not have required dimension {}.".format(dim) 96 | for w, idx in self.word2id.items(): 97 | if w in word_vectors: 98 | self.embeddings[idx] = np.asarray(word_vectors[w]) 99 | return self.embeddings 100 | 101 | -------------------------------------------------------------------------------- /src/models/modules/DynamicLSTM.py: -------------------------------------------------------------------------------- 1 | import numpy 2 | import numpy as np 3 | import torch 4 | import torch.nn as nn 5 | 6 | 7 | class DynamicLSTM(nn.Module): 8 | def __init__(self, input_size, hidden_size, num_layers=1, bias=True, batch_first=True, dropout=0, 9 | bidirectional=False, device=torch.device("cpu")): 10 | """ 11 | Dynamic LSTM which can hold variable length sequence, use like TensorFlow's RNN(input, length...). 12 | 13 | :param input_size: The number of expected features in the input x 14 | :param hidden_size: The number of features in the hidden state h 15 | :param num_layers: Number of recurrent layers. 16 | :param bias: If False, then the layer does not use bias weights b_ih and b_hh. Default: True 17 | :param batch_first: If True, then the input and output tensors are provided as (batch, seq, feature) 18 | :param dropout: If non-zero, introduces a dropout layer on the outputs of each RNN layer except the last layer 19 | :param bidirectional: If True, becomes a bidirectional RNN. Default: False 20 | """ 21 | super(DynamicLSTM, self).__init__() 22 | self.input_size = input_size 23 | self.hidden_size = hidden_size 24 | self.num_layers = num_layers 25 | self.bias = bias 26 | self.batch_first = batch_first 27 | self.dropout = dropout 28 | self.bidirectional = bidirectional 29 | self.LSTM = nn.LSTM( 30 | input_size=input_size, 31 | hidden_size=hidden_size, 32 | num_layers=num_layers, 33 | bias=bias, 34 | batch_first=batch_first, 35 | dropout=dropout, 36 | bidirectional=bidirectional 37 | ) 38 | 39 | self.device = device 40 | self.to(device) 41 | 42 | def forward(self, x, x_len, only_use_last_hidden_state=False): 43 | """ 44 | sequence -> sort -> pad and pack -> process using RNN -> unpack -> unsort 45 | 46 | :param x: FloatTensor, pre-padded input sequence (batch_size, seq_len, feature_dim) 47 | :param x_len: numpy list, indicating corresponding actual sequence length 48 | :return: output, (h_n, c_n) 49 | - **output**: FloatTensor, packed output sequence (batch_size, seq_len, feature_dim * num_directions) 50 | containing the output features `(h_t)` from the last layer of the LSTM, for each t. 51 | - **h_n**: FloatTensor, (num_layers * num_directions, batch, hidden_size) 52 | containing the hidden state for `t = seq_len` 53 | - **c_n**: FloatTensor, (num_layers * num_directions, batch, hidden_size) 54 | containing the cell state for `t = seq_len` 55 | """ 56 | # 1. sort 57 | x_sort_idx = np.argsort(-x_len) 58 | x_unsort_idx = torch.LongTensor(np.argsort(x_sort_idx)).to(self.device) 59 | x_len = x_len[x_sort_idx] 60 | x = x[torch.LongTensor(x_sort_idx).to(self.device)] 61 | # 2. pack 62 | x_p = torch.nn.utils.rnn.pack_padded_sequence(x, x_len, batch_first=self.batch_first) 63 | # 3. process using RNN 64 | out_pack, (ht, ct) = self.LSTM(x_p, None) 65 | # 4. unsort h 66 | ht = torch.transpose(ht, 0, 1)[x_unsort_idx] 67 | ht = torch.transpose(ht, 0, 1) 68 | 69 | if only_use_last_hidden_state: 70 | return ht 71 | else: 72 | # 5. unpack output 73 | out = torch.nn.utils.rnn.pad_packed_sequence(out_pack, batch_first=self.batch_first) # (sequence, lengths) 74 | out = out[0] # 75 | # 6. unsort out c 76 | out = out[x_unsort_idx] 77 | ct = torch.transpose(ct, 0, 1)[x_unsort_idx] 78 | ct = torch.transpose(ct, 0, 1) 79 | return out, (ht, ct) 80 | 81 | 82 | if __name__ == "__main__": 83 | BATCH_SIZE = 5 84 | SEQ_LEN = 3 85 | D = 2 86 | aa = DynamicLSTM(input_size=2, hidden_size=2, batch_first=True) 87 | binp = torch.rand(BATCH_SIZE, SEQ_LEN, D) 88 | binlen = numpy.array([3, 3, 3, 2, 2], dtype=numpy.int8) 89 | boup, _ = aa(binp, binlen) 90 | print(boup) 91 | 92 | for i in range(BATCH_SIZE): 93 | sinp = binp[i].unsqueeze(0) 94 | sinlen = numpy.array([binlen[i]], dtype=numpy.int8) 95 | soup, _ = aa(sinp, sinlen) 96 | print(soup) 97 | -------------------------------------------------------------------------------- /src/eval/EEtesting.py: -------------------------------------------------------------------------------- 1 | from seqeval.metrics import f1_score, precision_score, recall_score 2 | 3 | 4 | class EDTester(): 5 | def __init__(self, type_i2s, role_i2s, ignore_time): 6 | self.voc_i2s = type_i2s 7 | self.role_i2s = role_i2s 8 | self.ignore_time = ignore_time 9 | 10 | def calculate_report(self, y, y_, transform=True): 11 | ''' 12 | calculating F1, P, R 13 | 14 | :param y: golden label, list 15 | :param y_: model output, list 16 | :return: 17 | ''' 18 | if transform: 19 | for i in range(len(y)): 20 | for j in range(len(y[i])): 21 | y[i][j] = self.voc_i2s[y[i][j]] 22 | for i in range(len(y_)): 23 | for j in range(len(y_[i])): 24 | y_[i][j] = self.voc_i2s[y_[i][j]] 25 | return precision_score(y, y_), recall_score(y, y_), f1_score(y, y_) 26 | 27 | @staticmethod 28 | def merge_segments(y, send_id=None): 29 | segs = {} 30 | tt = "" 31 | st, ed = -1, -1 32 | for i, x in enumerate(y): 33 | if x.startswith("B-"): 34 | if tt == "": 35 | tt = x[2:] 36 | if send_id is None: 37 | st = i 38 | else: 39 | st = '%s__%d' % (send_id, i) 40 | else: 41 | ed = i 42 | segs[st] = (ed, tt) 43 | tt = x[2:] 44 | if send_id is None: 45 | st = i 46 | else: 47 | st = '%s__%d' % (send_id, i) 48 | elif x.startswith("I-"): 49 | if tt == "": 50 | y[i] = "B" + y[i][1:] 51 | tt = x[2:] 52 | if send_id is None: 53 | st = i 54 | else: 55 | st = '%s__%d' % (send_id, i) 56 | else: 57 | if tt != x[2:]: 58 | ed = i 59 | segs[st] = (ed, tt) 60 | y[i] = "B" + y[i][1:] 61 | tt = x[2:] 62 | if send_id is None: 63 | st = i 64 | else: 65 | st = '%s__%d' % (send_id, i) 66 | else: 67 | ed = i 68 | if tt != "": 69 | segs[st] = (ed, tt) 70 | tt = "" 71 | 72 | if tt != "": 73 | segs[st] = (len(y), tt) 74 | return segs 75 | 76 | def calculate_sets(self, y, y_): 77 | ct, p1, p2 = 0, 0, 0 78 | for sent, sent_ in zip(y, y_): 79 | # trigger start 80 | for key, value in sent.items(): 81 | # key = trigger end, event type 82 | # value = args 83 | p1 += len(value) 84 | if key not in sent_: 85 | continue 86 | # matched sentences 87 | arguments = value 88 | arguments_ = sent_[key] 89 | for item, item_ in zip(arguments, arguments_): 90 | # print('item', self.role_i2s[item[2]], self.role_i2s[item_[2]]) 91 | if self.ignore_time and self.role_i2s[item[2]].upper().startswith('TIME'): 92 | continue 93 | if item[2] == item_[2]: 94 | ct += 1 95 | 96 | for key, value in sent_.items(): 97 | # p2_key += 1 98 | # p2 += len(value) 99 | # print('key', key) 100 | for item in sent_[key]: 101 | if self.ignore_time and self.role_i2s[item[2]].upper().startswith('TIME'): 102 | continue 103 | p2 += 1 104 | 105 | 106 | if ct == 0 or p1 == 0 or p2 == 0: 107 | return 0.0, 0.0, 0.0 108 | else: 109 | p = 1.0 * ct / p2 110 | r = 1.0 * ct / p1 111 | f1 = 2.0 * p * r / (p + r) 112 | print('ct', ct) 113 | print('p1', p1) 114 | print('p2', p2) 115 | return p, r, f1 116 | -------------------------------------------------------------------------------- /src/eval/Groundingtesting.py: -------------------------------------------------------------------------------- 1 | 2 | class GroundingTester(): 3 | def __init__(self): 4 | pass 5 | 6 | def calculate_lists(self, y, y_): 7 | ''' 8 | for a sequence, whether the prediction is correct 9 | note that len(y) == len(y_) 10 | :param y: 11 | :param y_: 12 | :return: 13 | ''' 14 | ct = 0 15 | p2 = len(y_) 16 | p1 = len(y) 17 | for i in range(p2): 18 | if y[i] == y_[i]: 19 | ct = ct + 1 20 | if ct == 0 or p1 == 0 or p2 == 0: 21 | return 0.0, 0.0, 0.0 22 | else: 23 | p = 1.0 * ct / p2 24 | r = 1.0 * ct / p1 25 | f1 = 2.0 * p * r / (p + r) 26 | return p, r, f1 27 | 28 | def calculate_sets_no_order(self, y, y_): 29 | ''' 30 | for each predicted item, whether it is in the gt 31 | :param y: [batch, items] 32 | :param y_: [batch, items] 33 | :return: 34 | ''' 35 | ct, p1, p2 = 0, 0, 0 36 | for batch, batch_ in zip(y, y_): 37 | value_set = set(batch) 38 | value_set_ = set(batch_) 39 | p1 += len(value_set) 40 | p2 += len(value_set_) 41 | 42 | for value_ in value_set_: 43 | # if value_ == '(0,0,0)': 44 | if value_ in value_set: 45 | ct += 1 46 | 47 | if ct == 0 or p1 == 0 or p2 == 0: 48 | return 0.0, 0.0, 0.0 49 | else: 50 | p = 1.0 * ct / p2 51 | r = 1.0 * ct / p1 52 | f1 = 2.0 * p * r / (p + r) 53 | return p, r, f1 54 | 55 | def calculate_sets_noun(self, y, y_): 56 | ''' 57 | for each ground truth entity, whether it is in the predicted entities 58 | :param y: [batch, role_num, multiple_args] 59 | :param y_: [batch, role_num, multiple_entities] 60 | :return: 61 | ''' 62 | # print('y', y) 63 | # print('y_', y_) 64 | ct, p1, p2 = 0, 0, 0 65 | # for batch_idx, batch_idx_ in zip(y, y_): 66 | # batch = y[batch_idx] 67 | # batch_ = y_[batch_idx_] 68 | for batch, batch_ in zip(y, y_): 69 | # print('batch', batch) 70 | # print('batch_', batch_) 71 | p1 += len(batch) 72 | p2 += len(batch_) 73 | for role in batch: 74 | found = False 75 | entities = batch[role] 76 | for entity in entities: 77 | for role_ in batch_: 78 | entities_ = batch_[role_] 79 | if entity in entities_: 80 | ct += 1 81 | found = True 82 | break 83 | if found: 84 | break 85 | 86 | if ct == 0 or p1 == 0 or p2 == 0: 87 | return 0.0, 0.0, 0.0 88 | else: 89 | p = 1.0 * ct / p2 90 | r = 1.0 * ct / p1 91 | f1 = 2.0 * p * r / (p + r) 92 | return p, r, f1 93 | 94 | def calculate_sets_triple(self, y, y_): 95 | ''' 96 | for each role, whether the predicted entities have overlap with the gt entities 97 | :param y: dict, role -> entities 98 | :param y_: dict, role -> entities 99 | :return: 100 | ''' 101 | ct, p1, p2 = 0, 0, 0 102 | # for batch_idx, batch_idx_ in zip(y, y_): 103 | # batch = y[batch_idx] 104 | # batch_ = y_[batch_idx_] 105 | for batch, batch_ in zip(y, y_): 106 | p1 += len(batch) 107 | p2 += len(batch_) 108 | for role in batch: 109 | entities = batch[role] 110 | if role in batch_: 111 | entities_ = batch_[role] 112 | for entity_ in entities_: 113 | if entity_ in entities: 114 | ct += 1 115 | break 116 | 117 | if ct == 0 or p1 == 0 or p2 == 0: 118 | return 0.0, 0.0, 0.0 119 | else: 120 | p = 1.0 * ct / p2 121 | r = 1.0 * ct / p1 122 | f1 = 2.0 * p * r / (p + r) 123 | return p, r, f1 124 | 125 | -------------------------------------------------------------------------------- /src/models/modules/GCN.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn import functional as F 4 | from torch.optim import Adadelta 5 | 6 | import sys 7 | #sys.path.append('/dvmm-filer2/users/manling/mm-event-graph2') 8 | from src.util.util_model import BottledOrthogonalLinear, log 9 | 10 | 11 | class GraphConvolution(nn.Module): 12 | def __init__(self, in_features, out_features, edge_types, dropout=0.5, bias=True, use_bn=False, 13 | device=torch.device("cpu")): 14 | """ 15 | Single Layer GraphConvolution 16 | 17 | :param in_features: The number of incoming features 18 | :param out_features: The number of output features 19 | :param edge_types: The number of edge types in the whole graph 20 | :param dropout: Dropout keep rate, if not bigger than 0, 0 or None, default 0.5 21 | :param bias: If False, then the layer does not use bias weights b_ih and b_hh. Default: True 22 | """ 23 | super(GraphConvolution, self).__init__() 24 | self.in_features = in_features 25 | self.out_features = out_features 26 | self.edge_types = edge_types 27 | self.dropout = dropout if type(dropout) == float and -1e-7 < dropout < 1 + 1e-7 else None 28 | # parameters for gates 29 | self.Gates = nn.ModuleList() 30 | # parameters for graph convolutions 31 | self.GraphConv = nn.ModuleList() 32 | # batch norm 33 | self.use_bn = use_bn 34 | if self.use_bn: 35 | self.bn = nn.BatchNorm1d(self.out_features) 36 | 37 | for _ in range(edge_types): 38 | self.Gates.append(BottledOrthogonalLinear(in_features=in_features, 39 | out_features=1, 40 | bias=bias)) 41 | self.GraphConv.append(BottledOrthogonalLinear(in_features=in_features, 42 | out_features=out_features, 43 | bias=bias)) 44 | self.device = device 45 | self.to(device) 46 | 47 | def forward(self, input, adj): 48 | """ 49 | 50 | :param input: FloatTensor, input feature tensor, (batch_size, seq_len, hidden_size) 51 | :param adj: FloatTensor (sparse.FloatTensor.to_dense()), adjacent matrix for provided graph of padded sequences, (batch_size, edge_types, seq_len, seq_len) 52 | :return: output 53 | - **output**: FloatTensor, output feature tensor with the same size of input, (batch_size, seq_len, hidden_size) 54 | """ 55 | 56 | adj_ = adj.transpose(0, 1) # (edge_types, batch_size, seq_len, seq_len) 57 | ts = [] 58 | for i in range(self.edge_types): 59 | gate_status = F.sigmoid(self.Gates[i](input)) # (batch_size, seq_len, 1) 60 | adj_hat_i = adj_[i] * gate_status # (batch_size, seq_len, seq_len) 61 | ts.append(torch.bmm(adj_hat_i, self.GraphConv[i](input))) 62 | ts = torch.stack(ts).sum(dim=0, keepdim=False).to(self.device) 63 | if self.use_bn: 64 | ts = ts.transpose(1, 2).contiguous() 65 | ts = self.bn(ts) 66 | ts = ts.transpose(1, 2).contiguous() 67 | ts = F.relu(ts) 68 | if self.dropout is not None: 69 | ts = F.dropout(ts, p=self.dropout, training=self.training) 70 | return ts 71 | 72 | def __repr__(self): 73 | return self.__class__.__name__ + ' (' + str(self.in_features) + ' -> ' + str(self.out_features) + ')' 74 | 75 | 76 | if __name__ == "__main__": 77 | device = torch.device("cuda") 78 | 79 | BATCH_SIZE = 1 80 | SEQ_LEN = 8 81 | D = 6 82 | ET = 1 83 | CLASSN = 2 84 | adj = torch.sparse.FloatTensor( 85 | torch.LongTensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 86 | [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 87 | [0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7, 0, 1, 2, 3, 4, 5, 6, 7], 88 | [1, 3, 0, 2, 1, 5, 0, 4, 3, 7, 2, 6, 5, 7, 4, 6, 0, 1, 2, 3, 4, 5, 6, 7]]), 89 | torch.FloatTensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]), 90 | torch.Size([BATCH_SIZE, ET, SEQ_LEN, SEQ_LEN])).to_dense().to(device) 91 | input = torch.randn(BATCH_SIZE, SEQ_LEN, D).to(device) 92 | label = torch.LongTensor([0, 1, 0, 1, 0, 1, 0, 1]).to(device) 93 | 94 | cc = GraphConvolution(in_features=D, out_features=D, edge_types=ET, device=device, use_bn=True) 95 | oo = BottledOrthogonalLinear(in_features=D, out_features=CLASSN).to(device) 96 | 97 | optimizer = Adadelta(list(cc.parameters()) + list(oo.parameters())) 98 | 99 | aloss = 1e9 100 | df = 1e9 101 | while df > 1e-7: 102 | output = oo(cc(input, adj)).view(BATCH_SIZE * SEQ_LEN, CLASSN) 103 | loss = F.cross_entropy(output, label) 104 | df = abs(aloss - loss.item()) 105 | aloss = loss.item() 106 | loss.backward() 107 | optimizer.step() 108 | log(aloss) 109 | 110 | log(F.softmax(output), dim=2) 111 | -------------------------------------------------------------------------------- /src/dataflow/numpy/anno_mapping.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict 2 | 3 | def event_type_norm(type_str): 4 | return type_str.replace('.', '||').replace(':', '||').replace('-', '|').upper() 5 | 6 | 7 | def role_name_norm(type_str): 8 | return type_str.upper() 9 | 10 | entity_type_mapping_brat = { 11 | 'PER': 'PER', 12 | 'ORG': 'ORG', 13 | 'GPE': 'GPE', 14 | 'LOC': 'LOC', 15 | 'FAC': 'FAC', 16 | 'VEH': 'VEH', 17 | 'WEA': 'WEA', 18 | # 'TIM': 'TIME', 19 | # 'NUM': 'VALUE', 20 | # 'TIT', 21 | 'MON': 'VALUE', 22 | # 'URL', 23 | # 'RES', 24 | # 'BALLOT' 25 | } 26 | 27 | event_type_mapping_brat2ace = { 28 | 'Die': 'Life:Die', 29 | # 'Injure': 'Life:Injure', 30 | 'TransferMoney': 'Transaction:Transfer-Money', 31 | 'Attack': 'Conflict:Attack', 32 | 'Demonstrate': 'Conflict:Demonstrate', 33 | 'Correspondence': 'Contact:Phone-Write', 34 | 'Meet': 'Contact:Meet', 35 | 'ArrestJail': 'Justice:Arrest-Jail', 36 | # 'ReleaseParole': 'Justice:Release-Parole', 37 | 'TransportPerson': 'Movement:Transport', 38 | 'TransportArtifact': 'Movement:Transport', 39 | } 40 | 41 | event_type_mapping_ace2brat = {event_type_norm(v): k for k, v in event_type_mapping_brat2ace.items()} 42 | # print(event_type_mapping_ace2brat) 43 | 44 | event_type_mapping_aida2brat = { 45 | 'Life.Die': 'Die', 46 | 'Life.Injure': 'Injure', 47 | 'Transaction.TransferMoney': 'TransferMoney', 48 | 'Conflict.Attack': 'Attack', 49 | 'Conflict.Demonstrate': 'Demonstrate', 50 | 'Contact.Correspondence': 'Correspondence', 51 | 'Contact.Meet': 'Meet', 52 | 'Justice.ArrestJail': 'ArrestJail', 53 | 'Justice.ReleaseParole': 'ReleaseParole', 54 | 'Movement.TransportPerson': 'TransportPerson', 55 | 'Movement.TransportArtifact': 'TransportArtifact', 56 | } 57 | 58 | event_role_mapping_brat2ace = { # delete time-related ones when testing 59 | 'Die': {'Victim': 'Victim', 'Agent': 'Agent', 'Instrument': 'Instrument', 'Place': 'Place'}, #, 'Time': 'Time'}, 60 | # 'Injure': {'Victim': 'Victim', 'Agent': 'Agent', 'Instrument': 'Instrument', 'Place': 'Place'}, #, 'Time': 'Time'}, 61 | 'TransferMoney': {'Giver': 'Giver', 'Recipient': 'Recipient', 'Beneficiary': 'Beneficiary', 'Money': 'Money', 'Place': 'Place'}, #, 'Time': 'Time'}, 62 | 'Attack': {'Attacker': 'Attacker', 'Instrument': 'Instrument', 'Place': 'Place', 'Target': 'Target'}, #, 'Time': 'Time'}, 63 | 'Demonstrate': {'Demonstrator': 'Entity', 'Place': 'Place'}, #, 'Time': 'Time'}, 64 | 'Correspondence': {'Participant': 'Entity', 'Place': 'Place'}, #, 'Time': 'Time'}, 65 | 'Meet': {'Participant': 'Entity', 'Place': 'Place'}, #, 'Time': 'Time'}, 66 | 'ArrestJail': {'Agent': 'Agent', 'Person': 'Person', 'Place': 'Place'}, #, 'Time': 'Time'}, 67 | # 'ReleaseParole': {'Agent': 'Entity', 'Person': 'Person', 'Place': 'Place'}, #, 'Time': 'Time'}, 68 | 'TransportPerson': {'Agent': 'Agent', 'Person': 'Artifact', 'Instrument': 'Vehicle', 'Destination': 'Destination', 'Origin': 'Origin'}, #, 'Time': 'Time'}, 69 | 'TransportArtifact': {'Agent': 'Agent', 'Artifact': 'Artifact', 'Instrument': 'Vehicle', 'Destination': 'Destination', 'Origin': 'Origin'}, #, 'Time': 'Time'}, 70 | } 71 | 72 | # event_role_mapping_ace2brat = {role_name_norm(v): k for t in event_role_mapping_brat2ace.items() for k, v in event_role_mapping_brat2ace[t]} 73 | event_role_mapping_ace2brat = defaultdict(lambda : defaultdict()) 74 | for t in event_role_mapping_brat2ace: 75 | for k, v in event_role_mapping_brat2ace[t].items(): 76 | event_type_ace = event_type_norm(event_type_mapping_brat2ace[t]) 77 | event_role_mapping_ace2brat[event_type_ace][role_name_norm(v)] = k 78 | # print(event_role_mapping_ace2brat) 79 | 80 | event_type_mapping_image2ace = { 81 | 'Movement.TransportPerson': 'Movement:Transport', 82 | 'Movement.TransportArtifact': 'Movement:Transport', 83 | 'Life.Die': 'Life:Die', 84 | 'Conflict.Attack': 'Conflict:Attack', 85 | 'Conflict.Demonstrate': 'Conflict:Demonstrate', 86 | 'Contact.Phone-Write': 'Contact:Phone-Write', 87 | 'Contact.Meet': 'Contact:Meet', 88 | 'Transaction.TransferMoney': 'Transaction:Transfer-Money', 89 | 'Justice.ArrestJail': 'Justice:Arrest-Jail', 90 | } 91 | 92 | event_role_mapping_image2ace = { 93 | 'Life.Die': {'victim': 'Victim', 'agent': 'Agent', 'instrument': 'Instrument', 'place': 'Place'}, #, 'Time': 'Time'}, 94 | 'Transaction.TransferMoney': {'instrument': 'Instrument', 'giver': 'Giver', 'recipient': 'Recipient', 'beneficiary': 'Beneficiary', 'money': 'Money', 'place': 'Place'}, #, 'Time': 'Time'}, 95 | 'Conflict.Attack': {'attacker': 'Attacker', 'instrument': 'Instrument', 'place': 'Place', 'target': 'Target', 'victim': 'Target'}, #, 'Time': 'Time'}, 96 | 'Conflict.Demonstrate': {'police':'Police', 'instrument': 'Instrument', 'demonstrator': 'Entity', 'place': 'Place', 'participant': 'Entity'}, #, 'Time': 'Time'}, 97 | 'Contact.Phone-Write': {'instrument':'Instrument', 'participant': 'Entity', 'place': 'Place'}, #, 'Time': 'Time'}, 98 | 'Contact.Meet': {'participant': 'Entity', 'place': 'Place'}, #, 'Time': 'Time'}, 99 | 'Justice.ArrestJail': {'instrument': 'Instrument', 'agent': 'Agent', 'person': 'Person', 'place': 'Place'}, #, 'Time': 'Time'}, 100 | 'Movement.TransportPerson': {'agent': 'Agent', 'person': 'Artifact', 'instrument': 'Instrument', 'destination': 'Destination', 'origin': 'Origin'}, #, 'Time': 'Time'}, 101 | 'Movement.TransportArtifact': {'person': 'Artifact', 'agent': 'Agent', 'artifact': 'Artifact', 'instrument': 'Instrument', 'destination': 'Destination', 'origin': 'Origin'}, #, 'Time': 'Time'}, 102 | } 103 | # 'TransportArtifact' has 'person' in image annotation -------------------------------------------------------------------------------- /src/dataflow/numpy/prepare_vocab.py: -------------------------------------------------------------------------------- 1 | """ 2 | Prepare vocabulary and initial word vectors. 3 | """ 4 | import json 5 | import pickle 6 | import argparse 7 | import numpy as np 8 | from collections import Counter 9 | 10 | import sys 11 | #sys.path.append("../") 12 | from utils import vocab, constant, helper 13 | 14 | def parse_args(): 15 | parser = argparse.ArgumentParser(description='Prepare vocab for relation extraction.') 16 | parser.add_argument('data_dir', help='TACRED directory.') 17 | # /data/m1/lim22/multimedia-common-space/Multimedia-Common-Space/ace/JMEE_data/head 18 | parser.add_argument('vocab_dir', help='Output vocab directory.') 19 | # /data/m1/lim22/multimedia-common-space/Multimedia-Common-Space/ace/vocab 20 | parser.add_argument('--glove_dir', default='/data/m1/lim22/env/glove', help='GloVe directory.') 21 | parser.add_argument('--wv_file', default='glove.840B.300d.txt', help='GloVe vector file.') 22 | parser.add_argument('--wv_dim', type=int, default=300, help='GloVe vector dimension.') 23 | parser.add_argument('--min_freq', type=int, default=0, help='If > 0, use min_freq as the cutoff.') 24 | parser.add_argument('--lower', action='store_true', help='If specified, lowercase all words.') 25 | 26 | args = parser.parse_args() 27 | return args 28 | 29 | def main(): 30 | args = parse_args() 31 | 32 | # input files 33 | train_file = args.data_dir + '/JMEE_train.json' 34 | dev_file = args.data_dir + '/JMEE_dev.json' 35 | test_file = args.data_dir + '/JMEE_test.json' 36 | wv_file = args.glove_dir + '/' + args.wv_file 37 | wv_dim = args.wv_dim 38 | 39 | # output files 40 | helper.ensure_dir(args.vocab_dir) 41 | vocab_file = args.vocab_dir + '/vocab.pkl' 42 | emb_file = args.vocab_dir + '/embedding.npy' 43 | 44 | # load files 45 | print("loading files...") 46 | train_tokens = load_tokens(train_file) 47 | dev_tokens = load_tokens(dev_file) 48 | test_tokens = load_tokens(test_file) 49 | if args.lower: 50 | train_tokens, dev_tokens, test_tokens = [[t.lower() for t in tokens] for tokens in\ 51 | (train_tokens, dev_tokens, test_tokens)] 52 | 53 | # load glove 54 | print("loading glove...") 55 | glove_vocab = vocab.load_glove_vocab(wv_file, wv_dim) 56 | print("{} words loaded from glove.".format(len(glove_vocab))) 57 | 58 | print("building vocab...") 59 | v = build_vocab(train_tokens, glove_vocab, args.min_freq) 60 | 61 | print("calculating oov...") 62 | datasets = {'train': train_tokens, 'dev': dev_tokens, 'test': test_tokens} 63 | for dname, d in datasets.items(): 64 | total, oov = count_oov(d, v) 65 | print("{} oov: {}/{} ({:.2f}%)".format(dname, oov, total, oov*100.0/total)) 66 | 67 | print("building embeddings...") 68 | embedding = vocab.build_embedding(wv_file, v, wv_dim) 69 | print("embedding size: {} x {}".format(*embedding.shape)) 70 | 71 | print("dumping to files...") 72 | with open(vocab_file, 'wb') as outfile: 73 | pickle.dump(v, outfile) 74 | np.save(emb_file, embedding) 75 | print("all done.") 76 | 77 | # def load_tokens(filename): 78 | # with open(filename) as infile: 79 | # data = json.load(infile) 80 | # tokens = [] 81 | # for d in data: 82 | # ts = d['token'] 83 | # ss, se, os, oe = d['subj_start'], d['subj_end'], d['obj_start'], d['obj_end'] 84 | # # do not create vocab for entity words 85 | # ts[ss:se+1] = ['']*(se-ss+1) 86 | # ts[os:oe+1] = ['']*(oe-os+1) 87 | # tokens += list(filter(lambda t: t!='', ts)) 88 | # print("{} tokens from {} examples loaded from {}.".format(len(tokens), len(data), filename)) 89 | # return tokens 90 | def load_tokens(filename): 91 | parsed_data = json.load(open(filename)) 92 | 93 | tokens = [] 94 | for sent in parsed_data: 95 | ts = sent['words'] 96 | # # do not create vocab for entity words 97 | # events = sent['golden-event-mentions'] 98 | # for event in events: 99 | # trigger = event['trigger'] 100 | # args = event['arguments'] 101 | # for arg in args: 102 | # arg_start = arg['start'] 103 | # arg_end = arg['end'] 104 | tokens += list(filter(lambda t: t != '', ts)) 105 | print("{} tokens from {} examples loaded from {}.".format(len(tokens), len(parsed_data), filename)) 106 | return tokens 107 | 108 | def build_vocab(tokens, glove_vocab, min_freq): 109 | """ build vocab from tokens and glove words. """ 110 | counter = Counter(t for t in tokens) 111 | # if min_freq > 0, use min_freq, otherwise keep all glove words 112 | if min_freq > 0: 113 | v = sorted([t for t in counter if counter.get(t) >= min_freq], key=counter.get, reverse=True) 114 | else: 115 | v = sorted([t for t in counter if t in glove_vocab], key=counter.get, reverse=True) 116 | # add special tokens and entity mask tokens 117 | v = constant.VOCAB_PREFIX + entity_masks() + v 118 | print("vocab built with {}/{} words.".format(len(v), len(counter))) 119 | return v 120 | 121 | def count_oov(tokens, vocab): 122 | c = Counter(t for t in tokens) 123 | total = sum(c.values()) 124 | matched = sum(c[t] for t in vocab) 125 | return total, total - matched 126 | 127 | def entity_masks(): 128 | """ Get all entity mask tokens as a list. Accoeding to constant.py""" 129 | masks = [] 130 | # subj_entities = list(constant.SUBJ_NER_TO_ID.keys())[2:] 131 | obj_entities = list(constant.OBJ_NER_TO_ID.keys())[2:] 132 | # masks += ["SUBJ-" + e for e in subj_entities] 133 | # masks += ["OBJ-" + e for e in obj_entities] 134 | masks += [e for e in obj_entities] 135 | print(masks) #['SUBJ-ORGANIZATION', 'SUBJ-PERSON', 'OBJ-PERSON', 'OBJ-ORGANIZATION', 'OBJ-DATE', 'OBJ-NUMBER', 'OBJ-TITLE', 'OBJ-COUNTRY', 'OBJ-LOCATION', 'OBJ-CITY', 'OBJ-MISC', 'OBJ-STATE_OR_PROVINCE', 'OBJ-DURATION', 'OBJ-NATIONALITY', 'OBJ-CAUSE_OF_DEATH', 'OBJ-CRIMINAL_CHARGE', 'OBJ-RELIGION', 'OBJ-URL', 'OBJ-IDEOLOGY'] 136 | return masks 137 | 138 | if __name__ == '__main__': 139 | main() 140 | 141 | 142 | -------------------------------------------------------------------------------- /src/util/util_model.py: -------------------------------------------------------------------------------- 1 | import json 2 | import math 3 | import sys 4 | import ujson as json 5 | 6 | import torch 7 | import torch.nn as nn 8 | from torch.nn import init 9 | import numpy as np 10 | 11 | import sys 12 | #sys.path.append('/dvmm-filer2/users/manling/mm-event-graph2') 13 | 14 | 15 | 16 | class SparseMM(torch.autograd.Function): 17 | """ 18 | Sparse x dense matrix multiplication with autograd support. 19 | 20 | Implementation by Soumith Chintala: 21 | https://discuss.pytorch.org/t/ 22 | does-pytorch-support-autograd-on-sparse-matrix/6156/7 23 | """ 24 | 25 | def __init__(self, sparse): 26 | super(SparseMM, self).__init__() 27 | self.sparse = sparse 28 | 29 | def forward(self, dense): 30 | return torch.mm(self.sparse, dense) 31 | 32 | def backward(self, grad_output): 33 | grad_input = None 34 | if self.needs_input_grad[0]: 35 | grad_input = torch.mm(self.sparse.t(), grad_output) 36 | return grad_input 37 | 38 | 39 | class Bottle(nn.Module): 40 | ''' Perform the reshape routine before and after an operation ''' 41 | 42 | def forward(self, input): 43 | size = input.size() 44 | out = super(Bottle, self).forward(input.contiguous().view(np.prod(size[:-1]), size[-1])) 45 | return out.view(*(size[:-1] + (-1,))) 46 | 47 | 48 | 49 | class XavierLinear(nn.Module): 50 | ''' 51 | Simple Linear layer with Xavier init 52 | 53 | Paper by Xavier Glorot and Yoshua Bengio (2010): 54 | Understanding the difficulty of training deep feedforward neural networks 55 | http://proceedings.mlr.press/v9/glorot10a/glorot10a.pdf 56 | ''' 57 | 58 | def __init__(self, in_features, out_features, bias=True): 59 | super(XavierLinear, self).__init__() 60 | self.linear = nn.Linear(in_features, out_features, bias=bias) 61 | init.xavier_normal_(self.linear.weight) 62 | 63 | def forward(self, x): 64 | return self.linear(x) 65 | 66 | class OrthogonalLinear(nn.Module): 67 | def __init__(self, in_features, out_features, bias=True): 68 | super(OrthogonalLinear, self).__init__() 69 | self.linear = nn.Linear(in_features, out_features, bias=bias) 70 | init.orthogonal_(self.linear.weight) 71 | 72 | def forward(self, x): 73 | return self.linear(x) 74 | 75 | 76 | class BottledLinear(Bottle, nn.Linear): 77 | pass 78 | 79 | 80 | class BottledXavierLinear(Bottle, XavierLinear): 81 | pass 82 | 83 | 84 | class BottledOrthogonalLinear(Bottle, OrthogonalLinear): 85 | pass 86 | 87 | 88 | class MLP(nn.Module): 89 | def __init__(self, dim_in_hid_out, act_fn='ReLU', last_act=False): 90 | super(MLP, self).__init__() 91 | layers = [] 92 | for i in range(len(dim_in_hid_out) - 1): 93 | layers.append(XavierLinear(dim_in_hid_out[i], dim_in_hid_out[i + 1])) 94 | if i < len(dim_in_hid_out) - 2 or last_act: 95 | layers.append(getattr(torch.nn, act_fn)()) 96 | self.model = torch.nn.Sequential(*layers) 97 | 98 | def forward(self, x): 99 | return self.model(x) 100 | 101 | class BottledMLP(Bottle, MLP): 102 | pass 103 | 104 | 105 | 106 | def log(*args, **kwargs): 107 | print(file=sys.stdout, flush=True, *args, **kwargs) 108 | 109 | 110 | def logerr(*args, **kwargs): 111 | print(file=sys.stderr, flush=True, *args, **kwargs) 112 | 113 | 114 | def logonfile(fp, *args, **kwargs): 115 | fp.write(*args, **kwargs) 116 | 117 | 118 | def progressbar(cur, total, other_information): 119 | percent = '{:.2%}'.format(cur / total) 120 | if type(other_information) is str: 121 | log("\r[%-50s] %s %s" % ('=' * int(math.floor(cur * 50 / total)), percent, other_information)) 122 | else: 123 | log("\r[%-50s] %s" % ('=' * int(math.floor(cur * 50 / total)), percent)) 124 | 125 | 126 | def save_hyps(hyps, fp): 127 | json.dump(hyps, fp) 128 | 129 | 130 | def load_hyps(fp): 131 | hyps = json.load(fp) 132 | return hyps 133 | 134 | 135 | def masked_log_softmax(vector, mask, dim=-1): 136 | """ 137 | mask: [1,1,1,0,0]: the padded one is 0 138 | 139 | ``torch.nn.functional.log_softmax(vector)`` does not work if some elements of ``vector`` should be 140 | masked. This performs a log_softmax on just the non-masked portions of ``vector``. Passing 141 | ``None`` in for the mask is also acceptable; you'll just get a regular log_softmax. 142 | ``vector`` can have an arbitrary number of dimensions; the only requirement is that ``mask`` is 143 | broadcastable to ``vector's`` shape. If ``mask`` has fewer dimensions than ``vector``, we will 144 | unsqueeze on dimension 1 until they match. If you need a different unsqueezing of your mask, 145 | do it yourself before passing the mask into this function. 146 | In the case that the input vector is completely masked, the return value of this function is 147 | arbitrary, but not ``nan``. You should be masking the result of whatever computation comes out 148 | of this in that case, anyway, so the specific values returned shouldn't matter. Also, the way 149 | that we deal with this case relies on having single-precision floats; mixing half-precision 150 | floats with fully-masked vectors will likely give you ``nans``. 151 | If your logits are all extremely negative (i.e., the max value in your logit vector is -50 or 152 | lower), the way we handle masking here could mess you up. But if you've got logit values that 153 | extreme, you've got bigger problems than this. 154 | """ 155 | if mask is not None: 156 | mask = mask.float() 157 | while mask.dim() < vector.dim(): 158 | mask = mask.unsqueeze(1) 159 | # vector + mask.log() is an easy way to zero out masked elements in logspace, but it 160 | # results in nans when the whole vector is masked. We need a very small value instead of a 161 | # zero in the mask for these cases. log(1 + 1e-45) is still basically 0, so we can safely 162 | # just add 1e-45 before calling mask.log(). We use 1e-45 because 1e-46 is so small it 163 | # becomes 0 - this is just the smallest value we can actually use. 164 | vector = vector + (mask + 1e-45).log() # float('-inf') 165 | return torch.nn.functional.log_softmax(vector, dim=dim) -------------------------------------------------------------------------------- /src/models/modules/EmbeddingLayer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn import functional as F 4 | 5 | 6 | class EmbeddingLayer(nn.Module): 7 | def __init__(self, embedding_size=None, embedding_matrix=None, 8 | fine_tune=True, dropout=0.5, 9 | padding_idx=None, 10 | max_norm=None, norm_type=2, scale_grad_by_freq=False, 11 | sparse=False, 12 | device=torch.device("cpu")): 13 | ''' 14 | Embedding Layer need at least one of `embedding_size` and `embedding_matrix` 15 | :param embedding_size: tuple, contains 2 integers indicating the shape of embedding matrix, eg: (20000, 300) 16 | :param embedding_matrix: torch.Tensor, the pre-trained value of embedding matrix 17 | :param fine_tune: boolean, whether fine tune embedding matrix 18 | :param dropout: float, dropout rate 19 | :param padding_idx: int, if given, pads the output with zeros whenever it encounters the index 20 | :param max_norm: float, if given, will renormalize the embeddings to always have a norm lesser than this 21 | :param norm_type: float, the p of the p-norm to compute for the max_norm option 22 | :param scale_grad_by_freq: boolean, if given, this will scale gradients by the frequency of the words in the mini-batch 23 | :param sparse: boolean, *unclear option copied from original module* 24 | ''' 25 | super(EmbeddingLayer, self).__init__() 26 | 27 | if embedding_matrix is not None: 28 | embedding_size = embedding_matrix.size() 29 | else: 30 | embedding_matrix = torch.nn.init.uniform_(torch.FloatTensor(embedding_size[0], embedding_size[1]), 31 | a=-0.15, 32 | b=0.15) 33 | assert (embedding_size is not None) 34 | assert (embedding_matrix is not None) 35 | # Config copying 36 | self.matrix = nn.Embedding(num_embeddings=embedding_size[0], 37 | embedding_dim=embedding_size[1], 38 | padding_idx=padding_idx, 39 | max_norm=max_norm, 40 | norm_type=norm_type, 41 | scale_grad_by_freq=scale_grad_by_freq, 42 | sparse=sparse) 43 | self.matrix.weight.data.copy_(embedding_matrix) 44 | self.matrix.weight.requires_grad = fine_tune 45 | self.dropout = dropout if type(dropout) == float and -1e-7 < dropout < 1 + 1e-7 else None 46 | 47 | self.device = device 48 | self.to(device) 49 | 50 | # def init_embeddings(self): 51 | # if self.emb_matrix is None: 52 | # self.emb.weight.data[1:,:].uniform_(-1.0, 1.0) 53 | # else: 54 | # self.emb_matrix = torch.from_numpy(self.emb_matrix) 55 | # self.emb.weight.data.copy_(self.emb_matrix) 56 | # # decide finetuning 57 | # if self.opt['topn'] <= 0: 58 | # print("Do not finetune word embedding layer.") 59 | # self.emb.weight.requires_grad = False 60 | # elif self.opt['topn'] < self.opt['vocab_size']: 61 | # print("Finetune top {} word embeddings.".format(self.opt['topn'])) 62 | # self.emb.weight.register_hook(lambda x: \ 63 | # torch_utils.keep_partial_grad(x, self.opt['topn'])) 64 | # else: 65 | # print("Finetune all embeddings.") 66 | 67 | def forward(self, x): 68 | ''' 69 | Forward this module 70 | :param x: torch.LongTensor, token sequence or sentence, shape is [batch, sentence_len] 71 | :return: torch.FloatTensor, output data, shape is [batch, sentence_len, embedding_size] 72 | ''' 73 | if self.dropout is not None: 74 | return F.dropout(self.matrix(x), p=self.dropout, training=self.training) 75 | else: 76 | return self.matrix(x) 77 | 78 | 79 | class MultiLabelEmbeddingLayer(nn.Module): 80 | def __init__(self, embedding_size=None, embedding_matrix=None, 81 | fine_tune=True, dropout=0.5, 82 | padding_idx=None, 83 | max_norm=None, norm_type=2, scale_grad_by_freq=False, 84 | sparse=False, 85 | device=torch.device("cpu")): 86 | ''' 87 | MultiLabelEmbeddingLayer Layer need at least one of `embedding_size` and `embedding_matrix` 88 | :param embedding_size: tuple, contains 2 integers indicating the shape of embedding matrix, eg: (20000, 300) 89 | :param embedding_matrix: torch.Tensor, the pre-trained value of embedding matrix 90 | :param fine_tune: boolean, whether fine tune embedding matrix 91 | :param dropout: float, dropout rate 92 | :param padding_idx: int, if given, pads the output with zeros whenever it encounters the index 93 | :param max_norm: float, if given, will renormalize the embeddings to always have a norm lesser than this 94 | :param norm_type: float, the p of the p-norm to compute for the max_norm option 95 | :param scale_grad_by_freq: boolean, if given, this will scale gradients by the frequency of the words in the mini-batch 96 | :param sparse: boolean, *unclear option copied from original module* 97 | ''' 98 | super(MultiLabelEmbeddingLayer, self).__init__() 99 | 100 | if embedding_matrix is not None: 101 | embedding_size = embedding_matrix.size() 102 | else: 103 | embedding_matrix = torch.randn(embedding_size[0], embedding_size[1]) 104 | assert (embedding_size is not None) 105 | assert (embedding_matrix is not None) 106 | # Config copying 107 | self.matrix = nn.Embedding(num_embeddings=embedding_size[0], 108 | embedding_dim=embedding_size[1], 109 | padding_idx=padding_idx, 110 | max_norm=max_norm, 111 | norm_type=norm_type, 112 | scale_grad_by_freq=scale_grad_by_freq, 113 | sparse=sparse) 114 | self.matrix.weight.data.copy_(embedding_matrix) 115 | self.matrix.weight.requires_grad = fine_tune 116 | self.dropout = dropout if type(dropout) == float and -1e-7 < dropout < 1 + 1e-7 else None 117 | 118 | self.device = device 119 | self.to(device) 120 | 121 | def forward(self, x): 122 | ''' 123 | Forward this module 124 | :param x: list, token sequence or sentence, shape is [batch, sentence_len, variable_size(>=1)] 125 | :return: torch.FloatTensor, output data, shape is [batch, sentence_len, embedding_size] 126 | ''' 127 | BATCH = len(x) 128 | SEQ_LEN = len(x[0]) 129 | x = [self.matrix(torch.LongTensor(x[i][j]).to(self.device)).sum(0) 130 | for i in range(BATCH) 131 | for j in range(SEQ_LEN)] 132 | x = torch.stack(x).view(BATCH, SEQ_LEN, -1) 133 | if self.dropout is not None: 134 | return F.dropout(x, p=self.dropout, training=self.training) 135 | else: 136 | return x 137 | -------------------------------------------------------------------------------- /src/eval/SRtesting.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | class SRTester(): 4 | def __init__(self): 5 | pass 6 | 7 | def calculate_lists(self, y, y_): 8 | ''' 9 | for a sequence, whether the prediction is correct 10 | note that len(y) == len(y_) 11 | :param y: 12 | :param y_: 13 | :return: 14 | ''' 15 | ct = 0 16 | p2 = len(y_) 17 | p1 = len(y) 18 | for i in range(p2): 19 | if y[i] == y_[i]: 20 | ct = ct + 1 21 | if ct == 0 or p1 == 0 or p2 == 0: 22 | return 0.0, 0.0, 0.0 23 | else: 24 | p = 1.0 * ct / p2 25 | r = 1.0 * ct / p1 26 | f1 = 2.0 * p * r / (p + r) 27 | return p, r, f1 28 | 29 | def calculate_sets_no_order(self, y, y_): 30 | ''' 31 | for each predicted item, whether it is in the gt 32 | :param y: [batch, items] 33 | :param y_: [batch, items] 34 | :return: 35 | ''' 36 | ct, p1, p2 = 0, 0, 0 37 | for batch, batch_ in zip(y, y_): 38 | value_set = set(batch) 39 | value_set_ = set(batch_) 40 | p1 += len(value_set) 41 | p2 += len(value_set_) 42 | 43 | for value_ in value_set_: 44 | # if value_ == '(0,0,0)': 45 | if value_ in value_set: 46 | ct += 1 47 | 48 | if ct == 0 or p1 == 0 or p2 == 0: 49 | return 0.0, 0.0, 0.0 50 | else: 51 | p = 1.0 * ct / p2 52 | r = 1.0 * ct / p1 53 | f1 = 2.0 * p * r / (p + r) 54 | return p, r, f1 55 | 56 | def calculate_sets_noun(self, y, y_): 57 | ''' 58 | for each ground truth entity, whether it is in the predicted entities 59 | :param y: [batch, role_num, multiple_args] 60 | :param y_: [batch, role_num, multiple_entities] 61 | :return: 62 | ''' 63 | # print('y', y) 64 | # print('y_', y_) 65 | ct, p1, p2 = 0, 0, 0 66 | # for batch_idx, batch_idx_ in zip(y, y_): 67 | # batch = y[batch_idx] 68 | # batch_ = y_[batch_idx_] 69 | for batch, batch_ in zip(y, y_): 70 | # print('batch', batch) 71 | # print('batch_', batch_) 72 | p1 += len(batch) 73 | p2 += len(batch_) 74 | for role in batch: 75 | found = False 76 | entities = batch[role] 77 | for entity in entities: 78 | for role_ in batch_: 79 | entities_ = batch_[role_] 80 | if entity in entities_: 81 | ct += 1 82 | found = True 83 | break 84 | if found: 85 | break 86 | 87 | if ct == 0 or p1 == 0 or p2 == 0: 88 | return 0.0, 0.0, 0.0 89 | else: 90 | p = 1.0 * ct / p2 91 | r = 1.0 * ct / p1 92 | f1 = 2.0 * p * r / (p + r) 93 | return p, r, f1 94 | 95 | def calculate_sets_triple(self, y, y_): 96 | ''' 97 | for each role, whether the predicted entities have overlap with the gt entities 98 | :param y: dict, role -> entities 99 | :param y_: dict, role -> entities 100 | :return: 101 | ''' 102 | ct, p1, p2 = 0, 0, 0 103 | # for batch_idx, batch_idx_ in zip(y, y_): 104 | # batch = y[batch_idx] 105 | # batch_ = y_[batch_idx_] 106 | for batch, batch_ in zip(y, y_): 107 | p1 += len(batch) 108 | p2 += len(batch_) 109 | for role in batch: 110 | # is_correct = False 111 | entities = batch[role] 112 | if role in batch_: 113 | entities_ = batch_[role] 114 | for entity_ in entities_: 115 | if entity_ in entities: 116 | ct += 1 117 | # is_correct = True 118 | break 119 | # if not is_correct: 120 | # print('Wrong one:', role, batch[role]) 121 | 122 | if ct == 0 or p1 == 0 or p2 == 0: 123 | return 0.0, 0.0, 0.0 124 | else: 125 | p = 1.0 * ct / p2 126 | r = 1.0 * ct / p1 127 | f1 = 2.0 * p * r / (p + r) 128 | return p, r, f1 129 | 130 | 131 | def visualize_sets_triple(self, image_id_batch, y_verb, y_verb_, y, y_, verb_id2s, role_id2s, noun_id2s, sr_visualpath, image_path=None): 132 | ''' 133 | for each role, whether the predicted entities have overlap with the gt entities 134 | :param y: dict, role -> entities 135 | :param y_: dict, role -> entities 136 | :return: 137 | ''' 138 | 139 | # sr_visualpath = '/scratch/manling2/html/m2e2/sr_errors' 140 | image_path = '../of500_images_resized/' 141 | 142 | ct, p1, p2 = 0, 0, 0 143 | # for batch_idx, batch_idx_ in zip(y, y_): 144 | # batch = y[batch_idx] 145 | # batch_ = y_[batch_idx_] 146 | batch_idx = 0 147 | print('visualize image:', image_id_batch) 148 | 149 | if y_verb_[0] == y_verb[0]: 150 | f_html = open(os.path.join(sr_visualpath, 'verb_correct', '%s.html' % (image_id_batch[0])), 'w') 151 | else: 152 | f_html = open(os.path.join(sr_visualpath, 'verb_wrong', '%s.html' % (image_id_batch[0])), 'w') 153 | f_html.write("\n
") 154 | 155 | f_html.write('[verb_prediction] %s \n
' % verb_id2s[y_verb_[0]]) 156 | f_html.write('[verb_ground_truth] %s \n
\n
' % verb_id2s[y_verb[0]]) 157 | 158 | for batch, batch_ in zip(y, y_): 159 | # print('image_id', image_id_batch[batch_idx]) 160 | p1 += len(batch) 161 | p2 += len(batch_) 162 | for role in batch: 163 | is_correct = False 164 | entities = batch[role] 165 | if role in batch_: 166 | entities_ = batch_[role] 167 | for entity_ in entities_: 168 | if entity_ in entities: 169 | ct += 1 170 | is_correct = True 171 | break 172 | if not is_correct: 173 | f_html.write('[prediction]\n
') 174 | f_html.write('%s.%s = [' % (verb_id2s[y_verb_[batch_idx]], role_id2s[role])) 175 | for entity_ in entities_: 176 | f_html.write('%s, ' % noun_id2s[entity_]) 177 | f_html.write(']\n
') 178 | f_html.write('[ground truth]\n
') 179 | f_html.write('%s.%s = [' % (verb_id2s[y_verb[batch_idx]], role_id2s[role])) 180 | for entity in entities: 181 | f_html.write('%s, ' % noun_id2s[entity]) 182 | f_html.write(']\n\n

') 183 | batch_idx = batch_idx + 1 184 | f_html.flush() 185 | f_html.close() 186 | 187 | if ct == 0 or p1 == 0 or p2 == 0: 188 | return 0.0, 0.0, 0.0 189 | else: 190 | p = 1.0 * ct / p2 191 | r = 1.0 * ct / p1 192 | f1 = 2.0 * p * r / (p + r) 193 | return p, r, f1 -------------------------------------------------------------------------------- /src/util/constant.py: -------------------------------------------------------------------------------- 1 | """ 2 | Define constants. 3 | """ 4 | EMB_INIT_RANGE = 1.0 5 | 6 | # vocab 7 | PAD_TOKEN = '' 8 | PAD_ID = 0 9 | UNK_TOKEN = '' 10 | UNK_ID = 1 11 | 12 | VOCAB_PREFIX = [PAD_TOKEN, UNK_TOKEN] 13 | 14 | # hard-coded mappings from fields to ids 15 | # SUBJ is trigger, no type 16 | # SUBJ_NER_TO_ID = {PAD_TOKEN: 0, UNK_TOKEN: 1, 'ORGANIZATION': 2, 'PERSON': 3} 17 | 18 | OBJ_NER_TO_ID = {PAD_TOKEN: 0, UNK_TOKEN: 1, 'Weapon': 2, 'Organization': 3, 'Geopolitical_Entity': 4, 'Location': 5, 'Time': 6, 'Vehicle': 7, 'Value': 8, 'Facility': 9, 'Person': 10} 19 | 20 | NER_NORM = {'WEA': 'Weapon', 'ORG': 'Organization', 'GPE': 'Geopolitical_Entity', 'LOC': 'Location', 'TIME': 'Time', 'VEH': 'Vehicle', 'VALUE': 'Value', 'FAC': 'Facility', 'PER': 'Person'} 21 | 22 | NER_NIL_LABEL = 'O' 23 | NUM_OTHER_NER = 3 24 | NER_TO_ID = {PAD_TOKEN: 0, UNK_TOKEN: 1, NER_NIL_LABEL: 2, 'Weapon': 3, 'Organization': 4, 'Geopolitical_Entity': 5, 'Location': 6, 'Time': 7, 'Vehicle': 8, 'Value': 9, 'Facility': 10, 'Person': 11} 25 | # {'ORG:Sports', 'VALUE:other', 'GPE:GPE-Cluster', 'VEH:Subarea-Vehicle', 'WEA:Nuclear', 'ORG:Religious', 'ORG:Government', 'FAC:Building-Grounds', 'LOC:Address', 'PER:Indeterminate', 'LOC:Boundary', 'ORG:Entertainment', 'WEA:Exploding', 'PER:Group', 'WEA:Blunt', 'VEH:Underspecified', 'WEA:Projectile', 'GPE:Nation', 'VEH:Water', 'GPE:Special', 'LOC:Region-General', 'LOC:Celestial', 'GPE:County-or-District', 'VEH:Air', 'WEA:Sharp', 'ORG:Educational', 'PER:Individual', 'FAC:Airport', 'WEA:Chemical', 'GPE:Population-Center', 'LOC:Region-International', 'TIME:other', 'ORG:Media', 'LOC:Water-Body', 'WEA:Biological', 'FAC:Subarea-Facility', 'WEA:Shooting', 'ORG:Commercial', 'ORG:Non-Governmental', 'GPE:Continent', 'ORG:Medical-Science', 'WEA:Underspecified', 'FAC:Path', 'GPE:State-or-Province', 'VEH:Land', 'LOC:Land-Region-Natural', 'FAC:Plant' 26 | 27 | POS_TO_ID = {PAD_TOKEN: 0, UNK_TOKEN: 1, 'NNP': 2, 'NN': 3, 'IN': 4, 'DT': 5, ',': 6, 'JJ': 7, 'NNS': 8, 'VBD': 9, 'CD': 10, 'CC': 11, '.': 12, 'RB': 13, 'VBN': 14, 'PRP': 15, 'TO': 16, 'VB': 17, 'VBG': 18, 'VBZ': 19, 'PRP$': 20, ':': 21, 'POS': 22, '\'\'': 23, '``': 24, '-RRB-': 25, '-LRB-': 26, 'VBP': 27, 'MD': 28, 'NNPS': 29, 'WP': 30, 'WDT': 31, 'WRB': 32, 'RP': 33, 'JJR': 34, 'JJS': 35, '$': 36, 'FW': 37, 'RBR': 38, 'SYM': 39, 'EX': 40, 'RBS': 41, 'WP$': 42, 'PDT': 43, 'LS': 44, 'UH': 45, '#': 46} 28 | 29 | DEPREL_TO_ID = {PAD_TOKEN: 0, UNK_TOKEN: 1, 'punct': 2, 'compound': 3, 'case': 4, 'nmod': 5, 'det': 6, 'nsubj': 7, 'amod': 8, 'conj': 9, 'dobj': 10, 'ROOT': 11, 'cc': 12, 'nmod:poss': 13, 'mark': 14, 'advmod': 15, 'appos': 16, 'nummod': 17, 'dep': 18, 'ccomp': 19, 'aux': 20, 'advcl': 21, 'acl:relcl': 22, 'xcomp': 23, 'cop': 24, 'acl': 25, 'auxpass': 26, 'nsubjpass': 27, 'nmod:tmod': 28, 'neg': 29, 'compound:prt': 30, 'mwe': 31, 'parataxis': 32, 'root': 33, 'nmod:npmod': 34, 'expl': 35, 'csubj': 36, 'cc:preconj': 37, 'iobj': 38, 'det:predet': 39, 'discourse': 40, 'csubjpass': 41} 30 | 31 | NEGATIVE_LABEL_TRIGGER = 'other_event' 32 | NUM_OTHER_TRIGGER = 2 33 | LABEL_TO_ID_TRIGGER = {PAD_TOKEN: 0, NEGATIVE_LABEL_TRIGGER: 1, 'Conflict:Attack': 2, 'Conflict:Demonstrate': 3, 34 | 'Contact:Meet': 4, 'Contact:Phone-Write': 5, 'Life:Die': 6, 'Movement:Transport': 7, 35 | 'Justice:Arrest-Jail': 8, 'Transaction:Transfer-Money': 9} # 'Life:Injure', 'Justice:Release-Parole' 36 | LABEL_TO_ID_TRIGGER_UNSEEN = {'Personnel:Elect': 2, 37 | 'Life:Marry': 3, 'Life:Injure': 4, 'Justice:Execute': 5, 'Justice:Trial-Hearing': 6, 38 | 'Life:Be-Born': 7, 'Justice:Convict': 8, 'Justice:Release-Parole': 9, 'Justice:Fine': 10} 39 | LABEL_TO_ID_TRIGGER_UNVISUAL = {'Personnel:End-Position': 3, 'Personnel:Start-Position': 4, 'Personnel:Nominate': 5, 40 | 'Justice:Sue': 7, 'Business:End-Org': 9, 'Business:Start-Org': 11, 41 | 'Transaction:Transfer-Ownership': 10, 'Justice:Sentence': 13, 42 | 'Justice:Charge-Indict': 16, 'Business:Declare-Bankruptcy': 18, 43 | 'Justice:Pardon': 21, 'Justice:Appeal': 22, 'Justice:Extradite': 23, 44 | 'Life:Divorce': 24, 'Business:Merge-Org': 25, 'Justice:Acquit': 26} 45 | LABEL_TO_ID_TRIGGER_ALL = {PAD_TOKEN: 0, UNK_TOKEN: 1, NEGATIVE_LABEL_TRIGGER: 2, 'Personnel:Elect': 3, 46 | 'Personnel:End-Position': 4, 'Personnel:Start-Position': 5, 47 | 'Movement:Transport': 6, 'Conflict:Attack': 7, 'Personnel:Nominate': 8, 'Contact:Meet': 9, 48 | 'Life:Marry': 10, 49 | 'Justice:Sue': 11, 'Contact:Phone-Write': 12, 'Transaction:Transfer-Money': 13, 50 | 'Conflict:Demonstrate': 14, 'Life:Injure': 15, 'Business:End-Org': 16, 'Life:Die': 17, 51 | 'Justice:Arrest-Jail': 18, 'Transaction:Transfer-Ownership': 19, 'Business:Start-Org': 20, 52 | 'Justice:Execute': 21, 'Justice:Sentence': 22, 'Justice:Trial-Hearing': 23, 'Life:Be-Born': 24, 53 | 'Justice:Charge-Indict': 25, 'Justice:Convict': 26, 'Business:Declare-Bankruptcy': 27, 54 | 'Justice:Release-Parole': 28, 'Justice:Fine': 29, 'Justice:Pardon': 30, 'Justice:Appeal': 31, 55 | 'Justice:Extradite': 32, 'Life:Divorce': 33, 'Business:Merge-Org': 34, 'Justice:Acquit': 35} 56 | 57 | NEGATIVE_LABEL_ROLE = 'other_role' 58 | NUM_OTHER_ROLE = 2 59 | LABEL_TO_ID_ROLE = {PAD_TOKEN: 0, NEGATIVE_LABEL_ROLE: 1, 'Buyer': 2, 'Target': 3, 'Agent': 4, 'Vehicle': 5, 60 | 'Instrument': 6, 'Person': 7, 'Victim': 8, 'Attacker': 9, 'Artifact': 10, 'Seller': 11, 61 | 'Recipient': 12, 'Money': 13, 'Giver': 14, 'Entity': 15, 'Place': 16, 'Defendant': 17, 62 | 'Destination': 18, 'Origin': 19} 63 | LABEL_TO_ID_ROLE_UNSEEN = {'Beneficiary': 2, 'Prosecutor': 3, 64 | 'Plaintiff': 4, 'Adjudicator': 5, 'Agent': 6, 'Vehicle': 7} 65 | LABEL_TO_ID_ROLE_UNVISUAL = {'Time-Holds': 8, 'Time-Starting': 10, 'Price': 11, 'Time-Before': 12, 66 | 'Time-After': 14, 'Time-Ending': 19, 'Time-At-End': 22, 67 | 'Time-At-Beginning': 28, 'Time-Within': 32, 'Org': 18, 68 | 'Sentence': 15, 'Crime': 16, 'Position': 24 } # 69 | LABEL_TO_ID_ROLE_ALL = {PAD_TOKEN: 0, UNK_TOKEN: 1, NEGATIVE_LABEL_ROLE: 2, 'Buyer': 3, 'Target': 4, 70 | 'Instrument': 5, 'Person': 6, 'Victim': 7, 71 | 'Time-Holds': 8, 'Attacker': 9, 'Time-Starting': 10, 'Price': 11, 'Time-Before': 12, 72 | 'Artifact': 13, 'Time-After': 14, 'Sentence': 15, 'Crime': 16, 'Destination': 17, 73 | 'Org': 18, 'Time-Ending': 19, 'Beneficiary': 20, 'Seller': 21, 'Time-At-End': 22, 74 | 'Recipient': 23, 'Position': 24, 'Money': 25, 'Giver': 26, 'Prosecutor': 27, 75 | 'Time-At-Beginning': 28, 'Entity': 29, 'Place': 30, 'Defendant': 31, 76 | 'Time-Within': 32, 'Plaintiff': 33, 'Adjudicator': 34, 'Agent': 35, 77 | 'Origin': 36, 'Vehicle': 37} 78 | 79 | # NEGATIVE_LABEL = 'no_relation' 80 | # LABEL_TO_ID = {'no_relation': 0, 'per:title': 1, 'org:top_members/employees': 2, 'per:employee_of': 3, 'org:alternate_names': 4, 'org:country_of_headquarters': 5, 'per:countries_of_residence': 6, 'org:city_of_headquarters': 7, 'per:cities_of_residence': 8, 'per:age': 9, 'per:stateorprovinces_of_residence': 10, 'per:origin': 11, 'org:subsidiaries': 12, 'org:parents': 13, 'per:spouse': 14, 'org:stateorprovince_of_headquarters': 15, 'per:children': 16, 'per:other_family': 17, 'per:alternate_names': 18, 'org:members': 19, 'per:siblings': 20, 'per:schools_attended': 21, 'per:parents': 22, 'per:date_of_death': 23, 'org:member_of': 24, 'org:founded_by': 25, 'org:website': 26, 'per:cause_of_death': 27, 'org:political/religious_affiliation': 28, 'org:founded': 29, 'per:city_of_death': 30, 'org:shareholders': 31, 'org:number_of_employees/members': 32, 'per:date_of_birth': 33, 'per:city_of_birth': 34, 'per:charges': 35, 'per:stateorprovince_of_death': 36, 'per:religion': 37, 'per:stateorprovince_of_birth': 38, 'per:country_of_birth': 39, 'org:dissolved': 40, 'per:country_of_death': 41} 81 | 82 | INFINITY_NUMBER = 1e12 83 | 84 | # TYPE_ROLE_MAP : in encoder_ont 85 | # not mask, zero-shot, so only list the ones that belongs to the ontology 86 | -------------------------------------------------------------------------------- /data/ace/ace_sr_mapping.txt: -------------------------------------------------------------------------------- 1 | apprehending agent Justice||Arrest|Jail Agent 2 | arresting agent Justice||Arrest|Jail Agent 3 | catching agent Justice||Arrest|Jail Agent 4 | chasing agent Justice||Arrest|Jail Agent 5 | detaining agent Justice||Arrest|Jail Agent 6 | dragging agent Justice||Arrest|Jail Agent 7 | frisking agent Justice||Arrest|Jail Agent 8 | handcuffing agent Justice||Arrest|Jail Agent 9 | burying agent Life||Die Agent 10 | colliding agent Life||Die Agent 11 | crashing agent Life||Die Agent 12 | carrying agent Movement||Transport Agent 13 | carrying agentpart Movement||Transport Agent 14 | carrying item Movement||Transport Agent 15 | fetching agent Movement||Transport Agent 16 | hauling carrier Movement||Transport Agent 17 | lifting agent Movement||Transport Agent 18 | loading agent Movement||Transport Agent 19 | towing agent Movement||Transport Agent 20 | unloading agent Movement||Transport Agent 21 | unpacking agent Movement||Transport Agent 22 | disembarking agent Movement||Transport Agent 23 | landing agent Movement||Transport Agent 24 | marching agent Movement||Transport Agent 25 | piloting agent Movement||Transport Agent 26 | rafting agent Movement||Transport Agent 27 | rowing agent Movement||Transport Agent 28 | skidding agent Movement||Transport Agent 29 | taxiing agent Movement||Transport Agent 30 | wheeling agent Movement||Transport Agent 31 | boarding agent Movement||Transport Artifact 32 | boating boaters Movement||Transport Artifact 33 | fetching item Movement||Transport Artifact 34 | hauling item Movement||Transport Artifact 35 | lifting item Movement||Transport Artifact 36 | loading item Movement||Transport Artifact 37 | towing item Movement||Transport Artifact 38 | unloading item Movement||Transport Artifact 39 | unpacking item Movement||Transport Artifact 40 | wheeling item Movement||Transport Artifact 41 | aiming agent Conflict||Attack Attacker 42 | attacking agent Conflict||Attack Attacker 43 | burning agent Conflict||Attack Attacker 44 | butting agent Conflict||Attack Attacker 45 | deflecting agent Conflict||Attack Attacker 46 | destroying agent Conflict||Attack Attacker 47 | ejecting agent Conflict||Attack Attacker 48 | erupting agent Conflict||Attack Attacker 49 | flaming agent Conflict||Attack Attacker 50 | hitting agent Conflict||Attack Attacker 51 | launching agent Conflict||Attack Attacker 52 | punching agent Conflict||Attack Attacker 53 | ramming agent Conflict||Attack Attacker 54 | shooting agent Conflict||Attack Attacker 55 | slapping agent Conflict||Attack Attacker 56 | striking agent Conflict||Attack Attacker 57 | striking coagent Conflict||Attack Attacker 58 | subduing agent Conflict||Attack Attacker 59 | tackling agent Conflict||Attack Attacker 60 | launching source Conflict||Attack Attacker 61 | fetching destination Movement||Transport Destination 62 | lifting end Movement||Transport Destination 63 | loading destination Movement||Transport Destination 64 | landing destination Movement||Transport Destination 65 | piloting end Movement||Transport Destination 66 | confronting agent Conflict||Demonstrate Entity 67 | confronting confronted Conflict||Demonstrate Entity 68 | congregating individuals Conflict||Demonstrate Entity 69 | gathering gatherers Conflict||Demonstrate Entity 70 | parading agent Conflict||Demonstrate Entity 71 | protesting agent Conflict||Demonstrate Entity 72 | communicating adressee Contact||Meet Entity 73 | communicating agent Contact||Meet Entity 74 | discussing agents Contact||Meet Entity 75 | saluting agent Contact||Meet Entity 76 | saluting target Contact||Meet Entity 77 | scolding agent Contact||Meet Entity 78 | scolding victim Contact||Meet Entity 79 | shaking agent Contact||Meet Entity 80 | socializing agent Contact||Meet Entity 81 | socializing coagent Contact||Meet Entity 82 | talking agent Contact||Meet Entity 83 | talking listener Contact||Meet Entity 84 | calling agent Contact||Phone|Write Entity 85 | phoning agent Contact||Phone|Write Entity 86 | telephoning agent Contact||Phone|Write Entity 87 | writing agent Contact||Phone|Write Entity 88 | writing target Contact||Phone|Write Entity 89 | buying agent Transaction||Transfer|Money Giver 90 | paying agent Transaction||Transfer|Money Giver 91 | aiming item Conflict||Attack Instrument 92 | attacking weapon Conflict||Attack Instrument 93 | deflecting deflecteditem Conflict||Attack Instrument 94 | destroying tool Conflict||Attack Instrument 95 | ejecting item Conflict||Attack Instrument 96 | hitting tool Conflict||Attack Instrument 97 | launching item Conflict||Attack Instrument 98 | ramming rammingitem Conflict||Attack Instrument 99 | shooting firearm Conflict||Attack Instrument 100 | shooting projectile Conflict||Attack Instrument 101 | slapping tool Conflict||Attack Instrument 102 | striking tool Conflict||Attack Instrument 103 | calling tool Contact||Phone|Write Instrument 104 | phoning tool Contact||Phone|Write Instrument 105 | writing tool Contact||Phone|Write Instrument 106 | catching tool Justice||Arrest|Jail Instrument 107 | dragging tool Justice||Arrest|Jail Instrument 108 | burying tool Life||Die Instrument 109 | crashing item Life||Die Instrument 110 | buying payment Transaction||Transfer|MONEY Instrument 111 | fetching source Movement||Transport Origin 112 | lifting start Movement||Transport Origin 113 | unloading source Movement||Transport Origin 114 | piloting start Movement||Transport Origin 115 | dragging contact Justice||Arrest|Jail Person 116 | apprehending victim Justice||Arrest|Jail Person 117 | arresting suspect Justice||Arrest|Jail Person 118 | catching caughtitem Justice||Arrest|Jail Person 119 | chasing chasee Justice||Arrest|Jail Person 120 | detaining victim Justice||Arrest|Jail Person 121 | dragging item Justice||Arrest|Jail Person 122 | frisking victim Justice||Arrest|Jail Person 123 | handcuffing victim Justice||Arrest|Jail Person 124 | aiming place Conflict||Attack Place 125 | attacking place Conflict||Attack Place 126 | burning place Conflict||Attack Place 127 | butting place Conflict||Attack Place 128 | deflecting place Conflict||Attack Place 129 | destroying place Conflict||Attack Place 130 | ejecting place Conflict||Attack Place 131 | ejecting source Conflict||Attack Place 132 | erupting place Conflict||Attack Place 133 | flaming place Conflict||Attack Place 134 | hitting place Conflict||Attack Place 135 | launching place Conflict||Attack Place 136 | punching place Conflict||Attack Place 137 | ramming place Conflict||Attack Place 138 | slapping place Conflict||Attack Place 139 | striking place Conflict||Attack Place 140 | subduing place Conflict||Attack Place 141 | tackling place Conflict||Attack Place 142 | confronting place Conflict||Demonstrate Place 143 | congregating place Conflict||Demonstrate Place 144 | gathering place Conflict||Demonstrate Place 145 | parading place Conflict||Demonstrate Place 146 | protesting place Conflict||Demonstrate Place 147 | communicating place Contact||Meet Place 148 | discussing place Contact||Meet Place 149 | saluting place Contact||Meet Place 150 | scolding place Contact||Meet Place 151 | shaking place Contact||Meet Place 152 | socializing place Contact||Meet Place 153 | talking place Contact||Meet Place 154 | calling place Contact||Phone|Write Place 155 | phoning place Contact||Phone|Write Place 156 | telephoning place Contact||Phone|Write Place 157 | writing place Contact||Phone|Write Place 158 | apprehending place Justice||Arrest|Jail Place 159 | arresting place Justice||Arrest|Jail Place 160 | catching place Justice||Arrest|Jail Place 161 | chasing place Justice||Arrest|Jail Place 162 | detaining place Justice||Arrest|Jail Place 163 | dragging place Justice||Arrest|Jail Place 164 | frisking place Justice||Arrest|Jail Place 165 | handcuffing place Justice||Arrest|Jail Place 166 | burying destination Life||Die Place 167 | burying place Life||Die Place 168 | colliding place Life||Die Place 169 | crashing place Life||Die Place 170 | steering place Movement||Transport Place 171 | carrying place Movement||Transport Place 172 | fetching place Movement||Transport Place 173 | hauling place Movement||Transport Place 174 | lifting place Movement||Transport Place 175 | loading place Movement||Transport Place 176 | towing place Movement||Transport Place 177 | unloading place Movement||Transport Place 178 | unpacking place Movement||Transport Place 179 | boarding place Movement||Transport Place 180 | boating place Movement||Transport Place 181 | disembarking place Movement||Transport Place 182 | landing place Movement||Transport Place 183 | marching place Movement||Transport Place 184 | piloting place Movement||Transport Place 185 | rafting place Movement||Transport Place 186 | rowing place Movement||Transport Place 187 | skidding place Movement||Transport Place 188 | taxiing place Movement||Transport Place 189 | wheeling place Movement||Transport Place 190 | buying place Transaction||Transfer|Money Place 191 | paying place Transaction||Transfer|Money Place 192 | buying seller Transaction||Transfer|Money Recipient 193 | paying seller Transaction||Transfer|Money Recipient 194 | aiming target Conflict||Attack Target 195 | attacking victim Conflict||Attack Target 196 | burning target Conflict||Attack Target 197 | butting target Conflict||Attack Target 198 | destroying destroyeditem Conflict||Attack Target 199 | hitting victim Conflict||Attack Target 200 | hitting victimpart Conflict||Attack Target 201 | punching bodypart Conflict||Attack Target 202 | punching victim Conflict||Attack Target 203 | ramming victim Conflict||Attack Target 204 | shooting target Conflict||Attack Target 205 | slapping victim Conflict||Attack Target 206 | slapping victimpart Conflict||Attack Target 207 | striking agentpart Conflict||Attack Target 208 | subduing target Conflict||Attack Target 209 | tackling victim Conflict||Attack Target 210 | ejecting destination Conflict||Attack Target 211 | launching destination Conflict||Attack Target 212 | boarding vehicle Movement||Transport Vehicle 213 | boating vehicle Movement||Transport Vehicle 214 | steering agent Movement||Transport Vehicle 215 | steering tool Movement||Transport Vehicle 216 | steering vehicle Movement||Transport Vehicle 217 | hauling tool Movement||Transport Vehicle 218 | loading tool Movement||Transport Vehicle 219 | unloading tool Movement||Transport Vehicle 220 | unpacking container Movement||Transport Vehicle 221 | disembarking vehicle Movement||Transport Vehicle 222 | piloting vehicle Movement||Transport Vehicle 223 | rowing vehicle Movement||Transport Vehicle 224 | wheeling carrier Movement||Transport Vehicle 225 | burying item Life||Die Victim 226 | crashing against Life||Die Victim 227 | colliding item Life||Die Victim -------------------------------------------------------------------------------- /src/dataflow/torch/Data.py: -------------------------------------------------------------------------------- 1 | import json 2 | # import ujson as json 3 | from collections import Counter, OrderedDict 4 | 5 | import codecs 6 | import six 7 | import torch 8 | from torchtext.data import Field, Example, Pipeline, Dataset 9 | 10 | # import sys 11 | # sys.path.append('/dvmm-filer2/users/manling/mm-event-graph2') 12 | from src.dataflow.torch.Corpus import Corpus 13 | from src.dataflow.torch.Sentence import Sentence_ace 14 | 15 | 16 | class SparseField(Field): 17 | def process(self, batch, device=None, train=True): 18 | return batch 19 | 20 | 21 | class EntityField(Field): 22 | ''' 23 | Processing data each sentence has only one 24 | 25 | [(2, 3, "entity_type")] 26 | ''' 27 | 28 | def preprocess(self, x): 29 | return x 30 | 31 | def pad(self, minibatch): 32 | return minibatch 33 | 34 | def numericalize(self, arr, device=None, train=True): 35 | return arr 36 | 37 | 38 | class EventField(Field): 39 | ''' 40 | Processing data each sentence has only one 41 | 42 | { 43 | (2, 3, "event_type_str") --> [(1, 2, "role_type_str"), ...] 44 | ... 45 | } 46 | ''' 47 | 48 | def preprocess(self, x): 49 | return x 50 | 51 | def build_vocab(self, *args, **kwargs): 52 | counter = Counter() 53 | sources = [] 54 | for arg in args: 55 | if isinstance(arg, Dataset): 56 | sources += [getattr(arg, name) for name, field in 57 | arg.fields.items() if field is self] 58 | else: 59 | sources.append(arg) 60 | for data in sources: 61 | for x in data: 62 | for key, value in x.items(): 63 | for v in value: 64 | counter.update([v[2]]) 65 | self.vocab = self.vocab_cls(counter, specials=["OTHER"], **kwargs) 66 | 67 | def pad(self, minibatch): 68 | return minibatch 69 | 70 | def numericalize(self, arr, device=None, train=True): 71 | if self.use_vocab: 72 | # arr = [{key: [(v[0], v[1], self.vocab.stoi[v[2]]) for v in value] for key, value in dd.items()} for dd in 73 | # arr] 74 | arr_num = list() 75 | for dd in arr: 76 | dd_dict = dict() 77 | for key, value in dd.items(): 78 | value_list = [] 79 | for v in value: 80 | if v[2] in self.vocab.stoi: 81 | value_list.append( (v[0], v[1], self.vocab.stoi[v[2]]) ) 82 | else: 83 | value_list.append( (v[0], v[1], 0) ) 84 | dd_dict[key] = value_list 85 | arr_num.append(dd_dict) 86 | return arr_num 87 | 88 | 89 | class MultiTokenField(Field): 90 | ''' 91 | Processing data like "[ ["A", "A", "A"], ["A", "A"], ["A", "A"], ["A"] ]" 92 | ''' 93 | 94 | def preprocess(self, x): 95 | """Load a single example using this field, tokenizing if necessary. 96 | 97 | If the input is a Python 2 `str`, it will be converted to Unicode 98 | first. If `sequential=True`, it will be tokenized. Then the input 99 | will be optionally lowercased and passed to the user-provided 100 | `preprocessing` Pipeline.""" 101 | if (six.PY2 and isinstance(x, six.string_types) and 102 | not isinstance(x, six.text_type)): # never 103 | x = Pipeline(lambda s: six.text_type(s, encoding='utf-8'))(x) 104 | if self.sequential and isinstance(x, six.text_type): # never 105 | x = self.tokenize(x.rstrip('\n')) 106 | if self.lower: 107 | x = [Pipeline(six.text_type.lower)(xx) for xx in x] 108 | if self.preprocessing is not None: 109 | return self.preprocessing(x) 110 | else: 111 | return x 112 | 113 | def build_vocab(self, *args, **kwargs): 114 | counter = Counter() 115 | sources = [] 116 | for arg in args: 117 | if isinstance(arg, Dataset): 118 | sources += [getattr(arg, name) for name, field in 119 | arg.fields.items() if field is self] 120 | else: 121 | sources.append(arg) 122 | for data in sources: 123 | for x in data: 124 | if not self.sequential: 125 | x = [x] 126 | for xx in x: 127 | counter.update(xx) 128 | specials = list(OrderedDict.fromkeys( 129 | tok for tok in [self.unk_token, self.pad_token, self.init_token, 130 | self.eos_token] 131 | if tok is not None)) 132 | self.vocab = self.vocab_cls(counter, specials=specials, **kwargs) 133 | 134 | def pad(self, minibatch): 135 | minibatch = list(minibatch) 136 | if not self.sequential: 137 | return minibatch 138 | if self.fix_length is None: 139 | max_len = max(len(x) for x in minibatch) 140 | else: 141 | max_len = self.fix_length + ( 142 | self.init_token, self.eos_token).count(None) - 2 143 | padded, lengths = [], [] 144 | for x in minibatch: 145 | if self.pad_first: 146 | padded.append( 147 | [[self.pad_token]] * max(0, max_len - len(x)) + 148 | ([] if self.init_token is None else [[self.init_token]]) + 149 | list(x[-max_len:] if self.truncate_first else x[:max_len]) + 150 | ([] if self.eos_token is None else [[self.eos_token]])) 151 | else: 152 | padded.append( 153 | ([] if self.init_token is None else [[self.init_token]]) + 154 | list(x[-max_len:] if self.truncate_first else x[:max_len]) + 155 | ([] if self.eos_token is None else [[self.eos_token]]) + 156 | [[self.pad_token]] * max(0, max_len - len(x))) 157 | lengths.append(len(padded[-1]) - max(0, max_len - len(x))) 158 | 159 | if self.include_lengths: 160 | return (padded, lengths) 161 | return padded 162 | 163 | def numericalize(self, arr, device=None, train=True): 164 | if self.include_lengths and not isinstance(arr, tuple): 165 | raise ValueError("Field has include_lengths set to True, but " 166 | "input data is not a tuple of " 167 | "(data batch, batch lengths).") 168 | if isinstance(arr, tuple): 169 | arr, lengths = arr 170 | lengths = torch.LongTensor(lengths) 171 | 172 | if self.use_vocab: 173 | if self.sequential: 174 | arr = [[[self.vocab.stoi[xx] for xx in x] for x in ex] for ex in arr] 175 | 176 | if self.postprocessing is not None: 177 | arr = self.postprocessing(arr, self.vocab, train) 178 | 179 | if self.include_lengths: 180 | return arr, lengths 181 | return arr 182 | 183 | 184 | class ACE2005Dataset(Corpus): 185 | """ 186 | Defines a dataset composed of Examples along with its Fields. 187 | """ 188 | 189 | sort_key = None 190 | 191 | def __init__(self, path, fields, amr=False, keep_events=None, only_keep=False, **kwargs): 192 | ''' 193 | Create a corpus given a path, field list, and a filter function. 194 | 195 | :param path: str, Path to the data file 196 | :param fields: dict[str: tuple(str, Field)], 197 | If using a dict, the keys should be a subset of the JSON keys or CSV/TSV 198 | columns, and the values should be tuples of (name, field). 199 | Keys not present in the input dictionary are ignored. 200 | This allows the user to rename columns from their JSON/CSV/TSV key names 201 | and also enables selecting a subset of columns to load. 202 | :param keep_events: int, minimum sentence events. Default keep all. 203 | ''' 204 | self.keep_events = keep_events 205 | self.only_keep = only_keep 206 | super(ACE2005Dataset, self).__init__(path, fields, amr, **kwargs) 207 | 208 | def parse_example(self, path, fields, amr, **kwargs): 209 | examples = [] 210 | 211 | _file = codecs.open(path, 'r', 'utf-8') 212 | jl = json.load(_file) 213 | print(path, len(jl)) 214 | for js in jl: 215 | ex = self.parse_sentence(js, fields, amr) 216 | if ex is not None: 217 | examples.append(ex) 218 | # for line in f: 219 | # line = line.strip() 220 | # if len(line) == 0: 221 | # continue 222 | # print(line) 223 | # jl = json.loads(line, encoding="utf-8") 224 | # for js in jl: 225 | # ex = self.parse_sentence(js, fields) 226 | # if ex is not None: 227 | # examples.append(ex) 228 | 229 | return examples 230 | 231 | def parse_sentence(self, js, fields, amr): 232 | SENTID = fields["sentence_id"] 233 | WORDS = fields["words"] 234 | POSTAGS = fields["pos-tags"] 235 | # LEMMAS = fields["lemma"] 236 | ENTITYLABELS = fields["golden-entity-mentions"] 237 | if amr: 238 | colcc = "simple-parsing" 239 | else: 240 | colcc = "combined-parsing" 241 | # print(colcc) 242 | ADJMATRIX = fields[colcc] 243 | LABELS = fields["golden-event-mentions"] 244 | EVENTS = fields["all-events"] 245 | ENTITIES = fields["all-entities"] 246 | 247 | sentence = Sentence_ace(json_content=js, graph_field_name=colcc) 248 | ex = Example() 249 | # print('sentence.wordList', WORDS[1].preprocess(sentence.wordList)) 250 | setattr(ex, SENTID[0], SENTID[1].preprocess(sentence.sentence_id)) 251 | setattr(ex, WORDS[0], WORDS[1].preprocess(sentence.wordList)) 252 | setattr(ex, POSTAGS[0], POSTAGS[1].preprocess(sentence.posLabelList)) 253 | # setattr(ex, LEMMAS[0], LEMMAS[1].preprocess(sentence.lemmaList)) 254 | setattr(ex, ENTITYLABELS[0], ENTITYLABELS[1].preprocess(sentence.entityLabelList)) 255 | setattr(ex, ADJMATRIX[0], (sentence.adjpos, sentence.adjv)) 256 | setattr(ex, LABELS[0], LABELS[1].preprocess(sentence.triggerLabelList)) 257 | setattr(ex, EVENTS[0], EVENTS[1].preprocess(sentence.events)) 258 | setattr(ex, ENTITIES[0], ENTITIES[1].preprocess(sentence.entities)) 259 | 260 | if self.keep_events is not None: 261 | if self.only_keep and sentence.containsEvents != self.keep_events: 262 | return None 263 | elif not self.only_keep and sentence.containsEvents < self.keep_events: 264 | return None 265 | else: 266 | return ex 267 | else: 268 | return ex 269 | 270 | def longest(self): 271 | return max([len(x.POSTAGS) for x in self.examples]) 272 | -------------------------------------------------------------------------------- /src/dataflow/torch/Sentence.py: -------------------------------------------------------------------------------- 1 | # import sys 2 | # sys.path.append('/dvmm-filer2/users/manling/mm-event-graph2') 3 | 4 | from src.util.consts import CUTOFF 5 | # from PIL import Image 6 | # import os 7 | 8 | 9 | def pretty_str(a): 10 | a = a.upper() 11 | if a == 'O': 12 | return a 13 | elif a[1] == '-': 14 | return a[:2] + "|".join(a[2:].split("-")).replace(":", "||") 15 | else: 16 | return "|".join(a.split("-")).replace(":", "||") 17 | 18 | 19 | class Sentence: 20 | def __init__(self, json_content, with_sentid=False): 21 | # self.wordList = json_content["words"][:CUTOFF] 22 | # self.posLabelList = json_content["pos-tags"][:CUTOFF] 23 | # # self.lemmaList = json_content["lemma"][:CUTOFF] 24 | # self.length = len(self.wordList) 25 | # 26 | # self.entityLabelList = self.generateEntityLabelList(json_content["golden-entity-mentions"]) 27 | # # self.triggerLabelList = self.generateTriggerLabelList(json_content["golden-event-mentions"]) 28 | # self.adjpos, self.adjv = self.generateAdjMatrix(json_content[graph_field_name]) 29 | # 30 | # self.entities = self.generateGoldenEntities(json_content["golden-entity-mentions"]) 31 | # # self.events = self.generateGoldenEvents(json_content["golden-event-mentions"]) 32 | # 33 | # # self.containsEvents = len(json_content["golden-event-mentions"]) 34 | # self.tokenList = self.makeTokenList() 35 | self.json_content = json_content 36 | self.with_sentid = with_sentid 37 | 38 | def generateEntityLabelList(self, entitiesJsonList): 39 | ''' 40 | Keep the overlapping entity labels 41 | :param entitiesJsonList: 42 | :return: 43 | ''' 44 | 45 | entityLabel = [["O"] for _ in range(self.length)] 46 | 47 | def assignEntityLabel(index, label): 48 | if index >= CUTOFF: 49 | return 50 | if len(entityLabel[index]) == 1 and entityLabel[index][0] == "O": 51 | entityLabel[index][0] = pretty_str(label) 52 | else: 53 | entityLabel[index].append(pretty_str(label)) 54 | 55 | for entityJson in entitiesJsonList: 56 | start = entityJson["start"] 57 | end = entityJson["end"] 58 | etype = entityJson["entity-type"].split(":")[0] 59 | assignEntityLabel(start, "B-" + etype) 60 | for i in range(start + 1, end): 61 | assignEntityLabel(i, "I-" + etype) 62 | 63 | return entityLabel 64 | 65 | def generateGoldenEntities(self, entitiesJson): 66 | ''' 67 | [(2, 3, "entity_type")] 68 | ''' 69 | golden_list = [] 70 | for entityJson in entitiesJson: 71 | start = entityJson["start"] 72 | if start >= CUTOFF: 73 | continue 74 | end = min(entityJson["end"], CUTOFF) 75 | etype = entityJson["entity-type"].split(":")[0] 76 | golden_list.append((start, end, etype)) 77 | return golden_list 78 | 79 | def generateGoldenEvents(self, eventsJson, with_sentid=False, sent_id=None): 80 | ''' 81 | 82 | { 83 | (2, 3, "event_type_str") --> [(1, 2, "role_type_str"), ...] 84 | ... 85 | } 86 | 87 | ''' 88 | golden_dict = {} 89 | for eventJson in eventsJson: 90 | triggerJson = eventJson["trigger"] 91 | if triggerJson["start"] >= CUTOFF: 92 | continue 93 | if with_sentid: 94 | key = ('%s__%d' % (sent_id, triggerJson["start"]), 95 | min(triggerJson["end"], CUTOFF), pretty_str(eventJson["event_type"])) 96 | else: 97 | key = (triggerJson["start"], 98 | min(triggerJson["end"], CUTOFF), pretty_str(eventJson["event_type"])) 99 | values = [] 100 | for argumentJson in eventJson["arguments"]: 101 | if argumentJson["start"] >= CUTOFF: 102 | continue 103 | value = (argumentJson["start"], min(argumentJson["end"], CUTOFF), pretty_str(argumentJson["role"])) 104 | values.append(value) 105 | golden_dict[key] = list(sorted(values)) 106 | return golden_dict 107 | 108 | def generateTriggerLabelList(self, triggerJsonList): 109 | triggerLabel = ["O" for _ in range(self.length)] 110 | 111 | def assignTriggerLabel(index, label): 112 | if index >= CUTOFF: 113 | return 114 | triggerLabel[index] = pretty_str(label) 115 | 116 | for eventJson in triggerJsonList: 117 | triggerJson = eventJson["trigger"] 118 | start = triggerJson["start"] 119 | end = triggerJson["end"] 120 | etype = eventJson["event_type"] 121 | assignTriggerLabel(start, "B-" + etype) 122 | for i in range(start + 1, end): 123 | assignTriggerLabel(i, "I-" + etype) 124 | return triggerLabel 125 | 126 | def generateAdjMatrix(self, edgeJsonList): 127 | sparseAdjMatrixPos = [[], [], []] 128 | sparseAdjMatrixValues = [] 129 | 130 | def addedge(type_, from_, to_, value_): 131 | sparseAdjMatrixPos[0].append(type_) 132 | sparseAdjMatrixPos[1].append(from_) 133 | sparseAdjMatrixPos[2].append(to_) 134 | sparseAdjMatrixValues.append(value_) 135 | 136 | for edgeJson in edgeJsonList: 137 | ss = edgeJson.split("/") 138 | fromIndex = int(ss[-1].split("=")[-1]) 139 | toIndex = int(ss[-2].split("=")[-1]) 140 | etype = ss[0] 141 | if etype.lower() == "root" or fromIndex == -1 or toIndex == -1 or fromIndex >= CUTOFF or toIndex >= CUTOFF: 142 | continue 143 | addedge(0, fromIndex, toIndex, 1.0) 144 | addedge(1, toIndex, fromIndex, 1.0) 145 | 146 | for i in range(self.length): 147 | addedge(2, i, i, 1.0) 148 | 149 | return sparseAdjMatrixPos, sparseAdjMatrixValues 150 | 151 | # def makeTokenList(self): 152 | # # return [Token(self.wordList[i], self.posLabelList[i], self.lemmaList[i], self.entityLabelList[i], 153 | # # self.triggerLabelList[i]) 154 | # # for i in range(self.length)] 155 | # return [Token(self.wordList[i], self.posLabelList[i], self.entityLabelList[i], 156 | # self.triggerLabelList[i]) 157 | # for i in range(self.length)] 158 | 159 | def __len__(self): 160 | return self.length 161 | 162 | def __iter__(self): 163 | for x in self.tokenList: 164 | yield x 165 | 166 | def __getitem__(self, index): 167 | return self.tokenList[index] 168 | 169 | 170 | class Sentence_ace(Sentence): 171 | def __init__(self, json_content, graph_field_name): 172 | Sentence.__init__(self, json_content) 173 | if "sentence_id" in json_content: 174 | self.sentence_id = json_content["sentence_id"] 175 | else: 176 | self.sentence_id = "none" 177 | self.wordList = json_content["words"][:CUTOFF] 178 | self.posLabelList = json_content["pos-tags"][:CUTOFF] 179 | # self.lemmaList = json_content["lemma"][:CUTOFF] 180 | self.length = len(self.wordList) 181 | 182 | if "golden-entity-mentions" in json_content: 183 | self.entityLabelList = self.generateEntityLabelList(json_content["golden-entity-mentions"]) 184 | self.entities = self.generateGoldenEntities(json_content["golden-entity-mentions"]) 185 | else: 186 | self.entityLabelList = self.generateEntityLabelList(list()) 187 | self.entities = self.generateGoldenEntities(list()) 188 | if "golden-event-mentions" in json_content: 189 | self.triggerLabelList = self.generateTriggerLabelList(json_content["golden-event-mentions"]) 190 | self.events = self.generateGoldenEvents(json_content["golden-event-mentions"]) 191 | self.containsEvents = len(json_content["golden-event-mentions"]) 192 | else: 193 | self.triggerLabelList = self.generateTriggerLabelList(list()) 194 | self.events = self.generateGoldenEvents(list()) 195 | self.containsEvents = 0 196 | 197 | self.adjpos, self.adjv = self.generateAdjMatrix(json_content[graph_field_name]) 198 | self.tokenList = self.makeTokenList() 199 | 200 | def makeTokenList(self): 201 | # return [Token(self.wordList[i], self.posLabelList[i], self.lemmaList[i], self.entityLabelList[i], 202 | # self.triggerLabelList[i]) 203 | # for i in range(self.length)] 204 | return [Token(self.wordList[i], self.posLabelList[i], self.entityLabelList[i], 205 | self.triggerLabelList[i]) 206 | for i in range(self.length)] 207 | 208 | 209 | class Sentence_grounding(Sentence): 210 | def __init__(self, json_content, graph_field_name, img_dir, transform=None): 211 | Sentence.__init__(self, json_content) 212 | self.image_id = json_content["image"] 213 | self.sentence_id = json_content["sentence_id"] 214 | self.wordList = json_content["words"][:CUTOFF] 215 | self.posLabelList = json_content["pos-tags"][:CUTOFF] 216 | # self.lemmaList = json_content["lemma"][:CUTOFF] 217 | self.length = len(self.wordList) 218 | 219 | self.entityLabelList = self.generateEntityLabelList(json_content["golden-entity-mentions"]) 220 | self.adjpos, self.adjv = self.generateAdjMatrix(json_content[graph_field_name]) 221 | 222 | self.entities = self.generateGoldenEntities(json_content["golden-entity-mentions"]) 223 | 224 | self.tokenList = self.makeTokenList() 225 | 226 | # get the image vectors 227 | # img_path = os.path.join(img_dir, self.image_id) 228 | # try: 229 | # self.image_vec = Image.open(img_path).convert('RGB') 230 | # if transform is not None: 231 | # self.image_vec = transform(self.image_vec) 232 | # except: 233 | # self.image_vec = None 234 | 235 | def makeTokenList(self): 236 | return [Token(self.wordList[i], self.posLabelList[i], self.entityLabelList[i]) 237 | for i in range(self.length)] 238 | 239 | class Sentence_m2e2(Sentence): 240 | def __init__(self, json_content, graph_field_name, with_sentid=False): 241 | Sentence.__init__(self, json_content, with_sentid=with_sentid) 242 | self.image_id = json_content["image"] 243 | self.sentence_id = json_content["sentence_id"] 244 | self.wordList = json_content["words"][:CUTOFF] 245 | self.posLabelList = json_content["pos-tags"][:CUTOFF] 246 | self.length = len(self.wordList) 247 | 248 | self.entityLabelList = self.generateEntityLabelList(json_content["golden-entity-mentions"]) 249 | self.triggerLabelList = self.generateTriggerLabelList(json_content["golden-event-mentions"]) 250 | self.adjpos, self.adjv = self.generateAdjMatrix(json_content[graph_field_name]) 251 | 252 | self.entities = self.generateGoldenEntities(json_content["golden-entity-mentions"]) 253 | self.events = self.generateGoldenEvents(json_content["golden-event-mentions"], self.with_sentid, self.sentence_id) 254 | 255 | self.containsEvents = len(json_content["golden-event-mentions"]) 256 | self.tokenList = self.makeTokenList() 257 | 258 | def makeTokenList(self): 259 | return [Token(self.wordList[i], self.posLabelList[i], self.entityLabelList[i], 260 | self.triggerLabelList[i]) 261 | for i in range(self.length)] 262 | 263 | 264 | class Token: 265 | # def __init__(self, word, posLabel, lemmaLabel, entityLabel, triggerLabel): 266 | def __init__(self, word, posLabel, entityLabel, triggerLabel=None): 267 | self.word = word 268 | self.posLabel = posLabel 269 | # self.lemmaLabel = lemmaLabel 270 | self.entityLabel = entityLabel 271 | self.triggerLabel = triggerLabel 272 | self.predictedLabel = None 273 | 274 | def addPredictedLabel(self, label): 275 | self.predictedLabel = label -------------------------------------------------------------------------------- /src/engine/TestRunnerEE.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import pickle 4 | import json 5 | import sys 6 | from functools import partial 7 | 8 | import numpy as np 9 | import torch 10 | from tensorboardX import SummaryWriter 11 | from torchtext.data import Field 12 | from torchtext.vocab import Vectors 13 | from torchtext.data import BucketIterator 14 | from math import ceil 15 | 16 | import sys 17 | #export PATH=/dvmm-filer2/users/manling/mm-event-graph2:$PATH 18 | sys.path.append('../..') 19 | 20 | from src.util import consts 21 | from src.dataflow.torch.Data import ACE2005Dataset, MultiTokenField, SparseField, EventField, EntityField 22 | from src.models.ee import EDModel 23 | from src.eval.EEtesting import EDTester 24 | from src.eval.EEvisualizing import EDVisualizer 25 | from src.engine.EEtraining import ee_train 26 | from src.util.util_model import log 27 | from src.engine.EEtraining import run_over_data 28 | from src.engine.EErunner import load_ee_model 29 | from src.util.util_model import progressbar 30 | from src.dataflow.torch.Sentence import Token 31 | from src.engine.EErunner import event_role_mask 32 | 33 | 34 | class EERunnerTest(object): 35 | def __init__(self): 36 | parser = argparse.ArgumentParser(description="neural networks trainer") 37 | parser.add_argument("--test_ee", help="event extraction validation set") 38 | parser.add_argument("--train_ee", help="event extraction training set", required=False) 39 | parser.add_argument("--dev_ee", help="event extraction development set", required=False) 40 | parser.add_argument("--webd", help="word embedding", required=False) 41 | parser.add_argument("--ignore_time_test", help="testing ignore place in sr model", action='store_true') 42 | 43 | parser.add_argument("--batch", help="batch size", default=128, type=int) 44 | # parser.add_argument("--epochs", help="n of epochs", default=sys.maxsize, type=int) 45 | 46 | parser.add_argument("--seed", help="RNG seed", default=1111, type=int) 47 | # parser.add_argument("--optimizer", default="adam") 48 | # parser.add_argument("--lr", default=1, type=float) 49 | # parser.add_argument("--l2decay", default=0, type=float) 50 | parser.add_argument("--maxnorm", default=3, type=float) 51 | 52 | parser.add_argument("--out", help="output model path", default="out") 53 | parser.add_argument("--finetune", help="pretrained model path") 54 | # parser.add_argument("--earlystop", default=999999, type=int) 55 | # parser.add_argument("--restart", default=999999, type=int) 56 | # parser.add_argument("--shuffle", help="shuffle", action='store_true') 57 | parser.add_argument("--amr", help="use amr", action='store_true') 58 | 59 | parser.add_argument("--device", default="cpu") 60 | parser.add_argument("--hps_path", help="model hyperparams", required=False) 61 | parser.add_argument("--hps", help="model hyperparams", required=False) 62 | 63 | self.a = parser.parse_args() 64 | if self.a.hps_path: 65 | self.a.hps = json.load(open(self.a.hps_path)) 66 | print(self.a.hps) 67 | 68 | def set_device(self, device="cpu"): 69 | # self.device = torch.device(device) 70 | self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 71 | 72 | def get_device(self): 73 | return self.device 74 | 75 | def get_tester(self, voc_i2s, voc_role_i2s): 76 | return EDTester(voc_i2s, voc_role_i2s, self.a.ignore_time_test) 77 | 78 | def run(self): 79 | print("Running on", self.a.device) 80 | self.set_device(self.a.device) 81 | 82 | np.random.seed(self.a.seed) 83 | torch.manual_seed(self.a.seed) 84 | torch.backends.cudnn.benchmark = True 85 | 86 | # create training set 87 | if self.a.test_ee: 88 | log('loading event extraction corpus from %s' % self.a.test_ee) 89 | 90 | WordsField = Field(lower=True, include_lengths=True, batch_first=True) 91 | PosTagsField = Field(lower=True, batch_first=True) 92 | EntityLabelsField = MultiTokenField(lower=False, batch_first=True) 93 | AdjMatrixField = SparseField(sequential=False, use_vocab=False, batch_first=True) 94 | LabelField = Field(lower=False, batch_first=True, pad_token='0', unk_token=None) 95 | EventsField = EventField(lower=False, batch_first=True) 96 | EntitiesField = EntityField(lower=False, batch_first=True, use_vocab=False) 97 | SENTIDField = SparseField(sequential=False, use_vocab=False, batch_first=True) 98 | if self.a.amr: 99 | colcc = 'simple-parsing' 100 | else: 101 | colcc = 'combined-parsing' 102 | print(colcc) 103 | 104 | train_ee_set = ACE2005Dataset(path=self.a.train_ee, 105 | fields={"sentence_id": ("SENTID", SENTIDField), "words": ("WORDS", WordsField), 106 | "pos-tags": ("POSTAGS", PosTagsField), 107 | "golden-entity-mentions": ("ENTITYLABELS", EntityLabelsField), 108 | colcc: ("ADJM", AdjMatrixField), 109 | "golden-event-mentions": ("LABEL", LabelField), 110 | "all-events": ("EVENT", EventsField), 111 | "all-entities": ("ENTITIES", EntitiesField)}, 112 | amr=self.a.amr, keep_events=1) 113 | 114 | dev_ee_set = ACE2005Dataset(path=self.a.dev_ee, 115 | fields={"sentence_id": ("SENTID", SENTIDField), "words": ("WORDS", WordsField), 116 | "pos-tags": ("POSTAGS", PosTagsField), 117 | "golden-entity-mentions": ("ENTITYLABELS", EntityLabelsField), 118 | colcc: ("ADJM", AdjMatrixField), 119 | "golden-event-mentions": ("LABEL", LabelField), 120 | "all-events": ("EVENT", EventsField), 121 | "all-entities": ("ENTITIES", EntitiesField)}, 122 | amr=self.a.amr, keep_events=0) 123 | 124 | test_ee_set = ACE2005Dataset(path=self.a.test_ee, 125 | fields={"sentence_id": ("SENTID", SENTIDField), "words": ("WORDS", WordsField), 126 | "pos-tags": ("POSTAGS", PosTagsField), 127 | "golden-entity-mentions": ("ENTITYLABELS", EntityLabelsField), 128 | colcc: ("ADJM", AdjMatrixField), 129 | "golden-event-mentions": ("LABEL", LabelField), 130 | "all-events": ("EVENT", EventsField), 131 | "all-entities": ("ENTITIES", EntitiesField)}, 132 | amr=self.a.amr, keep_events=0) 133 | 134 | if self.a.webd: 135 | pretrained_embedding = Vectors(self.a.webd, ".", unk_init=partial(torch.nn.init.uniform_, a=-0.15, b=0.15)) 136 | WordsField.build_vocab(train_ee_set.WORDS, dev_ee_set.WORDS, vectors=pretrained_embedding) 137 | else: 138 | WordsField.build_vocab(train_ee_set.WORDS, dev_ee_set.WORDS) 139 | PosTagsField.build_vocab(train_ee_set.POSTAGS, dev_ee_set.POSTAGS) 140 | EntityLabelsField.build_vocab(train_ee_set.ENTITYLABELS, dev_ee_set.ENTITYLABELS) 141 | LabelField.build_vocab(train_ee_set.LABEL, dev_ee_set.LABEL) 142 | EventsField.build_vocab(train_ee_set.EVENT, dev_ee_set.EVENT) 143 | consts.O_LABEL = LabelField.vocab.stoi[consts.O_LABEL_NAME] 144 | # print("O label is", consts.O_LABEL) 145 | consts.ROLE_O_LABEL = EventsField.vocab.stoi[consts.ROLE_O_LABEL_NAME] 146 | # print("O label for AE is", consts.ROLE_O_LABEL) 147 | 148 | self.a.label_weight = torch.ones([len(LabelField.vocab.itos)]) * 5 149 | self.a.label_weight[consts.O_LABEL] = 1.0 150 | self.a.arg_weight = torch.ones([len(EventsField.vocab.itos)]) * 5 151 | # add role mask 152 | self.a.role_mask = event_role_mask(self.a.test_ee, self.a.train_ee, self.a.dev_ee, LabelField.vocab.stoi, 153 | EventsField.vocab.stoi, self.device) 154 | # print('self.a.hps', self.a.hps) 155 | if not self.a.hps_path: 156 | self.a.hps = eval(self.a.hps) 157 | if "wemb_size" not in self.a.hps: 158 | self.a.hps["wemb_size"] = len(WordsField.vocab.itos) 159 | if "pemb_size" not in self.a.hps: 160 | self.a.hps["pemb_size"] = len(PosTagsField.vocab.itos) 161 | if "psemb_size" not in self.a.hps: 162 | self.a.hps["psemb_size"] = max([train_ee_set.longest(), dev_ee_set.longest(), test_ee_set.longest()]) + 2 163 | if "eemb_size" not in self.a.hps: 164 | self.a.hps["eemb_size"] = len(EntityLabelsField.vocab.itos) 165 | if "oc" not in self.a.hps: 166 | self.a.hps["oc"] = len(LabelField.vocab.itos) 167 | if "ae_oc" not in self.a.hps: 168 | self.a.hps["ae_oc"] = len(EventsField.vocab.itos) 169 | 170 | tester = self.get_tester(LabelField.vocab.itos, EventsField.vocab.itos) 171 | visualizer = EDVisualizer(self.a.test_ee) 172 | 173 | if self.a.finetune: 174 | log('init model from ' + self.a.finetune) 175 | model = load_ee_model(self.a.hps, self.a.finetune, WordsField.vocab.vectors, self.device) 176 | log('model loaded, there are %i sets of params' % len(model.parameters_requires_grads())) 177 | else: 178 | model = load_ee_model(self.a.hps, None, WordsField.vocab.vectors, self.device) 179 | log('model created from scratch, there are %i sets of params' % len(model.parameters_requires_grads())) 180 | 181 | self.a.word_i2s = WordsField.vocab.itos 182 | self.a.label_i2s = LabelField.vocab.itos 183 | self.a.role_i2s = EventsField.vocab.itos 184 | writer = SummaryWriter(os.path.join(self.a.out, "exp")) 185 | self.a.writer = writer 186 | 187 | # train_iter = BucketIterator(train_ee_set, batch_size=self.a.batch, 188 | # train=True, shuffle=False, device=-1, 189 | # sort_key=lambda x: len(x.POSTAGS)) 190 | # dev_iter = BucketIterator(dev_ee_set, batch_size=self.a.batch, train=False, 191 | # shuffle=False, device=-1, 192 | # sort_key=lambda x: len(x.POSTAGS)) 193 | test_iter = BucketIterator(test_ee_set, batch_size=self.a.batch, train=False, 194 | shuffle=False, device=-1, 195 | sort_key=lambda x: len(x.POSTAGS)) 196 | 197 | print("\nStarting testing ...\n") 198 | 199 | # Testing Phrase 200 | test_loss, test_ed_p, test_ed_r, test_ed_f1, \ 201 | test_ae_p, test_ae_r, test_ae_f1 = run_over_data(data_iter=test_iter, 202 | optimizer=None, 203 | model=model, 204 | need_backward=False, 205 | MAX_STEP=len(test_iter), 206 | tester=tester, 207 | visualizer=visualizer, 208 | hyps=model.hyperparams, 209 | device=model.device, 210 | maxnorm=self.a.maxnorm, 211 | word_i2s=self.a.word_i2s, 212 | label_i2s=self.a.label_i2s, 213 | role_i2s=self.a.role_i2s, 214 | weight=self.a.label_weight, 215 | arg_weight=self.a.arg_weight, 216 | save_output=os.path.join( 217 | self.a.out, 218 | "test_final.txt"), 219 | role_mask=self.a.role_mask) 220 | 221 | print("\nFinally test loss: ", test_loss, 222 | "\ntest ed p: ", test_ed_p, 223 | " test ed r: ", test_ed_r, 224 | " test ed f1: ", test_ed_f1, 225 | "\ntest ae p: ", test_ae_p, 226 | " test ae r: ", test_ae_r, 227 | " test ae f1: ", test_ae_f1) 228 | 229 | if __name__ == "__main__": 230 | EERunnerTest().run() -------------------------------------------------------------------------------- /src/eval/EEvisualizing.py: -------------------------------------------------------------------------------- 1 | import json 2 | from collections import defaultdict 3 | import os 4 | import sys 5 | sys.path.append('../../') 6 | from src.dataflow.numpy.anno_mapping import entity_type_mapping_brat, event_type_mapping_ace2brat, event_role_mapping_ace2brat, event_role_mapping_brat2ace 7 | import shutil 8 | 9 | class EDVisualizer(): 10 | def __init__(self, JMEE_data_json): 11 | self.sent_info = defaultdict(lambda : defaultdict()) 12 | self._load_sent_info(JMEE_data_json) 13 | 14 | def _load_sent_info(self, JMEE_data_json): 15 | raw_data = json.load(open(JMEE_data_json)) 16 | for data_instance in raw_data: 17 | sent_id = data_instance['sentence_id'] 18 | words = data_instance['words'] 19 | # index_all = data_instance['index_all'] 20 | # index = data_instance['index'] 21 | # entity = data_instance['golden-entity-mentions'] 22 | sentence = data_instance['sentence'] 23 | 24 | self.sent_info[sent_id]['words'] = words 25 | # self.sent_info[sent_id]['index_all'] = index_all 26 | # self.sent_info[sent_id]['index'] = index 27 | # self.sent_info[sent_id]['entity'] = entity 28 | self.sent_info[sent_id]['sentence'] = sentence 29 | 30 | 31 | def save_json(self, predicted_event_triggers, predicted_events, sent_id, ee_role_i2s, text_result_json): 32 | sent_id = sent_id[0] 33 | for result_dict in predicted_events: 34 | for key in result_dict: 35 | # print('self.sent_info[sent_id]', self.sent_info) 36 | # print('sent_id', sent_id) 37 | eventstart_sentid, event_end, event_type_ace = key 38 | trigger_word = ' '.join(self.sent_info[sent_id]['words'][int(eventstart_sentid):int(event_end)]) 39 | text_result_json[eventstart_sentid]['sentence_id'] = sent_id 40 | text_result_json[eventstart_sentid]['sentence'] = self.sent_info[sent_id]['sentence'] 41 | text_result_json[eventstart_sentid]['sentence_tokens'] = self.sent_info[sent_id]['words'] 42 | text_result_json[eventstart_sentid]['pred_event_type'] = event_type_ace 43 | text_result_json[eventstart_sentid]['pred_trigger'] = {'index_start': eventstart_sentid, 'index_end':event_end, 'text':trigger_word} 44 | text_result_json[eventstart_sentid]['pred_roles'] = defaultdict(list) 45 | 46 | args = result_dict[key] 47 | for arg_start, arg_end, arg_type in args: 48 | arg_word = ' '.join(self.sent_info[sent_id]['words'][int(arg_start):int(arg_end)]) 49 | # arg_start_offset, arg_end_offset = self.get_offset_by_idx(sent_id, arg_start, arg_end) 50 | role_name_ace = ee_role_i2s[arg_type] 51 | text_result_json[eventstart_sentid]['pred_roles'][role_name_ace].append( {'index_start':arg_start, 'index_end':arg_end, 'text':arg_word} ) 52 | 53 | 54 | 55 | def visualize_html(self, predicted_event_triggers, predicted_events, sent_id, ee_role_i2s, visual_writer): 56 | for eventstart_sentid in predicted_event_triggers: 57 | # print('eventstart_sentid', eventstart_sentid) 58 | event_end, event_type_ace = predicted_event_triggers[eventstart_sentid] 59 | # if '__' in eventstart_sentid: 60 | event_start = int(eventstart_sentid[eventstart_sentid.find('__')+2:]) 61 | # print('predicted_event_triggers', predicted_event_triggers) 62 | # event_start_offset, event_end_offset = self.get_offset_by_idx(sent_id, event_start, event_end) 63 | visual_writer.write("sentence: %s
\n" % self.sent_info[sent_id]['sentence']) 64 | trigger_word = ' '.join(self.sent_info[sent_id]['words'][event_start:event_end]) # multiple words] 65 | visual_writer.write("event type: %s
\n" % event_type_ace) 66 | visual_writer.write("trigger word: %s
\n" % trigger_word) 67 | if (eventstart_sentid, event_end, event_type_ace) in predicted_events: 68 | args = predicted_events[(eventstart_sentid, event_end, event_type_ace)] 69 | for arg_start, arg_end, arg_type in args: 70 | arg_word = ' '.join(self.sent_info[sent_id]['words'][arg_start:arg_end]) 71 | # arg_start_offset, arg_end_offset = self.get_offset_by_idx(sent_id, arg_start, arg_end) 72 | role_name_ace = ee_role_i2s[arg_type] 73 | visual_writer.write(" [%s]: %s
\n" % (role_name_ace, arg_word)) 74 | visual_writer.write('
\n
\n') 75 | 76 | visual_writer.flush() 77 | visual_writer.close() 78 | 79 | 80 | def visualize_brat(self, predicted_event_triggers, predicted_events, sent_id, ee_role_i2s, visual_writer, save_entity=False): 81 | print('predicted_event_triggers', predicted_event_triggers) 82 | print('predicted_events', predicted_events) 83 | for eventstart_sentid in predicted_event_triggers: 84 | # print('eventstart_sentid', eventstart_sentid) 85 | event_end, event_type_ace = predicted_event_triggers[eventstart_sentid] 86 | # if '__' in eventstart_sentid: 87 | event_start = int(eventstart_sentid[eventstart_sentid.find('__')+2:]) 88 | # print('predicted_event_triggers', predicted_event_triggers) 89 | event_start_offset, event_end_offset = self.get_offset_by_idx(sent_id, event_start, event_end) 90 | trigger_word = ' '.join(self.sent_info[sent_id]['words'][event_start:event_end]) # multiple words 91 | # T44 Attack 6595 6604 onslaught 92 | # event_id = 'T%s_%d_%d' % (sent_id[sent_id.find('_'):], event_start_offset, event_end_offset) 93 | token_id = 'T%d_%d' % (event_start_offset, event_end_offset) 94 | if event_type_ace not in event_type_mapping_ace2brat: 95 | print('ignored type', event_type_ace) 96 | continue 97 | event_type_brat = event_type_mapping_ace2brat[event_type_ace] 98 | visual_writer.write(token_id) 99 | visual_writer.write('\t%s' % event_type_brat) 100 | visual_writer.write(' %d %d' % (event_start_offset, event_end_offset)) 101 | visual_writer.write('\t%s\n' % trigger_word) 102 | 103 | # E11 Attack:T44 Attacker:T26493 Target:T45 104 | # save triggers 105 | event_id = token_id.replace('T', 'E') 106 | visual_writer.write('%s' % event_id) 107 | visual_writer.write('\t%s:%s' % (event_type_brat, token_id)) 108 | # save args 109 | if save_entity: 110 | entity_lines = list() 111 | # try: 112 | if (eventstart_sentid, event_end, event_type_ace) in predicted_events: 113 | args = predicted_events[(eventstart_sentid, event_end, event_type_ace)] 114 | for arg_start, arg_end, arg_type in args: 115 | # index -> offset index 116 | arg_word = ' '.join(self.sent_info[sent_id]['words'][arg_start:arg_end]) 117 | arg_start_offset, arg_end_offset = self.get_offset_by_idx(sent_id, arg_start, arg_end) 118 | role_name_ace = ee_role_i2s[arg_type] 119 | # print(event_type_ace, role_name_ace) 120 | role_name_brat = event_role_mapping_ace2brat[event_type_ace][role_name_ace] 121 | visual_writer.write(' %s:T%d_%d' % (role_name_brat, arg_start_offset, arg_end_offset)) 122 | if save_entity: 123 | # T44 Attack 6595 6604 onslaught 124 | arg_type_brat = 'VAL' # do not save the entity types? 125 | entity_lines.append('T%d_%d\t%s %d %d\t%s\n' % ( 126 | arg_start_offset, arg_end_offset, arg_type_brat, 127 | arg_start_offset, arg_end_offset, arg_word 128 | )) 129 | # except: 130 | # print('no arguments ', (eventstart_sentid, event_end, event_type_ace)) 131 | visual_writer.write('\n') 132 | if save_entity: 133 | visual_writer.write(''.join(entity_lines)) 134 | 135 | visual_writer.flush() 136 | visual_writer.close() 137 | 138 | def get_offset_by_idx(self, sent_id, idx_start, idx_end): 139 | # print('idx_start', idx_start, 'idx_end', idx_end) 140 | # print(self.sent_info[sent_id]['index_all']) 141 | start_offset, _ = self.sent_info[sent_id]['index_all'][idx_start] 142 | _, end_offset = self.sent_info[sent_id]['index_all'][idx_end - 1] 143 | end_offset = end_offset + 1 # brat is [start, end] not [start, end) 144 | return start_offset, end_offset 145 | 146 | def rewrite_brat(self, event_tmp_dir, ann_dir, save_entity=False): 147 | ''' 148 | The previous format is not brat format, need postprocessing 149 | :return: 150 | ''' 151 | # if os.path.isdir(event_tmp_dir): 152 | for event_tmp_file in os.listdir(event_tmp_dir): 153 | if event_tmp_file.endswith('.ann_tmp'): 154 | event_tmp_path = os.path.join(event_tmp_dir, event_tmp_file) 155 | ann_path = os.path.join(ann_dir, event_tmp_file.replace('.ann_tmp', '.ann')) 156 | if not save_entity: 157 | self._rewrite_brat(event_tmp_path, ann_path) 158 | else: 159 | # remove repeated lines? 160 | # simpler version: copy as final ann file, ignore the repeated lines 161 | shutil.copyfile(event_tmp_path, event_tmp_path.replace('_tmp', '')) 162 | 163 | # copy rsd txt 164 | rsd_path = os.path.join(ann_dir, event_tmp_file.replace('.ann_tmp', '.txt')) 165 | rsd_path_new = os.path.join(event_tmp_dir, event_tmp_file.replace('.ann_tmp', '.txt')) 166 | shutil.copyfile(rsd_path, rsd_path_new) 167 | 168 | def _rewrite_brat(self, event_tmp_file, anno_file): 169 | # get all entity offset and entity id mapping 170 | token_offsetid2realid = dict() 171 | 172 | entity_lines = list() 173 | if os.path.exists(anno_file): 174 | lines = open(anno_file).readlines() 175 | for line in lines: 176 | if line.startswith('T'): 177 | tabs = line.split('\t') 178 | id = tabs[0] 179 | subs = tabs[1].split(' ') 180 | type = subs[0] 181 | if type in entity_type_mapping_brat: 182 | start = int(subs[1]) 183 | end = int(subs[2]) 184 | # mention = tabs[2] 185 | offset_id = 'T%d_%d' % (start, end) 186 | token_offsetid2realid[offset_id] = id 187 | entity_lines.append(line) 188 | else: 189 | print('NoANN', anno_file) 190 | 191 | # write entities 192 | writer = open(event_tmp_file.replace('_tmp', ''), 'w') 193 | print(event_tmp_file.replace('_tmp', '')) 194 | writer.write(''.join(entity_lines)) 195 | 196 | # write events: 197 | for line in open(event_tmp_file): 198 | line = line.rstrip('\n') 199 | if line.startswith('T'): 200 | writer.write('%s\n' % line) 201 | else: 202 | # update the arg token_id 203 | # E11 Attack:T44 Attacker:T26493 Target:T45 204 | role_values = line.split('\t')[-1].split(' ') 205 | # event_id = role_values[0].split(':')[1] 206 | # event_type = role_values[0].split(':')[0] 207 | writer.write(line.split('\t')[0]) 208 | writer.write('\t') 209 | writer.write(role_values[0]) 210 | entity_added_lines=list() 211 | for role_value in role_values[1:]: 212 | # role_name_raw = role_value.split(':')[0].replace('2', '').replace('3', '').replace('4', '').replace( 213 | # '5', '') 214 | # if len(role_name_raw) == 0: 215 | # continue 216 | # if role_name_raw not in event_role_mapping_brat2ace[event_type]: 217 | # print('ignored role: ', event_id, role_value) 218 | # continue 219 | # role_name = event_role_mapping_brat2ace[event_type][role_name_raw] 220 | entity = role_value.split(':')[1] 221 | if entity not in token_offsetid2realid: 222 | print('[ERROR] entity id can not find in *.ann', entity, anno_file) 223 | entity_type = 'VAL' 224 | entity_str = 'VAL' 225 | entity_added_lines.append('%s\t%s %s %s\t%s\n' % (entity, entity_type, 226 | entity[1:].split('_')[0], 227 | entity[1:].split('_')[1], 228 | entity_str) ) 229 | entity_realid = entity 230 | else: 231 | entity_realid = token_offsetid2realid[entity] 232 | # print('entity', entity, 'entity_realid', entity_realid) 233 | writer.write(' ') 234 | writer.write(role_value.replace(entity, entity_realid)) 235 | writer.write('\n') 236 | writer.write(''.join(entity_added_lines)) 237 | 238 | writer.flush() 239 | writer.close() 240 | 241 | -------------------------------------------------------------------------------- /src/util/util_img.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Revised based on https://github.com/hassanhub/MultiGrounding/blob/master/code/utils.py 3 | ''' 4 | 5 | from skimage.feature import peak_local_max 6 | import cv2 7 | import numpy as np 8 | import matplotlib.pyplot as plt 9 | from scipy import ndimage as ndi 10 | 11 | 12 | # bbox generation config 13 | rel_peak_thr = .3 14 | rel_rel_thr = .3 15 | ioa_thr = .6 16 | topk_boxes = 3 17 | 18 | 19 | def heat2bbox(heat_map, original_image_shape): 20 | h, w = heat_map.shape 21 | 22 | bounding_boxes = [] 23 | 24 | heat_map = heat_map - np.min(heat_map) 25 | heat_map = heat_map / np.max(heat_map) 26 | 27 | bboxes = [] 28 | box_scores = [] 29 | 30 | peak_coords = peak_local_max(heat_map, exclude_border=False, 31 | threshold_rel=rel_peak_thr) # find local peaks of heat map 32 | 33 | heat_resized = cv2.resize(heat_map, ( 34 | original_image_shape[1], original_image_shape[0])) ## resize heat map to original image shape 35 | peak_coords_resized = ((peak_coords + 0.5) * 36 | np.asarray([original_image_shape]) / 37 | np.asarray([[h, w]]) 38 | ).astype('int32') 39 | 40 | for pk_coord in peak_coords_resized: 41 | pk_value = heat_resized[tuple(pk_coord)] 42 | mask = heat_resized > pk_value * rel_rel_thr 43 | labeled, n = ndi.label(mask) 44 | l = labeled[tuple(pk_coord)] 45 | yy, xx = np.where(labeled == l) 46 | min_x = np.min(xx) 47 | min_y = np.min(yy) 48 | max_x = np.max(xx) 49 | max_y = np.max(yy) 50 | bboxes.append((min_x, min_y, max_x, max_y)) 51 | box_scores.append(pk_value) # you can change to pk_value * probability of sentence matching image or etc. 52 | 53 | ## Merging boxes that overlap too much 54 | box_idx = np.argsort(-np.asarray(box_scores)) 55 | box_idx = box_idx[:min(topk_boxes, len(box_scores))] 56 | bboxes = [bboxes[i] for i in box_idx] 57 | box_scores = [box_scores[i] for i in box_idx] 58 | 59 | to_remove = [] 60 | for iii in range(len(bboxes)): 61 | for iiii in range(iii): 62 | if iiii in to_remove: 63 | continue 64 | b1 = bboxes[iii] 65 | b2 = bboxes[iiii] 66 | isec = max(min(b1[2], b2[2]) - max(b1[0], b2[0]), 0) * max(min(b1[3], b2[3]) - max(b1[1], b2[1]), 0) 67 | ioa1 = isec / ((b1[2] - b1[0]) * (b1[3] - b1[1])) 68 | ioa2 = isec / ((b2[2] - b2[0]) * (b2[3] - b2[1])) 69 | if ioa1 > ioa_thr and ioa1 == ioa2: 70 | to_remove.append(iii) 71 | elif ioa1 > ioa_thr and ioa1 >= ioa2: 72 | to_remove.append(iii) 73 | elif ioa2 > ioa_thr and ioa2 >= ioa1: 74 | to_remove.append(iiii) 75 | 76 | for i in range(len(bboxes)): 77 | if i not in to_remove: 78 | bounding_boxes.append({ 79 | 'score': box_scores[i], 80 | 'bbox': bboxes[i], 81 | 'bbox_normalized': np.asarray([ 82 | bboxes[i][0] / heat_resized.shape[1], 83 | bboxes[i][1] / heat_resized.shape[0], 84 | bboxes[i][2] / heat_resized.shape[1], 85 | bboxes[i][3] / heat_resized.shape[0], 86 | ]), 87 | }) 88 | 89 | return bounding_boxes 90 | 91 | 92 | def img_heat_bbox_disp(image, heat_map, title='', en_name='', alpha=0.6, cmap='viridis', cbar='False', dot_max=False, 93 | bboxes=[], order=None, show=True): 94 | thr_hit = 1 # a bbox is acceptable if hit point is in middle 85% of bbox area 95 | thr_fit = .60 # the biggest acceptable bbox should not exceed 60% of the image 96 | H, W = image.shape[0:2] 97 | # resize heat map 98 | heat_map_resized = cv2.resize(heat_map, (H, W)) 99 | 100 | # display 101 | fig = plt.figure(figsize=(15, 5)) 102 | fig.suptitle(title, size=15) 103 | ax = plt.subplot(1, 3, 1) 104 | plt.imshow(image) 105 | if dot_max: 106 | max_loc = np.unravel_index(np.argmax(heat_map_resized, axis=None), heat_map_resized.shape) 107 | plt.scatter(x=max_loc[1], y=max_loc[0], edgecolor='w', linewidth=3) 108 | 109 | if len(bboxes) > 0: # it gets normalized bbox 110 | if order == None: 111 | order = 'xxyy' 112 | 113 | for i in range(len(bboxes)): 114 | bbox_norm = bboxes[i] 115 | if order == 'xxyy': 116 | x_min, x_max, y_min, y_max = int(bbox_norm[0] * W), int(bbox_norm[1] * W), int(bbox_norm[2] * H), int( 117 | bbox_norm[3] * H) 118 | elif order == 'xyxy': 119 | x_min, x_max, y_min, y_max = int(bbox_norm[0] * W), int(bbox_norm[2] * W), int(bbox_norm[1] * H), int( 120 | bbox_norm[3] * H) 121 | x_length, y_length = x_max - x_min, y_max - y_min 122 | box = plt.Rectangle((x_min, y_min), x_length, y_length, edgecolor='w', linewidth=3, fill=False) 123 | plt.gca().add_patch(box) 124 | if en_name != '': 125 | ax.text(x_min + .5 * x_length, y_min + 10, en_name, 126 | verticalalignment='center', horizontalalignment='center', 127 | # transform=ax.transAxes, 128 | color='white', fontsize=15) 129 | # an = ax.annotate(en_name, xy=(x_min,y_min), xycoords="data", va="center", ha="center", bbox=dict(boxstyle="round", fc="w")) 130 | # plt.gca().add_patch(an) 131 | 132 | plt.imshow(heat_map_resized, alpha=alpha, cmap=cmap) 133 | 134 | # plt.figure(2, figsize=(6, 6)) 135 | plt.subplot(1, 3, 2) 136 | plt.imshow(image) 137 | # plt.figure(3, figsize=(6, 6)) 138 | plt.subplot(1, 3, 3) 139 | plt.imshow(heat_map_resized) 140 | fig.tight_layout() 141 | fig.subplots_adjust(top=.85) 142 | 143 | if show: 144 | plt.show() 145 | else: 146 | plt.close() 147 | 148 | return fig 149 | 150 | 151 | def filter_bbox(bbox_dict, order=None): 152 | thr_fit = .99 # the biggest acceptable bbox should not exceed 80% of the image 153 | if order == None: 154 | order = 'xxyy' 155 | 156 | filtered_bbox = [] 157 | filtered_bbox_norm = [] 158 | filtered_score = [] 159 | if len(bbox_dict) > 0: # it gets normalized bbox 160 | for i in range(len(bbox_dict)): 161 | bbox = bbox_dict[i]['bbox'] 162 | bbox_norm = bbox_dict[i]['bbox_normalized'] 163 | bbox_score = bbox_dict[i]['score'] 164 | if order == 'xxyy': 165 | x_min, x_max, y_min, y_max = bbox_norm[0], bbox_norm[1], bbox_norm[2], bbox_norm[3] 166 | elif order == 'xyxy': 167 | x_min, x_max, y_min, y_max = bbox_norm[0], bbox_norm[2], bbox_norm[1], bbox_norm[3] 168 | if bbox_score > 0: 169 | x_length, y_length = x_max - x_min, y_max - y_min 170 | if x_length * y_length < thr_fit: 171 | filtered_score.append(bbox_score) 172 | filtered_bbox.append(bbox) 173 | filtered_bbox_norm.append(bbox_norm) 174 | return filtered_bbox, filtered_bbox_norm, filtered_score 175 | 176 | 177 | def crop_resize_im(image, bbox, size, order='xxyy'): 178 | H, W, _ = image.shape 179 | if order == 'xxyy': 180 | roi = image[int(bbox[2] * H):int(bbox[3] * H), int(bbox[0] * W):int(bbox[1] * W), :] 181 | elif order == 'xyxy': 182 | roi = image[int(bbox[1] * H):int(bbox[3] * H), int(bbox[0] * W):int(bbox[2] * W), :] 183 | roi = cv2.resize(roi, size) 184 | return roi 185 | 186 | 187 | def im2double(im): 188 | return cv2.normalize(im.astype('float'), None, 0.0, 1.0, cv2.NORM_MINMAX) 189 | 190 | 191 | def IoU(boxA, boxB): 192 | # order = xyxy 193 | xA = max(boxA[0], boxB[0]) 194 | yA = max(boxA[1], boxB[1]) 195 | xB = min(boxA[2], boxB[2]) 196 | yB = min(boxA[3], boxB[3]) 197 | 198 | # compute the area of intersection rectangle 199 | interArea = max(0, xB - xA) * max(0, yB - yA) 200 | 201 | # compute the area of both the prediction and ground-truth 202 | # rectangles 203 | boxAArea = (boxA[2] - boxA[0]) * (boxA[3] - boxA[1]) 204 | boxBArea = (boxB[2] - boxB[0]) * (boxB[3] - boxB[1]) 205 | 206 | # compute the intersection over union by taking the intersection 207 | # area and dividing it by the sum of prediction + ground-truth 208 | # areas - the interesection area 209 | iou = interArea / float(boxAArea + boxBArea - interArea) 210 | 211 | # return the intersection over union value 212 | return iou 213 | 214 | 215 | def isCorrect(bbox_annot, bbox_pred, iou_thr=.4): 216 | iou_value_max = 0.0 217 | for bbox_p in bbox_pred: 218 | for bbox_a in bbox_annot: 219 | iou_value = IoU(bbox_p, bbox_a) 220 | iou_value_max = max(iou_value, iou_value_max) 221 | if iou_value >= iou_thr: 222 | return 1, iou_value 223 | return 0, iou_value_max 224 | 225 | 226 | def isCorrectHit(bbox_annot, heatmap, orig_img_shape): 227 | H, W = orig_img_shape 228 | heatmap_resized = cv2.resize(heatmap, (W, H)) 229 | max_loc = np.unravel_index(np.argmax(heatmap_resized, axis=None), heatmap_resized.shape) 230 | print('max_loc', max_loc) 231 | for bbox in bbox_annot: 232 | if bbox[0] <= max_loc[1] <= bbox[2] and bbox[1] <= max_loc[0] <= bbox[3]: 233 | return 1 234 | return 0 235 | 236 | 237 | def check_percent(bboxes): 238 | for bbox in bboxes: 239 | x_length = bbox[2] - bbox[0] 240 | y_length = bbox[3] - bbox[1] 241 | if x_length * y_length < .05: 242 | return False 243 | return True 244 | 245 | 246 | def union(bbox): 247 | if len(bbox) == 0: 248 | return [] 249 | if type(bbox[0]) == type(0.0) or type(bbox[0]) == type(0): 250 | bbox = [bbox] 251 | maxes = np.max(bbox, axis=0) 252 | mins = np.min(bbox, axis=0) 253 | return [[mins[0], mins[1], maxes[2], maxes[3]]] 254 | 255 | 256 | def attCorrectness(bbox_annot, heatmap, orig_img_shape): 257 | H, W = orig_img_shape 258 | heatmap_resized = cv2.resize(heatmap, (W, H)) 259 | h_s = np.sum(heatmap_resized) 260 | if h_s == 0: 261 | return 0 262 | else: 263 | heatmap_resized /= h_s 264 | att_correctness = 0 265 | for bbox in bbox_annot: 266 | x0, y0, x1, y1 = bbox 267 | att_correctness += np.sum(heatmap_resized[y0:y1, x0:x1]) 268 | return att_correctness 269 | 270 | 271 | def calc_correctness(annot, heatmap, orig_img_shape, iou_thr=.5): 272 | bbox_dict = heat2bbox(heatmap, orig_img_shape) 273 | bbox, bbox_norm, bbox_score = filter_bbox(bbox_dict=bbox_dict, order='xyxy') 274 | bbox_norm_annot = union(annot['bbox_norm']) 275 | bbox_annot = annot['bbox'] 276 | bbox_norm_pred = union(bbox_norm) 277 | # print('bbox_norm_annot', bbox_norm_annot) 278 | # print('bbox_norm_pred', bbox_norm_pred) 279 | # print('bbox_annot', bbox_annot) 280 | # print('bbox_norm', bbox_norm) 281 | bbox_correctness, bbox_iou = isCorrect(bbox_norm_annot, bbox_norm_pred, iou_thr=iou_thr) 282 | hit_correctness = isCorrectHit(bbox_annot, heatmap, orig_img_shape) 283 | att_correctness = attCorrectness(bbox_annot, heatmap, orig_img_shape) 284 | return bbox_correctness, hit_correctness, att_correctness, bbox_iou 285 | 286 | def precision_bbox(boxA, boxB): 287 | ''' 288 | 289 | :param boxA: predicted 290 | :param boxB: GT 291 | :return: 292 | ''' 293 | # order = xyxy 294 | xA = max(boxA[0], boxB[0]) 295 | yA = max(boxA[1], boxB[1]) 296 | xB = min(boxA[2], boxB[2]) 297 | yB = min(boxA[3], boxB[3]) 298 | 299 | # compute the area of intersection rectangle 300 | interArea = max(0, xB - xA) * max(0, yB - yA) 301 | 302 | # compute the area of both the prediction and ground-truth 303 | # rectangles 304 | boxAArea = (boxA[2] - boxA[0]) * (boxA[3] - boxA[1]) 305 | # boxBArea = (boxB[2] - boxB[0]) * (boxB[3] - boxB[1]) 306 | 307 | # compute the intersection over union by taking the intersection 308 | # area and dividing it by the sum of prediction + ground-truth 309 | # areas - the interesection area 310 | precision = interArea / float(boxAArea) 311 | # recall = interArea / float(boxB) 312 | 313 | # return the intersection over union value 314 | return precision 315 | 316 | def calc_correctness_box(annot_all_norm, bbox_one_norm, iou_thr=.5, count_inside=False): 317 | ''' 318 | 319 | :param bbox_annot: list of boxes 320 | :param bbox_pred_one: one bbox 321 | :return: 322 | ''' 323 | # bbox_norm_annot = annot['bbox_norm'] 324 | # bbox_annot = annot['bbox'] 325 | 326 | # bbox_norm = bboxes_role['bbox_norm'] 327 | # bbox = bboxes_role['bbox'] 328 | 329 | # for bbox_p in bbox_norm: 330 | iou_value_max = 0.0 331 | for bbox_a in annot_all_norm: 332 | iou_value = IoU(bbox_one_norm, bbox_a) 333 | iou_value_max = max(iou_value, iou_value_max) 334 | if iou_value >= iou_thr: 335 | return 1, iou_value 336 | if count_inside: 337 | precision = precision_bbox(bbox_one_norm, bbox_a) 338 | if precision >= 0.9: 339 | return 1, precision 340 | return 0, iou_value_max 341 | 342 | def calc_correctness_box_uion(annot_all_norm, bbox_all_norm, iou_thr=.5, count_inside=False): 343 | ''' 344 | 345 | :param bbox_annot: list of boxes 346 | :param bbox_pred_one: one bbox 347 | :return: 348 | ''' 349 | # bbox_norm_annot = annot['bbox_norm'] 350 | # bbox_annot = annot['bbox'] 351 | 352 | # bbox_norm = bboxes_role['bbox_norm'] 353 | # bbox = bboxes_role['bbox'] 354 | 355 | # for bbox_p in bbox_norm: 356 | # iou_value_max = 0.0 357 | iou_value = 0.0 358 | annot_all_norm_union = union(annot_all_norm) 359 | bbox_all_norm_union = union(bbox_all_norm) 360 | 361 | for bbox_a in annot_all_norm_union: 362 | for bbox_b in bbox_all_norm_union: 363 | iou_value = IoU(bbox_a, bbox_b) 364 | if iou_value >= iou_thr: 365 | return 1, iou_value 366 | if count_inside: 367 | precision = precision_bbox(bbox_b, bbox_a) 368 | if precision >= 0.9: 369 | return 1, precision 370 | return 0, iou_value 371 | -------------------------------------------------------------------------------- /data/object/class-descriptions-boxable.csv: -------------------------------------------------------------------------------- 1 | /m/011k07,Tortoise,0 2 | /m/011q46kg,Container,1 3 | /m/012074,Magpie,0 4 | /m/0120dh,Sea turtle,0 5 | /m/01226z,Football,0 6 | /m/012n7d,Ambulance,1 7 | /m/012w5l,Ladder,0 8 | /m/012xff,Toothbrush,0 9 | /m/012ysf,Syringe,0 10 | /m/0130jx,Sink,0 11 | /m/0138tl,Toy,0 12 | /m/013y1f,Organ,0 13 | /m/01432t,Cassette deck,0 14 | /m/014j1m,Apple,0 15 | /m/014sv8,Human eye,1 16 | /m/014trl,Cosmetics,0 17 | /m/014y4n,Paddle,0 18 | /m/0152hh,Snowman,0 19 | /m/01599,Beer,0 20 | /m/01_5g,Chopsticks,0 21 | /m/015h_t,Human beard,1 22 | /m/015p6,Bird,0 23 | /m/015qbp,Parking meter,0 24 | /m/015qff,Traffic light,0 25 | /m/015wgc,Croissant,0 26 | /m/015x4r,Cucumber,0 27 | /m/015x5n,Radish,0 28 | /m/0162_1,Towel,0 29 | /m/0167gd,Doll,0 30 | /m/016m2d,Skull,0 31 | /m/0174k2,Washing machine,0 32 | /m/0174n1,Glove,0 33 | /m/0175cv,Tick,0 34 | /m/0176mf,Belt,0 35 | /m/017ftj,Sunglasses,0 36 | /m/018j2,Banjo,0 37 | /m/018p4k,Cart,0 38 | /m/018xm,Ball,0 39 | /m/01940j,Backpack,0 40 | /m/0199g,Bicycle,1 41 | /m/019dx1,Home appliance,0 42 | /m/019h78,Centipede,0 43 | /m/019jd,Boat,1 44 | /m/019w40,Surfboard,0 45 | /m/01b638,Boot,0 46 | /m/01b7fy,Headphones,1 47 | /m/01b9xk,Hot dog,0 48 | /m/01bfm9,Shorts,0 49 | /m/01_bhs,Fast food,0 50 | /m/01bjv,Bus,1 51 | /m/01bl7v,Boy,1 52 | /m/01bms0,Screwdriver,0 53 | /m/01bqk0,Bicycle wheel,0 54 | /m/01btn,Barge,0 55 | /m/01c648,Laptop,0 56 | /m/01cmb2,Miniskirt,0 57 | /m/01d380,Drill,0 58 | /m/01d40f,Dress,0 59 | /m/01dws,Bear,0 60 | /m/01dwsz,Waffle,0 61 | /m/01dwwc,Pancake,0 62 | /m/01dxs,Brown bear,0 63 | /m/01dy8n,Woodpecker,0 64 | /m/01f8m5,Blue jay,0 65 | /m/01f91_,Pretzel,0 66 | /m/01fb_0,Bagel,0 67 | /m/01fdzj,Tower,0 68 | /m/01fh4r,Teapot,0 69 | /m/01g317,Person,1 70 | /m/01g3x7,Bow and arrow,0 71 | /m/01gkx_,Swimwear,0 72 | /m/01gllr,Beehive,0 73 | /m/01gmv2,Brassiere,0 74 | /m/01h3n,Bee,0 75 | /m/01h44,Bat,0 76 | /m/01h8tj,Starfish,0 77 | /m/01hrv5,Popcorn,0 78 | /m/01j3zr,Burrito,0 79 | /m/01j4z9,Chainsaw,0 80 | /m/01j51,Balloon,0 81 | /m/01j5ks,Wrench,0 82 | /m/01j61q,Tent,0 83 | /m/01jfm_,Vehicle registration plate,0 84 | /m/01jfsr,Lantern,0 85 | /m/01k6s3,Toaster,0 86 | /m/01kb5b,Flashlight,0 87 | /m/01knjb,Billboard,1 88 | /m/01krhy,Tiara,0 89 | /m/01lcw4,Limousine,0 90 | /m/01llwg,Necklace,0 91 | /m/01lrl,Carnivore,0 92 | /m/01lsmm,Scissors,0 93 | /m/01lynh,Stairs,0 94 | /m/01m2v,Computer keyboard,0 95 | /m/01m4t,Printer,0 96 | /m/01mqdt,Traffic sign,0 97 | /m/01mzpv,Chair,0 98 | /m/01n4qj,Shirt,0 99 | /m/01n5jq,Poster,1 100 | /m/01nkt,Cheese,0 101 | /m/01nq26,Sock,0 102 | /m/01pns0,Fire hydrant,0 103 | /m/01prls,Land vehicle,0 104 | /m/01r546,Earrings,0 105 | /m/01rkbr,Tie,0 106 | /m/01rzcn,Watercraft,1 107 | /m/01s105,Cabinetry,0 108 | /m/01s55n,Suitcase,0 109 | /m/01tcjp,Muffin,0 110 | /m/01vbnl,Bidet,0 111 | /m/01ww8y,Snack,0 112 | /m/01x3jk,Snowmobile,0 113 | /m/01x3z,Clock,0 114 | /m/01xgg_,Medical equipment,1 115 | /m/01xq0k1,Cattle,0 116 | /m/01xqw,Cello,1 117 | /m/01xs3r,Jet ski,1 118 | /m/01x_v,Camel,0 119 | /m/01xygc,Coat,0 120 | /m/01xyhv,Suit,0 121 | /m/01y9k5,Desk,0 122 | /m/01yrx,Cat,0 123 | /m/01yx86,Bronze sculpture,0 124 | /m/01z1kdw,Juice,0 125 | /m/02068x,Gondola,1 126 | /m/020jm,Beetle,0 127 | /m/020kz,Cannon,0 128 | /m/020lf,Computer mouse,0 129 | /m/021mn,Cookie,0 130 | /m/021sj1,Office building,1 131 | /m/0220r2,Fountain,0 132 | /m/0242l,Coin,0 133 | /m/024d2,Calculator,0 134 | /m/024g6,Cocktail,0 135 | /m/02522,Computer monitor,0 136 | /m/025dyy,Box,0 137 | /m/025fsf,Stapler,0 138 | /m/025nd,Christmas tree,0 139 | /m/025rp__,Cowboy hat,0 140 | /m/0268lbt,Hiking equipment,0 141 | /m/026qbn5,Studio couch,0 142 | /m/026t6,Drum,0 143 | /m/0270h,Dessert,0 144 | /m/0271qf7,Wine rack,0 145 | /m/0271t,Drink,0 146 | /m/027pcv,Zucchini,0 147 | /m/027rl48,Ladle,0 148 | /m/0283dt1,Human mouth,0 149 | /m/0284d,Dairy,0 150 | /m/029b3,Dice,0 151 | /m/029bxz,Oven,0 152 | /m/029tx,Dinosaur,0 153 | /m/02bm9n,Ratchet,0 154 | /m/02crq1,Couch,0 155 | /m/02ctlc,Cricket ball,0 156 | /m/02cvgx,Winter melon,0 157 | /m/02d1br,Spatula,0 158 | /m/02d9qx,Whiteboard,0 159 | /m/02ddwp,Pencil sharpener,0 160 | /m/02dgv,Door,0 161 | /m/02dl1y,Hat,0 162 | /m/02f9f_,Shower,0 163 | /m/02fh7f,Eraser,0 164 | /m/02fq_6,Fedora,0 165 | /m/02g30s,Guacamole,0 166 | /m/02gzp,Dagger,0 167 | /m/02h19r,Scarf,0 168 | /m/02hj4,Dolphin,0 169 | /m/02jfl0,Sombrero,0 170 | /m/02jnhm,Tin can,0 171 | /m/02jvh9,Mug,0 172 | /m/02jz0l,Tap,0 173 | /m/02l8p9,Harbor seal,0 174 | /m/02lbcq,Stretcher,1 175 | /m/02mqfb,Can opener,0 176 | /m/02_n6y,Goggles,0 177 | /m/02p0tk3,Human body,1 178 | /m/02p3w7d,Roller skates,0 179 | /m/02p5f1q,Coffee cup,0 180 | /m/02pdsw,Cutting board,0 181 | /m/02pjr4,Blender,0 182 | /m/02pkr5,Plumbing fixture,0 183 | /m/02pv19,Stop sign,0 184 | /m/02rdsp,Office supplies,0 185 | /m/02rgn06,Volleyball,0 186 | /m/02s195,Vase,0 187 | /m/02tsc9,Slow cooker,0 188 | /m/02vkqh8,Wardrobe,0 189 | /m/02vqfm,Coffee,0 190 | /m/02vwcm,Whisk,0 191 | /m/02w3r3,Paper towel,0 192 | /m/02w3_ws,Personal care,0 193 | /m/02wbm,Food,0 194 | /m/02wbtzl,Sun hat,0 195 | /m/02wg_p,Tree house,0 196 | /m/02wmf,Flying disc,0 197 | /m/02wv6h6,Skirt,0 198 | /m/02wv84t,Gas stove,0 199 | /m/02x8cch,Salt and pepper shakers,0 200 | /m/02x984l,Mechanical fan,0 201 | /m/02xb7qb,Face powder,0 202 | /m/02xqq,Fax,0 203 | /m/02xwb,Fruit,0 204 | /m/02y6n,French fries,0 205 | /m/02z51p,Nightstand,0 206 | /m/02zn6n,Barrel,0 207 | /m/02zt3,Kite,0 208 | /m/02zvsm,Tart,0 209 | /m/030610,Treadmill,0 210 | /m/0306r,Fox,0 211 | /m/03120,Flag,1 212 | /m/0319l,Horn,0 213 | /m/031b6r,Window blind,0 214 | /m/031n1,Human foot,0 215 | /m/0323sq,Golf cart,0 216 | /m/032b3c,Jacket,0 217 | /m/033cnk,Egg,0 218 | /m/033rq4,Street light,0 219 | /m/0342h,Guitar,0 220 | /m/034c16,Pillow,0 221 | /m/035r7c,Human leg,1 222 | /m/035vxb,Isopod,0 223 | /m/0388q,Grape,0 224 | /m/039xj_,Human ear,0 225 | /m/03bbps,Power plugs and sockets,0 226 | /m/03bj1,Panda,0 227 | /m/03bk1,Giraffe,0 228 | /m/03bt1vf,Woman,1 229 | /m/03c7gz,Door handle,0 230 | /m/03d443,Rhinoceros,0 231 | /m/03dnzn,Bathtub,0 232 | /m/03fj2,Goldfish,0 233 | /m/03fp41,Houseplant,0 234 | /m/03fwl,Goat,0 235 | /m/03g8mr,Baseball bat,0 236 | /m/03grzl,Baseball glove,0 237 | /m/03hj559,Mixing bowl,0 238 | /m/03hl4l9,Marine invertebrates,0 239 | /m/03hlz0c,Kitchen utensil,0 240 | /m/03jbxj,Light switch,0 241 | /m/03jm5,House,1 242 | /m/03k3r,Horse,0 243 | /m/03kt2w,Stationary bicycle,0 244 | /m/03l9g,Hammer,0 245 | /m/03ldnb,Ceiling fan,0 246 | /m/03m3pdh,Sofa bed,0 247 | /m/03m3vtv,Adhesive tape,0 248 | /m/03m5k,Harp,0 249 | /m/03nfch,Sandal,0 250 | /m/03p3bw,Bicycle helmet,0 251 | /m/03q5c7,Saucer,0 252 | /m/03q5t,Harpsichord,0 253 | /m/03q69,Human hair,0 254 | /m/03qhv5,Heater,0 255 | /m/03qjg,Harmonica,0 256 | /m/03qrc,Hamster,0 257 | /m/03rszm,Curtain,0 258 | /m/03ssj5,Bed,0 259 | /m/03s_tn,Kettle,0 260 | /m/03tw93,Fireplace,0 261 | /m/03txqz,Scale,0 262 | /m/03v5tg,Drinking straw,0 263 | /m/03vt0,Insect,0 264 | /m/03wvsk,Hair dryer,0 265 | /m/03_wxk,Kitchenware,0 266 | /m/03wym,Indoor rower,0 267 | /m/03xxp,Invertebrate,0 268 | /m/03y6mg,Food processor,0 269 | /m/03__z0,Bookcase,0 270 | /m/040b_t,Refrigerator,0 271 | /m/04169hn,Wood-burning stove,0 272 | /m/0420v5,Punching bag,0 273 | /m/043nyj,Common fig,0 274 | /m/0440zs,Cocktail shaker,0 275 | /m/0449p,Jaguar,0 276 | /m/044r5d,Golf ball,0 277 | /m/0463sg,Fashion accessory,0 278 | /m/046dlr,Alarm clock,0 279 | /m/047j0r,Filing cabinet,0 280 | /m/047v4b,Artichoke,0 281 | /m/04bcr3,Table,0 282 | /m/04brg2,Tableware,0 283 | /m/04c0y,Kangaroo,0 284 | /m/04cp_,Koala,0 285 | /m/04ctx,Knife,0 286 | /m/04dr76w,Bottle,0 287 | /m/04f5ws,Bottle opener,0 288 | /m/04g2r,Lynx,0 289 | /m/04gth,Lavender,0 290 | /m/04h7h,Lighthouse,0 291 | /m/04h8sr,Dumbbell,0 292 | /m/04hgtk,Human head,1 293 | /m/04kkgm,Bowl,0 294 | /m/04lvq_,Humidifier,0 295 | /m/04m6gz,Porch,0 296 | /m/04m9y,Lizard,0 297 | /m/04p0qw,Billiard table,0 298 | /m/04rky,Mammal,0 299 | /m/04rmv,Mouse,0 300 | /m/04_sv,Motorcycle,1 301 | /m/04szw,Musical instrument,0 302 | /m/04tn4x,Swim cap,0 303 | /m/04v6l4,Frying pan,0 304 | /m/04vv5k,Snowplow,0 305 | /m/04y4h8h,Bathroom cabinet,0 306 | /m/04ylt,Missile,1 307 | /m/04yqq2,Bust,0 308 | /m/04yx4,Man,1 309 | /m/04z4wx,Waffle iron,0 310 | /m/04zpv,Milk,0 311 | /m/04zwwv,Ring binder,0 312 | /m/050gv4,Plate,0 313 | /m/050k8,Mobile phone,1 314 | /m/052lwg6,Baked goods,0 315 | /m/052sf,Mushroom,0 316 | /m/05441v,Crutch,0 317 | /m/054fyh,Pitcher,0 318 | /m/054_l,Mirror,0 319 | /m/054xkw,Lifejacket,0 320 | /m/05_5p_0,Table tennis racket,0 321 | /m/05676x,Pencil case,0 322 | /m/057cc,Musical keyboard,0 323 | /m/057p5t,Scoreboard,0 324 | /m/0584n8,Briefcase,0 325 | /m/058qzx,Kitchen knife,0 326 | /m/05bm6,Nail,0 327 | /m/05ctyq,Tennis ball,0 328 | /m/05gqfk,Plastic bag,0 329 | /m/05kms,Oboe,0 330 | /m/05kyg_,Chest of drawers,0 331 | /m/05n4y,Ostrich,0 332 | /m/05r5c,Piano,0 333 | /m/05r655,Girl,1 334 | /m/05s2s,Plant,0 335 | /m/05vtc,Potato,0 336 | /m/05w9t9,Hair spray,0 337 | /m/05y5lj,Sports equipment,0 338 | /m/05z55,Pasta,0 339 | /m/05z6w,Penguin,0 340 | /m/05zsy,Pumpkin,0 341 | /m/061_f,Pear,0 342 | /m/061hd_,Infant bed,0 343 | /m/0633h,Polar bear,0 344 | /m/063rgb,Mixer,0 345 | /m/0642b4,Cupboard,0 346 | /m/065h6l,Jacuzzi,0 347 | /m/0663v,Pizza,0 348 | /m/06_72j,Digital clock,0 349 | /m/068zj,Pig,0 350 | /m/06bt6,Reptile,0 351 | /m/06c54,Rifle,1 352 | /m/06c7f7,Lipstick,0 353 | /m/06_fw,Skateboard,0 354 | /m/06j2d,Raven,0 355 | /m/06k2mb,High heels,0 356 | /m/06l9r,Red panda,0 357 | /m/06m11,Rose,0 358 | /m/06mf6,Rabbit,0 359 | /m/06msq,Sculpture,0 360 | /m/06ncr,Saxophone,0 361 | /m/06nrc,Shotgun,1 362 | /m/06nwz,Seafood,0 363 | /m/06pcq,Submarine sandwich,0 364 | /m/06__v,Snowboard,0 365 | /m/06y5r,Sword,0 366 | /m/06z37_,Picture frame,0 367 | /m/07030,Sushi,0 368 | /m/0703r8,Loveseat,0 369 | /m/071p9,Ski,0 370 | /m/071qp,Squirrel,0 371 | /m/073bxn,Tripod,0 372 | /m/073g6,Stethoscope,0 373 | /m/074d1,Submarine,0 374 | /m/0755b,Scorpion,0 375 | /m/076bq,Segway,0 376 | /m/076lb9,Training bench,0 377 | /m/078jl,Snake,0 378 | /m/078n6m,Coffee table,0 379 | /m/079cl,Skyscraper,1 380 | /m/07bgp,Sheep,0 381 | /m/07c52,Television,0 382 | /m/07c6l,Trombone,0 383 | /m/07clx,Tea,0 384 | /m/07cmd,Tank,1 385 | /m/07crc,Taco,0 386 | /m/07cx4,Telephone,0 387 | /m/07dd4,Torch,0 388 | /m/07dm6,Tiger,0 389 | /m/07fbm7,Strawberry,0 390 | /m/07gql,Trumpet,0 391 | /m/07j7r,Tree,0 392 | /m/07j87,Tomato,0 393 | /m/07jdr,Train,1 394 | /m/07k1x,Tool,1 395 | /m/07kng9,Picnic basket,0 396 | /m/07mcwg,Cooking spray,0 397 | /m/07mhn,Trousers,0 398 | /m/07pj7bq,Bowling equipment,0 399 | /m/07qxg_,Football helmet,0 400 | /m/07r04,Truck,1 401 | /m/07v9_z,Measuring cup,0 402 | /m/07xyvk,Coffeemaker,0 403 | /m/07y_7,Violin,0 404 | /m/07yv9,Vehicle,1 405 | /m/080hkjn,Handbag,0 406 | /m/080n7g,Paper cutter,0 407 | /m/081qc,Wine,0 408 | /m/083kb,Weapon,0 409 | /m/083wq,Wheel,0 410 | /m/084hf,Worm,0 411 | /m/084rd,Wok,0 412 | /m/084zz,Whale,0 413 | /m/0898b,Zebra,0 414 | /m/08dz3q,Auto part,1 415 | /m/08hvt4,Jug,0 416 | /m/08ks85,Pizza cutter,0 417 | /m/08p92x,Cream,0 418 | /m/08pbxl,Monkey,0 419 | /m/096mb,Lion,0 420 | /m/09728,Bread,0 421 | /m/099ssp,Platter,0 422 | /m/09b5t,Chicken,0 423 | /m/09csl,Eagle,0 424 | /m/09ct_,Helicopter,1 425 | /m/09d5_,Owl,0 426 | /m/09ddx,Duck,0 427 | /m/09dzg,Turtle,0 428 | /m/09f20,Hippopotamus,0 429 | /m/09f_2,Crocodile,0 430 | /m/09g1w,Toilet,0 431 | /m/09gtd,Toilet paper,0 432 | /m/09gys,Squid,0 433 | /m/09j2d,Clothing,0 434 | /m/09j5n,Footwear,0 435 | /m/09k_b,Lemon,0 436 | /m/09kmb,Spider,0 437 | /m/09kx5,Deer,0 438 | /m/09ld4,Frog,0 439 | /m/09qck,Banana,0 440 | /m/09rvcxw,Rocket,1 441 | /m/09tvcd,Wine glass,0 442 | /m/0b3fp9,Countertop,0 443 | /m/0bh9flk,Tablet computer,0 444 | /m/0bjyj5,Waste container,0 445 | /m/0b_rs,Swimming pool,0 446 | /m/0bt9lr,Dog,0 447 | /m/0bt_c3,Book,0 448 | /m/0bwd_0j,Elephant,0 449 | /m/0by6g,Shark,0 450 | /m/0c06p,Candle,0 451 | /m/0c29q,Leopard,0 452 | /m/0c2jj,Axe,0 453 | /m/0c3m8g,Hand dryer,0 454 | /m/0c3mkw,Soap dispenser,0 455 | /m/0c568,Porcupine,0 456 | /m/0c9ph5,Flower,0 457 | /m/0ccs93,Canary,0 458 | /m/0cd4d,Cheetah,0 459 | /m/0cdl1,Palm tree,0 460 | /m/0cdn1,Hamburger,0 461 | /m/0cffdh,Maple,0 462 | /m/0cgh4,Building,1 463 | /m/0ch_cf,Fish,0 464 | /m/0cjq5,Lobster,0 465 | /m/0cjs7,Asparagus,0 466 | /m/0c_jw,Furniture,0 467 | /m/0cl4p,Hedgehog,0 468 | /m/0cmf2,Airplane,1 469 | /m/0cmx8,Spoon,0 470 | /m/0cn6p,Otter,0 471 | /m/0cnyhnx,Bull,0 472 | /m/0_cp5,Oyster,0 473 | /m/0cqn2,Horizontal bar,0 474 | /m/0crjs,Convenience store,0 475 | /m/0ct4f,Bomb,1 476 | /m/0cvnqh,Bench,0 477 | /m/0cxn2,Ice cream,0 478 | /m/0cydv,Caterpillar,0 479 | /m/0cyf8,Butterfly,0 480 | /m/0cyfs,Parachute,0 481 | /m/0cyhj_,Orange,0 482 | /m/0czz2,Antelope,0 483 | /m/0d20w4,Beaker,0 484 | /m/0d_2m,Moths and butterflies,0 485 | /m/0d4v4,Window,0 486 | /m/0d4w1,Closet,0 487 | /m/0d5gx,Castle,0 488 | /m/0d8zb,Jellyfish,0 489 | /m/0dbvp,Goose,0 490 | /m/0dbzx,Mule,0 491 | /m/0dftk,Swan,0 492 | /m/0dj6p,Peach,0 493 | /m/0djtd,Coconut,0 494 | /m/0dkzw,Seat belt,0 495 | /m/0dq75,Raccoon,0 496 | /m/0_dqb,Chisel,0 497 | /m/0dt3t,Fork,0 498 | /m/0dtln,Lamp,0 499 | /m/0dv5r,Camera,0 500 | /m/0dv77,Squash,0 501 | /m/0dv9c,Racket,0 502 | /m/0dzct,Human face,1 503 | /m/0dzf4,Human arm,1 504 | /m/0f4s2w,Vegetable,0 505 | /m/0f571,Diaper,0 506 | /m/0f6nr,Unicycle,0 507 | /m/0f6wt,Falcon,0 508 | /m/0f8s22,Chime,0 509 | /m/0f9_l,Snail,0 510 | /m/0fbdv,Shellfish,0 511 | /m/0fbw6,Cabbage,0 512 | /m/0fj52s,Carrot,0 513 | /m/0fldg,Mango,0 514 | /m/0fly7,Jeans,0 515 | /m/0fm3zh,Flowerpot,0 516 | /m/0fp6w,Pineapple,0 517 | /m/0fqfqc,Drawer,0 518 | /m/0fqt361,Stool,0 519 | /m/0frqm,Envelope,0 520 | /m/0fszt,Cake,0 521 | /m/0ft9s,Dragonfly,0 522 | /m/0ftb8,Sunflower,0 523 | /m/0fx9l,Microwave oven,0 524 | /m/0fz0h,Honeycomb,0 525 | /m/0gd2v,Marine mammal,0 526 | /m/0gd36,Sea lion,0 527 | /m/0gj37,Ladybug,0 528 | /m/0gjbg72,Shelf,0 529 | /m/0gjkl,Watch,0 530 | /m/0gm28,Candy,0 531 | /m/0grw1,Salad,0 532 | /m/0gv1x,Parrot,0 533 | /m/0gxl3,Handgun,1 534 | /m/0h23m,Sparrow,0 535 | /m/0h2r6,Van,1 536 | /m/0h8jyh6,Grinder,0 537 | /m/0h8kx63,Spice rack,0 538 | /m/0h8l4fh,Light bulb,0 539 | /m/0h8lkj8,Corded phone,1 540 | /m/0h8mhzd,Sports uniform,0 541 | /m/0h8my_4,Tennis racket,0 542 | /m/0h8mzrc,Wall clock,0 543 | /m/0h8n27j,Serving tray,0 544 | /m/0h8n5zk,Kitchen & dining room table,0 545 | /m/0h8n6f9,Dog bed,0 546 | /m/0h8n6ft,Cake stand,0 547 | /m/0h8nm9j,Cat furniture,0 548 | /m/0h8nr_l,Bathroom accessory,0 549 | /m/0h8nsvg,Facial tissue holder,0 550 | /m/0h8ntjv,Pressure cooker,0 551 | /m/0h99cwc,Kitchen appliance,0 552 | /m/0h9mv,Tire,0 553 | /m/0hdln,Ruler,0 554 | /m/0hf58v5,Luggage and bags,0 555 | /m/0hg7b,Microphone,0 556 | /m/0hkxq,Broccoli,0 557 | /m/0hnnb,Umbrella,0 558 | /m/0hnyx,Pastry,0 559 | /m/0hqkz,Grapefruit,0 560 | /m/0j496,Band-aid,0 561 | /m/0jbk,Animal,0 562 | /m/0jg57,Bell pepper,0 563 | /m/0jly1,Turkey,0 564 | /m/0jqgx,Lily,0 565 | /m/0jwn_,Pomegranate,0 566 | /m/0jy4k,Doughnut,0 567 | /m/0jyfg,Glasses,0 568 | /m/0k0pj,Human nose,0 569 | /m/0k1tl,Pen,0 570 | /m/0_k2,Ant,0 571 | /m/0k4j,Car,1 572 | /m/0k5j,Aircraft,1 573 | /m/0k65p,Human hand,1 574 | /m/0km7z,Skunk,0 575 | /m/0kmg4,Teddy bear,0 576 | /m/0kpqd,Watermelon,0 577 | /m/0kpt_,Cantaloupe,0 578 | /m/0ky7b,Dishwasher,0 579 | /m/0l14j_,Flute,0 580 | /m/0l3ms,Balance beam,0 581 | /m/0l515,Sandwich,0 582 | /m/0ll1f78,Shrimp,0 583 | /m/0llzx,Sewing machine,0 584 | /m/0lt4_,Binoculars,0 585 | /m/0m53l,Rays and skates,0 586 | /m/0mcx2,Ipod,0 587 | /m/0mkg,Accordion,0 588 | /m/0mw_6,Willow,0 589 | /m/0n28_,Crab,0 590 | /m/0nl46,Crown,0 591 | /m/0nybt,Seahorse,0 592 | /m/0p833,Perfume,0 593 | /m/0pcr,Alpaca,0 594 | /m/0pg52,Taxi,0 595 | /m/0ph39,Canoe,0 596 | /m/0qjjc,Remote control,0 597 | /m/0qmmr,Wheelchair,0 598 | /m/0wdt60w,Rugby ball,0 599 | /m/0xfy,Armadillo,0 600 | /m/0xzly,Maracas,0 601 | /m/0zvk5,Helmet,0 -------------------------------------------------------------------------------- /src/engine/Groundingrunner.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import pickle 4 | import json 5 | import sys 6 | from functools import partial 7 | 8 | import numpy as np 9 | import torch 10 | from tensorboardX import SummaryWriter 11 | from torchvision import transforms 12 | 13 | import sys 14 | sys.path.append('../..') 15 | from src.util import consts 16 | from src.dataflow.torch.Data import MultiTokenField, SparseField, EntityField 17 | from torchtext.data import Field 18 | from src.util.vocab import Vocab 19 | from torchtext.vocab import Vectors 20 | from src.dataflow.numpy.data_loader_grounding import GroundingDataset 21 | from src.models.grounding import GroundingModel 22 | from src.eval.Groundingtesting import GroundingTester 23 | from src.engine.Groundingtraining import grounding_train 24 | from src.engine.SRrunner import load_sr_model 25 | from src.engine.EErunner import load_ee_model 26 | from src.util.util_model import log 27 | 28 | 29 | class GroundingRunner(object): 30 | def __init__(self): 31 | parser = argparse.ArgumentParser(description="neural networks trainer") 32 | parser.add_argument("--test", help="validation set") 33 | parser.add_argument("--train", help="training set", required=False) 34 | parser.add_argument("--dev", help="development set", required=False) 35 | parser.add_argument("--webd", help="word embedding", required=False) 36 | parser.add_argument("--img_dir", help="Grounding images directory", required=False) 37 | parser.add_argument("--amr", help="use amr", action='store_true') 38 | 39 | # sr model parameter 40 | parser.add_argument("--wnebd", help="noun word embedding", required=False) 41 | parser.add_argument("--wvebd", help="verb word embedding", required=False) 42 | parser.add_argument("--wrebd", help="role word embedding", required=False) 43 | parser.add_argument("--add_object", help="add_object", action='store_true') 44 | parser.add_argument("--object_class_map_file", help="object_class_map_file", required=False) 45 | parser.add_argument("--object_detection_pkl_file", help="object_detection_pkl_file", required=False) 46 | parser.add_argument("--object_detection_threshold", default=0.2, type=float, help="object_detection_threshold", 47 | required=False) 48 | 49 | parser.add_argument("--vocab", help="vocab_dir", required=False) 50 | parser.add_argument("--sr_hps", help="sr model hyperparams", required=False) 51 | 52 | # ee model parameter 53 | parser.add_argument("--ee_hps", help="ee model hyperparams", required=False) 54 | 55 | parser.add_argument("--batch", help="batch size", default=128, type=int) 56 | parser.add_argument("--epochs", help="n of epochs", default=sys.maxsize, type=int) 57 | 58 | parser.add_argument("--seed", help="RNG seed", default=42, type=int) 59 | parser.add_argument("--optimizer", default="adam") 60 | parser.add_argument("--lr", default=1e-3, type=float) 61 | parser.add_argument("--l2decay", default=0, type=float) 62 | parser.add_argument("--maxnorm", default=3, type=float) 63 | 64 | parser.add_argument("--out", help="output model path", default="out") 65 | parser.add_argument("--finetune_sr", help="pretrained sr model path") 66 | parser.add_argument("--finetune_ee", help="pretrained ee model path") 67 | parser.add_argument("--earlystop", default=999999, type=int) 68 | parser.add_argument("--restart", default=999999, type=int) 69 | parser.add_argument("--shuffle", help="shuffle", action='store_true') 70 | 71 | parser.add_argument("--device", default="cpu") 72 | 73 | self.a = parser.parse_args() 74 | 75 | def set_device(self, device="cpu"): 76 | # self.device = torch.device(device) 77 | self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 78 | 79 | def get_device(self): 80 | return self.device 81 | 82 | # def load_model(self, ed_model, sr_model, fine_tune): 83 | # if fine_tune is None: 84 | # train_model = GroundingModel(ed_model, sr_model, self.get_device()) 85 | # return train_model 86 | # else: 87 | # mymodel = GroundingModel(ed_model, sr_model, self.get_device()) 88 | # mymodel.load_model(fine_tune) 89 | # mymodel.to(self.get_device()) 90 | # return mymodel 91 | 92 | def get_tester(self): 93 | return GroundingTester() 94 | 95 | def run(self): 96 | print("Running on", self.a.device) 97 | self.set_device(self.a.device) 98 | 99 | np.random.seed(self.a.seed) 100 | torch.manual_seed(self.a.seed) 101 | torch.backends.cudnn.benchmark = True 102 | 103 | # create training set 104 | if self.a.train: 105 | log('loading corpus from %s' % self.a.train) 106 | 107 | transform = transforms.Compose([ 108 | transforms.Resize(256), 109 | transforms.RandomHorizontalFlip(), 110 | transforms.RandomCrop(224), 111 | transforms.ToTensor(), 112 | transforms.Normalize((0.485, 0.456, 0.406), 113 | (0.229, 0.224, 0.225))]) 114 | 115 | IMAGEIDField = SparseField(sequential=False, use_vocab=False, batch_first=True) 116 | SENTIDField = SparseField(sequential=False, use_vocab=False, batch_first=True) 117 | # IMAGEField = SparseField(sequential=False, use_vocab=False, batch_first=True) 118 | WordsField = Field(lower=True, include_lengths=True, batch_first=True) 119 | PosTagsField = Field(lower=True, batch_first=True) 120 | EntityLabelsField = MultiTokenField(lower=False, batch_first=True) 121 | AdjMatrixField = SparseField(sequential=False, use_vocab=False, batch_first=True) 122 | EntitiesField = EntityField(lower=False, batch_first=True, use_vocab=False) 123 | 124 | if self.a.amr: 125 | colcc = 'simple-parsing' 126 | else: 127 | colcc = 'combined-parsing' 128 | print(colcc) 129 | 130 | train_set = GroundingDataset(path=self.a.train, 131 | img_dir=self.a.img_dir, 132 | fields={"id": ("IMAGEID", IMAGEIDField), 133 | "sentence_id": ("SENTID", SENTIDField), 134 | "words": ("WORDS", WordsField), 135 | "pos-tags": ("POSTAGS", PosTagsField), 136 | "golden-entity-mentions": ("ENTITYLABELS", EntityLabelsField), 137 | colcc: ("ADJM", AdjMatrixField), 138 | "all-entities": ("ENTITIES", EntitiesField), 139 | # "image": ("IMAGE", IMAGEField), 140 | }, 141 | transform=transform, 142 | amr=self.a.amr, 143 | load_object=self.a.add_object, 144 | object_ontology_file=self.a.object_class_map_file, 145 | object_detection_pkl_file=self.a.object_detection_pkl_file, 146 | object_detection_threshold=self.a.object_detection_threshold, 147 | ) 148 | 149 | dev_set = GroundingDataset(path=self.a.dev, 150 | img_dir=self.a.img_dir, 151 | fields={"id": ("IMAGEID", IMAGEIDField), 152 | "sentence_id": ("SENTID", SENTIDField), 153 | "words": ("WORDS", WordsField), 154 | "pos-tags": ("POSTAGS", PosTagsField), 155 | "golden-entity-mentions": ("ENTITYLABELS", EntityLabelsField), 156 | colcc: ("ADJM", AdjMatrixField), 157 | "all-entities": ("ENTITIES", EntitiesField), 158 | # "image": ("IMAGE", IMAGEField), 159 | }, 160 | transform=transform, 161 | amr=self.a.amr, 162 | load_object=self.a.add_object, 163 | object_ontology_file=self.a.object_class_map_file, 164 | object_detection_pkl_file=self.a.object_detection_pkl_file, 165 | object_detection_threshold=self.a.object_detection_threshold, 166 | ) 167 | 168 | test_set = GroundingDataset(path=self.a.test, 169 | img_dir=self.a.img_dir, 170 | fields={"id": ("IMAGEID", IMAGEIDField), 171 | "sentence_id": ("SENTID", SENTIDField), 172 | "words": ("WORDS", WordsField), 173 | "pos-tags": ("POSTAGS", PosTagsField), 174 | "golden-entity-mentions": ("ENTITYLABELS", EntityLabelsField), 175 | colcc: ("ADJM", AdjMatrixField), 176 | "all-entities": ("ENTITIES", EntitiesField), 177 | # "image": ("IMAGE", IMAGEField), 178 | }, 179 | transform=transform, 180 | amr=self.a.amr, 181 | load_object=self.a.add_object, 182 | object_ontology_file=self.a.object_class_map_file, 183 | object_detection_pkl_file=self.a.object_detection_pkl_file, 184 | object_detection_threshold=self.a.object_detection_threshold, 185 | ) 186 | 187 | if self.a.webd: 188 | pretrained_embedding = Vectors(self.a.webd, ".", unk_init=partial(torch.nn.init.uniform_, a=-0.15, b=0.15)) 189 | WordsField.build_vocab(train_set.WORDS, dev_set.WORDS, vectors=pretrained_embedding) 190 | else: 191 | WordsField.build_vocab(train_set.WORDS, dev_set.WORDS) 192 | # WordsField.build_vocab(train_set.WORDS, dev_set.WORDS) 193 | PosTagsField.build_vocab(train_set.POSTAGS, dev_set.POSTAGS) 194 | EntityLabelsField.build_vocab(train_set.ENTITYLABELS, dev_set.ENTITYLABELS) 195 | 196 | # sr model initialization 197 | self.a.sr_hps = eval(self.a.sr_hps) 198 | vocab_noun = Vocab(os.path.join(self.a.vocab, 'vocab_situation_noun.pkl'), load=True) 199 | vocab_role = Vocab(os.path.join(self.a.vocab, 'vocab_situation_role.pkl'), load=True) 200 | vocab_verb = Vocab(os.path.join(self.a.vocab, 'vocab_situation_verb.pkl'), load=True) 201 | embeddingMatrix_noun = torch.FloatTensor(np.load(self.a.wnebd)).to(self.device) 202 | embeddingMatrix_verb = torch.FloatTensor(np.load(self.a.wvebd)).to(self.device) 203 | embeddingMatrix_role = torch.FloatTensor(np.load(self.a.wrebd)).to(self.device) 204 | if "wvemb_size" not in self.a.sr_hps: 205 | self.a.sr_hps["wvemb_size"] = len(vocab_verb.id2word) 206 | if "wremb_size" not in self.a.sr_hps: 207 | self.a.sr_hps["wremb_size"] = len(vocab_role.id2word) 208 | if "wnemb_size" not in self.a.sr_hps: 209 | self.a.sr_hps["wnemb_size"] = len(vocab_noun.id2word) 210 | if "ae_oc" not in self.a.sr_hps: 211 | self.a.sr_hps["ae_oc"] = len(vocab_role.id2word) 212 | 213 | self.a.ee_hps = eval(self.a.ee_hps) 214 | if "wemb_size" not in self.a.ee_hps: 215 | self.a.ee_hps["wemb_size"] = len(WordsField.vocab.itos) 216 | if "pemb_size" not in self.a.ee_hps: 217 | self.a.ee_hps["pemb_size"] = len(PosTagsField.vocab.itos) 218 | if "psemb_size" not in self.a.ee_hps: 219 | self.a.ee_hps["psemb_size"] = max([train_set.longest(), dev_set.longest(), test_set.longest()]) + 2 220 | if "eemb_size" not in self.a.ee_hps: 221 | self.a.ee_hps["eemb_size"] = len(EntityLabelsField.vocab.itos) 222 | if "oc" not in self.a.ee_hps: 223 | self.a.ee_hps["oc"] = 36 #??? 224 | if "ae_oc" not in self.a.ee_hps: 225 | self.a.ee_hps["ae_oc"] = 20 #??? 226 | 227 | tester = self.get_tester() 228 | 229 | if self.a.finetune_sr: 230 | log('init sr model from ' + self.a.finetune_sr) 231 | sr_model = load_sr_model(self.a.sr_hps, embeddingMatrix_noun, embeddingMatrix_verb, embeddingMatrix_role, self.a.finetune_sr, self.device) 232 | log('sr model loaded, there are %i sets of params' % len(sr_model.parameters_requires_grads())) 233 | else: 234 | sr_model = load_sr_model(self.a.sr_hps, embeddingMatrix_noun, embeddingMatrix_verb, embeddingMatrix_role, None, self.device) 235 | log('sr model created from scratch, there are %i sets of params' % len(sr_model.parameters_requires_grads())) 236 | 237 | if self.a.finetune_ee: 238 | log('init model from ' + self.a.finetune_ee) 239 | ee_model = load_ee_model(self.a.ee_hps, self.a.finetune_ee, WordsField.vocab.vectors, self.device) 240 | log('model loaded, there are %i sets of params' % len(ee_model.parameters_requires_grads())) 241 | else: 242 | ee_model = load_ee_model(self.a.ee_hps, None, WordsField.vocab.vectors, self.device) 243 | log('model created from scratch, there are %i sets of params' % len(ee_model.parameters_requires_grads())) 244 | 245 | model = GroundingModel(ee_model, sr_model, self.get_device()) 246 | 247 | if self.a.optimizer == "adadelta": 248 | optimizer_constructor = partial(torch.optim.Adadelta, params=model.parameters_requires_grads(), 249 | weight_decay=self.a.l2decay) 250 | elif self.a.optimizer == "adam": 251 | optimizer_constructor = partial(torch.optim.Adam, params=model.parameters_requires_grads(), 252 | weight_decay=self.a.l2decay) 253 | else: 254 | optimizer_constructor = partial(torch.optim.SGD, params=model.parameters_requires_grads(), 255 | weight_decay=self.a.l2decay, 256 | momentum=0.9) 257 | 258 | log('optimizer in use: %s' % str(self.a.optimizer)) 259 | 260 | if not os.path.exists(self.a.out): 261 | os.mkdir(self.a.out) 262 | with open(os.path.join(self.a.out, "word.vec"), "wb") as f: 263 | pickle.dump(WordsField.vocab, f) 264 | with open(os.path.join(self.a.out, "pos.vec"), "wb") as f: 265 | pickle.dump(PosTagsField.vocab.stoi, f) 266 | with open(os.path.join(self.a.out, "entity.vec"), "wb") as f: 267 | pickle.dump(EntityLabelsField.vocab.stoi, f) 268 | with open(os.path.join(self.a.out, "ee_hyps.json"), "w") as f: 269 | json.dump(self.a.ee_hps, f) 270 | with open(os.path.join(self.a.out, "sr_hyps.json"), "w") as f: 271 | json.dump(self.a.sr_hps, f) 272 | 273 | log('init complete\n') 274 | 275 | self.a.word_i2s = vocab_noun.id2word 276 | self.a.label_i2s = vocab_verb.id2word # LabelField.vocab.itos 277 | self.a.role_i2s = vocab_role.id2word 278 | self.a.word_i2s = WordsField.vocab.itos 279 | # self.a.label_i2s = LabelField.vocab.itos 280 | # self.a.role_i2s = EventsField.vocab.itos 281 | writer = SummaryWriter(os.path.join(self.a.out, "exp")) 282 | self.a.writer = writer 283 | 284 | grounding_train( 285 | model=model, 286 | train_set=train_set, 287 | dev_set=dev_set, 288 | test_set=test_set, 289 | optimizer_constructor=optimizer_constructor, 290 | epochs=self.a.epochs, 291 | tester=tester, 292 | parser=self.a, 293 | other_testsets={ 294 | # "dev 1/1": dev_set1, 295 | # "test 1/1": test_set1, 296 | }, 297 | transform=transform, 298 | vocab_objlabel=vocab_noun.word2id 299 | ) 300 | log('Done!') 301 | 302 | 303 | if __name__ == "__main__": 304 | GroundingRunner().run() 305 | --------------------------------------------------------------------------------