├── .gitignore
├── LICENSE
├── README.md
├── bench.py
├── checkpoint
└── .gitignore
├── config.py
├── data
└── .gitignore
├── dataset.py
├── image_feature.py
├── log
└── .gitignore
├── merge.py
├── model.py
├── model_gqa.py
├── preprocess.py
├── requirements.txt
├── setup.sh
├── test.py
├── train.py
├── transforms.py
└── utils.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Data files
2 | *.hdf5
3 | *.model
4 | log/*.txt
5 | *.pkl
6 | data/**
7 | w2v/**
8 | # Byte-compiled / optimized / DLL files
9 | __pycache__/
10 | *.py[cod]
11 | *$py.class
12 |
13 | # C extensions
14 | *.so
15 |
16 | # Distribution / packaging
17 | .Python
18 | env/
19 | build/
20 | develop-eggs/
21 | dist/
22 | downloads/
23 | eggs/
24 | .eggs/
25 | lib/
26 | lib64/
27 | parts/
28 | sdist/
29 | var/
30 | wheels/
31 | *.egg-info/
32 | .installed.cfg
33 | *.egg
34 |
35 | # PyInstaller
36 | # Usually these files are written by a python script from a template
37 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
38 | *.manifest
39 | *.spec
40 |
41 | # Installer logs
42 | pip-log.txt
43 | pip-delete-this-directory.txt
44 |
45 | # Unit test / coverage reports
46 | htmlcov/
47 | .tox/
48 | .coverage
49 | .coverage.*
50 | .cache
51 | nosetests.xml
52 | coverage.xml
53 | *.cover
54 | .hypothesis/
55 |
56 | # Translations
57 | *.mo
58 | *.pot
59 |
60 | # Django stuff:
61 | *.log
62 | local_settings.py
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | target/
76 |
77 | # Jupyter Notebook
78 | .ipynb_checkpoints
79 |
80 | # pyenv
81 | .python-version
82 |
83 | # celery beat schedule file
84 | celerybeat-schedule
85 |
86 | # SageMath parsed files
87 | *.sage.py
88 |
89 | # dotenv
90 | .env
91 |
92 | # virtualenv
93 | .venv
94 | venv/
95 | ENV/
96 |
97 | # Spyder project settings
98 | .spyderproject
99 | .spyproject
100 |
101 | # Rope project settings
102 | .ropeproject
103 |
104 | # mkdocs documentation
105 | /site
106 |
107 | # mypy
108 | .mypy_cache/
109 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2018 Kim Seonghyeon
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
mac-network-pytorch
2 | Memory, Attention and Composition (MAC) Network for CLEVR/GQA from Compositional Attention Networks for Machine Reasoning (https://arxiv.org/abs/1803.03067) implemented in PyTorch
3 | Requirements:
4 |
5 | - Python 3.6
6 | - PyTorch 1.0.1
7 | - torch-vision
8 | - Pillow
9 | - nltk
10 | - tqdm
11 | - block.bootstrap.pytorch murel.bootstrap.pytorch
12 |
13 | To train:
14 |
15 | - Download and extract either
16 | CLEVR v1.0 dataset from http://cs.stanford.edu/people/jcjohns/clevr/ or
17 | GQA dataset from https://cs.stanford.edu/people/dorarad/gqa/download.html
18 |
19 | For GQA
20 | cd data
21 | mkdir gqa && cd gqa
22 | wget https://nlp.stanford.edu/data/gqa/data1.2.zip
23 | unzip data1.2.zip
24 |
25 | mkdir questions
26 | mv balanced_train_data.json questions/gqa_train_questions.json
27 | mv balanced_val_data.json questions/gqa_val_questions.json
28 | mv balanced_testdev_data.json questions/gqa_testdev_questions.json
29 | cd ..
30 |
31 | wget http://nlp.stanford.edu/data/glove.6B.zip
32 | unzip glove.6B.zip
33 | wget http://nlp.stanford.edu/data/gqa/objectFeatures.zip
34 | unzip objectFeatures.zip
35 | cd ..
36 |
37 |
38 | - Preprocessing question data and extracting image features using ResNet 101 (Not required for GQA)
39 | For CLEVR
40 | a. Extract image features
41 |
42 | python image_feature.py data/CLEVR_v1.0
43 |
44 | b. Preprocess questions
45 | python preprocess.py CLEVR data/CLEVR_v1.0
46 |
47 | For GQA
48 | a. Merge object features (this may take some time)
49 | python merge.py --name objects
50 | mv data/gqa_objects.hdf5 data/gqa_features.hdf5
51 |
52 | b. Preprocess questions
53 | python preprocess.py gqa data/gqa
54 |
55 | !CAUTION! the size of file created by image_feature.py is very large! You may use hdf5 compression, but it will slow down feature extraction.
56 |
57 | - Run train.py with dataset type as argument (gqa or CLEVR)
58 |
59 | python train.py gqa
60 |
61 | CLEVR -> This implementation produces 95.75% accuracy at epoch 10, 96.5% accuracy at epoch 20.
62 | Parts of the code borrowed from https://github.com/rosinality/mac-network-pytorch and
63 | https://github.com/stanfordnlp/mac-network.
64 |
--------------------------------------------------------------------------------
/bench.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import pickle
3 | from collections import Counter
4 |
5 | import numpy as np
6 | import torch
7 | from torch import nn
8 | from torch import optim
9 | from torch.autograd import Variable
10 | from torch.utils.data import DataLoader
11 | from tqdm import tqdm
12 |
13 | from dataset import CLEVR, collate_data, transform
14 | from model_gqa import MACNetwork
15 |
16 | batch_size = 64
17 | n_epoch = 20
18 | dim = 512
19 |
20 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
21 |
22 | def accumulate(model1, model2, decay=0.999):
23 | par1 = dict(model1.named_parameters())
24 | par2 = dict(model2.named_parameters())
25 |
26 | for k in par1.keys():
27 | par1[k].data.mul_(decay).add_(1 - decay, par2[k].data)
28 |
29 | def train():
30 | moving_loss = 0
31 |
32 | net.train(True)
33 |
34 | image = torch.randn(batch_size, 1024, 14, 14, device=device)
35 | question = torch.randint(0, 28, (batch_size, 30), dtype=torch.int64, device=device)
36 | answer = torch.randint(0, 28, (batch_size,), dtype=torch.int64, device=device)
37 | q_len = torch.tensor([30] * batch_size, dtype=torch.int64, device=device)
38 |
39 | for i in range(30):
40 | net.zero_grad()
41 | output = net(image, question, q_len)
42 | loss = criterion(output, answer)
43 | loss.backward()
44 | optimizer.step()
45 | correct = output.detach().argmax(1) \
46 | == answer
47 | correct = torch.tensor(correct, dtype=torch.float32).sum() \
48 | / batch_size
49 |
50 | if moving_loss == 0:
51 | moving_loss = correct
52 |
53 | else:
54 | moving_loss = moving_loss * 0.99 + correct * 0.01
55 |
56 | accumulate(net_running, net)
57 |
58 | if __name__ == '__main__':
59 | with open('data/dic.pkl', 'rb') as f:
60 | dic = pickle.load(f)
61 |
62 | n_words = len(dic['word_dic']) + 1
63 | n_answers = len(dic['answer_dic'])
64 |
65 | net = MACNetwork(n_words, dim).to(device)
66 | net_running = MACNetwork(n_words, dim).to(device)
67 | accumulate(net_running, net, 0)
68 |
69 | criterion = nn.CrossEntropyLoss()
70 | optimizer = optim.Adam(net.parameters(), lr=1e-4)
71 |
72 | train()
--------------------------------------------------------------------------------
/checkpoint/.gitignore:
--------------------------------------------------------------------------------
1 | # Data files
2 | *.hdf5
3 | *.model
4 | log/*.txt
5 | *.pkl
6 |
7 | # Byte-compiled / optimized / DLL files
8 | __pycache__/
9 | *.py[cod]
10 | *$py.class
11 |
12 | # C extensions
13 | *.so
14 |
15 | # Distribution / packaging
16 | .Python
17 | env/
18 | build/
19 | develop-eggs/
20 | dist/
21 | downloads/
22 | eggs/
23 | .eggs/
24 | lib/
25 | lib64/
26 | parts/
27 | sdist/
28 | var/
29 | wheels/
30 | *.egg-info/
31 | .installed.cfg
32 | *.egg
33 |
34 | # PyInstaller
35 | # Usually these files are written by a python script from a template
36 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
37 | *.manifest
38 | *.spec
39 |
40 | # Installer logs
41 | pip-log.txt
42 | pip-delete-this-directory.txt
43 |
44 | # Unit test / coverage reports
45 | htmlcov/
46 | .tox/
47 | .coverage
48 | .coverage.*
49 | .cache
50 | nosetests.xml
51 | coverage.xml
52 | *.cover
53 | .hypothesis/
54 |
55 | # Translations
56 | *.mo
57 | *.pot
58 |
59 | # Django stuff:
60 | *.log
61 | local_settings.py
62 |
63 | # Flask stuff:
64 | instance/
65 | .webassets-cache
66 |
67 | # Scrapy stuff:
68 | .scrapy
69 |
70 | # Sphinx documentation
71 | docs/_build/
72 |
73 | # PyBuilder
74 | target/
75 |
76 | # Jupyter Notebook
77 | .ipynb_checkpoints
78 |
79 | # pyenv
80 | .python-version
81 |
82 | # celery beat schedule file
83 | celerybeat-schedule
84 |
85 | # SageMath parsed files
86 | *.sage.py
87 |
88 | # dotenv
89 | .env
90 |
91 | # virtualenv
92 | .venv
93 | venv/
94 | ENV/
95 |
96 | # Spyder project settings
97 | .spyderproject
98 | .spyproject
99 |
100 | # Rope project settings
101 | .ropeproject
102 |
103 | # mkdocs documentation
104 | /site
105 |
106 | # mypy
107 | .mypy_cache/
108 |
--------------------------------------------------------------------------------
/config.py:
--------------------------------------------------------------------------------
1 | # File: config.py
2 | # Author: Ronil Pancholia
3 | # Date: 4/20/19
4 | # Time: 6:37 PM
5 |
6 | ## dropouts
7 | encInputDropout = 0.2 # dropout of the rnn inputs to the Question Input Unit
8 | encStateDropout = 0.0 # dropout of the rnn states of the Question Input Unit
9 | stemDropout = 0.2 # dropout of the Image Input Unit (the stem)
10 | qDropout = 0.08 # dropout on the question vector
11 | qDropoutOut = 0 # dropout on the question vector the goes to the output unit
12 | memoryDropout = 0.15 # dropout on the recurrent memory
13 | readDropout = 0.15 # dropout of the read unit
14 | writeDropout = 1.0 # dropout of the write unit
15 | outputDropout = 0.85 # dropout of the output unit
16 | controlPreDropout = 1.0 # dropout of the write unit
17 | controlPostDropout = 1.0 # dropout of the write unit
18 | wordEmbDropout = 1.0 # dropout of the write unit
--------------------------------------------------------------------------------
/data/.gitignore:
--------------------------------------------------------------------------------
1 | # Data files
2 | *.hdf5
3 | *.model
4 | log/*.txt
5 | *.pkl
6 |
7 | # Byte-compiled / optimized / DLL files
8 | __pycache__/
9 | *.py[cod]
10 | *$py.class
11 |
12 | # C extensions
13 | *.so
14 |
15 | # Distribution / packaging
16 | .Python
17 | env/
18 | build/
19 | develop-eggs/
20 | dist/
21 | downloads/
22 | eggs/
23 | .eggs/
24 | lib/
25 | lib64/
26 | parts/
27 | sdist/
28 | var/
29 | wheels/
30 | *.egg-info/
31 | .installed.cfg
32 | *.egg
33 |
34 | # PyInstaller
35 | # Usually these files are written by a python script from a template
36 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
37 | *.manifest
38 | *.spec
39 |
40 | # Installer logs
41 | pip-log.txt
42 | pip-delete-this-directory.txt
43 |
44 | # Unit test / coverage reports
45 | htmlcov/
46 | .tox/
47 | .coverage
48 | .coverage.*
49 | .cache
50 | nosetests.xml
51 | coverage.xml
52 | *.cover
53 | .hypothesis/
54 |
55 | # Translations
56 | *.mo
57 | *.pot
58 |
59 | # Django stuff:
60 | *.log
61 | local_settings.py
62 |
63 | # Flask stuff:
64 | instance/
65 | .webassets-cache
66 |
67 | # Scrapy stuff:
68 | .scrapy
69 |
70 | # Sphinx documentation
71 | docs/_build/
72 |
73 | # PyBuilder
74 | target/
75 |
76 | # Jupyter Notebook
77 | .ipynb_checkpoints
78 |
79 | # pyenv
80 | .python-version
81 |
82 | # celery beat schedule file
83 | celerybeat-schedule
84 |
85 | # SageMath parsed files
86 | *.sage.py
87 |
88 | # dotenv
89 | .env
90 |
91 | # virtualenv
92 | .venv
93 | venv/
94 | ENV/
95 |
96 | # Spyder project settings
97 | .spyderproject
98 | .spyproject
99 |
100 | # Rope project settings
101 | .ropeproject
102 |
103 | # mkdocs documentation
104 | /site
105 |
106 | # mypy
107 | .mypy_cache/
108 |
--------------------------------------------------------------------------------
/dataset.py:
--------------------------------------------------------------------------------
1 | import json
2 | import os
3 | import pickle
4 |
5 | import numpy as np
6 | from PIL import Image
7 | import torch
8 | from torch.utils.data import Dataset
9 | from torchvision import transforms
10 | import h5py
11 |
12 | from transforms import Scale
13 |
14 | class CLEVR(Dataset):
15 | def __init__(self, root, split='train', transform=None, lang='en'):
16 | with open(f'data/CLEVR_{split}_{lang}.pkl', 'rb') as f:
17 | self.data = pickle.load(f)
18 |
19 | # self.transform = transform
20 | self.root = root
21 | self.split = split
22 | with open(f'mini_CLEVR_{split}_questions_translated.json') as f:
23 | data = json.load(f)
24 | self.img_idx_map = {}
25 | i=0
26 | for question in data['questions']:
27 | if question['image_index'] not in self.img_idx_map.keys():
28 | self.img_idx_map[question['image_index']] = i
29 | i+=1
30 | self.idx_img_map = {v:k for k,v in self.img_idx_map.items()}
31 | self.h = h5py.File(f'data/CLEVR_features.hdf5'.format(split), 'r')
32 | self.img = self.h['data']
33 |
34 | def close(self):
35 | self.h.close()
36 |
37 | def __getitem__(self, index):
38 | imgfile, question, answer = self.data[index]
39 | # img = Image.open(os.path.join(self.root, 'images',
40 | # self.split, imgfile)).convert('RGB')
41 |
42 | # img = self.transform(img)
43 | id = int(imgfile.rsplit('_', 1)[1][:-4])
44 | img = torch.from_numpy(self.img[self.img_idx_map[id]])
45 |
46 | return img, question, len(question), answer
47 |
48 | def __len__(self):
49 | return len(self.data)
50 |
51 | class GQA(Dataset):
52 | def __init__(self, root, split='train', transform=None):
53 | with open(f'data/gqa_{split}.pkl', 'rb') as f:
54 | self.data = pickle.load(f)
55 |
56 | self.root = root
57 | self.split = split
58 |
59 | self.h = h5py.File('data/gqa_features.hdf5'.format(split), 'r')
60 | self.img = self.h['features']
61 | self.img_info = json.load(open('data/gqa_objects_merged_info.json', 'r'))
62 |
63 | def close(self):
64 | self.h.close()
65 |
66 | def __getitem__(self, index):
67 | imgfile, question, answer = self.data[index]
68 | idx = int(self.img_info[imgfile]['index'])
69 | img = torch.from_numpy(self.img[idx])
70 | return img, question, len(question), answer
71 |
72 | def __len__(self):
73 | return len(self.data)
74 |
75 | transform = transforms.Compose([
76 | Scale([224, 224]),
77 | transforms.Pad(4),
78 | transforms.RandomCrop([224, 224]),
79 | transforms.ToTensor(),
80 | transforms.Normalize(mean=[0.5, 0.5, 0.5],
81 | std=[0.5, 0.5, 0.5])
82 | ])
83 |
84 | def collate_data(batch):
85 | images, lengths, answers = [], [], []
86 | batch_size = len(batch)
87 |
88 | max_len = max(map(lambda x: len(x[1]), batch))
89 |
90 | questions = np.zeros((batch_size, max_len), dtype=np.int64)
91 | sort_by_len = sorted(batch, key=lambda x: len(x[1]), reverse=True)
92 |
93 | for i, b in enumerate(sort_by_len):
94 | image, question, length, answer = b
95 | images.append(image)
96 | length = len(question)
97 | questions[i, :length] = question
98 | lengths.append(length)
99 | answers.append(answer)
100 |
101 | return torch.stack(images), torch.from_numpy(questions), \
102 | lengths, torch.LongTensor(answers)
103 |
--------------------------------------------------------------------------------
/image_feature.py:
--------------------------------------------------------------------------------
1 | import h5py
2 | import torch
3 | from torchvision.models.resnet import ResNet, resnet101
4 | from torchvision import transforms
5 | from torch.utils.data import Dataset, DataLoader
6 | from transforms import Scale
7 | import sys
8 | import os
9 | from PIL import Image
10 | from tqdm import tqdm
11 | from torch.autograd import Variable
12 | import json
13 |
14 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
15 |
16 | def forward(self, x):
17 | x = self.conv1(x)
18 | x = self.bn1(x)
19 | x = self.relu(x)
20 | x = self.maxpool(x)
21 |
22 | x = self.layer1(x)
23 | x = self.layer2(x)
24 | x = self.layer3(x)
25 |
26 | return x
27 |
28 | transform = transforms.Compose([
29 | Scale([224, 224]),
30 | transforms.ToTensor(),
31 | transforms.Normalize(mean=[0.485, 0.456, 0.406],
32 | std=[0.229, 0.224, 0.225])])
33 |
34 | class CLEVR(Dataset):
35 | def __init__(self, root, split='train'):
36 | self.root = root
37 | self.split = split
38 | self.length = len(os.listdir(os.path.join(root,
39 | 'images', split)))
40 |
41 | with open(f'mini_CLEVR_{split}_questions.json') as f:
42 | data = json.load(f)
43 | self.img_idx_map = {}
44 | i=0
45 | for question in data['questions']:
46 | if question['image_index'] not in self.img_idx_map.keys():
47 | self.img_idx_map[question['image_index']] = i
48 | i+=1
49 | self.idx_img_map = {v:k for k,v in self.img_idx_map.items()}
50 |
51 | def __getitem__(self, index):
52 | img = os.path.join(self.root, 'images',
53 | self.split,
54 | 'CLEVR_{}_{}.png'.format(self.split,
55 | str(self.idx_img_map[index]).zfill(6)))
56 | img = Image.open(img).convert('RGB')
57 | return transform(img)
58 |
59 | def __len__(self):
60 | return self.length
61 |
62 | batch_size = 50
63 |
64 | resnet = resnet101(True).to(device)
65 | resnet.eval()
66 | resnet.forward = forward.__get__(resnet, ResNet)
67 |
68 | def create_dataset(split):
69 | dataloader = DataLoader(CLEVR(sys.argv[1], split), batch_size=batch_size,
70 | num_workers=4)
71 |
72 | size = len(dataloader)
73 |
74 | print(split, 'total', size * batch_size)
75 |
76 | f = h5py.File('data/CLEVR_features.hdf5'.format(split), 'w', libver='latest')
77 | dset = f.create_dataset('data', (size * batch_size, 1024, 14, 14),
78 | dtype='f4')
79 |
80 | with torch.no_grad():
81 | for i, image in tqdm(enumerate(dataloader)):
82 | image = image.to(device)
83 | features = resnet(image).detach().cpu().numpy()
84 | try:
85 | dset[i * batch_size:(i + 1) * batch_size] = features
86 | except:
87 | dset[i * batch_size:i * batch_size+features.shape[0]] = features
88 |
89 | f.close()
90 |
91 | create_dataset('val')
92 | create_dataset('train')
--------------------------------------------------------------------------------
/log/.gitignore:
--------------------------------------------------------------------------------
1 | # Data files
2 | *.hdf5
3 | *.model
4 | log/*.txt
5 | *.pkl
6 |
7 | # Byte-compiled / optimized / DLL files
8 | __pycache__/
9 | *.py[cod]
10 | *$py.class
11 |
12 | # C extensions
13 | *.so
14 |
15 | # Distribution / packaging
16 | .Python
17 | env/
18 | build/
19 | develop-eggs/
20 | dist/
21 | downloads/
22 | eggs/
23 | .eggs/
24 | lib/
25 | lib64/
26 | parts/
27 | sdist/
28 | var/
29 | wheels/
30 | *.egg-info/
31 | .installed.cfg
32 | *.egg
33 |
34 | # PyInstaller
35 | # Usually these files are written by a python script from a template
36 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
37 | *.manifest
38 | *.spec
39 |
40 | # Installer logs
41 | pip-log.txt
42 | pip-delete-this-directory.txt
43 |
44 | # Unit test / coverage reports
45 | htmlcov/
46 | .tox/
47 | .coverage
48 | .coverage.*
49 | .cache
50 | nosetests.xml
51 | coverage.xml
52 | *.cover
53 | .hypothesis/
54 |
55 | # Translations
56 | *.mo
57 | *.pot
58 |
59 | # Django stuff:
60 | *.log
61 | local_settings.py
62 |
63 | # Flask stuff:
64 | instance/
65 | .webassets-cache
66 |
67 | # Scrapy stuff:
68 | .scrapy
69 |
70 | # Sphinx documentation
71 | docs/_build/
72 |
73 | # PyBuilder
74 | target/
75 |
76 | # Jupyter Notebook
77 | .ipynb_checkpoints
78 |
79 | # pyenv
80 | .python-version
81 |
82 | # celery beat schedule file
83 | celerybeat-schedule
84 |
85 | # SageMath parsed files
86 | *.sage.py
87 |
88 | # dotenv
89 | .env
90 |
91 | # virtualenv
92 | .venv
93 | venv/
94 | ENV/
95 |
96 | # Spyder project settings
97 | .spyderproject
98 | .spyproject
99 |
100 | # Rope project settings
101 | .ropeproject
102 |
103 | # mkdocs documentation
104 | /site
105 |
106 | # mypy
107 | .mypy_cache/
108 |
--------------------------------------------------------------------------------
/merge.py:
--------------------------------------------------------------------------------
1 | # File: merge.py
2 | # Author: Ronil Pancholia
3 | # Date: 4/21/19
4 | # Time: 9:20 PM
5 |
6 | # Script to merge hdf5 chunk files to one and update info.json accordingly
7 |
8 | import warnings
9 |
10 | warnings.filterwarnings("ignore", category=FutureWarning)
11 | warnings.filterwarnings("ignore", message="size changed")
12 |
13 | from tqdm import tqdm
14 | import argparse
15 | import h5py
16 | import json
17 |
18 | parser = argparse.ArgumentParser()
19 | parser.add_argument('--name', type=str, help="features directory name")
20 | parser.add_argument('--chunksNum', type=int, default=16, help="number of file chunks")
21 | parser.add_argument('--chunkSize', type=int, default=10000, help="file chunk size")
22 | args = parser.parse_args()
23 |
24 | print("Merging features file for gqa_{}. This may take a while (and may be 0 for some time).".format(args.name))
25 |
26 | # Format specification for features files
27 | spec = {
28 | "spatial": {"features": (148855, 2048, 7, 7)},
29 | "objects": {"features": (148855, 100, 2048),
30 | "bboxes": (148855, 100, 4)}
31 | }
32 |
33 | # Merge hdf5 files
34 | lengths = [0]
35 | with h5py.File("data/gqa_{name}.hdf5".format(name=args.name)) as out:
36 | datasets = {}
37 | for dname in spec[args.name]:
38 | datasets[dname] = out.create_dataset(dname, spec[args.name][dname])
39 |
40 | low = 0
41 | for i in tqdm(range(args.chunksNum)):
42 | with h5py.File("data/{name}/gqa_{name}_{index}.h5".format(name=args.name, index=i)) as chunk:
43 | high = low + chunk["features"].shape[0]
44 |
45 | for dname in spec[args.name]:
46 | # low = i * args.chunkSize
47 | # high = (i + 1) * args.chunkSize if i < args.chunksNum -1 else spec[args.name][dname][0]
48 | datasets[dname][low:high] = chunk[dname][:]
49 |
50 | low = high
51 | lengths.append(high)
52 |
53 | # Update info file
54 | with open("data/{name}/gqa_{name}_info.json".format(name=args.name)) as infoIn:
55 | info = json.load(infoIn)
56 | for imageId in info:
57 | info[imageId]["index"] = lengths[info[imageId]["file"]] + info[imageId]["idx"]
58 | # info[imageId]["index"] = info[imageId]["file"] * args.chunkSize + info[imageId]["idx"]
59 | del info[imageId]["idx"]
60 | del info[imageId]["file"]
61 |
62 | with open("data/gqa_{name}_merged_info.json".format(name=args.name), "w") as infoOut:
63 | json.dump(info, infoOut)
--------------------------------------------------------------------------------
/model.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.autograd import Variable
3 | from torch import nn
4 | from torch.nn.init import kaiming_uniform_, xavier_uniform_, normal
5 | import torch.nn.functional as F
6 |
7 | def linear(in_dim, out_dim, bias=True):
8 | lin = nn.Linear(in_dim, out_dim, bias=bias)
9 | xavier_uniform_(lin.weight)
10 | if bias:
11 | lin.bias.data.zero_()
12 |
13 | return lin
14 |
15 | class ControlUnit(nn.Module):
16 | def __init__(self, dim, max_step):
17 | super().__init__()
18 |
19 | self.position_aware = nn.ModuleList()
20 | for i in range(max_step):
21 | self.position_aware.append(linear(dim * 2, dim))
22 |
23 | self.control_question = linear(dim * 2, dim)
24 | self.attn = linear(dim, 1)
25 |
26 | self.dim = dim
27 |
28 | def forward(self, step, context, question, control):
29 | position_aware = self.position_aware[step](question)
30 |
31 | control_question = torch.cat([control, position_aware], 1)
32 | control_question = self.control_question(control_question)
33 | control_question = control_question.unsqueeze(1)
34 |
35 | context_prod = control_question * context
36 | attn_weight = self.attn(context_prod)
37 |
38 | attn = F.softmax(attn_weight, 1)
39 |
40 | next_control = (attn * context).sum(1)
41 |
42 | return next_control
43 |
44 |
45 | class ReadUnit(nn.Module):
46 | def __init__(self, dim):
47 | super().__init__()
48 |
49 | self.mem = linear(dim, dim)
50 | self.concat = linear(dim * 2, dim)
51 | self.attn = linear(dim, 1)
52 |
53 | def forward(self, memory, know, control):
54 | mem = self.mem(memory[-1]).unsqueeze(2)
55 | concat = self.concat(torch.cat([mem * know, know], 1) \
56 | .permute(0, 2, 1))
57 | attn = concat * control[-1].unsqueeze(1)
58 | attn = self.attn(attn).squeeze(2)
59 | attn = F.softmax(attn, 1).unsqueeze(1)
60 |
61 | read = (attn * know).sum(2)
62 |
63 | return read
64 |
65 |
66 | class WriteUnit(nn.Module):
67 | def __init__(self, dim, self_attention=False, memory_gate=False):
68 | super().__init__()
69 |
70 | self.concat = linear(dim * 2, dim)
71 |
72 | if self_attention:
73 | self.attn = linear(dim, 1)
74 | self.mem = linear(dim, dim)
75 |
76 | if memory_gate:
77 | self.control = linear(dim, 1)
78 |
79 | self.self_attention = self_attention
80 | self.memory_gate = memory_gate
81 |
82 | def forward(self, memories, retrieved, controls):
83 | prev_mem = memories[-1]
84 | concat = self.concat(torch.cat([retrieved, prev_mem], 1))
85 | next_mem = concat
86 |
87 | if self.self_attention:
88 | controls_cat = torch.stack(controls[:-1], 2)
89 | attn = controls[-1].unsqueeze(2) * controls_cat
90 | attn = self.attn(attn.permute(0, 2, 1))
91 | attn = F.softmax(attn, 1).permute(0, 2, 1)
92 |
93 | memories_cat = torch.stack(memories, 2)
94 | attn_mem = (attn * memories_cat).sum(2)
95 | next_mem = self.mem(attn_mem) + concat
96 |
97 | if self.memory_gate:
98 | control = self.control(controls[-1])
99 | gate = F.sigmoid(control)
100 | next_mem = gate * prev_mem + (1 - gate) * next_mem
101 |
102 | return next_mem
103 |
104 |
105 | class MACUnit(nn.Module):
106 | def __init__(self, dim, max_step=12,
107 | self_attention=False, memory_gate=False,
108 | dropout=0.15):
109 | super().__init__()
110 |
111 | self.control = ControlUnit(dim, max_step)
112 | self.read = ReadUnit(dim)
113 | self.write = WriteUnit(dim, self_attention, memory_gate)
114 |
115 | self.mem_0 = nn.Parameter(torch.zeros(1, dim))
116 | self.control_0 = nn.Parameter(torch.zeros(1, dim))
117 |
118 | self.dim = dim
119 | self.max_step = max_step
120 | self.dropout = dropout
121 |
122 | def get_mask(self, x, dropout):
123 | mask = torch.empty_like(x).bernoulli_(1 - dropout)
124 | mask = mask / (1 - dropout)
125 |
126 | return mask
127 |
128 | def forward(self, context, question, knowledge):
129 | b_size = question.size(0)
130 |
131 | control = self.control_0.expand(b_size, self.dim)
132 | memory = self.mem_0.expand(b_size, self.dim)
133 |
134 | if self.training:
135 | control_mask = self.get_mask(control, self.dropout)
136 | memory_mask = self.get_mask(memory, self.dropout)
137 | control = control * control_mask
138 | memory = memory * memory_mask
139 |
140 | controls = [control]
141 | memories = [memory]
142 |
143 | for i in range(self.max_step):
144 | control = self.control(i, context, question, control)
145 | if self.training:
146 | control = control * control_mask
147 | controls.append(control)
148 |
149 | read = self.read(memories, knowledge, controls)
150 | memory = self.write(memories, read, controls)
151 | if self.training:
152 | memory = memory * memory_mask
153 | memories.append(memory)
154 |
155 | return memory
156 |
157 |
158 | class MACNetwork(nn.Module):
159 | def __init__(self, n_vocab, dim, embed_hidden=300,
160 | max_step=12, self_attention=False, memory_gate=False,
161 | classes=28, dropout=0.15):
162 | super().__init__()
163 |
164 | self.conv = nn.Sequential(nn.Conv2d(1024, dim, 3, padding=1),
165 | nn.ELU(),
166 | nn.Conv2d(dim, dim, 3, padding=1),
167 | nn.ELU())
168 |
169 | self.embed = nn.Embedding(n_vocab, embed_hidden)
170 | self.lstm = nn.LSTM(embed_hidden, dim,
171 | batch_first=True, bidirectional=True)
172 | self.lstm_proj = nn.Linear(dim * 2, dim)
173 |
174 | self.mac = MACUnit(dim, max_step,
175 | self_attention, memory_gate, dropout)
176 |
177 |
178 | self.classifier = nn.Sequential(linear(dim * 3, dim),
179 | nn.ELU(),
180 | linear(dim, classes))
181 |
182 | self.max_step = max_step
183 | self.dim = dim
184 |
185 | self.reset()
186 |
187 | def reset(self):
188 | self.embed.weight.data.uniform_(0, 1)
189 |
190 | kaiming_uniform_(self.conv[0].weight)
191 | self.conv[0].bias.data.zero_()
192 | kaiming_uniform_(self.conv[2].weight)
193 | self.conv[2].bias.data.zero_()
194 |
195 | kaiming_uniform_(self.classifier[0].weight)
196 |
197 | def forward(self, image, question, question_len, dropout=0.15):
198 | b_size = question.size(0)
199 |
200 | img = self.conv(image)
201 | img = img.view(b_size, self.dim, -1)
202 |
203 | embed = self.embed(question)
204 | embed = nn.utils.rnn.pack_padded_sequence(embed, question_len,
205 | batch_first=True)
206 | lstm_out, (h, _) = self.lstm(embed)
207 | lstm_out, _ = nn.utils.rnn.pad_packed_sequence(lstm_out,
208 | batch_first=True)
209 | lstm_out = self.lstm_proj(lstm_out)
210 | h = h.permute(1, 0, 2).contiguous().view(b_size, -1)
211 |
212 | memory = self.mac(lstm_out, h, img)
213 |
214 | out = torch.cat([memory, h], 1)
215 | out = self.classifier(out)
216 |
217 | return out
--------------------------------------------------------------------------------
/model_gqa.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | from block.models.networks.fusions.fusions import Tucker
4 | from torch import nn
5 | from torch.nn.init import kaiming_uniform_, xavier_uniform_
6 |
7 | import config
8 | from utils import get_or_load_embeddings
9 |
10 |
11 | def linear(in_dim, out_dim, bias=True):
12 | lin = nn.Linear(in_dim, out_dim, bias=bias)
13 | xavier_uniform_(lin.weight)
14 | if bias:
15 | lin.bias.data.zero_()
16 |
17 | return lin
18 |
19 | class ControlUnit(nn.Module):
20 | def __init__(self, dim, max_step):
21 | super().__init__()
22 |
23 | self.position_aware = nn.ModuleList()
24 | for i in range(max_step):
25 | self.position_aware.append(linear(dim * 2, dim))
26 |
27 | self.control_question = linear(dim * 2, dim)
28 | self.attn = linear(dim, 1)
29 |
30 | self.dim = dim
31 |
32 | def forward(self, step, context, question, control):
33 | position_aware = self.position_aware[step](question)
34 |
35 | control_question = torch.cat([control, position_aware], 1)
36 | control_question = self.control_question(control_question)
37 | control_question = control_question.unsqueeze(1)
38 |
39 | context_prod = control_question * context
40 | attn_weight = self.attn(context_prod)
41 |
42 | attn = F.softmax(attn_weight, 1)
43 |
44 | next_control = (attn * context).sum(1)
45 |
46 | return next_control
47 |
48 |
49 | class ReadUnit(nn.Module):
50 | def __init__(self, dim):
51 | super().__init__()
52 |
53 | self.mem = linear(dim, dim)
54 | self.concat = linear(dim * 2, dim)
55 | self.attn = linear(dim, 1)
56 | self.tucker = Tucker((2048, 2048), 1, mm_dim=50, shared=True)
57 |
58 | def forward(self, memory, know, control):
59 | mem = self.mem(memory[-1]).unsqueeze(2)
60 | s_matrix = (mem * know)
61 | s_matrix = s_matrix.view(-1, 2048)
62 | attn = self.tucker([s_matrix, control[-1].repeat(know.size(2), 1)]).view(know.size(2), know.size(0))
63 | attn = attn.transpose(0, 1)
64 | attn = F.softmax(attn, 1).unsqueeze(1)
65 | read = (attn * know).sum(2)
66 |
67 | return read
68 |
69 |
70 | class WriteUnit(nn.Module):
71 | def __init__(self, dim, self_attention=False, memory_gate=False):
72 | super().__init__()
73 |
74 | self.concat = linear(dim * 2, dim)
75 |
76 | if self_attention:
77 | self.attn = linear(dim, 1)
78 | self.mem = linear(dim, dim)
79 |
80 | if memory_gate:
81 | self.control = linear(dim, 1)
82 |
83 | self.self_attention = self_attention
84 | self.memory_gate = memory_gate
85 |
86 | def forward(self, memories, retrieved, controls):
87 | prev_mem = memories[-1]
88 | concat = self.concat(torch.cat([retrieved, prev_mem], 1))
89 | next_mem = concat
90 |
91 | if self.self_attention:
92 | controls_cat = torch.stack(controls[:-1], 2)
93 | attn = controls[-1].unsqueeze(2) * controls_cat
94 | attn = self.attn(attn.permute(0, 2, 1))
95 | attn = F.softmax(attn, 1).permute(0, 2, 1)
96 |
97 | memories_cat = torch.stack(memories, 2)
98 | attn_mem = (attn * memories_cat).sum(2)
99 | next_mem = self.mem(attn_mem) + concat
100 |
101 | if self.memory_gate:
102 | control = self.control(controls[-1])
103 | gate = F.sigmoid(control)
104 | next_mem = gate * prev_mem + (1 - gate) * next_mem
105 |
106 | return next_mem
107 |
108 |
109 | class MACUnit(nn.Module):
110 | def __init__(self, dim, max_step=12,
111 | self_attention=False, memory_gate=False,
112 | dropout=0.15):
113 | super().__init__()
114 |
115 | self.control = ControlUnit(dim, max_step)
116 | self.read = ReadUnit(dim)
117 | self.write = WriteUnit(dim, self_attention, memory_gate)
118 |
119 | self.mem_0 = nn.Parameter(torch.zeros(1, dim))
120 | self.control_0 = nn.Parameter(torch.zeros(1, dim))
121 |
122 | self.dim = dim
123 | self.max_step = max_step
124 | self.dropout = dropout
125 | self.dropouts = {}
126 | self.dropouts["encInput"]: config.encInputDropout
127 | self.dropouts["encState"]: config.encStateDropout
128 | self.dropouts["stem"]: config.stemDropout
129 | self.dropouts["question"]: config.qDropout
130 | self.dropouts["memory"]: config.memoryDropout
131 | self.dropouts["read"]: config.readDropout
132 | self.dropouts["write"]: config.writeDropout
133 | self.dropouts["output"]: config.outputDropout
134 | self.dropouts["controlPre"]: config.controlPreDropout
135 | self.dropouts["controlPost"]: config.controlPostDropout
136 | self.dropouts["wordEmb"]: config.wordEmbDropout
137 | self.dropouts["word"]: config.wordDp
138 | self.dropouts["vocab"]: config.vocabDp
139 | self.dropouts["object"]: config.objectDp
140 | self.dropouts["wordStandard"]: config.wordStandardDp
141 |
142 | def get_mask(self, x, dropout):
143 | mask = torch.empty_like(x).bernoulli_(1 - dropout)
144 | mask = mask / (1 - dropout)
145 |
146 | return mask
147 |
148 | def forward(self, context, question, knowledge):
149 | b_size = question.size(0)
150 |
151 | control = self.control_0.expand(b_size, self.dim)
152 | memory = self.mem_0.expand(b_size, self.dim)
153 |
154 | if self.training:
155 | control_mask = self.get_mask(control, self.dropout)
156 | memory_mask = self.get_mask(memory, self.dropout)
157 | control = control * control_mask
158 | memory = memory * memory_mask
159 |
160 | controls = [control]
161 | memories = [memory]
162 |
163 | for i in range(self.max_step):
164 | control = self.control(i, context, question, control)
165 | if self.training:
166 | control = control * control_mask
167 | controls.append(control)
168 |
169 | read = self.read(memories, knowledge, controls)
170 | memory = self.write(memories, read, controls)
171 | if self.training:
172 | memory = memory * memory_mask
173 | memories.append(memory)
174 |
175 | return memory
176 |
177 |
178 | class MACNetwork(nn.Module):
179 | def __init__(self, n_vocab, dim, embed_hidden=300,
180 | max_step=12, self_attention=False, memory_gate=False,
181 | classes=28, dropout=0.15):
182 | super().__init__()
183 |
184 | self.conv = nn.Sequential(nn.Conv2d(1024, dim, 3, padding=1),
185 | nn.ELU(),
186 | nn.Conv2d(dim, dim, 3, padding=1),
187 | nn.ELU())
188 |
189 | self.embed = nn.Embedding(n_vocab, embed_hidden)
190 | self.embed.weight.data = torch.Tensor(get_or_load_embeddings())
191 | self.embed.weight.requires_grad = False
192 | self.lstm = nn.LSTM(embed_hidden, dim,
193 | batch_first=True, bidirectional=True)
194 | self.lstm_proj = nn.Linear(dim * 2, dim)
195 |
196 | self.mac = MACUnit(dim, max_step,
197 | self_attention, memory_gate, dropout)
198 |
199 |
200 | self.classifier = nn.Sequential(linear(dim * 3, dim),
201 | nn.ELU(),
202 | linear(dim, classes))
203 |
204 | self.max_step = max_step
205 | self.dim = dim
206 |
207 | self.reset()
208 |
209 | def reset(self):
210 | self.embed.weight.data.uniform_(0, 1)
211 |
212 | kaiming_uniform_(self.conv[0].weight)
213 | self.conv[0].bias.data.zero_()
214 | kaiming_uniform_(self.conv[2].weight)
215 | self.conv[2].bias.data.zero_()
216 |
217 | kaiming_uniform_(self.classifier[0].weight)
218 |
219 | def forward(self, image, question, question_len, dropout=0.15):
220 | b_size = question.size(0)
221 |
222 | img = image
223 | img = img.view(b_size, self.dim, -1)
224 |
225 | embed = self.embed(question)
226 | embed = nn.utils.rnn.pack_padded_sequence(embed, question_len,
227 | batch_first=True)
228 | lstm_out, (h, _) = self.lstm(embed)
229 | lstm_out, _ = nn.utils.rnn.pad_packed_sequence(lstm_out, batch_first=True)
230 | lstm_out = self.lstm_proj(lstm_out)
231 | h = h.permute(1, 0, 2).contiguous().view(b_size, -1)
232 |
233 | memory = self.mac(lstm_out, h, img)
234 |
235 | out = torch.cat([memory, h], 1)
236 | out = self.classifier(out)
237 |
238 | return out
--------------------------------------------------------------------------------
/preprocess.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import json
4 | import pickle
5 | import jieba
6 | import nltk
7 | import tqdm
8 | from torchvision import transforms
9 | from PIL import Image
10 | from transforms import Scale
11 | import MeCab
12 | from konlpy.tag import Okt
13 | image_index = {'CLEVR': 'image_filename',
14 | 'gqa': 'imageId'}
15 |
16 |
17 | def process_question(root, split, word_dic=None, answer_dic=None, dataset_type='CLEVR',lang='en'):
18 | if word_dic is None:
19 | word_dic = {}
20 |
21 | if answer_dic is None:
22 | answer_dic = {}
23 |
24 | with open(os.path.join(root, 'questions', f'mini_{dataset_type}_{split}_questions.json'), encoding='utf-8') as f:
25 | data = json.load(f)
26 |
27 | result = []
28 | word_index = 1
29 | answer_index = 0
30 |
31 | for question in tqdm.tqdm(data['questions']):
32 | if lang == 'en':
33 | words = nltk.word_tokenize(question['question'])
34 | elif lang =='vi':
35 | words = nltk.word_tokenize(question['translated']['vietnamese'])
36 | elif lang == 'zh':
37 | words = jieba.lcut(question['translated']['chinese'])
38 | elif lang == 'ja':
39 | mecab = MeCab.Tagger("-Owakati")
40 | words = mecab.parse(question['translated']['japanese']).strip().split()
41 | elif lang == 'ko':
42 | okt = Okt()
43 | words = okt.morphs(question['translated']['korean'])
44 | question_token = []
45 |
46 | for word in words:
47 | try:
48 | question_token.append(word_dic[word])
49 |
50 | except:
51 | question_token.append(word_index)
52 | word_dic[word] = word_index
53 | word_index += 1
54 |
55 | answer_word = question['answer']
56 |
57 | try:
58 | answer = answer_dic[answer_word]
59 | except:
60 | answer = answer_index
61 | answer_dic[answer_word] = answer_index
62 | answer_index += 1
63 |
64 | result.append((question[image_index[dataset_type]], question_token, answer))
65 |
66 | with open(f'data/{dataset_type}_{split}_{lang}.pkl', 'wb') as f:
67 | pickle.dump(result, f)
68 |
69 | return word_dic, answer_dic
70 |
71 |
72 | if __name__ == '__main__':
73 | dataset_type = sys.argv[1]
74 | root = sys.argv[2]
75 | lang = sys.argv[3]
76 | word_dic, answer_dic = process_question(root, 'train', dataset_type=dataset_type)
77 | process_question(root, 'val', word_dic, answer_dic, dataset_type=dataset_type)
78 |
79 | with open(f'data/{dataset_type}_{lang}_dic.pkl', 'wb') as f:
80 | pickle.dump({'word_dic': word_dic, 'answer_dic': answer_dic}, f)
81 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | Pillow
2 | nltk
3 | tqdm
4 | block.bootstrap.pytorch
5 | murel.bootstrap.pytorch
6 | konlpy
7 | mecab-python3
8 | gensim
9 |
--------------------------------------------------------------------------------
/setup.sh:
--------------------------------------------------------------------------------
1 | cd data
2 | mkdir gqa && cd gqa
3 | wget https://nlp.stanford.edu/data/gqa/data1.2.zip
4 | unzip data1.2.zip
5 | mkdir questions
6 | mv balanced_train_data.json questions/gqa_train_questions.json
7 | mv balanced_val_data.json questions/gqa_val_questions.json
8 | mv balanced_testdev_data.json questions/gqa_testdev_questions.json
9 | cd ..
10 |
11 |
12 | wget http://nlp.stanford.edu/data/glove.6B.zip
13 | unzip glove.6B.zip
14 | wget http://nlp.stanford.edu/data/gqa/objectFeatures.zip
15 | unzip objectFeatures.zip
16 | cd ..
17 |
18 | python merge.py --name objects
19 | mv data/gqa_objects.hdf5 data/gqa_features.hdf5
20 |
21 | python preprocess.py gqa data/gqa
22 |
--------------------------------------------------------------------------------
/test.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import pickle
3 | from collections import Counter
4 |
5 | import torch
6 | from tqdm import tqdm
7 | from torch.utils.data import DataLoader
8 | from dataset import CLEVR, collate_data, transform
9 |
10 | batch_size = 64
11 | n_epoch = 180
12 |
13 | train_set = DataLoader(
14 | CLEVR(sys.argv[1], 'val', transform=None),
15 | batch_size=batch_size,
16 | num_workers=4,
17 | collate_fn=collate_data,
18 | )
19 | net = torch.load(sys.argv[2])
20 | net.eval()
21 |
22 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
23 |
24 | for epoch in range(n_epoch):
25 | dataset = iter(train_set)
26 | pbar = tqdm(dataset)
27 | correct_counts = 0
28 | total_counts = 0
29 |
30 | for image, question, q_len, answer in pbar:
31 | image, question = image.to(device), question.to(device)
32 | output = net(image, question, q_len)
33 | correct = output.detach().argmax(1) == answer.to(device)
34 | for c in correct:
35 | if c:
36 | correct_counts += 1
37 | total_counts += 1
38 |
39 | print('Avg Acc: {:.5f}'.format(correct_counts / total_counts))
40 |
41 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import multiprocessing
2 | import pickle
3 | import sys
4 |
5 | import torch
6 | from torch import nn
7 | from torch import optim
8 | from torch.utils.data import DataLoader
9 | from tqdm import tqdm
10 |
11 | from dataset import CLEVR, collate_data, transform, GQA
12 | from model_gqa import MACNetwork
13 |
14 | batch_size = 128
15 | n_epoch = 25
16 | dim_dict = {'CLEVR': 512,
17 | 'gqa': 2048}
18 |
19 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
20 |
21 |
22 | def accumulate(model1, model2, decay=0.999):
23 | par1 = dict(model1.named_parameters())
24 | par2 = dict(model2.named_parameters())
25 |
26 | for k in par1.keys():
27 | par1[k].data.mul_(decay).add_(1 - decay, par2[k].data)
28 |
29 |
30 | def train(epoch, dataset_type, lang='en'):
31 | if dataset_type == "CLEVR":
32 | dataset_object = CLEVR('data/CLEVR', transform=transform, lang=lang)
33 | else:
34 | dataset_object = GQA('data/gqa', transform=transform)
35 |
36 | train_set = DataLoader(
37 | dataset_object, batch_size=batch_size, num_workers=multiprocessing.cpu_count(), collate_fn=collate_data
38 | )
39 |
40 | dataset = iter(train_set)
41 | pbar = tqdm(dataset)
42 | moving_loss = 0
43 |
44 | net.train(True)
45 | for image, question, q_len, answer in pbar:
46 | image, question, answer = (
47 | image.to(device),
48 | question.to(device),
49 | answer.to(device),
50 | )
51 |
52 | net.zero_grad()
53 | output = net(image, question, q_len)
54 | loss = criterion(output, answer)
55 | loss.backward()
56 | optimizer.step()
57 | correct = output.detach().argmax(1) == answer
58 | correct = torch.tensor(correct, dtype=torch.float32).sum() / batch_size
59 |
60 | if moving_loss == 0:
61 | moving_loss = correct
62 | else:
63 | moving_loss = moving_loss * 0.99 + correct * 0.01
64 |
65 | pbar.set_description('Epoch: {}; Loss: {:.8f}; Acc: {:.5f}'.format(epoch + 1, loss.item(), moving_loss))
66 |
67 | accumulate(net_running, net)
68 |
69 | dataset_object.close()
70 |
71 |
72 | def valid(epoch, dataset_type,lang='en'):
73 | if dataset_type == "CLEVR":
74 | dataset_object = CLEVR('data/CLEVR', 'val', transform=None, lang=lang)
75 | else:
76 | dataset_object = GQA('data/gqa', 'val', transform=None)
77 |
78 | valid_set = DataLoader(
79 | dataset_object, batch_size=4*batch_size, num_workers=multiprocessing.cpu_count(), collate_fn=collate_data
80 | )
81 | dataset = iter(valid_set)
82 |
83 | net_running.train(False)
84 | correct_counts = 0
85 | total_counts = 0
86 | running_loss = 0.0
87 | batches_done = 0
88 | with torch.no_grad():
89 | pbar = tqdm(dataset)
90 | for image, question, q_len, answer in pbar:
91 | image, question, answer = (
92 | image.to(device),
93 | question.to(device),
94 | answer.to(device),
95 | )
96 |
97 | output = net_running(image, question, q_len)
98 | loss = criterion(output, answer)
99 | correct = output.detach().argmax(1) == answer
100 | running_loss += loss.item()
101 |
102 | batches_done += 1
103 | for c in correct:
104 | if c:
105 | correct_counts += 1
106 | total_counts += 1
107 |
108 | pbar.set_description('Epoch: {}; Loss: {:.8f}; Acc: {:.5f}'.format(epoch + 1, loss.item(), correct_counts / batches_done))
109 |
110 | with open('log/log_{}.txt'.format(str(epoch + 1).zfill(2)), 'w') as w:
111 | w.write('{:.5f}\n'.format(correct_counts / total_counts))
112 |
113 | print('Validation Accuracy: {:.5f}'.format(correct_counts / total_counts))
114 | print('Validation Loss: {:.8f}'.format(running_loss / total_counts))
115 |
116 | dataset_object.close()
117 |
118 |
119 | if __name__ == '__main__':
120 | dataset_type = sys.argv[1]
121 | lang = sys.argv[2]
122 | with open(f'data/{dataset_type}_dic.pkl', 'rb') as f:
123 | dic = pickle.load(f)
124 |
125 | n_words = len(dic['word_dic']) + 1
126 | n_answers = len(dic['answer_dic'])
127 |
128 | net = MACNetwork(n_words, dim_dict[dataset_type], classes=n_answers, max_step=4).to(device)
129 | net_running = MACNetwork(n_words, dim_dict[dataset_type], classes=n_answers, max_step=4).to(device)
130 | accumulate(net_running, net, 0)
131 |
132 | criterion = nn.CrossEntropyLoss()
133 | optimizer = optim.Adam(net.parameters(), lr=1e-4)
134 |
135 | for epoch in range(n_epoch):
136 | train(epoch, dataset_type,lang=lang)
137 | valid(epoch, dataset_type,lang=lang)
138 |
139 | with open(f'checkpoint/checkpoint_{lang}.model'.format(str(epoch + 1).zfill(2)), 'wb') as f:
140 | torch.save(net_running.state_dict(), f)
141 |
--------------------------------------------------------------------------------
/transforms.py:
--------------------------------------------------------------------------------
1 | """ Shamelessly take from newer torchvision repository """
2 |
3 | from PIL import Image
4 | import collections
5 | collections.Iterable = collections.abc.Iterable
6 | class Scale(object):
7 | """Rescale the input PIL.Image to the given size.
8 |
9 | Args:
10 | size (sequence or int): Desired output size. If size is a sequence like
11 | (w, h), output size will be matched to this. If size is an int,
12 | smaller edge of the image will be matched to this number.
13 | i.e, if height > width, then image will be rescaled to
14 | (size * height / width, size)
15 | interpolation (int, optional): Desired interpolation. Default is
16 | ``PIL.Image.BILINEAR``
17 | """
18 |
19 | def __init__(self, size, interpolation=Image.BILINEAR):
20 | assert isinstance(size, int) or (isinstance(size, collections.Iterable) and len(size) == 2)
21 | self.size = size
22 | self.interpolation = interpolation
23 |
24 | def __call__(self, img):
25 | """
26 | Args:
27 | img (PIL.Image): Image to be scaled.
28 |
29 | Returns:
30 | PIL.Image: Rescaled image.
31 | """
32 | if isinstance(self.size, int):
33 | w, h = img.size
34 | if (w <= h and w == self.size) or (h <= w and h == self.size):
35 | return img
36 | if w < h:
37 | ow = self.size
38 | oh = int(self.size * h / w)
39 | return img.resize((ow, oh), self.interpolation)
40 | else:
41 | oh = self.size
42 | ow = int(self.size * w / h)
43 | return img.resize((ow, oh), self.interpolation)
44 | else:
45 | return img.resize(self.size, self.interpolation)
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | # File: utils.py
2 | # Author: Ronil Pancholia
3 | # Date: 4/22/19
4 | # Time: 7:57 PM
5 | import pickle
6 | import sys
7 |
8 | import numpy as np
9 | import jieba
10 |
11 | from gensim.models import KeyedVectors
12 | id2word = []
13 | embedding_weights = None
14 |
15 | def get_or_load_embeddings(lang='en'):
16 | """
17 | parameters:
18 | lang: language of the embeding
19 |
20 | returns:
21 | embedding matrix
22 | """
23 | global embedding_weights, id2word
24 |
25 | if embedding_weights is not None:
26 | return embedding_weights
27 |
28 | dataset_type = sys.argv[1]
29 | with open(f'data/{dataset_type}_dic.pkl', 'rb') as f:
30 | dic = pickle.load(f)
31 |
32 | id2word = set(dic['word_dic'].keys())
33 | id2word.update(set(dic['answer_dic'].keys()))
34 |
35 | word2id = {word: id for id, word in enumerate(id2word)}
36 |
37 | embed_size = 300
38 | vocab_size = len(id2word)
39 | sd = 1 / np.sqrt(embed_size)
40 | embedding_weights = np.random.normal(0, scale=sd, size=[vocab_size, embed_size])
41 | embedding_weights = embedding_weights.astype(np.float32)
42 | if lang == 'en':
43 | with open("/kaggle/input/phow2v/glove.6B.300d.txt", encoding="utf-8", mode="r") as textFile:
44 | for line in textFile:
45 | line = line.split()
46 | word = line[0]
47 |
48 | id = word2id.get(word, None)
49 | if id is not None:
50 | embedding_weights[id] = np.array(line[1:], dtype=np.float32)
51 | elif lang == 'vi':
52 | with open("/kaggle/input/phow2v/word2vec_vi_words_300dims.txt", encoding="utf-8", mode="r") as textFile:
53 | for line in textFile:
54 | line = line.split()
55 | try:
56 | float(line[1])
57 | word = line[0]
58 | id = word2id.get(word, None)
59 | if id is not None:
60 | embedding_weights[id] = np.array(line[1:], dtype=np.float32)
61 | except:
62 | word = '_'.join(line[:2])
63 | id = word2id.get(word, None)
64 | if id is not None:
65 | embedding_weights[id] = np.array(line[2:], dtype=np.float32)
66 | elif lang == 'zh':
67 | fb_model = KeyedVectors.load_word2vec_format('/kaggle/input/phow2v/cc.zh.300.vec')
68 | for word, vector in fb_model.items():
69 | id = word2id.get(word, None)
70 | if id is not None:
71 | embedding_weights[id] = np.array(vector, dtype=np.float32)
72 | elif lang == 'ja':
73 | fb_model = KeyedVectors.load_word2vec_format('/kaggle/input/phow2v/cc.ja.300.vec')
74 | for word, vector in fb_model.items():
75 | id = word2id.get(word, None)
76 | if id is not None:
77 | embedding_weights[id] = np.array(vector, dtype=np.float32)
78 | elif lang == 'ko':
79 | fb_model = KeyedVectors.load_word2vec_format('/kaggle/input/phow2v/cc.ko.300.vec')
80 | for word, vector in fb_model.items():
81 | id = word2id.get(word, None)
82 | if id is not None:
83 | embedding_weights[id] = np.array(vector, dtype=np.float32)
84 | return embedding_weights
85 |
--------------------------------------------------------------------------------