├── t5_large.yml ├── t5_base.yml ├── e2e_t5_base.yml ├── COCO_t5_base_nocaps.yml ├── dataset ├── diversity.py ├── data_utils.py ├── word_norm.json ├── vocabulary.py ├── blacklist_only_no_people.json ├── nocaps_novel_constraint.txt ├── EvalAI.py ├── e2e_dataset.py ├── dataset.py ├── pymteval.py ├── coco_dataset.py ├── reader.py └── new_copy_vocab_all.txt ├── utils.py ├── .gitignore ├── config.py ├── checkpointing.py ├── constraint.py ├── README.md ├── LICENSE ├── train_e2e_T5.py ├── train_T5.py └── train_COCO_T5.py /t5_large.yml: -------------------------------------------------------------------------------- 1 | train_path: dataset/commongen.train.jsonl 2 | dev_path: dataset/commongen.dev.jsonl 3 | test_path: dataset/commongen.test.jsonl 4 | batch_size: 16 5 | gradient_accumulation_steps: 12 6 | learning_rate: 5e-5 7 | checkpoint_every_step: 500 8 | max_epoch: 20 9 | use_pointer: False 10 | use_mention_flag: True 11 | copy_vocab_path: dataset/new_copy_vocab.txt 12 | lm_type: t5-large 13 | freeze_param: True 14 | use_orginal_enc_pos_embs: True 15 | freeze_enc_pos_param: False 16 | -------------------------------------------------------------------------------- /t5_base.yml: -------------------------------------------------------------------------------- 1 | train_path: dataset/commongen.train.jsonl 2 | dev_path: dataset/commongen.dev.jsonl 3 | test_path: dataset/commongen.test.jsonl 4 | batch_size: 48 5 | gradient_accumulation_steps: 4 6 | learning_rate: 5e-5 7 | checkpoint_every_step: 500 8 | max_epoch: 20 9 | use_pointer: False 10 | use_mention_flag: True 11 | copy_vocab_path: dataset/new_copy_vocab.txt 12 | lm_type: t5-base 13 | freeze_param: True 14 | do_pretrain_lm_init: True 15 | static_mf: False 16 | mention_flag_state: 3 17 | use_orginal_enc_pos_embs: True 18 | freeze_enc_pos_param: False 19 | -------------------------------------------------------------------------------- /e2e_t5_base.yml: -------------------------------------------------------------------------------- 1 | train_path: dataset/e2e_train_old.json 2 | dev_path: dataset/e2e_dev_old.json 3 | test_path: dataset/e2e_test_old.json 4 | copy_vocab_path: dataset/new_copy_vocab_all.txt 5 | max_generation_len: 50 6 | batch_size: 50 7 | gradient_accumulation_steps: 1 8 | learning_rate: 5e-5 9 | checkpoint_every_step: 1000 10 | max_epoch: 15 11 | use_pointer: False 12 | use_mention_flag: True 13 | mention_flag_state: 3 14 | lm_type: t5-base 15 | freeze_param: True 16 | enable_visual: False 17 | rm_dumplicated_caption: True 18 | shuffle_data: True 19 | rm_punctuation: False 20 | relative_pos_num: 55 21 | use_orginal_enc_pos_embs: True 22 | freeze_enc_pos_param: False 23 | do_pretrain_lm_init: True 24 | -------------------------------------------------------------------------------- /COCO_t5_base_nocaps.yml: -------------------------------------------------------------------------------- 1 | train_path: dataset/captions_train2017.json 2 | dev_path: dataset/nocaps_val_captions.json 3 | copy_vocab_path: dataset/new_object_class_v2.txt 4 | train_obj_h5_path: dataset/train_adaptive_101.h5 5 | dev_obj_h5_path: dataset/nocaps_val_adaptive.h5 6 | train_copy_obj_h5_path: dataset/tf_faster_rcnn_inception_resnet_v2_atrous_oid_v4_boxes.h5 7 | dev_copy_obj_h5_path: dataset/tf_faster_rcnn_inception_resnet_v2_atrous_oid_v4_boxes.h5 8 | word_norm_jsonpath: dataset/word_norm.json 9 | max_generation_len: 20 10 | batch_size: 25 11 | gradient_accumulation_steps: 2 12 | learning_rate: 5e-5 13 | checkpoint_every_step: 6000 14 | max_epoch: 15 15 | use_pointer: False 16 | use_mention_flag: True 17 | lm_type: t5-base 18 | freeze_param: True 19 | enable_visual: True 20 | use_copy_obj: True 21 | object_blacklist_path: dataset/blacklist_only_no_people.json 22 | rm_dumplicated_caption: True 23 | shuffle_data: True 24 | external_eval: False 25 | rm_punctuation: False 26 | mention_flag_state: 4 27 | relative_pos_num: 55 28 | use_orginal_enc_pos_embs: True 29 | freeze_enc_pos_param: False 30 | static_mf: False 31 | do_pretrain_lm_init: True 32 | use_mf_scalar: False 33 | -------------------------------------------------------------------------------- /dataset/diversity.py: -------------------------------------------------------------------------------- 1 | from itertools import chain 2 | import json 3 | 4 | def pad_sequence(sequence, n, pad_left=False, pad_right=False, 5 | left_pad_symbol=None, right_pad_symbol=None): 6 | sequence = iter(sequence) 7 | if pad_left: 8 | sequence = chain((left_pad_symbol,) * (n - 1), sequence) 9 | if pad_right: 10 | sequence = chain(sequence, (right_pad_symbol,) * (n - 1)) 11 | return sequence 12 | 13 | 14 | def ngrams(sequence, n, pad_left=False, pad_right=False, 15 | left_pad_symbol=None, right_pad_symbol=None): 16 | sequence = pad_sequence(sequence, n, pad_left, pad_right, 17 | left_pad_symbol, right_pad_symbol) 18 | 19 | history = [] 20 | while n > 1: 21 | history.append(next(sequence)) 22 | n -= 1 23 | for item in sequence: 24 | history.append(item) 25 | yield tuple(history) 26 | del history[0] 27 | 28 | def distinct_n_sentence_level(sentence, n): 29 | if len(sentence) == 0: 30 | return 0.0 # Prevent a zero division 31 | distinct_ngrams = set(ngrams(sentence, n)) 32 | return len(distinct_ngrams) / len(sentence) 33 | 34 | 35 | def distinct_n(sentences, n): 36 | return sum(distinct_n_sentence_level(sentence, n) for sentence in sentences) / len(sentences) -------------------------------------------------------------------------------- /dataset/data_utils.py: -------------------------------------------------------------------------------- 1 | import math 2 | import numpy as np 3 | import torch 4 | 5 | def get_position_emb_index(distance, num_buckets=16, max_distance=128, right=False): 6 | max_exact = num_buckets // 2 7 | if distance < max_exact: 8 | return distance if not right else distance + num_buckets 9 | else: 10 | pos = max_exact + math.log(distance / max_exact) / math.log(max_distance / max_exact) * (num_buckets - max_exact) 11 | pos = int(min(pos, num_buckets - 1)) 12 | return pos if not right else pos + num_buckets 13 | 14 | 15 | def process_tensor(tensor_list, last_dim, output_mask=False): 16 | tensor_len = [d.shape[0] for d in tensor_list] 17 | tensor_max_lenth = max(tensor_len) 18 | d_type = tensor_list[0].dtype 19 | if last_dim > 0: 20 | tensor_np = np.zeros((len(tensor_list), tensor_max_lenth, last_dim), dtype=d_type) 21 | else: 22 | tensor_np = np.zeros((len(tensor_list), tensor_max_lenth), dtype=d_type) 23 | mask_np = np.zeros((len(tensor_list), tensor_max_lenth), dtype=np.float32) 24 | for i, (d, l) in enumerate(zip(tensor_list, tensor_len)): 25 | if l > 0: 26 | tensor_np[i, :l] = d 27 | mask_np[i, :l] = 1 28 | if output_mask: 29 | return torch.from_numpy(tensor_np), torch.from_numpy(mask_np) 30 | else: 31 | return torch.from_numpy(tensor_np) -------------------------------------------------------------------------------- /dataset/word_norm.json: -------------------------------------------------------------------------------- 1 | { 2 | "t.v": "television", 3 | "r.v": "television", 4 | "tv": "television", 5 | "tvs": "televisions", 6 | "persons": "person", 7 | "childrens": "children", 8 | "teddybear": "teddy bear", 9 | "streetlight": "street light", 10 | "streetlights": "street lights", 11 | "hair drier": "hair dryer", 12 | "hairdryer": "hair dryer", 13 | "air plane": "airplane", 14 | "donut": "doughnut", 15 | "donuts": "doughnuts", 16 | "racquet": "racket", 17 | "racquets": "rackets", 18 | "mitt": "glove", 19 | "mitts": "gloves", 20 | "knapsack": "backpack", 21 | "bike": "bicycle", 22 | "bikes": "bicycles", 23 | "busses": "buses", 24 | "hydrant": "fire hydrant", 25 | "hydrants": "fire hydrants", 26 | "scarves": "scarfs", 27 | "hotdogs": "hot dogs", 28 | "hotdog": "hot dog", 29 | "sofa": "couch", 30 | "surfboard": "surf board", 31 | "street signs": "traffic signs", 32 | "stop light": "traffic sign", 33 | "cow": "cattle", 34 | "cows": "cattles", 35 | "pepole": "people", 36 | "bike": "bicycle", 37 | "bikes": "bicycles", 38 | "motor cycle": "motorcycle", 39 | "motor cycles": "motorcycles", 40 | "guys": "people", 41 | "guy": "person", 42 | "sailboat": "sail boat", 43 | "sailboats": "sail boats", 44 | "eephants": "elephants", 45 | "she": "a woman", 46 | "peerson": "person", 47 | "plane": "airplane", 48 | "passenger jet": "airplane", 49 | "kayaks": "canoes", 50 | "kayak":"canoe", 51 | "stop sigm": "stop sign", 52 | "kitty": "cat", 53 | "jetliner": "airplane", 54 | "jetliners": "airplanes", 55 | "aircraft": "airplane", 56 | "aircrafts": "airplanes", 57 | "track": "truck", 58 | "bow": "boy" 59 | } -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | from transformers import AdamW, get_linear_schedule_with_warmup, get_constant_schedule 2 | 3 | def get_output(output, config, copy_vocab, tokenzier): 4 | _BOUNDARY = tokenzier.eos_token_id 5 | N, D = output.size() 6 | output = output.detach().cpu() 7 | 8 | out = [] 9 | for i in range(N): 10 | txt = [] 11 | for j in range(D): 12 | ix = output[i, j].item() 13 | if ix == _BOUNDARY: break 14 | if ix < config.vocab_size: 15 | txt.append(ix) 16 | else: 17 | ix = ix - config.vocab_size 18 | txt += copy_vocab.token_fg_w[ix] 19 | out.append(txt) 20 | return out 21 | 22 | 23 | class AdamWOpt(object): 24 | def __init__(self, optimizer, scheduler): 25 | self.optimizer = optimizer 26 | self.scheduler = scheduler 27 | 28 | def __getattr__(self, name): 29 | return getattr(self.optimizer, name) 30 | 31 | def step(self): 32 | self.optimizer.step() 33 | self.scheduler.step() 34 | 35 | def build_optimizer(opt, model): 36 | no_decay = ["bias", "LayerNorm.weight"] 37 | optimizer_grouped_parameters = [ 38 | { 39 | "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], 40 | "weight_decay": opt.weight_decay, 41 | }, 42 | { 43 | "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 44 | "weight_decay": 0.0, 45 | }, 46 | ] 47 | assert opt.num_training_steps > 0 48 | optimizer = AdamW(optimizer_grouped_parameters, lr=opt.learning_rate, eps=opt.adam_epsilon) 49 | scheduler = get_constant_schedule(optimizer) #get_linear_schedule_with_warmup(optimizer, opt.warmup_step, opt.num_training_steps, -1) 50 | 51 | return AdamWOpt(optimizer, scheduler) 52 | -------------------------------------------------------------------------------- /dataset/vocabulary.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | class T5CopyVocabulary(object): 4 | 5 | def __init__(self, vocab_path, tokenizer, sep=','): 6 | with open(vocab_path) as out: 7 | self.d_to_w_group = {} 8 | self.i_to_w = {} 9 | self.w_to_i = {} 10 | self.i_to_cls = {} 11 | self.id_to_category = {} 12 | self.word_to_category_id = {} 13 | for idx, line in enumerate(out): 14 | items = line.strip().split(sep) 15 | self.d_to_w_group[idx] = [] 16 | for w in items: 17 | w = w.lower() 18 | assert len(w) > 0, "empty line %s" % line.strip() 19 | fg_index = len(self.i_to_w) 20 | self.d_to_w_group[idx].append((w, fg_index)) 21 | self.i_to_w[fg_index] = w 22 | self.w_to_i[w] = fg_index 23 | self.i_to_cls[fg_index] = idx 24 | self.id_to_category[len(self.id_to_category)] = items[0] 25 | self.word_to_category_id[items[0]] = len(self.word_to_category_id) 26 | self.detection_size = len(self.id_to_category) 27 | 28 | self.token_fg_w = {} 29 | for (fg_index, w) in self.i_to_w.items(): 30 | token_word = tokenizer(w, return_tensors="np")['input_ids'][0, :-1].tolist() 31 | self.token_fg_w[fg_index] = token_word 32 | 33 | self.token_class = {} 34 | for cls_index, w in self.id_to_category.items(): 35 | token_word = tokenizer(w, return_tensors="np")['input_ids'][0, :-1].tolist() 36 | self.token_class[cls_index] = token_word 37 | 38 | def get_detection_size(self): 39 | return self.detection_size 40 | 41 | def get_fg_size(self): 42 | return len(self.i_to_w) 43 | 44 | def get_category(self): 45 | return self.id_to_category 46 | 47 | 48 | -------------------------------------------------------------------------------- /dataset/blacklist_only_no_people.json: -------------------------------------------------------------------------------- 1 | { 2 | "blacklist_categories": [ 3 | "person", 4 | "people", 5 | "man", 6 | "woman", 7 | "men", 8 | "women", 9 | "boy", 10 | "girls", 11 | "boys", 12 | "girl", 13 | "lady", 14 | "ladies", 15 | "male", 16 | "female", 17 | "males", 18 | "females", 19 | "child", 20 | "children", 21 | "kid", 22 | "kids", 23 | "adult", 24 | "adults", 25 | "NONE", 26 | "__background__", 27 | "eye", 28 | "eyes", 29 | "Human eye", 30 | "Skull", 31 | "Human head", 32 | "head", 33 | "heads", 34 | "face", 35 | "faces", 36 | "mouth", 37 | "mouths", 38 | "ear", 39 | "ears", 40 | "nose", 41 | "noses", 42 | "hair", 43 | "hairs", 44 | "foot", 45 | "feet", 46 | "arm", 47 | "arms", 48 | "leg", 49 | "legs", 50 | "beard", 51 | "body", 52 | "Human face", 53 | "Human mouth", 54 | "Human ear", 55 | "Human nose", 56 | "Human hair", 57 | "Human hand", 58 | "Human foot", 59 | "Human arm", 60 | "Human leg", 61 | "Human beard", 62 | "Human body", 63 | "Vehicle registration plate", 64 | "Wheel", 65 | "Wheels", 66 | "front wheel", 67 | "back wheel", 68 | "steering wheel", 69 | "Seat belt", 70 | "Tire", 71 | "Bicycle wheel", 72 | "Auto part", 73 | "Door handle", 74 | "Clothing", 75 | "Footwear", 76 | "Fashion accessory", 77 | "Sports equipment", 78 | "Hiking equipment", 79 | "Mammal", 80 | "Personal care", 81 | "Bathroom accessory", 82 | "Plumbing fixture", 83 | "Land vehicle", 84 | "train front", 85 | "background", 86 | "NONE" 87 | ], 88 | "val_blacklist_categories": [] 89 | } 90 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | *.DS_Store 10 | 11 | # Distribution / packaging 12 | .Python 13 | build/ 14 | develop-eggs/ 15 | dist/ 16 | downloads/ 17 | eggs/ 18 | .eggs/ 19 | lib/ 20 | lib64/ 21 | parts/ 22 | sdist/ 23 | var/ 24 | wheels/ 25 | pip-wheel-metadata/ 26 | share/python-wheels/ 27 | *.egg-info/ 28 | .installed.cfg 29 | *.egg 30 | MANIFEST 31 | 32 | # PyInstaller 33 | # Usually these files are written by a python script from a template 34 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 35 | *.manifest 36 | *.spec 37 | 38 | # Installer logs 39 | pip-log.txt 40 | pip-delete-this-directory.txt 41 | 42 | # Unit test / coverage reports 43 | htmlcov/ 44 | .tox/ 45 | .nox/ 46 | .coverage 47 | .coverage.* 48 | .cache 49 | nosetests.xml 50 | coverage.xml 51 | *.cover 52 | *.py,cover 53 | .hypothesis/ 54 | .pytest_cache/ 55 | 56 | # Translations 57 | *.mo 58 | *.pot 59 | 60 | # Django stuff: 61 | *.log 62 | local_settings.py 63 | db.sqlite3 64 | db.sqlite3-journal 65 | 66 | # Flask stuff: 67 | instance/ 68 | .webassets-cache 69 | 70 | # Scrapy stuff: 71 | .scrapy 72 | 73 | # Sphinx documentation 74 | docs/_build/ 75 | 76 | # PyBuilder 77 | target/ 78 | 79 | # Jupyter Notebook 80 | .ipynb_checkpoints 81 | 82 | # IPython 83 | profile_default/ 84 | ipython_config.py 85 | 86 | # pyenv 87 | .python-version 88 | 89 | # pipenv 90 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 91 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 92 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 93 | # install all needed dependencies. 94 | #Pipfile.lock 95 | 96 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 97 | __pypackages__/ 98 | 99 | # Celery stuff 100 | celerybeat-schedule 101 | celerybeat.pid 102 | 103 | # SageMath parsed files 104 | *.sage.py 105 | 106 | # Environments 107 | .env 108 | .venv 109 | env/ 110 | venv/ 111 | ENV/ 112 | env.bak/ 113 | venv.bak/ 114 | 115 | # Spyder project settings 116 | .spyderproject 117 | .spyproject 118 | 119 | # Rope project settings 120 | .ropeproject 121 | 122 | # mkdocs documentation 123 | /site 124 | 125 | # mypy 126 | .mypy_cache/ 127 | .dmypy.json 128 | dmypy.json 129 | 130 | # Pyre type checker 131 | .pyre/ 132 | -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | from typing import Any, List 2 | from yacs.config import CfgNode as CN 3 | 4 | class Config(object): 5 | 6 | def __init__(self, config_yaml: str, config_override: List[Any] = []): 7 | self._C = CN() 8 | self._C.random_seed = 0 9 | self._C.train_path = "" 10 | self._C.dev_path = "" 11 | self._C.test_path = "" 12 | self._C.train_obj_h5_path = "" 13 | self._C.dev_obj_h5_path = "" 14 | self._C.test_obj_h5_path = "" 15 | self._C.train_copy_obj_h5_path = "" 16 | self._C.dev_copy_obj_h5_path = "" 17 | self._C.test_copy_obj_h5_path = "" 18 | self._C.object_blacklist_path = "" 19 | self._C.copy_vocab_path = "" 20 | self._C.lm_type = "t5-large" 21 | self._C.vocab_size = 0 22 | self._C.use_pointer = False 23 | self._C.batch_size = 192 24 | self._C.max_epoch = 20 25 | self._C.gradient_accumulation_steps = 1 26 | self._C.checkpoint_every_step = 1000 27 | self._C.weight_decay = 0.0 28 | self._C.adam_epsilon = 1e-8 29 | self._C.learning_rate = 5e-5 30 | self._C.warmup_step = 400 31 | self._C.num_training_steps = 0 32 | self._C.grad_clip_value = 0 33 | self._C.use_mention_flag = False 34 | self._C.mention_flag_state = 3 35 | self._C.max_generation_len = 25 36 | self._C.relation_map_path = "" 37 | self._C.entity_map_path = "" 38 | self._C.word_norm_jsonpath = "" 39 | self._C.enable_visual = False 40 | self._C.roi_dim = 2048 41 | self._C.box_dim = 8 42 | self._C.use_copy_obj = False 43 | self._C.rm_dumplicated_caption = False 44 | self._C.shuffle_data = False 45 | self._C.rm_punctuation = False 46 | self._C.external_eval = False 47 | self._C.relative_pos_num = 0 48 | self._C.use_orginal_enc_pos_embs = False 49 | self._C.freeze_param = True 50 | self._C.freeze_enc_pos_param = True 51 | self._C.decode_constrain = "" 52 | self._C.static_mf = False 53 | self._C.do_pretrain_lm_init = True 54 | self._C.use_mf_scalar = False 55 | self._C.use_mf_merged = False 56 | 57 | 58 | # Override parameter values from YAML file first, then from override list. 59 | self._C.merge_from_file(config_yaml) 60 | self._C.merge_from_list(config_override) 61 | 62 | # Make an instantiated object of this class immutable. 63 | self._C.freeze() 64 | 65 | def dump(self, file_path: str): 66 | self._C.dump(stream=open(file_path, "w")) 67 | 68 | def __getattr__(self, attr: str): 69 | return self._C.__getattr__(attr) 70 | 71 | def __str__(self): 72 | return _config_str(self) 73 | 74 | def __repr__(self): 75 | return self._C.__repr__() 76 | 77 | 78 | def _config_str(config: Config) -> str: 79 | r""" 80 | Collect a subset of config in sensible order (not alphabetical) according to phase. Used by 81 | :func:`Config.__str__()`. 82 | 83 | Parameters 84 | ---------- 85 | config: Config 86 | A :class:`Config` object which is to be printed. 87 | """ 88 | _C = config 89 | 90 | __C: CN = CN({"RANDOM_SEED": _C.random_seed}) 91 | common_string: str = str(__C) + "\n" 92 | 93 | return common_string -------------------------------------------------------------------------------- /dataset/nocaps_novel_constraint.txt: -------------------------------------------------------------------------------- 1 | 1604 2 | 1605 3 | 1610 4 | 1613 5 | 1614 6 | 1616 7 | 1617 8 | 1621 9 | 1622 10 | 1632 11 | 1634 12 | 1635 13 | 1636 14 | 1637 15 | 1642 16 | 1643 17 | 1647 18 | 1650 19 | 1653 20 | 1654 21 | 1657 22 | 1658 23 | 1665 24 | 1670 25 | 1671 26 | 1674 27 | 1677 28 | 1678 29 | 1680 30 | 1684 31 | 1687 32 | 1690 33 | 1692 34 | 1693 35 | 1695 36 | 1697 37 | 1704 38 | 1705 39 | 1707 40 | 1708 41 | 1711 42 | 1713 43 | 1715 44 | 1716 45 | 1717 46 | 1718 47 | 1719 48 | 1724 49 | 1726 50 | 1727 51 | 1728 52 | 1729 53 | 1731 54 | 1736 55 | 1738 56 | 1739 57 | 1740 58 | 1741 59 | 1742 60 | 1745 61 | 1747 62 | 1749 63 | 1751 64 | 1754 65 | 1756 66 | 1757 67 | 1759 68 | 1760 69 | 1763 70 | 1764 71 | 1765 72 | 1766 73 | 1767 74 | 1770 75 | 1774 76 | 1775 77 | 1776 78 | 1778 79 | 1780 80 | 1783 81 | 1785 82 | 1786 83 | 1788 84 | 1789 85 | 1791 86 | 1792 87 | 1793 88 | 1795 89 | 1796 90 | 1797 91 | 1799 92 | 1800 93 | 1801 94 | 1802 95 | 1803 96 | 1805 97 | 1810 98 | 1811 99 | 1814 100 | 1815 101 | 1816 102 | 1822 103 | 1823 104 | 1825 105 | 1830 106 | 1831 107 | 1833 108 | 1838 109 | 1839 110 | 1840 111 | 1841 112 | 1844 113 | 1847 114 | 1848 115 | 1849 116 | 1853 117 | 1854 118 | 1855 119 | 1856 120 | 1857 121 | 1862 122 | 1863 123 | 1866 124 | 1867 125 | 1868 126 | 1869 127 | 1872 128 | 1873 129 | 1874 130 | 1875 131 | 1876 132 | 1877 133 | 1878 134 | 1879 135 | 1880 136 | 1883 137 | 1884 138 | 1885 139 | 1888 140 | 1889 141 | 1890 142 | 1891 143 | 1892 144 | 1893 145 | 1895 146 | 1896 147 | 1897 148 | 1898 149 | 1899 150 | 1902 151 | 1906 152 | 1908 153 | 1910 154 | 1912 155 | 1921 156 | 1922 157 | 1923 158 | 1924 159 | 1926 160 | 1930 161 | 1931 162 | 1933 163 | 1937 164 | 1938 165 | 1939 166 | 1943 167 | 1947 168 | 1949 169 | 1951 170 | 1952 171 | 1953 172 | 1955 173 | 1956 174 | 1957 175 | 1961 176 | 1962 177 | 1963 178 | 1968 179 | 1969 180 | 1971 181 | 1972 182 | 1973 183 | 1974 184 | 1975 185 | 1976 186 | 1977 187 | 1979 188 | 1983 189 | 1984 190 | 1996 191 | 1997 192 | 1999 193 | 2002 194 | 2004 195 | 2007 196 | 2012 197 | 2013 198 | 2015 199 | 2017 200 | 2029 201 | 2032 202 | 2033 203 | 2034 204 | 2035 205 | 2037 206 | 2039 207 | 2044 208 | 2045 209 | 2046 210 | 2052 211 | 2053 212 | 2054 213 | 2055 214 | 2056 215 | 2058 216 | 2059 217 | 2060 218 | 2062 219 | 2066 220 | 2067 221 | 2068 222 | 2074 223 | 2079 224 | 2084 225 | 2085 226 | 2088 227 | 2089 228 | 2095 229 | 2096 230 | 2097 231 | 2101 232 | 2103 233 | 2104 234 | 2107 235 | 2108 236 | 2109 237 | 2111 238 | 2115 239 | 2122 240 | 2124 241 | 2125 242 | 2126 243 | 2127 244 | 2128 245 | 2134 246 | 2137 247 | 2138 248 | 2139 249 | 2140 250 | 2144 251 | 2145 252 | 2147 253 | 2148 254 | 2149 255 | 2150 256 | 2151 257 | 2152 258 | 2157 259 | 2161 260 | 2163 261 | 2168 262 | 2169 263 | 2171 264 | 2174 265 | 2175 266 | 2178 267 | 2181 268 | 2183 269 | 2184 270 | 2185 271 | 2186 272 | 2187 273 | 2188 274 | 2192 275 | 2193 276 | 2194 277 | 2197 278 | 2199 279 | 2200 280 | 2201 281 | -------------------------------------------------------------------------------- /dataset/EvalAI.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict 2 | import json 3 | import re 4 | import subprocess 5 | import tempfile 6 | import time 7 | from typing import Any, Dict, List 8 | 9 | from mypy_extensions import TypedDict 10 | 11 | 12 | Prediction = TypedDict("Prediction", {"image_id": int, "caption": str}) 13 | 14 | 15 | class NocapsEvaluator(object): 16 | def __init__(self, phase: str = "val"): 17 | 18 | # Constants specific to EvalAI. 19 | self._challenge_id = 355 20 | self._phase_id = 742 if phase == "val" else 743 21 | 22 | def evaluate(self, predictions: List[Prediction], request_metrics=None): 23 | # Save predictions as a json file first. 24 | _, predictions_filename = tempfile.mkstemp(suffix=".json", text=True) 25 | json.dump(predictions, open(predictions_filename, "w")) 26 | 27 | submission_command = f"evalai challenge {self._challenge_id} phase {self._phase_id} " \ 28 | f"submit --file {predictions_filename}" 29 | 30 | submission_command_subprocess = subprocess.Popen( 31 | submission_command.split(), 32 | stdout=subprocess.PIPE, 33 | stdin=subprocess.PIPE, 34 | stderr=subprocess.STDOUT, 35 | ) 36 | 37 | # This terminal output will have submission ID we need to check. 38 | submission_command_stdout = submission_command_subprocess.communicate(input=b"N\n")[0].decode("utf-8") 39 | 40 | submission_id_regex = re.search("evalai submission ([0-9]+)", submission_command_stdout) 41 | 42 | # Get an integer submission ID (as a string). 43 | submission_id = submission_id_regex.group(0).split()[-1] # type: ignore 44 | 45 | # Placeholder stdout for a pending submission. 46 | result_stdout: str = "The Submission is yet to be evaluated." 47 | num_tries: int = 0 48 | 49 | # Query every 10 seconds for result until it appears. 50 | while "CIDEr" not in result_stdout: 51 | 52 | time.sleep(10) 53 | 54 | result_stdout = subprocess.check_output( 55 | ["evalai", "submission", submission_id, "result"] 56 | ).decode("utf-8") 57 | 58 | num_tries += 1 59 | 60 | # Raise error if it takes more than 10 minutes. 61 | if num_tries == 60: 62 | raise ConnectionError("Unable to get results from EvalAI within 10 minutes!") 63 | 64 | # Convert result to json. 65 | # keys: {"in-domain", "near-domain", "out-domain", "entire"} 66 | # In each of these, keys: {"B1", "B2", "B3", "B4", "METEOR", "ROUGE-L", "CIDEr", "SPICE"} 67 | metrics = json.loads(result_stdout, encoding="utf-8") 68 | 69 | # Restructure the metrics dict for better tensorbaord logging. 70 | metrics = { 71 | "in-domain": metrics[0]["in-domain"], 72 | "near-domain": metrics[1]["near-domain"], 73 | "out-domain": metrics[2]["out-domain"], 74 | "entire": metrics[3]["entire"], 75 | } 76 | 77 | flipped_metrics: Dict[str, Any] = defaultdict(dict) 78 | for key, val in metrics.items(): 79 | for subkey, subval in val.items(): 80 | flipped_metrics[subkey][key] = subval 81 | 82 | # keys: {"B1", "B2", "B3", "B4", "METEOR", "ROUGE-L", "CIDEr", "SPICE"} 83 | # In each of these, keys: keys: {"in-domain", "near-domain", "out-domain", "entire"} 84 | metrics = flipped_metrics 85 | 86 | return metrics 87 | -------------------------------------------------------------------------------- /checkpointing.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import os 3 | from typing import Any, Dict, Optional, Union, Type 4 | 5 | import torch 6 | from torch import nn, optim 7 | 8 | 9 | class CheckpointManager(object): 10 | r""" 11 | A :class:`CheckpointManager` periodically serializes models and optimizer as .pth files during 12 | training, and keeps track of best performing checkpoint based on an observed metric. 13 | 14 | Extended Summary 15 | ---------------- 16 | It saves state dicts of models and optimizer as ``.pth`` files in a specified directory. This 17 | class closely follows the API of PyTorch optimizers and learning rate schedulers. 18 | 19 | Notes 20 | ----- 21 | For :class:`~torch.nn.DataParallel` objects, ``.module.state_dict()`` is called instead of 22 | ``.state_dict()``. 23 | 24 | Parameters 25 | ---------- 26 | models: Dict[str, torch.nn.Module] 27 | Models which need to be serialized as a checkpoint. 28 | optimizer: torch.optim.Optimizer 29 | Optimizer which needs to be serialized as a checkpoint. 30 | serialization_dir: str 31 | Path to an empty or non-existent directory to save checkpoints. 32 | mode: str, optional (default="max") 33 | One of ``min``, ``max``. In ``min`` mode, best checkpoint will be recorded when metric 34 | hits a lower value; in `max` mode it will be recorded when metric hits a higher value. 35 | filename_prefix: str, optional (default="checkpoint") 36 | Prefix of the to-be-saved checkpoint files. 37 | 38 | Examples 39 | -------- 40 | >>> model = torch.nn.Linear(10, 2) 41 | >>> optimizer = torch.optim.Adam(model.parameters()) 42 | >>> ckpt_manager = CheckpointManager({"model": model}, optimizer, "/tmp/ckpt", mode="min") 43 | >>> num_epochs = 20 44 | >>> for epoch in range(num_epochs): 45 | ... train(model) 46 | ... val_loss = validate(model) 47 | ... ckpt_manager.step(val_loss, epoch) 48 | """ 49 | 50 | def __init__( 51 | self, 52 | models: Union[nn.Module, Dict[str, nn.Module]], 53 | serialization_dir: str, 54 | mode: str = "max", 55 | filename_prefix: str = "model", 56 | ): 57 | 58 | # Convert single model to a dict. 59 | if isinstance(models, nn.Module): 60 | models = {"model": models} 61 | 62 | for key in models: 63 | if not isinstance(models[key], nn.Module): 64 | raise TypeError("{} is not a Module".format(type(models).__name__)) 65 | 66 | self._models = models 67 | self._serialization_dir = serialization_dir 68 | 69 | self._mode = mode 70 | self._filename_prefix = filename_prefix 71 | 72 | # Initialize members to hold state dict of best checkpoint and its performance. 73 | self._best_metric: Optional[Union[float, torch.Tensor]] = None 74 | 75 | 76 | def step(self, metric: Union[float, torch.Tensor]): 77 | r"""Serialize checkpoint and update best checkpoint based on metric and mode.""" 78 | 79 | # Update best checkpoint based on metric and metric mode. 80 | if not self._best_metric: 81 | self._best_metric = metric 82 | 83 | models_state_dict: Dict[str, Any] = {} 84 | for key in self._models: 85 | if isinstance(self._models[key], nn.DataParallel): 86 | models_state_dict[key] = self._models[key].module.state_dict() 87 | else: 88 | models_state_dict[key] = self._models[key].state_dict() 89 | 90 | if (self._mode == "min" and metric <= self._best_metric) or ( 91 | self._mode == "max" and metric >= self._best_metric 92 | ): 93 | self._best_metric = metric 94 | 95 | # Serialize checkpoint corresponding to current epoch (or iteration). 96 | torch.save( 97 | models_state_dict, 98 | os.path.join( 99 | self._serialization_dir, f"{self._filename_prefix}-best.pth" 100 | ), 101 | ) 102 | 103 | return True 104 | 105 | return False 106 | -------------------------------------------------------------------------------- /constraint.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | class cbs_matrix: 5 | 6 | def __init__(self, vocab_size): 7 | self.matrix = None 8 | self.vocab_size = vocab_size 9 | 10 | def init_matrix(self, state_size): 11 | self.matrix = np.zeros((1, state_size, state_size, self.vocab_size), dtype=np.uint8) 12 | 13 | def add_connect(self, from_state, to_state, w_group): 14 | assert self.matrix is not None 15 | for w_index in w_group: 16 | self.matrix[0, from_state, to_state, w_index] = 1 17 | self.matrix[0, from_state, from_state, w_index] = 0 18 | 19 | def add_connect_except(self, from_state, to_state, w_group): 20 | excluded_group_word = [w for w in range(self.vocab_size) if w not in w_group] 21 | self.add_connect(from_state, to_state, excluded_group_word) 22 | 23 | def init_row(self, state_index): 24 | assert self.matrix is not None 25 | self.matrix[0, state_index, state_index, :] = 1 26 | 27 | def get_matrix(self): 28 | return self.matrix 29 | 30 | def CBSConstraint(CBS_type, max_constrain_num): 31 | if CBS_type == 'Two': 32 | assert max_constrain_num <= 2 33 | return TwoConstraint() 34 | elif CBS_type == 'GBS': 35 | return GBSConstraint(max_constrain_num) 36 | else: 37 | raise NotImplementedError 38 | 39 | class Constraint: 40 | 41 | constraint_max_length = 6 42 | _num_cls = {} 43 | _cache = {} 44 | 45 | def connect_edge(self, M, additional_state, from_state, to_state, constraint): 46 | queue = [(from_state, c) for c in constraint] 47 | new_queue = [] 48 | index2state = {} 49 | while len(queue) > 0: 50 | (f_state, c) = queue.pop(0) 51 | if len(c) == 1: 52 | M.add_connect(f_state, to_state, c) 53 | else: 54 | if c[0] not in index2state: 55 | index2state[c[0]] = additional_state 56 | additional_state += 1 57 | M.add_connect(f_state, index2state[c[0]], [c[0]]) 58 | if not f_state == from_state: 59 | M.add_connect_except(f_state, from_state, [c[0]]) 60 | new_queue.append((index2state[c[0]], c[1:])) 61 | 62 | if len(queue) == 0 and len(new_queue) > 0: 63 | queue = new_queue 64 | new_queue = [] 65 | index2state = {} 66 | 67 | return M, additional_state 68 | 69 | class TwoConstraint(Constraint): 70 | 71 | def __init__(self): 72 | super(TwoConstraint).__init__() 73 | self.state_size = 100 #4 * self.constraint_max_length 74 | 75 | def select_state_func(self, beam_prediction, image_ids): 76 | bp = [] 77 | for i, image_id in enumerate(image_ids): 78 | if self._num_cls[image_id] == 0: 79 | bp.append(beam_prediction[i, 0].unsqueeze(0)) 80 | elif self._num_cls[image_id] == 1: 81 | bp.append(beam_prediction[i, 1].unsqueeze(0)) 82 | elif self._num_cls[image_id] == 2: 83 | bp.append(beam_prediction[i, 3].unsqueeze(0)) 84 | return torch.cat(bp, dim=0) 85 | 86 | def get_state_matrix(self, output_size, constraints, image_id): 87 | assert len(constraints) <= 2 88 | M = cbs_matrix(output_size) 89 | M.init_matrix(self.state_size) 90 | 91 | self._num_cls[image_id] = len(constraints) 92 | con_str = [] 93 | for c in constraints: 94 | c_list = ['#'.join([str(i) for i in x]) for x in c] 95 | con_str.append('^'.join(c_list)) 96 | marker = '*'.join(con_str) if len(con_str) > 0 else '***' 97 | if marker not in self._cache: 98 | if self._num_cls[image_id] == 0: 99 | additional_state = 1 100 | for i in range(1): 101 | M.init_row(i) 102 | elif self._num_cls[image_id] == 1: 103 | for i in range(2): 104 | M.init_row(i) 105 | additional_state = 2 106 | c1 = constraints[0] 107 | c1 = [w[:self.constraint_max_length + 1] for w in c1] 108 | M, additional_state = self.connect_edge(M, additional_state, 0, 1, c1) 109 | else: 110 | for i in range(4): 111 | M.init_row(i) 112 | additional_state = 4 113 | c1, c2 = constraints[0], constraints[1] 114 | c1 = [w[:self.constraint_max_length + 1] for w in c1] 115 | c2 = [w[:self.constraint_max_length + 1] for w in c2] 116 | M, additional_state = self.connect_edge(M, additional_state, 0, 1, c1) 117 | M, additional_state = self.connect_edge(M, additional_state, 0, 2, c2) 118 | M, additional_state = self.connect_edge(M, additional_state, 1, 3, c2) 119 | M, additional_state = self.connect_edge(M, additional_state, 2, 3, c1) 120 | 121 | self._cache[marker] = (M.get_matrix(), additional_state) 122 | 123 | return self._cache[marker] 124 | 125 | 126 | class GBSConstraint(Constraint): 127 | 128 | def __init__(self, max_constrain_num): 129 | super(GBSConstraint).__init__() 130 | self.state_size = 100 #(max_constrain_num ** 2) * (self.constraint_max_length - 1) + max_constrain_num + 1 131 | self.max_constrain_num = max_constrain_num 132 | 133 | def get_state_matrix(self, output_size, constraints, image_id): 134 | assert len(constraints) <= self.max_constrain_num 135 | 136 | M = cbs_matrix(output_size) 137 | M.init_matrix(self.state_size) 138 | 139 | self._num_cls[image_id] = len(constraints) 140 | con_str = [] 141 | for c in constraints: 142 | c_list = ['#'.join([str(i) for i in x]) for x in c] 143 | con_str.append('^'.join(c_list)) 144 | marker = '*'.join(con_str) if len(con_str) > 0 else '***' 145 | 146 | if marker not in self._cache: 147 | comb_constrains = [] 148 | for c in constraints: 149 | comb_constrains += c 150 | 151 | additional_state = len(constraints) + 1 152 | for i in range(additional_state): 153 | M.init_row(i) 154 | 155 | for i in range(len(constraints)): 156 | M, additional_state = self.connect_edge(M, additional_state, i, i + 1, comb_constrains) 157 | 158 | self._cache[marker] = (M.get_matrix(), additional_state) 159 | return self._cache[marker] 160 | 161 | def select_state_func(self, beam_prediction, image_ids): 162 | bp = [] 163 | for i, image_id in enumerate(image_ids): 164 | bp.append(beam_prediction[i, self._num_cls[image_id]].unsqueeze(0)) 165 | return torch.cat(bp, dim=0) 166 | -------------------------------------------------------------------------------- /dataset/e2e_dataset.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import Dataset, DataLoader 2 | from dataset.data_utils import * 3 | import sys 4 | import json 5 | import copy 6 | from tqdm import tqdm 7 | import random 8 | 9 | class E2EDataset(Dataset): 10 | 11 | def __init__(self, config, json_path, tokenizer, copy_vocab, is_training=False): 12 | super(E2EDataset, self).__init__() 13 | 14 | self.config = config 15 | self.tokenizer = tokenizer 16 | self.is_training = is_training 17 | self.copy_vocab = copy_vocab 18 | np.set_printoptions(threshold=sys.maxsize) 19 | 20 | self.keyword_norm = {'eatType': 'type', 'familyFriendly': 'family friendly', 'priceRange': 'price range'} 21 | self.key_index = {'type': 1, 'family friendly': 3, 'price range': 4, 'near': 5, 'name': 6, 'food': 7, 'area': 8, 'customer rating': 9} 22 | 23 | self.read_content(json_path) 24 | 25 | def __len__(self): 26 | return len(self.record) 27 | 28 | def __getitem__(self, index): 29 | ins_id, enc_input, enc_cls, mf, out, gt, gt_mr = self.record[index] 30 | 31 | data_item = { 32 | "ins_id": ins_id, 33 | "encoder_input_ids": enc_input, 34 | "encoder_class": enc_cls, 35 | "gt": gt, 36 | "gt_mr": gt_mr, 37 | "mention_flag": mf 38 | } 39 | if out is not None: 40 | data_item['cap'] = out 41 | 42 | return data_item 43 | 44 | def get_mention_index(self, cls_index, list_v): 45 | for (_, fg_index) in self.copy_vocab.d_to_w_group[cls_index]: 46 | fg_ch_list = self.copy_vocab.token_fg_w[fg_index] 47 | s1 = '&'.join([str(f) for f in fg_ch_list]) 48 | for ch_idx, first_ch in enumerate(list_v): 49 | if first_ch == fg_ch_list[0]: 50 | s2 = '&'.join([str(f) for f in list_v[ch_idx: ch_idx + len(fg_ch_list)]]) 51 | if s1 == s2: 52 | return ch_idx, ch_idx + len(fg_ch_list) 53 | return -1,-1 54 | 55 | def read_content(self, json_path): 56 | print("reading data from %s ..." % json_path) 57 | 58 | self.record = [] 59 | 60 | with open(json_path) as out: 61 | instances = json.loads(out.read()) 62 | 63 | total_input = 0 64 | match_input = 0 65 | type_info_ratio = {} 66 | for ins_id, instance in tqdm(enumerate(instances)): 67 | gt = copy.deepcopy(instance['ref']) 68 | gt_mr = [] 69 | new_mr_input = [] 70 | for mr in instance['mr']: 71 | each_mr = mr.replace('[', ' ') 72 | each_mr = each_mr.replace(']', ' ') 73 | each_mr = each_mr.strip() 74 | words = each_mr.split() 75 | if words[0] in self.keyword_norm: 76 | tag = self.keyword_norm[words[0]] 77 | value = ' '.join([x.strip() for x in words[1:]]) 78 | else: 79 | tag = words[0] if not words[0] == 'customer' else ' '.join(words[:2]) 80 | value = ' '.join([x.strip() for x in words[1:]]) if not words[0] == 'customer' else ' '.join([x.strip() for x in words[2:]]) 81 | new_mr_input.append(tag + ' ' + value) 82 | gt_mr.append((mr, value)) 83 | 84 | input_cls_info = [] 85 | for (text, gt_m) in zip(new_mr_input, gt_mr): 86 | cls_id = self.copy_vocab.word_to_category_id[gt_m[0]] 87 | m_input = self.tokenizer(text, return_tensors="np")['input_ids'][0, :-1].tolist() 88 | input_cls_info.append((cls_id, m_input, gt_m[1])) 89 | 90 | encoder_input = [] 91 | encoder_cls = [] 92 | for (cls_id, m_input, _) in input_cls_info: 93 | encoder_input += m_input 94 | encoder_cls += [cls_id] * len(m_input) 95 | encoder_input.append(self.tokenizer.eos_token_id) 96 | encoder_cls.append(0) 97 | encoder_input = np.array(encoder_input, dtype=np.int64) 98 | encoder_cls = np.array(encoder_cls, dtype=np.int64) 99 | 100 | if not self.is_training: 101 | mention_flag = np.array(encoder_cls > 0, dtype=np.int64) 102 | mention_flag = mention_flag[np.newaxis, :] 103 | self.record.append((ins_id, encoder_input, encoder_cls, mention_flag, None, gt, gt_mr)) 104 | else: 105 | for v in instance['ref']: 106 | ref = v 107 | v = ' '.join(v.split()).lower() 108 | v = self.tokenizer(v, return_tensors="np")['input_ids'][0, :self.config.max_generation_len] 109 | list_v = v.tolist() 110 | 111 | mentioned_cls_pos = [] 112 | for (cls_id, m_input, name) in input_cls_info: 113 | s_pos, e_pos = self.get_mention_index(cls_id, list_v) 114 | mentioned_cls_pos.append((s_pos, e_pos, cls_id)) 115 | total_input += 1 116 | if s_pos >= 0 and e_pos >= 0: 117 | match_input += 1 118 | # else: 119 | # print(self.copy_vocab.d_to_w_group[cls_id]) 120 | # print(ref) 121 | # print("----------") 122 | 123 | encoder_input = np.array(encoder_input, dtype=np.int64) 124 | encoder_cls = np.array(encoder_cls, dtype=np.int64) 125 | mention_flag = np.zeros((v.shape[0], encoder_input.shape[0]), dtype=np.int64) 126 | 127 | for (s_pos, e_pos, cls_id) in mentioned_cls_pos: 128 | for e_index in range(encoder_cls.shape[0]): 129 | if encoder_cls[e_index] == cls_id: 130 | if e_pos >= 0: 131 | mention_flag[:e_pos, e_index] = 1 132 | if not self.config.static_mf: 133 | mention_flag[e_pos:, e_index] = 2 134 | else: 135 | mention_flag[e_pos:, e_index] = 1 136 | else: 137 | mention_flag[:, e_index] = 0 if not self.config.use_mf_merged else 1 138 | self.record.append((ins_id, encoder_input, encoder_cls, mention_flag, v, gt, gt_mr)) 139 | 140 | if self.is_training: 141 | random.shuffle(self.record) 142 | print("Match Ratio %.2f" % (100 * match_input / total_input)) 143 | 144 | 145 | def data_wrapper(dataset): 146 | new_dataset = {'gt': [d['gt'] for d in dataset], 'gt_mr': [d['gt_mr'] for d in dataset], 'ins_id': [d['ins_id'] for d in dataset]} 147 | 148 | encoder_input_ids, encoder_mask = process_tensor([d['encoder_input_ids'] for d in dataset], 0, output_mask=True) 149 | encoder_class = process_tensor([d['encoder_class'] for d in dataset], 0, output_mask=False) 150 | new_dataset['encoder_input_ids'] = encoder_input_ids 151 | new_dataset['encoder_mask'] = encoder_mask 152 | new_dataset['encoder_cls'] = encoder_class 153 | 154 | max_gen_len = 1 155 | if 'cap' in dataset[0]: 156 | cap_decoder_input_ids, cap_decoder_mask = process_tensor([d['cap'] for d in dataset], 0, output_mask=True) 157 | cap_decoder_input_ids[cap_decoder_mask == 0] = -100 158 | new_dataset['cap_decoder_input_ids'] = cap_decoder_input_ids 159 | max_gen_len = cap_decoder_input_ids.size(1) 160 | 161 | batch_size = len(dataset) 162 | max_encoder_len = encoder_input_ids.size(1) 163 | mention_flag = np.zeros((batch_size, max_gen_len, max_encoder_len), dtype=np.int64) 164 | for i, d in enumerate(dataset): 165 | mention_flag[i, :d['mention_flag'].shape[0], :d['mention_flag'].shape[1]] = d['mention_flag'] 166 | new_dataset['mention_flag'] = torch.from_numpy(mention_flag) 167 | 168 | return new_dataset 169 | 170 | 171 | def get_data_loader(dataset, batch_size): 172 | collate_fn = lambda d: data_wrapper(d) 173 | return DataLoader(dataset, 174 | batch_size=batch_size, 175 | num_workers=0, 176 | collate_fn=collate_fn 177 | ) 178 | 179 | 180 | 181 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ACL2021MF 2 | Source Code For ACL 2021 Paper "Mention Flags (MF): Constraining Transformer-based Text Generators" 3 | 4 | ## Data Download 5 | Please download the evaluation code from [here](https://drive.google.com/drive/folders/10pZHQwNxzTPALzDXqNZokSQOJJtRjB5p?usp=sharing) and put it into the dataset/ folder. 6 | 7 | The pre-trained models are available in [here](https://drive.google.com/drive/folders/1pOY_G4ygQ8C76mgGlchyc7jbEtwoY_r9?usp=sharing). Please download each file and put them into the dataset/ folder. 8 | 9 | The training, dev and test data for Commonsense Generation and E2E task are available in [here](https://drive.google.com/drive/folders/1i_rua8e3Pl230K9vy3su_wkSZrGykrT2?usp=sharing). Please download each file and put them into the dataset/ folder. 10 | 11 | The training, dev and test data for is coming soon. 12 | 13 | ## Dependency 14 | Before running the code, please install following dependencies: 15 | - python==3.6.1 16 | - transformers==3.5.1 17 | - numpy==1.19.2 18 | - yacs==0.1.6 19 | - tqdm==4.49.0 20 | - torch==1.4.0a0+f067088 21 | - h5py==2.7.0 22 | - anytree==2.7.3 23 | - dataclasses==0.7 24 | - typing==3.6.6 25 | 26 | ## Running Models 27 | 28 | ### CommonSen 29 | 30 | #### Training all models in the paper 31 | 32 | | Model | Command | 33 | |------------------------|-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| 34 | | Trans, L3 Baseline | python train_T5.py --config t5_base.yml --config-override use_mention_flag False do_pretrain_lm_init False freeze_param False --serialization-dir dataset/commonGen_transL3_baseline --train | 35 | | Trans, L3 Mention Flag | python train_T5.py --config t5_base.yml --config-override do_pretrain_lm_init False freeze_param False --serialization-dir dataset/commonGen_transL3_mf --train | 36 | | T5-Base Baseline | python train_T5.py --config t5_base.yml --config-override use_mention_flag False --serialization-dir dataset/commonGen_t5_base_baseline --train | 37 | | T5-Base Mention Flag | python train_T5.py --config t5_base.yml --serialization-dir dataset/commonGen_t5_base_mf --train | 38 | | T5-Large Baseline | python train_T5.py --config t5_large.yml --config-override use_mention_flag False --serialization-dir dataset/commonGen_t5_large_baseline --train | 39 | | T5-Large Mention Flag | python train_T5.py --config t5_large.yml --serialization-dir dataset/commonGen_t5_large_mf --train | 40 | | T5-Base Scalar Mf | python train_T5.py --config t5_base.yml --config-override use_mf_scalar True --serialization-dir dataset/commonGen_t5_base_scalar_mf --train | 41 | | T5-Base Static Mf | python train_T5.py --config t5_base.yml --config-override static_mf True --serialization-dir dataset/commonGen_t5_base_static_mf --train | 42 | 43 | #### Evluating models 44 | 45 | | Model | Command | 46 | |:---------------------:|:------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------:| 47 | | T5-Base Mention Flag | python train_T5.py --config dataset/commonGen_t5_base_mf/config.yml --start-from-checkpoint dataset/commonGen_t5_base_mf --test --seen-constraint-path dataset/commonGen_seen_constraint.txt | 48 | | T5-Large Mention Flag | python train_T5.py --config dataset/commonGen_t5_large_mf/config.yml --start-from-checkpoint dataset/commonGen_t5_large_mf --test --seen-constraint-path dataset/commonGen_seen_constraint.txt | 49 | 50 | ### E2E 51 | 52 | #### Training all models in the paper 53 | 54 | | Model | Command | 55 | |:----------------------:|:---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------:| 56 | | T5-Base Baseline | python train_e2e_T5.py --config e2e_t5_base.yml --serialization-dir dataset/e2e_baseline --train --config-override use_mention_flag False | 57 | | T5-Base Mention Flag | python train_e2e_T5.py --config e2e_t5_base.yml --serialization-dir dataset/e2e_mf --train | 58 | | Trans, L3 Baseline | python train_e2e_T5.py --config e2e_t5_base.yml --serialization-dir dataset/e2e_transL3_baseline --train --config-override use_mention_flag False do_pretrain_lm_init False freeze_param False | 59 | | Trans, L3 Mention Flag | python train_e2e_T5.py --config e2e_t5_base.yml --serialization-dir dataset/e2e_transL3_mf --train --config-override do_pretrain_lm_init False freeze_param False | 60 | | T5-Base Static MF | python train_e2e_T5.py --config e2e_t5_base.yml --serialization-dir dataset/e2e_static_mf --train --config-override static_mf True | 61 | | T5-Base Scalar MF | python train_e2e_T5.py --config e2e_t5_base.yml --serialization-dir dataset/e2e_scalar_mf --train --config-override use_mf_scalar True | 62 | | T5-Base Merged MF | python train_e2e_T5.py --config e2e_t5_base.yml --serialization-dir dataset/e2e_merged_mf --train --config-override use_mf_merged True | 63 | 64 | #### Evaluating models 65 | 66 | | Model | Command | 67 | |:-------:|:-------------------------------------------------------------------------------------------------------:| 68 | | T5-Base | python train_e2e_T5.py --config dataset/e2e_mf/config.yml --start-from-checkpoint dataset/e2e_mf --test | 69 | 70 | 71 | ### nocaps 72 | 73 | #### Training all models in the paper 74 | 75 | | Model | Command | 76 | |:----------------------:|:-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------:| 77 | | T5-Base Baseline | python train_COCO_T5.py --config COCO_t5_base_nocaps.yml --serialization-dir dataset/nocaps_baseline --train --config-override use_mention_flag False | 78 | | T5-Base Mention Flags | python train_COCO_T5.py --config COCO_t5_base_nocaps.yml --serialization-dir dataset/nocaps_mf --train | 79 | | Trans L3 Baseline | python train_COCO_T5.py --config COCO_t5_base_nocaps.yml --serialization-dir dataset/nocaps_baseline_transL3 --train --config-override use_mention_flag False do_pretrain_lm_init False freeze_param False | 80 | | Trans L3 Mention Flags | python train_COCO_T5.py --config COCO_t5_base_nocaps.yml --serialization-dir dataset/nocaps_mf_transL3 --train --config-override use_mention_flag True do_pretrain_lm_init False freeze_param False | 81 | | T5-Base Scalar MF | python train_COCO_T5.py --config COCO_t5_base_nocaps.yml --serialization-dir dataset/nocaps_scalar_mf --train --config-override use_mf_scalar True | 82 | | T5-Base Static MF | python train_COCO_T5.py --config COCO_t5_base_nocaps.yml --serialization-dir dataset/nocaps_static_mf --train --config-override static_mf True | 83 | 84 | #### Evaluating models 85 | 86 | | Model | Command | 87 | |:--------------------:|:----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------:| 88 | | T5-Base Mention Flag | python train_COCO_T5.py --config dataset/nocaps_mf/config.yml --start-from-checkpoint dataset/nocaps_mf --validation --novel-constraint-path dataset/nocaps_novel_constraint.txt | 89 | -------------------------------------------------------------------------------- /dataset/dataset.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import Dataset, DataLoader 2 | import json 3 | import random 4 | from tqdm import tqdm 5 | import numpy as np 6 | import torch 7 | import copy 8 | import re 9 | import string 10 | import sys 11 | 12 | class CommonGenDataset(Dataset): 13 | 14 | def __init__(self, config, json_path, tokenizer, copy_vocab, decoder_start_token_id, is_training=False, attachable_index=None): 15 | super(CommonGenDataset, self).__init__() 16 | 17 | self.config = config 18 | self.copy_vocab = copy_vocab 19 | self.tokenizer = tokenizer 20 | self.is_training = is_training 21 | self.decoder_start_token_id = decoder_start_token_id 22 | self.attachable_index = attachable_index 23 | np.set_printoptions(threshold=sys.maxsize) 24 | self.read_content(json_path) 25 | 26 | def read_content(self, json_path): 27 | print("reading data from %s ..." % json_path) 28 | self.record = [] 29 | with open(json_path) as out: 30 | lines = out.readlines() 31 | for l in tqdm(lines): 32 | item = json.loads(l.strip()) 33 | concept_set = ' '.join(item['concept_set'].split('#')) 34 | concept_set = 'generate a sentence with these concepts : ' 35 | concept_set_input_ids = self.tokenizer(concept_set, return_tensors="np")['input_ids'][0, :-1].tolist() 36 | concept_cls = [] 37 | for concept in item['concept_set'].split('#'): 38 | if concept in self.copy_vocab.word_to_category_id: 39 | concept_cls.append(self.copy_vocab.word_to_category_id[concept]) 40 | else: 41 | fg_index = self.copy_vocab.w_to_i[concept] 42 | concept_cls.append(self.copy_vocab.i_to_cls[fg_index]) 43 | 44 | assert len(concept_cls) <= 5 45 | start_pos = [] 46 | for i, c_cls in enumerate(concept_cls): 47 | start_pos.append(len(concept_set_input_ids)) 48 | concept_set_input_ids += self.copy_vocab.token_class[c_cls] 49 | if i == len(concept_cls) - 1: 50 | concept_set_input_ids.append(self.tokenizer.eos_token_id) 51 | start_pos.append(len(concept_set_input_ids) - 1) 52 | 53 | position_indicator = np.zeros((5, len(concept_set_input_ids)), dtype=np.float32) 54 | for i in range(len(concept_cls)): 55 | position_indicator[i, start_pos[i]:start_pos[i+1]] = 1 56 | sum_check = np.sum(position_indicator, axis=1) 57 | for i in range(len(concept_cls)): 58 | assert sum_check[i] > 0 59 | 60 | cls_on_input = np.zeros((len(concept_set_input_ids), ), dtype=np.int64) 61 | for i, cls_ in enumerate(concept_cls): 62 | cls_on_input[start_pos[i]:start_pos[i+1]] = cls_ 63 | 64 | gt = copy.deepcopy(item['scene']) 65 | if self.is_training: 66 | for c in item['scene']: 67 | c = c.lower() 68 | c_input_ids = self.tokenizer(c, return_tensors="np")['input_ids'][0] 69 | string_caption = ' '.join([str(x) for x in c_input_ids]) 70 | if self.config.use_pointer: 71 | for c_cls in concept_cls: 72 | for _, fg_index in self.copy_vocab.d_to_w_group[c_cls]: 73 | fg_word_index = self.copy_vocab.token_fg_w[fg_index] 74 | fg_str = ' '.join([str(x) for x in fg_word_index]) 75 | fg_softmax_index = self.config.vocab_size + fg_index 76 | 77 | string_caption = re.sub(' %s ' % fg_str, ' (%d) ' % fg_softmax_index, string_caption) 78 | string_caption = re.sub(' %s$' % fg_str, ' (%d)' % fg_softmax_index, string_caption) 79 | string_caption = re.sub('^%s ' % fg_str, '(%d) ' % fg_softmax_index, string_caption) 80 | 81 | c_input_ids = [] 82 | str_id_list = string_caption.split() 83 | copy_mention_flag = np.zeros((len(str_id_list) + 1, 5)) 84 | decoder_mention_flag = np.zeros((len(str_id_list), len(concept_set_input_ids))) 85 | use_this_record = True 86 | if self.config.use_pointer: 87 | for index, w in enumerate(str_id_list): 88 | if w.startswith('(') and w.endswith(')'): 89 | fg_index = int(w[1:-1]) 90 | cls_index = self.copy_vocab.i_to_cls[fg_index - self.config.vocab_size] 91 | for j in range(len(concept_cls)): 92 | if concept_cls[j] == cls_index: 93 | copy_mention_flag[:index + 1, j] = 1 94 | copy_mention_flag[index + 1:, j] = 2 95 | decoder_mention_flag[:index + 1, start_pos[j]:start_pos[j+1]] = 1 96 | decoder_mention_flag[index + 1:, start_pos[j]:start_pos[j+1]] = 2 97 | c_input_ids.append(fg_index) 98 | 99 | assert c_input_ids[-1] > self.config.vocab_size 100 | else: 101 | c_input_ids.append(int(w)) 102 | assert c_input_ids[-1] <= self.config.vocab_size 103 | else: 104 | for index, w in enumerate(str_id_list): 105 | c_input_ids.append(int(w)) 106 | assert c_input_ids[-1] <= self.config.vocab_size 107 | 108 | for j, cls_index in enumerate(concept_cls): 109 | if cls_index == 0: continue 110 | 111 | for (_, fg_index) in self.copy_vocab.d_to_w_group[cls_index]: 112 | fg_ch_list = self.copy_vocab.token_fg_w[fg_index] 113 | s1 = '&'.join([str(f) for f in fg_ch_list]) 114 | 115 | for ch_idx, first_ch in enumerate(c_input_ids): 116 | if first_ch == fg_ch_list[0]: 117 | s2 = '&'.join([str(f) for f in c_input_ids[ch_idx: ch_idx + len(fg_ch_list)]]) 118 | if s1 == s2: 119 | if ch_idx + len(fg_ch_list) >= len(c_input_ids) - 1 or c_input_ids[ch_idx + len(fg_ch_list)] not in self.attachable_index: 120 | decoder_mention_flag[:ch_idx + len(fg_ch_list), start_pos[j]:start_pos[j+1]] = 1 121 | if not self.config.static_mf: 122 | decoder_mention_flag[ch_idx + len(fg_ch_list):, start_pos[j]:start_pos[j+1]] = 2 123 | else: 124 | decoder_mention_flag[ch_idx + len(fg_ch_list):, start_pos[j]:start_pos[j+1]] = 1 125 | 126 | break 127 | 128 | if not self.config.static_mf: 129 | for j in range(len(concept_cls)): 130 | for jj in range(start_pos[j], start_pos[j+1]): 131 | if not decoder_mention_flag[-1, jj] == 2: 132 | use_this_record = False 133 | 134 | if use_this_record: 135 | instance_tuple = (concept_set_input_ids, position_indicator, cls_on_input, concept_cls, copy_mention_flag, decoder_mention_flag, c_input_ids, gt, item['concept_set'].split('#')) 136 | self.record.append(instance_tuple) 137 | 138 | else: 139 | copy_mention_flag = np.zeros((1, 5)) 140 | copy_mention_flag[0, :len(concept_cls)] = 1 141 | decoder_mention_flag = np.zeros((1, len(concept_set_input_ids))) 142 | decoder_mention_flag[0, start_pos[0]: start_pos[-1]] = 1 143 | self.record.append((concept_set_input_ids, position_indicator, cls_on_input, concept_cls, copy_mention_flag, decoder_mention_flag, None, gt, item['concept_set'].split('#'))) 144 | 145 | if self.is_training: random.shuffle(self.record) 146 | 147 | def __len__(self): 148 | return len(self.record) 149 | 150 | def __getitem__(self, index): 151 | concept_set, pos, cls_on_input, concept_cls, copy_mention_flag, decoder_mention_flag, gen, gt, gt_concept = self.record[index] 152 | 153 | item = { 154 | "gt": gt, 155 | "gt_concepts": gt_concept, 156 | "concept_set": concept_set, 157 | "copy_pos": pos, 158 | "concept_cls": concept_cls, 159 | "copy_mention_flag": copy_mention_flag, 160 | "decoder_mention_flag": decoder_mention_flag, 161 | "cls_on_input": cls_on_input 162 | } 163 | 164 | if self.is_training: 165 | item['gen'] = gen 166 | 167 | return item 168 | 169 | 170 | def data_wrapper(dataset, tokenizer, decoder_start_token_id): 171 | batch_size = len(dataset) 172 | new_dataset = {'gt': [d['gt'] for d in dataset], 'gt_concepts': [d['gt_concepts'] for d in dataset]} 173 | 174 | _PAD = tokenizer.pad_token_id 175 | _EOS = tokenizer.eos_token_id 176 | _BOS = decoder_start_token_id 177 | 178 | max_concept_len = max([len(d['concept_set']) for d in dataset]) 179 | concept_set_input = np.full((batch_size, max_concept_len), _PAD, dtype=np.int64) 180 | cls_on_input = np.full((batch_size, max_concept_len), 0, dtype=np.int64) 181 | for i, d in enumerate(dataset): 182 | concept_set_input[i, :len(d['concept_set'])] = d['concept_set'] 183 | cls_on_input[i, :d['cls_on_input'].shape[0]] = d['cls_on_input'] 184 | new_dataset['input_ids'] = torch.from_numpy(concept_set_input) 185 | new_dataset['cls_on_input'] = torch.from_numpy(cls_on_input) 186 | new_dataset['attention_mask'] = (new_dataset['input_ids'] != _PAD).float() 187 | 188 | copy_pos = np.zeros((batch_size, 5, max_concept_len), dtype=np.float32) 189 | for i, d in enumerate(dataset): 190 | copy_pos[i, :, :d['copy_pos'].shape[-1]] = d['copy_pos'] 191 | new_dataset['copy_pos'] = torch.from_numpy(copy_pos) 192 | 193 | concept_cls = np.zeros((batch_size, 5), dtype=np.int64) 194 | for i, d in enumerate(dataset): 195 | concept_cls[i, :len(d['concept_cls'])] = d['concept_cls'] 196 | new_dataset['concept_cls'] = torch.from_numpy(concept_cls) 197 | 198 | max_gen_len = 1 199 | if 'gen' in dataset[0]: 200 | max_gen_len = max([len(d['gen']) for d in dataset]) + 1 201 | gen_out_seqs = np.full((batch_size, max_gen_len), -100, dtype=np.int64) 202 | gen_input_seqs = np.full((batch_size, max_gen_len), _EOS, dtype=np.int64) 203 | for i, d in enumerate(dataset): 204 | gen_input_seqs[i, 1:len(d['gen']) + 1] = d['gen'] 205 | gen_out_seqs[i, :len(d['gen'])] = d['gen'] 206 | gen_input_seqs[:, 0] = _BOS 207 | new_dataset['labels'] = torch.from_numpy(gen_out_seqs) 208 | new_dataset['decoder_input_ids'] = torch.from_numpy(gen_input_seqs) 209 | new_dataset['decoder_input_mask'] = (torch.from_numpy(gen_input_seqs) != _EOS).bool() 210 | 211 | copy_mention_flag = np.zeros((batch_size, max_gen_len, 5), dtype=np.int64) 212 | for i, d in enumerate(dataset): 213 | copy_mention_flag[i, :d['copy_mention_flag'].shape[0]] = d['copy_mention_flag'] 214 | new_dataset['copy_mention_flag'] = torch.from_numpy(copy_mention_flag) 215 | 216 | decoder_mention_flag = np.zeros((batch_size, max_gen_len, max_concept_len), dtype=np.int64) 217 | for i, d in enumerate(dataset): 218 | decoder_mention_flag[i, :d['decoder_mention_flag'].shape[0], :d['decoder_mention_flag'].shape[1]] = d['decoder_mention_flag'] 219 | new_dataset['decoder_mention_flag'] = torch.from_numpy(decoder_mention_flag) 220 | 221 | return new_dataset 222 | 223 | def get_data_loader(dataset, batch_size): 224 | collate_fn = lambda d: data_wrapper(d, dataset.tokenizer, dataset.decoder_start_token_id) 225 | return DataLoader(dataset, 226 | batch_size=batch_size, 227 | num_workers=0, 228 | collate_fn=collate_fn 229 | ) 230 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /train_e2e_T5.py: -------------------------------------------------------------------------------- 1 | from dataset.vocabulary import T5CopyVocabulary 2 | from dataset.e2e_dataset import E2EDataset, get_data_loader 3 | import argparse 4 | import torch 5 | import torch.nn as nn 6 | from config import Config 7 | import numpy as np 8 | from transformers import T5Tokenizer 9 | from checkpointing import CheckpointManager 10 | from t5 import get_lm_representation 11 | import utils 12 | from tqdm import tqdm 13 | import math 14 | import os, sys 15 | from speaksee import evaluation 16 | import random 17 | from dataset.pymteval import BLEUScore, NISTScore 18 | from dataset.diversity import distinct_n 19 | from constraint import CBSConstraint 20 | import json 21 | 22 | parser = argparse.ArgumentParser("Train a CommonGen T5") 23 | parser.add_argument( 24 | "--config", required=True, help="Path to a config file with all configuration parameters." 25 | ) 26 | parser.add_argument( 27 | "--config-override", 28 | default=[], 29 | nargs="*", 30 | help="A sequence of key-value pairs specifying certain config arguments (with dict-like " 31 | "nesting) using a dot operator. The actual config will be updated and recorded in " 32 | "the serialization directory.", 33 | ) 34 | parser.add_argument( 35 | "--serialization-dir", 36 | default=None, 37 | help="Path to a (non-existent) directory for serializing checkpoints and tensorboard logs.", 38 | ) 39 | parser.add_argument( 40 | "--start-from-checkpoint", 41 | default=None, 42 | help="Path to load checkpoint and continue training [only supported for module_training].", 43 | ) 44 | parser.add_argument( 45 | "--constraint-vocab", 46 | default=None, 47 | help="Path to load constraint vocab", 48 | ) 49 | parser.add_argument( 50 | "--output-path", 51 | default=None, 52 | help="Path to save output captions", 53 | ) 54 | group = parser.add_mutually_exclusive_group() 55 | group.add_argument('--train', action='store_true') 56 | group.add_argument('--validation', action='store_true') 57 | group.add_argument('--test', action='store_true') 58 | 59 | def run_eval(_C, model, eval_data_iter, tokenizer, copy_vocab, device, decode_constraint=None, constraint_vocab=None, output_path=None): 60 | model.eval() 61 | if decode_constraint is not None: 62 | assert constraint_vocab is not None 63 | constraint_vocab_dict = {} 64 | with open(constraint_vocab) as out: 65 | for line in out: 66 | line = line.strip() 67 | items = line.split('@') 68 | constraint_vocab_dict[items[0]] = items[1:] 69 | 70 | 71 | gt_cap, pred = [], [] 72 | obj_coverage = [0, 0] 73 | with torch.no_grad(): 74 | for batch in tqdm(eval_data_iter): 75 | for n in batch: 76 | if n not in ['gt', 'gt_mr', 'ins_id']: 77 | batch[n] = batch[n].to(device) 78 | 79 | if decode_constraint is not None: 80 | constraint_dict = {} 81 | for id_, gt_mr in enumerate(batch['gt_mr']): 82 | constraint_dict[id_] = [] 83 | for (mr, _) in gt_mr: 84 | if mr in constraint_vocab_dict: 85 | c = [] 86 | for fg_w in constraint_vocab_dict[mr]: 87 | fg_index = copy_vocab.w_to_i[fg_w] 88 | c.append(copy_vocab.token_fg_w[fg_index]) 89 | constraint_dict[id_].append(c) 90 | 91 | state_transform_list = [] 92 | state_num_list = [] 93 | for image_id in range(len(batch['gt_mr'])): 94 | state_matrix, state_num = decode_constraint.get_state_matrix(_C.vocab_size, constraint_dict[image_id], image_id) 95 | state_transform_list.append(state_matrix) 96 | state_num_list.append(state_num) 97 | max_size = max(state_num_list) 98 | state_transform_list = [s[:, :max_size, :max_size]for s in state_transform_list] 99 | state_transition = np.concatenate(state_transform_list, axis=0) 100 | state_transition = torch.from_numpy(state_transition).bool().to(device) 101 | else: 102 | state_transition = None 103 | 104 | outputs = model.search( 105 | input_ids=batch['encoder_input_ids'], 106 | attention_mask=batch['encoder_mask'], 107 | decoder_mention_flag=batch['mention_flag'], 108 | decoder_cls_on_input=batch['encoder_cls'], 109 | state_transition=state_transition, 110 | num_beams=5, 111 | length_penalty=1.0, 112 | max_length=_C.max_generation_len, 113 | min_length=2, 114 | no_repeat_ngram_size=3, 115 | early_stopping=True 116 | ) 117 | 118 | if decode_constraint is not None: 119 | outputs = decode_constraint.select_state_func(outputs, [i for i in range(len(batch['gt_mr']))]) 120 | 121 | dec = [tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=False) for g in outputs] 122 | for ins_id, d, gt, gt_mr in zip(batch['ins_id'], dec, batch['gt'], batch['gt_mr']): 123 | gt_cap.append(gt) 124 | pred.append((ins_id, d)) 125 | gt_count = 0 126 | lower_d = d.lower() 127 | for (fullname, g_class_name) in gt_mr: 128 | gt_count += 1 129 | cls_id = copy_vocab.word_to_category_id[fullname] 130 | 131 | has_found = False 132 | for (w, _) in copy_vocab.d_to_w_group[cls_id]: 133 | if w.lower() in lower_d: 134 | obj_coverage[0] += 1 135 | has_found = True 136 | break 137 | 138 | # if not has_found: 139 | # print(d) 140 | # print(copy_vocab.d_to_w_group[cls_id]) 141 | # print([gt_mr]) 142 | # print("-------") 143 | obj_coverage[1] += gt_count 144 | 145 | for p in pred[:20]: 146 | print(p) 147 | 148 | if output_path is not None: 149 | output_list = [] 150 | for _id, out in pred: 151 | output_list.append({"image_id": _id, "caption": out}) 152 | with open(output_path, 'w') as out: 153 | out.write(json.dumps(output_list)) 154 | 155 | pred = [p[1] for p in pred] 156 | gts = evaluation.PTBTokenizer.tokenize(gt_cap) 157 | gen = evaluation.PTBTokenizer.tokenize(pred) 158 | 159 | print("Object Coverage %.2f" % (100 * obj_coverage[0] / obj_coverage[1])) 160 | 161 | diversity_sen = [v[0].split() for (_, v) in gen.items()] 162 | print("Diversity-1 %.2f" % distinct_n(diversity_sen, 1)) 163 | print("Diversity-2 %.2f" % distinct_n(diversity_sen, 2)) 164 | 165 | bleu = BLEUScore() 166 | nist = NISTScore() 167 | for sents_ref, sent_sys in zip(gt_cap, pred): 168 | bleu.append(sent_sys, sents_ref) 169 | nist.append(sent_sys, sents_ref) 170 | print("NIST %.2f" % (nist.score())) 171 | print("BLEU %.2f" % (bleu.score() * 100)) 172 | 173 | val_meteor, _ = evaluation.Meteor().compute_score(gts, gen) 174 | print('METEOR %.2f' % (val_meteor * 100)) 175 | 176 | val_cider, individual_cider = evaluation.Cider().compute_score(gts, gen) 177 | print('CIDEr %.2f' % (val_cider)) 178 | 179 | val_rouge, _ = evaluation.Rouge().compute_score(gts, gen) 180 | print('ROUGE_L %.2f' % (val_rouge * 100)) 181 | 182 | metric_dict = {"CIDEr": {"entire": val_cider}} 183 | metric_dict.update({"METEOR": {"entire": val_meteor}}) 184 | 185 | return metric_dict 186 | 187 | 188 | if __name__ == "__main__": 189 | _A = parser.parse_args() 190 | 191 | _C = Config(_A.config, _A.config_override) 192 | 193 | np.random.seed(_C.random_seed) 194 | random.seed(_C.random_seed) 195 | torch.manual_seed(_C.random_seed) 196 | torch.cuda.manual_seed_all(_C.random_seed) 197 | torch.backends.cudnn.benchmark = False 198 | torch.backends.cudnn.deterministic = True 199 | 200 | 201 | device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') 202 | 203 | tokenizer = T5Tokenizer.from_pretrained(_C.lm_type, cache_dir='.') 204 | copy_vocab = T5CopyVocabulary(_C.copy_vocab_path, tokenizer, sep='@') 205 | lm = get_lm_representation(_C, tokenizer, copy_vocab) 206 | model = lm['t5'] 207 | model = model.to(device) 208 | _C.vocab_size = model.config.vocab_size 209 | 210 | total_parameter_count = 0 211 | trainable_parameter_count = 0 212 | for p in model.parameters(): 213 | total_parameter_count += p.numel() 214 | if p.requires_grad: 215 | trainable_parameter_count += p.numel() 216 | print('Total Parameter Count %d' % total_parameter_count) 217 | print('Trainable Parameter Count %d' % trainable_parameter_count) 218 | 219 | if len(_C.decode_constrain) > 0: 220 | decode_constraint = CBSConstraint(_C.decode_constrain, 2) 221 | else: 222 | decode_constraint = None 223 | 224 | if _A.train: 225 | train_data = E2EDataset(_C, _C.train_path, tokenizer, copy_vocab, is_training=True) 226 | train_data_loader = get_data_loader(train_data, _C.batch_size) 227 | train_iter = iter(train_data_loader) 228 | 229 | dev_data = E2EDataset(_C, _C.dev_path if (_A.validation or _A.train) else _C.test_path, tokenizer, copy_vocab) 230 | dev_data_loader = get_data_loader(dev_data, _C.batch_size) 231 | 232 | print(_C) 233 | for arg in vars(_A): 234 | print("{:<20}: {}".format(arg, getattr(_A, arg))) 235 | 236 | if _A.validation or _A.test: 237 | if torch.cuda.is_available(): 238 | model.load_state_dict(torch.load(os.path.join(_A.start_from_checkpoint, 'model-best.pth'))['model'], strict=False) 239 | else: 240 | model.load_state_dict(torch.load(os.path.join(_A.start_from_checkpoint, 'model-best.pth'), map_location=torch.device('cpu'))['model'], strict=False) 241 | 242 | run_eval(_C, model, dev_data_loader, tokenizer, copy_vocab, device, decode_constraint=decode_constraint, constraint_vocab=_A.constraint_vocab, output_path=_A.output_path) 243 | 244 | 245 | if _A.train: 246 | _C.num_training_steps = len(train_iter) * _C.max_epoch / _C.gradient_accumulation_steps 247 | epoch_num = math.ceil(_C.num_training_steps / _C.checkpoint_every_step) 248 | 249 | checkpoint_manager = CheckpointManager(model, _A.serialization_dir, mode="max") 250 | optimizer = utils.build_optimizer(_C, model) 251 | 252 | os.makedirs(_A.serialization_dir, exist_ok=True) 253 | _C.dump(os.path.join(_A.serialization_dir, "config.yml")) 254 | 255 | eval_every = _C.checkpoint_every_step * _C.gradient_accumulation_steps 256 | total_step = 0 257 | 258 | for epoch in range(epoch_num): 259 | print('EPOCH %d / %d' % (epoch + 1, epoch_num)) 260 | run_step = eval_every if total_step + eval_every < len(train_iter) * _C.max_epoch else len(train_iter) * _C.max_epoch - total_step 261 | model.train() 262 | 263 | with tqdm(total=math.ceil(run_step / _C.gradient_accumulation_steps), file=sys.stdout) as pbar: 264 | for step in range(run_step): 265 | try: 266 | batch = next(train_iter) 267 | except: 268 | train_iter = iter(train_data_loader) 269 | batch = next(train_iter) 270 | 271 | for n in batch: 272 | if n not in ['gt', 'gt_mr', 'ins_id']: 273 | batch[n] = batch[n].to(device) 274 | # optimizer.zero_grad() 275 | total_step += 1 276 | outputs = model( 277 | input_ids=batch['encoder_input_ids'], 278 | attention_mask=batch['encoder_mask'], 279 | decoder_mention_flag=batch['mention_flag'], 280 | decoder_cls_on_input=batch['encoder_cls'], 281 | labels=batch['cap_decoder_input_ids'] 282 | ) 283 | loss = outputs.loss 284 | loss = loss / _C.gradient_accumulation_steps 285 | loss.backward() 286 | 287 | if _C.grad_clip_value > 0: 288 | torch.nn.utils.clip_grad_value_(model.parameters(), _C.grad_clip_value) 289 | if (step + 1) % _C.gradient_accumulation_steps == 0: 290 | optimizer.step() 291 | if torch.cuda.is_initialized(): 292 | torch.cuda.synchronize() 293 | pbar.set_description("loss %.2f" % (loss.item() * _C.gradient_accumulation_steps)) 294 | pbar.update(1) 295 | optimizer.zero_grad() 296 | 297 | eval_result = run_eval(_C, model, dev_data_loader, tokenizer, copy_vocab, device) 298 | checkpoint_manager.step(eval_result["CIDEr"]["entire"]) 299 | -------------------------------------------------------------------------------- /dataset/pymteval.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | """ 5 | BLEU & NIST measurements -- should be compatible with mteval-v13a.pl (basic tokenization). 6 | Also provides BLEU +1 smoothing (if set to work like that). 7 | 8 | TODO: International tokenization 9 | TODO: NIST with variable number of references is not the same as the edited mteval-v13a.pl, 10 | but this should be the proper way to compute it. Should be fixed there. 11 | """ 12 | 13 | from __future__ import unicode_literals 14 | from __future__ import division 15 | from builtins import zip 16 | from builtins import range 17 | from past.utils import old_div 18 | from builtins import object 19 | from collections import defaultdict 20 | import math 21 | import re 22 | 23 | 24 | class NGramScore(object): 25 | """Base class for BLEU & NIST, providing tokenization and some basic n-gram matching 26 | functions.""" 27 | 28 | def __init__(self, max_ngram, case_sensitive): 29 | """Create the scoring object. 30 | @param max_ngram: the n-gram level to compute the score for 31 | @param case_sensitive: use case-sensitive matching? 32 | """ 33 | self.max_ngram = max_ngram 34 | self.case_sensitive = case_sensitive 35 | 36 | def reset(self): 37 | """Reset the object, zero all counters.""" 38 | raise NotImplementedError() 39 | 40 | def append(self, pred_sent, ref_sents): 41 | """Add a sentence to the statistics. 42 | @param pred_sent: system output / predicted sentence 43 | @param ref_sents: reference sentences 44 | """ 45 | raise NotImplementedError() 46 | 47 | def score(self): 48 | """Compute the current score based on sentences added so far.""" 49 | raise NotImplementedError() 50 | 51 | def ngrams(self, n, sent): 52 | """Given a sentence, return n-grams of nodes for the given N. Lowercases 53 | everything if the measure should not be case-sensitive. 54 | 55 | @param n: n-gram 'N' (1 for unigrams, 2 for bigrams etc.) 56 | @param sent: the sent in question 57 | @return: n-grams of nodes, as tuples of tuples (t-lemma & formeme) 58 | """ 59 | if not self.case_sensitive: 60 | return list(zip(*[[tok.lower() for tok in sent[i:]] for i in range(n)])) 61 | return list(zip(*[sent[i:] for i in range(n)])) 62 | 63 | def check_tokenized(self, pred_sent, ref_sents): 64 | """Tokenize the predicted sentence and reference sentences, if they are not tokenized. 65 | @param pred_sent: system output / predicted sentence 66 | @param ref_sent: a list of corresponding reference sentences 67 | @return: a tuple of (pred_sent, ref_sent) where everything is tokenized 68 | """ 69 | # tokenize if needed 70 | pred_sent = pred_sent if isinstance(pred_sent, list) else self.tokenize(pred_sent) 71 | ref_sents = [ref_sent if isinstance(ref_sent, list) else self.tokenize(ref_sent) 72 | for ref_sent in ref_sents] 73 | return pred_sent, ref_sents 74 | 75 | def get_ngram_counts(self, n, sents): 76 | """Returns a dictionary with counts of all n-grams in the given sentences. 77 | @param n: the "n" in n-grams (how long the n-grams should be) 78 | @param sents: list of sentences for n-gram counting 79 | @return: a dictionary (ngram: count) listing counts of n-grams attested in any of the sentences 80 | """ 81 | merged_ngrams = {} 82 | 83 | for sent in sents: 84 | ngrams = defaultdict(int) 85 | 86 | for ngram in self.ngrams(n, sent): 87 | ngrams[ngram] += 1 88 | for ngram, cnt in ngrams.items(): 89 | merged_ngrams[ngram] = max((merged_ngrams.get(ngram, 0), cnt)) 90 | return merged_ngrams 91 | 92 | def tokenize(self, sent): 93 | """This tries to mimic multi-bleu-detok from Moses, and by extension mteval-v13b. 94 | Code taken directly from there and attempted rewrite into Python.""" 95 | # language-independent part: 96 | sent = re.sub(r'', r'', sent) # strip "skipped" tags 97 | sent = re.sub(r'-\n', r'', sent) # strip end-of-line hyphenation and join lines 98 | sent = re.sub(r'\n', r' ', sent) # join lines 99 | sent = re.sub(r'"', r'"', sent) # convert SGML tag for quote to " 100 | sent = re.sub(r'&', r'&', sent) # convert SGML tag for ampersand to & 101 | sent = re.sub(r'<', r'<', sent) # convert SGML tag for less-than to > 102 | sent = re.sub(r'>', r'>', sent) # convert SGML tag for greater-than to < 103 | 104 | # language-dependent part (assuming Western languages): 105 | sent = " " + sent + " " # pad with spaces 106 | sent = re.sub(r'([\{-\~\[-\` -\&\(-\+\:-\@\/])', r' \1 ', sent) # tokenize punctuation 107 | sent = re.sub(r'([^0-9])([\.,])', r'\1 \2 ', sent) # tokenize period and comma unless preceded by a digit 108 | sent = re.sub(r'([\.,])([^0-9])', r' \1 \2', sent) # tokenize period and comma unless followed by a digit 109 | sent = re.sub(r'([0-9])(-)', r'\1 \2 ', sent) # tokenize dash when preceded by a digit 110 | sent = re.sub(r'\s+', r' ', sent) # one space only between words 111 | sent = sent.strip() # remove padding 112 | 113 | return sent.split(' ') 114 | 115 | 116 | class BLEUScore(NGramScore): 117 | """An accumulator object capable of computing BLEU score using multiple references. 118 | 119 | The BLEU score is always smoothed a bit so that it's never undefined. For sentence-level 120 | measurements, proper smoothing should be used via the smoothing parameter (set to 1.0 for 121 | the same behavior as default Moses's MERT sentence BLEU). 122 | """ 123 | 124 | TINY = 1e-15 125 | SMALL = 1e-9 126 | 127 | def __init__(self, max_ngram=4, case_sensitive=False, smoothing=0.0): 128 | """Create the scoring object. 129 | @param max_ngram: the n-gram level to compute the score for (default: 4) 130 | @param case_sensitive: use case-sensitive matching (default: no) 131 | @param smoothing: constant to add for smoothing (defaults to 0.0, sentBLEU uses 1.0) 132 | """ 133 | super(BLEUScore, self).__init__(max_ngram, case_sensitive) 134 | self.smoothing = smoothing 135 | self.reset() 136 | 137 | def reset(self): 138 | """Reset the object, zero all counters.""" 139 | self.ref_len = 0 140 | self.cand_lens = [0] * self.max_ngram 141 | self.hits = [0] * self.max_ngram 142 | 143 | def append(self, pred_sent, ref_sents): 144 | """Append a sentence for measurements, increase counters. 145 | 146 | @param pred_sent: the system output sentence (string/list of tokens) 147 | @param ref_sents: the corresponding reference sentences (list of strings/lists of tokens) 148 | """ 149 | pred_sent, ref_sents = self.check_tokenized(pred_sent, ref_sents) 150 | 151 | # compute n-gram matches 152 | for i in range(self.max_ngram): 153 | self.hits[i] += self.compute_hits(i + 1, pred_sent, ref_sents) 154 | self.cand_lens[i] += len(pred_sent) - i 155 | 156 | # take the reference that is closest in length to the candidate 157 | # (if there are two of the same distance, take the shorter one) 158 | closest_ref = min(ref_sents, key=lambda ref_sent: (abs(len(ref_sent) - len(pred_sent)), len(ref_sent))) 159 | self.ref_len += len(closest_ref) 160 | 161 | def score(self): 162 | """Return the current BLEU score, according to the accumulated counts.""" 163 | return self.bleu() 164 | 165 | def compute_hits(self, n, pred_sent, ref_sents): 166 | """Compute clipped n-gram hits for the given sentences and the given N 167 | 168 | @param n: n-gram 'N' (1 for unigrams, 2 for bigrams etc.) 169 | @param pred_sent: the system output sentence (tree/tokens) 170 | @param ref_sents: the corresponding reference sentences (list/tuple of trees/tokens) 171 | """ 172 | merged_ref_ngrams = self.get_ngram_counts(n, ref_sents) 173 | pred_ngrams = self.get_ngram_counts(n, [pred_sent]) 174 | 175 | hits = 0 176 | for ngram, cnt in pred_ngrams.items(): 177 | hits += min(merged_ref_ngrams.get(ngram, 0), cnt) 178 | 179 | return hits 180 | 181 | def bleu(self): 182 | """Return the current BLEU score, according to the accumulated counts.""" 183 | # brevity penalty (smoothed a bit: if candidate length is 0, we change it to 1e-5 184 | # to avoid division by zero) 185 | bp = 1.0 186 | if (self.cand_lens[0] <= self.ref_len): 187 | bp = math.exp(1.0 - old_div(self.ref_len, 188 | (float(self.cand_lens[0]) if self.cand_lens[0] else 1e-5))) 189 | 190 | return bp * self.ngram_precision() 191 | 192 | def ngram_precision(self): 193 | """Return the current n-gram precision (harmonic mean of n-gram precisions up to max_ngram) 194 | according to the accumulated counts.""" 195 | prec_log_sum = 0.0 196 | for n_hits, n_len in zip(self.hits, self.cand_lens): 197 | n_hits += self.smoothing # pre-set smoothing 198 | n_len += self.smoothing 199 | n_hits = max(n_hits, self.TINY) # forced smoothing just a litle to make BLEU defined 200 | n_len = max(n_len, self.SMALL) # only applied for zeros 201 | prec_log_sum += math.log(old_div(n_hits, n_len)) 202 | 203 | return math.exp((1.0 / self.max_ngram) * prec_log_sum) 204 | 205 | 206 | class NISTScore(NGramScore): 207 | """An accumulator object capable of computing NIST score using multiple references.""" 208 | 209 | # NIST beta parameter setting (copied from mteval-13a.pl) 210 | BETA = old_div(- math.log(0.5), math.log(1.5) ** 2) 211 | 212 | def __init__(self, max_ngram=5, case_sensitive=False): 213 | """Create the scoring object. 214 | @param max_ngram: the n-gram level to compute the score for (default: 5) 215 | @param case_sensitive: use case-sensitive matching (default: no) 216 | """ 217 | super(NISTScore, self).__init__(max_ngram, case_sensitive) 218 | self.reset() 219 | 220 | def reset(self): 221 | """Reset the object, zero all counters.""" 222 | self.ref_ngrams = [defaultdict(int) for _ in range(self.max_ngram + 1)] # has 0-grams 223 | # these two don't have 0-grams 224 | self.hit_ngrams = [[] for _ in range(self.max_ngram)] 225 | self.cand_lens = [[] for _ in range(self.max_ngram)] 226 | self.avg_ref_len = 0.0 227 | 228 | def append(self, pred_sent, ref_sents): 229 | """Append a sentence for measurements, increase counters. 230 | 231 | @param pred_sent: the system output sentence (string/list of tokens) 232 | @param ref_sents: the corresponding reference sentences (list of strings/lists of tokens) 233 | """ 234 | pred_sent, ref_sents = self.check_tokenized(pred_sent, ref_sents) 235 | # collect ngram matches 236 | for n in range(self.max_ngram): 237 | self.cand_lens[n].append(len(pred_sent) - n) # keep track of output length 238 | merged_ref_ngrams = self.get_ngram_counts(n + 1, ref_sents) 239 | pred_ngrams = self.get_ngram_counts(n + 1, [pred_sent]) 240 | # collect ngram matches 241 | hit_ngrams = {} 242 | for ngram in pred_ngrams: 243 | hits = min(pred_ngrams[ngram], merged_ref_ngrams.get(ngram, 0)) 244 | if hits: 245 | hit_ngrams[ngram] = hits 246 | self.hit_ngrams[n].append(hit_ngrams) 247 | # collect total reference ngram counts 248 | for ref_sent in ref_sents: 249 | for ngram in self.ngrams(n + 1, ref_sent): 250 | self.ref_ngrams[n + 1][ngram] += 1 251 | # ref_ngrams: use 0-grams for information value as well 252 | ref_len_sum = sum(len(ref_sent) for ref_sent in ref_sents) 253 | self.ref_ngrams[0][()] += ref_len_sum 254 | # collect average reference length 255 | self.avg_ref_len += ref_len_sum / float(len(ref_sents)) 256 | 257 | def score(self): 258 | """Return the current NIST score, according to the accumulated counts.""" 259 | return self.nist() 260 | 261 | def info(self, ngram): 262 | """Return the NIST informativeness of an n-gram.""" 263 | if ngram not in self.ref_ngrams[len(ngram)]: 264 | return 0.0 265 | return math.log(self.ref_ngrams[len(ngram) - 1][ngram[:-1]] / 266 | float(self.ref_ngrams[len(ngram)][ngram]), 2) 267 | 268 | def nist_length_penalty(self, lsys, avg_lref): 269 | """Compute the NIST length penalty, based on system output length & average reference length. 270 | @param lsys: total system output length 271 | @param avg_lref: total average reference length 272 | @return: NIST length penalty term 273 | """ 274 | ratio = lsys / float(avg_lref) 275 | if ratio >= 1: 276 | return 1 277 | if ratio <= 0: 278 | return 0 279 | return math.exp(-self.BETA * math.log(ratio) ** 2) 280 | 281 | def nist(self): 282 | """Return the current NIST score, according to the accumulated counts.""" 283 | # 1st NIST term 284 | hit_infos = [0.0 for _ in range(self.max_ngram)] 285 | for n in range(self.max_ngram): 286 | for hit_ngrams in self.hit_ngrams[n]: 287 | hit_infos[n] += sum(self.info(ngram) * hits for ngram, hits in hit_ngrams.items()) 288 | total_lens = [sum(self.cand_lens[n]) for n in range(self.max_ngram)] 289 | nist_sum = sum(old_div(hit_info, total_len) for hit_info, total_len in zip(hit_infos, total_lens)) 290 | # length penalty term 291 | bp = self.nist_length_penalty(sum(self.cand_lens[0]), self.avg_ref_len) 292 | return bp * nist_sum 293 | -------------------------------------------------------------------------------- /dataset/coco_dataset.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import Dataset, DataLoader 2 | from dataset.reader import CocoCaptionsReader, ImageFeaturesReader 3 | from tqdm import tqdm 4 | import json 5 | import random 6 | import sys 7 | from dataset.data_utils import * 8 | import numpy as np 9 | 10 | class COCODataset(Dataset): 11 | 12 | def __init__(self, config, h5_path, tokenizer, copy_vocab, attachable_index, caption_path=None, copy_h5_path=None, is_training=False, in_memory=False, cbs_class_path=None): 13 | if caption_path is not None: 14 | self._captions_reader = CocoCaptionsReader(caption_path, config.word_norm_jsonpath if len(config.word_norm_jsonpath) > 0 else None, rm_dumplicated_caption=config.rm_dumplicated_caption, shuffle=config.shuffle_data, is_train=is_training, rm_punctuation=config.rm_punctuation) 15 | else: 16 | self._captions_reader = None 17 | 18 | np.set_printoptions(threshold=sys.maxsize) 19 | 20 | self._image_features_reader = ImageFeaturesReader(h5_path) 21 | if config.use_copy_obj: 22 | self._copy_image_features_reader = ImageFeaturesReader(copy_h5_path, start_index=1601) 23 | self.config = config 24 | self.is_training = is_training 25 | self.copy_vocab = copy_vocab 26 | self.tokenizer = tokenizer 27 | self.attachable_index = attachable_index 28 | self.cbs_class = None 29 | if cbs_class_path is not None: 30 | self.cbs_class = {} 31 | with open(cbs_class_path) as out: 32 | for line in out: 33 | line = line.strip() 34 | items = line.split(',') 35 | self.cbs_class[int(items[0])] = sorted([int(v) for v in items[1:]]) 36 | 37 | self._image_ids = sorted(list(self._image_features_reader._map.keys())) 38 | self.obj_cache = {} 39 | self.cap_cache = {} 40 | self.global_obj_cache = {} 41 | 42 | if len(config.object_blacklist_path) > 0: 43 | with open(config.object_blacklist_path) as out: 44 | blacklist = json.load(out) 45 | full_list = blacklist['blacklist_categories'] + (blacklist['val_blacklist_categories'] if not is_training else []) 46 | self._blacklist_categories = set([s.lower() for s in full_list]) 47 | else: 48 | self._blacklist_categories = None 49 | 50 | self.img_index = self.tokenizer("", return_tensors="np")['input_ids'][0, 0] 51 | self.background_index = self.tokenizer("background", return_tensors="np")['input_ids'][0, 0] 52 | 53 | if in_memory or (not is_training): 54 | self._image_features_reader.open_h5_file() 55 | if config.use_copy_obj: 56 | self._copy_image_features_reader.open_h5_file() 57 | 58 | for index in tqdm(range(len(self._captions_reader))): 59 | img_id, cap, _ = self._captions_reader[index] 60 | if img_id not in self.obj_cache: 61 | self.process_obj(img_id) 62 | if self._captions_reader is not None: 63 | if img_id not in self.cap_cache or cap not in self.cap_cache[img_id]: 64 | self.process_cap(img_id, cap) 65 | for img_id in tqdm(self.cap_cache): 66 | self.process_global_cap(img_id) 67 | 68 | self._image_features_reader.close_h5_file() 69 | if config.use_copy_obj: 70 | self._copy_image_features_reader.close_h5_file() 71 | 72 | def __len__(self): 73 | if self.is_training: 74 | return len(self._captions_reader) 75 | else: 76 | return len(self._image_ids) 77 | 78 | def process_global_cap(self, img_id): 79 | if self._captions_reader is not None: 80 | obj_count = {} 81 | for _, cap in self.cap_cache[img_id].items(): 82 | for c in cap['used_cls']: 83 | if c not in obj_count: 84 | obj_count[c] = 0 85 | obj_count[c] += 1 86 | sort_obj = sorted(obj_count.items(), key=lambda x: x[1], reverse=True) 87 | top_obj = [x[0] for x in sort_obj[:2]] 88 | else: 89 | top_obj = [] 90 | 91 | mention_flag = np.zeros((1, self.obj_cache[img_id]['encoder_input_ids'].shape[0]), dtype=np.int64) 92 | if self.cbs_class is not None: 93 | top_obj = self.cbs_class[img_id] 94 | 95 | for index, ecls in enumerate(self.obj_cache[img_id]['encoder_cls'].tolist()): 96 | if ecls in top_obj: 97 | mention_flag[0, index] = 1 98 | elif ecls < 1601: 99 | mention_flag[0, index] = 3 100 | 101 | self.global_obj_cache[img_id] = mention_flag 102 | 103 | def process_cap(self, img_id, cap): 104 | if img_id not in self.cap_cache: 105 | self.cap_cache[img_id] = {} 106 | 107 | if cap not in self.cap_cache[img_id]: 108 | self.cap_cache[img_id][cap] = {} 109 | self.cap_cache[img_id][cap]['input_ids'] = self.tokenizer(cap.lower(), return_tensors="np")['input_ids'][0, :self.config.max_generation_len] 110 | 111 | mention_flag = np.zeros((self.cap_cache[img_id][cap]['input_ids'].shape[0], self.obj_cache[img_id]['encoder_input_ids'].shape[0]), dtype=np.int64) 112 | c_input_ids = self.cap_cache[img_id][cap]['input_ids'].tolist() 113 | en_cls = self.obj_cache[img_id]['encoder_cls'].tolist() 114 | visit_en_cls = [] 115 | start_pos = {} 116 | for i, c in enumerate(en_cls): 117 | if c not in visit_en_cls: 118 | start_pos[len(visit_en_cls)] = i 119 | visit_en_cls.append(c) 120 | 121 | start_pos[len(visit_en_cls)] = len(en_cls) 122 | 123 | used_cls = [] 124 | for j, cls_index in enumerate(visit_en_cls): 125 | if cls_index >= 1601: 126 | found_word = False 127 | all_fgs = [fg_index for (_, fg_index) in self.copy_vocab.d_to_w_group[cls_index]] 128 | for fg_index in all_fgs: 129 | fg_ch_list = self.copy_vocab.token_fg_w[fg_index] 130 | s1 = '&'.join([str(f) for f in fg_ch_list]) 131 | 132 | for ch_idx, first_ch in enumerate(c_input_ids): 133 | if first_ch == fg_ch_list[0]: 134 | s2 = '&'.join([str(f) for f in c_input_ids[ch_idx: ch_idx + len(fg_ch_list)]]) 135 | if s1 == s2: 136 | if ch_idx + len(fg_ch_list) >= len(c_input_ids) - 1 or c_input_ids[ch_idx + len(fg_ch_list)] not in self.attachable_index: 137 | mention_flag[:ch_idx + len(fg_ch_list), start_pos[j]:start_pos[j+1]] = 1 138 | if not self.config.static_mf: 139 | mention_flag[ch_idx + len(fg_ch_list):, start_pos[j]:start_pos[j+1]] = 2 140 | else: 141 | mention_flag[ch_idx + len(fg_ch_list):, start_pos[j]:start_pos[j+1]] = 1 142 | used_cls.append(cls_index) 143 | found_word = True 144 | break 145 | 146 | if found_word: break 147 | else: 148 | mention_flag[:, start_pos[j]:start_pos[j+1]] = 3 149 | 150 | self.cap_cache[img_id][cap]['mention_flag'] = mention_flag 151 | self.cap_cache[img_id][cap]['used_cls'] = list(set(used_cls)) 152 | 153 | def process_obj(self, img_id): 154 | image_features, box_np, class_np = self._image_features_reader[img_id] 155 | if self.config.use_copy_obj: 156 | copy_image_features, copy_box_np, copy_class_np = self._copy_image_features_reader[img_id] 157 | self.process_input_for_encoder(img_id, image_features, box_np, class_np, copy_image_features, copy_box_np, copy_class_np) 158 | else: 159 | self.process_input_for_encoder(img_id, image_features, box_np, class_np) 160 | 161 | def process_input_for_encoder(self, img_id, obj_features, obj_boxes, obj_cls, copy_obj_features=None, copy_obj_boxes=None, copy_obj_cls=None): 162 | obj_size = obj_features.shape[0] 163 | 164 | cls2objindex = {} 165 | for obj_index, cls_ in enumerate(obj_cls): 166 | if self._blacklist_categories is not None and self.copy_vocab.id_to_category[cls_].lower() in self._blacklist_categories: 167 | cls_ = 0 168 | if cls_ not in cls2objindex: 169 | cls2objindex[cls_] = [] 170 | cls2objindex[cls_].append((obj_index, obj_boxes[obj_index, 6])) 171 | 172 | if self.config.use_copy_obj: 173 | for obj_index, cls_ in enumerate(copy_obj_cls): 174 | if self._blacklist_categories is not None and self.copy_vocab.id_to_category[cls_].lower() in self._blacklist_categories: 175 | cls_ = 0 176 | if cls_ not in cls2objindex: 177 | cls2objindex[cls_] = [] 178 | cls2objindex[cls_].append((obj_index + obj_size, copy_obj_boxes[obj_index, 6])) 179 | 180 | for cls_ in cls2objindex: 181 | cls2objindex[cls_] = sorted(cls2objindex[cls_], key=lambda x: x[1]) 182 | 183 | encoder_input_ids = [] 184 | encoder_img_mask = [] 185 | encoder_cls = [] 186 | rel_position = [] 187 | img_order = [] 188 | 189 | key_order = sorted([k for k in cls2objindex.keys()]) 190 | for cls_ in key_order: 191 | rel_position_list = [] 192 | 193 | if cls_ == 0: 194 | input_ids = [self.background_index] 195 | else: 196 | input_ids = self.copy_vocab.token_class[cls_] 197 | 198 | for img_i in range(len(cls2objindex[cls_])): 199 | each_img_rel = [48] * len(cls2objindex[cls_]) 200 | each_img_rel[img_i] = 0 201 | each_img_rel += [31 + get_position_emb_index(w_i + 1) for w_i in range(len(input_ids))] 202 | rel_position_list.append(each_img_rel) 203 | for word_i in range(len(input_ids)): 204 | each_word_rel = [49] * len(cls2objindex[cls_]) 205 | each_word_rel += [0 if ii == word_i else get_position_emb_index(abs(ii - word_i), right=ii > word_i) for ii in range(len(input_ids))] 206 | rel_position_list.append(each_word_rel) 207 | rel_position_np = np.array(rel_position_list, dtype=np.int64) 208 | assert rel_position_np.shape[0] == rel_position_np.shape[1] 209 | rel_position.append(rel_position_np) 210 | 211 | sub_span = [self.img_index] * len(cls2objindex[cls_]) + input_ids 212 | encoder_input_ids += sub_span 213 | encoder_img_mask += [1] * len(cls2objindex[cls_]) + [0] * len(input_ids) 214 | encoder_cls += [cls_] * len(sub_span) 215 | img_order += [o[0] for o in cls2objindex[cls_]] 216 | encoder_input_ids.append(self.tokenizer.eos_token_id) 217 | encoder_img_mask.append(0) 218 | encoder_cls.append(0) 219 | 220 | dim_shape = sum([r.shape[0] for r in rel_position]) 221 | encoder_rel_position_np = np.ones((dim_shape + 1, dim_shape + 1), dtype=np.int64) * 54 222 | if not self.config.use_orginal_enc_pos_embs: 223 | accumulate_dim = 0 224 | rel_start_position = [] 225 | for r in rel_position: 226 | encoder_rel_position_np[accumulate_dim: accumulate_dim + r.shape[0], accumulate_dim: accumulate_dim + r.shape[0]] = r 227 | rel_start_position.append(accumulate_dim) 228 | accumulate_dim += r.shape[0] 229 | encoder_rel_position_np[-1, -1] = 0 230 | 231 | for i, ri in enumerate(rel_position): 232 | for j, rj in enumerate(rel_position): 233 | if i == j: continue 234 | i_vis_end = len(cls2objindex[key_order[i]]) 235 | j_vis_end = len(cls2objindex[key_order[j]]) 236 | for i_index in range(ri.shape[0]): 237 | for j_index in range(rj.shape[0]): 238 | if i_index < i_vis_end and j_index < j_vis_end: 239 | encoder_rel_position_np[rel_start_position[i] + i_index, rel_start_position[j] + j_index] = 50 240 | elif i_index < i_vis_end and j_index >= j_vis_end: 241 | encoder_rel_position_np[rel_start_position[i] + i_index, rel_start_position[j] + j_index] = 51 242 | elif i_index >= i_vis_end and j_index < j_vis_end: 243 | encoder_rel_position_np[rel_start_position[i] + i_index, rel_start_position[j] + j_index] = 52 244 | elif i_index >= i_vis_end and j_index >= j_vis_end: 245 | encoder_rel_position_np[rel_start_position[i] + i_index, rel_start_position[j] + j_index] = 53 246 | 247 | obj_feature_np = np.zeros((len(encoder_input_ids), obj_features.shape[-1]), dtype=np.float32) 248 | obj_box_np = np.zeros((len(encoder_input_ids), obj_boxes.shape[-1]), dtype=np.float32) 249 | obj_index = 0 250 | for i, m in enumerate(encoder_img_mask): 251 | if m == 1: 252 | if img_order[obj_index] < obj_size: 253 | cur_index = img_order[obj_index] 254 | obj_feature_np[i] = obj_features[cur_index] 255 | obj_box_np[i] = obj_boxes[cur_index] 256 | else: 257 | cur_index = img_order[obj_index] - obj_size 258 | obj_feature_np[i] = copy_obj_features[cur_index] 259 | obj_box_np[i] = copy_obj_boxes[cur_index] 260 | obj_index += 1 261 | 262 | self.obj_cache[img_id] = { 263 | "encoder_rel_position": encoder_rel_position_np, 264 | "encoder_input_ids": np.array(encoder_input_ids, dtype=np.int64), 265 | "encoder_cls": np.array(encoder_cls, dtype=np.int64), 266 | "encoder_img_mask": np.array(encoder_img_mask, dtype=np.float32), 267 | "obj_feature_np": obj_feature_np, 268 | "obj_box_np": obj_box_np, 269 | "image_id": img_id 270 | } 271 | 272 | def __getitem__(self, index): 273 | if self.is_training: 274 | img_id, cap, gt = self._captions_reader[index] 275 | else: 276 | img_id = self._image_ids[index] 277 | if self._captions_reader is not None: 278 | gt = self._captions_reader.get_gt_by_image_id(img_id) 279 | else: 280 | gt = None 281 | 282 | cap = None 283 | 284 | if img_id not in self.obj_cache: 285 | self.process_obj(img_id) 286 | 287 | item = self.obj_cache[img_id] 288 | 289 | if gt is not None: 290 | item['gt'] = gt 291 | 292 | if cap is not None: 293 | if cap not in self.cap_cache: 294 | self.process_cap(img_id, cap) 295 | item['cap'] = self.cap_cache[img_id][cap]['input_ids'] 296 | item['mention_flag'] = self.cap_cache[img_id][cap]['mention_flag'] 297 | item['mention_flag'][:, -1] = 0 298 | else: 299 | item['mention_flag'] = self.global_obj_cache[img_id] 300 | item['mention_flag'][:, -1] = 0 301 | 302 | return item 303 | 304 | 305 | def data_wrapper(config, dataset): 306 | new_dataset = {'image_ids': [int(d['image_id']) for d in dataset]} 307 | new_dataset['gt'] = [d['gt'] for d in dataset] 308 | 309 | encoder_input_ids, encoder_mask = process_tensor([d['encoder_input_ids'] for d in dataset], 0, output_mask=True) 310 | encoder_img_mask = process_tensor([d['encoder_img_mask'] for d in dataset], 0) 311 | encoder_cls = process_tensor([d['encoder_cls'] for d in dataset], 0) 312 | obj_feature = process_tensor([d['obj_feature_np'] for d in dataset], 2048) 313 | obj_box = process_tensor([d['obj_box_np'] for d in dataset], 8) 314 | 315 | new_dataset['encoder_input_ids'] = encoder_input_ids 316 | new_dataset['encoder_mask'] = encoder_mask 317 | new_dataset['encoder_img_mask'] = encoder_img_mask 318 | new_dataset['encoder_obj_feature'] = obj_feature 319 | new_dataset['encoder_obj_box'] = obj_box 320 | new_dataset['encoder_cls'] = encoder_cls 321 | 322 | max_gen_len = 1 323 | if 'cap' in dataset[0]: 324 | cap_decoder_input_ids, cap_decoder_mask = process_tensor([d['cap'] for d in dataset], 0, output_mask=True) 325 | cap_decoder_input_ids[cap_decoder_mask == 0] = -100 326 | new_dataset['cap_decoder_input_ids'] = cap_decoder_input_ids 327 | max_gen_len = cap_decoder_input_ids.size(1) 328 | 329 | batch_size = len(dataset) 330 | max_encoder_len = encoder_input_ids.size(1) 331 | mention_flag = np.zeros((batch_size, max_gen_len, max_encoder_len), dtype=np.int64) 332 | for i, d in enumerate(dataset): 333 | mention_flag[i, :d['mention_flag'].shape[0], :d['mention_flag'].shape[1]] = d['mention_flag'] 334 | new_dataset['mention_flag'] = torch.from_numpy(mention_flag) 335 | 336 | encoder_rel_position = np.zeros((batch_size, max_encoder_len, max_encoder_len), dtype=np.int64) 337 | for i, d in enumerate(dataset): 338 | encoder_rel_position[i, :d['encoder_rel_position'].shape[0], :d['encoder_rel_position'].shape[1]] = d['encoder_rel_position'] 339 | new_dataset['encoder_rel_position'] = torch.from_numpy(encoder_rel_position) 340 | 341 | return new_dataset 342 | 343 | def get_data_loader(config, dataset): 344 | collate_fn = lambda d: data_wrapper(config, d) 345 | return DataLoader(dataset, 346 | batch_size=config.batch_size, 347 | num_workers=0, 348 | collate_fn=collate_fn 349 | ) 350 | -------------------------------------------------------------------------------- /train_T5.py: -------------------------------------------------------------------------------- 1 | from dataset.vocabulary import T5CopyVocabulary 2 | from dataset.dataset import CommonGenDataset, get_data_loader 3 | import argparse 4 | import torch 5 | import torch.nn as nn 6 | from config import Config 7 | import numpy as np 8 | from transformers import T5Tokenizer 9 | from checkpointing import CheckpointManager 10 | from t5 import get_lm_representation 11 | import utils 12 | from tqdm import tqdm 13 | import math 14 | import os, sys 15 | from speaksee import evaluation 16 | import spacy 17 | import random 18 | from constraint import CBSConstraint 19 | from dataset.diversity import distinct_n 20 | import json 21 | 22 | nlp = spacy.load("en_core_web_sm") 23 | nlp.pipeline = [('tagger', nlp.tagger)] 24 | 25 | def tokenize(_list): 26 | new_dict = {} 27 | for item in _list: 28 | if isinstance(item, list): 29 | new_sentence_list = [] 30 | for sentence in item: 31 | a = '' 32 | for token in nlp(sentence): 33 | a += token.text 34 | a += ' ' 35 | new_sentence_list.append(a.rstrip()) 36 | new_dict[len(new_dict)] = new_sentence_list 37 | else: 38 | a = '' 39 | for token in nlp(item): 40 | a += token.text 41 | a += ' ' 42 | new_dict[len(new_dict)] = [a] 43 | 44 | return new_dict 45 | 46 | def get_coverage_score(gt_concepts, pred): 47 | covs = [] 48 | total_cs, match_cs = 0, 0 49 | for cs, p in zip(gt_concepts, pred): 50 | p = p.lower() 51 | if p.endswith('.'): 52 | p = p[:-1] 53 | p = p.strip() 54 | cs = set(cs) 55 | lemmas = set() 56 | for token in nlp(p): 57 | lemmas.add(token.lemma_) 58 | match_cs += len(lemmas&cs) 59 | total_cs += len(cs) 60 | cov = len(lemmas&cs)/len(cs) 61 | covs.append(cov) 62 | return 100 * sum(covs) / len(covs), 100 * match_cs / total_cs 63 | 64 | parser = argparse.ArgumentParser("Train a CommonGen T5") 65 | parser.add_argument( 66 | "--config", required=True, help="Path to a config file with all configuration parameters." 67 | ) 68 | parser.add_argument( 69 | "--config-override", 70 | default=[], 71 | nargs="*", 72 | help="A sequence of key-value pairs specifying certain config arguments (with dict-like " 73 | "nesting) using a dot operator. The actual config will be updated and recorded in " 74 | "the serialization directory.", 75 | ) 76 | parser.add_argument( 77 | "--serialization-dir", 78 | default=None, 79 | help="Path to a (non-existent) directory for serializing checkpoints and tensorboard logs.", 80 | ) 81 | parser.add_argument( 82 | "--start-from-checkpoint", 83 | default=None, 84 | help="Path to load checkpoint and continue training [only supported for module_training].", 85 | ) 86 | parser.add_argument( 87 | "--output-path", 88 | default=None, 89 | help="Path to save output captions", 90 | ) 91 | parser.add_argument( 92 | "--seen-constraint-path", 93 | default=None, 94 | help="Path to novel constraints", 95 | ) 96 | group = parser.add_mutually_exclusive_group() 97 | group.add_argument('--train', action='store_true') 98 | group.add_argument('--validation', action='store_true') 99 | group.add_argument('--test', action='store_true') 100 | 101 | def run_eval(_C, model, eval_data_iter, copy_vocab, tokenizer, device, decoder_start_token_id, only_test=False, decode_constraint=None, output_path=None, seen_constraint_path=None): 102 | model.eval() 103 | gts, pred, gt_concepts = [], [], [] 104 | cls_recall = [0, 0] 105 | novel_cls_recall = [0, 0] 106 | seen_cls_recall = [0, 0] 107 | 108 | seen_constraint_list = [] 109 | if seen_constraint_path is not None: 110 | with open(seen_constraint_path) as out: 111 | for l in out: 112 | l = l.strip() 113 | seen_constraint_list.append(l) 114 | 115 | with torch.no_grad(): 116 | for batch in tqdm(eval_data_iter): 117 | for n in batch: 118 | if n not in ['gt', 'gt_concepts']: 119 | batch[n] = batch[n].to(device) 120 | 121 | cls_used = [] 122 | for i in range(batch['concept_cls'].size(0)): 123 | gt_cls = [] 124 | for j in range(batch['concept_cls'].size(1)): 125 | ix = batch['concept_cls'][i][j].item() 126 | if ix > 0: 127 | gt_cls.append(ix) 128 | cls_used.append(set(gt_cls)) 129 | 130 | if decode_constraint is not None: 131 | constraint_dict = {} 132 | for i in range(batch['concept_cls'].size(0)): 133 | constraint_dict[i] = [] 134 | for cls_index in cls_used[i]: 135 | c = [] 136 | for (_, fg_idx) in copy_vocab.d_to_w_group[cls_index]: 137 | c.append(copy_vocab.token_fg_w[fg_idx]) 138 | constraint_dict[i].append(c) 139 | 140 | state_transform_list = [] 141 | state_num_list = [] 142 | for i in range(batch['concept_cls'].size(0)): 143 | state_matrix, state_num = decode_constraint.get_state_matrix(_C.vocab_size, constraint_dict[i], i) 144 | state_transform_list.append(state_matrix) 145 | state_num_list.append(state_num) 146 | max_size = max(state_num_list) 147 | state_transform_list = [s[:, :max_size, :max_size]for s in state_transform_list] 148 | state_transition_np = np.concatenate(state_transform_list, axis=0) 149 | state_transition = torch.from_numpy(state_transition_np).bool().to(device) 150 | else: 151 | state_transition = None 152 | 153 | outputs = model.search( 154 | input_ids=batch['input_ids'], 155 | attention_mask=batch['attention_mask'], 156 | decoder_copy_pos=batch['copy_pos'], 157 | decoder_concept_cls=batch['concept_cls'], 158 | decoder_copy_mention_flag=batch['copy_mention_flag'], 159 | decoder_mention_flag=batch['decoder_mention_flag'], 160 | decoder_cls_on_input=batch['cls_on_input'], 161 | state_transition=state_transition, 162 | num_beams=5, 163 | length_penalty=1.0, 164 | max_length=25, 165 | min_length=2, 166 | no_repeat_ngram_size=3, 167 | early_stopping=True, 168 | decoder_start_token_id=decoder_start_token_id 169 | ) 170 | 171 | if decode_constraint is not None: 172 | outputs = decode_constraint.select_state_func(outputs, [i for i in range(batch['concept_cls'].size(0))]) 173 | 174 | dec = [tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=False) for g in outputs] 175 | for d, gt in zip(dec, batch['gt']): 176 | gts.append(gt) 177 | pred.append(d) 178 | gt_concepts += batch['gt_concepts'] 179 | 180 | N, D = outputs.size() 181 | for i in range(N): 182 | gt_cls = cls_used[i] 183 | 184 | mention_cls = [] 185 | if _C.use_pointer: 186 | for j in range(D): 187 | ix = outputs[i][j].item() 188 | if ix >= _C.vocab_size: 189 | ix = ix - _C.vocab_size 190 | _cls = copy_vocab.i_to_cls[ix] 191 | mention_cls.append(copy_vocab.id_to_category[_cls]) 192 | else: 193 | w_list = dec[i].split() 194 | if w_list[-1].endswith('.'): 195 | w_list[-1] = w_list[-1][:-1] 196 | w_list = [w[:-2] if w.endswith("'s") else w for w in w_list] 197 | w_list = [w[:-1] if w.endswith(",") else w for w in w_list] 198 | for gt_c in gt_cls: 199 | for (w, _) in copy_vocab.d_to_w_group[gt_c]: 200 | if w in w_list: 201 | mention_cls.append(gt_c) 202 | break 203 | 204 | mention_cls = set(mention_cls) 205 | 206 | novel_gt = set([c for c in gt_cls if copy_vocab.id_to_category[c] not in seen_constraint_list]) 207 | seen_gt = set([c for c in gt_cls if copy_vocab.id_to_category[c] in seen_constraint_list]) 208 | 209 | novel_mention = set([c for c in mention_cls if copy_vocab.id_to_category[c] not in seen_constraint_list]) 210 | seen_mention = set([c for c in mention_cls if copy_vocab.id_to_category[c] in seen_constraint_list]) 211 | 212 | cls_recall[1] += len(gt_cls) 213 | cls_recall[0] += len(gt_cls & mention_cls) 214 | 215 | novel_cls_recall[1] += len(novel_gt) 216 | seen_cls_recall[1] += len(seen_gt) 217 | 218 | novel_cls_recall[0] += len(novel_gt & novel_mention) 219 | seen_cls_recall[0] += len(seen_gt & seen_mention) 220 | 221 | 222 | 223 | # if len(gt_cls - (gt_cls & mention_cls)) > 0 and only_test: 224 | # remaining_cls = gt_cls - (gt_cls & mention_cls) 225 | # print([copy_vocab.id_to_category[c] for c in gt_cls], [copy_vocab.id_to_category[c] for c in remaining_cls], dec[i]) 226 | # print([copy_vocab.id_to_category[c] for c in gt_cls], dec[i]) 227 | 228 | for p in pred[:20]: 229 | print(p) 230 | 231 | if output_path is not None: 232 | output_list = [] 233 | for _id, out in enumerate(pred): 234 | output_list.append({"image_id": _id, "caption": out}) 235 | with open(output_path, 'w') as out: 236 | out.write(json.dumps(output_list)) 237 | 238 | gts = tokenize(gts) 239 | gen = tokenize(pred) 240 | 241 | coverage_score, overall_coverage = get_coverage_score(gt_concepts, pred) 242 | print("Coverage %.2f" % coverage_score) 243 | print("Macro Coverage %.2f" % overall_coverage) 244 | print("Token-Level Coverage %.2f" % (100 * cls_recall[0] / cls_recall[1])) 245 | if len(seen_constraint_list) > 0: 246 | print("Novel Token-Level Coverage %.2f" % (100 * novel_cls_recall[0] / novel_cls_recall[1])) 247 | print("Seen Token-Level Coverage %.2f" % (100 * seen_cls_recall[0] / seen_cls_recall[1])) 248 | 249 | 250 | diversity_sen = [v[0].split() for (_, v) in gen.items()] 251 | print("Diversity-1 %.2f" % distinct_n(diversity_sen, 1)) 252 | print("Diversity-2 %.2f" % distinct_n(diversity_sen, 2)) 253 | 254 | val_bleu, _ = evaluation.Bleu(n=4).compute_score(gts, gen) 255 | method = ['Blue_1', 'Bleu_2', 'Bleu_3', 'Bleu_4'] 256 | metric_dict = {} 257 | for metric, score in zip(method, val_bleu): 258 | metric_dict['metric'] = {'entire': score * 100} 259 | print('%s %.2f' % (metric, score * 100)) 260 | 261 | val_meteor, _ = evaluation.Meteor().compute_score(gts, gen) 262 | print('METEOR %.2f' % (val_meteor * 100)) 263 | 264 | val_rouge, _ = evaluation.Rouge().compute_score(gts, gen) 265 | print('ROUGE_L %.2f' % (val_rouge * 100)) 266 | 267 | val_cider, _ = evaluation.Cider().compute_score(gts, gen) 268 | print('CIDEr %.2f' % (val_cider * 100)) 269 | 270 | val_spice, _ = evaluation.Spice().compute_score(gts, gen) 271 | print('SPICE %.2f' % (val_spice * 100)) 272 | 273 | metric_dict.update({"CIDEr": {"entire": val_cider}, "ROUGE_L": {"entire": val_rouge}, "METEOR": {"entire": val_meteor}, "SPICE": {"entire": val_spice}}) 274 | return metric_dict 275 | 276 | 277 | if __name__ == "__main__": 278 | _A = parser.parse_args() 279 | 280 | _C = Config(_A.config, _A.config_override) 281 | 282 | np.random.seed(_C.random_seed) 283 | random.seed(_C.random_seed) 284 | torch.manual_seed(_C.random_seed) 285 | torch.cuda.manual_seed_all(_C.random_seed) 286 | torch.backends.cudnn.benchmark = False 287 | torch.backends.cudnn.deterministic = True 288 | 289 | 290 | device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') 291 | 292 | tokenizer = T5Tokenizer.from_pretrained(_C.lm_type, cache_dir='.') 293 | copy_vocab = T5CopyVocabulary(_C.copy_vocab_path, tokenizer) 294 | lm = get_lm_representation(_C, tokenizer, copy_vocab) 295 | model = lm['t5'] 296 | model = model.to(device) 297 | _C.vocab_size = model.config.vocab_size 298 | 299 | if len(_C.decode_constrain) > 0: 300 | decode_constraint = CBSConstraint(_C.decode_constrain, 5) 301 | else: 302 | decode_constraint = None 303 | 304 | total_parameter_count = 0 305 | trainable_parameter_count = 0 306 | for p in model.parameters(): 307 | total_parameter_count += p.numel() 308 | if p.requires_grad: 309 | trainable_parameter_count += p.numel() 310 | print('Total Parameter Count %d' % total_parameter_count) 311 | print('Trainable Parameter Count %d' % trainable_parameter_count) 312 | 313 | if _A.train: 314 | train_data = CommonGenDataset(_C, _C.train_path, tokenizer, copy_vocab, model.config.decoder_start_token_id, attachable_index=lm['attachable_index'], is_training=True) 315 | train_data_loader = get_data_loader(train_data, _C.batch_size) 316 | train_iter = iter(train_data_loader) 317 | 318 | dev_data = CommonGenDataset(_C, _C.dev_path if (_A.validation or _A.train) else _C.test_path, tokenizer, copy_vocab, model.config.decoder_start_token_id) 319 | dev_data_loader = get_data_loader(dev_data, _C.batch_size) 320 | 321 | print(_C) 322 | for arg in vars(_A): 323 | print("{:<20}: {}".format(arg, getattr(_A, arg))) 324 | 325 | if _A.validation or _A.test: 326 | if torch.cuda.is_available(): 327 | model.load_state_dict(torch.load(os.path.join(_A.start_from_checkpoint, 'model-best.pth'))['model'], strict=False) 328 | else: 329 | model.load_state_dict(torch.load(os.path.join(_A.start_from_checkpoint, 'model-best.pth'), map_location=torch.device('cpu'))['model'], strict=False) 330 | 331 | run_eval(_C, model, dev_data_loader, copy_vocab, tokenizer, device, model.config.decoder_start_token_id, only_test=True, decode_constraint=decode_constraint, output_path=_A.output_path, seen_constraint_path=_A.seen_constraint_path) 332 | 333 | 334 | if _A.train: 335 | _C.num_training_steps = len(train_iter) * _C.max_epoch / _C.gradient_accumulation_steps 336 | epoch_num = math.ceil(_C.num_training_steps / _C.checkpoint_every_step) 337 | 338 | checkpoint_manager = CheckpointManager(model, _A.serialization_dir, mode="max") 339 | optimizer = utils.build_optimizer(_C, model) 340 | 341 | os.makedirs(_A.serialization_dir, exist_ok=True) 342 | _C.dump(os.path.join(_A.serialization_dir, "config.yml")) 343 | 344 | eval_every = _C.checkpoint_every_step * _C.gradient_accumulation_steps 345 | total_step = 0 346 | 347 | for epoch in range(epoch_num): 348 | print('EPOCH %d / %d' % (epoch + 1, epoch_num)) 349 | run_step = eval_every if total_step + eval_every < len(train_iter) * _C.max_epoch else len(train_iter) * _C.max_epoch - total_step 350 | model.train() 351 | 352 | with tqdm(total=math.ceil(run_step / _C.gradient_accumulation_steps), file=sys.stdout) as pbar: 353 | for step in range(run_step): 354 | try: 355 | batch = next(train_iter) 356 | except: 357 | train_iter = iter(train_data_loader) 358 | batch = next(train_iter) 359 | 360 | for n in batch: 361 | if n not in ['gt', 'gt_concepts']: 362 | batch[n] = batch[n].to(device) 363 | total_step += 1 364 | # optimizer.zero_grad() 365 | outputs = model( 366 | input_ids=batch['input_ids'], 367 | attention_mask=batch['attention_mask'], 368 | decoder_copy_pos=batch['copy_pos'], 369 | decoder_concept_cls=batch['concept_cls'], 370 | decoder_input_ids=batch['decoder_input_ids'], 371 | decoder_attention_mask=batch['decoder_input_mask'], 372 | decoder_copy_mention_flag=batch['copy_mention_flag'], 373 | decoder_mention_flag=batch['decoder_mention_flag'], 374 | decoder_cls_on_input=batch['cls_on_input'], 375 | labels=batch['labels'] 376 | ) 377 | loss = outputs.loss 378 | loss = loss / _C.gradient_accumulation_steps 379 | loss.backward() 380 | 381 | if _C.grad_clip_value > 0: 382 | torch.nn.utils.clip_grad_value_(model.parameters(), _C.grad_clip_value) 383 | if (step + 1) % _C.gradient_accumulation_steps == 0: 384 | optimizer.step() 385 | if torch.cuda.is_initialized(): 386 | torch.cuda.synchronize() 387 | pbar.set_description("loss %.2f" % (loss.item() * _C.gradient_accumulation_steps)) 388 | pbar.update(1) 389 | optimizer.zero_grad() 390 | 391 | eval_result = run_eval(_C, model, dev_data_loader, copy_vocab, tokenizer, device, model.config.decoder_start_token_id) 392 | checkpoint_manager.step(eval_result["CIDEr"]["entire"]) 393 | -------------------------------------------------------------------------------- /train_COCO_T5.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import numpy as np 4 | import torch 5 | import torch.nn as nn 6 | import math 7 | from tqdm import tqdm 8 | import sys 9 | import json 10 | import utils 11 | import random 12 | 13 | from config import Config 14 | from dataset.coco_dataset import COCODataset, get_data_loader 15 | from checkpointing import CheckpointManager 16 | from dataset import evaluation 17 | from t5 import get_lm_representation 18 | from dataset.vocabulary import T5CopyVocabulary 19 | from transformers import T5Tokenizer 20 | from dataset.EvalAI import NocapsEvaluator 21 | from constraint import CBSConstraint 22 | from dataset.diversity import distinct_n 23 | 24 | parser = argparse.ArgumentParser("Train a Transformer Captioner with RL") 25 | parser.add_argument( 26 | "--config", required=True, help="Path to a config file with all configuration parameters." 27 | ) 28 | parser.add_argument( 29 | "--eval-split", help="Path to the evaluation split" 30 | ) 31 | parser.add_argument( 32 | "--in-memory", action="store_true", help="Whether to load image features in memory." 33 | ) 34 | parser.add_argument( 35 | "--serialization-dir", 36 | default=None, 37 | help="Path to a (non-existent) directory for serializing checkpoints and tensorboard logs.", 38 | ) 39 | parser.add_argument( 40 | "--config-override", 41 | default=[], 42 | nargs="*", 43 | help="A sequence of key-value pairs specifying certain config arguments (with dict-like " 44 | "nesting) using a dot operator. The actual config will be updated and recorded in " 45 | "the serialization directory.", 46 | ) 47 | parser.add_argument( 48 | "--start-from-checkpoint", 49 | default=None, 50 | help="Path to load checkpoint and continue training [only supported for module_training].", 51 | ) 52 | parser.add_argument( 53 | "--output-path", 54 | default=None, 55 | help="Path to save output captions", 56 | ) 57 | parser.add_argument( 58 | "--cbs-class-path", 59 | default=None, 60 | help="Path to a (non-existent) directory for CBS class path.", 61 | ) 62 | parser.add_argument( 63 | "--novel-constraint-path", 64 | default=None, 65 | help="Path to novel constraints", 66 | ) 67 | group = parser.add_mutually_exclusive_group() 68 | group.add_argument('--train', action='store_true') 69 | group.add_argument('--validation', action='store_true') 70 | group.add_argument('--test', action='store_true') 71 | parser.add_argument('--port', type=int, default=8083, help='port for server to run') 72 | parser.add_argument('--host', type=str, default='localhost', help='host for server to run') 73 | 74 | def run_eval(_C, model, eval_data_iter, tokenzier, copy_vocab, device, output_path=None, test=False, full_eval=False, decode_constraint=None, novel_constraint_path=None): 75 | 76 | model.eval() 77 | predictions = [] 78 | gen, gts, img_ids = [], [], [] 79 | mentioned_cls = [] 80 | novel_mentioned_cls = [] 81 | used_cls = {} 82 | macro_mention = [0, 0] 83 | novel_macro_mention = [0, 0] 84 | 85 | novel_constraints = [] 86 | if novel_constraint_path is not None: 87 | with open(novel_constraint_path) as out: 88 | for l in out: 89 | novel_constraints.append(int(l.strip())) 90 | 91 | 92 | with torch.no_grad(): 93 | for batch in tqdm(eval_data_iter): 94 | for n in batch: 95 | if n in ['gt', 'image_ids']: continue 96 | batch[n] = batch[n].to(device) 97 | 98 | encoder_cls = batch['encoder_cls'].detach().cpu().numpy() 99 | mention_flag = batch['mention_flag'].detach().cpu().numpy() 100 | 101 | cls_used = [] 102 | for b_idx in range(encoder_cls.shape[0]): 103 | e_cls = encoder_cls[b_idx].tolist() 104 | mf = mention_flag[b_idx, 0].tolist() 105 | visited_cls = set() 106 | for cls_, m in zip(e_cls, mf): 107 | if m == 1: visited_cls.add(cls_) 108 | cls_used.append(list(visited_cls)) 109 | 110 | if decode_constraint is not None: 111 | constraint_dict = {} 112 | for i, image_id in enumerate(batch['image_ids']): 113 | constraint_dict[image_id] = [] 114 | for cls_index in cls_used[i]: 115 | c = [] 116 | for (_, fg_idx) in copy_vocab.d_to_w_group[cls_index]: 117 | c.append(copy_vocab.token_fg_w[fg_idx]) 118 | constraint_dict[image_id].append(c) 119 | 120 | state_transform_list = [] 121 | state_num_list = [] 122 | for image_id in batch['image_ids']: 123 | state_matrix, state_num = decode_constraint.get_state_matrix(_C.vocab_size, constraint_dict[image_id], image_id) 124 | state_transform_list.append(state_matrix) 125 | state_num_list.append(state_num) 126 | max_size = max(state_num_list) 127 | state_transform_list = [s[:, :max_size, :max_size]for s in state_transform_list] 128 | state_transition = torch.from_numpy(np.concatenate(state_transform_list, axis=0)).bool().to(device) 129 | else: 130 | state_transition = None 131 | 132 | outputs = model.search( 133 | input_ids=batch['encoder_input_ids'], 134 | attention_mask=batch['encoder_mask'], 135 | encoder_img_mask=batch['encoder_img_mask'], 136 | encoder_obj_feature=batch['encoder_obj_feature'], 137 | encoder_obj_box=batch['encoder_obj_box'], 138 | encoder_relative_pos_index=batch['encoder_rel_position'], 139 | decoder_mention_flag=batch['mention_flag'], 140 | decoder_cls_on_input=batch['encoder_cls'], 141 | state_transition=state_transition, 142 | num_beams=5, 143 | length_penalty=0.6, 144 | max_length=_C.max_generation_len, 145 | min_length=2, 146 | no_repeat_ngram_size=3, 147 | early_stopping=True 148 | ) 149 | 150 | if decode_constraint is not None: 151 | outputs = decode_constraint.select_state_func(outputs, batch['image_ids']) 152 | 153 | out = [tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=False) for g in outputs] 154 | gen += out 155 | if not test: 156 | gts += batch['gt'] 157 | img_ids += batch['image_ids'] 158 | 159 | # for index, o in enumerate(out): 160 | # print([copy_vocab.d_to_w_group[cls_index] for cls_index in cls_used[index]]) 161 | # print(o) 162 | # print("----------------") 163 | 164 | for b_idx in range(encoder_cls.shape[0]): 165 | cls_count = 0 166 | total_count = 0 167 | novel_total_count = 0 168 | novel_cls_count = 0 169 | single_img_used_cls = [] 170 | for cls_ in cls_used[b_idx]: 171 | total_count += 1 172 | if cls_ in novel_constraints: 173 | novel_total_count += 1 174 | 175 | for w, _ in copy_vocab.d_to_w_group[cls_]: 176 | if w in out[b_idx]: 177 | cls_count += 1 178 | if cls_ in novel_constraints: 179 | novel_cls_count += 1 180 | single_img_used_cls.append(cls_) 181 | break 182 | used_cls[batch['image_ids'][b_idx]] = single_img_used_cls 183 | 184 | # if cls_count < total_count and total_count > 0: 185 | # print([copy_vocab.d_to_w_group[cls_] for cls_ in cls_used[b_idx]]) 186 | # print(batch['image_ids'][b_idx], out[b_idx]) 187 | 188 | macro_mention[0] += cls_count 189 | macro_mention[1] += total_count 190 | if total_count > 0: 191 | mentioned_cls.append(100 * cls_count / total_count) 192 | 193 | novel_macro_mention[0] += novel_cls_count 194 | novel_macro_mention[1] += novel_total_count 195 | if novel_total_count > 0: 196 | novel_mentioned_cls.append(100 * cls_count / total_count) 197 | 198 | for c in gen[:20]: 199 | print(c) 200 | 201 | predictions = [] 202 | for img_id, p in zip(img_ids, gen): 203 | predictions.append({'image_id': img_id, "caption": p}) 204 | 205 | if output_path is not None: 206 | with open(output_path, 'w') as out: 207 | out.write(json.dumps(predictions) + '\n') 208 | 209 | with open('used_cls.txt', 'w') as out: 210 | for c in used_cls: 211 | list_ = [c] + used_cls[c] 212 | list_ = [str(s) for s in list_] 213 | out.write(','.join(list_) + '\n') 214 | 215 | if len(mentioned_cls) > 0 and macro_mention[1] > 0: 216 | print("Averaged Mentione Ratio %.2f" % (sum(mentioned_cls) / len(mentioned_cls))) 217 | print("Macro Mentione Ratio %.2f" % (100 * macro_mention[0] / macro_mention[1])) 218 | 219 | if len(novel_constraints) > 0: 220 | print("Averaged Novel Mentione Ratio %.2f" % (sum(novel_mentioned_cls) / len(novel_mentioned_cls))) 221 | print("Macro Novel Mentione Ratio %.2f" % (100 * novel_macro_mention[0] / novel_macro_mention[1])) 222 | 223 | 224 | if not test: 225 | if not _C.external_eval: 226 | gts = evaluation.PTBTokenizer.tokenize(gts) 227 | gen = evaluation.PTBTokenizer.tokenize(gen) 228 | 229 | diversity_sen = [v[0].split() for (_, v) in gen.items()] 230 | print("Diversity-1 %.2f" % distinct_n(diversity_sen, 1)) 231 | print("Diversity-2 %.2f" % distinct_n(diversity_sen, 2)) 232 | 233 | val_bleu, _ = evaluation.Bleu(n=4).compute_score(gts, gen) 234 | method = ['Blue_1', 'Bleu_2', 'Bleu_3', 'Bleu_4'] 235 | metric_dict = {} 236 | for metric, score in zip(method, val_bleu): 237 | metric_dict[metric] = {'entire': score * 100} 238 | print('%s %.2f' % (metric, score * 100)) 239 | 240 | val_cider, _ = evaluation.Cider().compute_score(gts, gen) 241 | print('CIDEr %.2f' % (val_cider * 100)) 242 | metric_dict['CIDEr'] = {"entire": val_cider} 243 | 244 | if full_eval: 245 | val_spice, _ = evaluation.Spice().compute_score(gts, gen) 246 | print('SPICE %.2f' % (val_spice * 100)) 247 | 248 | val_meteor, _ = evaluation.Meteor().compute_score(gts, gen) 249 | print('METEOR %.2f' % (val_meteor * 100)) 250 | 251 | val_rouge, _ = evaluation.Rouge().compute_score(gts, gen) 252 | print('ROUGE_L %.2f' % (val_rouge * 100)) 253 | else: 254 | evaluator = NocapsEvaluator(phase="val" if val else "test") 255 | metric_dict = evaluator.evaluate(predictions) 256 | for metric_name in metric_dict: 257 | for domain in metric_dict[metric_name]: 258 | print(f"{metric_name} {domain}:", metric_dict[metric_name][domain]) 259 | print("") 260 | 261 | return metric_dict 262 | 263 | if __name__ == "__main__": 264 | 265 | _A = parser.parse_args() 266 | os.environ["HDF5_USE_FILE_LOCKING"] = "FALSE" 267 | _C = Config(_A.config, _A.config_override) 268 | 269 | np.random.seed(_C.random_seed) 270 | random.seed(_C.random_seed) 271 | torch.manual_seed(_C.random_seed) 272 | torch.cuda.manual_seed_all(_C.random_seed) 273 | torch.backends.cudnn.benchmark = False 274 | torch.backends.cudnn.deterministic = True 275 | 276 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 277 | 278 | tokenizer = T5Tokenizer.from_pretrained(_C.lm_type, cache_dir='.') 279 | copy_vocab = T5CopyVocabulary(_C.copy_vocab_path, tokenizer) 280 | lm = get_lm_representation(_C, tokenizer, copy_vocab) 281 | model = lm['t5'] 282 | model = model.to(device) 283 | attachable_index = lm['attachable_index'] 284 | _C.vocab_size = model.config.vocab_size 285 | 286 | if len(_C.decode_constrain) > 0: 287 | decode_constraint = CBSConstraint(_C.decode_constrain, 2) 288 | else: 289 | decode_constraint = None 290 | 291 | total_parameter_count = 0 292 | trainable_parameter_count = 0 293 | for p in model.parameters(): 294 | total_parameter_count += p.numel() 295 | if p.requires_grad: 296 | trainable_parameter_count += p.numel() 297 | print('Total Parameter Count %d' % total_parameter_count) 298 | print('Trainable Parameter Count %d' % trainable_parameter_count) 299 | 300 | if _C.use_copy_obj: 301 | train_copy_obj_h5_path = _C.train_copy_obj_h5_path 302 | dev_copy_obj_h5_path = _C.dev_copy_obj_h5_path 303 | test_copy_obj_h5_path = _C.test_copy_obj_h5_path 304 | else: 305 | train_copy_obj_h5_path, dev_copy_obj_h5_path, test_copy_obj_h5_path = None, None, None 306 | 307 | if _A.train: 308 | train_data = COCODataset(_C, _C.train_obj_h5_path, tokenizer, copy_vocab, attachable_index, caption_path=_C.train_path, copy_h5_path=train_copy_obj_h5_path, in_memory=_A.in_memory, is_training=True) 309 | train_data_loader = get_data_loader(_C, train_data) 310 | train_iter = iter(train_data_loader) 311 | 312 | if not _A.test: 313 | val_data = COCODataset(_C, _C.dev_obj_h5_path, tokenizer, copy_vocab, attachable_index, caption_path=_C.dev_path, copy_h5_path=dev_copy_obj_h5_path, is_training=False, in_memory=_A.in_memory, cbs_class_path=_A.cbs_class_path) 314 | else: 315 | val_data = COCODataset(_C, _C.test_obj_h5_path, tokenizer, copy_vocab, attachable_index, caption_path=_C.test_path, copy_h5_path=test_copy_obj_h5_path, is_training=False, in_memory=_A.in_memory, cbs_class_path=_A.cbs_class_path) 316 | val_data_loader = get_data_loader(_C, val_data) 317 | 318 | if _A.start_from_checkpoint is not None: 319 | if torch.cuda.is_available(): 320 | model.load_state_dict(torch.load(os.path.join(_A.start_from_checkpoint, 'model-best.pth'))['model'], strict=False) 321 | else: 322 | model.load_state_dict(torch.load(os.path.join(_A.start_from_checkpoint, 'model-best.pth'), map_location=torch.device('cpu'))['model'], strict=False) 323 | 324 | if _A.validation or _A.test: 325 | assert _A.start_from_checkpoint is not None, "evaluation must come along with pre-trained model" 326 | run_eval(_C, model, val_data_loader, tokenizer, copy_vocab, device, output_path=_A.output_path, test=_A.test, full_eval=True, decode_constraint=decode_constraint, novel_constraint_path=_A.novel_constraint_path) 327 | 328 | if _A.train: 329 | model.train() 330 | _C.num_training_steps = len(train_iter) * _C.max_epoch / _C.gradient_accumulation_steps 331 | epoch_num = math.ceil(_C.num_training_steps / _C.checkpoint_every_step) 332 | 333 | optimizer = utils.build_optimizer(_C, model) 334 | checkpoint_manager = CheckpointManager(model, _A.serialization_dir, mode="max") 335 | eval_every = _C.checkpoint_every_step * _C.gradient_accumulation_steps 336 | 337 | total_step = 0 338 | 339 | print(_C) 340 | for arg in vars(_A): 341 | print("{:<20}: {}".format(arg, getattr(_A, arg))) 342 | 343 | os.makedirs(_A.serialization_dir, exist_ok=True) 344 | _C.dump(os.path.join(_A.serialization_dir, "config.yml")) 345 | 346 | for epoch in range(epoch_num): 347 | print('EPOCH %d / %d' % (epoch + 1, epoch_num)) 348 | run_step = eval_every if total_step + eval_every < len(train_iter) * _C.max_epoch else len(train_iter) * _C.max_epoch - total_step 349 | model.train() 350 | 351 | with tqdm(total=math.ceil(run_step / _C.gradient_accumulation_steps), file=sys.stdout) as pbar: 352 | for step in range(run_step): 353 | try: 354 | batch = next(train_iter) 355 | except: 356 | train_iter = iter(train_data_loader) 357 | batch = next(train_iter) 358 | 359 | if torch.cuda.is_available(): 360 | for n in batch: 361 | if n in ['gt', 'image_ids']: continue 362 | batch[n] = batch[n].cuda() 363 | 364 | total_step += 1 365 | 366 | # optimizer.zero_grad() 367 | 368 | outputs = model( 369 | input_ids=batch['encoder_input_ids'], 370 | attention_mask=batch['encoder_mask'], 371 | encoder_img_mask=batch['encoder_img_mask'], 372 | encoder_obj_feature=batch['encoder_obj_feature'], 373 | encoder_obj_box=batch['encoder_obj_box'], 374 | encoder_relative_pos_index=batch['encoder_rel_position'], 375 | decoder_mention_flag=batch['mention_flag'], 376 | decoder_cls_on_input=batch['encoder_cls'], 377 | labels=batch['cap_decoder_input_ids'] 378 | ) 379 | #training 380 | loss = outputs.loss 381 | loss = loss / _C.gradient_accumulation_steps 382 | loss.backward() 383 | if _C.grad_clip_value != 0: 384 | torch.nn.utils.clip_grad_value_(model.parameters(), _C.grad_clip_value) 385 | if (step + 1) % _C.gradient_accumulation_steps == 0: 386 | optimizer.step() 387 | pbar.set_description("loss %.2f" % (loss.item() * _C.gradient_accumulation_steps)) 388 | pbar.update(1) 389 | optimizer.zero_grad() 390 | 391 | eval_result = run_eval(_C, model, val_data_loader, tokenizer, copy_vocab, device, output_path=_A.output_path) 392 | checkpoint_manager.step(eval_result["CIDEr"]["entire"]) 393 | 394 | -------------------------------------------------------------------------------- /dataset/reader.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict, List 2 | import json 3 | import h5py 4 | import numpy as np 5 | from tqdm import tqdm 6 | import random 7 | from anytree import AnyNode 8 | from anytree.search import findall_by_attr,findall 9 | import copy 10 | import string 11 | 12 | class OIDictImporter(object): 13 | ''' Importer that works on Open Images json hierarchy ''' 14 | def __init__(self, nodecls=AnyNode): 15 | self.nodecls = nodecls 16 | 17 | def import_(self, data): 18 | """Import tree from `data`.""" 19 | return self.__import(data) 20 | 21 | 22 | def __import(self, data, parent=None): 23 | assert isinstance(data, dict) 24 | assert "parent" not in data 25 | attrs = dict(data) 26 | children = attrs.pop("Subcategory", []) 27 | node = self.nodecls(parent=parent, **attrs) 28 | for child in children: 29 | self.__import(child, parent=node) 30 | return node 31 | 32 | class HierarchyFinder(object): 33 | 34 | def __init__(self, class_structure_path, abstract_list_path): 35 | importer = OIDictImporter() 36 | with open(class_structure_path) as f: 37 | self.class_structure = importer.import_(json.load(f)) 38 | 39 | with open(abstract_list_path) as out: 40 | self.abstract_list = json.load(out) 41 | 42 | def find_key(self, label): 43 | if label in self.abstract_list: 44 | return label 45 | return None 46 | 47 | def find_parent(self, label): 48 | target_node = findall(self.class_structure, filter_=lambda node: node.LabelName.lower() in (label))[0] 49 | while self.find_key(target_node.LabelName.lower()) is None: 50 | target_node = target_node.parent 51 | return self.find_key(target_node.LabelName.lower()) 52 | 53 | def nms(dets, classes, hierarchy, thresh=0.8): 54 | # Non-max suppression of overlapping boxes where score is based on 'height' in the hierarchy, 55 | # defined as the number of edges on the longest path to a leaf 56 | scores = [findall(hierarchy, filter_=lambda node: node.LabelName.lower() == cls)[0].height for cls in classes] 57 | 58 | x1 = dets[:, 0] 59 | y1 = dets[:, 1] 60 | x2 = dets[:, 2] 61 | y2 = dets[:, 3] 62 | 63 | areas = (x2 - x1 + 1) * (y2 - y1 + 1) 64 | 65 | scores = np.array(scores) 66 | order = scores.argsort() 67 | 68 | keep = [] 69 | while order.size > 0: 70 | i = order[0] 71 | keep.append(i) 72 | xx1 = np.maximum(x1[i], x1[order[1:]]) 73 | yy1 = np.maximum(y1[i], y1[order[1:]]) 74 | xx2 = np.minimum(x2[i], x2[order[1:]]) 75 | yy2 = np.minimum(y2[i], y2[order[1:]]) 76 | 77 | w = np.maximum(0.0, xx2 - xx1 + 1) 78 | h = np.maximum(0.0, yy2 - yy1 + 1) 79 | inter = w * h 80 | 81 | # check the score, objects with smaller or equal number of layers cannot be removed. 82 | keep_condition = np.logical_or(scores[order[1:]] <= scores[i], \ 83 | inter / (areas[i] + areas[order[1:]] - inter) <= thresh) 84 | 85 | inds = np.where(keep_condition)[0] 86 | order = order[inds + 1] 87 | 88 | return keep 89 | 90 | class ImageFeaturesReader(object): 91 | r""" 92 | A reader for H5 files containing pre-extracted image features. A typical image features file 93 | should have at least two H5 datasets, named ``image_id`` and ``features``. It may optionally 94 | have other H5 datasets, such as ``boxes`` (for bounding box coordinates), ``width`` and 95 | ``height`` for image size, and others. This reader only reads image features, because our 96 | UpDown captioner baseline does not require anything other than image features. 97 | 98 | Example of an h5 file:: 99 | 100 | image_bottomup_features.h5 101 | |--- "image_id" [shape: (num_images, )] 102 | |--- "features" [shape: (num_images, num_boxes, feature_size)] 103 | +--- .attrs {"split": "coco_train2017"} 104 | 105 | Parameters 106 | ---------- 107 | features_h5path : str 108 | Path to an H5 file containing image ids and features corresponding to one of the four 109 | ``split``s used: "coco_train2017", "coco_val2017", "nocaps_val", "nocaps_test". 110 | in_memory : bool 111 | Whether to load the features in memory. Beware, these files are sometimes tens of GBs 112 | in size. Set this to true if you have sufficient RAM. 113 | """ 114 | 115 | def __init__(self, features_h5path: str, start_index: int = 0): 116 | self.features_h5path = features_h5path 117 | 118 | # Keys are all the image ids, values depend on ``self._in_memory``. 119 | # If ``self._in_memory`` is True, values are image features corresponding to the image id. 120 | # Else values will be integers; indices in the files to read features from. 121 | self._map: Dict[int, Union[int, np.ndarray]] = {} 122 | 123 | features_h5 = h5py.File(self.features_h5path, "r") 124 | self._width = features_h5["width"][:] 125 | self._height = features_h5["height"][:] 126 | self.start_index = start_index 127 | 128 | image_id_np = np.array(features_h5["image_id"]) 129 | self._map = { 130 | image_id_np[index]: index for index in range(image_id_np.shape[0]) 131 | } 132 | 133 | features_h5.close() 134 | 135 | self.features_h5 = None 136 | 137 | def open_h5_file(self): 138 | self.features_h5 = h5py.File(self.features_h5path, "r") 139 | 140 | def close_h5_file(self): 141 | self.features_h5.close() 142 | self.features_h5 = None 143 | 144 | def __len__(self): 145 | return len(self._map) 146 | 147 | def process_box(self, index, box_np, score): 148 | new_box_np = np.zeros((box_np.shape[0], 8), dtype=np.float32) 149 | 150 | if score.shape[0] > box_np.shape[0]: 151 | score = score[:box_np.shape[0]] 152 | 153 | box_np[:, 0] /= self._width[index] 154 | box_np[:, 2] /= self._width[index] 155 | box_np[:, 1] /= self._height[index] 156 | box_np[:, 3] /= self._height[index] 157 | if box_np.shape[0] > 0: 158 | new_box_np[:, :4] = box_np 159 | new_box_np[:, 4] = box_np[:, 2] - box_np[:, 0] 160 | new_box_np[:, 5] = box_np[:, 3] - box_np[:, 1] 161 | new_box_np[:, 6] = (box_np[:, 2] - box_np[:, 0]) * (box_np[:, 3] - box_np[:, 1]) 162 | min_size = min(score.shape[0], box_np.shape[0]) 163 | new_box_np[:min_size, 7] = score[:min_size] 164 | 165 | return new_box_np 166 | 167 | def __getitem__(self, image_id: int): 168 | if self.features_h5 is None: 169 | features_h5 = h5py.File(self.features_h5path, "r") 170 | else: 171 | features_h5 = self.features_h5 172 | index = self._map[image_id] 173 | image_id_features = features_h5["features"][index].astype('float32').reshape(-1, 2048) 174 | class_ = features_h5["classes"][index].astype('int64') 175 | score = features_h5["scores"][index].reshape(-1) 176 | box_np = features_h5["boxes"][index].reshape(-1, 4) 177 | new_box_np = self.process_box(index, box_np, score) 178 | if self.features_h5 is None: 179 | features_h5.close() 180 | min_size = min([image_id_features.shape[0], new_box_np.shape[0], class_.shape[0]]) 181 | return image_id_features[:min_size], new_box_np[:min_size], class_[:min_size] + self.start_index 182 | 183 | class CocoCaptionsReader(object): 184 | def __init__(self, captions_jsonpath, captions_word_norm_jsonpath=None, rm_dumplicated_caption=False, shuffle=False, is_train=True, rm_punctuation=False): 185 | 186 | self._captions_jsonpath = captions_jsonpath 187 | 188 | with open(captions_jsonpath) as cap: 189 | captions_json = json.load(cap) 190 | 191 | vocab_norm = None 192 | if captions_word_norm_jsonpath is not None: 193 | with open(captions_word_norm_jsonpath) as word_norm: 194 | vocab_norm = json.load(word_norm) 195 | # List of (image id, caption) tuples. 196 | _captions_dict = {} 197 | caption_set = set() 198 | rm_dump_cap = 0 199 | c = copy.deepcopy 200 | print(f"Tokenizing captions from {captions_jsonpath}...") 201 | for caption_item in tqdm(captions_json["annotations"]): 202 | if 'unable' in caption_item["caption"]: 203 | continue 204 | if caption_item["caption"] in caption_set and rm_dumplicated_caption and is_train: 205 | rm_dump_cap += 1 206 | continue 207 | else: 208 | caption_set.add(caption_item["caption"]) 209 | caption_item["gt"] = c(caption_item["caption"]) 210 | caption_item["caption"]: str = caption_item["caption"].lower().strip() 211 | if rm_punctuation: 212 | for p in string.punctuation: 213 | caption_item["caption"] = caption_item["caption"].replace(p, ' ') 214 | if vocab_norm is not None: 215 | for key, value in vocab_norm.items(): 216 | caption_item['caption'] = caption_item['caption'].replace(' ' + key + ' ', ' ' + value + ' ') 217 | if caption_item["image_id"] not in _captions_dict: 218 | _captions_dict[caption_item["image_id"]] = [] 219 | _captions_dict[caption_item["image_id"]].append(caption_item) 220 | 221 | self._captions_together = [] 222 | for (k, captions) in _captions_dict.items(): 223 | gt = [c['gt'] for c in captions] 224 | for cap in captions: 225 | self._captions_together.append((k, cap['caption'], gt)) 226 | self._captions_dict = _captions_dict 227 | 228 | if rm_dumplicated_caption and is_train: 229 | print("remove duplicate captions %d" % rm_dump_cap) 230 | 231 | if shuffle: 232 | random.shuffle(self._captions_together) 233 | 234 | 235 | def __len__(self): 236 | return len(self._captions_together) 237 | 238 | def __getitem__(self, index): 239 | img_id, cap, gt = self._captions_together[index] 240 | return img_id, cap, gt 241 | 242 | def get_gt_by_image_id(self, image_id): 243 | caps = self._captions_dict[image_id] 244 | return [c['gt'] for c in caps] 245 | 246 | class BoxesReader(object): 247 | """ 248 | A reader for H5 files containing bounding boxes, classes and confidence scores inferred using 249 | an object detector. A typical H5 file should at least have the following structure: 250 | ``` 251 | image_boxes.h5 252 | |--- "image_id" [shape: (num_images, )] 253 | |--- "width" [shape: (num_images, )] 254 | |--- "height" [shape: (num_images, )] 255 | |--- "boxes" [shape: (num_images, max_num_boxes, 4)] 256 | |--- "classes" [shape: (num_images, max_num_boxes, )] 257 | +--- "scores" [shape: (num_images, max_num_boxes, )] 258 | ``` 259 | Box coordinates are of form [X1, Y1, X2, Y2], _not_ normalized by image width and height. Class 260 | IDs start from 1, i-th ID corresponds to (i-1)-th category in "categories" field of 261 | corresponding annotation file for this split (in COCO format). 262 | Parameters 263 | ---------- 264 | boxes_h5path : str 265 | Path to an H5 file containing boxes, classes and scores of a particular dataset split. 266 | """ 267 | 268 | def __init__(self, 269 | boxes_h5path: str, 270 | detection_dict: Dict[int, str], 271 | object_blacklist_path: str, 272 | class_structure_path: str = None, 273 | abstract_list_path: str = None, 274 | min_score: float = 0.01, 275 | top_k: int = 3, 276 | is_val: bool = False, 277 | cls_start_index: int = 0, 278 | object_filtering: bool = True, 279 | variant_copy_candidates: bool = True, 280 | in_memory: bool = False, 281 | copy_candidate_clear_up=False): 282 | 283 | with open(object_blacklist_path) as out: 284 | blacklist = json.load(out) 285 | full_list = blacklist['blacklist_categories'] + (blacklist['val_blacklist_categories'] if is_val else []) 286 | self._blacklist_categories = set([s.lower() for s in full_list]) 287 | self._boxes_h5path = boxes_h5path 288 | self.detection_dict = detection_dict 289 | self.min_score = min_score 290 | self.is_val = is_val 291 | self.object_filtering = object_filtering 292 | self.top_k = top_k 293 | self.cls_start_index = cls_start_index 294 | self.in_memory = in_memory 295 | self.variant_copy_candidates = variant_copy_candidates 296 | self.copy_candidate_clear_up = copy_candidate_clear_up 297 | 298 | if abstract_list_path is not None and class_structure_path is not None: 299 | self.hierarchy_finder = HierarchyFinder(class_structure_path, abstract_list_path) 300 | else: 301 | self.hierarchy_finder = None 302 | 303 | self.cache = {} 304 | 305 | with h5py.File(self._boxes_h5path, "r") as boxes_h5: 306 | self._width = boxes_h5["width"][:] 307 | self._height = boxes_h5["height"][:] 308 | self._image_ids = boxes_h5["image_id"][:].tolist() 309 | self._image_ids = { 310 | image_id: index for index, image_id in enumerate(self._image_ids) 311 | } 312 | 313 | if self.in_memory: 314 | for image_id in tqdm(self._image_ids): 315 | self.process_single_image(image_id, self._image_ids[image_id], boxes_h5) 316 | 317 | def __len__(self): 318 | return len(self._image_ids) 319 | 320 | def process_single_image(self, image_id, i, boxes_h5): 321 | feature = boxes_h5["features"][i].reshape(-1, 2048) 322 | box_np = boxes_h5["boxes"][i].reshape(-1, 4) 323 | box_score = boxes_h5["scores"][i] 324 | class_list = (boxes_h5["classes"][i] + self.cls_start_index).tolist() 325 | 326 | new_box_np = np.zeros((box_np.shape[0], 8)) 327 | if box_np.shape[0] > 0: 328 | box_np[:, 0] /= self._width[i] 329 | box_np[:, 2] /= self._width[i] 330 | box_np[:, 1] /= self._height[i] 331 | box_np[:, 3] /= self._height[i] 332 | 333 | new_box_np[:, :4] = box_np 334 | new_box_np[:, 4] = box_np[:, 2] - box_np[:, 0] 335 | new_box_np[:, 5] = box_np[:, 3] - box_np[:, 1] 336 | new_box_np[:, 6] = (box_np[:, 2] - box_np[:, 0]) * (box_np[:, 3] - box_np[:, 1]) 337 | 338 | min_size = min(box_score.shape[0], box_np.shape[0]) 339 | new_box_np[:min_size, 7] = box_score[:min_size] 340 | 341 | feature = feature[:box_np.shape[0]] 342 | new_box_np = new_box_np[:box_np.shape[0]] 343 | box_score = box_score[:box_np.shape[0]] 344 | class_list = class_list[:box_np.shape[0]] 345 | 346 | if not self.variant_copy_candidates: 347 | _class = [] 348 | _box = [] 349 | _feature = [] 350 | 351 | min_size = min(feature.shape[0], new_box_np.shape[0]) 352 | feature = feature[:min_size] 353 | new_box_np = new_box_np[:min_size] 354 | box_score = box_score[:min_size] 355 | class_list = class_list[:min_size] 356 | 357 | for idx, box_cls in enumerate(class_list): 358 | if box_cls not in self.detection_dict or box_score[idx] < self.min_score: 359 | continue 360 | _box.append(new_box_np[idx]) 361 | _class.append(box_cls) 362 | _feature.append(feature[idx]) 363 | 364 | new_box_np = np.zeros((len(_box), 8)) 365 | for i, bb in enumerate(_box): 366 | new_box_np[i] = bb 367 | feature_np = np.array(_feature) 368 | class_np = np.array(_class) 369 | obj_mask = np.ones((len(_class),), dtype=np.float32) 370 | 371 | for idx, box_cls in enumerate(class_np): 372 | text_class = self.detection_dict[box_cls] 373 | if text_class in self._blacklist_categories: 374 | obj_mask[idx] = 0.0 375 | 376 | if self.object_filtering: 377 | keep = nms(new_box_np, [self.detection_dict[box_cls] for box_cls in _class], self.hierarchy_finder.class_structure) 378 | for idx in range(len(_class)): 379 | if idx not in keep: 380 | obj_mask[idx] = 0.0 381 | 382 | if new_box_np.shape[0] > 0: 383 | anns = [] 384 | for idx, (box, cls_, mask) in enumerate(zip(new_box_np, class_np, obj_mask)): 385 | if mask == 1.0: 386 | anns.append((box, cls_, idx)) 387 | 388 | anns = sorted(anns, key=lambda x:x[0][7], reverse=True) 389 | 390 | if self.object_filtering: 391 | anns = anns[:self.top_k] 392 | 393 | seen_class = [] 394 | for box, cls_, idx in anns: 395 | if cls_ not in seen_class: 396 | seen_class.append(cls_) 397 | obj_mask[idx] = 2.0 398 | 399 | obj_mask[obj_mask < 2.0] = 0.0 400 | obj_mask[obj_mask == 2.0] = 1.0 401 | 402 | class_list = class_np.tolist() 403 | text_class = [self.detection_dict[v] for v in class_list] 404 | if self.hierarchy_finder is not None: 405 | parent_class = [self.hierarchy_finder.find_parent(v) for v in text_class] 406 | parent_class_index = [self.hierarchy_finder.abstract_list[v] for v in parent_class] 407 | else: 408 | parent_class_index = [0 for _ in range(len(text_class))] 409 | 410 | new_box_np = new_box_np.astype('float32') 411 | class_np = np.array(class_list).astype('int64') 412 | feature_np = feature_np.astype('float32') 413 | parent_class_np = np.array(parent_class_index).astype('int64') 414 | 415 | if not self.copy_candidate_clear_up: 416 | self.cache[image_id] = { 417 | "predicted_boxes": new_box_np, 418 | "predicted_classes": class_np, 419 | "predicted_feature": feature_np, 420 | "parent_classes": parent_class_np, 421 | "predicted_mask": obj_mask 422 | } 423 | else: 424 | self.cache[image_id] = { 425 | "predicted_boxes": new_box_np[obj_mask == 1], 426 | "predicted_classes": class_np[obj_mask == 1], 427 | "predicted_feature": feature_np[obj_mask == 1], 428 | "parent_classes": parent_class_np[obj_mask == 1], 429 | "predicted_mask": obj_mask[obj_mask == 1] 430 | } 431 | else: 432 | text_class = [self.detection_dict[v] for v in class_list] 433 | if self.hierarchy_finder is not None: 434 | parent_class = [self.hierarchy_finder.find_parent(v) for v in text_class] 435 | parent_class_index = [self.hierarchy_finder.abstract_list[v] for v in parent_class] 436 | else: 437 | parent_class_index = [0 for _ in range(len(text_class))] 438 | 439 | new_box_np = new_box_np.astype('float32') 440 | class_np = np.array(class_list).astype('int64') 441 | feature_np = feature.astype('float32') 442 | parent_class_np = np.array(parent_class_index).astype('int64') 443 | obj_mask = np.ones_like(class_np) 444 | 445 | if not self.copy_candidate_clear_up: 446 | self.cache[image_id] = { 447 | "predicted_boxes": new_box_np, 448 | "predicted_classes": class_np, 449 | "predicted_feature": feature_np, 450 | "parent_classes": parent_class_np, 451 | "predicted_mask": obj_mask 452 | } 453 | else: 454 | self.cache[image_id] = { 455 | "predicted_boxes": new_box_np[obj_mask == 1], 456 | "predicted_classes": class_np[obj_mask == 1], 457 | "predicted_feature": feature_np[obj_mask == 1], 458 | "parent_classes": parent_class_np[obj_mask == 1], 459 | "predicted_mask": obj_mask[obj_mask == 1] 460 | } 461 | 462 | 463 | 464 | def __getitem__(self, image_id: int): 465 | i = self._image_ids[image_id] 466 | 467 | if image_id in self.cache: 468 | return self.cache[image_id] 469 | 470 | with h5py.File(self._boxes_h5path, "r") as boxes_h5: 471 | self.process_single_image(image_id, i, boxes_h5) 472 | 473 | d = self.cache[image_id] 474 | return {key: np.array(d[key], copy=True) for key in d} 475 | -------------------------------------------------------------------------------- /dataset/new_copy_vocab_all.txt: -------------------------------------------------------------------------------- 1 | NONE 2 | eatType[pub]@pub 3 | name[The Vaults]@the vaults 4 | near[Café Adriatic]@café adriatic 5 | customer rating[5 out of 5]@5 star@high customer rating@5-star@5 out of 5@5 out of 5.@5 out of 5,@ratings at a high@high reviews@five star@5 out of 5 stars@rating and a high@five stars@great reviews@great customer rating@highly rated@five out of five@high customer ratings@5 stars@high rating of 5 out of 5@high rating of 5 out of 5.@high rating of 5 out of 5,@excellent customer rating@high quality@rating and high@high ratings@high customer reviews@great customer rating of 5 out of 5@great customer rating of 5 out of 5.@great customer rating of 5 out of 5,@great quality@great rating@excellent customer ratings@5 out of 5 star@excellent reviews@high customer rating of 5 out of 5@high customer rating of 5 out of 5.@high customer rating of 5 out of 5,@ratings and high@five-star@quality and high@5 of 5 stars@excellent customer rating of 5 out of 5@excellent customer rating of 5 out of 5.@excellent customer rating of 5 out of 5,@quality food at high@highly rated with 5 out of 5@highly rated with 5 out of 5.@highly rated with 5 out of 5,@high customer satisfaction rating@excellent 5-star rated@highly priced and rated@high rating@high customer rating as 5 out of 5@high customer rating as 5 out of 5.@high customer rating as 5 out of 5,@highly-rated@quality for high@5 of 5@5 of 5.@5 of 5,@highly rated at 5 out of 5@highly rated at 5 out of 5.@highly rated at 5 out of 5,@high priced average rating@rating for their high@excellent rating@five out of five star@high satisfaction rating@high-quality@excellent ratings@high ratings of 5 out of 5@high ratings of 5 out of 5.@high ratings of 5 out of 5,@high and is rated@great restaurant rated 5 out of 5 star@rating with high@great customer ratings@high standards@high and ratings@high rated@high but the rating@high priced and rated five stars@excellently rated@excellent ratings of 5 out of 5@excellent ratings of 5 out of 5.@excellent ratings of 5 out of 5,@very good ratings@high the quality@rating and is high@excellent average customer rating@reviews with high@excellent customer reviews@high and rated@high priced ad rated@high customer ratings of 5 out of 5@high customer ratings of 5 out of 5.@high customer ratings of 5 out of 5,@great customer ratings at 5 out of 5@great customer ratings at 5 out of 5.@great customer ratings at 5 out of 5,@rated very high@rating in the high@high class quality@rating at a high@rating with a high@reviews and a high@highly rated 5 out of 5@highly rated 5 out of 5.@highly rated 5 out of 5,@5 out of five@high rated as 5 out of 5@high rated as 5 out of 5.@high rated as 5 out of 5,@high and customer rating@high priced food rated@high end quality@high ratings at 5 of 5@high ratings at 5 of 5.@high ratings at 5 of 5,@rated high@ratings and are high@rating is high@ratings and a high@quality with a high@great rating of 5 of 5@great rating of 5 of 5.@great rating of 5 of 5,@very good customer rating@rated with a high@high but reviews@five out of five stars@rating for the high@high but its rating@ratings are not high@highly rated by customers at 5 out of 5@highly rated by customers at 5 out of 5.@highly rated by customers at 5 out of 5,@high customer rating of five out of five@excellent customer review of 5 out of 5@excellent customer review of 5 out of 5.@excellent customer review of 5 out of 5,@great ratings@high priced and rated@high 5 star rating@high but the quality@high-rated@great reviews with a five star@high standard@highly on customer rating with 5 out of 5@highly on customer rating with 5 out of 5.@highly on customer rating with 5 out of 5,@review in the high@highly rated by customers of 5 out of 5@highly rated by customers of 5 out of 5.@highly rated by customers of 5 out of 5,@greatly rated@high customer rating of a 5 out of 5@high customer rating of a 5 out of 5.@high customer rating of a 5 out of 5,@excellent rating of 5 out of 5@excellent rating of 5 out of 5.@excellent rating of 5 out of 5,@ratings and is high@high price and rating@great rating of 5 out of 5@great rating of 5 out of 5.@great rating of 5 out of 5,@high and customer ratings@rating in high@high priced pub rated@excellent review@highly-rated at 5 out of 5@highly-rated at 5 out of 5.@highly-rated at 5 out of 5,@high-priced but rated five stars@ratings are high@high-priced low-rated@highly rated with ratings 5 out of 5@highly rated with ratings 5 out of 5.@highly rated with ratings 5 out of 5,@very good rating@5-stars@high price average rating@rating but high@ratings in the high@high average customer rating@great customer ratings of 5 out of 5@great customer ratings of 5 out of 5.@great customer ratings of 5 out of 5,@very good reviews@high eve though ratings@rated and priced high@high rating of 5 out of 5 stars@rating for a high@excellent customer ratings of 5 out of 5@excellent customer ratings of 5 out of 5.@excellent customer ratings of 5 out of 5,@rated chinese at high@high prices and ratings@rating that serves high@high customer rated@high-rating@high customer rating of 5 stars@high standard of quality@high customer review@rating and are high@great customer reviews of 5 out of 5@great customer reviews of 5 out of 5.@great customer reviews of 5 out of 5,@excellent-rated@rate fitzbillies with high@highly rated with a 5 out of 5@highly rated with a 5 out of 5.@highly rated with a 5 out of 5,@rating rating is high@rating is a high@great customer reviews@rating to a high@ratings high@ratings are very high@high prices customer rated 6 | name[The Cambridge Blue]@the cambridge blue 7 | near[Café Brazil]@café brazil 8 | priceRange[cheap]@cheap@affordable@inexpensive@low price@low priced@prices are low@price range but low@low prices@affordably@low-priced@cheaply@prices low@price range and a low@price range but has low@prices and a low@price range and low@price is low@low-price@prices with a low@prices and low@price restaurant whit a low@prices but low@price range with the low@prices range and customer low@price as low@price range with low@price range is high with low@price range but has a low@price items with a low@inexpensively@price range with a low@price range and has a low@price range but a low@price range is low@prices and has received low@price and low@low-prices@prices and is low@price tag and a low@prices are very low@price range chinese restaurant with low@price and has a low@price but low@low -priced@prices with low@price range although a low@price is very low 9 | area[riverside]@riverside@by the river@near the river@along the river@on the river@off the river@close to the river@at the river@beside the river 10 | eatType[coffee shop]@coffee shop@coffee@café@coffee-shop 11 | familyFriendly[yes]@family friendly@child friendly@family-friendly@kid friendly@child-friendly@children friendly@for kids@for families@kids friendly@kids-friendly@welcomes families@children is allowed@welcomes kids@kids is allowed@kid-friendly@family-oriented@children-friendly@for children@for family@welcomes children@kids are welcome@family oriented@allows kids@allow children@family place@allows children@welcome families@children are welcome@families are welcome@children- friendly@welcomes families with children@welcomes the whole family@family is welcome@welcome family@welcome kids@welcomes family@accept children@allows families with children@welcomes families and children@family- oriented@place to bring the family@family orientated@welcome children@accepts children@accept your family@welcomes the family@family- friendly@welcomes all families@children are allowed@welcome families with children@welcomes the entire family@accepts family@welcomes your whole family@accepts families@kids are allowed@welcome families and children@welcome including children@welcomes all the family@family are welcome@kid -friendly@families are allowed@allow families@welcome to children 12 | food[Japanese]@japanese 13 | name[The Eagle]@the eagle 14 | near[Burger King]@burger king 15 | priceRange[less than £20]@less than £20@less than £20.@less than £20,@low price@cheap@inexpensive@moderately priced@affordable@less than 20@less than 20.@less than 20,@low-priced@low priced@under £20@under £20.@under £20,@low prices@price range with a low@price range but has a low@cheaply@prices but a low@less than £ 20@less than £ 20.@less than £ 20,@under 20@under 20.@under 20,@inexpensively@price range has a low@price range an has a low@prices and low@low-price@prices are in the low@price range and a low@price range but with a low@price is very low@prices are low@price range and low@price range and has a low@prices with low@low- priced@low-prices@price and a low@price range which is rated low@price low@price is low@affordably@price range is low@below average price@priced below average 16 | customer rating[low]@low rated@low customer rating@ratings are quite low@1 star@poor customer rating@one star@1 out of 5@1 out of 5.@1 out of 5,@low rating@low customer ratings@1 of 5 star@low ratings@one out of five stars@1 out of 5 stars@rating is low@low-rated@low rating english rating@rated low@rating that is low@rating of low@low customer reviews@ratings are low@low quality@low customer review@low consumer rating@1-star@low-customer-rated@one-star@low customer rated@low customers rating@poorly rated@poor rating@rated it low@poor customer ratings@rating score of low@poor reviews@low costumer rating@rating is rather low@rating low@poorly-rated@low cost low rating@doesn't get very good customer ratings@lowly rated@rating being low@low satisfaction rating@badly rated@low star rating@low customer service rating@low customer satisfaction rating@rated to a low@rating is quite low@low prices and ratings@doesn't get good reviews@rating is poor@reviews are low@rating are low@lowly-rated@poor customer reviews@low consumer ratings@low a rating@rating is very low@rated low but low@poor ratings@not have good reviews@poor rated@low in customer ratings@one out of five@doesn't have good customer reviews@not well rated@rate it low@low-ratings@low-rating@low approval rating@low reviews@one out of five star@ratings low@rated this low@1 out of 5 star@low client rating@not have very good reviews@rating and low@rate browns cambridge low@doesn't get good ratings@rated the location as low@rated the place low@rating it low@doesn't have very good customer ratings@rate blue spice low@doesn't have good customer ratings@bad reviews@doesn't have very good reviews@rating range is low@ratings is low@rating is pretty low@isn't well rated@rating loch fyne as low@low customer service ratings@bad review@bad rating@1 -star@lowly rated customer review@poor quality@rate low@rate it as low@rating is currently low@rated as low@low cost low quality@rating is very poor@low customers ratings@ratings there are low@rates low@ratings are poor@poor-quality@rated it as low@low customer review ratings@low average customer rating@not have good customer ratings@rating with low@rating is not low@low customer-rating@rating is bit low@rate this low@bad customer ratings@low in the ratings@quality is quite poor@rate midsummer house low@low review@low customer approval rating@low-quality@low quality reviews@rating in low 17 | food[French]@french 18 | name[The Mill]@the mill 19 | near[The Sorrento]@the sorrento 20 | name[Loch Fyne]@loch fyne 21 | near[The Rice Boat]@the rice boat 22 | eatType[restaurant]@restaurant 23 | food[English]@english@british 24 | name[Bibimbap House]@bibimbap house 25 | near[Clare Hall]@clare hall 26 | priceRange[moderate]@moderately priced@moderate priced@moderate price@moderately-priced@price range is moderate@moderate-priced@affordable@price is moderate@price range and has an average@average price@average prices@average priced@prices are moderate@moderate prices@price ranges are moderate@average-priced@prices are below average@price range moderate@price and an average@prices are above average@price range and a average@prices with average@moderately price@price range and average@price range and an average@price range with an average@prices and average@price range is reasonable@reasonable prices@prices but has an average@mid-range price@mid range prices@mid-range prices@price moderate@averagely priced@prices it has an average@price is about average@moderately prices@price range of moderate@price with average@price range with average@price range is average@mid-range priced@prices but lower than average@price range however the average@reasonable priced@price but below average@price range and with an average@not cheap@prices than average@prices but only has average@prices are ok@price ranges and average@price range that has average@reasonable price@prices are higher than average@prices above average@price range in moderate@price range is rated as average@prices are average@price range and rating of average@prices and an average@price range restaurant with an average@prices fall in the moderate@price range is from very moderate@prices average@price chinese average@price range is lower than average@price and has average@price range and and average@price with an average@prices range is moderate@average- priced@price range that is rated average@prices are quite reasonable@price are really moderate@prices moderate@price and average@prices named clowns with average@price range and gets average@price range yet average@price range that is moderate@mid range price@price range above average@price average@prices are reasonable@prices for meals are moderate@price range but a average@prices for the moderate@price range that is rather moderate@price range is high with average@mid range priced@prices in the moderate@price is more than average@price range there is moderate@prices and has average@averagely price@price range and rated average@prices fall into the moderate@price range and has a average@average-price@price range is slightly above average@price range and has received average@price range it has an average@price range is above average@price range which is moderate@prices slightly higher than average@moderate- priced@price is pretty reasonable@prices in the average@price range but an average@price range with the average@price ranging moderate@prices are higher then average@price no average@price range is higher than average@prices but also an average@prices ranges are moderate@price is average@price range is really moderate@price range for wildwood is moderate@prices are low to moderate@price range for you is moderate@price range very moderate@prices are very reasonable@prices with an average@price range and is rated average@price ranging at a moderate@price menu and average@prices are in the moderate@prices that are above average@prices are pretty moderate@moderate-price@price range but average@prices are fairly average@prices and moderate@price is high and average@price and has an average@price range and holds an average 27 | name[The Rice Boat]@the rice boat 28 | familyFriendly[no]@not family-friendly@not child friendly@not kids friendly@non family friendly@no children@not kid friendly@non family-friendly@not children friendly@isn't family-friendly@not kid-friendly@families are not welcome@not classified as family-friendly@no kids@not family friendly@does not welcome children@isn't child friendly@isn't children-friendly@non-child friendly@not a family friendly@does not allow children@no kids friendly@not suitable for families@not child-friendly@not a family-friendly@non-family-friendly@isn't kids friendly@not suitable for children@not a children friendly@not so children friendly@adults only@non-family friendly@child unfriendly@doesn't allow children@no families@adult only@not friendly to kids@not for families@not kids-friendly@non kid friendly@kid unfriendly@does not allow kids@adult-only@non children friendly@not being child-friendly@not a kid friendly@no children friendly@not so family-friendly@non-children friendly@not a family oriented@no a family friendly@child-unfriendly@no family@not opened to all age@non child friendly@not a child friendly@non-kids friendly@not family@non children-friendly@no good for children@isn't kid friendly@non-kid friendly@not being family friendly@not children-friendly@isn't children friendly@does not welcome families@not for children@not a kids friendly@isn't family friendly@non family oriented@adult establishment@not so family friendly@not begin child friendly@not friendly for kids@doesn't allow kids@not to families@not suitable for kids@no family-friendly@not considered family-friendly@does not allow families@not too kid friendly@adult place@not especially family-friendly@children are not allowed@non-child-friendly@no for families@not for the family@adults-only@no children-friendly@no family friendly@non kids friendly@kids are not allowed@not offer a child-friendly@not being family-friendly@no longer child friendly@non kid-friendly@non-family oriented@no means kids friendly@families are not allowed@no kids-friendly@not good for family@not being child friendly@not good for kids@no for children@isn't kid-friendly@not considered kid friendly@not considered a child friendly@non - kid friendly@kids are not welcome@not very kid friendly@not considered child friendly@doesn't welcome children@not very family-friendly@doesn't allow families@not considered family friendly@no to children@children are not welcome@not feel very family-friendly@not a child-friendly@not friendly to families@not known as family-friendly@not family oriented@non-kids-friendly@isn't a kids friendly@does not welcome kids@not very family friendly@not friendly to children@not very child friendly@not down as kid friendly@not friendly for children@not for family@no for kids@not quite family friendly@not considered a family friendly@not be kid friendly@isn't the family-friendly@isn't a family-friendly@no kid friendly@doesn't accept kids@family unfriendly@not provide a family-friendly@not so much family-friendly@not be child friendly@not known as kid friendly@isn't a kid friendly@not for kids@non -kid friendly@not good for families@non- kid friendly@adult venue@not be considered family-friendly@not kids@not friendly with families@isn't down as family-friendly@not viewed as child friendly@no time family friendly@not rated as family friendly@not being kid friendly@no a family-friendly@not only kid friendly@not really family-friendly@not provide a children friendly@non child-friendly@family-unfriendly@not a kid-friendly@no good for kids@not very children friendly@not considered children friendly@no a kid friendly@non-kid-friendly@not very kids friendly@does not accept families@not a very child friendly@not suitable for family@not friendly with children@non-children-friendly@non kids-friendly@not children@not friendly for family@not family- friendly@isn't very family friendly@not offer a kid friendly@isn't considered child friendly@not provide family-friendly@not the most family friendly@not only child friendly@only for adults@not exactly child-friendly@not only not children friendly@doesn't allow families with children@not however family friendly@not your average family-friendly@isn't kids-friendly@not so child friendly@no good for families@no to family-friendly@isn't very family-friendly@adult@adult oriented@family@families@kids@kid@adults 29 | name[The Wrestlers]@the wrestlers 30 | near[Raja Indian Cuisine]@raja indian cuisine 31 | area[city centre]@city centre@in the centre@city center@in the center@center of the city@centre of the city@centre of city@town centre@center of city@town center@centre of town 32 | name[Aromi]@aromi 33 | food[Fast food]@fast food 34 | name[The Phoenix]@the phoenix 35 | customer rating[3 out of 5]@3 out of 5@3 out of 5.@3 out of 5,@3 star@three star@good standard@average customer rating@good quality@average customer reviews@rated average@3 out of 5 stars@good reviews@averagely rated@average customer rating of 3 out of 5@average customer rating of 3 out of 5.@average customer rating of 3 out of 5,@average rating@three stars@three out of five@3 stars@average rating of 3 out of 5@average rating of 3 out of 5.@average rating of 3 out of 5,@3-star@good customer rating@3 out of 5 star@average customer ratings@3 out of five@three out of five stars@average rating of 3 out of 5 stars@three of five star@average ratings@good-quality@good rating@rated with average@3 of 5@3 of 5.@3 of 5,@well priced quality@average reviews@well rated 3 out of 5@well rated 3 out of 5.@well rated 3 out of 5,@average quality@average in customer ratings@average rated@average price and rating@3 out of five star@well rated@average rating is 3 out of 5@average rating is 3 out of 5.@average rating is 3 out of 5,@three-star@average price average rating@rating and an average@average customer rated@rating and has average@three out of five star@good ratings@rated it average@3 of 5 star@rating and over average@ratings average@average customer rating is 3 out of 5@average customer rating is 3 out of 5.@average customer rating is 3 out of 5,@average priced averagely rated@average ratings a 3 out of 5@average ratings a 3 out of 5.@average ratings a 3 out of 5,@average price rating@average customer rating is a 3 out of 5@average customer rating is a 3 out of 5.@average customer rating is a 3 out of 5,@rating is an average 3 out of 5@rating is an average 3 out of 5.@rating is an average 3 out of 5,@average priced and rated@3 of 5 stars@average customer rating of a three out of five@rating and average@ratings are average@good customer ratings@good rated@rating and with average@average reviews of 3 out of 5@average reviews of 3 out of 5.@average reviews of 3 out of 5,@average-rated@ratings are only average@average rated 3 out of 5@average rated 3 out of 5.@average rated 3 out of 5,@average customer rating of three out of five@rated fast food average@rated as average@rating and is average@three out of 5@three out of 5.@three out of 5,@good review@good customer rating at 3 out of 5@good customer rating at 3 out of 5.@good customer rating at 3 out of 5,@average rating 3 out of 5@average rating 3 out of 5.@average rating 3 out of 5,@ratings and average@rated the phoenix average@reviews are average@good with a rating@well-rated@ratings are below average@3-stars@rated restaurant with average@three of five stars@average low customer rating@rated kids friendly average@rated place within average@rating in the average@average customer ratings are 3 out of 5@average customer ratings are 3 out of 5.@average customer ratings are 3 out of 5,@rating with an average@rating below average@rating with the average@average customer rating of 3 out of 5 stars@ratings on average@rating is average@average customer rating 3 out of 5@average customer rating 3 out of 5.@average customer rating 3 out of 5,@three of five@rating average 36 | name[Browns Cambridge]@browns cambridge 37 | name[Taste of Cambridge]@taste of cambridge 38 | food[Italian]@italian 39 | name[Cocum]@cocum 40 | name[The Dumpling Tree]@the dumpling tree 41 | priceRange[high]@high prices@high price@high priced@expensive@high-priced@price range is high@prices are high@price is high@price high@price is in the high@price range of begin high@price range high@prices in the high@price range is quite high@prices are in the high@price rang is high@price range of high@prices are ridiculously high@prices being on the high@price range is on the high@price is a bit high@prices are on the high@price ranged high@price are little bit high@prices that range high@price range is a little high@prices are a little high@high-price@prices are somewhat high@price is a little high@prices are a bit high@price is pretty high@price a little high@prices are within the high@price range restaurant with high@price range is typically high@prices ranging in the high@prices a bit high@price is kind of high@prices are quite high@price is also very high@price ranges to high@high- priced 42 | food[Indian]@indian 43 | name[The Punter]@the punter 44 | name[The Golden Curry]@the golden curry 45 | near[Café Rouge]@café rouge 46 | customer rating[1 out of 5]@one out of five@poor rating@one star@rating is a low 1 out of 5@rating is a low 1 out of 5.@rating is a low 1 out of 5,@1 out of 5@1 out of 5.@1 out of 5,@low customer rating@1 star@low rated@1 out of 5 star@rating and low@1 of 5@1 of 5.@1 of 5,@1 out of 5 stars@low customer rating of 1 out of 5@low customer rating of 1 out of 5.@low customer rating of 1 out of 5,@ratings and low@low rating@one of five@1-star@low ratings@rating is low@poor rated@one out of five stars@poor customer rating@ratings are low@lowly rated@rated low@low customer service rating of 1 out of 5@low customer service rating of 1 out of 5.@low customer service rating of 1 out of 5,@low quality@rated but low@poor quality@poor customer rating of 1 out of 5@poor customer rating of 1 out of 5.@poor customer rating of 1 out of 5,@poorly rated@one out of five star@low-rated@low priced and rated@one-star@poor reviews@poor customer ratings@reviews and low@low customer ratings@poorly-rated@rate pretty low@low customer rated@rating is bad@low rated as 1 out of 5@low rated as 1 out of 5.@low rated as 1 out of 5,@low-quality@bad rating@rating and a low@low customer rating 1 out of 5@low customer rating 1 out of 5.@low customer rating 1 out of 5,@poor customer rated@low priced high quality@poor rating of 1 out of 5@poor rating of 1 out of 5.@poor rating of 1 out of 5,@low rating of only 1 out of 5@low rating of only 1 out of 5.@low rating of only 1 out of 5,@not received good reviews@quality low@bad reviews@rated average and low@poor ratings@low costumer rating@low consumer ratings@not very good reviews@low reviews@bad customer ratings@low customer rating of 1 star@low rating of 1 out of 5@low rating of 1 out of 5.@low rating of 1 out of 5,@not get good ratings@low rating of 1 star@rated a low 1 star@low rating of only 1 out of 5 stars@ratings are a low 1 out of 5@ratings are a low 1 out of 5.@ratings are a low 1 out of 5,@rating is somewhat low@poor standard@quality at a low@bad customer rating@not of very good quality@ratings with low@poor customer rating with 1 out of 5@poor customer rating with 1 out of 5.@poor customer rating with 1 out of 5,@1 stars@rating is low at 1 out of 5@rating is low at 1 out of 5.@rating is low at 1 out of 5,@one out of 5@one out of 5.@one out of 5,@rating are low@poor customer reviews@rating that serves low@quality food at low@1 of 5 stars@low star rating@rated and low@poor reviews of 1 out of 5@poor reviews of 1 out of 5.@poor reviews of 1 out of 5,@rated clowns as low@low customers rating@low and ratings@rating and medium low@quality at low@ratings and a low@low in customer rating@low customer review@low price rated@rating with a low@not got good ratings@rated it low@rated low at 1 out of 5@rated low at 1 out of 5.@rated low at 1 out of 5,@1 of five@rating is bad 1 out of 5@rating is bad 1 out of 5.@rating is bad 1 out of 5,@low customer reviews@low-priced pub rated@low rating 1 out of 5@low rating 1 out of 5.@low rating 1 out of 5,@low rated with 1 out of 5@low rated with 1 out of 5.@low rated with 1 out of 5,@ratings at low@1 out of five@rating is poor 1 out of 5@rating is poor 1 out of 5.@rating is poor 1 out of 5,@not well rated@low customer rating of 1 out of 5 stars@poor-rated@low customer rating of one out of five@rating is poor 47 | name[Alimentum]@alimentum 48 | near[Yippee Noodle Bar]@yippee noodle bar 49 | priceRange[£20-25]@20-25@20-25.@20-25,@prices are average@averagely priced@average priced@moderately priced@average price@moderately-priced@average prices@20 to 25@20 to 25.@20 to 25,@affordable@reasonable prices@moderate priced@average-priced@moderate price@20 - 25@20 - 25.@20 - 25,@moderate prices@£20-25@£20-25.@£20-25,@prices that are very reasonable@20 -25@20 -25.@20 -25,@reasonable price@moderate-priced@moderately prices@average -priced@price range is average@moderately price@prices reasonable@price range is about average@20- 25@20- 25.@20- 25,@20--25@20--25.@20--25,@price range is an average@prices are reasonable@averagely-priced@averagely price@price range in the average@prices and moderate@price rang is about average@prices in the average@reasonable priced 50 | customer rating[high]@high customer rating@highly rated@rating is high@high customer ratings@5 out of 5 stars@five star@5 out of 5@5 out of 5.@5 out of 5,@rated high@high rated@rated as high@high costumer rating@high-rating@high-rated@high customer reviews@great customer ratings@5 star@highly-rated@5-star@high customers rating@high rating@high ratings@high customer service rating@5 stars@high quality@ratings is so high@well rated@high quality reviews@great 5 star review@rating high@ratings are high@great reviews@rate wildwood high@high costumer ratings@great customer reviews@rating is relative high@high customer rated@five stars@high customer satisfaction rating@five-star@excellent customer rating@high reviews@rating are high@rating of high@rate high@excellent rating@highly customer rated@great ratings@highly- rated@rate this restaurant high@rated it high@great quality@high average customer rating@rating is very high@very good rating@rate it high@five out of five stars@ratings is high@rating it high@5-stars@great customer rating@excellent customer reviews@well-rated@excellent reviews@five out of five@rate quite high@excellently rated@excellent service and quality@high customer rating reviews@excellent 5-star rated@high-quality@excellent standard@5 out of 5 star@great rating@excellent customer ratings@high customer satisfaction ratings@rated it as high@rate it as high@ratings are very high@rated in high@rating is so high@rate this establishment as high@rate is high@ratings where high@high-ratings@high priced quality@rating that is high@high customer approval ratings@high customer approval rating@rates high@excellent quality@highly -rated@high customer-rating@high which provides quality@rated like high@five out of five star@excellent 5 star rating@quality meals at high@high customer review@excellent ratings 51 | food[Chinese]@chinese@chinesee 52 | near[The Portland Arms]@the portland arms 53 | priceRange[more than £30]@more than £30@more than £30.@more than £30,@high price@prices are quite high@over £30@over £30.@over £30,@more than 30@more than 30.@more than 30,@expensive@high prices@high-priced@over 30@over 30.@over 30,@price range with high@price range is high@prices and high@high priced@price range and high@more than £ 30@more than £ 30.@more than £ 30,@price range with a high@not cheap@prices and very high@price range and a high@prices are in the high@prices while boasting a high@price range is the high@prices are on the high@price is high@prices with a high@price range and has a high@high-price@prices are high@prices and a high@prices and has a high@price range along with a high@price rating and a high@prices for high@price for the high@price range is very high@price range and holds a high@price with a high@prices and features a high@prices has a high@prices venue with a high@prices and have high@prices restaurant that provides high@prices are a bit high@prices with high@price range that has a high@price a little high@price range with spectacular high@price with high@price range of high@price is on the high@price range restaurant with high@higher priced@more that 30 pounds 54 | name[Midsummer House]@midsummer house 55 | near[All Bar One]@all bar one 56 | near[Express by Holiday Inn]@express by holiday inn 57 | name[Blue Spice]@blue spice 58 | name[Strada]@strada 59 | customer rating[average]@average customer ratings@average rating@average customer rating@rated average@3 out of 5 star@average rated@3 star@averagely rated@rating is average@rated it average@average customer review@rated the food as average@three star@average ratings@rate as average@3 out of 5 stars@three stars@good quality@rated about average@average consumer rated@ratings are average@rating of average@three-stars@well as the reviews@average consumer reviews@good rating@3-star@three of five star@average consumer rating@3 stars@average-rated@three out of five stars@rates average@three out of five@rate it as average@rate is as average@average customer reviews@3 out of 5@3 out of 5.@3 out of 5,@averagely-rated@three out of five star@average reviews@average customer service rating@3 of 5 stars@rated as average@good ratings@rated this as average@good reviews@average price rating@rating average@rate it average@three-star@average review@average customer satisfaction rating@rating on the average@good customer ratings@average in ratings@well-rated@average customer rated@good standard@rate them average@average costumer rating@average quality@average customer-rating@reviews have been average@ratings is average@rating this pub average@average in customer rating@rated it as average@rated an average@average satisfaction rating@rating is only average@rated as an average@well rated@average customer service rated@average customers reviews@ratings are only average@ratings average@average-quality@rated the food average@rating is about average@rated to be average@rated in average@rating being average@rated strada restaurant as average@rated the punter as average@average in both ratings@rating is above average@rate the phoenix as average@rating is just average@average restaurant rating@ratings which are average@rate the establishment as average@rate the waterman as average@average user rating@3 of 5 star@rating are average@rating on average@rated by customers as average@average- rated@rate the eagle as average@good customer rating@reviews rating it as average@reviews are average@average and customer ratings@ratings being average@rated cocum as average@3 of 5@3 of 5.@3 of 5,@review of average@ratings of average@rate the alimentum as average@standards average@average-rating@rating is typically average@rating that is average@average by customer ratings@average 60 | name[The Waterman]@the waterman 61 | name[Zizzi]@zizzi 62 | near[The Bakers]@the bakers 63 | name[Green Man]@green man 64 | near[Café Sicilia]@café sicilia 65 | name[Clowns]@clowns 66 | near[Avalon]@avalon 67 | name[Giraffe]@giraffe 68 | name[The Olive Grove]@the olive grove 69 | near[The Six Bells]@the six bells 70 | name[The Twenty Two]@the twenty two 71 | name[The Cricketers]@the cricketers 72 | near[Ranch]@ranch 73 | near[Crowne Plaza Hotel]@crowne plaza hotel 74 | name[Wildwood]@wildwood 75 | near[Rainbow Vegetarian Café]@rainbow vegetarian café 76 | name[The Golden Palace]@the golden palace 77 | name[The Plough]@the plough 78 | name[Cotto]@cotto 79 | name[Fitzbillies]@fitzbillies 80 | name[Travellers Rest Beefeater]@travellers rest beefeater 81 | --------------------------------------------------------------------------------