├── .gitignore ├── README.md ├── cliora ├── analysis │ ├── __init__.py │ ├── cky.py │ ├── diora_tree.py │ └── utils.py ├── blocks │ ├── __init__.py │ └── negative_sampler.py ├── data │ ├── __init__.py │ ├── batch_iterator.py │ ├── dataloader.py │ ├── dataset.py │ ├── embeddings.py │ ├── preprocessing.py │ └── reading.py ├── external │ ├── __init__.py │ └── standalone_elmo.py ├── logging │ ├── __init__.py │ ├── accumulator.py │ └── configuration.py ├── misc │ └── convert_conll_to_jsonl.py ├── net │ ├── __init__.py │ ├── cliora.py │ ├── diora.py │ ├── experiment_logger.py │ ├── inside_index.py │ ├── offset_cache.py │ ├── outside_index.py │ ├── trainer.py │ ├── utils.py │ └── vg.py ├── scripts │ ├── __init__.py │ ├── parse.py │ ├── parse_diora.py │ ├── phrase_embed.py │ ├── phrase_embed_simple.py │ ├── right_branch.py │ └── train.py └── utils │ ├── __init__.py │ ├── checkpoint.py │ ├── flags.py │ ├── fs.py │ └── path.py ├── requirements.txt ├── test_cliora.sh ├── test_diora.sh ├── train_cliora.sh └── train_diora.sh /.gitignore: -------------------------------------------------------------------------------- 1 | .DS_Store 2 | .idea 3 | ./outputs 4 | ./flickr_data 5 | ./coco_data -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## CLIORA 2 | 3 | This is the official codebase for ICLR **oral** paper: Unsupervised Vision-Language Grammar Induction with Shared Structure Modeling. 4 | 5 | We introduce a new task of Unsupervised Vision-Language Grammar Induction and devise a model Contrastive Language-Image inside-Outside Recursive Autoencoder (CLIORA) to solve it. Please read our paper for more details: https://openreview.net/forum?id=N0n_QyQ5lBF. 6 | 7 | This code follows the implementation architecture of [DIORA](https://github.com/iesl/diora). 8 | 9 | Please cite our paper as follows: 10 | 11 | ``` 12 | @inproceedings{wan2022cliora, 13 | title={Unsupervised Vision-Language Grammar Induction with Shared Structure Modeling}, 14 | author={Wan, Bo and Han, Wenjuan and Zheng, Zilong and Tuytelaars, Tinne}, 15 | booktitle={The International Conference on Learning Representations (ICLR)}, 16 | year={2022}, 17 | } 18 | ``` 19 | 20 | ## Envs and Datas 21 | 22 | Install dependencies (using Conda as a virtual environment): 23 | 24 | ``` 25 | conda create -n cliora python=3.8 26 | source activate cliora 27 | pip install -r requirements.txt 28 | ``` 29 | 30 | 31 | Download [flickr_data](https://esatkuleuvenbe-my.sharepoint.com/:u:/g/personal/bwan_esat_kuleuven_be/ERcLeIlPJxBDg7Jdf6IwOT0BU5kbcTHSRM7U_dPX_y4ftg?e=j9gyB9) and [outputs](https://esatkuleuvenbe-my.sharepoint.com/:u:/g/personal/bwan_esat_kuleuven_be/EYCdZiPIcj5OtQQqIH49B4gBcfT607sKdnGxrsdkYPKapQ?e=1aGlyk) and put the files as the following structure: 32 | 33 | ``` 34 | cliora 35 | ├───cliora 36 | │ ├─... 37 | │ 38 | ├───flickr_data 39 | │ ├─flickr_feat_maf 40 | │ 41 | ├───outputs 42 | ├─flickr 43 | ``` 44 | 45 | We use the same object features as [MAF](https://github.com/qinzzz/Multimodal-Alignment-Framework). Download [train_features_compress.hdf5](https://drive.google.com/file/d/1ABnF0SZMf6pOAC89LJXbXZLMW1X86O96/view?usp=sharing), [val features_compress.hdf5](https://drive.google.com/file/d/1iK-yz6PHwRuAciRW1vGkg9Bkj-aBE8yJ/view?usp=sharing), [test features_compress.hdf5](https://drive.google.com/file/d/1pjntkbr20l2MiUBVQLVV6rQNWpXQymFs/view?usp=sharing) to `flickr_data/flickr_feat_maf`. 46 | 47 | ## Running CLIORA 48 | ``` 49 | export PYTHONPATH=$(pwd):$PYTHONPATH 50 | 51 | 52 | ## Train DIORA 53 | sh train_diora.sh 54 | 55 | ## Test DIORA 56 | sh test_diora.sh 57 | 58 | ## Train CLOIRA based on DIORA 59 | sh train_clora.sh 60 | 61 | ## Test CLIORA 62 | sh test_cliora.sh 63 | ``` 64 | 65 | ## Multi-GPU Training 66 | Single-GPU training: 67 | ``` 68 | export CUDA_VISIBLE_DEVICES=0 69 | python -m cliora/scripts/train.py 70 | --cuda 71 | ... # other args 72 | ``` 73 | 74 | Multi-GPU Training: 75 | 76 | ``` 77 | export CUDA_VISIBLE_DEVICES=0,1,2,3 78 | export NGPUS=4 79 | python -m torch.distributed.launch --nproc_per_node=$NGPUS cliora/scripts/train.py 80 | --cuda 81 | --multigpu 82 | ... # other args 83 | ``` 84 | 85 | ## Visualization 86 | Download [Flickr30K Entities Dataset](http://hockenmaier.cs.illinois.edu/DenotationGraph/) and put the image folder `flickr_images` under `flickr_data/`. Add `--visualize` when run `test_cliora.sh`: 87 | ``` 88 | # test_cliora.sh 89 | python cliora/scripts/parse.py 90 | --cuda 91 | --visualize 92 | --obj_feats 93 | ... # other args 94 | ``` 95 | 96 | ## Word Embedding 97 | 98 | We provide randomly-initialized word embedding, skip-thoughts embedding and ELMo embedding. If you use ELMo embedding and specify the `--elmo_cache_dir`, then the context-insensitive ELMo vectors will be cached, making it much faster to load these vectors after the initial usage. 99 | 100 | Example Usage: 101 | 102 | ``` 103 | word_emb=none/skip/elmo 104 | 105 | python cliora/scripts/train.py 106 | --emb word_emb 107 | ... # other args 108 | ``` 109 | 110 | 111 | ## License 112 | 113 | Copyright 2018, University of Massachusetts Amherst 114 | 115 | Licensed under the Apache License, Version 2.0 (the "License"); 116 | you may not use this file except in compliance with the License. 117 | You may obtain a copy of the License at 118 | 119 | http://www.apache.org/licenses/LICENSE-2.0 120 | 121 | Unless required by applicable law or agreed to in writing, software 122 | distributed under the License is distributed on an "AS IS" BASIS, 123 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 124 | See the License for the specific language governing permissions and 125 | limitations under the License. 126 | -------------------------------------------------------------------------------- /cliora/analysis/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bobwan1995/cliora/b064bdf967d4ccc4f3327183efd888b927bfb4fb/cliora/analysis/__init__.py -------------------------------------------------------------------------------- /cliora/analysis/cky.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from cliora.logging.configuration import get_logger 4 | 5 | 6 | class ParsePredictor(object): 7 | # def __init__(self, net, word2idx): 8 | def __init__(self, net): 9 | super(ParsePredictor, self).__init__() 10 | self.net = net 11 | # self.word2idx = word2idx 12 | # self.idx2word = {v: k for k, v in word2idx.items()} 13 | self.logger = get_logger() 14 | 15 | def parse_batch(self, batch_map): 16 | sentences = batch_map['sentences'] 17 | batch_size = sentences.shape[0] 18 | length = sentences.shape[1] 19 | scalars = self.net.saved_scalars 20 | device = self.net.device 21 | dtype = torch.float32 22 | 23 | # Assign missing scalars 24 | for i in range(length): 25 | scalars[0][i] = torch.full((batch_size, 1), 1, dtype=dtype, device=device) 26 | 27 | trees = self.batched_cky(batch_map, scalars) 28 | 29 | return trees 30 | 31 | def batched_cky(self, batch_map, scalars): 32 | sentences = batch_map['sentences'] 33 | batch_size = sentences.shape[0] 34 | length = sentences.shape[1] 35 | device = self.net.device 36 | dtype = torch.float32 37 | 38 | # Chart. 39 | chart = [torch.full((length-i, batch_size), 1, dtype=dtype, device=device) for i in range(length)] 40 | 41 | # Backpointers. 42 | bp = {} 43 | for ib in range(batch_size): 44 | bp[ib] = [[None] * (length - i) for i in range(length)] 45 | bp[ib][0] = [i for i in range(length)] 46 | 47 | for level in range(1, length): 48 | L = length - level 49 | N = level 50 | 51 | for pos in range(L): 52 | 53 | pairs, lps, rps, sps = [], [], [], [] 54 | 55 | # Assumes that the bottom-left most leaf is in the first constituent. 56 | spbatch = scalars[level][pos] 57 | 58 | for idx in range(N): 59 | # (level, pos) 60 | l_level = idx 61 | l_pos = pos 62 | r_level = level-idx-1 63 | r_pos = pos+idx+1 64 | 65 | assert l_level >= 0 66 | assert l_pos >= 0 67 | assert r_level >= 0 68 | assert r_pos >= 0 69 | 70 | l = (l_level, l_pos) 71 | r = (r_level, r_pos) 72 | 73 | lp = chart[l_level][l_pos].view(-1, 1) 74 | rp = chart[r_level][r_pos].view(-1, 1) 75 | sp = spbatch[:, idx].view(-1, 1) 76 | 77 | lps.append(lp) 78 | rps.append(rp) 79 | sps.append(sp) 80 | 81 | pairs.append((l, r)) 82 | 83 | lps, rps, sps = torch.cat(lps, 1), torch.cat(rps, 1), torch.cat(sps, 1) 84 | 85 | ps = lps + rps + sps 86 | argmax = ps.argmax(1).long() 87 | 88 | valmax = ps[range(batch_size), argmax] 89 | chart[level][pos, :] = valmax 90 | 91 | for i, ix in enumerate(argmax.tolist()): 92 | bp[i][level][pos] = pairs[ix] 93 | 94 | trees = [] 95 | for i in range(batch_size): 96 | tree = self.follow_backpointers(bp[i], bp[i][-1][0]) 97 | trees.append(tree) 98 | 99 | return trees 100 | 101 | def follow_backpointers(self, bp, pair): 102 | if isinstance(pair, int): 103 | return pair 104 | 105 | l, r = pair 106 | lout = self.follow_backpointers(bp, bp[l[0]][l[1]]) 107 | rout = self.follow_backpointers(bp, bp[r[0]][r[1]]) 108 | 109 | return (lout, rout) 110 | 111 | -------------------------------------------------------------------------------- /cliora/analysis/diora_tree.py: -------------------------------------------------------------------------------- 1 | 2 | def spans_to_tree(spans, tokens): 3 | length = len(tokens) 4 | 5 | # Add missing spans. 6 | span_set = set(spans) 7 | for pos in range(length): 8 | if pos not in span_set: 9 | spans.append((pos, 1)) 10 | 11 | spans = sorted(spans, key=lambda x: (x[1], x[0])) # pos, level 12 | pos_to_node = {} 13 | # root_node = None 14 | 15 | for i, span in enumerate(spans): 16 | 17 | pos, size = span 18 | 19 | if i < length: 20 | assert i == pos 21 | node = (pos, size, tokens[i]) 22 | pos_to_node[pos] = node 23 | continue 24 | 25 | node = (pos, size, []) 26 | 27 | for i_pos in range(pos, pos+size): 28 | child = pos_to_node[i_pos] 29 | c_pos, c_size = child[0], child[1] 30 | 31 | if i_pos == c_pos: 32 | node[2].append(child) 33 | pos_to_node[i_pos] = node 34 | 35 | def helper(node): 36 | pos, size, tok = node 37 | if isinstance(tok, int): 38 | return tok 39 | return tuple([helper(x) for x in tok]) 40 | 41 | root_node = pos_to_node[0] 42 | tree = helper(root_node) 43 | 44 | return tree 45 | 46 | 47 | class TreesFromDiora(object): 48 | def __init__(self, net): 49 | self.diora = net 50 | 51 | def to_spans(self, lst): 52 | return [(pos, level + 1) for level, pos in lst] 53 | 54 | def parse_batch(self, batch_map): 55 | batch_size, length = batch_map['sentences'].shape 56 | root_level = length - 1 57 | tokens = [i for i in range(length)] 58 | 59 | trees = [] 60 | for i_b in range(batch_size): 61 | spans = self.to_spans(self.diora.inside_tree[(i_b, 0)][(root_level, 0)]) 62 | binary_tree = spans_to_tree(spans, tokens) 63 | trees.append(binary_tree) 64 | return trees 65 | -------------------------------------------------------------------------------- /cliora/analysis/utils.py: -------------------------------------------------------------------------------- 1 | import types 2 | 3 | def get_actions(tree, SHIFT = 0, REDUCE = 1, OPEN='(', CLOSE=')'): 4 | #input tree in bracket form: ((A B) (C D)) 5 | #output action sequence: S S R S S R R 6 | actions = [] 7 | tree = tree.strip() 8 | i = 0 9 | num_shift = 0 10 | num_reduce = 0 11 | left = 0 12 | right = 0 13 | while i < len(tree): 14 | if tree[i] != ' ' and tree[i] != OPEN and tree[i] != CLOSE: #terminal 15 | if tree[i-1] == OPEN or tree[i-1] == ' ': 16 | actions.append(SHIFT) 17 | num_shift += 1 18 | elif tree[i] == CLOSE: 19 | actions.append(REDUCE) 20 | num_reduce += 1 21 | right += 1 22 | elif tree[i] == OPEN: 23 | left += 1 24 | i += 1 25 | assert(num_shift == num_reduce + 1) 26 | return actions 27 | 28 | 29 | def get_spans(actions, SHIFT = 0, REDUCE = 1): 30 | sent = list(range((len(actions)+1) // 2)) 31 | spans = [] 32 | pointer = 0 33 | stack = [] 34 | for action in actions: 35 | if action == SHIFT: 36 | word = sent[pointer] 37 | stack.append(word) 38 | pointer += 1 39 | elif action == REDUCE: 40 | right = stack.pop() 41 | left = stack.pop() 42 | if isinstance(left, int): 43 | left = (left, None) 44 | if isinstance(right, int): 45 | right = (None, right) 46 | new_span = (left[0], right[1]) 47 | spans.append(new_span) 48 | stack.append(new_span) 49 | return spans 50 | 51 | 52 | def get_stats(span1, span2): 53 | tp = 0 54 | fp = 0 55 | fn = 0 56 | for span in span1: 57 | if span in span2: 58 | tp += 1 59 | else: 60 | fp += 1 61 | for span in span2: 62 | if span not in span1: 63 | fn += 1 64 | return tp, fp, fn 65 | 66 | 67 | def override_init_with_batch(var): 68 | init_with_batch = var.init_with_batch 69 | 70 | def func(self, *args, **kwargs): 71 | init_with_batch(*args, **kwargs) 72 | self.saved_scalars = {i: {} for i in range(self.length)} 73 | self.saved_scalars_out = {i: {} for i in range(self.length)} 74 | 75 | var.init_with_batch = types.MethodType(func, var) 76 | 77 | 78 | def override_inside_hook(var): 79 | def func(self, level, h, c, s): 80 | length = self.length 81 | B = self.batch_size 82 | L = length - level 83 | 84 | assert s.shape[0] == B 85 | assert s.shape[1] == L 86 | # assert s.shape[2] == N 87 | assert s.shape[3] == 1 88 | assert len(s.shape) == 4 89 | smax = s.max(2, keepdim=True)[0] 90 | s = s - smax 91 | 92 | for pos in range(L): 93 | self.saved_scalars[level][pos] = s[:, pos, :] 94 | 95 | var.inside_hook = types.MethodType(func, var) -------------------------------------------------------------------------------- /cliora/blocks/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bobwan1995/cliora/b064bdf967d4ccc4f3327183efd888b927bfb4fb/cliora/blocks/__init__.py -------------------------------------------------------------------------------- /cliora/blocks/negative_sampler.py: -------------------------------------------------------------------------------- 1 | from collections import Counter 2 | 3 | import numpy as np 4 | import torch 5 | 6 | from tqdm import tqdm 7 | 8 | 9 | def choose_negative_samples(negative_sampler, k_neg): 10 | neg_samples = negative_sampler.sample(k_neg) 11 | neg_samples = torch.from_numpy(neg_samples) 12 | return neg_samples 13 | 14 | 15 | def calculate_freq_dist(data, vocab_size): 16 | # TODO: This becomes really slow on large datasets. 17 | counter = Counter() 18 | for i in range(vocab_size): 19 | counter[i] = 0 20 | for x in tqdm(data, desc='freq_dist'): 21 | counter.update(x) 22 | freq_dist = [v for k, v in sorted(counter.items(), key=lambda x: x[0])] 23 | freq_dist = np.asarray(freq_dist, dtype=np.float32) 24 | return freq_dist 25 | 26 | 27 | class NegativeSampler: 28 | def __init__(self, freq_dist, dist_power, epsilon=10**-2): 29 | self.dist = freq_dist ** dist_power + epsilon * (1/len(freq_dist)) 30 | self.dist = self.dist / sum(self.dist) # Final distribution should be normalized 31 | self.rng = np.random.RandomState() 32 | 33 | def set_seed(self, seed): 34 | self.rng.seed(seed) 35 | 36 | def sample(self, num_samples): 37 | return self.rng.choice(len(self.dist), num_samples, p=self.dist, replace=False) 38 | -------------------------------------------------------------------------------- /cliora/data/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bobwan1995/cliora/b064bdf967d4ccc4f3327183efd888b927bfb4fb/cliora/data/__init__.py -------------------------------------------------------------------------------- /cliora/data/batch_iterator.py: -------------------------------------------------------------------------------- 1 | from cliora.data.dataloader import FixedLengthBatchSampler, SimpleDataset, COCODataset, FlickrDataset 2 | from cliora.blocks.negative_sampler import choose_negative_samples 3 | 4 | import torch 5 | import numpy as np 6 | 7 | def get_config(config, **kwargs): 8 | for k, v in kwargs.items(): 9 | if k in config: 10 | config[k] = v 11 | return config 12 | 13 | 14 | def get_default_config(): 15 | 16 | default_config = dict( 17 | batch_size=16, 18 | forever=False, 19 | drop_last=False, 20 | sort_by_length=True, 21 | shuffle=True, 22 | random_seed=None, 23 | filter_length=None, 24 | workers=16, 25 | pin_memory=False, 26 | include_partial=False, 27 | cuda=False, 28 | ngpus=1, 29 | k_neg=3, 30 | negative_sampler=None, 31 | options_path=None, 32 | weights_path=None, 33 | vocab=None, 34 | length_to_size=None, 35 | rank=None, 36 | data_type=None, 37 | use_obj=False, 38 | mode=None, 39 | ) 40 | 41 | return default_config 42 | 43 | 44 | class BatchIterator(object): 45 | 46 | def __init__(self, sentences, extra={}, **kwargs): 47 | self.sentences = sentences 48 | self.config = get_config(get_default_config(), **kwargs) 49 | self.extra = extra 50 | self.loader = None 51 | self.get_dataset() 52 | 53 | def chunk(self, tensor, chunks, dim=0, i=0): 54 | if isinstance(tensor, torch.Tensor): 55 | return torch.chunk(tensor, chunks, dim=dim)[i] 56 | index = torch.chunk(torch.arange(len(tensor)), chunks, dim=dim)[i] 57 | return [tensor[ii] for ii in index] 58 | 59 | def partition(self, tensor, rank, device_ids): 60 | if tensor is None: 61 | return None 62 | if isinstance(tensor, dict): 63 | for k, v in tensor.items(): 64 | tensor[k] = self.partition(v, rank, device_ids) 65 | return tensor 66 | return self.chunk(tensor, len(device_ids), 0, rank) 67 | 68 | def get_dataset_size(self): 69 | return len(self.sentences) 70 | 71 | def get_dataset_minlen(self): 72 | return min(map(len, self.sentences)) 73 | 74 | def get_dataset_maxlen(self): 75 | return max(map(len, self.sentences)) 76 | 77 | def get_dataset_stats(self): 78 | return 'size={} minlen={} maxlen={}'.format( 79 | self.get_dataset_size(), self.get_dataset_minlen(), self.get_dataset_maxlen() 80 | ) 81 | 82 | def choose_negative_samples(self, negative_sampler, k_neg): 83 | return choose_negative_samples(negative_sampler, k_neg) 84 | 85 | def get_dataset(self): 86 | data_type = self.config.get('data_type') 87 | use_obj = self.config.get('use_obj') 88 | mode = self.config.get('mode') 89 | if use_obj and data_type == 'coco': 90 | dataset = COCODataset(self.sentences, self.extra['example_ids']) 91 | elif use_obj and data_type == 'flickr': 92 | dataset = FlickrDataset(self.sentences, self.extra['example_ids'], mode) 93 | else: 94 | dataset = SimpleDataset(self.sentences) 95 | self.dataset = dataset 96 | 97 | def get_iterator(self, **kwargs): 98 | config = get_config(self.config.copy(), **kwargs) 99 | 100 | random_seed = config.get('random_seed') 101 | batch_size = config.get('batch_size') 102 | filter_length = config.get('filter_length') 103 | pin_memory = config.get('pin_memory') 104 | include_partial = config.get('include_partial') 105 | cuda = config.get('cuda') 106 | ngpus = config.get('ngpus') 107 | rank = config.get('rank') 108 | k_neg = config.get('k_neg') 109 | negative_sampler = config.get('negative_sampler', None) 110 | workers = config.get('workers') 111 | length_to_size = config.get('length_to_size', None) 112 | # data_type = config.get('data_type') 113 | # use_obj = config.get('use_obj') 114 | # debug = config.get('debug') 115 | 116 | def collate_fn(batch): 117 | index, sents, obj_feats, boxes, obj_cates = zip(*batch) 118 | sents = torch.from_numpy(np.array(sents)).long() 119 | obj_feats = torch.from_numpy(np.array(obj_feats)) 120 | boxes = torch.from_numpy(np.array(boxes)) 121 | obj_cates = torch.from_numpy(np.array(obj_cates)).long() 122 | 123 | batch_map = {} 124 | batch_map['index'] = index 125 | batch_map['sents'] = sents 126 | batch_map['obj_feats'] = obj_feats 127 | batch_map['boxes'] = boxes 128 | batch_map['obj_cates'] = obj_cates 129 | 130 | for k, v in self.extra.items(): 131 | batch_map[k] = [v[idx] for idx in index] 132 | batch_map['image_feats'] = torch.from_numpy(np.array(batch_map['image_feats'])) 133 | 134 | if ngpus > 1: 135 | for k in batch_map.keys(): 136 | batch_map[k] = self.partition(batch_map[k], rank, range(ngpus)) 137 | 138 | return batch_map 139 | 140 | if self.loader is None: 141 | rng = np.random.RandomState(seed=random_seed) 142 | sampler = FixedLengthBatchSampler(self.dataset, batch_size=batch_size, rng=rng, 143 | maxlen=filter_length, include_partial=include_partial, length_to_size=length_to_size) 144 | loader = torch.utils.data.DataLoader(self.dataset, shuffle=(sampler is None), num_workers=workers, pin_memory=pin_memory,batch_sampler=sampler, collate_fn=collate_fn) 145 | self.loader = loader 146 | 147 | def myiterator(): 148 | 149 | for batch in self.loader: 150 | index = batch['index'] 151 | sentences = batch['sents'] 152 | obj_feats = batch['obj_feats'] 153 | boxes = batch['boxes'] 154 | obj_cates = batch['obj_cates'] 155 | 156 | batch_size, length = sentences.shape 157 | 158 | neg_samples = None 159 | if negative_sampler is not None: 160 | neg_samples = self.choose_negative_samples(negative_sampler, k_neg) 161 | 162 | if cuda: 163 | sentences = sentences.cuda() 164 | obj_feats = obj_feats.cuda() 165 | boxes = boxes.cuda() 166 | obj_cates = obj_cates.cuda() 167 | if cuda and neg_samples is not None: 168 | neg_samples = neg_samples.cuda() 169 | 170 | batch_map = {} 171 | batch_map['sentences'] = sentences 172 | batch_map['neg_samples'] = neg_samples 173 | batch_map['batch_size'] = batch_size 174 | batch_map['length'] = length 175 | batch_map['obj_feats'] = obj_feats 176 | batch_map['boxes'] = boxes 177 | batch_map['obj_cates'] = obj_cates 178 | 179 | for k, v in self.extra.items(): 180 | batch_map[k] = batch[k] 181 | 182 | yield batch_map 183 | 184 | return myiterator() 185 | 186 | -------------------------------------------------------------------------------- /cliora/data/dataloader.py: -------------------------------------------------------------------------------- 1 | import os 2 | import h5py 3 | import torch 4 | from torch.utils.data import Sampler 5 | 6 | import numpy as np 7 | import pickle as pkl 8 | import json 9 | from cliora.logging.configuration import get_logger 10 | 11 | class FixedLengthBatchSampler(Sampler): 12 | 13 | def __init__(self, data_source, batch_size, include_partial=False, rng=None, maxlen=None, 14 | length_to_size=None): 15 | self.data_source = data_source 16 | self.active = False 17 | if rng is None: 18 | rng = np.random.RandomState(seed=11) 19 | self.rng = rng 20 | self.batch_size = batch_size 21 | self.maxlen = maxlen 22 | self.include_partial = include_partial 23 | self.length_to_size = length_to_size 24 | self._batch_size_cache = { 0: self.batch_size } 25 | self.logger = get_logger() 26 | 27 | def get_batch_size(self, length): 28 | if self.length_to_size is None: 29 | return self.batch_size 30 | if length in self._batch_size_cache: 31 | return self._batch_size_cache[length] 32 | start = max(self._batch_size_cache.keys()) 33 | batch_size = self._batch_size_cache[start] 34 | for n in range(start+1, length+1): 35 | if n in self.length_to_size: 36 | batch_size = self.length_to_size[n] 37 | self._batch_size_cache[n] = batch_size 38 | return batch_size 39 | 40 | def reset(self): 41 | """ 42 | Create a map of {length: List[example_id]} and maintain how much of 43 | each list has been seen. 44 | 45 | If include_partial is False, then do not provide batches that are below 46 | the batch_size. 47 | 48 | If length_to_size is set, then batch size is determined by length. 49 | 50 | """ 51 | 52 | # Record the lengths of each example. 53 | length_map = {} 54 | for i in range(len(self.data_source)): 55 | x = self.data_source.dataset[i] 56 | length = len(x) 57 | 58 | if self.maxlen is not None and self.maxlen > 0 and length > self.maxlen: 59 | continue 60 | 61 | length_map.setdefault(length, []).append(i) 62 | 63 | # Shuffle the order. 64 | for length in length_map.keys(): 65 | self.rng.shuffle(length_map[length]) 66 | 67 | # Initialize state. 68 | state = {} 69 | for length, arr in length_map.items(): 70 | batch_size = self.get_batch_size(length) 71 | nbatches = len(arr) // batch_size 72 | surplus = nbatches * batch_size < len(arr) 73 | state[length] = dict(nbatches=nbatches, surplus=surplus, position=-1) 74 | 75 | # Batch order, in terms of length. 76 | order = [] 77 | for length, v in state.items(): 78 | order += [length] * v['nbatches'] 79 | 80 | ## Optionally, add partial batches. 81 | if self.include_partial: 82 | for length, v in state.items(): 83 | if v['surplus']: 84 | order += [length] 85 | 86 | self.rng.shuffle(order) 87 | 88 | self.length_map = length_map 89 | self.state = state 90 | self.order = order 91 | self.index = -1 92 | 93 | def get_next_batch(self): 94 | index = self.index + 1 95 | 96 | length = self.order[index] 97 | batch_size = self.get_batch_size(length) 98 | position = self.state[length]['position'] + 1 99 | start = position * batch_size 100 | batch_index = self.length_map[length][start:start+batch_size] 101 | 102 | self.state[length]['position'] = position 103 | self.index = index 104 | 105 | return batch_index 106 | 107 | def __iter__(self): 108 | self.reset() 109 | for _ in range(len(self)): 110 | yield self.get_next_batch() 111 | 112 | def __len__(self): 113 | return len(self.order) 114 | 115 | 116 | class SimpleDataset(torch.utils.data.Dataset): 117 | 118 | def __init__(self, dataset): 119 | self.dataset = dataset 120 | 121 | def __getitem__(self, index): 122 | item = self.dataset[index] 123 | return index, item, np.zeros(1), np.zeros(1), np.zeros(1) 124 | 125 | def __len__(self): 126 | return len(self.dataset) 127 | 128 | 129 | class COCODataset(torch.utils.data.Dataset): 130 | 131 | def __init__(self, dataset, img_ids=None): 132 | self.dataset = dataset 133 | self.img_ids = img_ids 134 | self.data_path = './coco_data' 135 | 136 | def __getitem__(self, index): 137 | item = self.dataset[index] 138 | # img_id = self.img_ids[index] 139 | # obj_data = np.load(os.path.join(self.data_path, 'det_feats/{}.npy'.format(img_id)), allow_pickle=False) 140 | # obj_feats = obj_data[:, :-4] 141 | # boxes = obj_data[:, -4:] 142 | 143 | obj_feats = np.zeros(1).astype(np.int32) - 1 144 | boxes = np.zeros(1).astype(np.int32) - 1 145 | obj_cates = np.zeros(1).astype(np.int32) - 1 146 | return index, item, obj_feats, boxes, obj_cates 147 | 148 | def __len__(self): 149 | return len(self.dataset) 150 | 151 | 152 | # class FlickrDataset(torch.utils.data.Dataset): 153 | # 154 | # def __init__(self, dataset, img_ids=None, mode='train'): 155 | # self.dataset = dataset 156 | # self.img_ids = img_ids 157 | # self.mode = mode 158 | # self.data_path = './flickr_data/' 159 | # 160 | # def __getitem__(self, index): 161 | # item = self.dataset[index] 162 | # img_id = self.img_ids[index] 163 | # 164 | # obj_feats = np.zeros([36, 2048]).astype(np.float32) 165 | # boxes = np.zeros([36, 4]).astype(np.float32) - 1 166 | # obj_cates = np.zeros([36]).astype(np.int32) - 1 167 | # 168 | # if self.pkl: 169 | # obj_data = self.det_res[img_id] 170 | # num_box = min(len(obj_data['bbox']), 36) 171 | # obj_feats[:num_box] = obj_data['feats'][:num_box] 172 | # boxes[:num_box] = obj_data['bbox'][:num_box] 173 | # obj_cates[:num_box] = obj_data['class'][:num_box] 174 | # else: 175 | # obj_data = np.load(self.data_path+'flickr_feat/{}.npy'.format(img_id), allow_pickle=False) 176 | # num_box = min(obj_data.shape[0], 36) 177 | # 178 | # obj_feats[:num_box] = obj_data[:num_box, :2048].astype(np.float32) 179 | # boxes[:num_box] = obj_data[:num_box, 2048:-1].astype(np.float32) 180 | # obj_cates[:num_box] = obj_data[:num_box, -1].astype(np.int32) 181 | # 182 | # return index, item, obj_feats, boxes, obj_cates 183 | # 184 | # def __len__(self): 185 | # return len(self.dataset) 186 | 187 | 188 | class FlickrDataset(torch.utils.data.Dataset): 189 | 190 | def __init__(self, dataset, img_ids=None, mode='train'): 191 | self.dataset = dataset 192 | self.img_ids = img_ids 193 | self.mode = mode 194 | self.data_path = './flickr_data/flickr_feat_maf/' 195 | self.imgid2idx = pkl.load(open(self.data_path+f"{mode}_imgid2idx.pkl", "rb")) 196 | self.detection_dict = json.load(open(self.data_path+f"{mode}_detection_dict.json")) 197 | obj_vocab = open(self.data_path+"objects_vocab.txt").readlines() 198 | self.obj2ind = {obj.strip():idx for idx,obj in enumerate(obj_vocab)} 199 | with h5py.File(self.data_path+f"{mode}_features_compress.hdf5", "r") as hdf5_file: 200 | self.features = np.array(hdf5_file.get("features")) 201 | self.predicted_boxes = np.array(hdf5_file.get("bboxes")) 202 | self.indexes = np.array(hdf5_file.get("pos_bboxes")) 203 | 204 | def __getitem__(self, index): 205 | item = self.dataset[index] 206 | 207 | img_id = self.img_ids[index] 208 | feat_index = self.imgid2idx[int(img_id)] 209 | start_end_index = self.indexes[feat_index] 210 | num_box = min(start_end_index[1] - start_end_index[0], 36) 211 | # Get boxes 212 | boxes = np.zeros([36, 4]).astype(np.float32) - 1 213 | boxes[:num_box] = self.predicted_boxes[start_end_index[0] : start_end_index[1]][:num_box] 214 | # Get features 215 | obj_feats = np.zeros([36, 2048]).astype(np.float32) 216 | obj_feats[:num_box] = self.features[start_end_index[0] : start_end_index[1]][:num_box] 217 | # Get classes 218 | obj_cates = np.zeros([36]).astype(np.int32) - 1 219 | obj_cates[:num_box] = np.array([self.obj2ind.get(i) for i in 220 | self.detection_dict[img_id]["classes"]]).astype(np.int32)[:num_box] 221 | 222 | return index, item, obj_feats, boxes, obj_cates 223 | 224 | def __len__(self): 225 | return len(self.dataset) 226 | -------------------------------------------------------------------------------- /cliora/data/dataset.py: -------------------------------------------------------------------------------- 1 | from collections import deque 2 | import pickle 3 | import torch 4 | import numpy as np 5 | 6 | from tqdm import tqdm 7 | 8 | from cliora.data.reading import NLIReader, PlainTextReader, ConllReader, JSONLReader, PTBReader, COCOReader, FlickrReader 9 | from cliora.data.batch_iterator import BatchIterator 10 | from cliora.data.embeddings import EmbeddingsReader, UNK_TOKEN 11 | from cliora.data.preprocessing import indexify, build_text_vocab 12 | from cliora.data.preprocessing import synthesize_training_data 13 | from cliora.logging.configuration import get_logger 14 | from cliora.blocks.negative_sampler import NegativeSampler, calculate_freq_dist 15 | 16 | class ConsolidateDatasets(object): 17 | """ 18 | A class for consolidating many datasets. 19 | """ 20 | 21 | def __init__(self, datasets): 22 | self.datasets = datasets 23 | 24 | def reindex(self, sentences, inverse_mapping): 25 | def fn(s): 26 | for idx in s: 27 | yield inverse_mapping[idx] 28 | def queue(lst): 29 | q = deque(lst) 30 | while len(q) > 0: 31 | yield q.popleft() 32 | return [list(fn(s)) for s in tqdm(queue(sentences), desc='reindex')] 33 | 34 | def remap_embeddings(self, datasets, inverse_mapping_lst, master_word2idx): 35 | size = datasets[0]['embeddings'].shape[1] 36 | embeddings = np.zeros((len(master_word2idx), size), dtype=np.float32) 37 | for dset, old2master in zip(datasets, inverse_mapping_lst): 38 | idx_from, idx_to = zip(*old2master.items()) 39 | embeddings[np.asarray(idx_to)] = dset['embeddings'][np.asarray(idx_from)] 40 | return embeddings 41 | 42 | def consolidate_word2idx(self, word2idx_lst): 43 | master_word2idx = {} 44 | inverse_mapping_lst = [] 45 | 46 | for w2i in word2idx_lst: 47 | old2master = {} 48 | for w, idx in w2i.items(): 49 | if w not in master_word2idx: 50 | master_word2idx[w] = len(master_word2idx) 51 | old2master[idx] = master_word2idx[w] 52 | inverse_mapping_lst.append(old2master) 53 | 54 | return master_word2idx, inverse_mapping_lst 55 | 56 | def run(self): 57 | word2idx_lst = [x['word2idx'] for x in self.datasets] 58 | master_word2idx, inverse_mapping_lst = self.consolidate_word2idx(word2idx_lst) 59 | embeddings = self.remap_embeddings(self.datasets, inverse_mapping_lst, master_word2idx) 60 | for dset, inverse_mapping in zip(self.datasets, inverse_mapping_lst): 61 | dset['sentences'] = self.reindex(dset['sentences'], inverse_mapping) 62 | dset['word2idx'] = master_word2idx 63 | dset['embeddings'] = embeddings 64 | 65 | 66 | class ReaderManager(object): 67 | def __init__(self, reader): 68 | super(ReaderManager, self).__init__() 69 | self.reader = reader 70 | self.logger = get_logger() 71 | 72 | def run(self, options, text_path, embeddings_path): 73 | reader = self.reader 74 | logger = self.logger 75 | 76 | logger.info('Reading text: {}'.format(text_path)) 77 | reader_result = reader.read(text_path) 78 | sentences = reader_result['sentences'] 79 | extra = reader_result['extra'] 80 | metadata = reader_result.get('metadata', {}) 81 | logger.info('len(sentences)={}'.format(len(sentences))) 82 | 83 | if 'word2idx' in metadata: 84 | word2idx = metadata['word2idx'] 85 | else: 86 | word2idx = build_text_vocab(sentences) 87 | logger.info('len(vocab)={}'.format(len(word2idx))) 88 | 89 | if 'embeddings' in metadata: 90 | logger.info('Using embeddings from metadata.') 91 | embeddings = metadata['embeddings'] 92 | del metadata['embeddings'] 93 | else: 94 | logger.info('Reading embeddings.') 95 | embeddings, word2idx = EmbeddingsReader().get_embeddings( 96 | options, embeddings_path, word2idx) 97 | 98 | unk_index = word2idx.get(UNK_TOKEN, None) 99 | logger.info('Converting tokens to indexes (unk_index={}).'.format(unk_index)) 100 | sentences = indexify(sentences, word2idx, unk_index) 101 | 102 | return { 103 | "sentences": sentences, 104 | "embeddings": embeddings, 105 | "word2idx": word2idx, 106 | "extra": extra, 107 | "metadata": metadata, 108 | } 109 | 110 | 111 | class ReconstructDataset(object): 112 | 113 | def initialize(self, options, text_path=None, embeddings_path=None, filter_length=0, data_type=None): 114 | if data_type == 'coco': 115 | reader = COCOReader(lowercase=options.lowercase, filter_length=filter_length) 116 | elif data_type == 'flickr': 117 | reader = FlickrReader(lowercase=options.lowercase, filter_length=filter_length) 118 | else: 119 | raise NotImplementedError 120 | 121 | manager = ReaderManager(reader) 122 | result = manager.run(options, text_path, embeddings_path) 123 | 124 | return result 125 | 126 | 127 | def make_batch_iterator(options, dset, mode=None, shuffle=True, include_partial=False, filter_length=0, 128 | batch_size=None, length_to_size=None): 129 | sentences = dset['sentences'] 130 | word2idx = dset['word2idx'] 131 | extra = dset['extra'] 132 | # metadata = dset['metadata'] 133 | 134 | cuda = options.cuda 135 | multigpu = options.multigpu 136 | ngpus = 1 137 | if cuda and multigpu: 138 | ngpus = torch.cuda.device_count() 139 | 140 | vocab_size = len(word2idx) 141 | 142 | negative_sampler = None 143 | if options.reconstruct_mode in ('margin', 'softmax', 'vl_feat_softmax', 'vl_belief_softmax'): 144 | freq_dist = calculate_freq_dist(sentences, vocab_size) 145 | negative_sampler = NegativeSampler(freq_dist=freq_dist, dist_power=options.freq_dist_power) 146 | vocab_lst = [w for w, _ in sorted(word2idx.items(), key=lambda x: x[1])] 147 | 148 | batch_iterator = BatchIterator( 149 | sentences, extra=extra, mode=mode, use_obj=options.obj_feats, data_type=options.data_type, shuffle=shuffle, include_partial=include_partial, 150 | filter_length=filter_length, batch_size=batch_size, rank=options.local_rank, 151 | cuda=cuda, ngpus=ngpus, negative_sampler=negative_sampler, 152 | vocab=vocab_lst, k_neg=options.k_neg, 153 | options_path=options.elmo_options_path, 154 | weights_path=options.elmo_weights_path, 155 | length_to_size=length_to_size 156 | ) 157 | 158 | # DIRTY HACK: Makes it easier to print examples later. Should really wrap this within the class. 159 | batch_iterator.word2idx = word2idx 160 | 161 | return batch_iterator -------------------------------------------------------------------------------- /cliora/data/embeddings.py: -------------------------------------------------------------------------------- 1 | import os 2 | import hashlib 3 | from collections import OrderedDict 4 | 5 | from cliora.logging.configuration import get_logger 6 | 7 | import torch 8 | import numpy as np 9 | from cliora.external.standalone_elmo import batch_to_ids, ElmoCharacterEncoder, remove_sentence_boundaries 10 | from tqdm import tqdm 11 | import pickle 12 | 13 | # With loaded embedding matrix, the padding vector will be initialized to zero 14 | # and will not be trained. Hopefully this isn't a problem. It seems better than 15 | # random initialization... 16 | PADDING_TOKEN = "_PAD" 17 | 18 | UNK_TOKEN = "_" 19 | 20 | EXISTING_VOCAB_TOKEN = "unused-token-a7g39i" 21 | 22 | 23 | def maybe_download(remote_url, cache_dir): 24 | path = os.path.join(cache_dir, os.path.basename(remote_url)) 25 | if not os.path.exists(path): 26 | os.system(f'curl {remote_url} -o {path} -L') 27 | return path 28 | 29 | 30 | class ElmoEmbedder(object): 31 | def __init__(self, options_file, weights_file, cache_dir, cuda=False): 32 | logger = get_logger() 33 | logger.info('Initialize ELMo Model.') 34 | 35 | self.char_embedder = ElmoCharacterEncoder( 36 | options_file=maybe_download(options_file, cache_dir=cache_dir), 37 | weight_file=maybe_download(weights_file, cache_dir=cache_dir), 38 | requires_grad=False) 39 | 40 | if cuda: 41 | self.char_embedder.cuda() 42 | 43 | self.cuda = cuda 44 | self.cache_dir = cache_dir 45 | 46 | def __call__(self, word2idx): 47 | """ 48 | 1. Sort tokens alphabetically by `word` from `word2idx`. 49 | 2. Embed the newly sorted tokens. 50 | 3. Re-order embeddings according to `idx` from `word2idx`. 51 | Will skip step (2) if there is a previously cached version of embeddings. 52 | """ 53 | 54 | logger = get_logger() 55 | 56 | def sort_by_tok(item): 57 | tok, idx = item 58 | return tok 59 | 60 | def sort_by_idx(item): 61 | tok, idx = item 62 | return idx 63 | 64 | size = 512 65 | batch_size = 1024 66 | 67 | # 1. Sort tokens alphabetically by `word` from `word2idx`. 68 | tokens = [tok for tok, idx in sorted(word2idx.items(), key=sort_by_tok)] 69 | 70 | # 2. Embed the newly sorted tokens. 71 | vocab_identifier = hash_tokens(tokens) 72 | embeddings_file = os.path.join(self.cache_dir, f'elmo_{vocab_identifier}.npy') 73 | shape = (len(tokens), size) 74 | if os.path.exists(embeddings_file): 75 | logger.info('Loading cached elmo vectors: {}'.format(embeddings_file)) 76 | embeddings = np.load(embeddings_file) 77 | assert embeddings.shape == shape 78 | 79 | else: 80 | logger.info('Begin caching vectors. shape = {}, cuda = {}'.format(shape, self.cuda)) 81 | 82 | embeddings = np.zeros(shape, dtype=np.float32) 83 | 84 | for start in tqdm(range(0, len(tokens), batch_size), desc='embed'): 85 | end = min(start + batch_size, len(tokens)) 86 | batch = tokens[start:end] 87 | batch_ids = batch_to_ids([[x] for x in batch]) 88 | if self.cuda: 89 | batch_ids = batch_ids.cuda() 90 | output = self.char_embedder(batch_ids) 91 | vec = remove_sentence_boundaries(output['token_embedding'], output['mask'])[0].squeeze(1) 92 | 93 | embeddings[start:end] = vec.cpu().numpy() 94 | 95 | # Cache embeddings. 96 | logger.info('Saving cached elmo vectors: {}'.format(embeddings_file)) 97 | np.save(embeddings_file, embeddings) 98 | 99 | # 3. Re-order embeddings according to `idx` from `word2idx`. 100 | sorted_word2idx = {tok: idx for idx, tok in enumerate(tokens)} 101 | index = [sorted_word2idx[tok] for tok, idx in sorted(word2idx.items(), key=sort_by_idx)] 102 | old_embeddings = embeddings 103 | embeddings = embeddings[index] 104 | 105 | # Duplicate embeddings. This is meant to mirror behavior in elmo, which has separate embeddings 106 | # for forward and backward LSTMs. 107 | embeddings = np.concatenate([embeddings, embeddings], 1) 108 | 109 | return embeddings 110 | 111 | class EmbeddingsReader(object): 112 | 113 | def read_glove(self, *args, **kwargs): 114 | return read_glove(*args, **kwargs) 115 | 116 | def get_emb_w2v(self, options, embeddings_path, word2idx): 117 | embeddings, word2idx = self.read_glove(embeddings_path, word2idx) 118 | return embeddings, word2idx 119 | 120 | def get_emb_elmo(self, options, embeddings_path, word2idx): 121 | elmo_encoder = ElmoEmbedder( 122 | options_file=options.elmo_options_path, 123 | weights_file=options.elmo_weights_path, 124 | cache_dir=options.elmo_cache_dir, 125 | cuda=options.cuda) 126 | embeddings = elmo_encoder(word2idx) 127 | return embeddings, word2idx 128 | 129 | def get_emb_skip(self, options, embeddings_path, word2idx): 130 | all_embeddings = pickle.load(open(embeddings_path, 'rb')) 131 | embeddings = np.zeros((len(word2idx), 620), dtype=np.float32) 132 | pad_emb = all_embeddings.get('a') 133 | for w, idx in word2idx.items(): 134 | embeddings[idx] = all_embeddings.get(w, pad_emb) 135 | return embeddings, word2idx 136 | 137 | def get_emb_both(self, options, embeddings_path, word2idx): 138 | e_w2v, w2i_w2v = self.get_emb_w2v(options, embeddings_path, word2idx) 139 | e_elmo, w2i_elmo = self.get_emb_elmo(options, embeddings_path, word2idx) 140 | 141 | vec_size = e_w2v.shape[1] + e_elmo.shape[1] 142 | vocab = [w for w, i in sorted(w2i_w2v.items(), key=lambda x: x[1]) if w in w2i_elmo] 143 | vocab_size = len(vocab) 144 | 145 | embeddings = np.zeros((vocab_size, vec_size), dtype=np.float32) 146 | word2idx = {w: i for i, w in enumerate(vocab)} 147 | 148 | for w, i in word2idx.items(): 149 | embeddings[i, :e_w2v.shape[1]] = e_w2v[w2i_w2v[w]] 150 | embeddings[i, e_w2v.shape[1]:] = e_elmo[w2i_elmo[w]] 151 | 152 | return embeddings, word2idx 153 | 154 | def get_embeddings(self, options, embeddings_path, word2idx): 155 | if options.emb == 'w2v': 156 | out = self.get_emb_w2v(options, embeddings_path, word2idx) 157 | elif options.emb == 'skip': 158 | out = self.get_emb_skip(options, embeddings_path, word2idx) 159 | elif options.emb == 'elmo': 160 | out = self.get_emb_elmo(options, embeddings_path, word2idx) 161 | elif options.emb == 'both': 162 | out = self.get_emb_both(options, embeddings_path, word2idx) 163 | elif options.emb == 'none': 164 | out = (torch.nn.Embedding(len(word2idx), 1024), word2idx) 165 | else: 166 | raise NotImplementedError 167 | return out 168 | 169 | 170 | def read_glove(filename, word2idx): 171 | """ 172 | Two cases: 173 | 174 | 1. The word2idx has already been filtered according to embedding vocabulary. 175 | 2. The word2idx is derived solely from the raw text data. 176 | 177 | """ 178 | logger = get_logger() 179 | 180 | glove_vocab = set() 181 | size = None 182 | 183 | validate_word2idx(word2idx) 184 | 185 | logger.info('Reading Glove Vocab.') 186 | 187 | with open(filename) as f: 188 | for i, line in enumerate(f): 189 | word, vec = line.split(' ', 1) 190 | glove_vocab.add(word) 191 | 192 | if i == 0: 193 | size = len(vec.strip().split(' ')) 194 | 195 | new_vocab = set.intersection(set(word2idx.keys()), glove_vocab) 196 | new_vocab.discard(PADDING_TOKEN) 197 | new_vocab.discard(UNK_TOKEN) 198 | 199 | if word2idx.get(EXISTING_VOCAB_TOKEN, None) == 2: 200 | new_word2idx = word2idx.copy() 201 | 202 | logger.info('Using existing vocab mapping.') 203 | else: 204 | new_word2idx = OrderedDict() 205 | new_word2idx[PADDING_TOKEN] = len(new_word2idx) 206 | new_word2idx[UNK_TOKEN] = len(new_word2idx) 207 | new_word2idx[EXISTING_VOCAB_TOKEN] = len(new_word2idx) 208 | 209 | for w, _ in word2idx.items(): 210 | if w in new_word2idx: 211 | continue 212 | new_word2idx[w] = len(new_word2idx) 213 | 214 | logger.info('Creating new mapping.') 215 | 216 | logger.info('glove-vocab-size={} vocab-size={} intersection-size={} (-{})'.format( 217 | len(glove_vocab), len(word2idx), len(new_vocab), len(word2idx) - len(new_vocab))) 218 | 219 | embeddings = np.zeros((len(new_word2idx), size), dtype=np.float32) 220 | 221 | logger.info('Reading Glove Embeddings.') 222 | 223 | with open(filename) as f: 224 | for line in f: 225 | word, vec = line.strip().split(' ', 1) 226 | 227 | if word is PADDING_TOKEN or word is UNK_TOKEN: 228 | continue 229 | 230 | if word in new_vocab and word not in new_word2idx: 231 | raise ValueError 232 | 233 | if word not in new_word2idx: 234 | continue 235 | 236 | word_id = new_word2idx[word] 237 | vec = np.fromstring(vec, dtype=float, sep=' ') 238 | embeddings[word_id] = vec 239 | 240 | validate_word2idx(new_word2idx) 241 | 242 | return embeddings, new_word2idx 243 | 244 | 245 | def validate_word2idx(word2idx): 246 | vocab = [w for w, i in sorted(word2idx.items(), key=lambda x: x[1])] 247 | for i, w in enumerate(vocab): 248 | assert word2idx[w] == i 249 | 250 | def validate_word_order(tokens): 251 | """ 252 | Verify tokens are in sorted order. 253 | """ 254 | for w0, w1 in zip(tokens, sorted(tokens)): 255 | assert w0 == w1 256 | 257 | def hash_tokens(tokens): 258 | validate_word_order(tokens) 259 | 260 | m = hashlib.sha256() 261 | for w in tokens: 262 | m.update(str.encode(w)) 263 | return m.hexdigest() 264 | 265 | 266 | -------------------------------------------------------------------------------- /cliora/data/preprocessing.py: -------------------------------------------------------------------------------- 1 | import math 2 | import random 3 | from collections import Counter, OrderedDict 4 | 5 | import numpy as np 6 | from tqdm import tqdm 7 | 8 | 9 | DEFAULT_UNK_INDEX = 1 10 | 11 | 12 | def set_random_seed(seed): 13 | random.seed(seed) 14 | np.random.seed(seed) 15 | 16 | 17 | def build_text_vocab(sentences, word2idx=None): 18 | word2idx = OrderedDict() if word2idx is None else word2idx.copy() 19 | for s in sentences: 20 | for w in s: 21 | if w not in word2idx: 22 | word2idx[w] = len(word2idx) 23 | return word2idx 24 | 25 | 26 | def indexify(sentences, word2idx, unk_index=None): 27 | def fn(s): 28 | for w in s: 29 | if w not in word2idx and unk_index is None: 30 | raise ValueError 31 | yield word2idx.get(w, unk_index) 32 | return [list(fn(s)) for s in tqdm(sentences, desc='indexify')] 33 | 34 | 35 | def batchify(examples, batch_size): 36 | sorted_examples = list(sorted(examples, key=lambda x: len(x))) 37 | num_batches = int(math.ceil(len(examples) / batch_size)) 38 | batches = [] 39 | 40 | for i in range(num_batches): 41 | start = i * batch_size 42 | end = start + batch_size 43 | batch = sorted_examples[start:end] 44 | batches.append(pad(batch)) 45 | 46 | return batches 47 | 48 | 49 | def pad(examples, padding_token=0): 50 | def convert2numpy(batch): 51 | # Note that this is tranposed to have dimensions (batch_size, sentence_length). 52 | return np.array(batch, dtype=np.int32).T 53 | 54 | maxlength = np.max([len(x) for x in examples]) 55 | batch = [] 56 | 57 | for x in examples: 58 | diff = maxlength - len(x) 59 | padded = [0] * diff + x 60 | batch.append(padded) 61 | 62 | return convert2numpy(batch) 63 | 64 | 65 | def batch_iterator(dataset, batch_size, seed=None, drop_last=False): 66 | if seed is not None: 67 | set_random_seed(seed) 68 | 69 | nexamples = len(dataset) 70 | nbatches = math.ceil(nexamples/batch_size) 71 | index = random.sample(range(nexamples), nexamples) 72 | 73 | for i in range(nbatches): 74 | start = i * batch_size 75 | end = start + batch_size 76 | if end > nexamples and drop_last: 77 | break 78 | 79 | batch = [dataset[i] for i in index[start:end]] 80 | yield batch 81 | 82 | 83 | def prepare_batch(batch): 84 | return batch 85 | 86 | 87 | def synthesize_training_data(nexamples, vocab_size, min_length=10, max_length=30, seed=None): 88 | if seed is not None: 89 | set_random_seed(seed) 90 | 91 | dataset = [] 92 | 93 | for i in range(nexamples): 94 | length = np.random.randint(min_length, max_length) 95 | example = np.random.randint(0, vocab_size, size=length).tolist() 96 | dataset.append(example) 97 | 98 | return dataset 99 | -------------------------------------------------------------------------------- /cliora/data/reading.py: -------------------------------------------------------------------------------- 1 | """ 2 | Each reader should return: 3 | 4 | - sentences - This is the primary input (raw text) to the model. Not tokenized. 5 | - extra - Additional model input such as entity or sentence labels. 6 | - metadata - Info about the data that is not specific to examples / batches. 7 | 8 | """ 9 | 10 | import os 11 | import json 12 | import pickle 13 | from tqdm import tqdm 14 | import numpy as np 15 | import torch 16 | 17 | def pick(lst, k): 18 | return [d[k] for d in lst] 19 | 20 | 21 | def flatten_tree(tr): 22 | def func(tr): 23 | if not isinstance(tr, (list, tuple)): 24 | return [tr] 25 | result = [] 26 | for x in tr: 27 | result += func(x) 28 | return result 29 | return func(tr) 30 | 31 | 32 | def convert_binary_bracketing(parse, lowercase=True): 33 | transitions = [] 34 | tokens = [] 35 | 36 | for word in parse.split(' '): 37 | if word[0] != "(": 38 | if word == ")": 39 | transitions.append(1) 40 | else: 41 | # Downcase all words to match GloVe. 42 | if lowercase: 43 | tokens.append(word.lower()) 44 | else: 45 | tokens.append(word) 46 | transitions.append(0) 47 | return tokens, transitions 48 | 49 | 50 | def build_tree(tokens, transitions): 51 | stack = [] 52 | buf = tokens[::-1] 53 | 54 | for t in transitions: 55 | if t == 0: 56 | stack.append(buf.pop()) 57 | elif t == 1: 58 | right = stack.pop() 59 | left = stack.pop() 60 | stack.append((left, right)) 61 | 62 | assert len(stack) == 1 63 | 64 | return stack[0] 65 | 66 | 67 | def get_spans_and_siblings(tree): 68 | def helper(tr, idx=0, name='root'): 69 | if isinstance(tr, (str, int)): 70 | return 1, [(idx, idx+1)], [] 71 | 72 | l_size, l_spans, l_sibs = helper(tr[0], name='l', idx=idx) 73 | r_size, r_spans, r_sibs = helper(tr[1], name='r', idx=idx+l_size) 74 | 75 | size = l_size + r_size 76 | 77 | # Siblings. 78 | spans = [(idx, idx+size)] + l_spans + r_spans 79 | siblings = [(l_spans[0], r_spans[0], name)] + l_sibs + r_sibs 80 | 81 | return size, spans, siblings 82 | 83 | _, spans, siblings = helper(tree) 84 | 85 | return spans, siblings 86 | 87 | 88 | def get_spans(tree): 89 | def helper(tr, idx=0): 90 | if isinstance(tr, (str, int)): 91 | return 1, [] 92 | 93 | spans = [] 94 | sofar = idx 95 | 96 | for subtree in tr: 97 | size, subspans = helper(subtree, idx=sofar) 98 | spans += subspans 99 | sofar += size 100 | 101 | size = sofar - idx 102 | spans += [(idx, sofar)] 103 | 104 | return size, spans 105 | 106 | _, spans = helper(tree) 107 | 108 | return spans 109 | 110 | 111 | class BaseTextReader(object): 112 | def __init__(self, lowercase=True, filter_length=0, include_id=False): 113 | self.lowercase = lowercase 114 | self.filter_length = filter_length if filter_length is not None else 0 115 | self.include_id = include_id 116 | 117 | def read(self, filename): 118 | return self.read_sentences(filename) 119 | 120 | def read_sentences(self, filename): 121 | sentences = [] 122 | extra = dict() 123 | 124 | example_ids = [] 125 | 126 | with open(filename) as f: 127 | for line in tqdm(f, desc='read'): 128 | for s in self.read_line(line): 129 | if self.filter_length > 0 and len(s) > self.filter_length: 130 | continue 131 | if self.include_id: 132 | example_id = s[0] 133 | s = s[1:] 134 | else: 135 | example_id = len(sentences) 136 | if self.lowercase: 137 | s = [w.lower() for w in s] 138 | example_ids.append(example_id) 139 | sentences.append(s) 140 | 141 | extra['example_ids'] = example_ids 142 | 143 | return { 144 | "sentences": sentences, 145 | "extra": extra 146 | } 147 | 148 | def read_line(self, line): 149 | raise NotImplementedError 150 | 151 | 152 | class PlainTextReader(BaseTextReader): 153 | def __init__(self, lowercase=True, filter_length=0, delim=' ', include_id=False): 154 | super(PlainTextReader, self).__init__(lowercase=lowercase, filter_length=filter_length, include_id=include_id) 155 | self.delim = delim 156 | 157 | def read_line(self, line): 158 | s = line.strip().split(self.delim) 159 | if self.lowercase: 160 | s = [w.lower() for w in s] 161 | yield s 162 | 163 | 164 | class JSONLReader(object): 165 | def __init__(self, lowercase=True, filter_length=0, delim=' ', include_id=False): 166 | self.lowercase = lowercase 167 | self.filter_length = filter_length if filter_length is not None else 0 168 | 169 | def read(self, filename): 170 | sentences = [] 171 | 172 | # extra 173 | extra = dict() 174 | example_ids = [] 175 | trees = [] 176 | 177 | # read 178 | with open(filename) as f: 179 | for line in tqdm(f, desc='read'): 180 | ex = json.loads(line) 181 | example_id = ex['example_id'] 182 | tr = ex['tree'] 183 | if not 'sentence' in ex: 184 | ex['sentence'] = flatten_tree(tr) 185 | s = ex['sentence'] 186 | 187 | if self.filter_length > 0 and len(s) > self.filter_length: 188 | continue 189 | if self.lowercase: 190 | s = [w.lower() for w in s] 191 | 192 | example_ids.append(example_id) 193 | sentences.append(s) 194 | trees.append(tr) 195 | 196 | extra['example_ids'] = example_ids 197 | extra['trees'] = trees 198 | 199 | return { 200 | "sentences": sentences, 201 | "extra": extra 202 | } 203 | 204 | 205 | class NLIReader(object): 206 | 207 | LABEL_MAP = { 208 | "entailment": 0, 209 | "neutral": 1, 210 | "contradiction": 2 211 | } 212 | 213 | def __init__(self, lowercase=True, filter_length=0): 214 | self.lowercase = lowercase 215 | self.filter_length = filter_length if filter_length is not None else 0 216 | 217 | @staticmethod 218 | def build(lowercase=True, filter_length=0): 219 | return NLISentenceReader(lowercase=True, filter_length=0) 220 | 221 | def read(self, filename): 222 | return self.read_sentences(filename) 223 | 224 | def read_sentences(self, filename): 225 | raise NotImplementedError 226 | 227 | def read_line(self, line): 228 | example = json.loads(line) 229 | 230 | try: 231 | label = self.read_label(example['gold_label']) 232 | except: 233 | return None 234 | 235 | s1, t1 = convert_binary_bracketing(example['sentence1_binary_parse'], lowercase=self.lowercase) 236 | s2, t2 = convert_binary_bracketing(example['sentence2_binary_parse'], lowercase=self.lowercase) 237 | example_id = example['pairID'] 238 | 239 | return dict(s1=s1, label=label, s2=s2, t1=t1, t2=t2, example_id=example_id) 240 | 241 | def read_label(self, label): 242 | return self.LABEL_MAP[label] 243 | 244 | 245 | class NLISentenceReader(NLIReader): 246 | def read_sentences(self, filename): 247 | sentences = [] 248 | extra = {} 249 | example_ids = [] 250 | 251 | with open(filename) as f: 252 | for line in tqdm(f, desc='read'): 253 | smap = self.read_line(line) 254 | if smap is None: 255 | continue 256 | 257 | s1, s2, label = smap['s1'], smap['s2'], smap['label'] 258 | example_id = smap['example_id'] 259 | skip_s1 = self.filter_length > 0 and len(s1) > self.filter_length 260 | skip_s2 = self.filter_length > 0 and len(s2) > self.filter_length 261 | 262 | if not skip_s1: 263 | example_ids.append(example_id + '_1') 264 | sentences.append(s1) 265 | if not skip_s2: 266 | example_ids.append(example_id + '_2') 267 | sentences.append(s2) 268 | 269 | extra['example_ids'] = example_ids 270 | 271 | return { 272 | "sentences": sentences, 273 | "extra": extra, 274 | } 275 | 276 | 277 | class ConllReader(object): 278 | def __init__(self, lowercase=True, filter_length=0): 279 | self.lowercase = lowercase 280 | self.filter_length = filter_length if filter_length is not None else 0 281 | 282 | def read(self, filename): 283 | sentences = [] 284 | extra = {} 285 | example_ids = [] 286 | entity_labels = [] 287 | 288 | with open(filename) as f: 289 | for line in tqdm(f, desc='read'): 290 | data = json.loads(line) 291 | s = data['sentence'] 292 | 293 | # skip long sentences 294 | if self.filter_length > 0 and len(s) > self.filter_length: 295 | continue 296 | 297 | sentences.append(s) 298 | example_ids.append(data['example_id']) 299 | entity_labels.append(data['entities']) 300 | 301 | extra['example_ids'] = example_ids 302 | extra['entity_labels'] = entity_labels 303 | 304 | return { 305 | "sentences": sentences, 306 | "extra": extra, 307 | } 308 | 309 | 310 | class SyntheticReader(object): 311 | def __init__(self, nexamples=100, embedding_size=10, vocab_size=14, seed=11, minlen=10, 312 | maxlen=20, length=None): 313 | super(SyntheticReader, self).__init__() 314 | self.nexamples = nexamples 315 | self.embedding_size = embedding_size 316 | self.vocab_size = vocab_size 317 | self.seed = seed 318 | self.minlen = minlen 319 | self.maxlen = maxlen 320 | self.length = length 321 | 322 | def read(self, filename=None): 323 | min_length = self.minlen 324 | max_length = self.maxlen 325 | 326 | if self.length is not None: 327 | min_length = self.length 328 | max_length = min_length + 1 329 | 330 | sentences = synthesize_training_data(self.nexamples, self.vocab_size, 331 | min_length=min_length, max_length=max_length, seed=self.seed) 332 | 333 | metadata = {} 334 | metadata['embeddings'] = np.random.randn(self.vocab_size, self.embedding_size).astype(np.float32) 335 | 336 | return { 337 | "sentences": sentences, 338 | "extra": extra, 339 | "metadata": metadata 340 | } 341 | 342 | 343 | class PTBReader(object): 344 | def __init__(self, lowercase=True, filter_length=0, delim=' '): 345 | self.delim = delim 346 | self.lowercase = lowercase 347 | self.filter_length = filter_length if filter_length is not None else 0 348 | 349 | def read(self, filename): 350 | sentences = [] 351 | 352 | # extra 353 | extra = dict() 354 | example_ids = [] 355 | gts = [] 356 | 357 | file = pickle.load(open(filename, 'rb')) 358 | datas = file['other_data'] 359 | word2idx = file['word2idx'] 360 | 361 | # read 362 | for idx, data in enumerate(datas): 363 | sent = data[0] 364 | gt = data[5] 365 | s = sent.strip().split(self.delim) 366 | if self.filter_length > 0 and len(s) > self.filter_length: 367 | continue 368 | if self.lowercase: 369 | s = [w.lower() for w in s] 370 | s = [w if w in word2idx else '' for w in s] 371 | 372 | example_ids.append(idx) 373 | sentences.append(s) 374 | gts.append(gt) 375 | 376 | extra['example_ids'] = example_ids 377 | extra['GT'] = gts 378 | metadata = {} 379 | metadata['word2idx'] = word2idx 380 | 381 | return { 382 | "sentences": sentences, 383 | "extra": extra, 384 | "metadata": metadata 385 | } 386 | 387 | 388 | class COCOReader(object): 389 | def __init__(self, lowercase=True, filter_length=0, delim=' '): 390 | self.delim = delim 391 | self.lowercase = lowercase 392 | self.filter_length = filter_length if filter_length is not None else 0 393 | 394 | def read(self, filename): 395 | sentences = [] 396 | extra = dict() 397 | 398 | example_ids = [] 399 | gts = [] 400 | vis_feats = [] 401 | word2idx = json.load(open(filename.replace(filename.split('/')[-1], 'coco.dict.json'), 'r')) 402 | 403 | if 'train' in filename: 404 | split = 'train' 405 | elif 'val' in filename: 406 | split = 'val' 407 | elif 'test' in filename: 408 | split = 'test' 409 | else: 410 | raise NotImplementedError 411 | 412 | with open(filename.replace(filename.split('/')[-1], 'id_list/{}.txt'.format(split)), 'r') as f: 413 | origin_img_ids = f.readlines() 414 | origin_img_ids = np.array([int(i.strip('.jpg\n').split('_')[-1]) for i in origin_img_ids]).repeat(5) 415 | 416 | if split == 'test': 417 | image_feats = np.zeros([len(origin_img_ids), 2048]) 418 | else: 419 | image_feats = np.load(filename.replace(filename.split('/')[-1], '{}_ims.npy'.format(split))) 420 | image_feats = image_feats.repeat(5, 0) 421 | 422 | with open(filename) as f: 423 | lines = f.readlines() 424 | 425 | assert len(origin_img_ids) == len(lines) 426 | assert len(lines) == len(image_feats) 427 | 428 | for idx, line in tqdm(enumerate(lines), desc='read'): 429 | (sent, gt, _, _) = json.loads(line.strip()) 430 | s = sent.strip().split(self.delim) 431 | 432 | if self.filter_length > 0 and len(s) > self.filter_length: 433 | continue 434 | if self.lowercase: 435 | s = [w.lower() for w in s] 436 | s = [w if w in word2idx else '' for w in s] 437 | example_ids.append(origin_img_ids[idx]) 438 | sentences.append(s) 439 | gts.append([tuple(i) for i in gt]) 440 | vis_feats.append(image_feats[idx]) 441 | 442 | extra['example_ids'] = example_ids 443 | extra['image_feats'] = vis_feats 444 | extra['GT'] = gts 445 | metadata = {} 446 | metadata['word2idx'] = word2idx 447 | 448 | return { 449 | "sentences": sentences, 450 | "extra": extra, 451 | "metadata": metadata 452 | } 453 | 454 | 455 | class FlickrReader(object): 456 | def __init__(self, lowercase=True, filter_length=0, delim=' '): 457 | self.delim = delim 458 | self.lowercase = lowercase 459 | self.filter_length = filter_length if filter_length is not None else 0 460 | 461 | def read(self, filename): 462 | sentences = [] 463 | extra = dict() 464 | 465 | example_ids = [] 466 | gts = [] 467 | vg_gts = [] 468 | vis_feats = [] 469 | 470 | word2idx = json.load(open(filename.replace(filename.split('/')[-1], 'flickr.dic.json'), 'r')) 471 | 472 | if 'train' in filename: 473 | split = 'train' 474 | elif 'val' in filename: 475 | split = 'val' 476 | elif 'test' in filename: 477 | split = 'test' 478 | else: 479 | raise NotImplementedError 480 | 481 | with open(filename.replace(filename.split('/')[-1], '{}.txt'.format(split)), 'r') as f: 482 | origin_img_sent_ids = f.readlines() 483 | # origin_vg_gts = json.load(open(filename.replace(filename.split('/')[-1], 'sent_anno.json'))) 484 | if split in ['val', 'test']: 485 | origin_vg_gts = pickle.load(open(filename.replace(filename.split('/')[-1], 'gt_anno_{}.pkl'.format(split)), 'rb')) 486 | else: 487 | origin_vg_gts = None 488 | with open(filename) as f: 489 | lines = f.readlines() 490 | # image_feats = np.zeros([len(lines), 2048]) 491 | 492 | assert len(origin_img_sent_ids) == len(lines) 493 | 494 | for idx, line in tqdm(enumerate(lines), desc='read'): 495 | (sent, gt) = json.loads(line.strip()) 496 | s = sent.strip().split(self.delim) 497 | 498 | if self.filter_length > 0 and len(s) > self.filter_length: 499 | continue 500 | if self.lowercase: 501 | s = [w.lower() for w in s] 502 | s = [w if w in word2idx else '' for w in s] 503 | 504 | im_id, sent_id = origin_img_sent_ids[idx].strip().split('\t') 505 | example_ids.append(im_id) 506 | if origin_vg_gts is not None: 507 | vg_gt = origin_vg_gts.get(im_id + '_' + sent_id, [{}, None]) 508 | vg_gts.append(vg_gt) 509 | else: 510 | vg_gts.append([{}, None]) 511 | # vg_gts.append(origin_vg_gts[im_id][int(sent_id)]) 512 | sentences.append(s) 513 | gts.append([tuple(i) for i in gt]) 514 | vis_feats.append(np.zeros(1)) 515 | 516 | extra['example_ids'] = example_ids 517 | extra['image_feats'] = vis_feats 518 | extra['GT'] = gts 519 | extra['VG_GT'] = vg_gts 520 | metadata = {} 521 | metadata['word2idx'] = word2idx 522 | # metadata['embeddings'] = embeddings 523 | 524 | return { 525 | "sentences": sentences, 526 | "extra": extra, 527 | "metadata": metadata 528 | } 529 | 530 | -------------------------------------------------------------------------------- /cliora/external/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bobwan1995/cliora/b064bdf967d4ccc4f3327183efd888b927bfb4fb/cliora/external/__init__.py -------------------------------------------------------------------------------- /cliora/logging/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bobwan1995/cliora/b064bdf967d4ccc4f3327183efd888b927bfb4fb/cliora/logging/__init__.py -------------------------------------------------------------------------------- /cliora/logging/accumulator.py: -------------------------------------------------------------------------------- 1 | class MeanAccumulator(object): 2 | def __init__(self): 3 | self.reset() 4 | 5 | def record(self, val): 6 | if self.count == 0: 7 | self.count += 1 8 | self.val += val 9 | return 10 | 11 | count = self.count + 1 12 | self.val = ((self.val * self.count) + val) / count 13 | self.count = count 14 | 15 | def reset(self): 16 | self.val = 0 17 | self.count = 0 18 | 19 | 20 | class Accumulator(object): 21 | def __init__(self): 22 | self.table = {} 23 | 24 | def record(self, key, val): 25 | if not key in self.table: 26 | self.table[key] = MeanAccumulator() 27 | self.table[key].record(val) 28 | 29 | def has(self, key): 30 | return key in self.table 31 | 32 | def get_mean(self, key, default=0): 33 | if key in self.table: 34 | val = self.table[key].val 35 | else: 36 | val = default 37 | return val 38 | 39 | def reset(self, key=None): 40 | if key is None: 41 | keys = list(self.table.keys()) 42 | else: 43 | keys = [key] 44 | 45 | for key in keys: 46 | del self.table[key] 47 | 48 | -------------------------------------------------------------------------------- /cliora/logging/configuration.py: -------------------------------------------------------------------------------- 1 | import os 2 | import logging 3 | 4 | from cliora.utils.fs import mkdir_p 5 | 6 | 7 | LOGGING_NAMESPACE = 'cliora' 8 | 9 | 10 | def configure_experiment(experiment_path, rank=None): 11 | mkdir_p(experiment_path) 12 | if rank is None: 13 | log_file = os.path.join(experiment_path, 'experiment.log') 14 | else: 15 | log_file = os.path.join(experiment_path, 'experiment.log.{}'.format(rank)) 16 | configure_logger(log_file) 17 | 18 | 19 | def configure_logger(log_file): 20 | """ 21 | Simple logging configuration. 22 | """ 23 | 24 | # Create logger. 25 | logger = logging.getLogger(LOGGING_NAMESPACE) 26 | logger.setLevel(logging.INFO) 27 | 28 | # Create file handler. 29 | fh = logging.FileHandler(log_file) 30 | fh.setLevel(logging.INFO) 31 | 32 | # Also log to console. 33 | ch = logging.StreamHandler() 34 | ch.setLevel(logging.INFO) 35 | 36 | # create formatter and add it to the handlers 37 | formatter = logging.Formatter('%(asctime)s [%(levelname)s] %(message)s') 38 | fh.setFormatter(formatter) 39 | ch.setFormatter(formatter) 40 | 41 | # add the handlers to the logger 42 | logger.addHandler(fh) 43 | logger.addHandler(ch) 44 | 45 | # HACK: Weird fix that counteracts other libraries (i.e. allennlp) modifying 46 | # the global logger. 47 | if len(logger.parent.handlers) > 0: 48 | logger.parent.handlers.pop() 49 | 50 | return logger 51 | 52 | 53 | def get_logger(): 54 | return logging.getLogger(LOGGING_NAMESPACE) 55 | -------------------------------------------------------------------------------- /cliora/misc/convert_conll_to_jsonl.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import json 4 | 5 | 6 | def pick(lst, k): 7 | return [d[k] for d in lst] 8 | 9 | 10 | class ConllReader(object): 11 | def __init__(self, word_index=None, tag_index=None, delim=' '): 12 | super(ConllReader, self).__init__() 13 | self.word_index = word_index 14 | self.tag_index = tag_index 15 | self.delim = delim 16 | 17 | self.example_counter = None 18 | 19 | def reset(self): 20 | self.example_counter = 0 21 | 22 | def get_word(self, parts): 23 | return parts[self.word_index] 24 | 25 | def get_tag_and_labels(self, parts): 26 | x = parts[self.tag_index] 27 | 28 | def get_labels(y): 29 | return y 30 | 31 | if x.startswith('O'): 32 | return 'O', None 33 | if x.startswith('I'): 34 | return 'I', get_labels(x.split('-', 1)[1]) 35 | if x.startswith('B'): 36 | return 'B', get_labels(x.split('-', 1)[1]) 37 | 38 | raise ValueError('Not a BIO tag: {}'.format(x)) 39 | 40 | def convert_records_to_example(self, records): 41 | example_id = self.example_counter 42 | 43 | word_lst = pick(records, 'word') 44 | tag_lst = pick(records, 'tag') 45 | labels_lst = pick(records, 'labels') 46 | 47 | entity_lst = [] 48 | warning_lst = [] 49 | 50 | for i, tag in enumerate(tag_lst): 51 | # Adjust tags if needed. 52 | if tag == 'I' and len(entity_lst) == 0: 53 | warning = '[warning] Converting I to B. I appears at beginning of sentence. i = {}'.format(i) 54 | warning_lst.append(warning) 55 | tag = 'B' 56 | 57 | if tag == 'I' and len(entity_lst) == 0: 58 | warning = '[warning] Converting I to B. I appears before any B tags. i = {}'.format(i) 59 | warning_lst.append(warning) 60 | tag = 'B' 61 | 62 | if tag == 'I' and len(entity_lst) > 0: 63 | pos = entity_lst[-1][-2] 64 | size = entity_lst[-1][-1] 65 | 66 | if pos + size != i: 67 | warning = '[warning] Converting I to B. I appears after O. i = {}'.format(i) 68 | warning_lst.append(warning) 69 | tag = 'B' 70 | 71 | # Record entity. 72 | if tag == 'O': 73 | continue 74 | if tag == 'B': 75 | labels = labels_lst[i] 76 | 77 | # entity = (labels, position, size) 78 | entity = [labels, i, 1] 79 | entity_lst.append(entity) 80 | 81 | assert labels is not None and isinstance(labels, str) 82 | if tag == 'I': 83 | # increment size 84 | entity_lst[-1][-1] += 1 85 | 86 | pos = entity_lst[-1][-2] 87 | size = entity_lst[-1][-1] 88 | 89 | assert pos + size - 1 == i 90 | 91 | # Build Example 92 | example = {} 93 | example['example_id'] = '{}_{}'.format(options.name, example_id) 94 | example['entities'] = entity_lst 95 | example['sentence'] = word_lst 96 | 97 | if len(warning_lst) > 0: 98 | example['warnings'] = warning_lst 99 | 100 | # Cleanup 101 | self.example_counter += 1 102 | 103 | return example 104 | 105 | def _read(self, filename): 106 | 107 | lst = [] 108 | 109 | with open(filename) as f: 110 | for i, line in enumerate(f): 111 | line = line.rstrip() 112 | 113 | # Skip Empty Lines 114 | if len(line) == 0: 115 | if len(lst) > 0: 116 | yield lst 117 | lst = [] 118 | continue 119 | 120 | parts = line.split(self.delim) 121 | 122 | word = self.get_word(parts) 123 | tag, labels = self.get_tag_and_labels(parts) 124 | 125 | record = dict() 126 | record['word'] = word 127 | record['tag'] = tag 128 | record['labels'] = labels 129 | 130 | lst.append(record) 131 | 132 | # In case final does not end in newline. 133 | if len(lst) > 0: 134 | yield lst 135 | 136 | def read(self, filename): 137 | self.reset() 138 | 139 | for record_lst in self._read(filename): 140 | example = self.convert_records_to_example(record_lst) 141 | yield example 142 | 143 | 144 | if __name__ == '__main__': 145 | parser = argparse.ArgumentParser() 146 | parser.add_argument('--path', default='./train.txt', type=str) 147 | parser.add_argument('--delim', default=' ', type=str) 148 | parser.add_argument('--i_word', default=0, type=int) 149 | parser.add_argument('--i_tag', default=2, type=int) 150 | parser.add_argument('--name', default='conll2000', type=str) 151 | options = parser.parse_args() 152 | 153 | reader = ConllReader(tag_index=options.i_tag, word_index=options.i_word, delim=options.delim) 154 | for example in reader.read(options.path): 155 | print(json.dumps(example)) 156 | -------------------------------------------------------------------------------- /cliora/net/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bobwan1995/cliora/b064bdf967d4ccc4f3327183efd888b927bfb4fb/cliora/net/__init__.py -------------------------------------------------------------------------------- /cliora/net/cliora.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from cliora.net.utils import * 5 | 6 | class Chart(object): 7 | def __init__(self, batch_size, length, size, dtype=None, cuda=False): 8 | super(Chart, self).__init__() 9 | 10 | ncells = int(length * (1 + length) / 2) 11 | 12 | device = torch.cuda.current_device() if cuda else None 13 | 14 | ## Inside. 15 | self.inside_h = torch.full((batch_size, ncells, size), 0, dtype=dtype, device=device) 16 | self.inside_c = torch.full((batch_size, ncells, size), 0, dtype=dtype, device=device) 17 | self.inside_s = torch.full((batch_size, ncells, 1), 0, dtype=dtype, device=device) 18 | 19 | ## Outside. 20 | self.outside_h = torch.full((batch_size, ncells, size), 0, dtype=dtype, device=device) 21 | self.outside_c = torch.full((batch_size, ncells, size), 0, dtype=dtype, device=device) 22 | self.outside_s = torch.full((batch_size, ncells, 1), 0, dtype=dtype, device=device) 23 | 24 | ## Visual 25 | self.vis_aggragate = torch.full((batch_size, ncells, size), 0, dtype=dtype, device=device) 26 | 27 | 28 | class AttentionHead(nn.Module): 29 | def __init__(self, q_dim, k_dim, v_dim, h_dim): 30 | super(AttentionHead, self).__init__() 31 | self.h_dim = h_dim 32 | self.dropout = nn.Dropout(0.1) 33 | 34 | 35 | def forward(self, h_q, h_k, h_v, temp=1.0): 36 | 37 | all_atten_score = torch.einsum('abx,cdx->acbd', h_q, h_k) 38 | # atten_score = torch.diagonal(all_atten_score, 0, 0, 1).permute(2, 0, 1) / self.h_dim**0.5 39 | atten_score = torch.diagonal(all_atten_score/temp, 0, 0, 1).permute(2, 0, 1) 40 | atten_prob = self.dropout(torch.softmax(atten_score, dim=-1)) 41 | cxt = torch.bmm(atten_prob, h_v) 42 | return cxt 43 | 44 | 45 | # Composition Functions 46 | class VLComposeMLP(nn.Module): 47 | def __init__(self, size, ninput=2, leaf=False): 48 | super(VLComposeMLP, self).__init__() 49 | 50 | self.size = size 51 | self.ninput = ninput 52 | if leaf: 53 | self.leaf_fc = nn.Linear(self.size, self.size) 54 | self.h_fcs = nn.Sequential( 55 | nn.Linear(2 * self.size, self.size), 56 | nn.ReLU(), 57 | nn.Linear(self.size, self.size), 58 | nn.ReLU() 59 | ) 60 | 61 | @property 62 | def device(self): 63 | return next(self.parameters()).device 64 | 65 | @property 66 | def is_cuda(self): 67 | device = self.device 68 | return device.index is not None and device.index >= 0 69 | 70 | 71 | def leaf_transform(self, x, obj, atten_head, normalize_func): 72 | h = torch.tanh(self.leaf_fc(x)) 73 | h = normalize_func(h) # TODO 74 | 75 | cxt = atten_head(h, obj, obj) 76 | # cxt = obj.mean(1).unsqueeze(1).expand(h.shape) 77 | 78 | # visual as residule 79 | h = h + cxt 80 | return h, cxt 81 | 82 | def forward(self, hs, cs): 83 | input_h = torch.cat(hs, 1) 84 | h = self.h_fcs(input_h) 85 | c = torch.full(h.shape, 0, dtype=torch.float32, device=h.device) 86 | return h, c 87 | 88 | 89 | # Score Functions 90 | 91 | class Bilinear(nn.Module): 92 | def __init__(self, size): 93 | super(Bilinear, self).__init__() 94 | self.size = size 95 | self.mat = nn.Parameter(torch.FloatTensor(self.size, self.size)) 96 | 97 | def forward(self, vector1, vector2): 98 | bma = torch.matmul(vector1, self.mat).unsqueeze(1) 99 | ba = torch.matmul(bma, vector2.unsqueeze(2)).view(-1, 1) 100 | return ba 101 | 102 | 103 | # Inside 104 | 105 | def inside_fill_chart(batch_info, chart, index, h, c, s): 106 | L = batch_info.length - batch_info.level 107 | 108 | offset = index.get_offset(batch_info.length)[batch_info.level] 109 | 110 | chart.inside_h[:, offset:offset+L] = h 111 | chart.inside_c[:, offset:offset+L] = c 112 | chart.inside_s[:, offset:offset+L] = s 113 | 114 | 115 | def get_inside_states(batch_info, chart, index, size): 116 | lidx, ridx = index.get_inside_index(batch_info.length, batch_info.level) 117 | 118 | ls = chart.index_select(index=lidx, dim=1).view(-1, size) 119 | rs = chart.index_select(index=ridx, dim=1).view(-1, size) 120 | 121 | return ls, rs 122 | 123 | 124 | def inside_compose(compose_func, hs, cs): 125 | return compose_func(hs, cs) 126 | 127 | 128 | def inside_score(score_func, batch_info, hs, cs, ss): 129 | B = batch_info.batch_size 130 | L = batch_info.length - batch_info.level 131 | N = batch_info.level 132 | 133 | s = score_func(hs[0], hs[1]) + ss[0] + ss[1] 134 | s = s.view(B, L, N, 1) 135 | p = torch.softmax(s, dim=2) 136 | 137 | return s, p 138 | 139 | 140 | def inside_aggregate(batch_info, h, c, s, p, obj, normalize_func, atten_head): 141 | B = batch_info.batch_size 142 | L = batch_info.length - batch_info.level 143 | N = batch_info.level 144 | 145 | h_agg = torch.sum(h.view(B, L, N, -1) * p, 2) 146 | c_agg = torch.sum(c.view(B, L, N, -1) * p, 2) 147 | s_agg = torch.sum(s * p, 2) 148 | 149 | h_agg = normalize_func(h_agg) # TODO 150 | # cxt = obj.mean(1).unsqueeze(1).expand(h_agg.shape) 151 | cxt = atten_head(h_agg, obj, obj) 152 | h_agg = h_agg + cxt 153 | 154 | h_agg = normalize_func(h_agg) 155 | c_agg = normalize_func(c_agg) # ignore 156 | 157 | return h_agg, c_agg, s_agg 158 | 159 | 160 | # Outside 161 | 162 | def outside_fill_chart(batch_info, chart, index, h, c, s): 163 | L = batch_info.length - batch_info.level 164 | 165 | offset = index.get_offset(batch_info.length)[batch_info.level] 166 | 167 | chart.outside_h[:, offset:offset+L] = h 168 | chart.outside_c[:, offset:offset+L] = c 169 | chart.outside_s[:, offset:offset+L] = s 170 | 171 | 172 | def get_outside_states(batch_info, pchart, schart, index, size): 173 | pidx, sidx = index.get_outside_index(batch_info.length, batch_info.level) 174 | 175 | ps = pchart.index_select(index=pidx, dim=1).view(-1, size) 176 | ss = schart.index_select(index=sidx, dim=1).view(-1, size) 177 | 178 | return ps, ss 179 | 180 | 181 | def outside_compose(compose_func, hs, cs): 182 | return compose_func(hs, cs) 183 | 184 | 185 | def outside_score(score_func, batch_info, hs, cs, ss): 186 | B = batch_info.batch_size 187 | L = batch_info.length - batch_info.level 188 | 189 | s = score_func(hs[0], hs[1]) + ss[0] + ss[1] 190 | s = s.view(B, -1, L, 1) 191 | p = torch.softmax(s, dim=1) 192 | 193 | return s, p 194 | 195 | 196 | def outside_aggregate(batch_info, h, c, s, p, normalize_func): 197 | B = batch_info.batch_size 198 | L = batch_info.length - batch_info.level 199 | N = s.shape[1] 200 | 201 | h_agg = torch.sum(h.view(B, N, L, -1) * p, 1) 202 | c_agg = torch.sum(c.view(B, N, L, -1) * p, 1) 203 | s_agg = torch.sum(s * p, 1) 204 | 205 | h_agg = normalize_func(h_agg) 206 | c_agg = normalize_func(c_agg) 207 | 208 | return h_agg, c_agg, s_agg 209 | 210 | 211 | # Base 212 | 213 | class DioraBase(nn.Module): 214 | r"""DioraBase 215 | 216 | """ 217 | 218 | def __init__(self, size, outside=True, normalize='unit', compress=False, share=True): 219 | super(DioraBase, self).__init__() 220 | assert normalize in ('none', 'unit'), 'Does not support "{}".'.format(normalize) 221 | 222 | self.share = share 223 | self.size = size 224 | self.outside = outside 225 | self.inside_normalize_func = NormalizeFunc(normalize) 226 | self.outside_normalize_func = NormalizeFunc(normalize) 227 | self.compress = compress 228 | self.atten_head = AttentionHead(size, size, size, size) 229 | self.ninput = 2 230 | 231 | self.index = None 232 | self.charts = None 233 | 234 | self.init_parameters() 235 | self.reset_parameters() 236 | self.reset() 237 | 238 | def init_parameters(self): 239 | raise NotImplementedError 240 | 241 | def reset_parameters(self): 242 | params = [p for p in self.parameters() if p.requires_grad] 243 | for i, param in enumerate(params): 244 | param.data.normal_() 245 | 246 | @property 247 | def device(self): 248 | return next(self.parameters()).device 249 | 250 | @property 251 | def inside_h(self): 252 | return self.chart.inside_h 253 | 254 | @property 255 | def inside_c(self): 256 | return self.chart.inside_c 257 | 258 | @property 259 | def inside_s(self): 260 | return self.chart.inside_s 261 | 262 | @property 263 | def outside_h(self): 264 | return self.chart.outside_h 265 | 266 | @property 267 | def outside_c(self): 268 | return self.chart.outside_c 269 | 270 | @property 271 | def outside_s(self): 272 | return self.chart.outside_s 273 | 274 | @property 275 | def is_cuda(self): 276 | device = self.device 277 | return device.index is not None and device.index >= 0 278 | 279 | def cuda(self): 280 | super(DioraBase, self).cuda() 281 | if self.index is not None: 282 | self.index.cuda = True # TODO: Should support to/from cpu/gpu. 283 | 284 | def get(self, chart, level): 285 | length = self.length 286 | L = length - level 287 | offset = self.index.get_offset(length)[level] 288 | return chart[:, offset:offset+L] 289 | 290 | def leaf_transform(self, x, obj_embed): 291 | normalize_func = self.inside_normalize_func 292 | transform_func = self.inside_compose_func.leaf_transform 293 | atten_head = self.atten_head 294 | 295 | input_shape = x.shape[:-1] 296 | h, c = transform_func(x, obj_embed, atten_head, normalize_func) 297 | 298 | h = normalize_func(h.view(*input_shape, self.size)) 299 | c = normalize_func(c.view(*input_shape, self.size)) # ignore 300 | 301 | return h, c 302 | 303 | # Inside 304 | def inside_func(self, compose_func, score_func, atten_func, obj_embed, batch_info, chart, index, normalize_func): 305 | lh, rh = get_inside_states(batch_info, chart.inside_h, index, batch_info.size) 306 | lc, rc = get_inside_states(batch_info, chart.inside_c, index, batch_info.size) 307 | ls, rs = get_inside_states(batch_info, chart.inside_s, index, 1) 308 | 309 | hlst = [lh, rh] 310 | clst = [lc, rc] 311 | slst = [ls, rs] 312 | 313 | h, c = inside_compose(compose_func, hlst, clst) 314 | s, p = inside_score(score_func, batch_info, hlst, clst, slst) 315 | hbar, cbar, sbar = inside_aggregate(batch_info, h, c, s, p, obj_embed, normalize_func, atten_func) 316 | 317 | inside_fill_chart(batch_info, chart, index, hbar, cbar, sbar) 318 | 319 | return h, c, s 320 | 321 | def inside_pass(self, obj_embed): 322 | compose_func = self.inside_compose_func 323 | score_func = self.inside_score_func 324 | index = self.index 325 | chart = self.chart 326 | normalize_func = self.inside_normalize_func 327 | atten_func = self.atten_head 328 | 329 | for level in range(1, self.length): 330 | 331 | batch_info = BatchInfo( 332 | batch_size=self.batch_size, 333 | length=self.length, 334 | size=self.size, 335 | level=level, 336 | ) 337 | 338 | h, c, s = self.inside_func(compose_func, score_func, atten_func, obj_embed, 339 | batch_info, chart, index, normalize_func=normalize_func) 340 | 341 | self.inside_hook(level, h, c, s) 342 | 343 | 344 | def inside_hook(self, level, h, c, s): 345 | pass 346 | 347 | def outside_hook(self, level, h, c, s): 348 | pass 349 | 350 | def initialize_outside_root(self): 351 | B = self.batch_size 352 | D = self.size 353 | normalize_func = self.outside_normalize_func 354 | 355 | if self.compress: 356 | h = torch.matmul(self.inside_h[:, -1:], self.root_mat_out) 357 | else: 358 | h = self.root_vector_out_h.view(1, 1, D).expand(B, 1, D) 359 | if self.root_vector_out_c is None: 360 | device = torch.cuda.current_device() if self.is_cuda else None 361 | c = torch.full(h.shape, 0, dtype=torch.float32, device=device) 362 | else: 363 | c = self.root_vector_out_c.view(1, 1, D).expand(B, 1, D) 364 | 365 | h = normalize_func(h) 366 | c = normalize_func(c) 367 | 368 | self.chart.outside_h[:, -1:] = h 369 | self.chart.outside_c[:, -1:] = c 370 | 371 | 372 | def outside_func(self, compose_func, score_func, batch_info, chart, index, normalize_func): 373 | ph, sh = get_outside_states( 374 | batch_info, chart.outside_h, chart.inside_h, index, batch_info.size) 375 | pc, sc = get_outside_states( 376 | batch_info, chart.outside_c, chart.inside_c, index, batch_info.size) 377 | ps, ss = get_outside_states( 378 | batch_info, chart.outside_s, chart.inside_s, index, 1) 379 | 380 | hlst = [sh, ph] 381 | clst = [sc, pc] 382 | slst = [ss, ps] 383 | 384 | h, c = outside_compose(compose_func, hlst, clst) 385 | s, p = outside_score(score_func, batch_info, hlst, clst, slst) 386 | hbar, cbar, sbar = outside_aggregate(batch_info, h, c, s, p, normalize_func) 387 | 388 | # TODO: add attention here 389 | outside_fill_chart(batch_info, chart, index, hbar, cbar, sbar) 390 | 391 | return h, c, s 392 | 393 | 394 | def outside_pass(self): 395 | self.initialize_outside_root() 396 | 397 | compose_func = self.outside_compose_func 398 | score_func = self.outside_score_func 399 | index = self.index 400 | chart = self.chart 401 | normalize_func = self.outside_normalize_func 402 | 403 | for level in range(self.length - 2, -1, -1): 404 | batch_info = BatchInfo( 405 | batch_size=self.batch_size, 406 | length=self.length, 407 | size=self.size, 408 | level=level, 409 | ) 410 | 411 | h, c, s = self.outside_func(compose_func, score_func, 412 | batch_info, chart, index, normalize_func=normalize_func) 413 | 414 | self.outside_hook(level, h, c, s) 415 | 416 | # Initialization 417 | def init_with_batch(self, h, c): 418 | size = self.size 419 | batch_size, length, _ = h.shape 420 | 421 | self.batch_size = batch_size 422 | self.length = length 423 | 424 | self.chart = Chart(batch_size, length, size, dtype=torch.float32, cuda=self.is_cuda) 425 | self.chart.inside_h[:, :self.length] = h 426 | self.chart.inside_c[:, :self.length] = c 427 | 428 | def reset(self): 429 | self.batch_size = None 430 | self.length = None 431 | self.chart = None 432 | self.all_atten_score = None 433 | self.atten_score = None 434 | 435 | def get_chart_wrapper(self): 436 | return self 437 | 438 | def forward(self, x_span, x_word, obj_embed_span=None, obj_embed_word=None): 439 | if self.index is None: 440 | self.index = Index(cuda=self.is_cuda) 441 | 442 | self.reset() 443 | 444 | h, c = self.leaf_transform(x_span, obj_embed_span) 445 | 446 | self.init_with_batch(h, c) 447 | 448 | self.inside_pass(obj_embed_span) 449 | 450 | if self.outside: 451 | self.outside_pass() 452 | 453 | # TODO: COCO 454 | # self.all_atten_score = orch.einsum('abx,cx->acb', self.chart.inside_h + self.chart.outside_h, obj_embed_span) 455 | 456 | # Flickr 457 | self.all_atten_score = torch.einsum('abx,cdx->acbd', self.chart.inside_h + self.chart.outside_h, obj_embed_span) 458 | 459 | if self.training: 460 | self.vg_atten_score_word = torch.einsum('abx,cdx->acbd', x_word, obj_embed_word) 461 | self.vg_atten_score = self.vg_atten_score_word 462 | else: 463 | self.vg_atten_score_word = torch.einsum('abx,cdx->acbd', self.inside_normalize_func(x_word), obj_embed_word) 464 | self.vg_atten_score = self.all_atten_score[:, :, :x_span.size(1)] + self.vg_atten_score_word 465 | 466 | self.atten_score = torch.diagonal(self.vg_atten_score, 0, 0, 1).permute(2, 0, 1) 467 | 468 | return None 469 | 470 | 471 | class DioraMLP(DioraBase): 472 | 473 | def init_parameters(self): 474 | self.inside_score_func = Bilinear(self.size) 475 | self.inside_compose_func = VLComposeMLP(self.size, leaf=True) 476 | if self.share: 477 | self.outside_score_func = self.inside_score_func 478 | self.outside_compose_func = self.inside_compose_func 479 | else: 480 | self.outside_score_func = Bilinear(self.size) 481 | self.outside_compose_func = VLComposeMLP(self.size) 482 | 483 | if self.compress: 484 | self.root_mat_out = nn.Parameter(torch.FloatTensor(self.size, self.size)) 485 | else: 486 | self.root_vector_out_h = nn.Parameter(torch.FloatTensor(self.size)) 487 | 488 | self.root_vector_out_c = None 489 | -------------------------------------------------------------------------------- /cliora/net/diora.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from cliora.net.utils import * 5 | 6 | 7 | class Chart(object): 8 | def __init__(self, batch_size, length, size, dtype=None, cuda=False): 9 | super(Chart, self).__init__() 10 | 11 | ncells = int(length * (1 + length) / 2) 12 | 13 | device = torch.cuda.current_device() if cuda else None 14 | 15 | ## Inside. 16 | self.inside_h = torch.full((batch_size, ncells, size), 0, dtype=dtype, device=device) 17 | self.inside_c = torch.full((batch_size, ncells, size), 0, dtype=dtype, device=device) 18 | self.inside_s = torch.full((batch_size, ncells, 1), 0, dtype=dtype, device=device) 19 | 20 | ## Outside. 21 | self.outside_h = torch.full((batch_size, ncells, size), 0, dtype=dtype, device=device) 22 | self.outside_c = torch.full((batch_size, ncells, size), 0, dtype=dtype, device=device) 23 | self.outside_s = torch.full((batch_size, ncells, 1), 0, dtype=dtype, device=device) 24 | 25 | 26 | class ComposeMLP(nn.Module): 27 | def __init__(self, size, ninput=2, leaf=False): 28 | super(ComposeMLP, self).__init__() 29 | 30 | self.size = size 31 | self.ninput = ninput 32 | 33 | if leaf: 34 | self.leaf_fc = nn.Linear(self.size, self.size) 35 | self.h_fcs = nn.Sequential( 36 | nn.Linear(2 * self.size, self.size), 37 | nn.ReLU(), 38 | nn.Linear(self.size, self.size), 39 | nn.ReLU() 40 | ) 41 | # self.reset_parameters() 42 | 43 | @property 44 | def device(self): 45 | return next(self.parameters()).device 46 | 47 | @property 48 | def is_cuda(self): 49 | device = self.device 50 | return device.index is not None and device.index >= 0 51 | 52 | # def reset_parameters(self): 53 | # # TODO: Init with diagonal. 54 | # params = [p for p in self.parameters() if p.requires_grad] 55 | # for i, param in enumerate(params): 56 | # param.data.normal_() 57 | 58 | def leaf_transform(self, x): 59 | # h = self.leaf_fc(x) 60 | h = torch.tanh(self.leaf_fc(x)) 61 | c = torch.full(h.shape, 0, dtype=torch.float32, device=h.device) 62 | 63 | return h, c 64 | 65 | def forward(self, hs, cs=None, constant=1.0): 66 | input_h = torch.cat(hs, 1) 67 | h = self.h_fcs(input_h) 68 | 69 | # device = torch.cuda.current_device() if self.is_cuda else None 70 | c = torch.full(h.shape, 0, dtype=torch.float32, device=h.device) 71 | 72 | return h, c 73 | 74 | 75 | # Score Functions 76 | 77 | class Bilinear(nn.Module): 78 | def __init__(self, size): 79 | super(Bilinear, self).__init__() 80 | self.size = size 81 | self.mat = nn.Parameter(torch.FloatTensor(self.size, self.size)) 82 | # self.reset_parameters() 83 | 84 | # def reset_parameters(self): 85 | # params = [p for p in self.parameters() if p.requires_grad] 86 | # for i, param in enumerate(params): 87 | # param.data.normal_() 88 | 89 | def forward(self, vector1, vector2): 90 | # bilinear 91 | # a = 1 (in a more general bilinear function, a is any positive integer) 92 | # vector1.shape = (b, m) 93 | # matrix.shape = (m, n) 94 | # vector2.shape = (b, n) 95 | bma = torch.matmul(vector1, self.mat).unsqueeze(1) 96 | ba = torch.matmul(bma, vector2.unsqueeze(2)).view(-1, 1) 97 | return ba 98 | 99 | 100 | # Inside 101 | 102 | def inside_fill_chart(batch_info, chart, index, h, c, s): 103 | L = batch_info.length - batch_info.level 104 | 105 | offset = index.get_offset(batch_info.length)[batch_info.level] 106 | 107 | chart.inside_h[:, offset:offset+L] = h 108 | chart.inside_c[:, offset:offset+L] = c 109 | chart.inside_s[:, offset:offset+L] = s 110 | 111 | 112 | def get_inside_states(batch_info, chart, index, size): 113 | lidx, ridx = index.get_inside_index(batch_info.length, batch_info.level) 114 | 115 | ls = chart.index_select(index=lidx, dim=1).view(-1, size) 116 | rs = chart.index_select(index=ridx, dim=1).view(-1, size) 117 | 118 | return ls, rs 119 | 120 | 121 | def inside_compose(compose_func, hs, cs): 122 | return compose_func(hs, cs) 123 | 124 | 125 | def inside_score(score_func, batch_info, hs, ss): 126 | B = batch_info.batch_size 127 | L = batch_info.length - batch_info.level 128 | N = batch_info.level 129 | 130 | s = score_func(hs[0], hs[1]) + ss[0] + ss[1] 131 | s = s.view(B, L, N, 1) 132 | p = torch.softmax(s, dim=2) 133 | 134 | return s, p 135 | 136 | 137 | def inside_aggregate(batch_info, h, c, s, p, normalize_func): 138 | B = batch_info.batch_size 139 | L = batch_info.length - batch_info.level 140 | N = batch_info.level 141 | 142 | h_agg = torch.sum(h.view(B, L, N, -1) * p, 2) 143 | c_agg = torch.sum(c.view(B, L, N, -1) * p, 2) 144 | s_agg = torch.sum(s * p, 2) 145 | 146 | h_agg = normalize_func(h_agg) 147 | c_agg = normalize_func(c_agg) 148 | 149 | return h_agg, c_agg, s_agg 150 | 151 | 152 | # Outside 153 | 154 | def outside_fill_chart(batch_info, chart, index, h, c, s): 155 | L = batch_info.length - batch_info.level 156 | 157 | offset = index.get_offset(batch_info.length)[batch_info.level] 158 | 159 | chart.outside_h[:, offset:offset+L] = h 160 | chart.outside_c[:, offset:offset+L] = c 161 | chart.outside_s[:, offset:offset+L] = s 162 | 163 | 164 | def get_outside_states(batch_info, pchart, schart, index, size): 165 | pidx, sidx = index.get_outside_index(batch_info.length, batch_info.level) 166 | 167 | ps = pchart.index_select(index=pidx, dim=1).view(-1, size) 168 | ss = schart.index_select(index=sidx, dim=1).view(-1, size) 169 | 170 | return ps, ss 171 | 172 | 173 | def outside_compose(compose_func, hs, cs): 174 | return compose_func(hs, cs, 0) 175 | 176 | 177 | def outside_score(score_func, batch_info, hs, ss): 178 | B = batch_info.batch_size 179 | L = batch_info.length - batch_info.level 180 | 181 | s = score_func(hs[0], hs[1]) + ss[0] + ss[1] 182 | s = s.view(B, -1, L, 1) 183 | p = torch.softmax(s, dim=1) 184 | 185 | return s, p 186 | 187 | 188 | def outside_aggregate(batch_info, h, c, s, p, normalize_func): 189 | B = batch_info.batch_size 190 | L = batch_info.length - batch_info.level 191 | N = s.shape[1] 192 | 193 | h_agg = torch.sum(h.view(B, N, L, -1) * p, 1) 194 | c_agg = torch.sum(c.view(B, N, L, -1) * p, 1) 195 | s_agg = torch.sum(s * p, 1) 196 | 197 | h_agg = normalize_func(h_agg) 198 | c_agg = normalize_func(c_agg) 199 | 200 | return h_agg, c_agg, s_agg 201 | 202 | 203 | # Base 204 | 205 | class DioraBase(nn.Module): 206 | r"""DioraBase 207 | 208 | """ 209 | 210 | def __init__(self, size, word_mat=None, cate_mat=None, outside=True, normalize='unit', compress=False, share=True): 211 | super(DioraBase, self).__init__() 212 | assert normalize in ('none', 'unit', 'xavier'), 'Does not support "{}".'.format(normalize) 213 | 214 | self.size = size 215 | self.share = share 216 | self.outside = outside 217 | self.inside_normalize_func = NormalizeFunc(normalize) 218 | self.outside_normalize_func = NormalizeFunc(normalize) 219 | # self.inside_normalize_func = nn.LayerNorm(size) 220 | # self.outside_normalize_func = nn.LayerNorm(size) 221 | self.compress = compress 222 | self.ninput = 2 223 | 224 | self.index = None 225 | self.charts = None 226 | 227 | self.init_parameters() 228 | self.reset_parameters() # reset all submodules 229 | self.reset() 230 | 231 | def init_parameters(self): 232 | raise NotImplementedError 233 | 234 | def reset_parameters(self): 235 | params = [p for p in self.parameters() if p.requires_grad] 236 | for i, param in enumerate(params): 237 | param.data.normal_() 238 | 239 | @property 240 | def device(self): 241 | return next(self.parameters()).device 242 | 243 | @property 244 | def inside_h(self): 245 | return self.chart.inside_h 246 | 247 | @property 248 | def inside_c(self): 249 | return self.chart.inside_c 250 | 251 | @property 252 | def inside_s(self): 253 | return self.chart.inside_s 254 | 255 | @property 256 | def outside_h(self): 257 | return self.chart.outside_h 258 | 259 | @property 260 | def outside_c(self): 261 | return self.chart.outside_c 262 | 263 | @property 264 | def outside_s(self): 265 | return self.chart.outside_s 266 | 267 | @property 268 | def is_cuda(self): 269 | device = self.device 270 | return device.index is not None and device.index >= 0 271 | 272 | def cuda(self): 273 | super(DioraBase, self).cuda() 274 | if self.index is not None: 275 | self.index.cuda = True # TODO: Should support to/from cpu/gpu. 276 | 277 | def get(self, chart, level): 278 | length = self.length 279 | L = length - level 280 | offset = self.index.get_offset(length)[level] 281 | return chart[:, offset:offset+L] 282 | 283 | def leaf_transform(self, x): 284 | normalize_func = self.inside_normalize_func 285 | transform_func = self.inside_compose_func.leaf_transform 286 | 287 | input_shape = x.shape[:-1] 288 | h, c = transform_func(x) 289 | h = normalize_func(h.view(*input_shape, self.size)) 290 | c = normalize_func(c.view(*input_shape, self.size)) 291 | 292 | return h, c 293 | 294 | # Inside 295 | def inside_func(self, compose_func, score_func, batch_info, chart, index, normalize_func): 296 | lh, rh = get_inside_states(batch_info, chart.inside_h, index, batch_info.size) 297 | lc, rc = get_inside_states(batch_info, chart.inside_c, index, batch_info.size) 298 | ls, rs = get_inside_states(batch_info, chart.inside_s, index, 1) 299 | 300 | hlst = [lh, rh] 301 | clst = [lc, rc] 302 | slst = [ls, rs] 303 | 304 | h, c = inside_compose(compose_func, hlst, clst) 305 | s, p = inside_score(score_func, batch_info, hlst, slst) 306 | hbar, cbar, sbar = inside_aggregate(batch_info, h, c, s, p, normalize_func) 307 | 308 | inside_fill_chart(batch_info, chart, index, hbar, cbar, sbar) 309 | 310 | return h, c, s 311 | 312 | def inside_pass(self): 313 | compose_func = self.inside_compose_func 314 | score_func = self.inside_score_func 315 | index = self.index 316 | chart = self.chart 317 | normalize_func = self.inside_normalize_func 318 | 319 | for level in range(1, self.length): 320 | 321 | batch_info = BatchInfo( 322 | batch_size=self.batch_size, 323 | length=self.length, 324 | size=self.size, 325 | level=level, 326 | ) 327 | 328 | h, c, s = self.inside_func(compose_func, score_func, batch_info, chart, index, 329 | normalize_func=normalize_func) 330 | 331 | self.inside_hook(level, h, c, s) 332 | 333 | def inside_hook(self, level, h, c, s): 334 | pass 335 | 336 | # Outside 337 | def initialize_outside_root(self): 338 | B = self.batch_size 339 | D = self.size 340 | normalize_func = self.outside_normalize_func 341 | 342 | if self.compress: 343 | h = torch.matmul(self.inside_h[:, -1:], self.root_mat_out) 344 | else: 345 | h = self.root_vector_out_h.view(1, 1, D).expand(B, 1, D) 346 | if self.root_vector_out_c is None: 347 | device = torch.cuda.current_device() if self.is_cuda else None 348 | c = torch.full(h.shape, 0, dtype=torch.float32, device=device) 349 | else: 350 | c = self.root_vector_out_c.view(1, 1, D).expand(B, 1, D) 351 | 352 | h = normalize_func(h) 353 | c = normalize_func(c) 354 | 355 | self.chart.outside_h[:, -1:] = h 356 | self.chart.outside_c[:, -1:] = c 357 | 358 | def outside_func(self, compose_func, score_func, batch_info, chart, index, normalize_func): 359 | ph, sh = get_outside_states( 360 | batch_info, chart.outside_h, chart.inside_h, index, batch_info.size) 361 | pc, sc = get_outside_states( 362 | batch_info, chart.outside_c, chart.inside_c, index, batch_info.size) 363 | ps, ss = get_outside_states( 364 | batch_info, chart.outside_s, chart.inside_s, index, 1) 365 | 366 | hlst = [sh, ph] 367 | clst = [sc, pc] 368 | slst = [ss, ps] 369 | 370 | h, c = outside_compose(compose_func, hlst, clst) 371 | s, p = outside_score(score_func, batch_info, hlst, slst) 372 | hbar, cbar, sbar = outside_aggregate(batch_info, h, c, s, p, normalize_func) 373 | 374 | outside_fill_chart(batch_info, chart, index, hbar, cbar, sbar) 375 | 376 | return h, c, s 377 | 378 | def outside_pass(self): 379 | self.initialize_outside_root() 380 | 381 | compose_func = self.outside_compose_func 382 | score_func = self.outside_score_func 383 | index = self.index 384 | chart = self.chart 385 | normalize_func = self.outside_normalize_func 386 | 387 | for level in range(self.length - 2, -1, -1): 388 | batch_info = BatchInfo( 389 | batch_size=self.batch_size, 390 | length=self.length, 391 | size=self.size, 392 | level=level, 393 | ) 394 | 395 | h, c, s = self.outside_func(compose_func, score_func, batch_info, chart, index, 396 | normalize_func=normalize_func) 397 | 398 | self.outside_hook(level, h, c, s) 399 | 400 | def outside_hook(self, level, h, c, s): 401 | pass 402 | 403 | # Initialization 404 | def init_with_batch(self, h, c): 405 | size = self.size 406 | batch_size, length, _ = h.shape 407 | 408 | self.batch_size = batch_size 409 | self.length = length 410 | 411 | self.chart = Chart(batch_size, length, size, dtype=torch.float32, cuda=self.is_cuda) 412 | self.chart.inside_h[:, :self.length] = h 413 | self.chart.inside_c[:, :self.length] = c 414 | 415 | def reset(self): 416 | self.batch_size = None 417 | self.length = None 418 | self.chart = None 419 | self.atten_score = None 420 | self.all_atten_score = None 421 | self.vg_atten_score = None 422 | 423 | 424 | def forward(self, x_span, x_word, obj_embed_span=None, obj_embed_word=None): 425 | if self.index is None: 426 | self.index = Index(cuda=self.is_cuda) 427 | 428 | self.reset() 429 | 430 | # h: normalized word embeds, c: zero vec with the same shape 431 | h, c = self.leaf_transform(x_span) 432 | 433 | self.init_with_batch(h, c) 434 | 435 | self.inside_pass() 436 | 437 | if self.outside: 438 | self.outside_pass() 439 | 440 | # Word Level Visual Grounding 441 | # if self.training: 442 | # self.vg_atten_score_word = torch.einsum('abx,cdx->acbd', x_word, obj_embed_word) 443 | # self.vg_atten_score = self.vg_atten_score_word 444 | # else: 445 | # self.vg_atten_score_word = torch.einsum('abx,cdx->acbd', self.inside_normalize_func(x_word), obj_embed_word) 446 | # self.vg_atten_score = self.vg_atten_score_word 447 | # 448 | # self.atten_score = torch.diagonal(self.vg_atten_score, 0, 0, 1).permute(2, 0, 1) # TODO 449 | 450 | return None 451 | 452 | 453 | class DioraMLP(DioraBase): 454 | 455 | def init_parameters(self): 456 | self.inside_score_func = Bilinear(self.size) 457 | self.inside_compose_func = ComposeMLP(self.size, leaf=True) 458 | 459 | if self.share: 460 | self.outside_score_func = self.inside_score_func 461 | self.outside_compose_func = self.inside_compose_func 462 | else: 463 | self.outside_score_func = Bilinear(self.size) 464 | self.outside_compose_func = ComposeMLP(self.size) 465 | 466 | if self.compress: 467 | self.root_mat_out = nn.Parameter(torch.FloatTensor(self.size, self.size)) 468 | else: 469 | self.root_vector_out_h = nn.Parameter(torch.FloatTensor(self.size)) 470 | self.root_vector_out_c = None 471 | # self.root_vector_out_c = nn.Parameter(torch.FloatTensor(self.size)) 472 | -------------------------------------------------------------------------------- /cliora/net/experiment_logger.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | from collections import Counter 4 | 5 | from cliora.logging.accumulator import Accumulator 6 | from cliora.logging.configuration import get_logger 7 | 8 | 9 | class ExperimentLogger(object): 10 | def __init__(self): 11 | super(ExperimentLogger, self).__init__() 12 | self.logger = get_logger() 13 | self.A = None 14 | self.c = Counter() 15 | 16 | def str_length_distribution(self): 17 | result = '' 18 | keys = sorted(self.c.keys()) 19 | for i, k in enumerate(keys): 20 | if i > 0: 21 | result += ' ' 22 | result += '{}:{}'.format(k, self.c[k]) 23 | return result 24 | 25 | def record(self, result): 26 | if self.A is None: 27 | self.A = Accumulator() 28 | A = self.A 29 | 30 | self.c[result['length']] += 1 31 | 32 | for k, v in result.items(): 33 | if 'loss' in k: 34 | A.record(k, v) 35 | if 'acc' in k: 36 | A.record(k, v) 37 | 38 | def log_batch(self, epoch, step, batch_idx, batch_size=1): 39 | A = self.A 40 | logger = self.logger 41 | 42 | log_out = 'Epoch/Step/Batch={}/{}/{}'.format(epoch, step, batch_idx) 43 | 44 | for k in A.table.keys(): 45 | if 'loss' in k: 46 | log_out += ' {}={:.3f}'.format(k, A.get_mean(k)) 47 | if 'acc' in k: 48 | log_out += ' {}={:.3f}'.format(k, A.get_mean(k)) 49 | 50 | logger.info(log_out) 51 | 52 | # Average sentence length from previous batches 53 | total_length = sum(k * v for k, v in self.c.items()) 54 | total_batches = sum(self.c.values()) 55 | average_length = total_length / total_batches 56 | logger.info('Average-Length={}'.format(average_length)) 57 | logger.info('Length-Distribution={}'.format(self.str_length_distribution())) 58 | 59 | A.reset() 60 | self.c.clear() 61 | 62 | def log_epoch(self, epoch, step): 63 | logger = self.logger 64 | logger.info('Epoch/Step={}/{} (End-Of-Epoch)'.format(epoch, step)) 65 | 66 | def log_eval(self, loss, metric): 67 | logger = self.logger 68 | logger.info('Eval Loss={} Metric={}.'.format(loss, metric)) 69 | -------------------------------------------------------------------------------- /cliora/net/inside_index.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from cliora.net.offset_cache import get_offset_cache 3 | 4 | 5 | class InsideIndex(object): 6 | def get_pairs(self, level, i): 7 | pairs = [] 8 | for constituent_num in range(0, level): 9 | l_level = constituent_num 10 | l_i = i - level + constituent_num 11 | r_level = level - 1 - constituent_num 12 | r_i = i 13 | pair = ((l_level, l_i), (r_level, r_i)) 14 | pairs.append(pair) 15 | return pairs 16 | 17 | def get_all_pairs(self, level, n): 18 | pairs = [] 19 | for i in range(level, n): 20 | pairs += self.get_pairs(level, i) 21 | return pairs 22 | 23 | 24 | class InsideIndexCheck(object): 25 | def __init__(self, length, spans, siblings): 26 | sib_map = {} 27 | for x, y, n in siblings: 28 | sib_map[x] = (y, n) 29 | sib_map[y] = (x, n) 30 | 31 | check = {} 32 | for sibling, (target, name) in sib_map.items(): 33 | xlength = target[1] - target[0] 34 | xlevel = xlength - 1 35 | xpos = target[0] 36 | tgt = (xlevel, xpos) 37 | 38 | slength = sibling[1] - sibling[0] 39 | slevel = slength - 1 40 | spos = sibling[0] 41 | sis = (slevel, spos) 42 | 43 | check[(tgt, sis)] = True 44 | self.check = check 45 | 46 | def is_valid(self, tgt, sis): 47 | return (tgt, sis) in self.check 48 | 49 | 50 | # def get_inside_index(length, level, offset_cache=None, cuda=False): 51 | # if offset_cache is None: 52 | # offset_cache = get_offset_cache(length) 53 | # index = InsideIndex() 54 | # pairs = index.get_all_pairs(level, length) 55 | # 56 | # L = length - level 57 | # n_constituents = len(pairs) // L 58 | # idx_l, idx_r = [], [] 59 | # 60 | # for i in range(n_constituents): 61 | # index_l, index_r = [], [] 62 | # 63 | # lvl_l = i 64 | # lvl_r = level - i - 1 65 | # lstart, lend = 0, L 66 | # rstart, rend = length - L - lvl_r, length - lvl_r 67 | # 68 | # if lvl_l < 0: 69 | # lvl_l = length + lvl_l 70 | # if lvl_r < 0: 71 | # lvl_r = length + lvl_r 72 | # 73 | # for pos in range(lstart, lend): 74 | # offset = offset_cache[lvl_l] 75 | # idx = offset + pos 76 | # index_l.append(idx) 77 | # 78 | # for pos in range(rstart, rend): 79 | # offset = offset_cache[lvl_r] 80 | # idx = offset + pos 81 | # index_r.append(idx) 82 | # 83 | # idx_l.append(index_l) 84 | # idx_r.append(index_r) 85 | # 86 | # device = torch.cuda.current_device() if cuda else None 87 | # idx_l = torch.tensor(idx_l, dtype=torch.int64, device=device 88 | # ).transpose(0, 1).contiguous().flatten() 89 | # idx_r = torch.tensor(idx_r, dtype=torch.int64, device=device 90 | # ).transpose(0, 1).contiguous().flatten() 91 | # 92 | # return idx_l, idx_r 93 | 94 | 95 | def get_inside_index_unique(length, level, offset_cache=None, cuda=False): 96 | if offset_cache is None: 97 | offset_cache = get_offset_cache(length) 98 | index = InsideIndex() 99 | pairs = index.get_all_pairs(level, length) 100 | 101 | L = length - level 102 | n_constituents = len(pairs) // L 103 | idx_set = set() 104 | 105 | for i in range(n_constituents): 106 | lvl_l = i 107 | lvl_r = level - i - 1 108 | lstart, lend = 0, L 109 | rstart, rend = length - L - lvl_r, length - lvl_r 110 | 111 | if lvl_l < 0: 112 | lvl_l = length + lvl_l 113 | if lvl_r < 0: 114 | lvl_r = length + lvl_r 115 | 116 | for pos in range(lstart, lend): 117 | offset = offset_cache[lvl_l] 118 | idx = offset + pos 119 | idx_set.add(idx) 120 | 121 | for pos in range(rstart, rend): 122 | offset = offset_cache[lvl_r] 123 | idx = offset + pos 124 | idx_set.add(idx) 125 | 126 | device = torch.cuda.current_device() if cuda else None 127 | idx_lst = torch.tensor(list(idx_set), dtype=torch.int64, device=device).flatten() 128 | return idx_lst 129 | 130 | 131 | def get_inside_components(length, level, offset_cache=None): 132 | if offset_cache is None: 133 | offset_cache = get_offset_cache(length) 134 | index = InsideIndex() 135 | pairs = index.get_all_pairs(level, length) 136 | 137 | L = length - level 138 | n_constituents = len(pairs) // L 139 | output = [] 140 | 141 | for i in range(n_constituents): 142 | index_l, index_r = [], [] 143 | span_x, span_l, span_r = [], [], [] 144 | 145 | l_level = i 146 | r_level = level - l_level - 1 147 | 148 | l_start = 0 149 | l_end = L 150 | 151 | r_start = length - L - r_level 152 | r_end = length - r_level 153 | 154 | if l_level < 0: 155 | l_level = length + l_level 156 | if r_level < 0: 157 | r_level = length + r_level 158 | 159 | # The span being targeted. 160 | for pos in range(l_start, l_end): 161 | span_x.append((level, pos)) 162 | 163 | # The left child. 164 | for pos in range(l_start, l_end): 165 | offset = offset_cache[l_level] 166 | idx = offset + pos 167 | index_l.append(idx) 168 | span_l.append((l_level, pos)) 169 | 170 | # The right child. 171 | for pos in range(r_start, r_end): 172 | offset = offset_cache[r_level] 173 | idx = offset + pos 174 | index_r.append(idx) 175 | span_r.append((r_level, pos)) 176 | 177 | output.append((index_l, index_r, span_x, span_l, span_r)) 178 | 179 | return output 180 | 181 | 182 | def get_inside_index(length, level, offset_cache=None, cuda=False): 183 | components = get_inside_components(length, level, offset_cache) 184 | 185 | idx_l, idx_r = [], [] 186 | 187 | for i, (index_l, index_r, _, _, _) in enumerate(components): 188 | idx_l.append(index_l) 189 | idx_r.append(index_r) 190 | 191 | device = torch.cuda.current_device() if cuda else None 192 | idx_l = torch.tensor(idx_l, dtype=torch.int64, device=device 193 | ).transpose(0, 1).contiguous().flatten() 194 | idx_r = torch.tensor(idx_r, dtype=torch.int64, device=device 195 | ).transpose(0, 1).contiguous().flatten() 196 | 197 | return idx_l, idx_r -------------------------------------------------------------------------------- /cliora/net/offset_cache.py: -------------------------------------------------------------------------------- 1 | def get_offset_cache(length): 2 | offset_cache = {} 3 | ncells = int(length * (1 + length) / 2) 4 | for lvl in range(length): 5 | level_length = length - lvl 6 | ncells_less = int(level_length * (1 + level_length) / 2) 7 | offset_cache[lvl] = ncells - ncells_less 8 | return offset_cache -------------------------------------------------------------------------------- /cliora/net/outside_index.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from cliora.net.offset_cache import get_offset_cache 3 | 4 | 5 | class OutsideIndex(object): 6 | def get_pairs(self, level, i, n): 7 | """ 8 | Returns all (parent, sibling) coordinate pairs that 9 | are used to construct a node at coordinates 10 | (level, i) where there n leaf nodes. 11 | 12 | """ 13 | pairs = [] 14 | 15 | for level_ in range(level + 1, i + 1): 16 | p_level = level_ 17 | p_i = i 18 | s_level = level_ - level - 1 19 | s_i = i - level - 1 20 | 21 | pairs.append([(p_level, p_i), (s_level, s_i)]) 22 | 23 | for i_ in range(i + 1, n): 24 | p_level = level + i_ - i 25 | p_i = i_ 26 | s_level = i_ - i - 1 27 | s_i = i_ 28 | 29 | pairs.append([(p_level, p_i), (s_level, s_i)]) 30 | 31 | return pairs 32 | 33 | def xget_all_pairs(self, level, n): 34 | pairs = [] 35 | for i in range(level, n): 36 | pairs += self.get_pairs(level, i, n) 37 | return pairs 38 | 39 | def get_all_pairs(self, level, n): 40 | L = n - level 41 | N = L - 1 42 | 43 | pairs = [] 44 | 45 | for i in range(N): 46 | jseen = 0 47 | for j in range(L): 48 | if j < N - i: 49 | s_level = n - i - 1 50 | s_i = N - i - j - 1 51 | p_level = s_level 52 | p_i = s_level - j 53 | else: 54 | s_level = j - 1 55 | s_i = jseen 56 | p_level = n - (N - s_level) 57 | p_i = n - (N - s_i) 58 | jseen += 1 59 | pair = [(p_i, p_level), (s_i, s_level)] 60 | pairs.append(pair) 61 | 62 | return pairs 63 | 64 | 65 | class OutsideIndexCheck(object): 66 | def __init__(self, length, spans, siblings): 67 | sib_map = {} 68 | for x, y, n in siblings: 69 | sib_map[x] = (y, n) 70 | sib_map[y] = (x, n) 71 | 72 | check = {} 73 | for sibling, (target, name) in sib_map.items(): 74 | xlength = target[1] - target[0] 75 | xlevel = xlength - 1 76 | xpos = target[0] 77 | tgt = (xlevel, xpos) 78 | 79 | slength = sibling[1] - sibling[0] 80 | slevel = slength - 1 81 | spos = sibling[0] 82 | sis = (slevel, spos) 83 | 84 | par = (sis[0] + tgt[0] + 1, min(sis[1], tgt[1])) 85 | 86 | check[(par, sis)] = True 87 | self.check = check 88 | 89 | def is_valid(self, par, sis): 90 | return (par, sis) in self.check 91 | 92 | 93 | def get_outside_index(length, level, offset_cache=None, cuda=False): 94 | if offset_cache is None: 95 | offset_cache = get_offset_cache(length) 96 | index = OutsideIndex() 97 | pairs = index.get_all_pairs(level, length) 98 | 99 | par_lvl, par_pos = [], [] 100 | sis_lvl, sis_pos = [], [] 101 | 102 | for pair in pairs: 103 | par, sis = pair 104 | par_lvl.append(par[0]) 105 | par_pos.append(par[1] - par[0]) 106 | sis_lvl.append(sis[0]) 107 | sis_pos.append(sis[1] - sis[0]) 108 | 109 | device = torch.cuda.current_device() if cuda else None 110 | 111 | # Parent 112 | index = [] 113 | for lvl, pos in zip(par_lvl, par_pos): 114 | offset = offset_cache[lvl] 115 | idx = offset + pos 116 | index.append(idx) 117 | par_index = torch.tensor(index, dtype=torch.int64, device=device) 118 | 119 | # Sibling 120 | index = [] 121 | for lvl, pos in zip(sis_lvl, sis_pos): 122 | offset = offset_cache[lvl] 123 | idx = offset + pos 124 | index.append(idx) 125 | sis_index = torch.tensor(index, dtype=torch.int64, device=device) 126 | 127 | return par_index, sis_index 128 | 129 | 130 | def get_outside_components(length, level, offset_cache=None): 131 | index = OutsideIndex() 132 | pairs = index.get_all_pairs(level, length) 133 | output = [] 134 | 135 | for pair in pairs: 136 | par, sis = pair 137 | par_lvl = par[0] 138 | par_pos = par[1] - par[0] 139 | par_span = (par_lvl, par_pos) 140 | sis_lvl = sis[0] 141 | sis_pos = sis[1] - sis[0] 142 | sis_span = (sis_lvl, sis_pos) 143 | 144 | output.append((par_span, sis_span)) 145 | 146 | return output 147 | 148 | 149 | def get_topk_outside_index(length, level, K, offset_cache=None, cuda=False): 150 | if offset_cache is None: 151 | offset_cache = get_offset_cache(length) 152 | 153 | L = length - level 154 | # N = length - level - 1 155 | 156 | components = get_outside_components(length, level, offset_cache) 157 | 158 | p_info, s_info = [], [] 159 | for i, (p_span, s_span) in enumerate(components): 160 | p_level, p_pos = p_span 161 | s_level, s_pos = s_span 162 | n_idx = i // L 163 | x_pos = i % L 164 | p_idx = offset_cache[p_level] + p_pos 165 | s_idx = offset_cache[s_level] + s_pos 166 | 167 | p_info.append((x_pos, n_idx, p_level, p_pos, p_idx)) 168 | s_info.append((x_pos, n_idx, s_level, s_pos, s_idx)) 169 | 170 | def sort_key(x): 171 | x_pos, n_idx, inp_level, inp_pos, inp_idx = x 172 | return (x_pos, n_idx) 173 | 174 | def get_val(x): 175 | x_pos, n_idx, inp_level, inp_pos, inp_idx = x 176 | return inp_idx 177 | 178 | p_info = sorted(p_info, key=sort_key) 179 | s_info = sorted(s_info, key=sort_key) 180 | 181 | device = torch.cuda.current_device() if cuda else None 182 | 183 | p_index = torch.tensor([get_val(x) for x in p_info], dtype=torch.long, device=device) 184 | s_index = torch.tensor([get_val(x) for x in s_info], dtype=torch.long, device=device) 185 | 186 | return p_index, p_info, s_index, s_info -------------------------------------------------------------------------------- /cliora/net/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from scipy.special import factorial 4 | 5 | from cliora.net.outside_index import get_outside_index, get_topk_outside_index, get_outside_components 6 | from cliora.net.inside_index import get_inside_index, get_inside_components, get_inside_index_unique 7 | from cliora.net.offset_cache import get_offset_cache 8 | 9 | from collections import OrderedDict 10 | 11 | TINY = 1e-8 12 | class UnitNorm(object): 13 | def __call__(self, x, p=2, eps=TINY): 14 | return x / x.norm(p=p, dim=-1, keepdim=True).clamp(min=eps) 15 | 16 | 17 | class NormalizeFunc(nn.Module): 18 | def __init__(self, mode='none'): 19 | super(NormalizeFunc, self).__init__() 20 | self.mode = mode 21 | 22 | def forward(self, x): 23 | mode = self.mode 24 | if mode == 'none': 25 | return x 26 | elif mode == 'unit': 27 | return UnitNorm()(x) 28 | 29 | 30 | class BatchInfo(object): 31 | def __init__(self, **kwargs): 32 | super(BatchInfo, self).__init__() 33 | for k, v in kwargs.items(): 34 | setattr(self, k, v) 35 | 36 | 37 | class ImageEncoder(nn.Module): 38 | def __init__(self, input_size, size): 39 | super(ImageEncoder, self).__init__() 40 | 41 | self.fc = nn.Linear(input_size, size) 42 | self.fc_vis = nn.Linear(input_size, size) 43 | self.reset_parameters() 44 | 45 | def reset_parameters(self): 46 | # zero norm keep same with MAF 47 | params = [p for p in self.parameters() if p.requires_grad] 48 | for i, param in enumerate(params): 49 | # param.data.normal_() 50 | param.data.zero_() 51 | 52 | def forward(self, obj_feats): 53 | features_span = self.fc(obj_feats.float()) 54 | features_word = self.fc_vis(obj_feats.float()) 55 | return features_span, features_word 56 | 57 | 58 | def get_catalan(n): 59 | if n > 10: 60 | return 5000 # HACK: We only use this to check number of trees, and this avoids overflow. 61 | n = n - 1 62 | def choose(n, p): 63 | return factorial(n) / (factorial(p) * factorial(n-p)) 64 | return int(choose(2 * n, n) // (n + 1)) 65 | 66 | 67 | class Index(object): 68 | def __init__(self, cuda=False, enable_caching=True): 69 | super(Index, self).__init__() 70 | self.cuda = cuda 71 | self.cache = {} 72 | self.inside_index_cache = {} 73 | self.inside_index_unique_cache = {} 74 | self.outside_index_cache = {} 75 | self.outside_encoded_index_cache = {} 76 | self.offset_cache = {} 77 | self.enable_caching = enable_caching 78 | 79 | def cached_lookup(self, func, name, key): 80 | if name not in self.cache: 81 | self.cache[name] = {} 82 | cache = self.cache[name] 83 | if self.enable_caching: 84 | if key not in cache: 85 | cache[key] = func() 86 | return cache[key] 87 | else: 88 | return func() 89 | 90 | def get_catalan(self, n): 91 | name = 'catalan' 92 | key = n 93 | def func(): 94 | return get_catalan(n) 95 | return self.cached_lookup(func, name, key) 96 | 97 | def get_offset(self, length): 98 | name = 'offset_cache' 99 | key = length 100 | def func(): 101 | return get_offset_cache(length) 102 | return self.cached_lookup(func, name, key) 103 | 104 | def get_inside_index(self, length, level): 105 | name = 'inside_index_cache' 106 | key = (length, level) 107 | def func(): 108 | return get_inside_index(length, level, 109 | self.get_offset(length), cuda=self.cuda) 110 | return self.cached_lookup(func, name, key) 111 | 112 | def get_inside_index_unique(self, length, level): 113 | name = 'inside_index_unique_cache' 114 | key = (length, level) 115 | def func(): 116 | return get_inside_index_unique(length, level, 117 | self.get_offset(length), cuda=self.cuda) 118 | return self.cached_lookup(func, name, key) 119 | 120 | def get_outside_index(self, length, level): 121 | name = 'outside_index_cache' 122 | key = (length, level) 123 | def func(): 124 | return get_outside_index(length, level, 125 | self.get_offset(length), cuda=self.cuda) 126 | return self.cached_lookup(func, name, key) 127 | 128 | def get_topk_outside_index(self, length, level, K): 129 | name = 'topk_outside_index_cache' 130 | key = (length, level, K) 131 | def func(): 132 | return get_topk_outside_index(length, level, K, 133 | self.get_offset(length), cuda=self.cuda) 134 | return self.cached_lookup(func, name, key) -------------------------------------------------------------------------------- /cliora/net/vg.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from cliora.net.utils import * 5 | 6 | 7 | class Chart(object): 8 | def __init__(self, batch_size, length, size, dtype=None, cuda=False): 9 | super(Chart, self).__init__() 10 | 11 | ncells = int(length * (1 + length) / 2) 12 | 13 | device = torch.cuda.current_device() if cuda else None 14 | 15 | ## Inside. 16 | self.inside_h = torch.full((batch_size, ncells, size), 0, dtype=dtype, device=device) 17 | self.inside_c = torch.full((batch_size, ncells, size), 0, dtype=dtype, device=device) 18 | self.inside_s = torch.full((batch_size, ncells, 1), 0, dtype=dtype, device=device) 19 | 20 | ## Outside. 21 | self.outside_h = torch.full((batch_size, ncells, size), 0, dtype=dtype, device=device) 22 | self.outside_c = torch.full((batch_size, ncells, size), 0, dtype=dtype, device=device) 23 | self.outside_s = torch.full((batch_size, ncells, 1), 0, dtype=dtype, device=device) 24 | 25 | 26 | # Composition Functions 27 | 28 | # class TreeLSTM(nn.Module): 29 | # def __init__(self, size, ninput=2, leaf=False): 30 | # super(TreeLSTM, self).__init__() 31 | # 32 | # self.size = size 33 | # self.ninput = ninput 34 | # 35 | # if leaf: 36 | # self.W = nn.Parameter(torch.FloatTensor(3 * self.size, self.size)) 37 | # self.U = nn.Parameter(torch.FloatTensor(5 * self.size, self.ninput * self.size)) 38 | # self.B = nn.Parameter(torch.FloatTensor(5 * self.size)) 39 | # self.reset_parameters() 40 | # 41 | # def reset_parameters(self): 42 | # params = [p for p in self.parameters() if p.requires_grad] 43 | # for i, param in enumerate(params): 44 | # param.data.normal_() 45 | # 46 | # def leaf_transform(self, x): 47 | # W, B = self.W, self.B[:3*self.size] 48 | # 49 | # activations = torch.matmul(x, W.t()) + B 50 | # a_lst = torch.chunk(activations, 3, dim=-1) 51 | # u = torch.tanh(a_lst[0]) 52 | # i = torch.sigmoid(a_lst[1]) 53 | # o = torch.sigmoid(a_lst[2]) 54 | # 55 | # c = i * u 56 | # h = o * torch.tanh(c) 57 | # 58 | # return h, c 59 | # 60 | # def forward(self, hs, cs, constant=1.0): 61 | # U, B = self.U, self.B 62 | # 63 | # input_h = torch.cat(hs, 1) 64 | # 65 | # activations = torch.matmul(input_h, U.t()) + B 66 | # a_lst = torch.chunk(activations, 5, dim=1) 67 | # u = torch.tanh(a_lst[0]) 68 | # i = torch.sigmoid(a_lst[1]) 69 | # o = torch.sigmoid(a_lst[2]) 70 | # f0 = torch.sigmoid(a_lst[3] + constant) 71 | # f1 = torch.sigmoid(a_lst[4] + constant) 72 | # 73 | # c = f0 * cs[0] + f1 * cs[1] + i * u 74 | # h = o * torch.tanh(c) 75 | # 76 | # return h, c 77 | 78 | 79 | class ComposeMLP(nn.Module): 80 | def __init__(self, size, ninput=2, leaf=False): 81 | super(ComposeMLP, self).__init__() 82 | 83 | self.size = size 84 | self.ninput = ninput 85 | 86 | if leaf: 87 | self.leaf_fc = nn.Linear(self.size, self.size) 88 | self.h_fcs = nn.Sequential( 89 | nn.Linear(2 * self.size, self.size), 90 | nn.ReLU(), 91 | nn.Linear(self.size, self.size), 92 | nn.ReLU() 93 | ) 94 | # self.reset_parameters() 95 | 96 | @property 97 | def device(self): 98 | return next(self.parameters()).device 99 | 100 | @property 101 | def is_cuda(self): 102 | device = self.device 103 | return device.index is not None and device.index >= 0 104 | 105 | # def reset_parameters(self): 106 | # # TODO: Init with diagonal. 107 | # params = [p for p in self.parameters() if p.requires_grad] 108 | # for i, param in enumerate(params): 109 | # param.data.normal_() 110 | 111 | def leaf_transform(self, x): 112 | # h = self.leaf_fc(x) 113 | h = torch.tanh(self.leaf_fc(x)) 114 | c = torch.full(h.shape, 0, dtype=torch.float32, device=h.device) 115 | 116 | return h, c 117 | 118 | def forward(self, hs, cs=None, constant=1.0): 119 | input_h = torch.cat(hs, 1) 120 | h = self.h_fcs(input_h) 121 | 122 | # device = torch.cuda.current_device() if self.is_cuda else None 123 | c = torch.full(h.shape, 0, dtype=torch.float32, device=h.device) 124 | 125 | return h, c 126 | 127 | 128 | # Score Functions 129 | 130 | class Bilinear(nn.Module): 131 | def __init__(self, size): 132 | super(Bilinear, self).__init__() 133 | self.size = size 134 | self.mat = nn.Parameter(torch.FloatTensor(self.size, self.size)) 135 | # self.reset_parameters() 136 | 137 | # def reset_parameters(self): 138 | # params = [p for p in self.parameters() if p.requires_grad] 139 | # for i, param in enumerate(params): 140 | # param.data.normal_() 141 | 142 | def forward(self, vector1, vector2): 143 | # bilinear 144 | # a = 1 (in a more general bilinear function, a is any positive integer) 145 | # vector1.shape = (b, m) 146 | # matrix.shape = (m, n) 147 | # vector2.shape = (b, n) 148 | bma = torch.matmul(vector1, self.mat).unsqueeze(1) 149 | ba = torch.matmul(bma, vector2.unsqueeze(2)).view(-1, 1) 150 | return ba 151 | 152 | 153 | # Inside 154 | 155 | def inside_fill_chart(batch_info, chart, index, h, c, s): 156 | L = batch_info.length - batch_info.level 157 | 158 | offset = index.get_offset(batch_info.length)[batch_info.level] 159 | 160 | chart.inside_h[:, offset:offset+L] = h 161 | chart.inside_c[:, offset:offset+L] = c 162 | chart.inside_s[:, offset:offset+L] = s 163 | 164 | 165 | def get_inside_states(batch_info, chart, index, size): 166 | lidx, ridx = index.get_inside_index(batch_info.length, batch_info.level) 167 | 168 | ls = chart.index_select(index=lidx, dim=1).view(-1, size) 169 | rs = chart.index_select(index=ridx, dim=1).view(-1, size) 170 | 171 | return ls, rs 172 | 173 | 174 | def inside_compose(compose_func, hs, cs): 175 | return compose_func(hs, cs) 176 | 177 | 178 | def inside_score(score_func, batch_info, hs, ss): 179 | B = batch_info.batch_size 180 | L = batch_info.length - batch_info.level 181 | N = batch_info.level 182 | 183 | s = score_func(hs[0], hs[1]) + ss[0] + ss[1] 184 | s = s.view(B, L, N, 1) 185 | p = torch.softmax(s, dim=2) 186 | 187 | return s, p 188 | 189 | 190 | def inside_aggregate(batch_info, h, c, s, p, normalize_func): 191 | B = batch_info.batch_size 192 | L = batch_info.length - batch_info.level 193 | N = batch_info.level 194 | 195 | h_agg = torch.sum(h.view(B, L, N, -1) * p, 2) 196 | c_agg = torch.sum(c.view(B, L, N, -1) * p, 2) 197 | s_agg = torch.sum(s * p, 2) 198 | 199 | h_agg = normalize_func(h_agg) 200 | c_agg = normalize_func(c_agg) 201 | 202 | return h_agg, c_agg, s_agg 203 | 204 | 205 | # Outside 206 | 207 | def outside_fill_chart(batch_info, chart, index, h, c, s): 208 | L = batch_info.length - batch_info.level 209 | 210 | offset = index.get_offset(batch_info.length)[batch_info.level] 211 | 212 | chart.outside_h[:, offset:offset+L] = h 213 | chart.outside_c[:, offset:offset+L] = c 214 | chart.outside_s[:, offset:offset+L] = s 215 | 216 | 217 | def get_outside_states(batch_info, pchart, schart, index, size): 218 | pidx, sidx = index.get_outside_index(batch_info.length, batch_info.level) 219 | 220 | ps = pchart.index_select(index=pidx, dim=1).view(-1, size) 221 | ss = schart.index_select(index=sidx, dim=1).view(-1, size) 222 | 223 | return ps, ss 224 | 225 | 226 | def outside_compose(compose_func, hs, cs): 227 | return compose_func(hs, cs, 0) 228 | 229 | 230 | def outside_score(score_func, batch_info, hs, ss): 231 | B = batch_info.batch_size 232 | L = batch_info.length - batch_info.level 233 | 234 | s = score_func(hs[0], hs[1]) + ss[0] + ss[1] 235 | s = s.view(B, -1, L, 1) 236 | p = torch.softmax(s, dim=1) 237 | 238 | return s, p 239 | 240 | 241 | def outside_aggregate(batch_info, h, c, s, p, normalize_func): 242 | B = batch_info.batch_size 243 | L = batch_info.length - batch_info.level 244 | N = s.shape[1] 245 | 246 | h_agg = torch.sum(h.view(B, N, L, -1) * p, 1) 247 | c_agg = torch.sum(c.view(B, N, L, -1) * p, 1) 248 | s_agg = torch.sum(s * p, 1) 249 | 250 | h_agg = normalize_func(h_agg) 251 | c_agg = normalize_func(c_agg) 252 | 253 | return h_agg, c_agg, s_agg 254 | 255 | 256 | # Base 257 | 258 | class DioraBase(nn.Module): 259 | r"""DioraBase 260 | 261 | """ 262 | 263 | def __init__(self, size, word_mat=None, cate_mat=None, outside=True, normalize='unit', compress=False, share=True): 264 | super(DioraBase, self).__init__() 265 | assert normalize in ('none', 'unit', 'xavier'), 'Does not support "{}".'.format(normalize) 266 | 267 | self.size = size 268 | self.share = share 269 | self.outside = outside 270 | self.inside_normalize_func = NormalizeFunc(normalize) 271 | self.outside_normalize_func = NormalizeFunc(normalize) 272 | # self.inside_normalize_func = nn.LayerNorm(size) 273 | # self.outside_normalize_func = nn.LayerNorm(size) 274 | self.compress = compress 275 | self.ninput = 2 276 | 277 | self.index = None 278 | self.charts = None 279 | 280 | self.init_parameters() 281 | self.reset_parameters() # reset all submodules 282 | self.reset() 283 | 284 | def init_parameters(self): 285 | raise NotImplementedError 286 | 287 | def reset_parameters(self): 288 | params = [p for p in self.parameters() if p.requires_grad] 289 | for i, param in enumerate(params): 290 | param.data.normal_() 291 | 292 | @property 293 | def device(self): 294 | return next(self.parameters()).device 295 | 296 | @property 297 | def inside_h(self): 298 | return self.chart.inside_h 299 | 300 | @property 301 | def inside_c(self): 302 | return self.chart.inside_c 303 | 304 | @property 305 | def inside_s(self): 306 | return self.chart.inside_s 307 | 308 | @property 309 | def outside_h(self): 310 | return self.chart.outside_h 311 | 312 | @property 313 | def outside_c(self): 314 | return self.chart.outside_c 315 | 316 | @property 317 | def outside_s(self): 318 | return self.chart.outside_s 319 | 320 | @property 321 | def is_cuda(self): 322 | device = self.device 323 | return device.index is not None and device.index >= 0 324 | 325 | def cuda(self): 326 | super(DioraBase, self).cuda() 327 | if self.index is not None: 328 | self.index.cuda = True # TODO: Should support to/from cpu/gpu. 329 | 330 | def get(self, chart, level): 331 | length = self.length 332 | L = length - level 333 | offset = self.index.get_offset(length)[level] 334 | return chart[:, offset:offset+L] 335 | 336 | def leaf_transform(self, x): 337 | normalize_func = self.inside_normalize_func 338 | transform_func = self.inside_compose_func.leaf_transform 339 | 340 | input_shape = x.shape[:-1] 341 | h, c = transform_func(x) 342 | h = normalize_func(h.view(*input_shape, self.size)) 343 | c = normalize_func(c.view(*input_shape, self.size)) 344 | 345 | return h, c 346 | 347 | # Inside 348 | def inside_func(self, compose_func, score_func, batch_info, chart, index, normalize_func): 349 | lh, rh = get_inside_states(batch_info, chart.inside_h, index, batch_info.size) 350 | lc, rc = get_inside_states(batch_info, chart.inside_c, index, batch_info.size) 351 | ls, rs = get_inside_states(batch_info, chart.inside_s, index, 1) 352 | 353 | hlst = [lh, rh] 354 | clst = [lc, rc] 355 | slst = [ls, rs] 356 | 357 | h, c = inside_compose(compose_func, hlst, clst) 358 | s, p = inside_score(score_func, batch_info, hlst, slst) 359 | hbar, cbar, sbar = inside_aggregate(batch_info, h, c, s, p, normalize_func) 360 | 361 | inside_fill_chart(batch_info, chart, index, hbar, cbar, sbar) 362 | 363 | return h, c, s 364 | 365 | def inside_pass(self): 366 | compose_func = self.inside_compose_func 367 | score_func = self.inside_score_func 368 | index = self.index 369 | chart = self.chart 370 | normalize_func = self.inside_normalize_func 371 | 372 | for level in range(1, self.length): 373 | 374 | batch_info = BatchInfo( 375 | batch_size=self.batch_size, 376 | length=self.length, 377 | size=self.size, 378 | level=level, 379 | ) 380 | 381 | h, c, s = self.inside_func(compose_func, score_func, batch_info, chart, index, 382 | normalize_func=normalize_func) 383 | 384 | self.inside_hook(level, h, c, s) 385 | 386 | def inside_hook(self, level, h, c, s): 387 | pass 388 | 389 | # Outside 390 | def initialize_outside_root(self): 391 | B = self.batch_size 392 | D = self.size 393 | normalize_func = self.outside_normalize_func 394 | 395 | if self.compress: 396 | h = torch.matmul(self.inside_h[:, -1:], self.root_mat_out) 397 | else: 398 | h = self.root_vector_out_h.view(1, 1, D).expand(B, 1, D) 399 | if self.root_vector_out_c is None: 400 | device = torch.cuda.current_device() if self.is_cuda else None 401 | c = torch.full(h.shape, 0, dtype=torch.float32, device=device) 402 | else: 403 | c = self.root_vector_out_c.view(1, 1, D).expand(B, 1, D) 404 | 405 | h = normalize_func(h) 406 | c = normalize_func(c) 407 | 408 | self.chart.outside_h[:, -1:] = h 409 | self.chart.outside_c[:, -1:] = c 410 | 411 | def outside_func(self, compose_func, score_func, batch_info, chart, index, normalize_func): 412 | ph, sh = get_outside_states( 413 | batch_info, chart.outside_h, chart.inside_h, index, batch_info.size) 414 | pc, sc = get_outside_states( 415 | batch_info, chart.outside_c, chart.inside_c, index, batch_info.size) 416 | ps, ss = get_outside_states( 417 | batch_info, chart.outside_s, chart.inside_s, index, 1) 418 | 419 | hlst = [sh, ph] 420 | clst = [sc, pc] 421 | slst = [ss, ps] 422 | 423 | h, c = outside_compose(compose_func, hlst, clst) 424 | s, p = outside_score(score_func, batch_info, hlst, slst) 425 | hbar, cbar, sbar = outside_aggregate(batch_info, h, c, s, p, normalize_func) 426 | 427 | outside_fill_chart(batch_info, chart, index, hbar, cbar, sbar) 428 | 429 | return h, c, s 430 | 431 | def outside_pass(self): 432 | self.initialize_outside_root() 433 | 434 | compose_func = self.outside_compose_func 435 | score_func = self.outside_score_func 436 | index = self.index 437 | chart = self.chart 438 | normalize_func = self.outside_normalize_func 439 | 440 | for level in range(self.length - 2, -1, -1): 441 | batch_info = BatchInfo( 442 | batch_size=self.batch_size, 443 | length=self.length, 444 | size=self.size, 445 | level=level, 446 | ) 447 | 448 | h, c, s = self.outside_func(compose_func, score_func, batch_info, chart, index, 449 | normalize_func=normalize_func) 450 | 451 | self.outside_hook(level, h, c, s) 452 | 453 | def outside_hook(self, level, h, c, s): 454 | pass 455 | 456 | # Initialization 457 | def init_with_batch(self, h, c): 458 | size = self.size 459 | batch_size, length, _ = h.shape 460 | 461 | self.batch_size = batch_size 462 | self.length = length 463 | 464 | self.chart = Chart(batch_size, length, size, dtype=torch.float32, cuda=self.is_cuda) 465 | self.chart.inside_h[:, :self.length] = h 466 | self.chart.inside_c[:, :self.length] = c 467 | 468 | def reset(self): 469 | self.batch_size = None 470 | self.length = None 471 | self.chart = None 472 | self.atten_score = None 473 | self.all_atten_score = None 474 | self.vg_atten_score = None 475 | 476 | 477 | def forward(self, x, obj_embed=None, obj_embed_v=None, prior_sim=None): 478 | 479 | self.vg_atten_score = torch.einsum('abx,cdx->acbd', x, obj_embed) 480 | self.atten_score = torch.diagonal(self.vg_atten_score, 0, 0, 1).permute(2, 0, 1) 481 | 482 | return None 483 | 484 | 485 | class DioraMLP(DioraBase): 486 | 487 | def init_parameters(self): 488 | self.inside_score_func = Bilinear(self.size) 489 | self.inside_compose_func = ComposeMLP(self.size, leaf=True) 490 | 491 | if self.share: 492 | self.outside_score_func = self.inside_score_func 493 | self.outside_compose_func = self.inside_compose_func 494 | else: 495 | self.outside_score_func = Bilinear(self.size) 496 | self.outside_compose_func = ComposeMLP(self.size) 497 | 498 | if self.compress: 499 | self.root_mat_out = nn.Parameter(torch.FloatTensor(self.size, self.size)) 500 | else: 501 | self.root_vector_out_h = nn.Parameter(torch.FloatTensor(self.size)) 502 | self.root_vector_out_c = None 503 | # self.root_vector_out_c = nn.Parameter(torch.FloatTensor(self.size)) 504 | -------------------------------------------------------------------------------- /cliora/scripts/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bobwan1995/cliora/b064bdf967d4ccc4f3327183efd888b927bfb4fb/cliora/scripts/__init__.py -------------------------------------------------------------------------------- /cliora/scripts/parse.py: -------------------------------------------------------------------------------- 1 | import os 2 | import collections 3 | import json 4 | 5 | import torch 6 | import torchvision.ops as torchops 7 | import numpy as np 8 | from tqdm import tqdm 9 | 10 | from train import argument_parser, parse_args, configure 11 | from train import get_validation_dataset, get_validation_iterator 12 | from train import build_net 13 | 14 | from cliora.logging.configuration import get_logger 15 | 16 | from cliora.analysis.cky import ParsePredictor as CKY 17 | from cliora.analysis.utils import * 18 | import copy 19 | 20 | punctuation_words = set([x.lower() for x in ['.', ',', ':', '-LRB-', '-RRB-', '\'\'', 21 | '``', '--', ';', '-', '?', '!', '...', '-LCB-', '-RCB-']]) 22 | 23 | 24 | def remove_using_flat_mask(tr, mask): 25 | kept, removed = [], [] 26 | def func(tr, pos=0): 27 | if not isinstance(tr, (list, tuple)): 28 | if mask[pos] == False: 29 | removed.append(tr) 30 | return None, 1 31 | kept.append(tr) 32 | return tr, 1 33 | 34 | size = 0 35 | node = [] 36 | 37 | for subtree in tr: 38 | x, xsize = func(subtree, pos=pos + size) 39 | if x is not None: 40 | node.append(x) 41 | size += xsize 42 | 43 | if len(node) == 1: 44 | node = node[0] 45 | elif len(node) == 0: 46 | return None, size 47 | return node, size 48 | new_tree, _ = func(tr) 49 | return new_tree, kept, removed 50 | 51 | 52 | def flatten_tree(tr): 53 | def func(tr): 54 | if not isinstance(tr, (list, tuple)): 55 | return [tr] 56 | result = [] 57 | for x in tr: 58 | result += func(x) 59 | return result 60 | return func(tr) 61 | 62 | 63 | def postprocess(tr, tokens=None): 64 | if tokens is None: 65 | tokens = flatten_tree(tr) 66 | 67 | # Don't remove the last token. It's not punctuation. 68 | if tokens[-1].lower() not in punctuation_words: 69 | return tr 70 | 71 | mask = [True] * (len(tokens) - 1) + [False] 72 | tr, kept, removed = remove_using_flat_mask(tr, mask) 73 | assert len(kept) == len(tokens) - 1, 'Incorrect tokens left. Output = {}, Kept = {}'.format( 74 | tr, kept) 75 | assert len(kept) > 0, 'No tokens left. Original = {}'.format(tokens) 76 | assert len(removed) == 1 77 | tr = (tr, tokens[-1]) 78 | 79 | return tr 80 | 81 | 82 | def replace_leaves(tree, leaves): 83 | def func(tr, pos=0): 84 | if not isinstance(tr, (list, tuple)): 85 | return 1, leaves[pos] 86 | 87 | newtree = [] 88 | sofar = 0 89 | for node in tr: 90 | size, newnode = func(node, pos+sofar) 91 | sofar += size 92 | newtree += [newnode] 93 | 94 | return sofar, newtree 95 | 96 | _, newtree = func(tree) 97 | 98 | return newtree 99 | 100 | 101 | def run(options): 102 | logger = get_logger() 103 | 104 | validation_dataset = get_validation_dataset(options) 105 | validation_iterator = get_validation_iterator(options, validation_dataset) 106 | word2idx = validation_dataset['word2idx'] 107 | embeddings = validation_dataset['embeddings'] 108 | 109 | idx2word = {v: k for k, v in word2idx.items()} 110 | 111 | logger.info('Initializing model.') 112 | trainer = build_net(options, embeddings) 113 | 114 | # Parse 115 | 116 | diora = trainer.net.diora 117 | 118 | ## Monkey patch parsing specific methods. 119 | override_init_with_batch(diora) 120 | override_inside_hook(diora) 121 | 122 | ## Turn off outside pass. 123 | trainer.net.diora.outside = True 124 | # trainer.net.cliora.outside = False 125 | 126 | ## Eval mode. 127 | trainer.net.eval() 128 | 129 | ## Parse predictor. 130 | parse_predictor = CKY(net=diora) 131 | 132 | batches = validation_iterator.get_iterator(random_seed=options.seed) 133 | 134 | output_path = os.path.abspath(os.path.join(options.experiment_path, 'parse.jsonl')) 135 | 136 | logger.info('Beginning.') 137 | logger.info('Writing output to = {}'.format(output_path)) 138 | 139 | total_num = 0. 140 | recall_num = 0. 141 | ccr_num = 0. 142 | corpus_f1 = [0., 0., 0.] 143 | sent_f1 = [] 144 | f = open(output_path, 'w') 145 | 146 | recon_loss = 0 147 | vg_loss = 0 148 | contr_loss = 0 149 | total_loss = 0 150 | num_data = 0 151 | 152 | with torch.no_grad(): 153 | for i, batch_map in tqdm(enumerate(batches)): 154 | sentences = batch_map['sentences'] 155 | # batch_size = sentences.shape[0] 156 | length = sentences.shape[1] 157 | 158 | # Skip very short sentences. 159 | if length <= 2: 160 | continue 161 | 162 | result = trainer.step(batch_map, idx2word, train=False, compute_loss=True) 163 | recon_loss += result['reconstruction_softmax_loss'] 164 | vg_loss += result.get('vg_loss', 0) 165 | contr_loss += result.get('contrastive_loss', 0) 166 | total_loss += result['total_loss'] 167 | num_data += 1 168 | 169 | if diora.all_atten_score is not None: 170 | all_atten_score = torch.diagonal(diora.all_atten_score.cpu(), 0, 0, 1).permute(2, 0, 1) 171 | else: 172 | all_atten_score = None 173 | 174 | batch_ground_res = None 175 | if diora.atten_score is not None: 176 | targets = batch_map['VG_GT'] 177 | batch_size = len(targets) 178 | attenion_scores = diora.atten_score.cpu() 179 | precomp_boxes = batch_map['boxes'].cpu() 180 | batch_ground_res = [] 181 | for bid in range(batch_size): 182 | target_bid, noun_mask = targets[bid] 183 | precomp_boxes_bid = precomp_boxes[bid] 184 | attenion_scores_bid = attenion_scores[bid] 185 | # span_atten_scores = all_atten_score[bid] 186 | 187 | ground_res = [] 188 | for _, gt_anno in target_bid.items(): 189 | start_id, end_id, gt_box = gt_anno 190 | words_scores = attenion_scores_bid[start_id:end_id] 191 | max_word_scores, _ = words_scores.max(1) 192 | select_wid = max_word_scores.max(0)[1] 193 | word2phr_atten = words_scores[select_wid] 194 | 195 | # s, e = start_id, end_id-1 196 | # k = e-s 197 | # index = int(k*length - k*(k-1)/2 + s) 198 | # span_atten = span_atten_scores[index] 199 | 200 | select_box_ids = word2phr_atten.max(0)[1] 201 | # select_box_ids = (word2phr_atten+0.1*span_atten).max(0)[1] 202 | pred_box = precomp_boxes_bid[select_box_ids] 203 | 204 | iou = torchops.box_iou(pred_box[None, :], torch.Tensor([gt_box])) 205 | if iou.max() > 0.5: 206 | recall_num += 1 207 | ground_res.append(((start_id, end_id-1), 1)) 208 | else: 209 | ground_res.append(((start_id, end_id-1), 0)) 210 | total_num += 1 211 | 212 | batch_ground_res.append(ground_res) 213 | 214 | trees = parse_predictor.parse_batch(batch_map) 215 | for bid, tr in enumerate(trees): 216 | # CorpusF1 217 | gold_spans = set(batch_map['GT'][bid][:-1]) 218 | pred_actions = get_actions(str(tr)) 219 | pred_spans = set(get_spans(pred_actions)[:-1]) 220 | tp, fp, fn = get_stats(pred_spans, gold_spans) 221 | corpus_f1[0] += tp 222 | corpus_f1[1] += fp 223 | corpus_f1[2] += fn 224 | 225 | # SentF1 226 | overlap = pred_spans.intersection(gold_spans) 227 | prec = float(len(overlap)) / (len(pred_spans) + 1e-8) 228 | reca = float(len(overlap)) / (len(gold_spans) + 1e-8) 229 | if len(gold_spans) == 0: 230 | reca = 1. 231 | if len(pred_spans) == 0: 232 | prec = 1. 233 | f1 = 2 * prec * reca / (prec + reca + 1e-8) 234 | sent_f1.append(f1) 235 | 236 | # Ground Spans 237 | pred_boxes = [] 238 | if all_atten_score is not None: 239 | span_atten_scores = all_atten_score[bid] 240 | word_atten_scores = attenion_scores[bid] 241 | precomp_boxes_bid = precomp_boxes[bid] 242 | for span in pred_spans: 243 | s, e = span 244 | k = e-s 245 | index = int(k*length - k*(k-1)/2 + s) 246 | span_atten = span_atten_scores[index] 247 | 248 | word_atten = word_atten_scores[s:e+1] 249 | max_word_scores, _ = word_atten.max(1) 250 | select_wid = max_word_scores.max(0)[1] 251 | word2span_atten = word_atten[select_wid] 252 | 253 | select_box_ids = word2span_atten.max(0)[1] 254 | # select_box_ids = (word2span_atten+0.1*span_atten).max(0)[1] 255 | pred_box = precomp_boxes_bid[select_box_ids] 256 | pred_boxes.append(pred_box.tolist()) 257 | 258 | # CCRA 259 | if batch_ground_res is not None: 260 | ground_res = batch_ground_res[bid] 261 | for res in ground_res: 262 | phr = res[0] 263 | if res[1]: 264 | if phr[1] == phr[0]: 265 | ccr_num += 1 266 | elif phr in pred_spans: 267 | ccr_num += 1 268 | 269 | # write results 270 | example_id = batch_map['example_ids'][bid] 271 | s = [idx2word[idx] for idx in sentences[bid].tolist()] 272 | tr_index_conll = copy.deepcopy(tr) 273 | tr = replace_leaves(tr, s) 274 | if options.postprocess: 275 | tr = postprocess(tr, s) 276 | # o = collections.OrderedDict(example_id=str(example_id), tree=tr) 277 | o = collections.OrderedDict(example_id=str(example_id), tree=tr, tree_index_conll=tr_index_conll, 278 | sentence=s, gold_spans=list(gold_spans), pred_spans=list(pred_spans), 279 | pred_boxes=pred_boxes) 280 | f.write(json.dumps(o) + '\n') 281 | 282 | f.close() 283 | ground_acc = recall_num / (total_num + 1e-8) 284 | ccra = ccr_num / (total_num + 1e-8) 285 | # print('grounding acc:{}'.format(ground_acc)) 286 | tp, fp, fn = corpus_f1 287 | prec = tp / (tp + fp) 288 | recall = tp / (tp + fn) 289 | corpus_f1 = 2 * prec * recall / (prec + recall) if prec + recall > 0 else 0. 290 | sent_f1 = np.mean(np.array(sent_f1)) 291 | print('corpus_f1:{} \t sent_f1:{} \t grounding acc:{} \t ccra:{}'.format(corpus_f1, sent_f1, ground_acc, ccra)) 292 | print('recon_loss: {} ; vg_loss: {}; contr_loss: {}; total_loss: {}'.format( 293 | recon_loss/num_data, vg_loss/num_data, contr_loss/num_data, total_loss/num_data)) 294 | 295 | 296 | if __name__ == '__main__': 297 | parser = argument_parser() 298 | options = parse_args(parser) 299 | configure(options) 300 | 301 | run(options) 302 | -------------------------------------------------------------------------------- /cliora/scripts/parse_diora.py: -------------------------------------------------------------------------------- 1 | import os 2 | import collections 3 | import json 4 | 5 | import torch 6 | import torchvision.ops as torchops 7 | import numpy as np 8 | from tqdm import tqdm 9 | 10 | from train import argument_parser, parse_args, configure 11 | from train import get_validation_dataset, get_validation_iterator 12 | from train import build_net 13 | 14 | from cliora.logging.configuration import get_logger 15 | 16 | from cliora.analysis.cky import ParsePredictor as CKY 17 | from cliora.analysis.utils import * 18 | import copy 19 | 20 | punctuation_words = set([x.lower() for x in ['.', ',', ':', '-LRB-', '-RRB-', '\'\'', 21 | '``', '--', ';', '-', '?', '!', '...', '-LCB-', '-RCB-']]) 22 | 23 | 24 | def remove_using_flat_mask(tr, mask): 25 | kept, removed = [], [] 26 | def func(tr, pos=0): 27 | if not isinstance(tr, (list, tuple)): 28 | if mask[pos] == False: 29 | removed.append(tr) 30 | return None, 1 31 | kept.append(tr) 32 | return tr, 1 33 | 34 | size = 0 35 | node = [] 36 | 37 | for subtree in tr: 38 | x, xsize = func(subtree, pos=pos + size) 39 | if x is not None: 40 | node.append(x) 41 | size += xsize 42 | 43 | if len(node) == 1: 44 | node = node[0] 45 | elif len(node) == 0: 46 | return None, size 47 | return node, size 48 | new_tree, _ = func(tr) 49 | return new_tree, kept, removed 50 | 51 | 52 | def flatten_tree(tr): 53 | def func(tr): 54 | if not isinstance(tr, (list, tuple)): 55 | return [tr] 56 | result = [] 57 | for x in tr: 58 | result += func(x) 59 | return result 60 | return func(tr) 61 | 62 | 63 | def postprocess(tr, tokens=None): 64 | if tokens is None: 65 | tokens = flatten_tree(tr) 66 | 67 | # Don't remove the last token. It's not punctuation. 68 | if tokens[-1].lower() not in punctuation_words: 69 | return tr 70 | 71 | mask = [True] * (len(tokens) - 1) + [False] 72 | tr, kept, removed = remove_using_flat_mask(tr, mask) 73 | assert len(kept) == len(tokens) - 1, 'Incorrect tokens left. Output = {}, Kept = {}'.format( 74 | tr, kept) 75 | assert len(kept) > 0, 'No tokens left. Original = {}'.format(tokens) 76 | assert len(removed) == 1 77 | tr = (tr, tokens[-1]) 78 | 79 | return tr 80 | 81 | 82 | def replace_leaves(tree, leaves): 83 | def func(tr, pos=0): 84 | if not isinstance(tr, (list, tuple)): 85 | return 1, leaves[pos] 86 | 87 | newtree = [] 88 | sofar = 0 89 | for node in tr: 90 | size, newnode = func(node, pos+sofar) 91 | sofar += size 92 | newtree += [newnode] 93 | 94 | return sofar, newtree 95 | 96 | _, newtree = func(tree) 97 | 98 | return newtree 99 | 100 | 101 | def run(options): 102 | logger = get_logger() 103 | 104 | validation_dataset = get_validation_dataset(options) 105 | validation_iterator = get_validation_iterator(options, validation_dataset) 106 | word2idx = validation_dataset['word2idx'] 107 | embeddings = validation_dataset['embeddings'] 108 | 109 | idx2word = {v: k for k, v in word2idx.items()} 110 | 111 | logger.info('Initializing model.') 112 | trainer = build_net(options, embeddings) 113 | 114 | # Parse 115 | 116 | diora = trainer.net.diora 117 | 118 | ## Monkey patch parsing specific methods. 119 | override_init_with_batch(diora) 120 | override_inside_hook(diora) 121 | 122 | ## Turn off outside pass. 123 | trainer.net.diora.outside = True 124 | # trainer.net.cliora.outside = False 125 | 126 | ## Eval mode. 127 | trainer.net.eval() 128 | 129 | ## Parse predictor. 130 | parse_predictor = CKY(net=diora) 131 | 132 | batches = validation_iterator.get_iterator(random_seed=options.seed) 133 | 134 | output_path = os.path.abspath(os.path.join(options.experiment_path, 'parse.jsonl')) 135 | 136 | logger.info('Beginning.') 137 | logger.info('Writing output to = {}'.format(output_path)) 138 | 139 | corpus_f1 = [0., 0., 0.] 140 | sent_f1 = [] 141 | f = open(output_path, 'w') 142 | 143 | recon_loss = 0 144 | vg_loss = 0 145 | contr_loss = 0 146 | total_loss = 0 147 | num_data = 0 148 | 149 | with torch.no_grad(): 150 | for i, batch_map in tqdm(enumerate(batches)): 151 | sentences = batch_map['sentences'] 152 | # batch_size = sentences.shape[0] 153 | length = sentences.shape[1] 154 | 155 | # Skip very short sentences. 156 | if length <= 2: 157 | continue 158 | 159 | result = trainer.step(batch_map, idx2word, train=False, compute_loss=True) 160 | recon_loss += result['reconstruction_softmax_loss'] 161 | vg_loss += result['vg_loss'] 162 | contr_loss += result['contrastive_loss'] 163 | total_loss += result['total_loss'] 164 | num_data += 1 165 | 166 | trees = parse_predictor.parse_batch(batch_map) 167 | for bid, tr in enumerate(trees): 168 | # CorpusF1 169 | gold_spans = set(batch_map['GT'][bid][:-1]) 170 | pred_actions = get_actions(str(tr)) 171 | pred_spans = set(get_spans(pred_actions)[:-1]) 172 | tp, fp, fn = get_stats(pred_spans, gold_spans) 173 | corpus_f1[0] += tp 174 | corpus_f1[1] += fp 175 | corpus_f1[2] += fn 176 | 177 | # SentF1 178 | overlap = pred_spans.intersection(gold_spans) 179 | prec = float(len(overlap)) / (len(pred_spans) + 1e-8) 180 | reca = float(len(overlap)) / (len(gold_spans) + 1e-8) 181 | if len(gold_spans) == 0: 182 | reca = 1. 183 | if len(pred_spans) == 0: 184 | prec = 1. 185 | f1 = 2 * prec * reca / (prec + reca + 1e-8) 186 | sent_f1.append(f1) 187 | 188 | 189 | # write results 190 | example_id = batch_map['example_ids'][bid] 191 | s = [idx2word[idx] for idx in sentences[bid].tolist()] 192 | tr_index_conll = copy.deepcopy(tr) 193 | tr = replace_leaves(tr, s) 194 | if options.postprocess: 195 | tr = postprocess(tr, s) 196 | # o = collections.OrderedDict(example_id=str(example_id), tree=tr) 197 | o = collections.OrderedDict(example_id=str(example_id), tree=tr, tree_index_conll=tr_index_conll, 198 | sentence=s, gold_spans=list(gold_spans), pred_spans=list(pred_spans), 199 | ) 200 | f.write(json.dumps(o) + '\n') 201 | 202 | f.close() 203 | # print('grounding acc:{}'.format(ground_acc)) 204 | tp, fp, fn = corpus_f1 205 | prec = tp / (tp + fp) 206 | recall = tp / (tp + fn) 207 | corpus_f1 = 2 * prec * recall / (prec + recall) if prec + recall > 0 else 0. 208 | sent_f1 = np.mean(np.array(sent_f1)) 209 | print('corpus_f1:{} \t sent_f1:{}'.format(corpus_f1, sent_f1)) 210 | print('recon_loss: {} ; vg_loss: {}; contr_loss: {}; total_loss: {}'.format( 211 | recon_loss/num_data, vg_loss/num_data, contr_loss/num_data, total_loss/num_data)) 212 | 213 | 214 | if __name__ == '__main__': 215 | parser = argument_parser() 216 | options = parse_args(parser) 217 | configure(options) 218 | 219 | run(options) 220 | -------------------------------------------------------------------------------- /cliora/scripts/phrase_embed.py: -------------------------------------------------------------------------------- 1 | """ 2 | A script to embed every phrase in a dataset as a dense vector, then 3 | to find the top-k neighbors of each phrase according to cosine 4 | similarity. 5 | 6 | 1. Install missing dependencies. 7 | 8 | # More details: https://github.com/facebookresearch/faiss/blob/master/INSTALL.md 9 | conda install faiss-cpu -c cliora 10 | 11 | 2. Prepare data. For example, the chunking dataset from CoNLL 2000. 12 | 13 | wget https://www.clips.uantwerpen.be/conll2000/chunking/train.txt.gz 14 | gunzip train.txt.gz 15 | python cliora/misc/convert_conll_to_jsonl.py --path train.txt > conll-train.jsonl 16 | 17 | 3. Run this script. 18 | 19 | python cliora/scripts/phrase_embed.py \ 20 | --batch_size 10 \ 21 | --emb w2v \ 22 | --embeddings_path ~/data/glove.6B/glove.6B.50d.txt \ 23 | --hidden_dim 50 \ 24 | --log_every_batch 100 \ 25 | --save_after 1000 \ 26 | --data_type conll_jsonl \ 27 | --validation_path ./conll-train.jsonl \ 28 | --validation_filter_length 10 29 | 30 | Can control the number of neighbors to show with the `--k_top` flag. 31 | 32 | Can control the number of candidates to consider with `--k_candidates` flag. 33 | 34 | """ 35 | 36 | 37 | import json 38 | import types 39 | import itertools 40 | 41 | import torch 42 | 43 | import numpy as np 44 | 45 | from train import argument_parser, parse_args, configure 46 | from train import get_validation_dataset, get_validation_iterator 47 | from train import build_net 48 | 49 | from cliora.logging.configuration import get_logger 50 | 51 | try: 52 | import faiss 53 | from faiss import normalize_L2 54 | except: 55 | print('Could not import `faiss`, which is used to find nearest neighbors.') 56 | 57 | 58 | def get_cell_index(entity_labels, i_label=0, i_pos=1, i_size=2): 59 | def helper(): 60 | for i, lst in enumerate(entity_labels): 61 | for el in lst: 62 | if el is None: 63 | continue 64 | pos = el[i_pos] 65 | size = el[i_size] 66 | label = el[i_label] 67 | yield (i, pos, size, label) 68 | lst = list(helper()) 69 | if len(lst) == 0: 70 | return None, [] 71 | batch_index = [x[0] for x in lst] 72 | positions = [x[1] for x in lst] 73 | sizes = [x[2] for x in lst] 74 | labels = [x[3] for x in lst] 75 | 76 | return batch_index, positions, sizes, labels 77 | 78 | 79 | def get_many_cells(diora, chart, batch_index, positions, sizes): 80 | cells = [] 81 | length = diora.length 82 | 83 | idx = [] 84 | for bi, pos, size in zip(batch_index, positions, sizes): 85 | level = size - 1 86 | offset = diora.index.get_offset(length)[level] 87 | absolute_pos = offset + pos 88 | idx.append(absolute_pos) 89 | 90 | cells = chart[batch_index, idx] 91 | 92 | return cells 93 | 94 | 95 | def get_many_phrases(batch, batch_index, positions, sizes): 96 | batch = batch.tolist() 97 | lst = [] 98 | for bi, pos, size in zip(batch_index, positions, sizes): 99 | phrase = tuple(batch[bi][pos:pos+size]) 100 | lst.append(phrase) 101 | return lst 102 | 103 | 104 | class BatchRecorder(object): 105 | def __init__(self, dtype={}): 106 | super(BatchRecorder, self).__init__() 107 | self.cache = {} 108 | self.dtype = dtype 109 | self.dtype2flatten = { 110 | 'list': self._flatten_list, 111 | 'np': self._flatten_np, 112 | 'torch': self._flatten_torch, 113 | } 114 | 115 | def _flatten_list(self, v): 116 | return list(itertools.chain(*v)) 117 | 118 | def _flatten_np(self, v): 119 | return np.concatenate(v, axis=0) 120 | 121 | def _flatten_torch(self, v): 122 | return torch.cat(v, 0).cpu().data.numpy() 123 | 124 | def get_flattened_result(self): 125 | def helper(): 126 | for k, v in self.cache.items(): 127 | flatten = self.dtype2flatten[self.dtype.get(k, 'list')] 128 | yield k, flatten(v) 129 | return {k: v for k, v in helper()} 130 | 131 | def record(self, **kwargs): 132 | for k, v in kwargs.items(): 133 | self.cache.setdefault(k, []).append(v) 134 | 135 | 136 | class Index(object): 137 | def __init__(self, dim=None): 138 | super(Index, self).__init__() 139 | self.D, self.I = None, None 140 | self.index = faiss.IndexFlatIP(dim) 141 | 142 | def add(self, vecs): 143 | self.index.add(vecs) 144 | 145 | def cache(self, vecs, k): 146 | self.D, self.I = self.index.search(vecs, k) 147 | 148 | def topk(self, q, k): 149 | for j in range(k): 150 | idx = self.I[q][j] 151 | dist = self.D[q][j] 152 | yield idx, dist 153 | 154 | 155 | class NearestNeighborsLookup(object): 156 | def __init__(self): 157 | super(NearestNeighborsLookup, self).__init__() 158 | 159 | 160 | def run(options): 161 | logger = get_logger() 162 | 163 | validation_dataset = get_validation_dataset(options) 164 | validation_iterator = get_validation_iterator(options, validation_dataset) 165 | word2idx = validation_dataset['word2idx'] 166 | embeddings = validation_dataset['embeddings'] 167 | 168 | idx2word = {v: k for k, v in word2idx.items()} 169 | 170 | logger.info('Initializing model.') 171 | trainer = build_net(options, embeddings) 172 | diora = trainer.net.diora 173 | 174 | # 1. Get all relevant phrase vectors. 175 | 176 | dtype = { 177 | 'example_ids': 'list', 178 | 'labels': 'list', 179 | 'positions': 'list', 180 | 'sizes': 'list', 181 | 'phrases': 'list', 182 | 'inside': 'torch', 183 | 'outside': 'torch', 184 | } 185 | batch_recorder = BatchRecorder(dtype=dtype) 186 | 187 | ## Eval mode. 188 | trainer.net.eval() 189 | 190 | batches = validation_iterator.get_iterator(random_seed=options.seed) 191 | 192 | logger.info('Beginning to embed phrases.') 193 | 194 | with torch.no_grad(): 195 | for i, batch_map in enumerate(batches): 196 | sentences = batch_map['sentences'] 197 | batch_size = sentences.shape[0] 198 | length = sentences.shape[1] 199 | 200 | # Skips very short examples. 201 | if length <= 2: 202 | continue 203 | 204 | _ = trainer.step(batch_map, train=False, compute_loss=False) 205 | 206 | entity_labels = batch_map['entity_labels'] 207 | batch_index, positions, sizes, labels = get_cell_index(entity_labels) 208 | 209 | # Skip short phrases. 210 | batch_index = [x for x, y in zip(batch_index, sizes) if y >= 2] 211 | positions = [x for x, y in zip(positions, sizes) if y >= 2] 212 | labels = [x for x, y in zip(labels, sizes) if y >= 2] 213 | sizes = [y for y in sizes if y >= 2] 214 | 215 | cell_index = (batch_index, positions, sizes) 216 | 217 | batch_result = {} 218 | batch_result['example_ids'] = [batch_map['example_ids'][idx] for idx in cell_index[0]] 219 | batch_result['labels'] = labels 220 | batch_result['positions'] = cell_index[1] 221 | batch_result['sizes'] = cell_index[2] 222 | batch_result['phrases'] = get_many_phrases(sentences, *cell_index) 223 | batch_result['inside'] = get_many_cells(diora, diora.inside_h, *cell_index) 224 | batch_result['outside'] = get_many_cells(diora, diora.outside_h, *cell_index) 225 | 226 | batch_recorder.record(**batch_result) 227 | 228 | result = batch_recorder.get_flattened_result() 229 | 230 | # 2. Build an index of nearest neighbors. 231 | 232 | vectors = np.concatenate([result['inside'], result['outside']], axis=1) 233 | normalize_L2(vectors) 234 | 235 | index = Index(dim=vectors.shape[1]) 236 | index.add(vectors) 237 | index.cache(vectors, options.k_candidates) 238 | 239 | # 3. Print a summary. 240 | 241 | example_ids = result['example_ids'] 242 | phrases = result['phrases'] 243 | 244 | assert len(example_ids) == len(phrases) 245 | assert len(example_ids) == vectors.shape[0] 246 | 247 | def stringify(phrase): 248 | return ' '.join([idx2word[idx] for idx in phrase]) 249 | 250 | for i in range(vectors.shape[0]): 251 | topk = [] 252 | 253 | for j, score in index.topk(i, options.k_candidates): 254 | # Skip same example. 255 | if example_ids[i] == example_ids[j]: 256 | continue 257 | # Skip string match. 258 | if phrases[i] == phrases[j]: 259 | continue 260 | topk.append((j, score)) 261 | if len(topk) == options.k_top: 262 | break 263 | assert len(topk) == options.k_top, 'Did not find enough valid candidates.' 264 | 265 | # Print. 266 | print('[query] example_id={} phrase={}'.format( 267 | example_ids[i], stringify(phrases[i]))) 268 | for rank, (j, score) in enumerate(topk): 269 | print('rank={} score={:.3f} example_id={} phrase={}'.format( 270 | rank, score, example_ids[j], stringify(phrases[j]))) 271 | 272 | 273 | if __name__ == '__main__': 274 | parser = argument_parser() 275 | parser.add_argument('--k_candidates', default=100, type=int) 276 | parser.add_argument('--k_top', default=3, type=int) 277 | options = parse_args(parser) 278 | configure(options) 279 | 280 | run(options) 281 | -------------------------------------------------------------------------------- /cliora/scripts/phrase_embed_simple.py: -------------------------------------------------------------------------------- 1 | import os 2 | import collections 3 | import json 4 | import types 5 | 6 | import torch 7 | import numpy as np 8 | from tqdm import tqdm 9 | 10 | from train import argument_parser, parse_args, configure 11 | from train import get_validation_dataset, get_validation_iterator 12 | from train import build_net 13 | 14 | from cliora.logging.configuration import get_logger 15 | 16 | from cliora.analysis.cky import ParsePredictor as CKY 17 | 18 | 19 | punctuation_words = set([x.lower() for x in ['.', ',', ':', '-LRB-', '-RRB-', '\'\'', 20 | '``', '--', ';', '-', '?', '!', '...', '-LCB-', '-RCB-']]) 21 | 22 | 23 | def remove_using_flat_mask(tr, mask): 24 | kept, removed = [], [] 25 | def func(tr, pos=0): 26 | if not isinstance(tr, (list, tuple)): 27 | if mask[pos] == False: 28 | removed.append(tr) 29 | return None, 1 30 | kept.append(tr) 31 | return tr, 1 32 | 33 | size = 0 34 | node = [] 35 | 36 | for subtree in tr: 37 | x, xsize = func(subtree, pos=pos + size) 38 | if x is not None: 39 | node.append(x) 40 | size += xsize 41 | 42 | if len(node) == 1: 43 | node = node[0] 44 | elif len(node) == 0: 45 | return None, size 46 | return node, size 47 | new_tree, _ = func(tr) 48 | return new_tree, kept, removed 49 | 50 | 51 | def flatten_tree(tr): 52 | def func(tr): 53 | if not isinstance(tr, (list, tuple)): 54 | return [tr] 55 | result = [] 56 | for x in tr: 57 | result += func(x) 58 | return result 59 | return func(tr) 60 | 61 | 62 | def tree_to_spans(tree): 63 | spans = [] 64 | def helper(tr, pos=0): 65 | if isinstance(tr, str) or len(tr) == 1 and isinstance(tr[0], str): 66 | return 1 67 | if len(tr) == 1: 68 | return helper(tr[0], pos) 69 | size = 0 70 | for x in tr: 71 | xsize = helper(x, pos+size) 72 | size += xsize 73 | spans.append((pos, size)) 74 | return size 75 | _ = helper(tree) 76 | return spans 77 | 78 | 79 | def postprocess(tr, tokens=None): 80 | if tokens is None: 81 | tokens = flatten_tree(tr) 82 | 83 | # Don't remove the last token. It's not punctuation. 84 | if tokens[-1].lower() not in punctuation_words: 85 | return tr 86 | 87 | mask = [True] * (len(tokens) - 1) + [False] 88 | tr, kept, removed = remove_using_flat_mask(tr, mask) 89 | assert len(kept) == len(tokens) - 1, 'Incorrect tokens left. Original = {}, Output = {}, Kept = {}'.format( 90 | binary_tree, tr, kept) 91 | assert len(kept) > 0, 'No tokens left. Original = {}'.format(tokens) 92 | assert len(removed) == 1 93 | tr = (tr, tokens[-1]) 94 | 95 | return tr 96 | 97 | 98 | def override_init_with_batch(var): 99 | init_with_batch = var.init_with_batch 100 | 101 | def func(self, *args, **kwargs): 102 | init_with_batch(*args, **kwargs) 103 | self.saved_scalars = {i: {} for i in range(self.length)} 104 | self.saved_scalars_out = {i: {} for i in range(self.length)} 105 | 106 | var.init_with_batch = types.MethodType(func, var) 107 | 108 | 109 | def override_inside_hook(var): 110 | def func(self, level, h, c, s): 111 | length = self.length 112 | B = self.batch_size 113 | L = length - level 114 | 115 | assert s.shape[0] == B 116 | assert s.shape[1] == L 117 | # assert s.shape[2] == N 118 | assert s.shape[3] == 1 119 | assert len(s.shape) == 4 120 | smax = s.max(2, keepdim=True)[0] 121 | s = s - smax 122 | 123 | for pos in range(L): 124 | self.saved_scalars[level][pos] = s[:, pos, :] 125 | 126 | var.inside_hook = types.MethodType(func, var) 127 | 128 | 129 | def replace_leaves(tree, leaves): 130 | def func(tr, pos=0): 131 | if not isinstance(tr, (list, tuple)): 132 | return 1, leaves[pos] 133 | 134 | newtree = [] 135 | sofar = 0 136 | for node in tr: 137 | size, newnode = func(node, pos+sofar) 138 | sofar += size 139 | newtree += [newnode] 140 | 141 | return sofar, newtree 142 | 143 | _, newtree = func(tree) 144 | 145 | return newtree 146 | 147 | 148 | class TreeHelper(object): 149 | def __init__(self, diora, word2idx): 150 | self.diora = diora 151 | self.word2idx = word2idx 152 | self.idx2word = {idx: w for w, idx in self.word2idx.items()} 153 | 154 | def init(self, options): 155 | if options.parse_mode == 'latent': 156 | self.parse_predictor = CKY(net=self.diora, word2idx=self.word2idx) 157 | ## Monkey patch parsing specific methods. 158 | override_init_with_batch(self.diora) 159 | override_inside_hook(self.diora) 160 | 161 | def get_trees_for_batch(self, batch_map, options): 162 | sentences = batch_map['sentences'] 163 | batch_size = sentences.shape[0] 164 | length = sentences.shape[1] 165 | 166 | # trees 167 | if options.parse_mode == 'all-spans': 168 | raise Exception('Does not support this mode.') 169 | elif options.parse_mode == 'latent': 170 | trees = self.parse_predictor.parse_batch(batch_map) 171 | elif options.parse_mode == 'given': 172 | trees = batch_map['trees'] 173 | 174 | # spans 175 | spans = [] 176 | for ii, tr in enumerate(trees): 177 | s = [self.idx2word[idx] for idx in sentences[ii].tolist()] 178 | tr = replace_leaves(tr, s) 179 | if options.postprocess: 180 | tr = postprocess(tr, s) 181 | spans.append(tree_to_spans(tr)) 182 | 183 | return trees, spans 184 | 185 | 186 | class CSVHelper(object): 187 | def __init__(self): 188 | self.header = ['example_id', 'position', 'size'] 189 | 190 | def write_header(self, f): 191 | f.write(','.join(self.header) + '\n') 192 | 193 | def write_row(self, f, data): 194 | row = ','.join([data[k] for k in self.header]) 195 | f.write(row + '\n') 196 | 197 | 198 | def run(options): 199 | logger = get_logger() 200 | 201 | validation_dataset = get_validation_dataset(options) 202 | validation_iterator = get_validation_iterator(options, validation_dataset) 203 | word2idx = validation_dataset['word2idx'] 204 | embeddings = validation_dataset['embeddings'] 205 | 206 | idx2word = {v: k for k, v in word2idx.items()} 207 | 208 | logger.info('Initializing model.') 209 | trainer = build_net(options, embeddings, validation_iterator) 210 | diora = trainer.net.diora 211 | tree_helper = TreeHelper(diora, word2idx) 212 | tree_helper.init(options) 213 | csv_helper = CSVHelper() 214 | 215 | ## Eval mode. 216 | trainer.net.eval() 217 | 218 | batches = validation_iterator.get_iterator(random_seed=options.seed) 219 | 220 | meta_output_path = os.path.abspath(os.path.join(options.experiment_path, 'vectors.csv')) 221 | vec_output_path = os.path.abspath(os.path.join(options.experiment_path, 'vectors.npy')) 222 | 223 | logger.info('Beginning.') 224 | logger.info('Writing vectors to = {}'.format(vec_output_path)) 225 | logger.info('Writing metadata to = {}'.format(meta_output_path)) 226 | 227 | f_csv = open(meta_output_path, 'w') 228 | f_vec = open(vec_output_path, 'ab') 229 | csv_helper.write_header(f_csv) 230 | 231 | with torch.no_grad(): 232 | for i, batch_map in tqdm(enumerate(batches)): 233 | sentences = batch_map['sentences'] 234 | batch_size = sentences.shape[0] 235 | length = sentences.shape[1] 236 | 237 | # Skip very short sentences. 238 | if length <= 2: 239 | continue 240 | 241 | _ = trainer.step(batch_map, train=False, compute_loss=False) 242 | 243 | if options.parse_mode == 'all-spans': 244 | for ii in range(batch_size): 245 | example_id = batch_map['example_ids'][ii] 246 | for level in range(length): 247 | size = level + 1 248 | for pos in range(length - level): 249 | # metadata 250 | csv_helper.write_row(f_csv, 251 | collections.OrderedDict( 252 | example_id=example_id, 253 | position=str(pos), 254 | size=str(size) 255 | )) 256 | inside_vectors = diora.inside_h.view(-1, options.hidden_dim) 257 | outside_vectors = diora.outside_h.view(-1, options.hidden_dim) 258 | 259 | else: 260 | trees, spans = tree_helper.get_trees_for_batch(batch_map, options) 261 | 262 | batch_index = [] 263 | cell_index = [] 264 | offset_cache = diora.index.get_offset(length) 265 | 266 | for ii, sp_lst in enumerate(spans): 267 | example_id = batch_map['example_ids'][ii] 268 | for pos, size in sp_lst: 269 | # metadata 270 | csv_helper.write_row(f_csv, 271 | collections.OrderedDict( 272 | example_id=example_id, 273 | position=str(pos), 274 | size=str(size) 275 | )) 276 | # for vectors 277 | level = size - 1 278 | cell = offset_cache[level] + pos 279 | batch_index.append(ii) 280 | cell_index.append(cell) 281 | 282 | inside_vectors = diora.inside_h[batch_index, cell_index] 283 | assert inside_vectors.shape == (len(batch_index), options.hidden_dim) 284 | outside_vectors = diora.outside_h[batch_index, cell_index] 285 | assert outside_vectors.shape == (len(batch_index), options.hidden_dim) 286 | 287 | vectors = np.concatenate([inside_vectors, outside_vectors], axis=1) 288 | np.savetxt(f_vec, vectors) 289 | 290 | f_csv.close() 291 | f_vec.close() 292 | 293 | # X = np.loadtxt(vec_output_path) 294 | # print(X.shape) 295 | 296 | 297 | if __name__ == '__main__': 298 | parser = argument_parser() 299 | parser.add_argument('--parse_mode', default='latent', choices=('all-spans', 'latent', 'given'), help= 300 | 'Save vectors for...\n- `all-spans`: the whole chart,\n- `latent`: the latent tree,\n- `given`: a given tree.') 301 | options = parse_args(parser) 302 | configure(options) 303 | 304 | run(options) 305 | -------------------------------------------------------------------------------- /cliora/scripts/right_branch.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | import numpy as np 4 | from tqdm import tqdm 5 | 6 | from train import argument_parser, parse_args, configure 7 | from train import get_validation_dataset, get_validation_iterator 8 | 9 | from cliora.analysis.utils import * 10 | 11 | def run(options): 12 | validation_dataset = get_validation_dataset(options) 13 | validation_iterator = get_validation_iterator(options, validation_dataset) 14 | batches = validation_iterator.get_iterator(random_seed=options.seed) 15 | print('Beginning.') 16 | corpus_f1 = [0., 0., 0.] 17 | sent_f1 = [] 18 | 19 | with torch.no_grad(): 20 | for i, batch_map in tqdm(enumerate(batches)): 21 | sentences = batch_map['sentences'] 22 | batch_size = sentences.shape[0] 23 | length = sentences.shape[1] 24 | 25 | # Skip very short sentences. 26 | if length < 2: 27 | continue 28 | 29 | for bid in range(batch_size): 30 | # CorpusF1 31 | # gold_spans = set(batch_map['GT'][bid][5][:-1]) 32 | gold_spans = set(batch_map['GT'][bid][:-1]) 33 | # right branch 34 | pred_span = [(i, length-1) for i in range(length-1)] 35 | pred_spans = set(pred_span[1:]) 36 | # left branch 37 | # pred_spans = set([(0, i) for i in range(1, length-1)]) 38 | # tp, fp, fn = get_stats(pred_spans, gold_spans) 39 | tp = len(gold_spans); fp = len(pred_spans) - tp; fn = 0 40 | corpus_f1[0] += tp 41 | corpus_f1[1] += fp 42 | corpus_f1[2] += fn 43 | 44 | # SentF1 45 | overlap = pred_spans.intersection(gold_spans) 46 | prec = float(len(overlap)) / (len(pred_spans) + 1e-8) 47 | reca = float(len(overlap)) / (len(gold_spans) + 1e-8) 48 | if len(gold_spans) == 0: 49 | reca = 1. 50 | if len(pred_spans) == 0: 51 | prec = 1. 52 | f1 = 2 * prec * reca / (prec + reca + 1e-8) 53 | sent_f1.append(f1) 54 | 55 | 56 | tp, fp, fn = corpus_f1 57 | prec = tp / (tp + fp) 58 | recall = tp / (tp + fn) 59 | corpus_f1 = 2 * prec * recall / (prec + recall) if prec + recall > 0 else 0. 60 | sent_f1 = np.mean(np.array(sent_f1)) 61 | print('corpus_f1:{} \t sent_f1:{}'.format(corpus_f1, sent_f1)) 62 | 63 | 64 | if __name__ == '__main__': 65 | parser = argument_parser() 66 | options = parse_args(parser) 67 | configure(options) 68 | 69 | run(options) 70 | -------------------------------------------------------------------------------- /cliora/scripts/train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import os 4 | import random 5 | import uuid 6 | 7 | import torch 8 | import torchvision.ops as torchops 9 | import numpy as np 10 | from tqdm import tqdm 11 | import sys 12 | from cliora.data.dataset import ConsolidateDatasets, ReconstructDataset, make_batch_iterator 13 | 14 | from cliora.utils.path import package_path 15 | from cliora.logging.configuration import configure_experiment, get_logger 16 | from cliora.utils.flags import stringify_flags, init_with_flags_file, save_flags 17 | from cliora.utils.checkpoint import save_experiment 18 | 19 | from cliora.net.experiment_logger import ExperimentLogger 20 | from cliora.analysis.cky import ParsePredictor as CKY 21 | from cliora.analysis.diora_tree import TreesFromDiora 22 | from cliora.analysis.utils import * 23 | 24 | data_types_choices = ('coco', 'flickr') 25 | 26 | 27 | def count_params(net): 28 | return sum([x.numel() for x in net.parameters() if x.requires_grad]) 29 | 30 | 31 | def build_net(options, embeddings): 32 | from cliora.net.trainer import build_net 33 | 34 | trainer = build_net(options, embeddings, random_seed=options.seed) 35 | 36 | logger = get_logger() 37 | logger.info('# of params = {}'.format(count_params(trainer.net))) 38 | 39 | return trainer 40 | 41 | 42 | def generate_seeds(n, seed=11): 43 | random.seed(seed) 44 | seeds = [random.randint(0, 2**16) for _ in range(n)] 45 | return seeds 46 | 47 | 48 | def run_train(options, train_iterator, trainer, validation_iterator): 49 | logger = get_logger() 50 | experiment_logger = ExperimentLogger() 51 | save_emb = options.emb == 'none' 52 | 53 | logger.info('Running train.') 54 | 55 | seeds = generate_seeds(options.max_epoch, options.seed) 56 | word2idx = train_iterator.word2idx 57 | idx2word = {v: k for k, v in word2idx.items()} 58 | 59 | step = 0 60 | best_f1 = 0. 61 | # run_eval(options, trainer, validation_iterator) 62 | if not options.multigpu or options.local_rank == 0: 63 | if options.arch == 'hard': 64 | run_eval(options, trainer, validation_iterator) 65 | 66 | for epoch, seed in zip(range(options.max_epoch), seeds): 67 | # --- Train--- # 68 | 69 | # seed = seeds[epoch] 70 | 71 | logger.info('epoch={} seed={}'.format(epoch, seed)) 72 | 73 | def myiterator(): 74 | it = train_iterator.get_iterator(random_seed=seed) 75 | 76 | count = 0 77 | 78 | for batch_map in it: 79 | # TODO: Skip short examples (optionally). 80 | if batch_map['length'] <= 2: 81 | continue 82 | 83 | yield count, batch_map 84 | count += 1 85 | 86 | for batch_idx, batch_map in myiterator(): 87 | # if options.finetune and step >= options.finetune_after: 88 | # trainer.freeze_diora() 89 | 90 | # trainer.py 359 91 | result = trainer.step(batch_map, idx2word) 92 | 93 | experiment_logger.record(result) 94 | 95 | if step % options.log_every_batch == 0: 96 | experiment_logger.log_batch(epoch, step, batch_idx, batch_size=options.batch_size) 97 | 98 | del result 99 | 100 | step += 1 101 | 102 | experiment_logger.log_epoch(epoch, step) 103 | 104 | # Epoch Eval and Checkpoints -- # 105 | if not options.multigpu or options.local_rank == 0: 106 | trainer.save_model(save_emb, os.path.join(options.experiment_path, 'model.epoch_{}.pt'.format(epoch))) 107 | save_experiment(os.path.join(options.experiment_path, 'experiment.epoch_{}.json'.format(epoch)), step) 108 | 109 | corpus_f1 = run_eval(options, trainer, validation_iterator) 110 | if corpus_f1 > best_f1: 111 | best_f1 = corpus_f1 112 | logger.info('Saving model epoch {}, corpus_f1: {}, best_f1: {}.'.format(epoch, corpus_f1, best_f1)) 113 | 114 | if options.max_step is not None and step >= options.max_step: 115 | logger.info('Max-Step={} Quitting.'.format(options.max_step)) 116 | sys.exit() 117 | 118 | 119 | def run_eval(options, trainer, validation_iterator): 120 | logger = get_logger() 121 | 122 | # Eval mode. 123 | trainer.net.eval() 124 | if options.multigpu: 125 | diora = trainer.net.module.diora 126 | else: 127 | diora = trainer.net.diora 128 | # cliora.outside = False 129 | # cliora.outside = True # TODO 130 | diora.outside = options.obj_feats 131 | 132 | if options.arch == 'hard': 133 | diora.safe_set_K(2) 134 | parse_predictor = TreesFromDiora(net=diora) 135 | else: 136 | override_init_with_batch(diora) 137 | override_inside_hook(diora) 138 | parse_predictor = CKY(net=diora) 139 | batches = validation_iterator.get_iterator(random_seed=options.seed) 140 | 141 | logger.info('####### Beginning Eval #######') 142 | 143 | total_num = 0 144 | recall_num = 0 145 | corpus_f1 = [0., 0., 0.] 146 | sent_f1 = [] 147 | with torch.no_grad(): 148 | for i, batch_map in tqdm(enumerate(batches)): 149 | sentences = batch_map['sentences'] 150 | length = sentences.shape[1] 151 | 152 | # Skip very short sentences. 153 | if length <= 2: 154 | continue 155 | 156 | _ = trainer.step(batch_map, train=False, compute_loss=False) 157 | 158 | # Grounding eval 159 | if diora.atten_score is not None: 160 | targets = batch_map['VG_GT'] 161 | batch_size = len(targets) 162 | attenion_scores = diora.atten_score.cpu() 163 | precomp_boxes = batch_map['boxes'].cpu() 164 | for bid in range(batch_size): 165 | target_bid, noun_mask = targets[bid] 166 | precomp_boxes_bid = precomp_boxes[bid] 167 | attenion_scores_bid = attenion_scores[bid] 168 | select_scores, select_box_ids = attenion_scores_bid.max(1) 169 | pred_boxes = precomp_boxes_bid[select_box_ids] 170 | 171 | for _, gt_anno in target_bid.items(): 172 | start_id, end_id, gt_box = gt_anno 173 | pred_box = pred_boxes[start_id:end_id] 174 | select_score = select_scores[start_id:end_id] 175 | select_id = select_score.max(0)[1] 176 | iou = torchops.box_iou(pred_box[select_id][None, :], torch.Tensor([gt_box])) 177 | if iou.max() > 0.5: 178 | recall_num += 1 179 | total_num += 1 180 | 181 | # Parsing eval 182 | trees = parse_predictor.parse_batch(batch_map) 183 | 184 | for bid, tr in enumerate(trees): 185 | # CorpusF1 186 | # gold_spans = set(batch_map['GT'][bid][5][:-1]) 187 | gold_spans = set(batch_map['GT'][bid][:-1]) 188 | pred_actions = get_actions(str(tr)) 189 | pred_spans = set(get_spans(pred_actions)[:-1]) 190 | tp, fp, fn = get_stats(pred_spans, gold_spans) 191 | corpus_f1[0] += tp 192 | corpus_f1[1] += fp 193 | corpus_f1[2] += fn 194 | 195 | # SentF1 196 | overlap = pred_spans.intersection(gold_spans) 197 | prec = float(len(overlap)) / (len(pred_spans) + 1e-8) 198 | reca = float(len(overlap)) / (len(gold_spans) + 1e-8) 199 | if len(gold_spans) == 0: 200 | reca = 1. 201 | if len(pred_spans) == 0: 202 | prec = 1. 203 | f1 = 2 * prec * reca / (prec + reca + 1e-8) 204 | sent_f1.append(f1) 205 | 206 | ground_acc = recall_num / (total_num + 1e-8) 207 | # logger.info('grounding acc:{}'.format(ground_acc)) 208 | # return ground_acc 209 | tp, fp, fn = corpus_f1 210 | prec = tp / (tp + fp) 211 | recall = tp / (tp + fn) 212 | corpus_f1 = 2 * prec * recall / (prec + recall) if prec + recall > 0 else 0. 213 | sent_f1 = np.mean(np.array(sent_f1)) 214 | logger.info('corpus_f1:{} \t sent_f1:{} \t grounding acc:{}'.format(corpus_f1, sent_f1, ground_acc)) 215 | 216 | # Train mode. 217 | diora.outside = True 218 | trainer.net.train() 219 | return corpus_f1 220 | 221 | 222 | def get_train_dataset(options): 223 | return ReconstructDataset().initialize(options, text_path=options.train_path, 224 | embeddings_path=options.embeddings_path, filter_length=options.train_filter_length, 225 | data_type=options.train_data_type) 226 | 227 | 228 | def get_train_iterator(options, dataset): 229 | return make_batch_iterator(options, dataset, mode='train', shuffle=True, 230 | include_partial=False, filter_length=options.train_filter_length, 231 | batch_size=options.batch_size, length_to_size=options.length_to_size) 232 | 233 | 234 | def get_validation_dataset(options): 235 | return ReconstructDataset().initialize(options, text_path=options.validation_path, 236 | embeddings_path=options.embeddings_path, filter_length=options.validation_filter_length, 237 | data_type=options.validation_data_type) 238 | 239 | 240 | def get_validation_iterator(options, dataset): 241 | return make_batch_iterator(options, dataset, mode='test', shuffle=False, 242 | include_partial=True, filter_length=options.validation_filter_length, 243 | batch_size=options.validation_batch_size, length_to_size=options.length_to_size) 244 | 245 | 246 | def get_train_and_validation(options): 247 | train_dataset = get_train_dataset(options) 248 | validation_dataset = get_validation_dataset(options) 249 | 250 | # Modifies datasets. Unifying word mappings, embeddings, etc. 251 | if options.data_type not in ['coco', 'flickr']: 252 | ConsolidateDatasets([train_dataset, validation_dataset]).run() 253 | 254 | return train_dataset, validation_dataset 255 | 256 | 257 | def run(options): 258 | logger = get_logger() 259 | # experiment_logger = ExperimentLogger() 260 | 261 | train_dataset, validation_dataset = get_train_and_validation(options) 262 | if options.debug: 263 | train_iterator = get_validation_iterator(options, validation_dataset) 264 | else: 265 | train_iterator = get_train_iterator(options, train_dataset) 266 | validation_iterator = get_validation_iterator(options, validation_dataset) 267 | embeddings = train_dataset['embeddings'] 268 | 269 | logger.info('Initializing model.') 270 | trainer = build_net(options, embeddings) 271 | logger.info('Model:') 272 | for name, p in trainer.net.named_parameters(): 273 | logger.info('{} {} {}'.format(name, p.shape, p.requires_grad)) 274 | 275 | run_train(options, train_iterator, trainer, validation_iterator) 276 | 277 | 278 | def argument_parser(): 279 | parser = argparse.ArgumentParser() 280 | 281 | # Debug. 282 | parser.add_argument('--debug', action='store_true') 283 | parser.add_argument('--seed', default=11, type=int) 284 | parser.add_argument('--git_sha', default=None, type=str) 285 | parser.add_argument('--git_branch_name', default=None, type=str) 286 | parser.add_argument('--git_dirty', default=None, type=str) 287 | parser.add_argument('--uuid', default=None, type=str) 288 | parser.add_argument('--model_flags', default=None, type=str, 289 | help='Load model settings from a flags file.') 290 | parser.add_argument('--flags', default=None, type=str, 291 | help='Load any settings from a flags file.') 292 | 293 | parser.add_argument('--master_addr', default='127.0.0.1', type=str) 294 | parser.add_argument('--master_port', default='29500', type=str) 295 | parser.add_argument('--world_size', default=None, type=int) 296 | 297 | # Pytorch 298 | parser.add_argument('--cuda', action='store_true') 299 | parser.add_argument('--multigpu', action='store_true') 300 | parser.add_argument("--local_rank", default=None, type=int) # for distributed-data-parallel 301 | 302 | # Logging. 303 | parser.add_argument('--default_experiment_directory', default=os.path.join(package_path(), '..', 'log'), type=str) 304 | parser.add_argument('--experiment_name', default=None, type=str) 305 | parser.add_argument('--experiment_path', default=None, type=str) 306 | parser.add_argument('--log_every_batch', default=10, type=int) 307 | parser.add_argument('--save_latest', default=1000, type=int) 308 | parser.add_argument('--save_distinct', default=5000, type=int) 309 | parser.add_argument('--save_after', default=1000, type=int) 310 | 311 | # Loading. 312 | parser.add_argument('--load_model_path', default=None, type=str) 313 | 314 | # Data. 315 | parser.add_argument('--data_type', default='nli', choices=data_types_choices) 316 | parser.add_argument('--train_data_type', default=None, choices=data_types_choices) 317 | parser.add_argument('--validation_data_type', default=None, choices=data_types_choices) 318 | parser.add_argument('--train_path', default=os.path.expanduser('~/data/snli_1.0/snli_1.0_train.jsonl'), type=str) 319 | parser.add_argument('--validation_path', default=os.path.expanduser('~/data/snli_1.0/snli_1.0_dev.jsonl'), type=str) 320 | parser.add_argument('--embeddings_path', default=os.path.expanduser('~/data/glove/glove.6B.300d.txt'), type=str) 321 | 322 | # Data (synthetic). 323 | parser.add_argument('--synthetic-nexamples', default=1000, type=int) 324 | parser.add_argument('--synthetic-vocabsize', default=1000, type=int) 325 | parser.add_argument('--synthetic-embeddingsize', default=1024, type=int) 326 | parser.add_argument('--synthetic-minlen', default=20, type=int) 327 | parser.add_argument('--synthetic-maxlen', default=21, type=int) 328 | parser.add_argument('--synthetic-seed', default=11, type=int) 329 | parser.add_argument('--synthetic-length', default=None, type=int) 330 | parser.add_argument('--use-synthetic-embeddings', action='store_true') 331 | 332 | # Data (preprocessing). 333 | parser.add_argument('--uppercase', action='store_true') 334 | parser.add_argument('--train_filter_length', default=50, type=int) 335 | parser.add_argument('--validation_filter_length', default=0, type=int) 336 | 337 | # Model. 338 | parser.add_argument('--arch', default='mlp', choices=('mlp', 'hard')) 339 | parser.add_argument('--share', action='store_false') 340 | parser.add_argument('--hidden_dim', default=400, type=int) 341 | parser.add_argument('--normalize', default='unit', choices=('none', 'unit')) 342 | parser.add_argument('--compress', action='store_true', 343 | help='If true, then copy root from inside chart for outside. ' + \ 344 | 'Otherwise, learn outside root as bias.') 345 | 346 | # Model (Objective). 347 | parser.add_argument('--reconstruct_mode', default='softmax', 348 | choices=('softmax')) 349 | 350 | # Model (Embeddings). 351 | parser.add_argument('--emb', default='w2v', choices=('w2v', 'skip', 'elmo', 'both', 'none')) 352 | 353 | # Model (Negative Sampler). 354 | parser.add_argument('--margin', default=1, type=float) 355 | parser.add_argument('--k_neg', default=100, type=int) 356 | parser.add_argument('--freq_dist_power', default=0.75, type=float) 357 | 358 | # ELMo 359 | parser.add_argument('--elmo_options_path', default='https://s3-us-west-2.amazonaws.com/allennlp/models/elmo/2x4096_512_2048cnn_2xhighway/elmo_2x4096_512_2048cnn_2xhighway_options.json', type=str) 360 | parser.add_argument('--elmo_weights_path', default='https://s3-us-west-2.amazonaws.com/allennlp/models/elmo/2x4096_512_2048cnn_2xhighway/elmo_2x4096_512_2048cnn_2xhighway_weights.hdf5', type=str) 361 | parser.add_argument('--elmo_cache_dir', default='./log/elmo', type=str, 362 | help='If set, then context-insensitive word embeddings will be cached ' + \ 363 | '(identified by a hash of the vocabulary).') 364 | 365 | # Training. 366 | parser.add_argument('--batch_size', default=10, type=int) 367 | parser.add_argument('--length_to_size', default=None, type=str, 368 | help='Easily specify a mapping of length to batch_size.' + \ 369 | 'For instance, 10:32,20:16 means that all batches' + \ 370 | 'of length 10-19 will have batch size 32, 20 or greater' + \ 371 | 'will have batch size 16, and less than 10 will have batch size' + \ 372 | 'equal to the batch_size arg. Only applies to training.') 373 | parser.add_argument('--train_dataset_size', default=None, type=int) 374 | parser.add_argument('--validation_dataset_size', default=None, type=int) 375 | parser.add_argument('--validation_batch_size', default=None, type=int) 376 | parser.add_argument('--max_epoch', default=5, type=int) 377 | parser.add_argument('--max_step', default=None, type=int) 378 | parser.add_argument('--finetune', action='store_true') 379 | parser.add_argument('--finetune_after', default=0, type=int) 380 | 381 | # Parsing. 382 | parser.add_argument('--postprocess', action='store_true') 383 | parser.add_argument('--visualize', action='store_true') 384 | 385 | # Optimization. 386 | parser.add_argument('--lr', default=2e-3, type=float) 387 | 388 | # Vis feature 389 | parser.add_argument('--alpha_contr', type=float, default=1.0) 390 | parser.add_argument('--obj_feats', action='store_true') 391 | parser.add_argument('--vl_margin', default=0.2, type=float) 392 | parser.add_argument('--use_contr', action='store_true') 393 | parser.add_argument('--use_contr_ce', action='store_true') 394 | parser.add_argument('--vg_loss', action='store_true') 395 | parser.add_argument('--alpha_vg', type=float, default=1.0) 396 | parser.add_argument('--alpha_kl', type=float, default=1.0) 397 | 398 | # S-DIORA 399 | parser.add_argument('--hinge_margin', default=1, type=float) 400 | 401 | return parser 402 | 403 | 404 | def parse_args(parser): 405 | options, other_args = parser.parse_known_args() 406 | 407 | # Set default flag values (data). 408 | options.train_data_type = options.data_type if options.train_data_type is None else options.train_data_type 409 | options.validation_data_type = options.data_type if options.validation_data_type is None else options.validation_data_type 410 | options.validation_batch_size = options.batch_size if options.validation_batch_size is None else options.validation_batch_size 411 | 412 | # Set default flag values (config). 413 | if not options.git_branch_name: 414 | options.git_branch_name = os.popen( 415 | 'git rev-parse --abbrev-ref HEAD').read().strip() 416 | 417 | if not options.git_sha: 418 | options.git_sha = os.popen('git rev-parse HEAD').read().strip() 419 | 420 | if not options.git_dirty: 421 | options.git_dirty = os.popen("git diff --quiet && echo 'clean' || echo 'dirty'").read().strip() 422 | 423 | if not options.uuid: 424 | options.uuid = str(uuid.uuid4()) 425 | 426 | if not options.experiment_name: 427 | options.experiment_name = '{}'.format(options.uuid[:8]) 428 | 429 | if not options.experiment_path: 430 | options.experiment_path = os.path.join(options.default_experiment_directory, options.experiment_name) 431 | 432 | if options.length_to_size is not None: 433 | parts = [x.split(':') for x in options.length_to_size.split(',')] 434 | options.length_to_size = {int(x[0]): int(x[1]) for x in parts} 435 | 436 | options.lowercase = not options.uppercase 437 | 438 | for k, v in options.__dict__.items(): 439 | if type(v) == str and v.startswith('~'): 440 | options.__dict__[k] = os.path.expanduser(v) 441 | 442 | # Load model settings from a flags file. 443 | if options.model_flags is not None: 444 | flags_to_use = [] 445 | flags_to_use += ['arch'] 446 | flags_to_use += ['compress'] 447 | flags_to_use += ['emb'] 448 | flags_to_use += ['hidden_dim'] 449 | flags_to_use += ['normalize'] 450 | flags_to_use += ['reconstruct_mode'] 451 | 452 | options = init_with_flags_file(options, options.model_flags, flags_to_use) 453 | 454 | # Load any setting from a flags file. 455 | if options.flags is not None: 456 | options = init_with_flags_file(options, options.flags) 457 | 458 | return options 459 | 460 | 461 | def configure(options): 462 | # Configure output paths for this experiment. 463 | configure_experiment(options.experiment_path, rank=options.local_rank) 464 | 465 | # Get logger. 466 | logger = get_logger() 467 | 468 | # Print flags. 469 | logger.info(stringify_flags(options)) 470 | save_flags(options, options.experiment_path) 471 | 472 | 473 | if __name__ == '__main__': 474 | parser = argument_parser() 475 | options = parse_args(parser) 476 | configure(options) 477 | 478 | run(options) 479 | -------------------------------------------------------------------------------- /cliora/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bobwan1995/cliora/b064bdf967d4ccc4f3327183efd888b927bfb4fb/cliora/utils/__init__.py -------------------------------------------------------------------------------- /cliora/utils/checkpoint.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | 4 | def save_experiment(experiment_file, step): 5 | data = dict(step=step) 6 | data = json.dumps(data, indent=4, sort_keys=True) 7 | with open(experiment_file, 'w') as f: 8 | f.write(data) 9 | 10 | 11 | def load_experiment(experiment_file): 12 | with open(experiment_file, 'r') as f: 13 | data = json.loads(f.read()) 14 | return data 15 | -------------------------------------------------------------------------------- /cliora/utils/flags.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import sys 4 | 5 | 6 | def read_flags(fn): 7 | with open(fn) as f: 8 | flags = json.loads(f.read()) 9 | return flags 10 | 11 | 12 | def override_with_flags(options, flags, flags_to_use=None): 13 | """ 14 | If `flags_to_use` is None, then override all flags, otherwise, 15 | only consider flags from `flags_to_use`. 16 | 17 | """ 18 | if flags_to_use is None: 19 | for k, v in flags.items(): 20 | setattr(options, k, v) 21 | else: 22 | for k in flags_to_use: 23 | setattr(options, k, flags.get(k)) 24 | return options 25 | 26 | 27 | def init_with_flags_file(options, flags_file, flags_to_use=None): 28 | flags = read_flags(flags_file) 29 | options = override_with_flags(options, flags, flags_to_use) 30 | return options 31 | 32 | 33 | def stringify_flags(options): 34 | # Ignore negative boolean flags. 35 | flags = {k: v for k, v in options.__dict__.items()} 36 | return json.dumps(flags, indent=4, sort_keys=True) 37 | 38 | 39 | def save_flags(options, experiment_path): 40 | flags = stringify_flags(options) 41 | target_file = os.path.join(experiment_path, 'flags.json') 42 | with open(target_file, 'w') as f: 43 | f.write(flags) 44 | -------------------------------------------------------------------------------- /cliora/utils/fs.py: -------------------------------------------------------------------------------- 1 | import errno 2 | import os 3 | 4 | 5 | def mkdir_p(path): 6 | try: 7 | os.makedirs(path) 8 | except OSError as exc: # Python >2.5 9 | if exc.errno == errno.EEXIST and os.path.isdir(path): 10 | pass 11 | else: 12 | raise 13 | -------------------------------------------------------------------------------- /cliora/utils/path.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | 4 | def package_path(): 5 | my_directory = os.path.dirname(os.path.abspath(__file__)) 6 | my_package_directory = os.path.join(my_directory, '..', '..') 7 | return os.path.abspath(my_package_directory) 8 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch 2 | torchvision 3 | opencv-python 4 | tqdm 5 | h5py 6 | -------------------------------------------------------------------------------- /test_cliora.sh: -------------------------------------------------------------------------------- 1 | #export PYTHONPATH=$(pwd):$PYTHONPATH 2 | EXP_PATH="./outputs/flickr/flickr_cliora_1e5_mlpshare_bs32_skip_seed1234" 3 | #python cliora/scripts/right_branch.py \ 4 | #for ckpt in "0.pt" "1.pt" "2.pt" "3.pt" "4.pt" 5 | #do 6 | ckpt="1.pt" 7 | export CUDA_VISIBLE_DEVICES=0 8 | python cliora/scripts/parse.py \ 9 | --cuda \ 10 | --arch mlp \ 11 | --batch_size 64 \ 12 | --emb skip \ 13 | --embeddings_path ./flickr_data/skip_thoughts_dict.pkl \ 14 | --hidden_dim 400 \ 15 | --k_neg 100 \ 16 | --log_every_batch 100 \ 17 | --normalize unit \ 18 | --reconstruct_mode softmax \ 19 | --data_type flickr \ 20 | --train_path ./flickr_data/flickr_train.json \ 21 | --validation_path ./flickr_data/flickr_test.json \ 22 | --experiment_path $EXP_PATH \ 23 | --obj_feats \ 24 | --use_contr \ 25 | --vg_loss \ 26 | --load_model_path $EXP_PATH/model.epoch_$ckpt 27 | #done -------------------------------------------------------------------------------- /test_diora.sh: -------------------------------------------------------------------------------- 1 | #export PYTHONPATH=$(pwd):$PYTHONPATH 2 | EXP_PATH="./outputs/flickr/flickr_diora_5e4_mlpshare_bs32_RandInit_seed1234" 3 | export CUDA_VISIBLE_DEVICES=0 4 | python cliora/scripts/parse.py \ 5 | --cuda \ 6 | --arch mlp \ 7 | --batch_size 64 \ 8 | --emb none \ 9 | --embeddings_path none \ 10 | --hidden_dim 400 \ 11 | --k_neg 100 \ 12 | --log_every_batch 100 \ 13 | --normalize unit \ 14 | --reconstruct_mode softmax \ 15 | --data_type flickr \ 16 | --train_path ./flickr_data/flickr_train.json \ 17 | --validation_path ./flickr_data/flickr_test.json \ 18 | --experiment_path $EXP_PATH \ 19 | --load_model_path $EXP_PATH/model.epoch_29.pt -------------------------------------------------------------------------------- /train_cliora.sh: -------------------------------------------------------------------------------- 1 | # ***************** CLIORA ******************* 2 | 3 | # Finetune CLIORA based on DIORA on Flickr30K with word embedding from pretrained DIORA 4 | EXP_PATH="./outputs/flickr/flickr_cliora_1e5_mlpshare_bs32_RandInit_seed1234_valid" 5 | export CUDA_VISIBLE_DEVICES=0,1,2,3 6 | python -m torch.distributed.launch --nproc_per_node=4 cliora/scripts/train.py \ 7 | --cuda --multigpu \ 8 | --max_epoch 10 \ 9 | --seed 1234 \ 10 | --arch mlp \ 11 | --batch_size 32 \ 12 | --emb none \ 13 | --embeddings_path none \ 14 | --hidden_dim 400 \ 15 | --k_neg 100 \ 16 | --log_every_batch 100 \ 17 | --lr 1e-5 \ 18 | --normalize unit \ 19 | --reconstruct_mode softmax \ 20 | --train_filter_length 40 \ 21 | --data_type flickr \ 22 | --train_path ./flickr_data/flickr_train.json \ 23 | --validation_path ./flickr_data/flickr_test.json \ 24 | --experiment_path $EXP_PATH \ 25 | --obj_feats \ 26 | --use_contr --alpha_contr 1.0 \ 27 | --vg_loss --alpha_vg 1.0 \ 28 | --load_model_path ./outputs/flickr/flickr_diora_5e4_mlpshare_bs32_RandInit_seed1234/model.epoch_29.pt 29 | 30 | 31 | # Finetune CLIORA based on DIORA on Flickr30K with skip-thoughts initialized word embedding 32 | EXP_PATH="./outputs/flickr/flickr_cliora_1e5_mlpshare_bs32_skip_seed1234" 33 | export CUDA_VISIBLE_DEVICES=0,1,2,3 34 | python -m torch.distributed.launch --nproc_per_node=4 cliora/scripts/train.py \ 35 | --cuda --multigpu \ 36 | --max_epoch 10 \ 37 | --seed 1234 \ 38 | --master_port 12345 \ 39 | --arch mlp \ 40 | --batch_size 32 \ 41 | --emb skip \ 42 | --embeddings_path ./flickr_data/skip_thoughts_dict.pkl \ 43 | --hidden_dim 400 \ 44 | --k_neg 100 \ 45 | --log_every_batch 100 \ 46 | --lr 1e-5 \ 47 | --normalize unit \ 48 | --reconstruct_mode softmax \ 49 | --train_filter_length 40 \ 50 | --data_type flickr \ 51 | --train_path ./flickr_data/flickr_train.json \ 52 | --validation_path ./flickr_data/flickr_test.json \ 53 | --experiment_path $EXP_PATH \ 54 | --obj_feats \ 55 | --use_contr --alpha_contr 1.0 \ 56 | --vg_loss --alpha_vg 1.0 \ 57 | --load_model_path ./outputs/flickr/flickr_diora_5e4_mlpshare_bs64_skip_seed1234/model.epoch_29.pt 58 | -------------------------------------------------------------------------------- /train_diora.sh: -------------------------------------------------------------------------------- 1 | # ***************** DIORA ******************* 2 | # Train original DIORA on Flickr30K with randomly-initialized word embedding 3 | # For randomly-initialized word embedding, bs 32 gets better results than bs64 4 | EXP_PATH="./outputs/flickr/flickr_diora_5e4_mlpshare_bs32_RandInit_seed1234" 5 | 6 | export CUDA_VISIBLE_DEVICES=0,1,2,3 7 | python -m torch.distributed.launch --nproc_per_node=4 cliora/scripts/train.py \ 8 | --cuda --multigpu \ 9 | --max_epoch 30 \ 10 | --seed 1234 \ 11 | --arch mlp \ 12 | --batch_size 32 \ 13 | --emb none \ 14 | --embeddings_path none \ 15 | --hidden_dim 400 \ 16 | --k_neg 100 \ 17 | --log_every_batch 100 \ 18 | --lr 5e-4 \ 19 | --normalize unit \ 20 | --reconstruct_mode softmax \ 21 | --train_filter_length 40 \ 22 | --data_type flickr \ 23 | --train_path ./flickr_data/flickr_train.json \ 24 | --validation_path ./flickr_data/flickr_test.json \ 25 | --experiment_path $EXP_PATH 26 | 27 | 28 | # Train original DIORA on Flickr30K with skip-thoughts initialized word embedding 29 | EXP_PATH="./outputs/flickr/flickr_diora_5e4_mlpshare_bs64_skip_seed1234" 30 | 31 | export CUDA_VISIBLE_DEVICES=0,1,2,3 32 | python -m torch.distributed.launch --nproc_per_node=4 cliora/scripts/train.py \ 33 | --cuda --multigpu \ 34 | --max_epoch 30 \ 35 | --seed 1234 \ 36 | --arch mlp \ 37 | --batch_size 64 \ 38 | --emb skip \ 39 | --embeddings_path ./flickr_data/skip_thoughts_dict.pkl \ 40 | --hidden_dim 400 \ 41 | --k_neg 100 \ 42 | --log_every_batch 100 \ 43 | --lr 5e-4 \ 44 | --normalize unit \ 45 | --reconstruct_mode softmax \ 46 | --train_filter_length 40 \ 47 | --data_type flickr \ 48 | --train_path ./flickr_data/flickr_train.json \ 49 | --validation_path ./flickr_data/flickr_test.json \ 50 | --experiment_path $EXP_PATH --------------------------------------------------------------------------------