├── .gitignore ├── README.md ├── create_gt_visual_words.py ├── download.sh ├── requirements.txt ├── rule_mining.py ├── rule_utils.py ├── test_mining.py └── vqa.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 98 | __pypackages__/ 99 | 100 | # Celery stuff 101 | celerybeat-schedule 102 | celerybeat.pid 103 | 104 | # SageMath parsed files 105 | *.sage.py 106 | 107 | # Environments 108 | .env 109 | .venv 110 | env/ 111 | venv/ 112 | ENV/ 113 | env.bak/ 114 | venv.bak/ 115 | 116 | # Spyder project settings 117 | .spyderproject 118 | .spyproject 119 | 120 | # Rope project settings 121 | .ropeproject 122 | 123 | # mkdocs documentation 124 | /site 125 | 126 | # mypy 127 | .mypy_cache/ 128 | .dmypy.json 129 | dmypy.json 130 | 131 | # Pyre type checker 132 | .pyre/ 133 | 134 | # pytype static type analyzer 135 | .pytype/ 136 | 137 | # Cython debug symbols 138 | cython_debug/ 139 | 140 | .vscode/ 141 | logs/ 142 | data/ 143 | 144 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | # Rule-Mining for shortcut discovery (VQA-CE) 3 | 4 | vqa-ce 5 | 6 | 7 | This repo contains the rule mining pipeline described in the article : 8 | **Beyond Question-Based Biases: Assessing Multimodal Shortcut Learning in Visual Question Answering** by Corentin Dancette, Rémi Cadène, Damien Teney and 9 | Matthieu Cord (https://arxiv.org/abs/2104.03149). 10 | It also provides the VQA-CE dataset. 11 | 12 | Website here: https://cdancette.fr/projects/vqa-ce/ 13 | 14 | This code was developped with python 3.7 and pytorch 1.7.0. 15 | 16 | ### VQA-CE 17 | The VQA-CE counterexamples subset can be downloaded here : 18 | - counterexamples: https://github.com/cdancette/detect-shortcuts/releases/download/v1.0/counterexamples.json 19 | - hard: https://github.com/cdancette/detect-shortcuts/releases/download/v1.0/hard.json 20 | 21 | The "easy" subset can be obtained by substracting counterexamples and hard from all question_ids. 22 | 23 | ## Usage 24 | 25 | ### Installing requirements 26 | First, you need to install gminer. Follow instructions at https://github.com/cdancette/GMiner. 27 | 28 | For python requirements, run `pip install -r requirements.txt`. This will install pytorch, numpy and tqdm. 29 | ### Visual Question Answering (VQA) 30 | 31 | #### Download VQA and COCO data 32 | 33 | First, run `./download.sh`. Data will be downloaded in the `./data` directory. 34 | 35 | #### Run the rule mining pipeline 36 | 37 | Then run `python vqa.py --gminer_path ` to run our pipeline on the VQA v2 dataset. 38 | You can change the parameters, see the end of the `vqa.py` file or run `python vqa.py --help`. 39 | 40 | This will save in logs/vqa2 various files containing the rules found in the dataset, 41 | the question_ids for easy and counterexamples splits, and the predictions made by the rule model. 42 | 43 | To evaluate predictions, you can use the [multimodal](https://github.com/cdancette/multimodal) library: 44 | 45 | ```bash 46 | pip install multimodal 47 | python -m multimodal vqa2-eval -p logs/vqa2/rules_predictions.json --split val 48 | ``` 49 | 50 | 51 | ### Other task 52 | 53 | 54 | #### fit 55 | You can use our library to extract rule for any other dataset. 56 | 57 | To do so, you can use the `fit` function in our `rule_mining.py` 58 | It takes the following arguments : 59 | `fit(dataset, answer_ids, gminer_support=0.01, gminer_max_length=0, gminer_path=None)`, where : 60 | 61 | - `dataset` is a list of transactions. Each transaction is a list of integers describing tokens. 62 | - `answer_ids` is a list of integers, describing answer ids. They should be contained between 0 and max answer id. 63 | - `gminer_support` is the minimum support used to mine frequent itemset. 64 | - `gminer_max_length`: minimum length of an itemset. By default no minimum length 65 | - `gminer_path`: path to the gminer binary you compiled (see top of the readme). 66 | 67 | 68 | The function returns a list of rules, contained in namedtuples: `Rule = namedtuple("Rule", ["itemset", "ans", "sup", "conf"])`. 69 | 70 | The itemset contains the input token ids, ans is the answer id, sup and conf are the support and the confidence of this rule. 71 | 72 | #### match_rules 73 | 74 | We provide a function to get, for each example in your dataset, all rules matching its input. 75 | 76 | `match_rules(dataset, rules, answers=None, bsize=500)` 77 | 78 | This will return `(matching_rules, correct_rules)`, where `matching_rules` is a list of the same length as the dataset, giving for each example, the matching rules. 79 | 80 | You can use this to build your counterexamples subset (examples where all rules are incorrect), or your easy subset (where at least one rule is correct). 81 | -------------------------------------------------------------------------------- /create_gt_visual_words.py: -------------------------------------------------------------------------------- 1 | import json 2 | from collections import defaultdict 3 | from tqdm import tqdm 4 | 5 | 6 | with open("/data/common-data/coco/annotations/instances_train2014.json") as f: 7 | data_train = json.load(f) 8 | with open("/data/common-data/coco/annotations/instances_val2014.json") as f: 9 | data_val = json.load(f) 10 | 11 | result_path = "data/image_to_gt_vocab.json" 12 | 13 | result = defaultdict(lambda: {"classes": [], "scores": []}) 14 | 15 | categories = {c["id"]: c for c in data_train["categories"]} 16 | 17 | for annot in tqdm(data_train["annotations"]): 18 | image_id = annot["image_id"] 19 | category = categories[annot["category_id"]]["name"] 20 | if category not in result[image_id]["classes"]: 21 | result[image_id]["classes"].append(category) 22 | result[image_id]["scores"].append(1.0) 23 | print(len(result)) 24 | 25 | for annot in tqdm(data_val["annotations"]): 26 | image_id = annot["image_id"] 27 | category = categories[annot["category_id"]]["name"] 28 | if category not in result[image_id]["classes"]: 29 | result[image_id]["classes"].append(category) 30 | result[image_id]["scores"].append(1.0) 31 | print(len(result)) 32 | 33 | with open(result_path, "w") as f: 34 | json.dump(result, f) 35 | -------------------------------------------------------------------------------- /download.sh: -------------------------------------------------------------------------------- 1 | mkdir -p data/vqa2 2 | cd data/vqa2 3 | wget https://s3.amazonaws.com/cvmlp/vqa/mscoco/vqa/v2_Questions_Train_mscoco.zip 4 | wget https://s3.amazonaws.com/cvmlp/vqa/mscoco/vqa/v2_Questions_Val_mscoco.zip 5 | wget https://s3.amazonaws.com/cvmlp/vqa/mscoco/vqa/v2_Annotations_Train_mscoco.zip 6 | wget https://s3.amazonaws.com/cvmlp/vqa/mscoco/vqa/v2_Annotations_Val_mscoco.zip 7 | 8 | unzip "*mscoco.zip" 9 | rm *mscoco.zip 10 | 11 | mkdir -p ../data/coco 12 | cd ../data/coco 13 | wget https://github.com/cdancette/detect-shortcuts/releases/download/v1.0/image_to_detection.json 14 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch 2 | numpy 3 | tqdm -------------------------------------------------------------------------------- /rule_mining.py: -------------------------------------------------------------------------------- 1 | import os 2 | from collections import namedtuple 3 | import sys 4 | import json 5 | from typing import List 6 | from tempfile import TemporaryDirectory 7 | 8 | from tqdm import tqdm 9 | import numpy as np 10 | import torch 11 | 12 | from rule_utils import superset_filtering 13 | 14 | def loadjson(path): 15 | with open(path) as f: 16 | return json.load(f) 17 | 18 | 19 | Rule = namedtuple("Rule", ["itemset", "ans", "sup", "conf"]) 20 | 21 | def run_gminer(transactions, support, max_length=0, gminer_path=None): 22 | with TemporaryDirectory() as tempdir: 23 | path_gminer_in = tempdir + f"/gminer_in.txt" 24 | path_gminer_out = tempdir + f"/gminer_out" 25 | 26 | # convert trans_by_ans to GMiner format 27 | print("Converting transactions to GMiner input format") 28 | if not os.path.exists(path_gminer_in): 29 | with open(path_gminer_in, "w") as f: 30 | for trans in tqdm(transactions): 31 | trans = " ".join([str(x) for x in trans]) 32 | f.write(trans + "\n") 33 | 34 | print("Running GMiner") 35 | print(f"Number of transactions : {len(transactions)}") 36 | print(f"Number of items : {max([max(t) for t in transactions])}") 37 | 38 | if support * len(transactions) < 1: 39 | min_support = 1 / len(transactions) 40 | print( 41 | f"Warning: Number of transactions * support = {support * len(transactions)} is below 1. " 42 | f"Minimum support is {min_support}", 43 | ) 44 | sys.exit(1) 45 | if gminer_path is None: 46 | gminer_path = "./GMiner" 47 | command = ( 48 | f"{gminer_path} -i {path_gminer_in} -o {path_gminer_out} -s {support} -w 1" 49 | ) 50 | if max_length != 0: 51 | command += f" -l {max_length}" 52 | 53 | print("Running Gminer:", command) 54 | out = os.system(command) 55 | if out != 0: 56 | os.remove(path_gminer_out) 57 | sys.exit(1) 58 | print("Done running gminer") 59 | 60 | itemsets = [] 61 | print("Parsing Gminer output", path_gminer_out) 62 | with open(path_gminer_out, "r") as f: 63 | for line in tqdm(f): 64 | line = line.strip() 65 | tmp = line.split(" ") 66 | itemset = [int(x) for x in tmp[:-1]] 67 | supp = float(tmp[-1][:-1][1:]) 68 | itemsets.append((itemset, supp)) 69 | return itemsets 70 | 71 | 72 | def fit( 73 | dataset, 74 | answer_ids, 75 | gminer_support=0.01, 76 | gminer_max_length=0, 77 | gminer_path=None, 78 | ): 79 | """ 80 | train_dataset: list of token ids 81 | train_answers: list of answer ids 82 | """ 83 | 84 | max_token_id = max(max(t) for t in dataset) 85 | answer_ids = [t + max_token_id + 1 for t in answer_ids] 86 | item_id_to_ans_id = {t: t - max_token_id - 1 for t in answer_ids} 87 | 88 | transactions = [ 89 | items + [ans_id] for (items, ans_id) in zip(dataset, answer_ids) 90 | ] 91 | 92 | print(f"Minimum number of examples per rule: {gminer_support * len(transactions)}") 93 | 94 | itemsets = run_gminer( 95 | transactions, 96 | support=gminer_support, 97 | max_length=gminer_max_length, 98 | gminer_path=gminer_path, 99 | ) 100 | 101 | supports_by_itemset = {} 102 | supports_by_itemset[()] = 1.0 # initialize empty tuple 103 | for itemset, support in itemsets: 104 | itemset = tuple(sorted(itemset)) 105 | supports_by_itemset[itemset] = support 106 | 107 | # Extracting rules (itemsets with answers) 108 | pre_rules = [] 109 | for (itemset, support_with_ans) in itemsets: 110 | for i, it in enumerate(itemset): 111 | if it in item_id_to_ans_id: 112 | del itemset[i] 113 | pre_rules.append( 114 | (tuple(sorted(itemset)), item_id_to_ans_id[it], support_with_ans) 115 | ) 116 | break 117 | print(f"Number of rules : {len(pre_rules)}") 118 | # Computing confidence on training set 119 | print("Computing confidences on training set") 120 | rules: List[Rule] = [] 121 | for rule in tqdm(pre_rules): 122 | itemset, ans, support_with_ans = rule 123 | if len(itemset) == 0: 124 | confidence = support 125 | elif itemset in supports_by_itemset: 126 | # add confidence 127 | confidence = support_with_ans / supports_by_itemset[itemset] 128 | else: 129 | print(f"Missing data for itemset {itemset}...") 130 | rule = Rule( 131 | itemset=itemset, ans=ans, sup=supports_by_itemset[itemset], conf=confidence 132 | ) 133 | rules.append(rule) 134 | 135 | ########################## 136 | # SUPERSET Filtering 137 | ########################## 138 | # Here, we remove an itemset if 139 | # there was a previous itemset that is 140 | # a subset of it, and had better conf 141 | # and the same answer 142 | # Also, if the itemset was previously in the rules, 143 | # then we discard it (it means that there was another 144 | # rule with another answer which has a better confidence). 145 | print("Performing superset filtering") 146 | rules = superset_filtering(rules) 147 | 148 | rules = sorted( 149 | rules, key=lambda r: (-r.conf, -r.sup, len(r.itemset)) 150 | ) # conf, support, length 151 | 152 | print(f"Number of rules obtained from training set : {len(rules)}") 153 | return rules 154 | 155 | 156 | def match_rules( 157 | dataset, 158 | rules: List[Rule], 159 | answers=None, 160 | bsize=500, 161 | stop_all_have_rules=False, 162 | stop_all_correct_rules=False, 163 | ): 164 | """ 165 | This function will return lists of all rules that match a given example in the dataset. 166 | Args: 167 | dataset: list of list of token ids 168 | rules (List[Rule]): list of Rules 169 | answers: List[int] 170 | """ 171 | # filling transaction matrix 172 | max_word_id = max(max(d) for d in dataset) 173 | transactions_matrix = np.zeros((len(dataset), max_word_id + 1), dtype=bool) 174 | for i, d in enumerate(dataset): 175 | transactions_matrix[i, d] = True 176 | 177 | transactions_matrix = torch.from_numpy(transactions_matrix).bool().cuda() 178 | pad_index = transactions_matrix.shape[1] 179 | N = transactions_matrix.shape[0] 180 | 181 | # pad index 182 | transactions_matrix = torch.cat( 183 | (transactions_matrix, torch.ones(N, 1).bool().cuda()), dim=1, 184 | ) 185 | 186 | best_rules = dict() 187 | best_correct_rule = dict() 188 | all_rules = [[] for _ in range(len(transactions_matrix))] 189 | correct_rules = [[] for _ in range(len(transactions_matrix))] 190 | 191 | # Progress bars and iterables 192 | pbar = tqdm(total=len(transactions_matrix)) 193 | pbar.set_description("Total rules found ") 194 | pbar_correct = tqdm(total=len(transactions_matrix)) 195 | pbar_correct.set_description("Correct rules found") 196 | 197 | for i in tqdm(range(0, len(rules), bsize), desc="Rules processed"): 198 | rs = rules[i : i + bsize] 199 | itemsets = [r.itemset for r in rs] 200 | max_length = max([len(r) for r in itemsets]) 201 | itemsets = [list(r) + [pad_index] * (max_length - len(r)) for r in itemsets] 202 | indexes_concerned = ( 203 | (transactions_matrix[:, itemsets].all(dim=2).nonzero()) 204 | .detach() 205 | .cpu() 206 | .numpy() 207 | ) # (N * 2) where 2 = (trans_id, rule_id) 208 | transactions_for_rule = [[] for _ in range(len(rs))] 209 | 210 | num_trans_found = 0 211 | num_correct_trans_found = 0 212 | 213 | for j in range(len(indexes_concerned)): 214 | trans_id, rule_id = indexes_concerned[j] 215 | rule_id = rule_id + i 216 | rule = rules[rule_id] 217 | transactions_for_rule[rule_id - i].append(trans_id) 218 | if trans_id not in best_rules: 219 | num_trans_found += 1 220 | best_rules[trans_id] = rule 221 | all_rules[trans_id].append(rule) 222 | if rule.ans == answers[trans_id]: 223 | if trans_id not in best_correct_rule: 224 | best_correct_rule[trans_id] = rule 225 | num_correct_trans_found += 1 226 | correct_rules[trans_id].append(rule) 227 | 228 | pbar.update(num_trans_found) 229 | pbar_correct.update(num_correct_trans_found) 230 | 231 | if stop_all_have_rules and len(best_rules) == len(transactions_matrix): 232 | break 233 | if stop_all_correct_rules and len(best_correct_rule) == len( 234 | transactions_matrix 235 | ): 236 | break 237 | pbar.close() 238 | pbar_correct.close() 239 | del transactions_matrix 240 | 241 | return ( 242 | all_rules, 243 | correct_rules, 244 | ) 245 | -------------------------------------------------------------------------------- /rule_utils.py: -------------------------------------------------------------------------------- 1 | from itertools import combinations 2 | from tqdm import tqdm 3 | 4 | qid_to_annot = dict() 5 | 6 | def superset_filtering(rules): 7 | """ 8 | Two goals: 9 | - remove duplicate rules (ie, rules that have the same itemset, but different answers). 10 | We keep only the rule with the best confidence. 11 | - remove rules that are useless, because they are a superset of a previous rule (so they are more constrained, thus 12 | they have a smaller support), but also have a smaller confidence. 13 | """ 14 | rules_sorted = sorted( 15 | rules, key=lambda r: (len(r.itemset), -r.conf) 16 | ) # sorted by length (up), confidence (down) 17 | rules = [] 18 | rule_by_itemset = dict() # itemset -> list of rules 19 | # rules_discarded = defaultdict(set) 20 | for rule in tqdm(rules_sorted): 21 | itemset, aid, support, conf = rule 22 | itemset = frozenset(itemset) 23 | discard_rule = False 24 | if itemset in rule_by_itemset: 25 | continue 26 | else: 27 | rule_by_itemset[itemset] = rule 28 | 29 | if len(itemset) > 0: 30 | for it in combinations(itemset, len(itemset) - 1): 31 | it = frozenset(it) 32 | if it in rule_by_itemset: 33 | old_r = rule_by_itemset[it] 34 | if old_r.conf >= conf and old_r.ans == aid: 35 | discard_rule = True 36 | break 37 | if not discard_rule: 38 | rules.append(rule) 39 | print( 40 | f"After discarding rules, going from {len(rules_sorted)} to {len(rules)} rules." 41 | ) 42 | return rules 43 | 44 | 45 | def test_superset_filtering(): 46 | # various test cases that we want to manage 47 | assert superset_filtering([[(), 0.1, 10, 0.1]]) == [[(), 0.1, 10, 0.1]] 48 | 49 | rules = [ 50 | [(), 0.1, 10, 0.1], 51 | [(0,), 0.1, 0, 0.5], 52 | [(0, 1), 0.1, 0, 0.5], 53 | [(0, 1, 2), 0.1, 0, 0.5], 54 | ] 55 | assert superset_filtering(rules) == [[(), 0.1, 10, 0.1], [(0,), 0.1, 0, 0.5]] 56 | 57 | rules = [ 58 | [(), 0.1, 10, 0.1], 59 | [(0,), 0.1, 0, 0.5], 60 | [(0, 1), 0.1, 0, 0.3], 61 | [(0, 1, 2), 0.1, 0, 0.2], 62 | [(0, 1, 5), 0.1, 3, 0.2], # additional to keep 63 | ] 64 | 65 | assert superset_filtering(rules) == [ 66 | [(), 0.1, 10, 0.1], 67 | [(0,), 0.1, 0, 0.5], 68 | [(0, 1, 5), 0.1, 3, 0.2], 69 | ] 70 | 71 | # TODO this test fails.. it is quite bad, because it could allow 72 | # us to discard a lot of useless rules... 73 | rules = [ 74 | [(), 0.1, 0, 0.5], 75 | [(0,), 0.1, 0, 0.3], 76 | [(0, 1), 0.1, 0, 0.4], 77 | ] 78 | # assert superset_filtering(rules) == [ 79 | # [(), 0.1, 0, 0.5], 80 | # ] 81 | 82 | rules = [ 83 | [(), 0.1, 10, 0.1], 84 | [(0,), 0.1, 0, 0.5], 85 | [(0, 1), 0.1, 0, 0.6], 86 | [(0, 1, 2), 0.1, 0, 0.7], 87 | ] 88 | 89 | assert superset_filtering(rules) == [ 90 | [(), 0.1, 10, 0.1], 91 | [(0,), 0.1, 0, 0.5], 92 | [(0, 1), 0.1, 0, 0.6], 93 | [(0, 1, 2), 0.1, 0, 0.7], 94 | ] 95 | 96 | # same itemset, different answers, one is better (confidence) 97 | # We keep only the best. 98 | rules = [ 99 | [(), 0.1, 10, 0.2], 100 | [(), 0.1, 5, 0.1], 101 | ] 102 | 103 | assert superset_filtering(rules) == [[(), 0.1, 10, 0.2]] 104 | -------------------------------------------------------------------------------- /test_mining.py: -------------------------------------------------------------------------------- 1 | from tempfile import TemporaryDirectory 2 | from rule_mining import fit, Rule, match_rules 3 | # from rule_utils import match_rules 4 | 5 | def test_fit(): 6 | dataset = [ 7 | [0, 1, 2], 8 | [0, 1, 2], 9 | [0, 1, 3], 10 | [0, 1, 3], 11 | [0, 1, 4], 12 | [0, 1, 4], 13 | ] 14 | answers = [0, 0, 1, 1, 2, 2] 15 | rules = fit(dataset, answers, support_gminer=0.2) 16 | 17 | assert len(rules) == 4 18 | for k, (itemset, ans) in enumerate([[(2,), 0], [(3,), 1], [(4,), 2], [(), 0],]): 19 | assert rules[k].itemset == itemset 20 | assert rules[k].ans == ans 21 | 22 | # match_rules(dataset, rules) 23 | all_rules, correct_rules = match_rules(dataset, rules, answers=answers, bsize=10) 24 | print(all_rules) 25 | 26 | # item 0 27 | assert len(all_rules[0]) == 2 28 | assert all_rules[0][0].itemset == (2,) 29 | assert all_rules[0][1].itemset == () 30 | assert len(correct_rules[0]) == 2 31 | assert correct_rules[0][0].itemset == (2,) 32 | assert correct_rules[0][1].itemset == () 33 | 34 | # item 2 35 | assert len(all_rules[2]) == 2 36 | assert all_rules[2][0].itemset == (3,) 37 | assert all_rules[2][1].itemset == () 38 | assert len(correct_rules[2]) == 1 39 | assert correct_rules[2][0].itemset == (3,) 40 | -------------------------------------------------------------------------------- /vqa.py: -------------------------------------------------------------------------------- 1 | from rule_mining import Rule, fit, match_rules 2 | import pickle 3 | import os 4 | from collections import Counter 5 | import json 6 | from typing import List 7 | from tqdm import tqdm 8 | from torchtext.data.utils import get_tokenizer 9 | 10 | 11 | def loadjson(path): 12 | with open(path) as f: 13 | return json.load(f) 14 | 15 | 16 | def create_dataset( 17 | questions, 18 | visual_words="data/image_to_detection.json", 19 | annotations=None, 20 | textual=True, 21 | visual=True, 22 | visual_threshold=0.5, 23 | proportion=1.0, 24 | most_common_answers=None, 25 | ): 26 | print("Creating VQA binary dataset") 27 | print(f"Loading visual words at {visual_words}") 28 | visual_words = loadjson(visual_words) 29 | 30 | tokenizer = get_tokenizer("basic_english") 31 | 32 | total_len = int(len(questions) * proportion) 33 | transactions = [] 34 | answers = [] 35 | indexes = [] 36 | 37 | skipped = 0 38 | 39 | ############# 40 | # Regular textual 41 | ############# 42 | for i in tqdm(range(total_len)): 43 | transaction = [] 44 | 45 | if textual: 46 | tokens = tokenizer(questions[i]["question"]) 47 | transaction.extend(tokens) 48 | 49 | if visual: 50 | image_id = str(questions[i]["image_id"]) 51 | if image_id in visual_words: 52 | vwords = visual_words[image_id] 53 | classes = vwords["classes"] 54 | scores = vwords["scores"] 55 | if visual_threshold != 0: 56 | classes = [ 57 | c 58 | for (i, c) in enumerate(classes) 59 | if scores[i] >= visual_threshold 60 | ] 61 | classes = ["V_" + c for c in classes] # visual marker 62 | transaction.extend(classes) 63 | transactions.append(transaction) 64 | indexes.append(i) 65 | if annotations is not None: 66 | answers.append(annotations[i]["multiple_choice_answer"]) 67 | 68 | assert len(transactions) == len(answers) 69 | 70 | if annotations is not None and most_common_answers is not None: 71 | occurences = Counter(answers).most_common(most_common_answers) 72 | keep_answers = set(a for (a, _) in occurences) 73 | 74 | new_transactions = [] 75 | new_answers = [] 76 | new_indexes = [] 77 | for k in range(len(transactions)): 78 | if answers[k] in keep_answers: 79 | new_transactions.append(transactions[k]) 80 | new_answers.append(answers[k]) 81 | new_indexes.append(indexes[k]) 82 | transactions, answers, indexes = new_transactions, new_answers, new_indexes 83 | 84 | if annotations is not None: 85 | return transactions, answers, indexes 86 | 87 | return transactions, indexes 88 | 89 | 90 | def vqa( 91 | textual=True, 92 | visual=True, 93 | train_questions_path="data/vqa2/v2_OpenEnded_mscoco_train2014_questions.json", 94 | train_annotations_path="data/vqa2/v2_mscoco_train2014_annotations.json", 95 | val_questions_path=None, 96 | val_annotations_path=None, 97 | visual_threshold=0.5, 98 | support_gminer=2e-5, 99 | gminer_path=None, 100 | min_conf=0.3, 101 | max_length=5, 102 | version="vqa2", 103 | save_dir=None, 104 | keep_all_rules_train_predictions=False, 105 | visual_words="data/image_to_detection.json", 106 | ): 107 | 108 | train_questions = loadjson(train_questions_path) 109 | train_annotations = loadjson(train_annotations_path) 110 | if type(train_questions) == dict and "questions" in train_questions: 111 | train_questions = train_questions["questions"] 112 | train_annotations = train_annotations["annotations"] 113 | 114 | 115 | os.makedirs(save_dir, exist_ok=True) 116 | train_dataset, train_answers, train_indexes = create_dataset( 117 | train_questions, 118 | annotations=train_annotations, 119 | proportion=1.0, 120 | most_common_answers=3000, 121 | textual=textual, 122 | visual=visual, 123 | visual_threshold=visual_threshold, 124 | visual_words=visual_words, 125 | ) 126 | tokens = list(set(t for transaction in train_dataset for t in transaction)) 127 | token_to_id = {t: i for (i, t) in enumerate(tokens)} 128 | train_transactions = [ 129 | [token_to_id[t] for t in transaction] for transaction in train_dataset 130 | ] 131 | all_answers = list(set(train_answers)) 132 | ans_to_id = {ans: i for (i, ans) in enumerate(all_answers)} 133 | train_answers_ids = [ans_to_id[ans] for ans in train_answers] 134 | 135 | # rule mining 136 | rules: List[Rule] = fit( 137 | train_transactions, 138 | train_answers_ids, 139 | gminer_support=support_gminer, 140 | gminer_max_length=max_length, 141 | gminer_path=gminer_path, 142 | ) 143 | 144 | # - keep only rules with confidence > min_conf 145 | rules = [r for r in rules if r.conf >= min_conf] 146 | 147 | # show the best 20 rules 148 | for r in rules[:20]: 149 | print([tokens[tid] for tid in r.itemset], all_answers[r.ans], r.sup, r.conf) 150 | 151 | # match rules with examples 152 | matching_rules_train, matching_correct_rules_train = match_rules( 153 | train_transactions, rules, answers=train_answers_ids 154 | ) 155 | 156 | # val 157 | val_questions = loadjson(val_questions_path) 158 | val_annotations = loadjson(val_annotations_path) 159 | 160 | if type(val_questions) == dict and "questions" in val_questions: 161 | val_questions = val_questions["questions"] 162 | val_annotations = val_annotations["annotations"] 163 | 164 | val_dataset, val_answers, val_indexes = create_dataset( 165 | val_questions, 166 | annotations=val_annotations, 167 | proportion=1.0, 168 | textual=textual, 169 | visual=visual, 170 | visual_threshold=visual_threshold, 171 | visual_words=visual_words, 172 | ) 173 | 174 | val_transactions = [ 175 | [token_to_id[t] for t in transaction if t in token_to_id] 176 | for transaction in val_dataset 177 | ] 178 | val_answers_ids = [ans_to_id.get(ans, -1) for ans in val_answers] 179 | 180 | matching_rules_val, matching_correct_rules_val = match_rules( 181 | val_transactions, rules, val_answers_ids 182 | ) 183 | 184 | # - create hard evaluations set 185 | # we load annotations, because we'll consider every answer, not only the top answer. 186 | qid_counterexamples = [] 187 | qid_easy = [] 188 | qid_hard = [] 189 | for annot, rs in zip(val_annotations, matching_rules_val): 190 | possible_answers = set(ans["answer"] for ans in annot["answers"]) 191 | rules_answers = set(all_answers[r.ans] for r in rs) 192 | if len(set.intersection(rules_answers, possible_answers)) == 0 and len(rs) != 0: 193 | # goes into counterexamples 194 | qid_counterexamples.append(annot["question_id"]) 195 | elif len(rs) == 0: 196 | qid_hard.append(annot["question_id"]) 197 | else: 198 | qid_easy.append(annot["question_id"]) 199 | 200 | # keep_rules: 201 | # we keep only one correct rule per training example 202 | if not keep_all_rules_train_predictions: 203 | keep_rules = set() 204 | for rs in matching_correct_rules_train: 205 | if rs: 206 | keep_rules.add(rs[0]) 207 | print("Rules kept after keeping only one per training example:", len(keep_rules)) 208 | else: 209 | keep_rules = None 210 | 211 | 212 | # build predictions on validation set. 213 | n_missing_rules = 0 214 | predictions = [] 215 | for i, rs in enumerate(matching_rules_val): 216 | qid = val_questions[val_indexes[i]]["question_id"] 217 | if keep_rules is not None: 218 | rs = [r for r in rs if r in keep_rules] 219 | if rs: 220 | ans = all_answers[rs[0].ans] 221 | else: 222 | n_missing_rules += 1 223 | ans = "yes" 224 | predictions.append( 225 | {"question_id": qid, "answer": ans,} 226 | ) 227 | print("missing rules for val predictions", n_missing_rules, "which is (%)", 100*n_missing_rules / len(matching_rules_val)) 228 | 229 | # save predictions and rules 230 | with open(os.path.join(save_dir, "rules_predictions.json"), "w") as f: 231 | json.dump(predictions, f) 232 | with open(os.path.join(save_dir, "rules.pickle"), "bw") as f: 233 | rules_tuple = [tuple(r) for r in rules] 234 | pickle.dump(rules_tuple, f) 235 | with open(os.path.join(save_dir, "counterexamples.json"), "w") as f: 236 | json.dump(qid_counterexamples, f) 237 | with open(os.path.join(save_dir, "easy.json"), "w") as f: 238 | json.dump(qid_easy, f) 239 | with open(os.path.join(save_dir, "hard.json"), "w") as f: 240 | json.dump(qid_hard, f) 241 | 242 | return rules, qid_easy, qid_counterexamples, qid_hard 243 | 244 | 245 | if __name__ == "__main__": 246 | import argparse 247 | 248 | parser = argparse.ArgumentParser() 249 | parser.add_argument("--save_dir", required=True) 250 | parser.add_argument("--support", default=2.1e-5, type=float) 251 | parser.add_argument("--max_length", default=5, type=int) 252 | parser.add_argument("--min_conf", default=0.3, type=float) 253 | parser.add_argument("--gminer_path") 254 | parser.add_argument("--visual_words", default="data/image_to_detection.json") 255 | parser.add_argument("--train_questions_path", default="data/vqa2/v2_OpenEnded_mscoco_train2014_questions.json") 256 | parser.add_argument("--train_annotations_path", default="data/vqa2/v2_mscoco_train2014_annotations.json") 257 | parser.add_argument("--val_questions_path", default="data/vqa2/v2_OpenEnded_mscoco_val2014_questions.json") 258 | parser.add_argument("--val_annotations_path", default="data/vqa2/v2_mscoco_val2014_annotations.json") 259 | parser.add_argument("--keep_all_rules_train_predictions", action="store_true", help="keep all rules instead of just one correct rule per training example. Only used for predictions.") 260 | args = parser.parse_args() 261 | 262 | (rules, qid_easy, qid_counterexamples, qid_hard) = vqa( 263 | support_gminer=args.support, 264 | max_length=args.max_length, 265 | min_conf=args.min_conf, 266 | save_dir=args.save_dir, 267 | gminer_path=args.gminer_path, 268 | visual_words=args.visual_words, 269 | train_questions_path=args.train_questions_path, 270 | train_annotations_path=args.train_annotations_path, 271 | val_questions_path=args.val_questions_path, 272 | val_annotations_path=args.val_annotations_path, 273 | keep_all_rules_train_predictions=args.keep_all_rules_train_predictions 274 | ) 275 | --------------------------------------------------------------------------------