22 |
23 | """
24 |
25 | html_footer = """
26 |
27 |
28 | """
29 |
30 | row_header = """
31 |
32 | """
33 |
34 | element_header = """
35 |
36 | """
37 |
38 |
39 | class vqa_html_writer:
40 | def __init__(self, file_path, elements_per_row=4):
41 | self._writer = open(file_path, "w")
42 | self._writer.write(html_header)
43 | self.count = 0
44 | self.elements_per_row = elements_per_row
45 |
46 | def write_element(self, image, **kwarg):
47 | if self.count % self.elements_per_row == 0:
48 | self._writer.write(row_header + "\n")
49 | self._writer.write(element_header)
50 | self._writer.write('

')
51 | for key, value in kwarg.items():
52 | self._writer.write("
%s : %s
" % (key, value))
53 | self._writer.write("
")
54 | self.count += 1
55 | if self.count % self.elements_per_row == 0 and self.count > 0:
56 | self._writer.write("
")
57 |
58 | def close(self):
59 | if self.count % self.elements_per_row != 0:
60 | self._writer.write("
")
61 | self._writer.write(html_footer)
62 | self._writer.close()
63 |
64 |
65 | if __name__ == "__main__":
66 | html_writer = vqa_html_writer("/Users/tinayujiang/temp/test.html", 4)
67 | n = 10
68 | for i in range(10):
69 | image_path = (
70 | "/Users/tinayujiang/work/VQA/data_analysis/val2014/"
71 | + "COCO_val2014_000000290951.jpg"
72 | )
73 | info = {"question": "abcfs efc?", "answers": " wdds cdsde"}
74 | html_writer.write_element(image_path, **info)
75 |
76 | html_writer.close()
77 |
--------------------------------------------------------------------------------
/pythia/legacy/ensemble.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 | #
7 |
8 |
9 | import argparse
10 | import glob
11 | import json
12 |
13 | import numpy as np
14 |
15 | import _pickle as pickle
16 | from train_model.helper import print_result
17 |
18 |
19 | def parse_args():
20 | parser = argparse.ArgumentParser()
21 | parser.add_argument("--out", type=str, required=True, help="output file name")
22 | parser.add_argument(
23 | "--res_dirs",
24 | nargs="+",
25 | help="directories for results, NOTE:"
26 | "all *.pkl file under these dirs will be ensembled",
27 | default=None,
28 | )
29 | argments = parser.parse_args()
30 |
31 | return argments
32 |
33 |
34 | class answer_json:
35 | def __init__(self):
36 | self.answers = []
37 |
38 | def add(self, ques_id, ans):
39 | res = {"question_id": ques_id, "answer": ans}
40 | self.answers.append(res)
41 |
42 |
43 | if __name__ == "__main__":
44 |
45 | args = parse_args()
46 | result_dirs = args.res_dirs
47 | out_file = args.out
48 | question_ids = None
49 | soft_max_result = None
50 | ans_dic = None
51 | cnt = 0
52 | for res_dir in result_dirs:
53 | for file in glob.glob(res_dir + "/**/*.pkl", recursive=True):
54 | with open(file, "rb") as f:
55 | cnt += 1
56 | sm = pickle.load(f)
57 | if soft_max_result is None:
58 | soft_max_result = sm
59 | question_ids = pickle.load(f)
60 | ans_dic = pickle.load(f)
61 | else:
62 | soft_max_result += sm
63 |
64 | print("ensemble total %d models" % cnt)
65 |
66 | predicted_answers = np.argmax(soft_max_result, axis=1)
67 |
68 | pkl_file = out_file + ".pkl"
69 |
70 | print_result(question_ids, soft_max_result, ans_dic, out_file, False, pkl_file)
71 |
72 | print("Done")
73 |
--------------------------------------------------------------------------------
/pythia/legacy/eval_model/eval_demo.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 | #
7 |
8 |
9 | import json
10 | import sys
11 |
12 | from eval_model.vqaEval import VQAEval
13 |
14 |
15 | def parse_annotation(anno_file):
16 | with open(anno_file, "r") as f:
17 | annotations = json.load(f)["annotations"]
18 |
19 | q_2_anno = dict([(a["question_id"], a) for a in annotations])
20 | return q_2_anno
21 |
22 |
23 | def parse_ans(answ_file):
24 | with open(answ_file, "r") as f:
25 | answers = json.load(f)
26 |
27 | q_2_answ = dict([(a["question_id"], a) for a in answers])
28 | return q_2_answ
29 |
30 |
31 | if __name__ == "__main__":
32 | if len(sys.argv) < 3:
33 | exit(
34 | "USAGE: python eval_model/eval_demo.py \
35 | annotation_json_file answer_json_file"
36 | )
37 |
38 | anno_file = sys.argv[1]
39 | answ_file = sys.argv[2]
40 |
41 | q_2_anno = parse_annotation(anno_file)
42 | q_2_answ = parse_ans(answ_file)
43 |
44 | eval = VQAEval(q_2_anno, q_2_answ, 2)
45 | eval.evaluate()
46 | acc = eval.accuracy
47 | print(
48 | "overall: %.2f" % acc["overall"],
49 | "yes/no: %f" % acc["perAnswerType"]["yes/no"],
50 | "number: %.2f" % acc["perAnswerType"]["number"],
51 | "other: %.2f" % acc["perAnswerType"]["other"],
52 | )
53 |
--------------------------------------------------------------------------------
/pythia/legacy/global_variables/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 | #
7 |
--------------------------------------------------------------------------------
/pythia/legacy/global_variables/global_variables.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 | #
7 |
8 |
9 | import torch
10 |
11 | imdb_version = 1
12 | use_cuda = torch.cuda.is_available()
13 |
14 | model_type_gt = "gt_layout"
15 | model_type_scratch = "scratch"
16 | model_type_gt_rl = "gt+rl"
17 | model_type_top_down_bottom_up = "top_down_bottom_up"
18 |
19 |
20 | topdown_concate_attention = "concate_attention"
21 | topdown_project_attention = "project_attention"
22 |
--------------------------------------------------------------------------------
/pythia/legacy/info/code_structure_plot.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/guanghuixu/AnchorCaptioner/3a49ce5de025087cbea00075ec0636aee0525382/pythia/legacy/info/code_structure_plot.png
--------------------------------------------------------------------------------
/pythia/legacy/info/pythia.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/guanghuixu/AnchorCaptioner/3a49ce5de025087cbea00075ec0636aee0525382/pythia/legacy/info/pythia.jpg
--------------------------------------------------------------------------------
/pythia/legacy/info/vqa_example.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/guanghuixu/AnchorCaptioner/3a49ce5de025087cbea00075ec0636aee0525382/pythia/legacy/info/vqa_example.png
--------------------------------------------------------------------------------
/pythia/legacy/install.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | conda create --name vqa python=3.6
4 | source activate vqa
5 | pip install demjson pyyaml
6 |
7 | pip install http://download.pytorch.org/whl/cu90/torch-0.3.0-cp36-cp36m-linux_x86_64.whl
8 | pip install torchvision
9 | pip install tensorboardX
10 |
11 |
12 |
--------------------------------------------------------------------------------
/pythia/legacy/tools/convert_VG_to_COCO_qa.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 | #
7 |
8 |
9 | import json
10 | import string
11 |
12 | genome_data_file = "question_answers.json"
13 | genome_questions_file = "v2_OpenEnded_mscoco_genome_questions.json"
14 | genome_annotations_file = "v2_mscoco_genome_annotations.json"
15 |
16 | translator = str.maketrans("", "", string.punctuation)
17 | with open(genome_data_file, "r") as f:
18 | genome_data = json.load(f)
19 |
20 | genome_questions = []
21 | genome_annotations = []
22 |
23 | for data in genome_data:
24 | all_qas = data["qas"]
25 | for qas in all_qas:
26 | question = {}
27 | annotation = {}
28 | question["image_id"] = qas["image_id"]
29 | # assume unique question_id for every question answer pair
30 | question["question_id"] = qas["qa_id"]
31 | question["question"] = qas["question"]
32 | genome_questions.append(question)
33 | annotation["image_id"] = qas["image_id"]
34 | annotation["question_id"] = qas["qa_id"]
35 | answertxt = qas["answer"].translate(translator)
36 | answertxt = answertxt.lower()
37 | annotation["multiple_choice_answer"] = answertxt
38 | annotation["answers"] = []
39 | for i in range(10):
40 | answer = {}
41 | answer["answer"] = answertxt
42 | answer["answer_confifence"] = "yes"
43 | answer["answer_id"] = i + 1
44 | annotation["answers"].append(answer)
45 | genome_annotations.append(annotation)
46 |
47 | genome_data = {}
48 | genome_data["questions"] = genome_questions
49 |
50 | with open(genome_questions_file, "w") as f:
51 | json.dump(genome_data, f)
52 |
53 | genome_data = {}
54 | genome_data["annotations"] = genome_annotations
55 |
56 | with open(genome_annotations_file, "w") as f:
57 | json.dump(genome_data, f)
58 |
--------------------------------------------------------------------------------
/pythia/legacy/tools/convert_tsv_feature_to_indiv.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 | #
7 |
8 |
9 | import argparse
10 | import base64
11 | import csv
12 | import os
13 | import sys
14 |
15 | import numpy as np
16 |
17 | parser = argparse.ArgumentParser()
18 | parser.add_argument("--infile", type=str, required=True, help="input file")
19 | parser.add_argument("--label", type=str, required=True, help="label for dataset")
20 | parser.add_argument("--out_dir", type=str, required=True, help="imdb output directory")
21 | args = parser.parse_args()
22 |
23 | out_dir = args.out_dir
24 |
25 |
26 | csv.field_size_limit(sys.maxsize)
27 |
28 | FIELDNAMES = ["image_id", "image_w", "image_h", "num_boxes", "boxes", "features"]
29 | infile = args.infile
30 |
31 | label = args.label
32 |
33 | out_dir = os.path.join(out_dir, label)
34 |
35 | os.makedirs(out_dir, exist_ok=True)
36 |
37 | print("reading tsv...")
38 | with open(infile, "r") as tsv_in_file:
39 | reader = csv.DictReader(tsv_in_file, delimiter="\t", fieldnames=FIELDNAMES)
40 | for item in reader:
41 | item["num_boxes"] = int(item["num_boxes"])
42 | image_id = int(item["image_id"])
43 | image_w = float(item["image_w"])
44 | image_h = float(item["image_h"])
45 |
46 | image_bboxes = np.frombuffer(
47 | base64.b64decode(item["boxes"]), dtype=np.float32
48 | ).reshape((item["num_boxes"], -1))
49 |
50 | image_feat = np.frombuffer(
51 | base64.b64decode(item["features"]), dtype=np.float32
52 | ).reshape((item["num_boxes"], -1))
53 |
54 | image_feat_and_boxes = {"image_bboxes": image_bboxes, "image_feat": image_feat}
55 |
56 | image_file_name = os.path.join(
57 | out_dir, "COCO_" + label + "_%012d.npy" % image_id
58 | )
59 | np.save(image_file_name, image_feat_and_boxes)
60 |
--------------------------------------------------------------------------------
/pythia/legacy/tools/eval_ensemble_on_val.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 | #
7 |
8 |
9 | import glob
10 | import sys
11 |
12 | import torch
13 | import yaml
14 | from torch.utils.data import DataLoader
15 |
16 | from train_model.dataset_utils import prepare_eval_data_set
17 | from train_model.helper import build_model, run_model
18 |
19 | CONFIG = "config.yaml"
20 | MODELNAME = "best_model.pth"
21 |
22 | if __name__ == "__main__":
23 | if len(sys.argv) < 2:
24 | exit(
25 | "USAGE: python tools/eval_ensemble_on_val.py parent_dir \
26 | [ensemble sizes]"
27 | )
28 |
29 | esbl_sizes = [int(a) for a in sys.argv[2:]]
30 |
31 | parent_dir = sys.argv[1]
32 |
33 | model_pths = [
34 | file for file in glob.glob(parent_dir + "/**/" + MODELNAME, recursive=True)
35 | ]
36 | config_files = [c.replace(MODELNAME, CONFIG) for c in model_pths]
37 |
38 | if len(esbl_sizes) == 0:
39 | esbl_sizes = range(1, len(config_files) + 1)
40 |
41 | config_file = config_files[0]
42 |
43 | with open(config_file, "r") as f:
44 | config = yaml.load(f)
45 |
46 | batch_size = config["data"]["batch_size"]
47 | data_set_test = prepare_eval_data_set(
48 | **config["data"], **config["model"], verbose=True
49 | )
50 | data_reader_test = DataLoader(
51 | data_set_test, shuffle=False, batch_size=batch_size, num_workers=5
52 | )
53 | ans_dic = data_set_test.answer_dict
54 |
55 | accumulated_softmax = None
56 | final_result = {}
57 | n_model = 0
58 | for c_file, model_file in zip(config_files, model_pths):
59 | with open(c_file, "r") as f:
60 | config = yaml.load(f)
61 |
62 | myModel = build_model(config, data_set_test)
63 | myModel.load_state_dict(torch.load(model_file)["state_dict"])
64 |
65 | question_ids, soft_max_result = run_model(
66 | myModel, data_reader_test, ans_dic.UNK_idx
67 | )
68 |
69 | if n_model == 0:
70 | final_result = soft_max_result
71 |
--------------------------------------------------------------------------------
/pythia/legacy/tools/extract_detectron_weights.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 | #
7 |
8 |
9 | import os
10 | import pickle
11 | import sys
12 |
13 | if len(sys.argv) < 4:
14 | exit(
15 | "USAGE: python tools/extract_detectron_weights.py \
16 | weights_file out_dir feat_name [feat_name]"
17 | )
18 |
19 | wgts_file = sys.argv[1]
20 | out_dir = sys.argv[2]
21 |
22 | with open(wgts_file, "rb") as f:
23 | wgts = pickle.load(f, encoding="latin1")["blobs"]
24 |
25 | for i in range(3, len(sys.argv)):
26 | feat_name = sys.argv[i]
27 | wgt = wgts[feat_name]
28 | out_file = os.path.join(out_dir, feat_name + ".pkl")
29 | with open(out_file, "wb") as w:
30 | pickle.dump(wgt, w)
31 |
--------------------------------------------------------------------------------
/pythia/legacy/tools/extract_minival_ids.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 | #
7 |
8 |
9 | import json
10 | import pickle
11 |
12 |
13 | def extract_qid_imid(ques_json_file):
14 | with open(ques_json_file, "r") as f:
15 | info = json.load(f)
16 | questions = info["questions"]
17 |
18 | q_im_ids = []
19 | for q in questions:
20 | im_id = q["image_id"]
21 | q_id = q["question_id"]
22 | q_im_ids.append((im_id, q_id))
23 |
24 | return q_im_ids
25 |
26 |
27 | if __name__ == "__main__":
28 | minival_ques_file = "v2_OpenEnded_mscoco_minival2014_questions.json"
29 |
30 | val2train_ques_file = "v2_OpenEnded_mscoco_val2train2014_questions.json"
31 |
32 | minival_out_file = "data_prep/vqa_v2.0/minival_ids.pkl"
33 | val2train_out_file = "data_prep/vqa_v2.0/val2train_ids.pkl"
34 |
35 | minival_ids = extract_qid_imid(minival_ques_file)
36 | with open(minival_out_file, "wb") as w1:
37 | pickle.dump(minival_ids, w1)
38 |
39 | val2train_ids = extract_qid_imid(val2train_ques_file)
40 | with open(val2train_out_file, "wb") as w2:
41 | pickle.dump(val2train_ids, w2)
42 |
--------------------------------------------------------------------------------
/pythia/legacy/tools/extract_visual_features_vgg_pool5.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 | #
7 |
8 |
9 | import argparse
10 | import os
11 | import sys
12 | from glob import glob
13 |
14 | import numpy as np
15 | import torch
16 | import torch.nn as nn
17 | import torchvision.models as models
18 | from torch.autograd import Variable
19 |
20 | import skimage.color
21 | import skimage.io
22 | from global_variables.global_variables import use_cuda
23 |
24 | parser = argparse.ArgumentParser()
25 | parser.add_argument("--gpu_id", type=int, default=0)
26 | parser.add_argument("--data_dir", type=str, required=True)
27 | parser.add_argument("--out_dir", type=str, required=True)
28 |
29 | args = parser.parse_args()
30 | gpu_id = args.gpu_id # set GPU id to use
31 | os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu_id)
32 | sys.path.append("../../")
33 |
34 | image_basedir = args.data_dir
35 | save_basedir = args.out_dir
36 |
37 | channel_mean = np.array([123.68, 116.779, 103.939], dtype=np.float32)
38 |
39 |
40 | class vgg16_feature_module(nn.Module):
41 | def __init__(self, vgg16_model):
42 | super(vgg16_feature_module, self).__init__()
43 | self.feature_module = nn.Sequential(*list(list(vgg16_model.children())[0]))
44 |
45 | def forward(self, x):
46 | return self.feature_module(x)
47 |
48 |
49 | vgg16 = models.vgg16(pretrained=True)
50 | vgg16_feature = vgg16_feature_module(vgg16)
51 | vgg16_feature = vgg16_feature.cuda() if use_cuda else vgg16_feature
52 |
53 |
54 | def extract_image_pool5(impath):
55 | im = skimage.io.imread(impath)[..., :3]
56 | im_val = im[np.newaxis, ...] - channel_mean
57 |
58 | # permute to get NCHW
59 | im_val = np.transpose(im_val, axes=(0, 3, 1, 2))
60 | im_val_tensor = torch.FloatTensor(im_val)
61 | im_val_variable = Variable(im_val_tensor)
62 | im_val_variable = im_val_variable.cuda() if use_cuda else im_val_variable
63 |
64 | pool5_val = vgg16_feature(im_val_variable)
65 | return pool5_val.data.cpu().numpy()
66 |
67 |
68 | def extract_dataset_pool5(image_dir, save_dir, ext_filter="*.png"):
69 | image_list = glob(image_dir + "/" + ext_filter)
70 | os.makedirs(save_dir, exist_ok=True)
71 |
72 | for n_im, impath in enumerate(image_list):
73 | if (n_im + 1) % 100 == 0:
74 | print("processing %d / %d" % (n_im + 1, len(image_list)))
75 | image_name = os.path.basename(impath).split(".")[0]
76 | save_path = os.path.join(save_dir, image_name + ".npy")
77 | if not os.path.exists(save_path):
78 | pool5_val = extract_image_pool5(impath)
79 | np.save(save_path, pool5_val)
80 |
81 |
82 | for image_set in ["train", "val", "test"]:
83 | print("Extracting image set " + image_set)
84 | extract_dataset_pool5(
85 | os.path.join(image_basedir, image_set), os.path.join(save_basedir, image_set)
86 | )
87 | print("Done.")
88 |
--------------------------------------------------------------------------------
/pythia/legacy/tools/generate_minival_annotation.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 | #
7 |
8 |
9 | import json
10 | import pickle
11 |
12 | if __name__ == "__main__":
13 | val_annotation_file = "v2_mscoco_val2014_annotations.json"
14 | minival_id_file = "data/vqa_v2.0/minival_ids.pkl"
15 | minival_annotation_file = "v2_mscoco_minival2014_annotations.json"
16 |
17 | with open(minival_id_file, "rb") as f:
18 | q_im_ids = pickle.load(f)
19 |
20 | minival_ids = [x[1] for x in q_im_ids]
21 |
22 | with open(val_annotation_file, "r") as f:
23 | file_info = json.load(f)
24 | annotations = file_info["annotations"]
25 | info = file_info["info"]
26 | data_subtype = file_info["data_subtype"]
27 | license_info = file_info["license"]
28 |
29 | minival_annotations = [a for a in annotations if a["question_id"] in minival_ids]
30 |
31 | minival_info = {
32 | "data_subtype": data_subtype,
33 | "license": license_info,
34 | "info": info,
35 | "annotations": minival_annotations,
36 | }
37 |
38 | with open(minival_annotation_file, "w") as w:
39 | json.dump(minival_info, w)
40 |
--------------------------------------------------------------------------------
/pythia/legacy/tools/mirror_images.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 | #
7 |
8 | # All paths need to be updated
9 |
10 | import json
11 | import os
12 | from multiprocessing.dummy import Pool as ThreadPool
13 |
14 | from PIL import Image, ImageOps
15 |
16 | split = "val2014"
17 | image_paths = []
18 |
19 |
20 | def mirror_image(image_path):
21 | img = Image.open(image_path)
22 | mirror_img = ImageOps.mirror(img)
23 | image_name = image_path.split("/")[-1]
24 | fh = "data/" + split
25 | fh = os.path.join(fh, image_name)
26 | mirror_img.save(fh, "JPEG")
27 |
28 |
29 | with open("./COCO/060817/annotations/instances_val2014.json") as f:
30 | data = json.load(f)
31 | for item in data["images"]:
32 | image_id = int(item["id"])
33 | filepath = os.path.join("val2014/", item["file_name"])
34 | image_paths.append(filepath)
35 |
36 | pool = ThreadPool(10)
37 | results = pool.map(mirror_image, image_paths)
38 |
--------------------------------------------------------------------------------
/pythia/legacy/tools/model_path.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 | #
7 |
8 |
9 | ##vgg model from https://github.com/jcjohnson/pytorch-vgg
10 |
11 |
12 | vgg16_caffe2 = "https://s3-us-west-2.amazonaws.com/jcjohns-models/vgg16-00b39a1b.pth"
13 | vgg19_caffe2 = "https://s3-us-west-2.amazonaws.com/jcjohns-models/vgg19-d01eb7cb.pth"
14 |
--------------------------------------------------------------------------------
/pythia/legacy/tools/rename_genome_file.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 | #
7 |
8 |
9 | import os
10 | import shutil
11 | import sys
12 |
13 | if len(sys.argv) != 3:
14 | exit("Usage: python tools/rename_genome_file.py [inDir] [outDir]")
15 |
16 | inDir = sys.argv[1]
17 | outDir = sys.argv[2]
18 |
19 | OUT_NAME = "COCO_genome_%012d.npy"
20 |
21 | os.makedirs(outDir, exist_ok=True)
22 |
23 | n = 0
24 | print("BEGIN.....")
25 | for file in os.listdir(inDir):
26 | if file.endswith(".npy"):
27 | n += 1
28 | if n % 5000 == 0:
29 | print("process %d files" % n)
30 | image_id = int(file.split(".")[0])
31 | out_name = OUT_NAME % image_id
32 | in_file = os.path.join(inDir, file)
33 | out_file = os.path.join(outDir, out_name)
34 | shutil.copy(in_file, out_file)
35 |
36 | print("process total %d files" % n)
37 | print("DONE.....")
38 |
--------------------------------------------------------------------------------
/pythia/legacy/tools/subset_val.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 | #
7 |
8 |
9 | import argparse
10 | import json
11 | import random
12 |
13 |
14 | def parse_args():
15 | parser = argparse.ArgumentParser()
16 | parser.add_argument("--ques_file", type=str)
17 | pass
18 |
19 |
20 | if __name__ == "__main__":
21 | val_json_file = "v2_OpenEnded_mscoco_val2014_questions.json"
22 | minival_json_file = "v2_OpenEnded_mscoco_minival2014_questions.json"
23 | val_as_train_json_file = "v2_OpenEnded_mscoco_val2train2014_questions.json"
24 |
25 | with open(val_json_file, "r") as f:
26 | file_info = json.load(f)
27 | questions = file_info["questions"]
28 | info = file_info["info"]
29 | task_type = file_info["task_type"]
30 | data_type = file_info["data_type"]
31 | license = file_info["license"]
32 | data_subtype = file_info["info"]
33 |
34 | # collect image_id
35 | image_ids = []
36 | for q in questions:
37 | image_id = q["image_id"]
38 | image_ids.append(image_id)
39 |
40 | # divide image_ids to two parts
41 | random.shuffle(image_ids)
42 | minival_images = image_ids[:10000]
43 | other_images = image_ids[10000:]
44 |
45 | minival_ques = []
46 | other_ques = []
47 |
48 | total_minival = 0
49 | total_others = 0
50 | # seprate quesion_json_file
51 | for q in questions:
52 | image_id = q["image_id"]
53 |
54 | if image_id in minival_images:
55 | minival_ques.append(q)
56 | total_minival += 1
57 | else:
58 | other_ques.append(q)
59 | total_others += 1
60 |
61 | minival_json = {
62 | "info": info,
63 | "task_type": task_type,
64 | "data_type": data_type,
65 | "license": license,
66 | "data_subtype": "minival2014",
67 | "questions": minival_ques,
68 | }
69 |
70 | other_json = {
71 | "info": info,
72 | "task_type": task_type,
73 | "data_type": data_type,
74 | "license": license,
75 | "data_subtype": "val2train2014",
76 | "questions": other_ques,
77 | }
78 |
79 | with open(minival_json_file, "w") as w1:
80 | json.dump(minival_json, w1)
81 |
82 | with open(val_as_train_json_file, "w") as w2:
83 | json.dump(other_json, w2)
84 |
85 | print(
86 | "minival_questions: %d" % total_minival + "other_questions: %d" % total_others
87 | )
88 |
--------------------------------------------------------------------------------
/pythia/legacy/tools/timer.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | import timeit
3 |
4 |
5 | class Timer:
6 | def __init__(self, unit="s"):
7 | self.s_time = timeit.default_timer()
8 | self.unit = unit
9 | if self.unit != "s" and self.unit != "m" and self.unit != "h":
10 | raise NotImplementedError("unkown time unit, using s, m, h")
11 |
12 | def start(self):
13 | self.s_time = timeit.default_timer()
14 |
15 | def end(self):
16 | self.e_time = timeit.default_timer()
17 | period = self.e_time - self.s_time
18 | if self.unit == "s":
19 | return "%.1f s" % period
20 | elif self.unit == "m":
21 | return "%.2f min" % (period / 60)
22 | else:
23 | return "%.2f h" % (period / 3600)
24 |
--------------------------------------------------------------------------------
/pythia/legacy/top_down_bottom_up/image_embedding.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 | #
7 |
8 |
9 | import pickle
10 |
11 | import torch
12 | import torch.nn as nn
13 | import torch.nn.functional as F
14 |
15 |
16 | """
17 | parameters:
18 |
19 | input:
20 | image_feat_variable: [batch_size, num_location, image_feat_dim]
21 | or a list of [num_location, image_feat_dim]
22 | when using adaptive number of objects
23 | question_embedding:[batch_size, txt_embeding_dim]
24 |
25 | output:
26 | image_embedding:[batch_size, image_feat_dim]
27 |
28 |
29 | """
30 |
31 |
32 | class image_embedding(nn.Module):
33 | def __init__(self, image_attention_model):
34 | super(image_embedding, self).__init__()
35 | self.image_attention_model = image_attention_model
36 | self.out_dim = image_attention_model.out_dim
37 |
38 | def forward(self, image_feat_variable, question_embedding, image_dims):
39 | # N x K x n_att
40 | attention = self.image_attention_model(
41 | image_feat_variable, question_embedding, image_dims
42 | )
43 | att_reshape = attention.permute(0, 2, 1)
44 | tmp_embedding = torch.bmm(
45 | att_reshape, image_feat_variable
46 | ) # N x n_att x image_dim
47 | batch_size = att_reshape.size(0)
48 | image_embedding = tmp_embedding.view(batch_size, -1)
49 |
50 | return image_embedding
51 |
52 |
53 | class image_finetune(nn.Module):
54 | def __init__(self, in_dim, weights_file, bias_file):
55 | super(image_finetune, self).__init__()
56 | with open(weights_file, "rb") as w:
57 | weights = pickle.load(w)
58 | with open(bias_file, "rb") as b:
59 | bias = pickle.load(b)
60 | out_dim = bias.shape[0]
61 |
62 | self.lc = nn.Linear(in_dim, out_dim)
63 | self.lc.weight.data.copy_(torch.from_numpy(weights))
64 | self.lc.bias.data.copy_(torch.from_numpy(bias))
65 | self.out_dim = out_dim
66 |
67 | def forward(self, image):
68 | i2 = self.lc(image)
69 | i3 = F.relu(i2)
70 | return i3
71 |
--------------------------------------------------------------------------------
/pythia/legacy/top_down_bottom_up/image_feature_encoding.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 | #
7 |
8 |
9 | import os
10 | import pickle
11 |
12 | import torch
13 | import torch.nn as nn
14 | import torch.nn.functional as F
15 |
16 | from config.config import cfg
17 |
18 |
19 | def build_image_feature_encoding(method, par, in_dim):
20 | if method == "default_image":
21 | return DefaultImageFeature(in_dim)
22 | elif method == "finetune_faster_rcnn_fpn_fc7":
23 | return FinetuneFasterRcnnFpnFc7(in_dim, **par)
24 | else:
25 | raise NotImplementedError("unknown image feature encoding %s" % method)
26 |
27 |
28 | class DefaultImageFeature(nn.Module):
29 | def __init__(self, in_dim):
30 | super(DefaultImageFeature, self).__init__()
31 | self.in_dim = in_dim
32 | self.out_dim = in_dim
33 |
34 | def forward(self, image):
35 | return image
36 |
37 |
38 | class FinetuneFasterRcnnFpnFc7(nn.Module):
39 | def __init__(self, in_dim, weights_file, bias_file):
40 | super(FinetuneFasterRcnnFpnFc7, self).__init__()
41 | if not os.path.isabs(weights_file):
42 | weights_file = os.path.join(cfg.data.data_root_dir, weights_file)
43 | if not os.path.isabs(bias_file):
44 | bias_file = os.path.join(cfg.data.data_root_dir, bias_file)
45 | with open(weights_file, "rb") as w:
46 | weights = pickle.load(w)
47 | with open(bias_file, "rb") as b:
48 | bias = pickle.load(b)
49 | out_dim = bias.shape[0]
50 |
51 | self.lc = nn.Linear(in_dim, out_dim)
52 | self.lc.weight.data.copy_(torch.from_numpy(weights))
53 | self.lc.bias.data.copy_(torch.from_numpy(bias))
54 | self.out_dim = out_dim
55 |
56 | def forward(self, image):
57 | i2 = self.lc(image)
58 | i3 = F.relu(i2)
59 | return i3
60 |
--------------------------------------------------------------------------------
/pythia/legacy/top_down_bottom_up/intermediate_layer.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 | #
7 |
8 |
9 | import torch.nn as nn
10 |
11 |
12 | class inter_layer(nn.Module):
13 | def __init__(self, dim, n_layer):
14 | super(inter_layer, self).__init__()
15 | layers = []
16 | for i in range(n_layer):
17 | layers.append(nn.Linear(dim, dim))
18 | layers.append(nn.ReLU())
19 |
20 | self.main = nn.Sequential(*layers)
21 |
22 | def forward(self, x):
23 | return self.main(x)
24 |
--------------------------------------------------------------------------------
/pythia/legacy/top_down_bottom_up/nonlinear_layer.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 | #
7 |
8 |
9 | import torch
10 | import torch.nn as nn
11 | import torch.nn.functional as F
12 | from torch.nn.utils.weight_norm import weight_norm
13 |
14 |
15 | """
16 | nonlinear_layer: f_a : x\in R^m => y \in R^n
17 | \tilda{y} = tanh(Wx + b)
18 | g = sigmoid(W'x + b')
19 | y = \tilda(y) \circ g
20 | input (N, *, in_dim)
21 | output (N, *, out_dim)
22 | """
23 |
24 |
25 | class nonlinear_layer_org(nn.Module):
26 | def __init__(self, in_dim, out_dim):
27 | super(nonlinear_layer_org, self).__init__()
28 | self.fc1 = nn.Linear(in_dim, out_dim)
29 | self.gate = nn.Linear(in_dim, out_dim)
30 |
31 | def forward(self, x):
32 | y_tilda = F.tanh(self.fc1(x))
33 | g = F.sigmoid(self.gate(x))
34 | y = y_tilda * g
35 | return y
36 |
37 |
38 | class FCNet(nn.Module):
39 | """Simple class for non-linear fully connect network
40 | """
41 |
42 | def __init__(self, dims):
43 | super(FCNet, self).__init__()
44 |
45 | layers = []
46 | for i in range(len(dims) - 2):
47 | in_dim = dims[i]
48 | out_dim = dims[i + 1]
49 | layers.append(weight_norm(nn.Linear(in_dim, out_dim), dim=None))
50 | layers.append(nn.ReLU())
51 | layers.append(weight_norm(nn.Linear(dims[-2], dims[-1]), dim=None))
52 | layers.append(nn.ReLU())
53 |
54 | self.main = nn.Sequential(*layers)
55 |
56 | def forward(self, x):
57 | return self.main(x)
58 |
59 |
60 | class nonlinear_layer(nn.Module):
61 | """Simple class for non-linear fully connect network
62 | """
63 |
64 | def __init__(self, in_dim, out_dim):
65 | super(nonlinear_layer, self).__init__()
66 |
67 | layers = []
68 | layers.append(weight_norm(nn.Linear(in_dim, out_dim), dim=None))
69 | layers.append(nn.ReLU())
70 |
71 | self.main = nn.Sequential(*layers)
72 |
73 | def forward(self, x):
74 | return self.main(x)
75 |
--------------------------------------------------------------------------------
/pythia/legacy/top_down_bottom_up/post_combine_transform.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 | #
7 |
8 |
9 | import torch
10 | import torch.nn as nn
11 | import torch.nn.functional as F
12 | from torch.nn.utils.weight_norm import weight_norm
13 |
14 |
15 | def build_post_combine_transform(method, par, in_dim):
16 | if method == "linear_transform":
17 | return LinearTransform(in_dim, **par)
18 | elif method == "conv_transform":
19 | return ConvTransform(in_dim, **par)
20 | else:
21 | raise NotImplementedError("unkown post combime transform type %s" % method)
22 |
23 |
24 | class LinearTransform(nn.Module):
25 | def __init__(self, in_dim, **kwargs):
26 | super(LinearTransform, self).__init__()
27 | self.lc = weight_norm(
28 | nn.Linear(in_features=in_dim, out_features=kwargs["out_dim"]), dim=None
29 | )
30 | self.out_dim = kwargs["out_dim"]
31 |
32 | def forward(self, x):
33 | return self.lc(x)
34 |
35 |
36 | class ConvTransform(nn.Module):
37 | def __init__(self, in_dim, **kwargs):
38 | super(ConvTransform, self).__init__()
39 | self.conv1 = nn.Conv2d(
40 | in_channels=in_dim, out_channels=kwargs["hidden_dim"], kernel_size=1
41 | )
42 | self.conv2 = nn.Conv2d(
43 | in_channels=kwargs["hidden_dim"],
44 | out_channels=kwargs["out_dim"],
45 | kernel_size=1,
46 | )
47 | self.out_dim = kwargs["out_dim"]
48 |
49 | def forward(self, x):
50 | if len(x.size()) == 3: # N x k xdim
51 | # N x dim x k x 1
52 | x_reshape = torch.unsqueeze(x.permute(0, 2, 1), 3)
53 | elif len(x.size()) == 2: # N x dim
54 | # N x dim x 1 x 1
55 | x_reshape = torch.unsqueeze(torch.unsqueeze(x, 2), 3)
56 |
57 | iatt_conv1 = self.conv1(x_reshape) # N x hidden_dim x * x 1
58 | iatt_relu = F.relu(iatt_conv1)
59 | iatt_conv2 = self.conv2(iatt_relu) # N x out_dim x * x 1
60 |
61 | if len(x.size()) == 3:
62 | iatt_conv3 = torch.squeeze(iatt_conv2, 3).permute(0, 2, 1)
63 | elif len(x.size()) == 2:
64 | iatt_conv3 = torch.squeeze(torch.squeeze(iatt_conv2, 3), 2)
65 |
66 | return iatt_conv3
67 |
--------------------------------------------------------------------------------
/pythia/legacy/train_model/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 | #
7 |
--------------------------------------------------------------------------------
/pythia/legacy/train_model/eval_utils.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 | #
7 |
8 |
9 | import os
10 |
11 | import torch
12 | from torch.utils.data import DataLoader
13 |
14 |
15 | def get_final_validation(data_set_val, batch_size, snapshot_dir, eval_model):
16 | final_val_data_reader = DataLoader(
17 | data_set_val, shuffle=False, batch_size=batch_size
18 | )
19 |
20 | files = [
21 | os.path.join(snapshot_dir, file)
22 | for file in os.listdir(snapshot_dir)
23 | if file.startswith("model")
24 | ]
25 |
26 | for model_file in sorted(files, key=os.path.getctime, reverse=True):
27 | current_model = torch.load(model_file)
28 | total_sample = 0
29 | total_score = 0
30 | for i, batch in enumerate(final_val_data_reader):
31 | score, n_sample, _ = eval_model(batch, current_model)
32 | total_sample += n_sample
33 | total_score += score
34 |
35 | acc = total_score / total_sample
36 | print(model_file, ": %.6f" % acc)
37 |
--------------------------------------------------------------------------------
/pythia/legacy/train_model/evaluate.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 | #
7 |
8 |
9 | import argparse
10 | import os
11 |
12 | import yaml
13 | from torch.utils.data import DataLoader
14 |
15 | from train_model.dataset_utils import prepare_eval_data_set
16 | from train_model.Engineer import one_stage_eval_model
17 | from train_model.eval_utils import get_final_validation
18 | from train_model.model_factory import is_one_stageModel
19 |
20 | parser = argparse.ArgumentParser()
21 | parser.add_argument("--config", type=str, required=True, help="config yaml file")
22 | parser.add_argument("--out_dir", type=str, required=True, help="output directory")
23 | args = parser.parse_args()
24 |
25 | config_file = args.config
26 | out_dir = args.out_dir
27 |
28 | with open(config_file, "r") as f:
29 | config = yaml.load(f)
30 |
31 | # get the potential shared data_config info
32 | data_root_dir = config["data"]["data_root_dir"]
33 | batch_size = config["data"]["batch_size"]
34 | data_set_val = prepare_eval_data_set(**config["data"], **config["model"])
35 | data_reader_val = DataLoader(data_set_val, shuffle=False, batch_size=batch_size)
36 |
37 | snapshot_dir = os.path.join(out_dir, config["output"]["exp_name"])
38 | os.makedirs(snapshot_dir, exist_ok=True)
39 |
40 | model_type = config["model"]["model_type"]
41 | if is_one_stageModel(model_type):
42 | get_final_validation(data_set_val, batch_size, snapshot_dir, one_stage_eval_model)
43 | else:
44 | None
45 |
--------------------------------------------------------------------------------
/pythia/legacy/train_model/helper.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 | #
7 |
8 | import json
9 | import sys
10 | import timeit
11 |
12 | import numpy as np
13 |
14 | import _pickle as pickle
15 | from train_model.Engineer import masked_unk_softmax, one_stage_run_model
16 | from train_model.model_factory import prepare_model
17 |
18 |
19 | class answer_json:
20 | def __init__(self):
21 | self.answers = []
22 |
23 | def add(self, ques_id, ans):
24 | res = {"question_id": ques_id, "answer": ans}
25 | self.answers.append(res)
26 |
27 |
28 | def build_model(config, dataset):
29 | num_vocab_txt = dataset.vocab_dict.num_vocab
30 | num_choices = dataset.answer_dict.num_vocab
31 |
32 | num_image_feat = len(config["data"]["image_feat_train"][0].split(","))
33 | my_model = prepare_model(
34 | num_vocab_txt, num_choices, **config["model"], num_image_feat=num_image_feat
35 | )
36 | return my_model
37 |
38 |
39 | def run_model(current_model, data_reader, UNK_idx=0):
40 | softmax_tot = []
41 | q_id_tot = []
42 |
43 | start = timeit.default_timer()
44 | for i, batch in enumerate(data_reader):
45 | if (i + 1) % 100 == 0:
46 | end = timeit.default_timer()
47 | time = end - start
48 | start = timeit.default_timer()
49 | print(" process batch %d for test for %.1f s" % (i + 1, time))
50 | sys.stdout.flush()
51 |
52 | verbose_info = batch["verbose_info"]
53 | q_ids = verbose_info["question_id"].cpu().numpy().tolist()
54 | logit_res = one_stage_run_model(batch, current_model, eval_mode=True)
55 | softmax_res = masked_unk_softmax(logit_res, dim=1, mask_idx=UNK_idx)
56 | softmax_res = softmax_res.data.cpu().numpy().astype(np.float16)
57 | q_id_tot += q_ids
58 | softmax_tot.append(softmax_res)
59 | softmax_result = np.vstack(softmax_tot)
60 |
61 | return q_id_tot, softmax_result
62 |
63 |
64 | def print_result(
65 | question_ids, soft_max_result, ans_dic, out_file, json_only=True, pkl_res_file=None
66 | ):
67 | predicted_answers = np.argmax(soft_max_result, axis=1)
68 |
69 | if not json_only:
70 | with open(pkl_res_file, "wb") as writeFile:
71 | pickle.dump(soft_max_result, writeFile)
72 | pickle.dump(question_ids, writeFile)
73 | pickle.dump(ans_dic, writeFile)
74 |
75 | ans_json_out = answer_json()
76 | for idx, pred_idx in enumerate(predicted_answers):
77 | question_id = question_ids[idx]
78 | pred_ans = ans_dic.idx2word(pred_idx)
79 | ans_json_out.add(question_id, pred_ans)
80 |
81 | with open(out_file, "w") as f:
82 | json.dump(ans_json_out.answers, f)
83 |
--------------------------------------------------------------------------------
/pythia/models/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | __all__ = ["TopDownBottomUp", "Pythia", "LoRRA", "BAN"]
3 |
4 | from .top_down_bottom_up import TopDownBottomUp
5 | from .ban import BAN
6 | from .pythia import Pythia
7 | from .lorra import LoRRA
8 |
--------------------------------------------------------------------------------
/pythia/models/lorra.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | import torch
3 |
4 | from pythia.common.registry import registry
5 | from pythia.models.pythia import Pythia
6 | from pythia.modules.layers import ClassifierLayer
7 |
8 |
9 | @registry.register_model("lorra")
10 | class LoRRA(Pythia):
11 | def __init__(self, config):
12 | super().__init__(config)
13 |
14 | def build(self):
15 | self._init_text_embeddings("text")
16 | # For LoRRA context feature and text embeddings would be identity
17 | # but to keep a unified API, we will init them also
18 | # and we need to build them first before building pythia's other
19 | # modules as some of the modules require context attributes to be set
20 | self._init_text_embeddings("context")
21 | self._init_feature_encoders("context")
22 | self._init_feature_embeddings("context")
23 | super().build()
24 |
25 | def get_optimizer_parameters(self, config):
26 | params = super().get_optimizer_parameters(config)
27 | params += [
28 | {"params": self.context_feature_embeddings_list.parameters()},
29 | {"params": self.context_embeddings.parameters()},
30 | {"params": self.context_feature_encoders.parameters()},
31 | ]
32 |
33 | return params
34 |
35 | def _get_classifier_input_dim(self):
36 | # Now, the classifier's input will be cat of image and context based
37 | # features
38 | return 2 * super()._get_classifier_input_dim()
39 |
40 | def forward(self, sample_list):
41 | sample_list.text = self.word_embedding(sample_list.text)
42 | text_embedding_total = self.process_text_embedding(sample_list)
43 |
44 | image_embedding_total, _ = self.process_feature_embedding(
45 | "image", sample_list, text_embedding_total
46 | )
47 |
48 | context_embedding_total, _ = self.process_feature_embedding(
49 | "context", sample_list, text_embedding_total, ["order_vectors"]
50 | )
51 |
52 | if self.inter_model is not None:
53 | image_embedding_total = self.inter_model(image_embedding_total)
54 |
55 | joint_embedding = self.combine_embeddings(
56 | ["image", "text"],
57 | [image_embedding_total, text_embedding_total, context_embedding_total],
58 | )
59 |
60 | scores = self.calculate_logits(joint_embedding)
61 |
62 | return {"scores": scores}
63 |
--------------------------------------------------------------------------------
/pythia/models/m4c_captioner.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | from pythia.common.registry import registry
3 | from pythia.models.m4c import M4C
4 |
5 |
6 | @registry.register_model("m4c_captioner")
7 | class M4CCaptioner(M4C):
8 | def __init__(self, config):
9 | super().__init__(config)
10 | self.remove_unk_in_pred = self.config.remove_unk_in_pred
11 |
12 | def _forward_output(self, sample_list, fwd_results):
13 | super()._forward_output(sample_list, fwd_results)
14 |
15 | if (not self.training) and self.remove_unk_in_pred:
16 | # avoid outputting