├── README.md ├── dataset_utility.py ├── ncd.py ├── requirements.txt ├── test.py └── train.py /README.md: -------------------------------------------------------------------------------- 1 | # Unsupervised Abstract Reasoning for Raven’s Problem Matrices 2 | 3 | This code is the implementation of our [TIP paper](https://arxiv.org/pdf/2109.10011.pdf). 4 | 5 | This is the first unsupervised abstract reasoning method on Raven's Progressive Matrices, it is an extention of our [arxiv preprint](https://arxiv.org/pdf/2002.01646.pdf). 6 | 7 | ## Comparision with some supervised methods. 8 | 9 | ### Average testing accuracy on the RAVEN, I-RAVEN, and PGM dataset 10 | 11 | | Method | Raven | I-RAVEN | PGM | 12 | |-------------------|----------|----------------|---------------| 13 | | CNN | 36.97 | 13.26 | 33.00 | 14 | | ResNet50 | 86.26 | - | 42.00 | 15 | | DCNet (ICLR2021) |**93.58** | **49.36** | 68.57 | 16 | | NCD (Ours) | 36.99 | 48.22 | 47.62 | 17 | 18 | 19 | 20 | ### Generalization test results on PGM dataset 21 | 22 | | Method | neutral| interpolation | extrapolation | 23 | |-------------------|--------|----------------|---------------| 24 | | WReN (ICML2018) | 62.6 | 64.4 | 17.2 | 25 | | DCNet (ICLR2021) | 68.6 | 59.7 | 17.8 | 26 | | MXGNet (ICLR2020) |**89.6** | **84.6** | 18.4 | 27 | | NCD (Ours) | 47.6 | 47.0 | **24.9** | 28 | 29 | 30 | ## Citation 31 | If our code is useful for your research, please cite the following papers. 32 | 33 | ``` 34 | @article{zhuo2021unsup, 35 | title={Unsupervised Abstract Reasoning for Raven’s Problem Matrices}, 36 | author={Tao Zhuo, Qiang Huang, and Mohan Kankanhalli}, 37 | journal={IEEE Transactions on Image Processing}, 38 | year={2021} 39 | } 40 | ``` 41 | 42 | ``` 43 | @article{zhuo2020solving, 44 | title={Solving Raven's Progressive Matrices with Neural Networks}, 45 | author={Tao Zhuo and Mohan Kankanhalli}, 46 | journal={arXiv preprint arXiv:2002.01646}, 47 | year={2020} 48 | } 49 | ``` 50 | 51 | ``` 52 | @inproceedings{iclr2021, 53 | author={Tao Zhuo and Mohan Kankanhalli}, 54 | title={Effective Abstract Reasoning with Dual-Contrast Network}, 55 | booktitle={International Conference on Learning Representations (ICLR)}, 56 | year={2021} 57 | } 58 | ``` 59 | -------------------------------------------------------------------------------- /dataset_utility.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | import numpy as np 4 | import cv2 5 | 6 | import torch 7 | from torch.utils.data import Dataset, DataLoader 8 | from torchvision import transforms, utils 9 | 10 | 11 | 12 | class ToTensor(object): 13 | def __call__(self, sample): 14 | return torch.tensor(sample, dtype=torch.float32) 15 | 16 | class dataset(Dataset): 17 | def __init__(self, root, dataset_type, fig_type='*', img_size=160, transform=None, train_mode=False): 18 | self.transform = transform 19 | self.img_size = img_size 20 | self.train_mode = train_mode 21 | self.file_names = [f for f in glob.glob(os.path.join(root, fig_type, '*.npz')) if dataset_type in f] 22 | 23 | if self.train_mode: 24 | idx = list(range(len(self.file_names))) 25 | np.random.shuffle(idx) 26 | #self.file_names = [self.file_names[i] for i in idx[0:100000]] # randomly select 100K samples for fast model training on large-scale dataset 27 | 28 | def __len__(self): 29 | return len(self.file_names) 30 | 31 | 32 | def __getitem__(self, idx): 33 | data = np.load(self.file_names[idx]) 34 | image = data['image'].reshape(16, 160, 160) 35 | target = data['target'] 36 | 37 | del data 38 | 39 | resize_image = image 40 | if self.img_size is not None: 41 | resize_image = [] 42 | for idx in range(0, 16): 43 | resize_image.append(cv2.resize(image[idx, :], (self.img_size, self.img_size), interpolation = cv2.INTER_NEAREST)) 44 | resize_image = np.stack(resize_image) 45 | 46 | if self.transform: 47 | resize_image = self.transform(resize_image) 48 | target = torch.tensor(target, dtype=torch.long) 49 | 50 | return resize_image, target 51 | 52 | -------------------------------------------------------------------------------- /ncd.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torchvision import models 5 | 6 | class NCD(nn.Module): 7 | def __init__(self, mode='rc'): 8 | super(NCD, self).__init__() 9 | self.mode = mode 10 | net = models.resnet18(pretrained=True) 11 | 12 | self.feature = nn.Sequential(*list(net.children())[0:-2]) 13 | 14 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 15 | self.fc = nn.Sequential(nn.Dropout(0.5), nn.Linear(net.fc.in_features, 1)) 16 | 17 | for m in self.modules(): 18 | if isinstance(m, nn.BatchNorm2d): 19 | m.eval() 20 | 21 | # input x: [b, 10, 3, h, w] 22 | def feat(self, x): 23 | b = x.shape[0] 24 | 25 | x = x.view((b*10, -1) + x.shape[3:]) # [b*10, 3, h, w] 26 | x = self.feature(x) 27 | x = self.avgpool(x) 28 | 29 | x = x.view(b, 10, -1) 30 | x = x - 0.5 * (x[:, 0:1] + x[:, 1:2]) # [b, 10, 512] 31 | 32 | return x 33 | 34 | 35 | # input shape b*16*224*224 36 | def forward(self, x): 37 | b = x.shape[0] 38 | 39 | # images of the choices 40 | choices = x[:, 8:].unsqueeze(dim=2) # [b, 8, 1, h, w] 41 | 42 | # images of the rows 43 | row1 = x[:, 0:3].unsqueeze(1) # [b, 1, 3, h, w] 44 | row2 = x[:, 3:6].unsqueeze(1) # [b, 1, 3, h, w] 45 | 46 | row3_p = x[:, 6:8].unsqueeze(dim=1).repeat(1, 8, 1, 1, 1) # [b, 8, 2, h, w] 47 | row3 = torch.cat((row3_p, choices), dim=2) # [b, 8, 3, h, w] 48 | 49 | rows = torch.cat((row1, row2, row3), dim=1) # [b, 10, 3, h, w] 50 | 51 | if self.mode == 'r': 52 | x = self.feat(rows) 53 | 54 | elif self.mode == 'rc': 55 | 56 | # images of the columns 57 | col1 = x[:, 0:8:3].unsqueeze(1) # [b, 1, 3, h, w] 58 | col2 = x[:, 1:8:3].unsqueeze(1) # [b, 1, 3, h, w] 59 | 60 | col3_p = x[:, 2:8:3].unsqueeze(dim=1).repeat(1, 8, 1, 1, 1) # [b, 8, 2, h, w] 61 | col3 = torch.cat((col3_p, choices), dim=2) # [b, 8, 3, h, w] 62 | 63 | cols = torch.cat((col1, col2, col3), dim=1) # [b, 10, 3, h, w] 64 | 65 | x = self.feat(rows) + self.feat(cols) 66 | 67 | x = self.fc(x.view(b*10, -1)) 68 | 69 | return x.view(b, 10) 70 | 71 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy 2 | pillow 3 | python-opencv 4 | tqdm 5 | torch 6 | torchvision 7 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import numpy as np 4 | import argparse 5 | import time 6 | import torch.nn.functional as F 7 | 8 | from tqdm import trange 9 | from datetime import datetime 10 | from torch.utils.data import Dataset, DataLoader 11 | from torchvision import transforms, utils 12 | from torch.autograd import Variable 13 | 14 | from dataset_utility import dataset, ToTensor 15 | from ncd import NCD 16 | 17 | os.environ['CUDA_VISIBLE_DEVICES'] ='0, 1' 18 | torch.backends.cudnn.benchmark = True 19 | 20 | parser = argparse.ArgumentParser() 21 | parser.add_argument('--model_name', type=str, default='ncd') 22 | 23 | parser.add_argument('--dataset', type=str, default='i-raven') 24 | parser.add_argument('--root', type=str, default='../dataset/rpm') 25 | parser.add_argument('--pretrained_model', type=str, default='pretrained_models/model_iraven.pth') 26 | 27 | parser.add_argument('--batch_size', type=int, default=32) 28 | parser.add_argument('--img_size', type=int, default=256) 29 | parser.add_argument('--workers', type=int, default=8) 30 | parser.add_argument('--seed', type=int, default=123) 31 | 32 | args = parser.parse_args() 33 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 34 | 35 | if torch.cuda.is_available: 36 | torch.cuda.manual_seed(args.seed) 37 | tf = transforms.Compose([ToTensor()]) 38 | 39 | if args.dataset == 'raven' or args.dataset == 'i-raven': 40 | mode = 'r' 41 | args.img_size = 256 42 | args.batch_size = 64 43 | elif args.dataset == 'pgm': 44 | mode = 'rc' 45 | args.img_size = 96 46 | args.batch_size = 256 47 | 48 | model = NCD(mode).to(device) 49 | if torch.cuda.device_count() > 0: 50 | print("Let's use", torch.cuda.device_count(), "GPUs!") 51 | model = torch.nn.DataParallel(model) 52 | model.load_state_dict(torch.load(args.pretrained_model)) 53 | 54 | 55 | def test(test_loader): 56 | model.eval() 57 | metrics = {'correct': [], 'count': []} 58 | 59 | test_loader_iter = iter(test_loader) 60 | for _ in trange(len(test_loader_iter)): 61 | image, target = next(test_loader_iter) 62 | 63 | image = Variable(image, requires_grad=False).to(device) 64 | target = Variable(target, requires_grad=False).to(device) 65 | 66 | with torch.no_grad(): 67 | predict = model(image) 68 | 69 | pred = torch.max(predict[:, 2:], 1)[1] 70 | correct = pred.eq(target.data).cpu().sum().numpy() 71 | 72 | metrics['correct'].append(correct) 73 | metrics['count'].append(target.size(0)) 74 | 75 | accuracy = 100 * np.sum(metrics['correct']) / np.sum(metrics['count']) 76 | 77 | return metrics 78 | 79 | 80 | if __name__ == '__main__': 81 | 82 | if args.dataset == 'raven' or args.dataset == 'i-raven': 83 | fig_types = ['center_single', 'distribute_four', 'distribute_nine', 84 | 'left_center_single_right_center_single', 'up_center_single_down_center_single', 85 | 'in_center_single_out_center_single', 'in_distribute_four_out_center_single'] 86 | 87 | accuracy_list = [] 88 | for i in range(len(fig_types)): 89 | test_set = dataset(os.path.join(args.root, args.dataset), 'test', fig_types[i], args.img_size, tf) 90 | test_loader = DataLoader(test_set, batch_size=args.batch_size, shuffle=False, num_workers=args.workers, pin_memory=True) 91 | 92 | metrics_test = test(test_loader) 93 | acc_test = 100 * np.sum(metrics_test['correct']) / np.sum(metrics_test['count']) 94 | accuracy_list.append(acc_test) 95 | 96 | print ('FigType: {:s}, Accuracy: {:.3f} \n'.format(fig_types[i], acc_test)) 97 | 98 | print (accuracy_list) 99 | print ('Average Accuracy: {:.3f} \n'.format(np.mean(accuracy_list))) 100 | 101 | elif args.dataset == 'pgm': 102 | test_set = dataset(args.root, 'test', 'interpolation', args.img_size, tf) 103 | test_loader = DataLoader(test_set, batch_size=args.batch_size, shuffle=False, num_workers=args.workers, pin_memory=True) 104 | 105 | metrics_test = test(test_loader) 106 | acc_test = 100 * np.sum(metrics_test['correct']) / np.sum(metrics_test['count']) 107 | 108 | print ('Average Accuracy: {:.3f} \n'.format(acc_test)) 109 | 110 | 111 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import numpy as np 4 | import argparse 5 | import time 6 | import torch.nn.functional as F 7 | 8 | from tqdm import trange 9 | from datetime import datetime 10 | from torch.utils.data import Dataset, DataLoader 11 | from torchvision import transforms, utils 12 | from torch.autograd import Variable 13 | 14 | from dataset_utility import dataset, ToTensor 15 | from ncd import NCD 16 | 17 | os.environ['CUDA_VISIBLE_DEVICES'] ='0, 1' 18 | torch.backends.cudnn.benchmark = True 19 | 20 | parser = argparse.ArgumentParser() 21 | parser.add_argument('--model_name', type=str, default='ncd') 22 | parser.add_argument('--num_neg', type=int, default=4) 23 | 24 | parser.add_argument('--fig_type', type=str, default='*') 25 | parser.add_argument('--dataset', type=str, default='i-raven') 26 | parser.add_argument('--root', type=str, default='../dataset/rpm') 27 | 28 | #parser.add_argument('--fig_type', type=str, default='neutral') # neutral, interpolation, extrapolation 29 | #parser.add_argument('--dataset', type=str, default='pgm') 30 | #parser.add_argument('--root', type=str, default='../dataset/rpm/pgm') 31 | 32 | parser.add_argument('--train_mode', type=bool, default=True) 33 | parser.add_argument('--lr', type=float, default=2e-4) 34 | parser.add_argument('--epochs', type=int, default=20) 35 | parser.add_argument('--batch_size', type=int, default=64) 36 | parser.add_argument('--img_size', type=int, default=256) 37 | parser.add_argument('--workers', type=int, default=16) 38 | parser.add_argument('--seed', type=int, default=123) 39 | 40 | args = parser.parse_args() 41 | 42 | 43 | if args.dataset == 'raven' or args.dataset == 'i-raven': 44 | mode = 'r' 45 | args.img_size = 256 46 | args.batch_size = 64 47 | elif args.dataset == 'pgm': 48 | mode = 'rc' 49 | args.img_size = 96 50 | args.batch_size = 256 51 | 52 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 53 | 54 | if torch.cuda.is_available: 55 | torch.cuda.manual_seed(args.seed) 56 | 57 | tf = transforms.Compose([ToTensor()]) 58 | 59 | train_set = dataset(os.path.join(args.root, args.dataset), 'train', args.fig_type, args.img_size, tf, args.train_mode) 60 | test_set = dataset(os.path.join(args.root, args.dataset), 'test', args.fig_type, args.img_size, tf) 61 | 62 | print ('test length', len(test_set), args.fig_type) 63 | 64 | train_loader = DataLoader(train_set, batch_size=args.batch_size, shuffle=True, num_workers=args.workers, pin_memory=True) 65 | test_loader = DataLoader(test_set, batch_size=args.batch_size, shuffle=False, num_workers=args.workers, pin_memory=True) 66 | 67 | save_name = args.model_name + '_' + args.fig_type + '_' + str(args.num_neg) + '_' + str(args.img_size) + '_' + str(args.batch_size) 68 | 69 | save_path_model = os.path.join(args.dataset, 'models', save_name) 70 | if not os.path.exists(save_path_model): 71 | os.makedirs(save_path_model) 72 | 73 | save_path_log = os.path.join(args.dataset, 'logs') 74 | if not os.path.exists(save_path_log): 75 | os.makedirs(save_path_log) 76 | 77 | model = NCD(mode) 78 | if torch.cuda.device_count() > 0: 79 | print("Let's use", torch.cuda.device_count(), "GPUs!") 80 | model = torch.nn.DataParallel(model) 81 | model.to(device) 82 | #model.load_state_dict(torch.load(save_path_model+'/model_06.pth')) 83 | 84 | optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=args.lr) 85 | 86 | time_now = datetime.now().strftime('%D-%H:%M:%S') 87 | save_log_name = os.path.join(save_path_log, 'log_{:s}.txt'.format(save_name)) 88 | with open(save_log_name, 'a') as f: 89 | f.write('\n------ lr: {:f}, batch_size: {:d}, img_size: {:d}, time: {:s} ------\n'.format( 90 | args.lr, args.batch_size, args.img_size, time_now)) 91 | f.close() 92 | 93 | 94 | def replace_ans(x, k=8): 95 | if k > 0: 96 | x1 = x[:, 0:16-k] 97 | x2 = x[:, 16-k:] 98 | 99 | indices = np.random.permutation(x.size(0)) 100 | x = torch.cat((x1, x2[indices,:]), dim=1) 101 | 102 | return x 103 | 104 | def compute_loss(predict): 105 | pseudo_target = torch.zeros(predict.shape) 106 | pseudo_target = Variable(pseudo_target, requires_grad=False).to(device) 107 | pseudo_target[:, 0:2] = 1 108 | 109 | return F.binary_cross_entropy_with_logits(predict, pseudo_target) 110 | 111 | 112 | def train(epoch): 113 | model.train() 114 | metrics = {'loss': []} 115 | 116 | train_loader_iter = iter(train_loader) 117 | for batch_idx in trange(len(train_loader_iter)): 118 | image, _ = next(train_loader_iter) 119 | 120 | image = Variable(image, requires_grad=True).to(device) 121 | 122 | image = replace_ans(image, args.num_neg) 123 | predict = model(image) 124 | loss = compute_loss(predict) 125 | 126 | loss.backward() 127 | optimizer.step() 128 | optimizer.zero_grad() 129 | 130 | metrics['loss'].append(loss.item()) 131 | 132 | if batch_idx > 1 and batch_idx % 12000 == 0: 133 | print ('Epoch: {:d}/{:d}, Loss: {:.3f}'.format(epoch, args.epochs, np.mean(metrics['loss']))) 134 | 135 | 136 | print ('Epoch: {:d}/{:d}, Loss: {:.3f}'.format(epoch, args.epochs, np.mean(metrics['loss']))) 137 | 138 | return metrics 139 | 140 | 141 | def test(epoch): 142 | model.eval() 143 | metrics = {'correct': [], 'count': []} 144 | 145 | test_loader_iter = iter(test_loader) 146 | for _ in trange(len(test_loader_iter)): 147 | image, target = next(test_loader_iter) 148 | 149 | image = Variable(image, requires_grad=False).to(device) 150 | target = Variable(target, requires_grad=False).to(device) 151 | 152 | with torch.no_grad(): 153 | predict = model(image) 154 | 155 | pred = torch.max(predict[:, 2:], 1)[1] 156 | correct = pred.eq(target.data).cpu().sum().numpy() 157 | 158 | metrics['correct'].append(correct) 159 | metrics['count'].append(target.size(0)) 160 | 161 | accuracy = 100 * np.sum(metrics['correct']) / np.sum(metrics['count']) 162 | 163 | print ('Testing Epoch: {:d}/{:d}, Accuracy: {:.3f} \n'.format(epoch, args.epochs, accuracy)) 164 | 165 | return metrics 166 | 167 | 168 | if __name__ == '__main__': 169 | for epoch in range(1, args.epochs+1): 170 | 171 | #metrics_test = test(epoch) 172 | #break 173 | 174 | metrics_train = train(epoch) 175 | 176 | # Save model 177 | if epoch > 0: 178 | save_name = os.path.join(save_path_model, 'model_{:02d}.pth'.format(epoch)) 179 | torch.save(model.state_dict(), save_name) 180 | 181 | metrics_test = test(epoch) 182 | 183 | loss_train = np.mean(metrics_train['loss']) 184 | acc_test = 100 * np.sum(metrics_test['correct']) / np.sum(metrics_test['count']) 185 | 186 | time_now = datetime.now().strftime('%H:%M:%S') 187 | 188 | with open(save_log_name, 'a') as f: 189 | f.write('Epoch {:02d}: Accuracy: {:.3f}, Loss: {:.3f}, Time: {:s}\n'.format( 190 | epoch, acc_test, loss_train, time_now)) 191 | f.close() 192 | 193 | 194 | 195 | 196 | --------------------------------------------------------------------------------