├── .gitignore ├── LICENSE ├── README.md ├── bench.py ├── checkpoint └── .gitignore ├── config.py ├── data └── .gitignore ├── dataset.py ├── image_feature.py ├── log └── .gitignore ├── merge.py ├── model.py ├── model_gqa.py ├── preprocess.py ├── requirements.txt ├── setup.sh ├── test.py ├── train.py ├── transforms.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Data files 2 | *.hdf5 3 | *.model 4 | log/*.txt 5 | *.pkl 6 | data/** 7 | w2v/** 8 | # Byte-compiled / optimized / DLL files 9 | __pycache__/ 10 | *.py[cod] 11 | *$py.class 12 | 13 | # C extensions 14 | *.so 15 | 16 | # Distribution / packaging 17 | .Python 18 | env/ 19 | build/ 20 | develop-eggs/ 21 | dist/ 22 | downloads/ 23 | eggs/ 24 | .eggs/ 25 | lib/ 26 | lib64/ 27 | parts/ 28 | sdist/ 29 | var/ 30 | wheels/ 31 | *.egg-info/ 32 | .installed.cfg 33 | *.egg 34 | 35 | # PyInstaller 36 | # Usually these files are written by a python script from a template 37 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 38 | *.manifest 39 | *.spec 40 | 41 | # Installer logs 42 | pip-log.txt 43 | pip-delete-this-directory.txt 44 | 45 | # Unit test / coverage reports 46 | htmlcov/ 47 | .tox/ 48 | .coverage 49 | .coverage.* 50 | .cache 51 | nosetests.xml 52 | coverage.xml 53 | *.cover 54 | .hypothesis/ 55 | 56 | # Translations 57 | *.mo 58 | *.pot 59 | 60 | # Django stuff: 61 | *.log 62 | local_settings.py 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # pyenv 81 | .python-version 82 | 83 | # celery beat schedule file 84 | celerybeat-schedule 85 | 86 | # SageMath parsed files 87 | *.sage.py 88 | 89 | # dotenv 90 | .env 91 | 92 | # virtualenv 93 | .venv 94 | venv/ 95 | ENV/ 96 | 97 | # Spyder project settings 98 | .spyderproject 99 | .spyproject 100 | 101 | # Rope project settings 102 | .ropeproject 103 | 104 | # mkdocs documentation 105 | /site 106 | 107 | # mypy 108 | .mypy_cache/ 109 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 Kim Seonghyeon 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |

mac-network-pytorch

2 |

Memory, Attention and Composition (MAC) Network for CLEVR/GQA from Compositional Attention Networks for Machine Reasoning (https://arxiv.org/abs/1803.03067) implemented in PyTorch

3 |

Requirements:

4 | 13 |

To train:

14 |
    15 |
  1. Download and extract either
    16 | CLEVR v1.0 dataset from http://cs.stanford.edu/people/jcjohns/clevr/ or
    17 | GQA dataset from https://cs.stanford.edu/people/dorarad/gqa/download.html
  2. 18 |
19 |

For GQA

20 |
cd data
21 | mkdir gqa && cd gqa
22 | wget https://nlp.stanford.edu/data/gqa/data1.2.zip
23 | unzip data1.2.zip
24 | 
25 | mkdir questions
26 | mv balanced_train_data.json questions/gqa_train_questions.json
27 | mv balanced_val_data.json questions/gqa_val_questions.json
28 | mv balanced_testdev_data.json questions/gqa_testdev_questions.json
29 | cd ..
30 | 
31 | wget http://nlp.stanford.edu/data/glove.6B.zip
32 | unzip glove.6B.zip
33 | wget http://nlp.stanford.edu/data/gqa/objectFeatures.zip
34 | unzip objectFeatures.zip
35 | cd ..
36 | 
37 |
    38 |
  1. Preprocessing question data and extracting image features using ResNet 101 (Not required for GQA)
    39 | For CLEVR
    40 | a. Extract image features
  2. 41 |
42 |
python image_feature.py data/CLEVR_v1.0
43 | 
44 |

b. Preprocess questions

45 |
python preprocess.py CLEVR data/CLEVR_v1.0
46 | 
47 |

For GQA
48 | a. Merge object features (this may take some time)

49 |
python merge.py --name objects
50 | mv data/gqa_objects.hdf5 data/gqa_features.hdf5
51 | 
52 |

b. Preprocess questions

53 |
python preprocess.py gqa data/gqa
54 | 
55 |

!CAUTION! the size of file created by image_feature.py is very large! You may use hdf5 compression, but it will slow down feature extraction.

56 |
    57 |
  1. Run train.py with dataset type as argument (gqa or CLEVR)
  2. 58 |
59 |
python train.py gqa
60 | 
61 |

CLEVR -> This implementation produces 95.75% accuracy at epoch 10, 96.5% accuracy at epoch 20.

62 |

Parts of the code borrowed from https://github.com/rosinality/mac-network-pytorch and
63 | https://github.com/stanfordnlp/mac-network.

64 | -------------------------------------------------------------------------------- /bench.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import pickle 3 | from collections import Counter 4 | 5 | import numpy as np 6 | import torch 7 | from torch import nn 8 | from torch import optim 9 | from torch.autograd import Variable 10 | from torch.utils.data import DataLoader 11 | from tqdm import tqdm 12 | 13 | from dataset import CLEVR, collate_data, transform 14 | from model_gqa import MACNetwork 15 | 16 | batch_size = 64 17 | n_epoch = 20 18 | dim = 512 19 | 20 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 21 | 22 | def accumulate(model1, model2, decay=0.999): 23 | par1 = dict(model1.named_parameters()) 24 | par2 = dict(model2.named_parameters()) 25 | 26 | for k in par1.keys(): 27 | par1[k].data.mul_(decay).add_(1 - decay, par2[k].data) 28 | 29 | def train(): 30 | moving_loss = 0 31 | 32 | net.train(True) 33 | 34 | image = torch.randn(batch_size, 1024, 14, 14, device=device) 35 | question = torch.randint(0, 28, (batch_size, 30), dtype=torch.int64, device=device) 36 | answer = torch.randint(0, 28, (batch_size,), dtype=torch.int64, device=device) 37 | q_len = torch.tensor([30] * batch_size, dtype=torch.int64, device=device) 38 | 39 | for i in range(30): 40 | net.zero_grad() 41 | output = net(image, question, q_len) 42 | loss = criterion(output, answer) 43 | loss.backward() 44 | optimizer.step() 45 | correct = output.detach().argmax(1) \ 46 | == answer 47 | correct = torch.tensor(correct, dtype=torch.float32).sum() \ 48 | / batch_size 49 | 50 | if moving_loss == 0: 51 | moving_loss = correct 52 | 53 | else: 54 | moving_loss = moving_loss * 0.99 + correct * 0.01 55 | 56 | accumulate(net_running, net) 57 | 58 | if __name__ == '__main__': 59 | with open('data/dic.pkl', 'rb') as f: 60 | dic = pickle.load(f) 61 | 62 | n_words = len(dic['word_dic']) + 1 63 | n_answers = len(dic['answer_dic']) 64 | 65 | net = MACNetwork(n_words, dim).to(device) 66 | net_running = MACNetwork(n_words, dim).to(device) 67 | accumulate(net_running, net, 0) 68 | 69 | criterion = nn.CrossEntropyLoss() 70 | optimizer = optim.Adam(net.parameters(), lr=1e-4) 71 | 72 | train() -------------------------------------------------------------------------------- /checkpoint/.gitignore: -------------------------------------------------------------------------------- 1 | # Data files 2 | *.hdf5 3 | *.model 4 | log/*.txt 5 | *.pkl 6 | 7 | # Byte-compiled / optimized / DLL files 8 | __pycache__/ 9 | *.py[cod] 10 | *$py.class 11 | 12 | # C extensions 13 | *.so 14 | 15 | # Distribution / packaging 16 | .Python 17 | env/ 18 | build/ 19 | develop-eggs/ 20 | dist/ 21 | downloads/ 22 | eggs/ 23 | .eggs/ 24 | lib/ 25 | lib64/ 26 | parts/ 27 | sdist/ 28 | var/ 29 | wheels/ 30 | *.egg-info/ 31 | .installed.cfg 32 | *.egg 33 | 34 | # PyInstaller 35 | # Usually these files are written by a python script from a template 36 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 37 | *.manifest 38 | *.spec 39 | 40 | # Installer logs 41 | pip-log.txt 42 | pip-delete-this-directory.txt 43 | 44 | # Unit test / coverage reports 45 | htmlcov/ 46 | .tox/ 47 | .coverage 48 | .coverage.* 49 | .cache 50 | nosetests.xml 51 | coverage.xml 52 | *.cover 53 | .hypothesis/ 54 | 55 | # Translations 56 | *.mo 57 | *.pot 58 | 59 | # Django stuff: 60 | *.log 61 | local_settings.py 62 | 63 | # Flask stuff: 64 | instance/ 65 | .webassets-cache 66 | 67 | # Scrapy stuff: 68 | .scrapy 69 | 70 | # Sphinx documentation 71 | docs/_build/ 72 | 73 | # PyBuilder 74 | target/ 75 | 76 | # Jupyter Notebook 77 | .ipynb_checkpoints 78 | 79 | # pyenv 80 | .python-version 81 | 82 | # celery beat schedule file 83 | celerybeat-schedule 84 | 85 | # SageMath parsed files 86 | *.sage.py 87 | 88 | # dotenv 89 | .env 90 | 91 | # virtualenv 92 | .venv 93 | venv/ 94 | ENV/ 95 | 96 | # Spyder project settings 97 | .spyderproject 98 | .spyproject 99 | 100 | # Rope project settings 101 | .ropeproject 102 | 103 | # mkdocs documentation 104 | /site 105 | 106 | # mypy 107 | .mypy_cache/ 108 | -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | # File: config.py 2 | # Author: Ronil Pancholia 3 | # Date: 4/20/19 4 | # Time: 6:37 PM 5 | 6 | ## dropouts 7 | encInputDropout = 0.2 # dropout of the rnn inputs to the Question Input Unit 8 | encStateDropout = 0.0 # dropout of the rnn states of the Question Input Unit 9 | stemDropout = 0.2 # dropout of the Image Input Unit (the stem) 10 | qDropout = 0.08 # dropout on the question vector 11 | qDropoutOut = 0 # dropout on the question vector the goes to the output unit 12 | memoryDropout = 0.15 # dropout on the recurrent memory 13 | readDropout = 0.15 # dropout of the read unit 14 | writeDropout = 1.0 # dropout of the write unit 15 | outputDropout = 0.85 # dropout of the output unit 16 | controlPreDropout = 1.0 # dropout of the write unit 17 | controlPostDropout = 1.0 # dropout of the write unit 18 | wordEmbDropout = 1.0 # dropout of the write unit -------------------------------------------------------------------------------- /data/.gitignore: -------------------------------------------------------------------------------- 1 | # Data files 2 | *.hdf5 3 | *.model 4 | log/*.txt 5 | *.pkl 6 | 7 | # Byte-compiled / optimized / DLL files 8 | __pycache__/ 9 | *.py[cod] 10 | *$py.class 11 | 12 | # C extensions 13 | *.so 14 | 15 | # Distribution / packaging 16 | .Python 17 | env/ 18 | build/ 19 | develop-eggs/ 20 | dist/ 21 | downloads/ 22 | eggs/ 23 | .eggs/ 24 | lib/ 25 | lib64/ 26 | parts/ 27 | sdist/ 28 | var/ 29 | wheels/ 30 | *.egg-info/ 31 | .installed.cfg 32 | *.egg 33 | 34 | # PyInstaller 35 | # Usually these files are written by a python script from a template 36 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 37 | *.manifest 38 | *.spec 39 | 40 | # Installer logs 41 | pip-log.txt 42 | pip-delete-this-directory.txt 43 | 44 | # Unit test / coverage reports 45 | htmlcov/ 46 | .tox/ 47 | .coverage 48 | .coverage.* 49 | .cache 50 | nosetests.xml 51 | coverage.xml 52 | *.cover 53 | .hypothesis/ 54 | 55 | # Translations 56 | *.mo 57 | *.pot 58 | 59 | # Django stuff: 60 | *.log 61 | local_settings.py 62 | 63 | # Flask stuff: 64 | instance/ 65 | .webassets-cache 66 | 67 | # Scrapy stuff: 68 | .scrapy 69 | 70 | # Sphinx documentation 71 | docs/_build/ 72 | 73 | # PyBuilder 74 | target/ 75 | 76 | # Jupyter Notebook 77 | .ipynb_checkpoints 78 | 79 | # pyenv 80 | .python-version 81 | 82 | # celery beat schedule file 83 | celerybeat-schedule 84 | 85 | # SageMath parsed files 86 | *.sage.py 87 | 88 | # dotenv 89 | .env 90 | 91 | # virtualenv 92 | .venv 93 | venv/ 94 | ENV/ 95 | 96 | # Spyder project settings 97 | .spyderproject 98 | .spyproject 99 | 100 | # Rope project settings 101 | .ropeproject 102 | 103 | # mkdocs documentation 104 | /site 105 | 106 | # mypy 107 | .mypy_cache/ 108 | -------------------------------------------------------------------------------- /dataset.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import pickle 4 | 5 | import numpy as np 6 | from PIL import Image 7 | import torch 8 | from torch.utils.data import Dataset 9 | from torchvision import transforms 10 | import h5py 11 | 12 | from transforms import Scale 13 | 14 | class CLEVR(Dataset): 15 | def __init__(self, root, split='train', transform=None, lang='en'): 16 | with open(f'data/CLEVR_{split}_{lang}.pkl', 'rb') as f: 17 | self.data = pickle.load(f) 18 | 19 | # self.transform = transform 20 | self.root = root 21 | self.split = split 22 | with open(f'mini_CLEVR_{split}_questions_translated.json') as f: 23 | data = json.load(f) 24 | self.img_idx_map = {} 25 | i=0 26 | for question in data['questions']: 27 | if question['image_index'] not in self.img_idx_map.keys(): 28 | self.img_idx_map[question['image_index']] = i 29 | i+=1 30 | self.idx_img_map = {v:k for k,v in self.img_idx_map.items()} 31 | self.h = h5py.File(f'data/CLEVR_features.hdf5'.format(split), 'r') 32 | self.img = self.h['data'] 33 | 34 | def close(self): 35 | self.h.close() 36 | 37 | def __getitem__(self, index): 38 | imgfile, question, answer = self.data[index] 39 | # img = Image.open(os.path.join(self.root, 'images', 40 | # self.split, imgfile)).convert('RGB') 41 | 42 | # img = self.transform(img) 43 | id = int(imgfile.rsplit('_', 1)[1][:-4]) 44 | img = torch.from_numpy(self.img[self.img_idx_map[id]]) 45 | 46 | return img, question, len(question), answer 47 | 48 | def __len__(self): 49 | return len(self.data) 50 | 51 | class GQA(Dataset): 52 | def __init__(self, root, split='train', transform=None): 53 | with open(f'data/gqa_{split}.pkl', 'rb') as f: 54 | self.data = pickle.load(f) 55 | 56 | self.root = root 57 | self.split = split 58 | 59 | self.h = h5py.File('data/gqa_features.hdf5'.format(split), 'r') 60 | self.img = self.h['features'] 61 | self.img_info = json.load(open('data/gqa_objects_merged_info.json', 'r')) 62 | 63 | def close(self): 64 | self.h.close() 65 | 66 | def __getitem__(self, index): 67 | imgfile, question, answer = self.data[index] 68 | idx = int(self.img_info[imgfile]['index']) 69 | img = torch.from_numpy(self.img[idx]) 70 | return img, question, len(question), answer 71 | 72 | def __len__(self): 73 | return len(self.data) 74 | 75 | transform = transforms.Compose([ 76 | Scale([224, 224]), 77 | transforms.Pad(4), 78 | transforms.RandomCrop([224, 224]), 79 | transforms.ToTensor(), 80 | transforms.Normalize(mean=[0.5, 0.5, 0.5], 81 | std=[0.5, 0.5, 0.5]) 82 | ]) 83 | 84 | def collate_data(batch): 85 | images, lengths, answers = [], [], [] 86 | batch_size = len(batch) 87 | 88 | max_len = max(map(lambda x: len(x[1]), batch)) 89 | 90 | questions = np.zeros((batch_size, max_len), dtype=np.int64) 91 | sort_by_len = sorted(batch, key=lambda x: len(x[1]), reverse=True) 92 | 93 | for i, b in enumerate(sort_by_len): 94 | image, question, length, answer = b 95 | images.append(image) 96 | length = len(question) 97 | questions[i, :length] = question 98 | lengths.append(length) 99 | answers.append(answer) 100 | 101 | return torch.stack(images), torch.from_numpy(questions), \ 102 | lengths, torch.LongTensor(answers) 103 | -------------------------------------------------------------------------------- /image_feature.py: -------------------------------------------------------------------------------- 1 | import h5py 2 | import torch 3 | from torchvision.models.resnet import ResNet, resnet101 4 | from torchvision import transforms 5 | from torch.utils.data import Dataset, DataLoader 6 | from transforms import Scale 7 | import sys 8 | import os 9 | from PIL import Image 10 | from tqdm import tqdm 11 | from torch.autograd import Variable 12 | import json 13 | 14 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 15 | 16 | def forward(self, x): 17 | x = self.conv1(x) 18 | x = self.bn1(x) 19 | x = self.relu(x) 20 | x = self.maxpool(x) 21 | 22 | x = self.layer1(x) 23 | x = self.layer2(x) 24 | x = self.layer3(x) 25 | 26 | return x 27 | 28 | transform = transforms.Compose([ 29 | Scale([224, 224]), 30 | transforms.ToTensor(), 31 | transforms.Normalize(mean=[0.485, 0.456, 0.406], 32 | std=[0.229, 0.224, 0.225])]) 33 | 34 | class CLEVR(Dataset): 35 | def __init__(self, root, split='train'): 36 | self.root = root 37 | self.split = split 38 | self.length = len(os.listdir(os.path.join(root, 39 | 'images', split))) 40 | 41 | with open(f'mini_CLEVR_{split}_questions.json') as f: 42 | data = json.load(f) 43 | self.img_idx_map = {} 44 | i=0 45 | for question in data['questions']: 46 | if question['image_index'] not in self.img_idx_map.keys(): 47 | self.img_idx_map[question['image_index']] = i 48 | i+=1 49 | self.idx_img_map = {v:k for k,v in self.img_idx_map.items()} 50 | 51 | def __getitem__(self, index): 52 | img = os.path.join(self.root, 'images', 53 | self.split, 54 | 'CLEVR_{}_{}.png'.format(self.split, 55 | str(self.idx_img_map[index]).zfill(6))) 56 | img = Image.open(img).convert('RGB') 57 | return transform(img) 58 | 59 | def __len__(self): 60 | return self.length 61 | 62 | batch_size = 50 63 | 64 | resnet = resnet101(True).to(device) 65 | resnet.eval() 66 | resnet.forward = forward.__get__(resnet, ResNet) 67 | 68 | def create_dataset(split): 69 | dataloader = DataLoader(CLEVR(sys.argv[1], split), batch_size=batch_size, 70 | num_workers=4) 71 | 72 | size = len(dataloader) 73 | 74 | print(split, 'total', size * batch_size) 75 | 76 | f = h5py.File('data/CLEVR_features.hdf5'.format(split), 'w', libver='latest') 77 | dset = f.create_dataset('data', (size * batch_size, 1024, 14, 14), 78 | dtype='f4') 79 | 80 | with torch.no_grad(): 81 | for i, image in tqdm(enumerate(dataloader)): 82 | image = image.to(device) 83 | features = resnet(image).detach().cpu().numpy() 84 | try: 85 | dset[i * batch_size:(i + 1) * batch_size] = features 86 | except: 87 | dset[i * batch_size:i * batch_size+features.shape[0]] = features 88 | 89 | f.close() 90 | 91 | create_dataset('val') 92 | create_dataset('train') -------------------------------------------------------------------------------- /log/.gitignore: -------------------------------------------------------------------------------- 1 | # Data files 2 | *.hdf5 3 | *.model 4 | log/*.txt 5 | *.pkl 6 | 7 | # Byte-compiled / optimized / DLL files 8 | __pycache__/ 9 | *.py[cod] 10 | *$py.class 11 | 12 | # C extensions 13 | *.so 14 | 15 | # Distribution / packaging 16 | .Python 17 | env/ 18 | build/ 19 | develop-eggs/ 20 | dist/ 21 | downloads/ 22 | eggs/ 23 | .eggs/ 24 | lib/ 25 | lib64/ 26 | parts/ 27 | sdist/ 28 | var/ 29 | wheels/ 30 | *.egg-info/ 31 | .installed.cfg 32 | *.egg 33 | 34 | # PyInstaller 35 | # Usually these files are written by a python script from a template 36 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 37 | *.manifest 38 | *.spec 39 | 40 | # Installer logs 41 | pip-log.txt 42 | pip-delete-this-directory.txt 43 | 44 | # Unit test / coverage reports 45 | htmlcov/ 46 | .tox/ 47 | .coverage 48 | .coverage.* 49 | .cache 50 | nosetests.xml 51 | coverage.xml 52 | *.cover 53 | .hypothesis/ 54 | 55 | # Translations 56 | *.mo 57 | *.pot 58 | 59 | # Django stuff: 60 | *.log 61 | local_settings.py 62 | 63 | # Flask stuff: 64 | instance/ 65 | .webassets-cache 66 | 67 | # Scrapy stuff: 68 | .scrapy 69 | 70 | # Sphinx documentation 71 | docs/_build/ 72 | 73 | # PyBuilder 74 | target/ 75 | 76 | # Jupyter Notebook 77 | .ipynb_checkpoints 78 | 79 | # pyenv 80 | .python-version 81 | 82 | # celery beat schedule file 83 | celerybeat-schedule 84 | 85 | # SageMath parsed files 86 | *.sage.py 87 | 88 | # dotenv 89 | .env 90 | 91 | # virtualenv 92 | .venv 93 | venv/ 94 | ENV/ 95 | 96 | # Spyder project settings 97 | .spyderproject 98 | .spyproject 99 | 100 | # Rope project settings 101 | .ropeproject 102 | 103 | # mkdocs documentation 104 | /site 105 | 106 | # mypy 107 | .mypy_cache/ 108 | -------------------------------------------------------------------------------- /merge.py: -------------------------------------------------------------------------------- 1 | # File: merge.py 2 | # Author: Ronil Pancholia 3 | # Date: 4/21/19 4 | # Time: 9:20 PM 5 | 6 | # Script to merge hdf5 chunk files to one and update info.json accordingly 7 | 8 | import warnings 9 | 10 | warnings.filterwarnings("ignore", category=FutureWarning) 11 | warnings.filterwarnings("ignore", message="size changed") 12 | 13 | from tqdm import tqdm 14 | import argparse 15 | import h5py 16 | import json 17 | 18 | parser = argparse.ArgumentParser() 19 | parser.add_argument('--name', type=str, help="features directory name") 20 | parser.add_argument('--chunksNum', type=int, default=16, help="number of file chunks") 21 | parser.add_argument('--chunkSize', type=int, default=10000, help="file chunk size") 22 | args = parser.parse_args() 23 | 24 | print("Merging features file for gqa_{}. This may take a while (and may be 0 for some time).".format(args.name)) 25 | 26 | # Format specification for features files 27 | spec = { 28 | "spatial": {"features": (148855, 2048, 7, 7)}, 29 | "objects": {"features": (148855, 100, 2048), 30 | "bboxes": (148855, 100, 4)} 31 | } 32 | 33 | # Merge hdf5 files 34 | lengths = [0] 35 | with h5py.File("data/gqa_{name}.hdf5".format(name=args.name)) as out: 36 | datasets = {} 37 | for dname in spec[args.name]: 38 | datasets[dname] = out.create_dataset(dname, spec[args.name][dname]) 39 | 40 | low = 0 41 | for i in tqdm(range(args.chunksNum)): 42 | with h5py.File("data/{name}/gqa_{name}_{index}.h5".format(name=args.name, index=i)) as chunk: 43 | high = low + chunk["features"].shape[0] 44 | 45 | for dname in spec[args.name]: 46 | # low = i * args.chunkSize 47 | # high = (i + 1) * args.chunkSize if i < args.chunksNum -1 else spec[args.name][dname][0] 48 | datasets[dname][low:high] = chunk[dname][:] 49 | 50 | low = high 51 | lengths.append(high) 52 | 53 | # Update info file 54 | with open("data/{name}/gqa_{name}_info.json".format(name=args.name)) as infoIn: 55 | info = json.load(infoIn) 56 | for imageId in info: 57 | info[imageId]["index"] = lengths[info[imageId]["file"]] + info[imageId]["idx"] 58 | # info[imageId]["index"] = info[imageId]["file"] * args.chunkSize + info[imageId]["idx"] 59 | del info[imageId]["idx"] 60 | del info[imageId]["file"] 61 | 62 | with open("data/gqa_{name}_merged_info.json".format(name=args.name), "w") as infoOut: 63 | json.dump(info, infoOut) -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.autograd import Variable 3 | from torch import nn 4 | from torch.nn.init import kaiming_uniform_, xavier_uniform_, normal 5 | import torch.nn.functional as F 6 | 7 | def linear(in_dim, out_dim, bias=True): 8 | lin = nn.Linear(in_dim, out_dim, bias=bias) 9 | xavier_uniform_(lin.weight) 10 | if bias: 11 | lin.bias.data.zero_() 12 | 13 | return lin 14 | 15 | class ControlUnit(nn.Module): 16 | def __init__(self, dim, max_step): 17 | super().__init__() 18 | 19 | self.position_aware = nn.ModuleList() 20 | for i in range(max_step): 21 | self.position_aware.append(linear(dim * 2, dim)) 22 | 23 | self.control_question = linear(dim * 2, dim) 24 | self.attn = linear(dim, 1) 25 | 26 | self.dim = dim 27 | 28 | def forward(self, step, context, question, control): 29 | position_aware = self.position_aware[step](question) 30 | 31 | control_question = torch.cat([control, position_aware], 1) 32 | control_question = self.control_question(control_question) 33 | control_question = control_question.unsqueeze(1) 34 | 35 | context_prod = control_question * context 36 | attn_weight = self.attn(context_prod) 37 | 38 | attn = F.softmax(attn_weight, 1) 39 | 40 | next_control = (attn * context).sum(1) 41 | 42 | return next_control 43 | 44 | 45 | class ReadUnit(nn.Module): 46 | def __init__(self, dim): 47 | super().__init__() 48 | 49 | self.mem = linear(dim, dim) 50 | self.concat = linear(dim * 2, dim) 51 | self.attn = linear(dim, 1) 52 | 53 | def forward(self, memory, know, control): 54 | mem = self.mem(memory[-1]).unsqueeze(2) 55 | concat = self.concat(torch.cat([mem * know, know], 1) \ 56 | .permute(0, 2, 1)) 57 | attn = concat * control[-1].unsqueeze(1) 58 | attn = self.attn(attn).squeeze(2) 59 | attn = F.softmax(attn, 1).unsqueeze(1) 60 | 61 | read = (attn * know).sum(2) 62 | 63 | return read 64 | 65 | 66 | class WriteUnit(nn.Module): 67 | def __init__(self, dim, self_attention=False, memory_gate=False): 68 | super().__init__() 69 | 70 | self.concat = linear(dim * 2, dim) 71 | 72 | if self_attention: 73 | self.attn = linear(dim, 1) 74 | self.mem = linear(dim, dim) 75 | 76 | if memory_gate: 77 | self.control = linear(dim, 1) 78 | 79 | self.self_attention = self_attention 80 | self.memory_gate = memory_gate 81 | 82 | def forward(self, memories, retrieved, controls): 83 | prev_mem = memories[-1] 84 | concat = self.concat(torch.cat([retrieved, prev_mem], 1)) 85 | next_mem = concat 86 | 87 | if self.self_attention: 88 | controls_cat = torch.stack(controls[:-1], 2) 89 | attn = controls[-1].unsqueeze(2) * controls_cat 90 | attn = self.attn(attn.permute(0, 2, 1)) 91 | attn = F.softmax(attn, 1).permute(0, 2, 1) 92 | 93 | memories_cat = torch.stack(memories, 2) 94 | attn_mem = (attn * memories_cat).sum(2) 95 | next_mem = self.mem(attn_mem) + concat 96 | 97 | if self.memory_gate: 98 | control = self.control(controls[-1]) 99 | gate = F.sigmoid(control) 100 | next_mem = gate * prev_mem + (1 - gate) * next_mem 101 | 102 | return next_mem 103 | 104 | 105 | class MACUnit(nn.Module): 106 | def __init__(self, dim, max_step=12, 107 | self_attention=False, memory_gate=False, 108 | dropout=0.15): 109 | super().__init__() 110 | 111 | self.control = ControlUnit(dim, max_step) 112 | self.read = ReadUnit(dim) 113 | self.write = WriteUnit(dim, self_attention, memory_gate) 114 | 115 | self.mem_0 = nn.Parameter(torch.zeros(1, dim)) 116 | self.control_0 = nn.Parameter(torch.zeros(1, dim)) 117 | 118 | self.dim = dim 119 | self.max_step = max_step 120 | self.dropout = dropout 121 | 122 | def get_mask(self, x, dropout): 123 | mask = torch.empty_like(x).bernoulli_(1 - dropout) 124 | mask = mask / (1 - dropout) 125 | 126 | return mask 127 | 128 | def forward(self, context, question, knowledge): 129 | b_size = question.size(0) 130 | 131 | control = self.control_0.expand(b_size, self.dim) 132 | memory = self.mem_0.expand(b_size, self.dim) 133 | 134 | if self.training: 135 | control_mask = self.get_mask(control, self.dropout) 136 | memory_mask = self.get_mask(memory, self.dropout) 137 | control = control * control_mask 138 | memory = memory * memory_mask 139 | 140 | controls = [control] 141 | memories = [memory] 142 | 143 | for i in range(self.max_step): 144 | control = self.control(i, context, question, control) 145 | if self.training: 146 | control = control * control_mask 147 | controls.append(control) 148 | 149 | read = self.read(memories, knowledge, controls) 150 | memory = self.write(memories, read, controls) 151 | if self.training: 152 | memory = memory * memory_mask 153 | memories.append(memory) 154 | 155 | return memory 156 | 157 | 158 | class MACNetwork(nn.Module): 159 | def __init__(self, n_vocab, dim, embed_hidden=300, 160 | max_step=12, self_attention=False, memory_gate=False, 161 | classes=28, dropout=0.15): 162 | super().__init__() 163 | 164 | self.conv = nn.Sequential(nn.Conv2d(1024, dim, 3, padding=1), 165 | nn.ELU(), 166 | nn.Conv2d(dim, dim, 3, padding=1), 167 | nn.ELU()) 168 | 169 | self.embed = nn.Embedding(n_vocab, embed_hidden) 170 | self.lstm = nn.LSTM(embed_hidden, dim, 171 | batch_first=True, bidirectional=True) 172 | self.lstm_proj = nn.Linear(dim * 2, dim) 173 | 174 | self.mac = MACUnit(dim, max_step, 175 | self_attention, memory_gate, dropout) 176 | 177 | 178 | self.classifier = nn.Sequential(linear(dim * 3, dim), 179 | nn.ELU(), 180 | linear(dim, classes)) 181 | 182 | self.max_step = max_step 183 | self.dim = dim 184 | 185 | self.reset() 186 | 187 | def reset(self): 188 | self.embed.weight.data.uniform_(0, 1) 189 | 190 | kaiming_uniform_(self.conv[0].weight) 191 | self.conv[0].bias.data.zero_() 192 | kaiming_uniform_(self.conv[2].weight) 193 | self.conv[2].bias.data.zero_() 194 | 195 | kaiming_uniform_(self.classifier[0].weight) 196 | 197 | def forward(self, image, question, question_len, dropout=0.15): 198 | b_size = question.size(0) 199 | 200 | img = self.conv(image) 201 | img = img.view(b_size, self.dim, -1) 202 | 203 | embed = self.embed(question) 204 | embed = nn.utils.rnn.pack_padded_sequence(embed, question_len, 205 | batch_first=True) 206 | lstm_out, (h, _) = self.lstm(embed) 207 | lstm_out, _ = nn.utils.rnn.pad_packed_sequence(lstm_out, 208 | batch_first=True) 209 | lstm_out = self.lstm_proj(lstm_out) 210 | h = h.permute(1, 0, 2).contiguous().view(b_size, -1) 211 | 212 | memory = self.mac(lstm_out, h, img) 213 | 214 | out = torch.cat([memory, h], 1) 215 | out = self.classifier(out) 216 | 217 | return out -------------------------------------------------------------------------------- /model_gqa.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from block.models.networks.fusions.fusions import Tucker 4 | from torch import nn 5 | from torch.nn.init import kaiming_uniform_, xavier_uniform_ 6 | 7 | import config 8 | from utils import get_or_load_embeddings 9 | 10 | 11 | def linear(in_dim, out_dim, bias=True): 12 | lin = nn.Linear(in_dim, out_dim, bias=bias) 13 | xavier_uniform_(lin.weight) 14 | if bias: 15 | lin.bias.data.zero_() 16 | 17 | return lin 18 | 19 | class ControlUnit(nn.Module): 20 | def __init__(self, dim, max_step): 21 | super().__init__() 22 | 23 | self.position_aware = nn.ModuleList() 24 | for i in range(max_step): 25 | self.position_aware.append(linear(dim * 2, dim)) 26 | 27 | self.control_question = linear(dim * 2, dim) 28 | self.attn = linear(dim, 1) 29 | 30 | self.dim = dim 31 | 32 | def forward(self, step, context, question, control): 33 | position_aware = self.position_aware[step](question) 34 | 35 | control_question = torch.cat([control, position_aware], 1) 36 | control_question = self.control_question(control_question) 37 | control_question = control_question.unsqueeze(1) 38 | 39 | context_prod = control_question * context 40 | attn_weight = self.attn(context_prod) 41 | 42 | attn = F.softmax(attn_weight, 1) 43 | 44 | next_control = (attn * context).sum(1) 45 | 46 | return next_control 47 | 48 | 49 | class ReadUnit(nn.Module): 50 | def __init__(self, dim): 51 | super().__init__() 52 | 53 | self.mem = linear(dim, dim) 54 | self.concat = linear(dim * 2, dim) 55 | self.attn = linear(dim, 1) 56 | self.tucker = Tucker((2048, 2048), 1, mm_dim=50, shared=True) 57 | 58 | def forward(self, memory, know, control): 59 | mem = self.mem(memory[-1]).unsqueeze(2) 60 | s_matrix = (mem * know) 61 | s_matrix = s_matrix.view(-1, 2048) 62 | attn = self.tucker([s_matrix, control[-1].repeat(know.size(2), 1)]).view(know.size(2), know.size(0)) 63 | attn = attn.transpose(0, 1) 64 | attn = F.softmax(attn, 1).unsqueeze(1) 65 | read = (attn * know).sum(2) 66 | 67 | return read 68 | 69 | 70 | class WriteUnit(nn.Module): 71 | def __init__(self, dim, self_attention=False, memory_gate=False): 72 | super().__init__() 73 | 74 | self.concat = linear(dim * 2, dim) 75 | 76 | if self_attention: 77 | self.attn = linear(dim, 1) 78 | self.mem = linear(dim, dim) 79 | 80 | if memory_gate: 81 | self.control = linear(dim, 1) 82 | 83 | self.self_attention = self_attention 84 | self.memory_gate = memory_gate 85 | 86 | def forward(self, memories, retrieved, controls): 87 | prev_mem = memories[-1] 88 | concat = self.concat(torch.cat([retrieved, prev_mem], 1)) 89 | next_mem = concat 90 | 91 | if self.self_attention: 92 | controls_cat = torch.stack(controls[:-1], 2) 93 | attn = controls[-1].unsqueeze(2) * controls_cat 94 | attn = self.attn(attn.permute(0, 2, 1)) 95 | attn = F.softmax(attn, 1).permute(0, 2, 1) 96 | 97 | memories_cat = torch.stack(memories, 2) 98 | attn_mem = (attn * memories_cat).sum(2) 99 | next_mem = self.mem(attn_mem) + concat 100 | 101 | if self.memory_gate: 102 | control = self.control(controls[-1]) 103 | gate = F.sigmoid(control) 104 | next_mem = gate * prev_mem + (1 - gate) * next_mem 105 | 106 | return next_mem 107 | 108 | 109 | class MACUnit(nn.Module): 110 | def __init__(self, dim, max_step=12, 111 | self_attention=False, memory_gate=False, 112 | dropout=0.15): 113 | super().__init__() 114 | 115 | self.control = ControlUnit(dim, max_step) 116 | self.read = ReadUnit(dim) 117 | self.write = WriteUnit(dim, self_attention, memory_gate) 118 | 119 | self.mem_0 = nn.Parameter(torch.zeros(1, dim)) 120 | self.control_0 = nn.Parameter(torch.zeros(1, dim)) 121 | 122 | self.dim = dim 123 | self.max_step = max_step 124 | self.dropout = dropout 125 | self.dropouts = {} 126 | self.dropouts["encInput"]: config.encInputDropout 127 | self.dropouts["encState"]: config.encStateDropout 128 | self.dropouts["stem"]: config.stemDropout 129 | self.dropouts["question"]: config.qDropout 130 | self.dropouts["memory"]: config.memoryDropout 131 | self.dropouts["read"]: config.readDropout 132 | self.dropouts["write"]: config.writeDropout 133 | self.dropouts["output"]: config.outputDropout 134 | self.dropouts["controlPre"]: config.controlPreDropout 135 | self.dropouts["controlPost"]: config.controlPostDropout 136 | self.dropouts["wordEmb"]: config.wordEmbDropout 137 | self.dropouts["word"]: config.wordDp 138 | self.dropouts["vocab"]: config.vocabDp 139 | self.dropouts["object"]: config.objectDp 140 | self.dropouts["wordStandard"]: config.wordStandardDp 141 | 142 | def get_mask(self, x, dropout): 143 | mask = torch.empty_like(x).bernoulli_(1 - dropout) 144 | mask = mask / (1 - dropout) 145 | 146 | return mask 147 | 148 | def forward(self, context, question, knowledge): 149 | b_size = question.size(0) 150 | 151 | control = self.control_0.expand(b_size, self.dim) 152 | memory = self.mem_0.expand(b_size, self.dim) 153 | 154 | if self.training: 155 | control_mask = self.get_mask(control, self.dropout) 156 | memory_mask = self.get_mask(memory, self.dropout) 157 | control = control * control_mask 158 | memory = memory * memory_mask 159 | 160 | controls = [control] 161 | memories = [memory] 162 | 163 | for i in range(self.max_step): 164 | control = self.control(i, context, question, control) 165 | if self.training: 166 | control = control * control_mask 167 | controls.append(control) 168 | 169 | read = self.read(memories, knowledge, controls) 170 | memory = self.write(memories, read, controls) 171 | if self.training: 172 | memory = memory * memory_mask 173 | memories.append(memory) 174 | 175 | return memory 176 | 177 | 178 | class MACNetwork(nn.Module): 179 | def __init__(self, n_vocab, dim, embed_hidden=300, 180 | max_step=12, self_attention=False, memory_gate=False, 181 | classes=28, dropout=0.15): 182 | super().__init__() 183 | 184 | self.conv = nn.Sequential(nn.Conv2d(1024, dim, 3, padding=1), 185 | nn.ELU(), 186 | nn.Conv2d(dim, dim, 3, padding=1), 187 | nn.ELU()) 188 | 189 | self.embed = nn.Embedding(n_vocab, embed_hidden) 190 | self.embed.weight.data = torch.Tensor(get_or_load_embeddings()) 191 | self.embed.weight.requires_grad = False 192 | self.lstm = nn.LSTM(embed_hidden, dim, 193 | batch_first=True, bidirectional=True) 194 | self.lstm_proj = nn.Linear(dim * 2, dim) 195 | 196 | self.mac = MACUnit(dim, max_step, 197 | self_attention, memory_gate, dropout) 198 | 199 | 200 | self.classifier = nn.Sequential(linear(dim * 3, dim), 201 | nn.ELU(), 202 | linear(dim, classes)) 203 | 204 | self.max_step = max_step 205 | self.dim = dim 206 | 207 | self.reset() 208 | 209 | def reset(self): 210 | self.embed.weight.data.uniform_(0, 1) 211 | 212 | kaiming_uniform_(self.conv[0].weight) 213 | self.conv[0].bias.data.zero_() 214 | kaiming_uniform_(self.conv[2].weight) 215 | self.conv[2].bias.data.zero_() 216 | 217 | kaiming_uniform_(self.classifier[0].weight) 218 | 219 | def forward(self, image, question, question_len, dropout=0.15): 220 | b_size = question.size(0) 221 | 222 | img = image 223 | img = img.view(b_size, self.dim, -1) 224 | 225 | embed = self.embed(question) 226 | embed = nn.utils.rnn.pack_padded_sequence(embed, question_len, 227 | batch_first=True) 228 | lstm_out, (h, _) = self.lstm(embed) 229 | lstm_out, _ = nn.utils.rnn.pad_packed_sequence(lstm_out, batch_first=True) 230 | lstm_out = self.lstm_proj(lstm_out) 231 | h = h.permute(1, 0, 2).contiguous().view(b_size, -1) 232 | 233 | memory = self.mac(lstm_out, h, img) 234 | 235 | out = torch.cat([memory, h], 1) 236 | out = self.classifier(out) 237 | 238 | return out -------------------------------------------------------------------------------- /preprocess.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import json 4 | import pickle 5 | import jieba 6 | import nltk 7 | import tqdm 8 | from torchvision import transforms 9 | from PIL import Image 10 | from transforms import Scale 11 | import MeCab 12 | from konlpy.tag import Okt 13 | image_index = {'CLEVR': 'image_filename', 14 | 'gqa': 'imageId'} 15 | 16 | 17 | def process_question(root, split, word_dic=None, answer_dic=None, dataset_type='CLEVR',lang='en'): 18 | if word_dic is None: 19 | word_dic = {} 20 | 21 | if answer_dic is None: 22 | answer_dic = {} 23 | 24 | with open(os.path.join(root, 'questions', f'mini_{dataset_type}_{split}_questions.json'), encoding='utf-8') as f: 25 | data = json.load(f) 26 | 27 | result = [] 28 | word_index = 1 29 | answer_index = 0 30 | 31 | for question in tqdm.tqdm(data['questions']): 32 | if lang == 'en': 33 | words = nltk.word_tokenize(question['question']) 34 | elif lang =='vi': 35 | words = nltk.word_tokenize(question['translated']['vietnamese']) 36 | elif lang == 'zh': 37 | words = jieba.lcut(question['translated']['chinese']) 38 | elif lang == 'ja': 39 | mecab = MeCab.Tagger("-Owakati") 40 | words = mecab.parse(question['translated']['japanese']).strip().split() 41 | elif lang == 'ko': 42 | okt = Okt() 43 | words = okt.morphs(question['translated']['korean']) 44 | question_token = [] 45 | 46 | for word in words: 47 | try: 48 | question_token.append(word_dic[word]) 49 | 50 | except: 51 | question_token.append(word_index) 52 | word_dic[word] = word_index 53 | word_index += 1 54 | 55 | answer_word = question['answer'] 56 | 57 | try: 58 | answer = answer_dic[answer_word] 59 | except: 60 | answer = answer_index 61 | answer_dic[answer_word] = answer_index 62 | answer_index += 1 63 | 64 | result.append((question[image_index[dataset_type]], question_token, answer)) 65 | 66 | with open(f'data/{dataset_type}_{split}_{lang}.pkl', 'wb') as f: 67 | pickle.dump(result, f) 68 | 69 | return word_dic, answer_dic 70 | 71 | 72 | if __name__ == '__main__': 73 | dataset_type = sys.argv[1] 74 | root = sys.argv[2] 75 | lang = sys.argv[3] 76 | word_dic, answer_dic = process_question(root, 'train', dataset_type=dataset_type) 77 | process_question(root, 'val', word_dic, answer_dic, dataset_type=dataset_type) 78 | 79 | with open(f'data/{dataset_type}_{lang}_dic.pkl', 'wb') as f: 80 | pickle.dump({'word_dic': word_dic, 'answer_dic': answer_dic}, f) 81 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | Pillow 2 | nltk 3 | tqdm 4 | block.bootstrap.pytorch 5 | murel.bootstrap.pytorch 6 | konlpy 7 | mecab-python3 8 | gensim 9 | -------------------------------------------------------------------------------- /setup.sh: -------------------------------------------------------------------------------- 1 | cd data 2 | mkdir gqa && cd gqa 3 | wget https://nlp.stanford.edu/data/gqa/data1.2.zip 4 | unzip data1.2.zip 5 | mkdir questions 6 | mv balanced_train_data.json questions/gqa_train_questions.json 7 | mv balanced_val_data.json questions/gqa_val_questions.json 8 | mv balanced_testdev_data.json questions/gqa_testdev_questions.json 9 | cd .. 10 | 11 | 12 | wget http://nlp.stanford.edu/data/glove.6B.zip 13 | unzip glove.6B.zip 14 | wget http://nlp.stanford.edu/data/gqa/objectFeatures.zip 15 | unzip objectFeatures.zip 16 | cd .. 17 | 18 | python merge.py --name objects 19 | mv data/gqa_objects.hdf5 data/gqa_features.hdf5 20 | 21 | python preprocess.py gqa data/gqa 22 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import pickle 3 | from collections import Counter 4 | 5 | import torch 6 | from tqdm import tqdm 7 | from torch.utils.data import DataLoader 8 | from dataset import CLEVR, collate_data, transform 9 | 10 | batch_size = 64 11 | n_epoch = 180 12 | 13 | train_set = DataLoader( 14 | CLEVR(sys.argv[1], 'val', transform=None), 15 | batch_size=batch_size, 16 | num_workers=4, 17 | collate_fn=collate_data, 18 | ) 19 | net = torch.load(sys.argv[2]) 20 | net.eval() 21 | 22 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 23 | 24 | for epoch in range(n_epoch): 25 | dataset = iter(train_set) 26 | pbar = tqdm(dataset) 27 | correct_counts = 0 28 | total_counts = 0 29 | 30 | for image, question, q_len, answer in pbar: 31 | image, question = image.to(device), question.to(device) 32 | output = net(image, question, q_len) 33 | correct = output.detach().argmax(1) == answer.to(device) 34 | for c in correct: 35 | if c: 36 | correct_counts += 1 37 | total_counts += 1 38 | 39 | print('Avg Acc: {:.5f}'.format(correct_counts / total_counts)) 40 | 41 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import multiprocessing 2 | import pickle 3 | import sys 4 | 5 | import torch 6 | from torch import nn 7 | from torch import optim 8 | from torch.utils.data import DataLoader 9 | from tqdm import tqdm 10 | 11 | from dataset import CLEVR, collate_data, transform, GQA 12 | from model_gqa import MACNetwork 13 | 14 | batch_size = 128 15 | n_epoch = 25 16 | dim_dict = {'CLEVR': 512, 17 | 'gqa': 2048} 18 | 19 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 20 | 21 | 22 | def accumulate(model1, model2, decay=0.999): 23 | par1 = dict(model1.named_parameters()) 24 | par2 = dict(model2.named_parameters()) 25 | 26 | for k in par1.keys(): 27 | par1[k].data.mul_(decay).add_(1 - decay, par2[k].data) 28 | 29 | 30 | def train(epoch, dataset_type, lang='en'): 31 | if dataset_type == "CLEVR": 32 | dataset_object = CLEVR('data/CLEVR', transform=transform, lang=lang) 33 | else: 34 | dataset_object = GQA('data/gqa', transform=transform) 35 | 36 | train_set = DataLoader( 37 | dataset_object, batch_size=batch_size, num_workers=multiprocessing.cpu_count(), collate_fn=collate_data 38 | ) 39 | 40 | dataset = iter(train_set) 41 | pbar = tqdm(dataset) 42 | moving_loss = 0 43 | 44 | net.train(True) 45 | for image, question, q_len, answer in pbar: 46 | image, question, answer = ( 47 | image.to(device), 48 | question.to(device), 49 | answer.to(device), 50 | ) 51 | 52 | net.zero_grad() 53 | output = net(image, question, q_len) 54 | loss = criterion(output, answer) 55 | loss.backward() 56 | optimizer.step() 57 | correct = output.detach().argmax(1) == answer 58 | correct = torch.tensor(correct, dtype=torch.float32).sum() / batch_size 59 | 60 | if moving_loss == 0: 61 | moving_loss = correct 62 | else: 63 | moving_loss = moving_loss * 0.99 + correct * 0.01 64 | 65 | pbar.set_description('Epoch: {}; Loss: {:.8f}; Acc: {:.5f}'.format(epoch + 1, loss.item(), moving_loss)) 66 | 67 | accumulate(net_running, net) 68 | 69 | dataset_object.close() 70 | 71 | 72 | def valid(epoch, dataset_type,lang='en'): 73 | if dataset_type == "CLEVR": 74 | dataset_object = CLEVR('data/CLEVR', 'val', transform=None, lang=lang) 75 | else: 76 | dataset_object = GQA('data/gqa', 'val', transform=None) 77 | 78 | valid_set = DataLoader( 79 | dataset_object, batch_size=4*batch_size, num_workers=multiprocessing.cpu_count(), collate_fn=collate_data 80 | ) 81 | dataset = iter(valid_set) 82 | 83 | net_running.train(False) 84 | correct_counts = 0 85 | total_counts = 0 86 | running_loss = 0.0 87 | batches_done = 0 88 | with torch.no_grad(): 89 | pbar = tqdm(dataset) 90 | for image, question, q_len, answer in pbar: 91 | image, question, answer = ( 92 | image.to(device), 93 | question.to(device), 94 | answer.to(device), 95 | ) 96 | 97 | output = net_running(image, question, q_len) 98 | loss = criterion(output, answer) 99 | correct = output.detach().argmax(1) == answer 100 | running_loss += loss.item() 101 | 102 | batches_done += 1 103 | for c in correct: 104 | if c: 105 | correct_counts += 1 106 | total_counts += 1 107 | 108 | pbar.set_description('Epoch: {}; Loss: {:.8f}; Acc: {:.5f}'.format(epoch + 1, loss.item(), correct_counts / batches_done)) 109 | 110 | with open('log/log_{}.txt'.format(str(epoch + 1).zfill(2)), 'w') as w: 111 | w.write('{:.5f}\n'.format(correct_counts / total_counts)) 112 | 113 | print('Validation Accuracy: {:.5f}'.format(correct_counts / total_counts)) 114 | print('Validation Loss: {:.8f}'.format(running_loss / total_counts)) 115 | 116 | dataset_object.close() 117 | 118 | 119 | if __name__ == '__main__': 120 | dataset_type = sys.argv[1] 121 | lang = sys.argv[2] 122 | with open(f'data/{dataset_type}_dic.pkl', 'rb') as f: 123 | dic = pickle.load(f) 124 | 125 | n_words = len(dic['word_dic']) + 1 126 | n_answers = len(dic['answer_dic']) 127 | 128 | net = MACNetwork(n_words, dim_dict[dataset_type], classes=n_answers, max_step=4).to(device) 129 | net_running = MACNetwork(n_words, dim_dict[dataset_type], classes=n_answers, max_step=4).to(device) 130 | accumulate(net_running, net, 0) 131 | 132 | criterion = nn.CrossEntropyLoss() 133 | optimizer = optim.Adam(net.parameters(), lr=1e-4) 134 | 135 | for epoch in range(n_epoch): 136 | train(epoch, dataset_type,lang=lang) 137 | valid(epoch, dataset_type,lang=lang) 138 | 139 | with open(f'checkpoint/checkpoint_{lang}.model'.format(str(epoch + 1).zfill(2)), 'wb') as f: 140 | torch.save(net_running.state_dict(), f) 141 | -------------------------------------------------------------------------------- /transforms.py: -------------------------------------------------------------------------------- 1 | """ Shamelessly take from newer torchvision repository """ 2 | 3 | from PIL import Image 4 | import collections 5 | collections.Iterable = collections.abc.Iterable 6 | class Scale(object): 7 | """Rescale the input PIL.Image to the given size. 8 | 9 | Args: 10 | size (sequence or int): Desired output size. If size is a sequence like 11 | (w, h), output size will be matched to this. If size is an int, 12 | smaller edge of the image will be matched to this number. 13 | i.e, if height > width, then image will be rescaled to 14 | (size * height / width, size) 15 | interpolation (int, optional): Desired interpolation. Default is 16 | ``PIL.Image.BILINEAR`` 17 | """ 18 | 19 | def __init__(self, size, interpolation=Image.BILINEAR): 20 | assert isinstance(size, int) or (isinstance(size, collections.Iterable) and len(size) == 2) 21 | self.size = size 22 | self.interpolation = interpolation 23 | 24 | def __call__(self, img): 25 | """ 26 | Args: 27 | img (PIL.Image): Image to be scaled. 28 | 29 | Returns: 30 | PIL.Image: Rescaled image. 31 | """ 32 | if isinstance(self.size, int): 33 | w, h = img.size 34 | if (w <= h and w == self.size) or (h <= w and h == self.size): 35 | return img 36 | if w < h: 37 | ow = self.size 38 | oh = int(self.size * h / w) 39 | return img.resize((ow, oh), self.interpolation) 40 | else: 41 | oh = self.size 42 | ow = int(self.size * w / h) 43 | return img.resize((ow, oh), self.interpolation) 44 | else: 45 | return img.resize(self.size, self.interpolation) -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | # File: utils.py 2 | # Author: Ronil Pancholia 3 | # Date: 4/22/19 4 | # Time: 7:57 PM 5 | import pickle 6 | import sys 7 | 8 | import numpy as np 9 | import jieba 10 | 11 | from gensim.models import KeyedVectors 12 | id2word = [] 13 | embedding_weights = None 14 | 15 | def get_or_load_embeddings(lang='en'): 16 | """ 17 | parameters: 18 | lang: language of the embeding 19 | 20 | returns: 21 | embedding matrix 22 | """ 23 | global embedding_weights, id2word 24 | 25 | if embedding_weights is not None: 26 | return embedding_weights 27 | 28 | dataset_type = sys.argv[1] 29 | with open(f'data/{dataset_type}_dic.pkl', 'rb') as f: 30 | dic = pickle.load(f) 31 | 32 | id2word = set(dic['word_dic'].keys()) 33 | id2word.update(set(dic['answer_dic'].keys())) 34 | 35 | word2id = {word: id for id, word in enumerate(id2word)} 36 | 37 | embed_size = 300 38 | vocab_size = len(id2word) 39 | sd = 1 / np.sqrt(embed_size) 40 | embedding_weights = np.random.normal(0, scale=sd, size=[vocab_size, embed_size]) 41 | embedding_weights = embedding_weights.astype(np.float32) 42 | if lang == 'en': 43 | with open("/kaggle/input/phow2v/glove.6B.300d.txt", encoding="utf-8", mode="r") as textFile: 44 | for line in textFile: 45 | line = line.split() 46 | word = line[0] 47 | 48 | id = word2id.get(word, None) 49 | if id is not None: 50 | embedding_weights[id] = np.array(line[1:], dtype=np.float32) 51 | elif lang == 'vi': 52 | with open("/kaggle/input/phow2v/word2vec_vi_words_300dims.txt", encoding="utf-8", mode="r") as textFile: 53 | for line in textFile: 54 | line = line.split() 55 | try: 56 | float(line[1]) 57 | word = line[0] 58 | id = word2id.get(word, None) 59 | if id is not None: 60 | embedding_weights[id] = np.array(line[1:], dtype=np.float32) 61 | except: 62 | word = '_'.join(line[:2]) 63 | id = word2id.get(word, None) 64 | if id is not None: 65 | embedding_weights[id] = np.array(line[2:], dtype=np.float32) 66 | elif lang == 'zh': 67 | fb_model = KeyedVectors.load_word2vec_format('/kaggle/input/phow2v/cc.zh.300.vec') 68 | for word, vector in fb_model.items(): 69 | id = word2id.get(word, None) 70 | if id is not None: 71 | embedding_weights[id] = np.array(vector, dtype=np.float32) 72 | elif lang == 'ja': 73 | fb_model = KeyedVectors.load_word2vec_format('/kaggle/input/phow2v/cc.ja.300.vec') 74 | for word, vector in fb_model.items(): 75 | id = word2id.get(word, None) 76 | if id is not None: 77 | embedding_weights[id] = np.array(vector, dtype=np.float32) 78 | elif lang == 'ko': 79 | fb_model = KeyedVectors.load_word2vec_format('/kaggle/input/phow2v/cc.ko.300.vec') 80 | for word, vector in fb_model.items(): 81 | id = word2id.get(word, None) 82 | if id is not None: 83 | embedding_weights[id] = np.array(vector, dtype=np.float32) 84 | return embedding_weights 85 | --------------------------------------------------------------------------------