├── .gitignore
├── README.md
├── create_gt_visual_words.py
├── download.sh
├── requirements.txt
├── rule_mining.py
├── rule_utils.py
├── test_mining.py
└── vqa.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | share/python-wheels/
24 | *.egg-info/
25 | .installed.cfg
26 | *.egg
27 | MANIFEST
28 |
29 | # PyInstaller
30 | # Usually these files are written by a python script from a template
31 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
32 | *.manifest
33 | *.spec
34 |
35 | # Installer logs
36 | pip-log.txt
37 | pip-delete-this-directory.txt
38 |
39 | # Unit test / coverage reports
40 | htmlcov/
41 | .tox/
42 | .nox/
43 | .coverage
44 | .coverage.*
45 | .cache
46 | nosetests.xml
47 | coverage.xml
48 | *.cover
49 | *.py,cover
50 | .hypothesis/
51 | .pytest_cache/
52 | cover/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | .pybuilder/
76 | target/
77 |
78 | # Jupyter Notebook
79 | .ipynb_checkpoints
80 |
81 | # IPython
82 | profile_default/
83 | ipython_config.py
84 |
85 | # pyenv
86 | # For a library or package, you might want to ignore these files since the code is
87 | # intended to run in multiple environments; otherwise, check them in:
88 | # .python-version
89 |
90 | # pipenv
91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
94 | # install all needed dependencies.
95 | #Pipfile.lock
96 |
97 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
98 | __pypackages__/
99 |
100 | # Celery stuff
101 | celerybeat-schedule
102 | celerybeat.pid
103 |
104 | # SageMath parsed files
105 | *.sage.py
106 |
107 | # Environments
108 | .env
109 | .venv
110 | env/
111 | venv/
112 | ENV/
113 | env.bak/
114 | venv.bak/
115 |
116 | # Spyder project settings
117 | .spyderproject
118 | .spyproject
119 |
120 | # Rope project settings
121 | .ropeproject
122 |
123 | # mkdocs documentation
124 | /site
125 |
126 | # mypy
127 | .mypy_cache/
128 | .dmypy.json
129 | dmypy.json
130 |
131 | # Pyre type checker
132 | .pyre/
133 |
134 | # pytype static type analyzer
135 | .pytype/
136 |
137 | # Cython debug symbols
138 | cython_debug/
139 |
140 | .vscode/
141 | logs/
142 | data/
143 |
144 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 | # Rule-Mining for shortcut discovery (VQA-CE)
3 |
4 |
5 |
6 |
7 | This repo contains the rule mining pipeline described in the article :
8 | **Beyond Question-Based Biases: Assessing Multimodal Shortcut Learning in Visual Question Answering** by Corentin Dancette, Rémi Cadène, Damien Teney and
9 | Matthieu Cord (https://arxiv.org/abs/2104.03149).
10 | It also provides the VQA-CE dataset.
11 |
12 | Website here: https://cdancette.fr/projects/vqa-ce/
13 |
14 | This code was developped with python 3.7 and pytorch 1.7.0.
15 |
16 | ### VQA-CE
17 | The VQA-CE counterexamples subset can be downloaded here :
18 | - counterexamples: https://github.com/cdancette/detect-shortcuts/releases/download/v1.0/counterexamples.json
19 | - hard: https://github.com/cdancette/detect-shortcuts/releases/download/v1.0/hard.json
20 |
21 | The "easy" subset can be obtained by substracting counterexamples and hard from all question_ids.
22 |
23 | ## Usage
24 |
25 | ### Installing requirements
26 | First, you need to install gminer. Follow instructions at https://github.com/cdancette/GMiner.
27 |
28 | For python requirements, run `pip install -r requirements.txt`. This will install pytorch, numpy and tqdm.
29 | ### Visual Question Answering (VQA)
30 |
31 | #### Download VQA and COCO data
32 |
33 | First, run `./download.sh`. Data will be downloaded in the `./data` directory.
34 |
35 | #### Run the rule mining pipeline
36 |
37 | Then run `python vqa.py --gminer_path ` to run our pipeline on the VQA v2 dataset.
38 | You can change the parameters, see the end of the `vqa.py` file or run `python vqa.py --help`.
39 |
40 | This will save in logs/vqa2 various files containing the rules found in the dataset,
41 | the question_ids for easy and counterexamples splits, and the predictions made by the rule model.
42 |
43 | To evaluate predictions, you can use the [multimodal](https://github.com/cdancette/multimodal) library:
44 |
45 | ```bash
46 | pip install multimodal
47 | python -m multimodal vqa2-eval -p logs/vqa2/rules_predictions.json --split val
48 | ```
49 |
50 |
51 | ### Other task
52 |
53 |
54 | #### fit
55 | You can use our library to extract rule for any other dataset.
56 |
57 | To do so, you can use the `fit` function in our `rule_mining.py`
58 | It takes the following arguments :
59 | `fit(dataset, answer_ids, gminer_support=0.01, gminer_max_length=0, gminer_path=None)`, where :
60 |
61 | - `dataset` is a list of transactions. Each transaction is a list of integers describing tokens.
62 | - `answer_ids` is a list of integers, describing answer ids. They should be contained between 0 and max answer id.
63 | - `gminer_support` is the minimum support used to mine frequent itemset.
64 | - `gminer_max_length`: minimum length of an itemset. By default no minimum length
65 | - `gminer_path`: path to the gminer binary you compiled (see top of the readme).
66 |
67 |
68 | The function returns a list of rules, contained in namedtuples: `Rule = namedtuple("Rule", ["itemset", "ans", "sup", "conf"])`.
69 |
70 | The itemset contains the input token ids, ans is the answer id, sup and conf are the support and the confidence of this rule.
71 |
72 | #### match_rules
73 |
74 | We provide a function to get, for each example in your dataset, all rules matching its input.
75 |
76 | `match_rules(dataset, rules, answers=None, bsize=500)`
77 |
78 | This will return `(matching_rules, correct_rules)`, where `matching_rules` is a list of the same length as the dataset, giving for each example, the matching rules.
79 |
80 | You can use this to build your counterexamples subset (examples where all rules are incorrect), or your easy subset (where at least one rule is correct).
81 |
--------------------------------------------------------------------------------
/create_gt_visual_words.py:
--------------------------------------------------------------------------------
1 | import json
2 | from collections import defaultdict
3 | from tqdm import tqdm
4 |
5 |
6 | with open("/data/common-data/coco/annotations/instances_train2014.json") as f:
7 | data_train = json.load(f)
8 | with open("/data/common-data/coco/annotations/instances_val2014.json") as f:
9 | data_val = json.load(f)
10 |
11 | result_path = "data/image_to_gt_vocab.json"
12 |
13 | result = defaultdict(lambda: {"classes": [], "scores": []})
14 |
15 | categories = {c["id"]: c for c in data_train["categories"]}
16 |
17 | for annot in tqdm(data_train["annotations"]):
18 | image_id = annot["image_id"]
19 | category = categories[annot["category_id"]]["name"]
20 | if category not in result[image_id]["classes"]:
21 | result[image_id]["classes"].append(category)
22 | result[image_id]["scores"].append(1.0)
23 | print(len(result))
24 |
25 | for annot in tqdm(data_val["annotations"]):
26 | image_id = annot["image_id"]
27 | category = categories[annot["category_id"]]["name"]
28 | if category not in result[image_id]["classes"]:
29 | result[image_id]["classes"].append(category)
30 | result[image_id]["scores"].append(1.0)
31 | print(len(result))
32 |
33 | with open(result_path, "w") as f:
34 | json.dump(result, f)
35 |
--------------------------------------------------------------------------------
/download.sh:
--------------------------------------------------------------------------------
1 | mkdir -p data/vqa2
2 | cd data/vqa2
3 | wget https://s3.amazonaws.com/cvmlp/vqa/mscoco/vqa/v2_Questions_Train_mscoco.zip
4 | wget https://s3.amazonaws.com/cvmlp/vqa/mscoco/vqa/v2_Questions_Val_mscoco.zip
5 | wget https://s3.amazonaws.com/cvmlp/vqa/mscoco/vqa/v2_Annotations_Train_mscoco.zip
6 | wget https://s3.amazonaws.com/cvmlp/vqa/mscoco/vqa/v2_Annotations_Val_mscoco.zip
7 |
8 | unzip "*mscoco.zip"
9 | rm *mscoco.zip
10 |
11 | mkdir -p ../data/coco
12 | cd ../data/coco
13 | wget https://github.com/cdancette/detect-shortcuts/releases/download/v1.0/image_to_detection.json
14 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | torch
2 | numpy
3 | tqdm
--------------------------------------------------------------------------------
/rule_mining.py:
--------------------------------------------------------------------------------
1 | import os
2 | from collections import namedtuple
3 | import sys
4 | import json
5 | from typing import List
6 | from tempfile import TemporaryDirectory
7 |
8 | from tqdm import tqdm
9 | import numpy as np
10 | import torch
11 |
12 | from rule_utils import superset_filtering
13 |
14 | def loadjson(path):
15 | with open(path) as f:
16 | return json.load(f)
17 |
18 |
19 | Rule = namedtuple("Rule", ["itemset", "ans", "sup", "conf"])
20 |
21 | def run_gminer(transactions, support, max_length=0, gminer_path=None):
22 | with TemporaryDirectory() as tempdir:
23 | path_gminer_in = tempdir + f"/gminer_in.txt"
24 | path_gminer_out = tempdir + f"/gminer_out"
25 |
26 | # convert trans_by_ans to GMiner format
27 | print("Converting transactions to GMiner input format")
28 | if not os.path.exists(path_gminer_in):
29 | with open(path_gminer_in, "w") as f:
30 | for trans in tqdm(transactions):
31 | trans = " ".join([str(x) for x in trans])
32 | f.write(trans + "\n")
33 |
34 | print("Running GMiner")
35 | print(f"Number of transactions : {len(transactions)}")
36 | print(f"Number of items : {max([max(t) for t in transactions])}")
37 |
38 | if support * len(transactions) < 1:
39 | min_support = 1 / len(transactions)
40 | print(
41 | f"Warning: Number of transactions * support = {support * len(transactions)} is below 1. "
42 | f"Minimum support is {min_support}",
43 | )
44 | sys.exit(1)
45 | if gminer_path is None:
46 | gminer_path = "./GMiner"
47 | command = (
48 | f"{gminer_path} -i {path_gminer_in} -o {path_gminer_out} -s {support} -w 1"
49 | )
50 | if max_length != 0:
51 | command += f" -l {max_length}"
52 |
53 | print("Running Gminer:", command)
54 | out = os.system(command)
55 | if out != 0:
56 | os.remove(path_gminer_out)
57 | sys.exit(1)
58 | print("Done running gminer")
59 |
60 | itemsets = []
61 | print("Parsing Gminer output", path_gminer_out)
62 | with open(path_gminer_out, "r") as f:
63 | for line in tqdm(f):
64 | line = line.strip()
65 | tmp = line.split(" ")
66 | itemset = [int(x) for x in tmp[:-1]]
67 | supp = float(tmp[-1][:-1][1:])
68 | itemsets.append((itemset, supp))
69 | return itemsets
70 |
71 |
72 | def fit(
73 | dataset,
74 | answer_ids,
75 | gminer_support=0.01,
76 | gminer_max_length=0,
77 | gminer_path=None,
78 | ):
79 | """
80 | train_dataset: list of token ids
81 | train_answers: list of answer ids
82 | """
83 |
84 | max_token_id = max(max(t) for t in dataset)
85 | answer_ids = [t + max_token_id + 1 for t in answer_ids]
86 | item_id_to_ans_id = {t: t - max_token_id - 1 for t in answer_ids}
87 |
88 | transactions = [
89 | items + [ans_id] for (items, ans_id) in zip(dataset, answer_ids)
90 | ]
91 |
92 | print(f"Minimum number of examples per rule: {gminer_support * len(transactions)}")
93 |
94 | itemsets = run_gminer(
95 | transactions,
96 | support=gminer_support,
97 | max_length=gminer_max_length,
98 | gminer_path=gminer_path,
99 | )
100 |
101 | supports_by_itemset = {}
102 | supports_by_itemset[()] = 1.0 # initialize empty tuple
103 | for itemset, support in itemsets:
104 | itemset = tuple(sorted(itemset))
105 | supports_by_itemset[itemset] = support
106 |
107 | # Extracting rules (itemsets with answers)
108 | pre_rules = []
109 | for (itemset, support_with_ans) in itemsets:
110 | for i, it in enumerate(itemset):
111 | if it in item_id_to_ans_id:
112 | del itemset[i]
113 | pre_rules.append(
114 | (tuple(sorted(itemset)), item_id_to_ans_id[it], support_with_ans)
115 | )
116 | break
117 | print(f"Number of rules : {len(pre_rules)}")
118 | # Computing confidence on training set
119 | print("Computing confidences on training set")
120 | rules: List[Rule] = []
121 | for rule in tqdm(pre_rules):
122 | itemset, ans, support_with_ans = rule
123 | if len(itemset) == 0:
124 | confidence = support
125 | elif itemset in supports_by_itemset:
126 | # add confidence
127 | confidence = support_with_ans / supports_by_itemset[itemset]
128 | else:
129 | print(f"Missing data for itemset {itemset}...")
130 | rule = Rule(
131 | itemset=itemset, ans=ans, sup=supports_by_itemset[itemset], conf=confidence
132 | )
133 | rules.append(rule)
134 |
135 | ##########################
136 | # SUPERSET Filtering
137 | ##########################
138 | # Here, we remove an itemset if
139 | # there was a previous itemset that is
140 | # a subset of it, and had better conf
141 | # and the same answer
142 | # Also, if the itemset was previously in the rules,
143 | # then we discard it (it means that there was another
144 | # rule with another answer which has a better confidence).
145 | print("Performing superset filtering")
146 | rules = superset_filtering(rules)
147 |
148 | rules = sorted(
149 | rules, key=lambda r: (-r.conf, -r.sup, len(r.itemset))
150 | ) # conf, support, length
151 |
152 | print(f"Number of rules obtained from training set : {len(rules)}")
153 | return rules
154 |
155 |
156 | def match_rules(
157 | dataset,
158 | rules: List[Rule],
159 | answers=None,
160 | bsize=500,
161 | stop_all_have_rules=False,
162 | stop_all_correct_rules=False,
163 | ):
164 | """
165 | This function will return lists of all rules that match a given example in the dataset.
166 | Args:
167 | dataset: list of list of token ids
168 | rules (List[Rule]): list of Rules
169 | answers: List[int]
170 | """
171 | # filling transaction matrix
172 | max_word_id = max(max(d) for d in dataset)
173 | transactions_matrix = np.zeros((len(dataset), max_word_id + 1), dtype=bool)
174 | for i, d in enumerate(dataset):
175 | transactions_matrix[i, d] = True
176 |
177 | transactions_matrix = torch.from_numpy(transactions_matrix).bool().cuda()
178 | pad_index = transactions_matrix.shape[1]
179 | N = transactions_matrix.shape[0]
180 |
181 | # pad index
182 | transactions_matrix = torch.cat(
183 | (transactions_matrix, torch.ones(N, 1).bool().cuda()), dim=1,
184 | )
185 |
186 | best_rules = dict()
187 | best_correct_rule = dict()
188 | all_rules = [[] for _ in range(len(transactions_matrix))]
189 | correct_rules = [[] for _ in range(len(transactions_matrix))]
190 |
191 | # Progress bars and iterables
192 | pbar = tqdm(total=len(transactions_matrix))
193 | pbar.set_description("Total rules found ")
194 | pbar_correct = tqdm(total=len(transactions_matrix))
195 | pbar_correct.set_description("Correct rules found")
196 |
197 | for i in tqdm(range(0, len(rules), bsize), desc="Rules processed"):
198 | rs = rules[i : i + bsize]
199 | itemsets = [r.itemset for r in rs]
200 | max_length = max([len(r) for r in itemsets])
201 | itemsets = [list(r) + [pad_index] * (max_length - len(r)) for r in itemsets]
202 | indexes_concerned = (
203 | (transactions_matrix[:, itemsets].all(dim=2).nonzero())
204 | .detach()
205 | .cpu()
206 | .numpy()
207 | ) # (N * 2) where 2 = (trans_id, rule_id)
208 | transactions_for_rule = [[] for _ in range(len(rs))]
209 |
210 | num_trans_found = 0
211 | num_correct_trans_found = 0
212 |
213 | for j in range(len(indexes_concerned)):
214 | trans_id, rule_id = indexes_concerned[j]
215 | rule_id = rule_id + i
216 | rule = rules[rule_id]
217 | transactions_for_rule[rule_id - i].append(trans_id)
218 | if trans_id not in best_rules:
219 | num_trans_found += 1
220 | best_rules[trans_id] = rule
221 | all_rules[trans_id].append(rule)
222 | if rule.ans == answers[trans_id]:
223 | if trans_id not in best_correct_rule:
224 | best_correct_rule[trans_id] = rule
225 | num_correct_trans_found += 1
226 | correct_rules[trans_id].append(rule)
227 |
228 | pbar.update(num_trans_found)
229 | pbar_correct.update(num_correct_trans_found)
230 |
231 | if stop_all_have_rules and len(best_rules) == len(transactions_matrix):
232 | break
233 | if stop_all_correct_rules and len(best_correct_rule) == len(
234 | transactions_matrix
235 | ):
236 | break
237 | pbar.close()
238 | pbar_correct.close()
239 | del transactions_matrix
240 |
241 | return (
242 | all_rules,
243 | correct_rules,
244 | )
245 |
--------------------------------------------------------------------------------
/rule_utils.py:
--------------------------------------------------------------------------------
1 | from itertools import combinations
2 | from tqdm import tqdm
3 |
4 | qid_to_annot = dict()
5 |
6 | def superset_filtering(rules):
7 | """
8 | Two goals:
9 | - remove duplicate rules (ie, rules that have the same itemset, but different answers).
10 | We keep only the rule with the best confidence.
11 | - remove rules that are useless, because they are a superset of a previous rule (so they are more constrained, thus
12 | they have a smaller support), but also have a smaller confidence.
13 | """
14 | rules_sorted = sorted(
15 | rules, key=lambda r: (len(r.itemset), -r.conf)
16 | ) # sorted by length (up), confidence (down)
17 | rules = []
18 | rule_by_itemset = dict() # itemset -> list of rules
19 | # rules_discarded = defaultdict(set)
20 | for rule in tqdm(rules_sorted):
21 | itemset, aid, support, conf = rule
22 | itemset = frozenset(itemset)
23 | discard_rule = False
24 | if itemset in rule_by_itemset:
25 | continue
26 | else:
27 | rule_by_itemset[itemset] = rule
28 |
29 | if len(itemset) > 0:
30 | for it in combinations(itemset, len(itemset) - 1):
31 | it = frozenset(it)
32 | if it in rule_by_itemset:
33 | old_r = rule_by_itemset[it]
34 | if old_r.conf >= conf and old_r.ans == aid:
35 | discard_rule = True
36 | break
37 | if not discard_rule:
38 | rules.append(rule)
39 | print(
40 | f"After discarding rules, going from {len(rules_sorted)} to {len(rules)} rules."
41 | )
42 | return rules
43 |
44 |
45 | def test_superset_filtering():
46 | # various test cases that we want to manage
47 | assert superset_filtering([[(), 0.1, 10, 0.1]]) == [[(), 0.1, 10, 0.1]]
48 |
49 | rules = [
50 | [(), 0.1, 10, 0.1],
51 | [(0,), 0.1, 0, 0.5],
52 | [(0, 1), 0.1, 0, 0.5],
53 | [(0, 1, 2), 0.1, 0, 0.5],
54 | ]
55 | assert superset_filtering(rules) == [[(), 0.1, 10, 0.1], [(0,), 0.1, 0, 0.5]]
56 |
57 | rules = [
58 | [(), 0.1, 10, 0.1],
59 | [(0,), 0.1, 0, 0.5],
60 | [(0, 1), 0.1, 0, 0.3],
61 | [(0, 1, 2), 0.1, 0, 0.2],
62 | [(0, 1, 5), 0.1, 3, 0.2], # additional to keep
63 | ]
64 |
65 | assert superset_filtering(rules) == [
66 | [(), 0.1, 10, 0.1],
67 | [(0,), 0.1, 0, 0.5],
68 | [(0, 1, 5), 0.1, 3, 0.2],
69 | ]
70 |
71 | # TODO this test fails.. it is quite bad, because it could allow
72 | # us to discard a lot of useless rules...
73 | rules = [
74 | [(), 0.1, 0, 0.5],
75 | [(0,), 0.1, 0, 0.3],
76 | [(0, 1), 0.1, 0, 0.4],
77 | ]
78 | # assert superset_filtering(rules) == [
79 | # [(), 0.1, 0, 0.5],
80 | # ]
81 |
82 | rules = [
83 | [(), 0.1, 10, 0.1],
84 | [(0,), 0.1, 0, 0.5],
85 | [(0, 1), 0.1, 0, 0.6],
86 | [(0, 1, 2), 0.1, 0, 0.7],
87 | ]
88 |
89 | assert superset_filtering(rules) == [
90 | [(), 0.1, 10, 0.1],
91 | [(0,), 0.1, 0, 0.5],
92 | [(0, 1), 0.1, 0, 0.6],
93 | [(0, 1, 2), 0.1, 0, 0.7],
94 | ]
95 |
96 | # same itemset, different answers, one is better (confidence)
97 | # We keep only the best.
98 | rules = [
99 | [(), 0.1, 10, 0.2],
100 | [(), 0.1, 5, 0.1],
101 | ]
102 |
103 | assert superset_filtering(rules) == [[(), 0.1, 10, 0.2]]
104 |
--------------------------------------------------------------------------------
/test_mining.py:
--------------------------------------------------------------------------------
1 | from tempfile import TemporaryDirectory
2 | from rule_mining import fit, Rule, match_rules
3 | # from rule_utils import match_rules
4 |
5 | def test_fit():
6 | dataset = [
7 | [0, 1, 2],
8 | [0, 1, 2],
9 | [0, 1, 3],
10 | [0, 1, 3],
11 | [0, 1, 4],
12 | [0, 1, 4],
13 | ]
14 | answers = [0, 0, 1, 1, 2, 2]
15 | rules = fit(dataset, answers, support_gminer=0.2)
16 |
17 | assert len(rules) == 4
18 | for k, (itemset, ans) in enumerate([[(2,), 0], [(3,), 1], [(4,), 2], [(), 0],]):
19 | assert rules[k].itemset == itemset
20 | assert rules[k].ans == ans
21 |
22 | # match_rules(dataset, rules)
23 | all_rules, correct_rules = match_rules(dataset, rules, answers=answers, bsize=10)
24 | print(all_rules)
25 |
26 | # item 0
27 | assert len(all_rules[0]) == 2
28 | assert all_rules[0][0].itemset == (2,)
29 | assert all_rules[0][1].itemset == ()
30 | assert len(correct_rules[0]) == 2
31 | assert correct_rules[0][0].itemset == (2,)
32 | assert correct_rules[0][1].itemset == ()
33 |
34 | # item 2
35 | assert len(all_rules[2]) == 2
36 | assert all_rules[2][0].itemset == (3,)
37 | assert all_rules[2][1].itemset == ()
38 | assert len(correct_rules[2]) == 1
39 | assert correct_rules[2][0].itemset == (3,)
40 |
--------------------------------------------------------------------------------
/vqa.py:
--------------------------------------------------------------------------------
1 | from rule_mining import Rule, fit, match_rules
2 | import pickle
3 | import os
4 | from collections import Counter
5 | import json
6 | from typing import List
7 | from tqdm import tqdm
8 | from torchtext.data.utils import get_tokenizer
9 |
10 |
11 | def loadjson(path):
12 | with open(path) as f:
13 | return json.load(f)
14 |
15 |
16 | def create_dataset(
17 | questions,
18 | visual_words="data/image_to_detection.json",
19 | annotations=None,
20 | textual=True,
21 | visual=True,
22 | visual_threshold=0.5,
23 | proportion=1.0,
24 | most_common_answers=None,
25 | ):
26 | print("Creating VQA binary dataset")
27 | print(f"Loading visual words at {visual_words}")
28 | visual_words = loadjson(visual_words)
29 |
30 | tokenizer = get_tokenizer("basic_english")
31 |
32 | total_len = int(len(questions) * proportion)
33 | transactions = []
34 | answers = []
35 | indexes = []
36 |
37 | skipped = 0
38 |
39 | #############
40 | # Regular textual
41 | #############
42 | for i in tqdm(range(total_len)):
43 | transaction = []
44 |
45 | if textual:
46 | tokens = tokenizer(questions[i]["question"])
47 | transaction.extend(tokens)
48 |
49 | if visual:
50 | image_id = str(questions[i]["image_id"])
51 | if image_id in visual_words:
52 | vwords = visual_words[image_id]
53 | classes = vwords["classes"]
54 | scores = vwords["scores"]
55 | if visual_threshold != 0:
56 | classes = [
57 | c
58 | for (i, c) in enumerate(classes)
59 | if scores[i] >= visual_threshold
60 | ]
61 | classes = ["V_" + c for c in classes] # visual marker
62 | transaction.extend(classes)
63 | transactions.append(transaction)
64 | indexes.append(i)
65 | if annotations is not None:
66 | answers.append(annotations[i]["multiple_choice_answer"])
67 |
68 | assert len(transactions) == len(answers)
69 |
70 | if annotations is not None and most_common_answers is not None:
71 | occurences = Counter(answers).most_common(most_common_answers)
72 | keep_answers = set(a for (a, _) in occurences)
73 |
74 | new_transactions = []
75 | new_answers = []
76 | new_indexes = []
77 | for k in range(len(transactions)):
78 | if answers[k] in keep_answers:
79 | new_transactions.append(transactions[k])
80 | new_answers.append(answers[k])
81 | new_indexes.append(indexes[k])
82 | transactions, answers, indexes = new_transactions, new_answers, new_indexes
83 |
84 | if annotations is not None:
85 | return transactions, answers, indexes
86 |
87 | return transactions, indexes
88 |
89 |
90 | def vqa(
91 | textual=True,
92 | visual=True,
93 | train_questions_path="data/vqa2/v2_OpenEnded_mscoco_train2014_questions.json",
94 | train_annotations_path="data/vqa2/v2_mscoco_train2014_annotations.json",
95 | val_questions_path=None,
96 | val_annotations_path=None,
97 | visual_threshold=0.5,
98 | support_gminer=2e-5,
99 | gminer_path=None,
100 | min_conf=0.3,
101 | max_length=5,
102 | version="vqa2",
103 | save_dir=None,
104 | keep_all_rules_train_predictions=False,
105 | visual_words="data/image_to_detection.json",
106 | ):
107 |
108 | train_questions = loadjson(train_questions_path)
109 | train_annotations = loadjson(train_annotations_path)
110 | if type(train_questions) == dict and "questions" in train_questions:
111 | train_questions = train_questions["questions"]
112 | train_annotations = train_annotations["annotations"]
113 |
114 |
115 | os.makedirs(save_dir, exist_ok=True)
116 | train_dataset, train_answers, train_indexes = create_dataset(
117 | train_questions,
118 | annotations=train_annotations,
119 | proportion=1.0,
120 | most_common_answers=3000,
121 | textual=textual,
122 | visual=visual,
123 | visual_threshold=visual_threshold,
124 | visual_words=visual_words,
125 | )
126 | tokens = list(set(t for transaction in train_dataset for t in transaction))
127 | token_to_id = {t: i for (i, t) in enumerate(tokens)}
128 | train_transactions = [
129 | [token_to_id[t] for t in transaction] for transaction in train_dataset
130 | ]
131 | all_answers = list(set(train_answers))
132 | ans_to_id = {ans: i for (i, ans) in enumerate(all_answers)}
133 | train_answers_ids = [ans_to_id[ans] for ans in train_answers]
134 |
135 | # rule mining
136 | rules: List[Rule] = fit(
137 | train_transactions,
138 | train_answers_ids,
139 | gminer_support=support_gminer,
140 | gminer_max_length=max_length,
141 | gminer_path=gminer_path,
142 | )
143 |
144 | # - keep only rules with confidence > min_conf
145 | rules = [r for r in rules if r.conf >= min_conf]
146 |
147 | # show the best 20 rules
148 | for r in rules[:20]:
149 | print([tokens[tid] for tid in r.itemset], all_answers[r.ans], r.sup, r.conf)
150 |
151 | # match rules with examples
152 | matching_rules_train, matching_correct_rules_train = match_rules(
153 | train_transactions, rules, answers=train_answers_ids
154 | )
155 |
156 | # val
157 | val_questions = loadjson(val_questions_path)
158 | val_annotations = loadjson(val_annotations_path)
159 |
160 | if type(val_questions) == dict and "questions" in val_questions:
161 | val_questions = val_questions["questions"]
162 | val_annotations = val_annotations["annotations"]
163 |
164 | val_dataset, val_answers, val_indexes = create_dataset(
165 | val_questions,
166 | annotations=val_annotations,
167 | proportion=1.0,
168 | textual=textual,
169 | visual=visual,
170 | visual_threshold=visual_threshold,
171 | visual_words=visual_words,
172 | )
173 |
174 | val_transactions = [
175 | [token_to_id[t] for t in transaction if t in token_to_id]
176 | for transaction in val_dataset
177 | ]
178 | val_answers_ids = [ans_to_id.get(ans, -1) for ans in val_answers]
179 |
180 | matching_rules_val, matching_correct_rules_val = match_rules(
181 | val_transactions, rules, val_answers_ids
182 | )
183 |
184 | # - create hard evaluations set
185 | # we load annotations, because we'll consider every answer, not only the top answer.
186 | qid_counterexamples = []
187 | qid_easy = []
188 | qid_hard = []
189 | for annot, rs in zip(val_annotations, matching_rules_val):
190 | possible_answers = set(ans["answer"] for ans in annot["answers"])
191 | rules_answers = set(all_answers[r.ans] for r in rs)
192 | if len(set.intersection(rules_answers, possible_answers)) == 0 and len(rs) != 0:
193 | # goes into counterexamples
194 | qid_counterexamples.append(annot["question_id"])
195 | elif len(rs) == 0:
196 | qid_hard.append(annot["question_id"])
197 | else:
198 | qid_easy.append(annot["question_id"])
199 |
200 | # keep_rules:
201 | # we keep only one correct rule per training example
202 | if not keep_all_rules_train_predictions:
203 | keep_rules = set()
204 | for rs in matching_correct_rules_train:
205 | if rs:
206 | keep_rules.add(rs[0])
207 | print("Rules kept after keeping only one per training example:", len(keep_rules))
208 | else:
209 | keep_rules = None
210 |
211 |
212 | # build predictions on validation set.
213 | n_missing_rules = 0
214 | predictions = []
215 | for i, rs in enumerate(matching_rules_val):
216 | qid = val_questions[val_indexes[i]]["question_id"]
217 | if keep_rules is not None:
218 | rs = [r for r in rs if r in keep_rules]
219 | if rs:
220 | ans = all_answers[rs[0].ans]
221 | else:
222 | n_missing_rules += 1
223 | ans = "yes"
224 | predictions.append(
225 | {"question_id": qid, "answer": ans,}
226 | )
227 | print("missing rules for val predictions", n_missing_rules, "which is (%)", 100*n_missing_rules / len(matching_rules_val))
228 |
229 | # save predictions and rules
230 | with open(os.path.join(save_dir, "rules_predictions.json"), "w") as f:
231 | json.dump(predictions, f)
232 | with open(os.path.join(save_dir, "rules.pickle"), "bw") as f:
233 | rules_tuple = [tuple(r) for r in rules]
234 | pickle.dump(rules_tuple, f)
235 | with open(os.path.join(save_dir, "counterexamples.json"), "w") as f:
236 | json.dump(qid_counterexamples, f)
237 | with open(os.path.join(save_dir, "easy.json"), "w") as f:
238 | json.dump(qid_easy, f)
239 | with open(os.path.join(save_dir, "hard.json"), "w") as f:
240 | json.dump(qid_hard, f)
241 |
242 | return rules, qid_easy, qid_counterexamples, qid_hard
243 |
244 |
245 | if __name__ == "__main__":
246 | import argparse
247 |
248 | parser = argparse.ArgumentParser()
249 | parser.add_argument("--save_dir", required=True)
250 | parser.add_argument("--support", default=2.1e-5, type=float)
251 | parser.add_argument("--max_length", default=5, type=int)
252 | parser.add_argument("--min_conf", default=0.3, type=float)
253 | parser.add_argument("--gminer_path")
254 | parser.add_argument("--visual_words", default="data/image_to_detection.json")
255 | parser.add_argument("--train_questions_path", default="data/vqa2/v2_OpenEnded_mscoco_train2014_questions.json")
256 | parser.add_argument("--train_annotations_path", default="data/vqa2/v2_mscoco_train2014_annotations.json")
257 | parser.add_argument("--val_questions_path", default="data/vqa2/v2_OpenEnded_mscoco_val2014_questions.json")
258 | parser.add_argument("--val_annotations_path", default="data/vqa2/v2_mscoco_val2014_annotations.json")
259 | parser.add_argument("--keep_all_rules_train_predictions", action="store_true", help="keep all rules instead of just one correct rule per training example. Only used for predictions.")
260 | args = parser.parse_args()
261 |
262 | (rules, qid_easy, qid_counterexamples, qid_hard) = vqa(
263 | support_gminer=args.support,
264 | max_length=args.max_length,
265 | min_conf=args.min_conf,
266 | save_dir=args.save_dir,
267 | gminer_path=args.gminer_path,
268 | visual_words=args.visual_words,
269 | train_questions_path=args.train_questions_path,
270 | train_annotations_path=args.train_annotations_path,
271 | val_questions_path=args.val_questions_path,
272 | val_annotations_path=args.val_annotations_path,
273 | keep_all_rules_train_predictions=args.keep_all_rules_train_predictions
274 | )
275 |
--------------------------------------------------------------------------------