├── data ├── __init__.py ├── additional_transforms.py ├── feature_loader.py ├── transforms.py ├── datamgr.py └── dataset.py ├── requirements.txt ├── .idea ├── vcs.xml ├── inspectionProfiles │ └── profiles_settings.xml ├── modules.xml ├── TAD.iml ├── remote-mappings.xml ├── misc.xml ├── deployment.xml └── workspace.xml ├── scripts ├── train │ ├── cub_protonet.sh │ └── sun_protonet.sh └── test │ ├── sun_protonet.sh │ ├── cub_protonet.sh │ └── plot_distance_acc.sh ├── configs.py ├── methods ├── protonet.py ├── baselinefinetune.py ├── matchingnet.py ├── baselinetrain.py ├── meta_template.py ├── relationnet.py ├── maml.py └── apnet.py ├── README.md ├── filelists ├── SUN │ ├── write_SUN_filelist.py │ └── write_SUN_attr.py └── CUB │ └── write_CUB_filelist.py ├── io_utils.py ├── save_features.py ├── utils.py ├── test.py ├── plot_distance_acc.py ├── train.py └── backbone.py /data/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | opencv-python 2 | matplotlib 3 | scikit-learn 4 | tensorboard 5 | h5py 6 | tqdm -------------------------------------------------------------------------------- /.idea/vcs.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | -------------------------------------------------------------------------------- /.idea/inspectionProfiles/profiles_settings.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 6 | -------------------------------------------------------------------------------- /scripts/train/cub_protonet.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=3 python train.py --dataset CUB --model Conv4NP --method protonet --train_aug --n_shot 1 2 | CUDA_VISIBLE_DEVICES=3 python train.py --dataset CUB --model Conv4NP --method protonet --train_aug --n_shot 5 -------------------------------------------------------------------------------- /scripts/train/sun_protonet.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=2 python train.py --dataset SUN --model Conv4 --method protonet --train_aug --n_shot 1 2 | CUDA_VISIBLE_DEVICES=2 python train.py --dataset SUN --model Conv4 --method protonet --train_aug --n_shot 5 3 | -------------------------------------------------------------------------------- /.idea/modules.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /.idea/TAD.iml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /.idea/remote-mappings.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | -------------------------------------------------------------------------------- /configs.py: -------------------------------------------------------------------------------- 1 | save_dir = '.' 2 | data_dir = {} 3 | data_dir['CUB'] = './filelists/CUB/' 4 | data_dir['AWA2'] = './filelists/AWA2/' 5 | data_dir['SUN'] = './filelists/SUN/' 6 | data_dir['miniImagenet'] = './filelists/miniImagenet/' 7 | data_dir['omniglot'] = './filelists/omniglot/' 8 | data_dir['emnist'] = './filelists/emnist/' 9 | -------------------------------------------------------------------------------- /.idea/misc.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 6 | 7 | -------------------------------------------------------------------------------- /scripts/test/sun_protonet.sh: -------------------------------------------------------------------------------- 1 | python save_features.py --dataset SUN --model Conv4 --method protonet --train_aug --n_shot 1 2 | CUDA_VISIBLE_DEVICES=2 python test.py --dataset SUN --model Conv4 --method protonet --train_aug --n_shot 1 3 | python save_features.py --dataset SUN --model Conv4 --method protonet --train_aug --n_shot 5 4 | CUDA_VISIBLE_DEVICES=2 python test.py --dataset SUN --model Conv4 --method protonet --train_aug --n_shot 5 -------------------------------------------------------------------------------- /scripts/test/cub_protonet.sh: -------------------------------------------------------------------------------- 1 | python save_features.py --dataset CUB --model Conv4NP --method protonet --train_aug --n_shot 1 2 | CUDA_VISIBLE_DEVICES=3 python test.py --dataset CUB --model Conv4NP --method protonet --train_aug --n_shot 1 3 | python save_features.py --dataset CUB --model Conv4NP --method protonet --train_aug --n_shot 5 4 | CUDA_VISIBLE_DEVICES=3 python test.py --dataset CUB --model Conv4NP --method protonet --train_aug --n_shot 5 -------------------------------------------------------------------------------- /scripts/test/plot_distance_acc.sh: -------------------------------------------------------------------------------- 1 | python save_features.py --dataset CUB --model Conv4NP --method protonet --train_aug --n_shot 1 2 | CUDA_VISIBLE_DEVICES=2 python plot_distance_acc.py --dataset CUB --model Conv4NP --method protonet --train_aug --n_shot 1 --runs 1 3 | CUDA_VISIBLE_DEVICES=2 python plot_distance_acc.py --dataset CUB --model Conv4NP --method protonet --train_aug --n_shot 1 --runs 2 4 | CUDA_VISIBLE_DEVICES=2 python plot_distance_acc.py --dataset CUB --model Conv4NP --method protonet --train_aug --n_shot 1 --runs 3 5 | CUDA_VISIBLE_DEVICES=2 python plot_distance_acc.py --dataset CUB --model Conv4NP --method protonet --train_aug --n_shot 1 --runs 4 6 | CUDA_VISIBLE_DEVICES=2 python plot_distance_acc.py --dataset CUB --model Conv4NP --method protonet --train_aug --n_shot 1 --runs 5 7 | -------------------------------------------------------------------------------- /data/additional_transforms.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | 8 | import torch 9 | from PIL import ImageEnhance 10 | 11 | transformtypedict=dict(Brightness=ImageEnhance.Brightness, Contrast=ImageEnhance.Contrast, Sharpness=ImageEnhance.Sharpness, Color=ImageEnhance.Color) 12 | 13 | 14 | 15 | class ImageJitter(object): 16 | def __init__(self, transformdict): 17 | self.transforms = [(transformtypedict[k], transformdict[k]) for k in transformdict] 18 | 19 | 20 | def __call__(self, img): 21 | out = img 22 | randtensor = torch.rand(len(self.transforms)) 23 | 24 | for i, (transformer, alpha) in enumerate(self.transforms): 25 | r = alpha*(randtensor[i]*2.0 -1.0) + 1 26 | out = transformer(out).enhance(r).convert('RGB') 27 | 28 | return out 29 | 30 | 31 | 32 | 33 | -------------------------------------------------------------------------------- /data/feature_loader.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import h5py 4 | 5 | 6 | class SimpleHDF5Dataset: 7 | def __init__(self, file_handle=None): 8 | if file_handle == None: 9 | self.f = '' 10 | self.all_feats_dset = [] 11 | self.all_labels = [] 12 | self.total = 0 13 | else: 14 | self.f = file_handle 15 | self.all_feats_dset = self.f['all_feats'][...] 16 | self.all_labels = self.f['all_labels'][...] 17 | self.total = self.f['count'][0] 18 | # print('here') 19 | 20 | def __getitem__(self, i): 21 | return torch.Tensor(self.all_feats_dset[i, :]), int(self.all_labels[i]) 22 | 23 | def __len__(self): 24 | return self.total 25 | 26 | 27 | def init_loader(filename): 28 | with h5py.File(filename, 'r') as f: 29 | fileset = SimpleHDF5Dataset(f) 30 | 31 | # labels = [ l for l in fileset.all_labels if l != 0] 32 | feats = fileset.all_feats_dset 33 | labels = fileset.all_labels 34 | while np.sum(feats[-1]) == 0: 35 | feats = np.delete(feats, -1, axis=0) 36 | labels = np.delete(labels, -1, axis=0) 37 | 38 | class_list = np.unique(np.array(labels)).tolist() 39 | inds = range(len(labels)) 40 | 41 | cl_data_file = {} 42 | for cl in class_list: 43 | cl_data_file[cl] = [] 44 | for ind in inds: 45 | cl_data_file[labels[ind]].append(feats[ind]) 46 | 47 | return cl_data_file 48 | -------------------------------------------------------------------------------- /.idea/deployment.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 43 | -------------------------------------------------------------------------------- /methods/protonet.py: -------------------------------------------------------------------------------- 1 | # This code is modified from https://github.com/jakesnell/prototypical-networks 2 | 3 | import backbone 4 | import torch 5 | import torch.nn as nn 6 | from torch.autograd import Variable 7 | import numpy as np 8 | import torch.nn.functional as F 9 | from methods.meta_template import MetaTemplate 10 | 11 | class ProtoNet(MetaTemplate): 12 | def __init__(self, model_func, n_way, n_support): 13 | super(ProtoNet, self).__init__( model_func, n_way, n_support) 14 | self.loss_fn = nn.CrossEntropyLoss() 15 | 16 | 17 | def set_forward(self,x,is_feature = False): 18 | z_support, z_query = self.parse_feature(x,is_feature) 19 | 20 | z_support = z_support.contiguous() 21 | z_proto = z_support.view(self.n_way, self.n_support, -1 ).mean(1) #the shape of z is [n_data, n_dim] 22 | z_query = z_query.contiguous().view(self.n_way* self.n_query, -1 ) 23 | 24 | dists = euclidean_dist(z_query, z_proto) 25 | scores = -dists 26 | return scores 27 | 28 | 29 | def set_forward_loss(self, x): 30 | y_query = torch.from_numpy(np.repeat(range( self.n_way ), self.n_query )) 31 | y_query = Variable(y_query.cuda()) 32 | 33 | scores = self.set_forward(x) 34 | 35 | return self.loss_fn(scores, y_query ) 36 | 37 | 38 | def euclidean_dist( x, y): 39 | # x: N x D 40 | # y: M x D 41 | n = x.size(0) 42 | m = y.size(0) 43 | d = x.size(1) 44 | assert d == y.size(1) 45 | 46 | x = x.unsqueeze(1).expand(n, m, d) 47 | y = y.unsqueeze(0).expand(n, m, d) 48 | 49 | return torch.pow(x - y, 2).sum(2) -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | # Understanding Few-Shot Learning: Measuring Task Relatedness and Adaptation Difficulty via Attributes 3 | This repository is the official implementation of the paper "Understanding Few-Shot Learning: Measuring Task Relatedness and Adaptation Difficulty via Attributes" in Neural Information Processing Systems (NeurIPS 2023). In this project, we provide the Task Attribute Distance (TAD) metric to quantify the task relatedness and measure the adaptation difficulty of novel tasks. 4 | 5 | ## Dependenices 6 | 7 | The code is built with following libraries: 8 | - python 3.7 9 | - PyTorch 1.7.1 10 | - cv2 11 | - matplotlib 12 | - sklearn 13 | - tensorboard 14 | - h5py 15 | - tqdm 16 | 17 | #### Installation 18 | ```setup 19 | conda create -n TAD python=3.7 20 | source activate TAD 21 | conda install pytorch==1.7.1 torchvision==0.8.2 torchaudio==0.7.2 cudatoolkit=11.0 -c pytorch 22 | pip install -r requirements.txt 23 | ``` 24 | 25 | #### Dataset prepare 26 | Please download the CUB and SUN datasets, then put them under the path of `filelists//`. 27 | 28 | Here we provide a [link](https://drive.google.com/file/d/1Je-BZaCVe9fSoUUpkBhBlm8thxalRxkI/view?usp=sharing) of CUB dataset and related files. 29 | 30 | ## Training 31 | 32 | To train the FSL models (such as ProtoNet) on CUB dataset, run this command: 33 | 34 | ```train 35 | bash scripts/train/cub_protonet.sh 36 | ``` 37 | 38 | ## Evaluation 39 | 40 | To evaluate models on CUB, run: 41 | 42 | ```eval 43 | bash scripts/test/cub_protonet.sh 44 | ``` 45 | 46 | ## Plot task distance and accuracy 47 | 48 | To estimate the average TAD between each novel task and training tasks, then plot a figure of average TAD and accuracy, run: 49 | 50 | ```eval 51 | bash scripts/test/plot_distance_acc.sh 52 | ``` 53 | 54 | ## Fast start 55 | 56 | Here we provide some pretrained models for fast start. 57 | 58 | - [ProtoNet (Conv4NP)](https://drive.google.com/file/d/1AxXRP0QSmH0C5Y3i8GXEHThg6otK8leH/view?usp=sharing) trained on CUB in the 5-way 1-shot setting 59 | 60 | Download the pretrained model at file path `checkpoints/CUB/Conv4NP_protonet_0_aug_5way_1shot/`, and then run the command in `Plot task distance and accuracy` part. 61 | 62 | Our codebase is developed based on the [baseline++](https://github.com/wyharveychen/CloserLookFewShot) from the paper [A Closer Look at Few-shot Classification](https://arxiv.org/abs/1904.04232) and [COMET](https://github.com/snap-stanford/comet) from the paper [Concept Learners for Few-Shot Learning](https://arxiv.org/pdf/2007.07375.pdf). -------------------------------------------------------------------------------- /methods/baselinefinetune.py: -------------------------------------------------------------------------------- 1 | import backbone 2 | import torch 3 | import torch.nn as nn 4 | from torch.autograd import Variable 5 | import numpy as np 6 | import torch.nn.functional as F 7 | from methods.meta_template import MetaTemplate 8 | 9 | 10 | class BaselineFinetune(MetaTemplate): 11 | def __init__(self, model_func, n_way, n_support, loss_type="softmax"): 12 | super(BaselineFinetune, self).__init__(model_func, n_way, n_support) 13 | self.loss_type = loss_type 14 | 15 | def set_forward(self, x, is_feature=True): 16 | return self.set_forward_adaptation(x, is_feature); # Baseline always do adaptation 17 | 18 | def set_forward_adaptation(self, x, is_feature=True): 19 | assert is_feature == True, 'Baseline only support testing with feature' 20 | z_support, z_query = self.parse_feature(x, is_feature) 21 | 22 | z_support = z_support.contiguous().view(self.n_way * self.n_support, -1) 23 | z_query = z_query.contiguous().view(self.n_way * self.n_query, -1) 24 | 25 | y_support = torch.from_numpy(np.repeat(range(self.n_way), self.n_support)) 26 | y_support = Variable(y_support.cuda()) 27 | 28 | if self.loss_type == 'softmax': 29 | linear_clf = nn.Linear(self.feat_dim, self.n_way) 30 | elif self.loss_type == 'dist': 31 | linear_clf = backbone.distLinear(self.feat_dim, self.n_way) 32 | linear_clf = linear_clf.cuda() 33 | 34 | set_optimizer = torch.optim.SGD(linear_clf.parameters(), lr=0.01, momentum=0.9, dampening=0.9, 35 | weight_decay=0.001) 36 | 37 | loss_function = nn.CrossEntropyLoss() 38 | loss_function = loss_function.cuda() 39 | 40 | batch_size = 4 41 | support_size = self.n_way * self.n_support 42 | for epoch in range(100): 43 | rand_id = np.random.permutation(support_size) 44 | for i in range(0, support_size, batch_size): 45 | set_optimizer.zero_grad() 46 | selected_id = torch.from_numpy(rand_id[i: min(i + batch_size, support_size)]).cuda() 47 | z_batch = z_support[selected_id] 48 | y_batch = y_support[selected_id] 49 | scores = linear_clf(z_batch) 50 | loss = loss_function(scores, y_batch) 51 | loss.backward() 52 | set_optimizer.step() 53 | scores = linear_clf(z_query) 54 | return scores 55 | 56 | def set_forward_loss(self, x): 57 | raise ValueError('Baseline predict on pretrained feature and do not support finetune backbone') 58 | 59 | 60 | -------------------------------------------------------------------------------- /.idea/workspace.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 12 | 13 | 18 | 19 | 20 | 22 | 23 | 24 | 25 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 1684810029761 41 | 49 | 50 | 51 | 52 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | -------------------------------------------------------------------------------- /data/transforms.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------ 2 | # Copyright (c) Microsoft 3 | # Licensed under the MIT License. 4 | # Written by Bin Xiao (Bin.Xiao@microsoft.com) 5 | # ------------------------------------------------------------------------------ 6 | 7 | from __future__ import absolute_import 8 | from __future__ import division 9 | from __future__ import print_function 10 | 11 | import numpy as np 12 | import cv2 13 | 14 | def transform_preds(coords, center, scale, output_size): 15 | target_coords = np.zeros(coords.shape) 16 | trans = get_affine_transform(center, scale, 0, output_size, inv=1) 17 | for p in range(coords.shape[0]): 18 | target_coords[p, 0:2] = affine_transform(coords[p, 0:2], trans) 19 | return target_coords 20 | 21 | 22 | def get_affine_transform(center, 23 | scale, 24 | rot, 25 | output_size, 26 | shift=np.array([0, 0], dtype=np.float32), 27 | inv=0): 28 | if not isinstance(scale, np.ndarray) and not isinstance(scale, list): 29 | print(scale) 30 | scale = np.array([scale, scale]) 31 | 32 | scale_tmp = scale * 200.0 33 | src_w = scale_tmp[0] 34 | dst_w = output_size[0] 35 | dst_h = output_size[1] 36 | 37 | rot_rad = np.pi * rot / 180 38 | src_dir = get_dir([0, src_w * -0.5], rot_rad) 39 | dst_dir = np.array([0, dst_w * -0.5], np.float32) 40 | 41 | src = np.zeros((3, 2), dtype=np.float32) 42 | dst = np.zeros((3, 2), dtype=np.float32) 43 | src[0, :] = center + scale_tmp * shift 44 | src[1, :] = center + src_dir + scale_tmp * shift 45 | dst[0, :] = [dst_w * 0.5, dst_h * 0.5] 46 | dst[1, :] = np.array([dst_w * 0.5, dst_h * 0.5]) + dst_dir 47 | 48 | src[2:, :] = get_3rd_point(src[0, :], src[1, :]) 49 | dst[2:, :] = get_3rd_point(dst[0, :], dst[1, :]) 50 | 51 | if inv: 52 | trans = cv2.getAffineTransform(np.float32(dst), np.float32(src)) 53 | else: 54 | trans = cv2.getAffineTransform(np.float32(src), np.float32(dst)) 55 | 56 | return trans 57 | 58 | 59 | def affine_transform(pt, t): 60 | new_pt = np.array([pt[0], pt[1], 1.]).T 61 | new_pt = np.dot(t, new_pt) 62 | return new_pt[:2] 63 | 64 | 65 | def get_3rd_point(a, b): 66 | direct = a - b 67 | return b + np.array([-direct[1], direct[0]], dtype=np.float32) 68 | 69 | 70 | def get_dir(src_point, rot_rad): 71 | sn, cs = np.sin(rot_rad), np.cos(rot_rad) 72 | 73 | src_result = [0, 0] 74 | src_result[0] = src_point[0] * cs - src_point[1] * sn 75 | src_result[1] = src_point[0] * sn + src_point[1] * cs 76 | 77 | return src_result 78 | 79 | 80 | def crop(img, center, scale, output_size, rot=0): 81 | trans = get_affine_transform(center, scale, rot, output_size) 82 | 83 | dst_img = cv2.warpAffine(img, 84 | trans, 85 | (int(output_size[0]), int(output_size[1])), 86 | flags=cv2.INTER_LINEAR) 87 | 88 | return dst_img -------------------------------------------------------------------------------- /data/datamgr.py: -------------------------------------------------------------------------------- 1 | # This code is modified from https://github.com/facebookresearch/low-shot-shrink-hallucinate 2 | 3 | import torch 4 | from PIL import Image 5 | import numpy as np 6 | import torchvision.transforms as transforms 7 | import data.additional_transforms as add_transforms 8 | from data.dataset import SimpleDataset, SetDataset, EpisodicBatchSampler 9 | from abc import abstractmethod 10 | 11 | class TransformLoader: 12 | def __init__(self, image_size, 13 | normalize_param=dict(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), 14 | jitter_param=dict(Brightness=0.4, Contrast=0.4, Color=0.4)): 15 | self.image_size = image_size 16 | self.normalize_param = normalize_param 17 | self.jitter_param = jitter_param 18 | 19 | def parse_transform(self, transform_type): 20 | if transform_type == 'ImageJitter': 21 | method = add_transforms.ImageJitter(self.jitter_param) 22 | return method 23 | method = getattr(transforms, transform_type) 24 | if transform_type == 'RandomSizedCrop': 25 | return method(self.image_size) 26 | elif transform_type == 'CenterCrop': 27 | return method(self.image_size) 28 | elif transform_type == 'Scale': 29 | return method([int(self.image_size * 1.15), int(self.image_size * 1.15)]) 30 | elif transform_type == 'Normalize': 31 | return method(**self.normalize_param) 32 | else: 33 | return method() 34 | 35 | def get_composed_transform(self, aug=False): 36 | if aug: 37 | transform_list = ['ImageJitter', 'ToTensor', 'Normalize'] 38 | else: 39 | transform_list = ['ToTensor', 'Normalize'] 40 | 41 | transform_funcs = [self.parse_transform(x) for x in transform_list] 42 | transform = transforms.Compose(transform_funcs) 43 | return transform 44 | 45 | 46 | class DataManager: 47 | @abstractmethod 48 | def get_data_loader(self, data_file, aug): 49 | pass 50 | 51 | 52 | class SimpleDataManager(DataManager): 53 | def __init__(self, image_size, batch_size): 54 | super(SimpleDataManager, self).__init__() 55 | self.batch_size = batch_size 56 | self.trans_loader = TransformLoader(image_size) 57 | self.image_size = image_size 58 | 59 | def get_data_loader(self, data_file, aug, is_train=True): # parameters that would change on train/val set 60 | transform = self.trans_loader.get_composed_transform(aug) 61 | dataset = SimpleDataset(data_file, self.image_size, transform, is_train=is_train) 62 | data_loader_params = dict(batch_size=self.batch_size, shuffle=True, num_workers=12, pin_memory=True) 63 | data_loader = torch.utils.data.DataLoader(dataset, **data_loader_params) 64 | 65 | return data_loader 66 | 67 | 68 | class SetDataManager(DataManager): 69 | def __init__(self, image_size, n_way, n_support, n_query, n_eposide=100): 70 | super(SetDataManager, self).__init__() 71 | self.image_size = image_size 72 | self.n_way = n_way 73 | self.batch_size = n_support + n_query 74 | self.n_eposide = n_eposide 75 | 76 | self.trans_loader = TransformLoader(image_size) 77 | 78 | def get_data_loader(self, data_file, aug, is_train=True, attr_loc=False): # parameters that would change on train/val set 79 | transform = self.trans_loader.get_composed_transform(aug) 80 | dataset = SetDataset(data_file, self.batch_size, self.image_size, transform, is_train=is_train, attr_loc=attr_loc) 81 | sampler = EpisodicBatchSampler(len(dataset), self.n_way, self.n_eposide) 82 | data_loader_params = dict(batch_sampler=sampler, num_workers=12, pin_memory=True) 83 | data_loader = torch.utils.data.DataLoader(dataset, **data_loader_params) 84 | return data_loader -------------------------------------------------------------------------------- /filelists/SUN/write_SUN_filelist.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from os import listdir 3 | from os.path import isfile, isdir, join 4 | import os 5 | import json 6 | import random 7 | import re 8 | import scipy.io as sio 9 | 10 | def read_imgid_label_pair(filename): 11 | cwd = os.getcwd() 12 | prefix = os.path.join(cwd, 'materials/sun/images') 13 | 14 | image_dict = sio.loadmat(filename) 15 | image_path = image_dict['images'] 16 | label_to_imgid = {} 17 | label_to_path = {} 18 | for i in range(image_path.shape[0]): 19 | path = image_path[i] 20 | path_list = path[0][0].split('/') 21 | label = '' 22 | for j in range(len(path_list)): 23 | if j > 0 and j < len(path_list) - 1: 24 | label += path_list[j] 25 | if label not in label_to_imgid.keys(): 26 | label_to_imgid[label] = [] 27 | if label not in label_to_path.keys(): 28 | label_to_path[label] = [] 29 | label_to_imgid[label].append(i) 30 | label_to_path[label].append(prefix + '/' + path[0][0]) 31 | return label_to_imgid, label_to_path 32 | 33 | def read_img_attr_label(filename): 34 | attr_dict = sio.loadmat(filename) 35 | img_attr_labels = attr_dict['labels_cv'] # (14340, 102) 36 | return img_attr_labels 37 | 38 | def get_cl_attr_label(img_attr_labels, label_to_imgid, idx_to_label, cl_num): 39 | cl_attr_probs = [] 40 | for i in range(cl_num): 41 | label = idx_to_label[i] 42 | imgid_list = label_to_imgid[label] 43 | attr_probs = img_attr_labels[imgid_list].mean(0) 44 | cl_attr_probs.append(attr_probs) 45 | cl_attr_probs = np.array(cl_attr_probs) 46 | avg_prob = cl_attr_probs.mean() 47 | cl_attr_labels = (cl_attr_probs > avg_prob).astype('float32') 48 | return cl_attr_labels 49 | 50 | if __name__ == '__main__': 51 | 52 | prefix = './materials/sun/SUNAttributeDB' 53 | 54 | label_to_imgid, label_to_path = read_imgid_label_pair(os.path.join(prefix, 'images.mat')) 55 | img_attr_labels = read_img_attr_label(os.path.join(prefix, 'attributeLabels_continuous.mat')) 56 | 57 | all_label_list = list(label_to_imgid.keys()) 58 | cl_num = len(all_label_list) 59 | all_label_list.sort() 60 | label_to_idx = {} 61 | idx_to_label = {} 62 | for i in range(len(all_label_list)): 63 | idx_to_label[i] = all_label_list[i] 64 | label_to_idx[all_label_list[i]] = i 65 | 66 | cl_attr_labels = get_cl_attr_label(img_attr_labels, label_to_imgid, idx_to_label, cl_num) 67 | 68 | savedir = './materials/sun/' 69 | dataset_list = ['base','val','novel'] 70 | 71 | rs_label_list = list(label_to_imgid.keys()) 72 | random.shuffle(rs_label_list) 73 | for dataset in dataset_list: 74 | file_list = [] 75 | label_list = [] 76 | for i, label in enumerate(rs_label_list): 77 | label_id = label_to_idx[label] 78 | if 'base' in dataset: 79 | if (i >= 0 and i < 430): 80 | file_list = file_list + label_to_path[label] 81 | label_list = label_list + np.repeat(label_id, len(label_to_path[label])).tolist() 82 | if 'val' in dataset: 83 | if (i >= 430 and i < 645): 84 | file_list = file_list + label_to_path[label] 85 | label_list = label_list + np.repeat(label_id, len(label_to_path[label])).tolist() 86 | if 'novel' in dataset: 87 | if (i >= 645): 88 | file_list = file_list + label_to_path[label] 89 | label_list = label_list + np.repeat(label_id, len(label_to_path[label])).tolist() 90 | with open(savedir + dataset + '.json', 'w') as outfile: 91 | json.dump({'label_names':all_label_list, 'image_names':file_list, 'image_labels':label_list, 92 | 'attr_labels': cl_attr_labels.tolist()}, outfile) 93 | 94 | print("%s -OK" %dataset) 95 | -------------------------------------------------------------------------------- /filelists/CUB/write_CUB_filelist.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from os import listdir 3 | from os.path import isfile, isdir, join 4 | import os 5 | import json 6 | import random 7 | import re 8 | 9 | def read_img_id_pair(filename): 10 | img_to_id = dict() 11 | with open(filename, 'r') as fin: 12 | for line in fin.readlines(): 13 | line_split = line.strip().split(' ') 14 | img_to_id[line_split[1]] = int(line_split[0]) 15 | return img_to_id 16 | 17 | def read_parts(filename): 18 | id_to_parts = dict() 19 | with open(filename, 'r') as fin: 20 | for line in fin.readlines(): 21 | line_split = line.strip().split(' ') 22 | img_id, part_id, x, y, visible = int(line_split[0]), int(line_split[1]), float(line_split[2]), float(line_split[3]), int(line_split[4]) 23 | if part_id == 1: 24 | id_to_parts[img_id] = [[x, y, visible], ] 25 | else: 26 | id_to_parts[img_id].append([x, y, visible]) 27 | return id_to_parts 28 | 29 | def read_img_attr_label(filename): 30 | imgid_to_attrlabel = dict() 31 | with open(filename, 'r') as fin: 32 | for line in fin.readlines(): 33 | line_split = line.strip().split(' ') 34 | img_id, attr_label, visible = int(line_split[0]), int(line_split[1]), int(line_split[2]) 35 | if visible: 36 | if img_id in imgid_to_attrlabel.keys(): 37 | imgid_to_attrlabel[img_id].append(attr_label) 38 | else: 39 | imgid_to_attrlabel[img_id] = [attr_label] 40 | return imgid_to_attrlabel 41 | 42 | if __name__ == '__main__': 43 | 44 | img_to_idx = read_img_id_pair('./CUB_200_2011/images.txt') 45 | id_to_parts = read_parts('./CUB_200_2011/parts/part_locs.txt') 46 | id_to_attrlabel = read_img_attr_label('./CUB_200_2011/attributes/image_attribute_labels.txt') 47 | 48 | cwd = os.getcwd() 49 | data_path = join(cwd,'CUB_200_2011/images') 50 | savedir = './' 51 | dataset_list = ['base','val','novel'] 52 | 53 | folder_list = [f for f in listdir(data_path) if isdir(join(data_path, f))] 54 | folder_list.sort() 55 | label_dict = dict(zip(folder_list,range(0,len(folder_list)))) 56 | 57 | classfile_list_all = [] 58 | 59 | for i, folder in enumerate(folder_list): 60 | folder_path = join(data_path, folder) 61 | classfile_list_all.append( [ join(folder_path, cf) for cf in listdir(folder_path) if (isfile(join(folder_path,cf)) and cf[0] != '.')]) 62 | random.shuffle(classfile_list_all[i]) 63 | 64 | 65 | for dataset in dataset_list: 66 | file_list = [] 67 | label_list = [] 68 | for i, classfile_list in enumerate(classfile_list_all): 69 | if 'base' in dataset: 70 | if (i%2 == 0): 71 | file_list = file_list + classfile_list 72 | label_list = label_list + np.repeat(i, len(classfile_list)).tolist() 73 | if 'val' in dataset: 74 | if (i%4 == 1): 75 | file_list = file_list + classfile_list 76 | label_list = label_list + np.repeat(i, len(classfile_list)).tolist() 77 | if 'novel' in dataset: 78 | if (i%4 == 3): 79 | file_list = file_list + classfile_list 80 | label_list = label_list + np.repeat(i, len(classfile_list)).tolist() 81 | attr_label_list = [] 82 | for path in file_list: 83 | img = path.split('/')[-2] + '/' + path.split('/')[-1] 84 | attr_label_list.append(id_to_attrlabel[img_to_idx[img]]) 85 | part_list = [] 86 | for path in file_list: 87 | filename = re.search('/images/(.*)', path, flags=0).group(1) 88 | part_list.append(id_to_parts[img_to_idx[filename]]) 89 | with open(savedir + dataset + '.json', 'w') as outfile: 90 | json.dump({'label_names':folder_list, 'image_names':file_list, 'image_labels':label_list, 'part': part_list, 91 | 'attr_labels': attr_label_list}, outfile) 92 | 93 | print("%s -OK" %dataset) 94 | -------------------------------------------------------------------------------- /methods/matchingnet.py: -------------------------------------------------------------------------------- 1 | # This code is modified from https://github.com/facebookresearch/low-shot-shrink-hallucinate 2 | 3 | import backbone 4 | import torch 5 | import torch.nn as nn 6 | from torch.autograd import Variable 7 | import numpy as np 8 | import torch.nn.functional as F 9 | from methods.meta_template import MetaTemplate 10 | import utils 11 | import copy 12 | 13 | class MatchingNet(MetaTemplate): 14 | def __init__(self, model_func, n_way, n_support): 15 | super(MatchingNet, self).__init__( model_func, n_way, n_support) 16 | 17 | self.loss_fn = nn.NLLLoss() 18 | 19 | self.FCE = FullyContextualEmbedding(self.feat_dim) 20 | self.G_encoder = nn.LSTM(self.feat_dim, self.feat_dim, 1, batch_first=True, bidirectional=True) 21 | 22 | self.relu = nn.ReLU() 23 | self.softmax = nn.Softmax() 24 | 25 | def encode_training_set(self, S, G_encoder = None): 26 | if G_encoder is None: 27 | G_encoder = self.G_encoder 28 | out_G = G_encoder(S.unsqueeze(0))[0] 29 | out_G = out_G.squeeze(0) 30 | G = S + out_G[:,:S.size(1)] + out_G[:,S.size(1):] 31 | G_norm = torch.norm(G,p=2, dim =1).unsqueeze(1).expand_as(G) 32 | G_normalized = G.div(G_norm+ 0.00001) 33 | return G, G_normalized 34 | 35 | def get_logprobs(self, f, G, G_normalized, Y_S, FCE = None): 36 | if FCE is None: 37 | FCE = self.FCE 38 | F = FCE(f, G) 39 | F_norm = torch.norm(F,p=2, dim =1).unsqueeze(1).expand_as(F) 40 | F_normalized = F.div(F_norm+ 0.00001) 41 | #scores = F.mm(G_normalized.transpose(0,1)) #The implementation of Ross et al., but not consistent with origin paper and would cause large norm feature dominate 42 | scores = self.relu( F_normalized.mm(G_normalized.transpose(0,1)) ) *100 # The original paper use cosine simlarity, but here we scale it by 100 to strengthen highest probability after softmax 43 | softmax = self.softmax(scores) 44 | logprobs =(softmax.mm(Y_S)+1e-6).log() 45 | return logprobs 46 | 47 | def set_forward(self, x, is_feature = False): 48 | z_support, z_query = self.parse_feature(x,is_feature) 49 | 50 | z_support = z_support.contiguous().view( self.n_way* self.n_support, -1 ) 51 | z_query = z_query.contiguous().view( self.n_way* self.n_query, -1 ) 52 | G, G_normalized = self.encode_training_set( z_support) 53 | 54 | y_s = torch.from_numpy(np.repeat(range( self.n_way ), self.n_support )) 55 | Y_S = Variable( utils.one_hot(y_s, self.n_way ) ).cuda() 56 | f = z_query 57 | logprobs = self.get_logprobs(f, G, G_normalized, Y_S) 58 | return logprobs 59 | 60 | def set_forward_loss(self, x): 61 | y_query = torch.from_numpy(np.repeat(range( self.n_way ), self.n_query )) 62 | y_query = Variable(y_query.cuda()) 63 | 64 | logprobs = self.set_forward(x) 65 | 66 | return self.loss_fn(logprobs, y_query ) 67 | 68 | def cuda(self): 69 | super(MatchingNet, self).cuda() 70 | self.FCE = self.FCE.cuda() 71 | return self 72 | 73 | class FullyContextualEmbedding(nn.Module): 74 | def __init__(self, feat_dim): 75 | super(FullyContextualEmbedding, self).__init__() 76 | self.lstmcell = nn.LSTMCell(feat_dim*2, feat_dim) 77 | self.softmax = nn.Softmax() 78 | self.c_0 = Variable(torch.zeros(1,feat_dim)) 79 | self.feat_dim = feat_dim 80 | #self.K = K 81 | 82 | def forward(self, f, G): 83 | h = f 84 | c = self.c_0.expand_as(f) 85 | G_T = G.transpose(0,1) 86 | K = G.size(0) #Tuna to be comfirmed 87 | for k in range(K): 88 | logit_a = h.mm(G_T) 89 | a = self.softmax(logit_a) 90 | r = a.mm(G) 91 | x = torch.cat((f, r),1) 92 | 93 | h, c = self.lstmcell(x, (h, c)) 94 | h = h + f 95 | 96 | return h 97 | def cuda(self): 98 | super(FullyContextualEmbedding, self).cuda() 99 | self.c_0 = self.c_0.cuda() 100 | return self 101 | 102 | -------------------------------------------------------------------------------- /methods/baselinetrain.py: -------------------------------------------------------------------------------- 1 | import backbone 2 | import utils 3 | 4 | import torch 5 | import torch.nn as nn 6 | from torch.autograd import Variable 7 | import numpy as np 8 | import torch.nn.functional as F 9 | 10 | 11 | class BaselineTrain(nn.Module): 12 | def __init__(self, model_func, num_class, loss_type='softmax'): 13 | super(BaselineTrain, self).__init__() 14 | self.feature = model_func() 15 | if loss_type == 'softmax': 16 | self.classifier = nn.Linear(self.feature.final_feat_dim, num_class) 17 | self.classifier.bias.data.fill_(0) 18 | elif loss_type == 'dist': # Baseline ++ 19 | self.classifier = backbone.distLinear(self.feature.final_feat_dim, num_class) 20 | self.loss_type = loss_type # 'softmax' #'dist' 21 | self.num_class = num_class 22 | self.loss_fn = nn.CrossEntropyLoss() 23 | self.DBval = False; # only set True for CUB dataset, see issue #31 24 | 25 | def forward(self, x): 26 | x = Variable(x.cuda()) 27 | out = self.feature.forward(x) 28 | scores = self.classifier.forward(out) 29 | return scores 30 | 31 | def forward_loss(self, x, y): 32 | scores = self.forward(x) 33 | y = Variable(y.cuda()) 34 | return self.loss_fn(scores, y) 35 | 36 | def train_loop(self, epoch, train_loader, optimizer, tf_writer): 37 | print_freq = 10 38 | avg_loss = 0 39 | 40 | for i, (x, y) in enumerate(train_loader): 41 | optimizer.zero_grad() 42 | loss = self.forward_loss(x, y) 43 | loss.backward() 44 | optimizer.step() 45 | 46 | avg_loss = avg_loss + loss.item() 47 | 48 | if i % print_freq == 0: 49 | # print(optimizer.state_dict()['param_groups'][0]['lr']) 50 | print('Epoch {:d} | Batch {:d}/{:d} | Loss {:f}'.format(epoch, i, len(train_loader), 51 | avg_loss / float(i + 1))) 52 | tf_writer.add_scalar('loss/train', avg_loss / float(i + 1), epoch) 53 | 54 | def test_loop(self, val_loader): 55 | if self.DBval: 56 | return self.analysis_loop(val_loader) 57 | else: 58 | return -1 # no validation, just save model during iteration 59 | 60 | def analysis_loop(self, val_loader, record=None): 61 | class_file = {} 62 | for i, (x, y) in enumerate(val_loader): 63 | x = x.cuda() 64 | x_var = Variable(x) 65 | feats = self.feature.forward(x_var).data.cpu().numpy() 66 | labels = y.cpu().numpy() 67 | for f, l in zip(feats, labels): 68 | if l not in class_file.keys(): 69 | class_file[l] = [] 70 | class_file[l].append(f) 71 | 72 | for cl in class_file: 73 | class_file[cl] = np.array(class_file[cl]) 74 | 75 | DB = DBindex(class_file) 76 | print('DB index = %4.2f' % (DB)) 77 | return 1 / DB # DB index: the lower the better 78 | 79 | 80 | def DBindex(cl_data_file): 81 | # For the definition Davis Bouldin index (DBindex), see https://en.wikipedia.org/wiki/Davies%E2%80%93Bouldin_index 82 | # DB index present the intra-class variation of the data 83 | # As baseline/baseline++ do not train few-shot classifier in training, this is an alternative metric to evaluate the validation set 84 | # Emperically, this only works for CUB dataset but not for miniImagenet dataset 85 | 86 | class_list = cl_data_file.keys() 87 | cl_num = len(class_list) 88 | cl_means = [] 89 | stds = [] 90 | DBs = [] 91 | for cl in class_list: 92 | cl_means.append(np.mean(cl_data_file[cl], axis=0)) 93 | stds.append(np.sqrt(np.mean(np.sum(np.square(cl_data_file[cl] - cl_means[-1]), axis=1)))) 94 | 95 | mu_i = np.tile(np.expand_dims(np.array(cl_means), axis=0), (len(class_list), 1, 1)) 96 | mu_j = np.transpose(mu_i, (1, 0, 2)) 97 | mdists = np.sqrt(np.sum(np.square(mu_i - mu_j), axis=2)) 98 | 99 | for i in range(cl_num): 100 | DBs.append(np.max([(stds[i] + stds[j]) / mdists[i, j] for j in range(cl_num) if j != i])) 101 | return np.mean(DBs) 102 | 103 | -------------------------------------------------------------------------------- /filelists/SUN/write_SUN_attr.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from os import listdir 3 | from os.path import isfile, isdir, join 4 | import os 5 | import json 6 | import random 7 | import re 8 | import scipy.io as sio 9 | 10 | def read_imgid_label_pair(filename): 11 | cwd = os.getcwd() 12 | prefix = os.path.join(cwd, 'materials/sun/images') 13 | 14 | image_dict = sio.loadmat(filename) 15 | image_path = image_dict['images'] 16 | label_to_imgid = {} 17 | label_to_path = {} 18 | 19 | imgid_to_label = {} 20 | for i in range(image_path.shape[0]): 21 | path = image_path[i] 22 | path_list = path[0][0].split('/') 23 | label = '' 24 | for j in range(len(path_list)): 25 | if j > 0 and j < len(path_list) - 1: 26 | label += path_list[j] 27 | if label not in label_to_imgid.keys(): 28 | label_to_imgid[label] = [] 29 | if label not in label_to_path.keys(): 30 | label_to_path[label] = [] 31 | label_to_imgid[label].append(i) 32 | label_to_path[label].append(prefix + '/' + path[0][0]) 33 | 34 | imgid_to_label[i] = label 35 | return label_to_imgid, label_to_path, imgid_to_label 36 | 37 | def read_img_attr_label(filename): 38 | attr_dict = sio.loadmat(filename) 39 | img_attr_labels = attr_dict['labels_cv'] # (14340, 102) 40 | return img_attr_labels 41 | 42 | # def get_cl_attr_label(img_attr_labels, label_to_imgid, idx_to_label, cl_num): 43 | # cl_attr_probs = [] 44 | # for i in range(cl_num): 45 | # label = idx_to_label[i] 46 | # imgid_list = label_to_imgid[label] 47 | # attr_probs = img_attr_labels[imgid_list].mean(0) 48 | # cl_attr_probs.append(attr_probs) 49 | # cl_attr_probs = np.array(cl_attr_probs) 50 | # return cl_attr_probs 51 | 52 | def get_cl_attr_label(img_attr_labels, imgid_to_label, label_to_idx, cl_num): 53 | count = 0 54 | class_attr_count = np.zeros((cl_num, 102, 2)) 55 | 56 | for i in range(img_attr_labels.shape[0]): 57 | count += 1 58 | class_label = imgid_to_label[i] 59 | class_label_id = label_to_idx[class_label] 60 | 61 | attr_labels = img_attr_labels[i, :] 62 | for j in range(102): 63 | attr_label_prob = attr_labels[j] 64 | if attr_label_prob >= 0.5: 65 | class_attr_count[class_label_id][j][1] += 1 66 | else: 67 | class_attr_count[class_label_id][j][0] += 1 68 | print("count:", count) 69 | 70 | class_attr_min_label = np.argmin(class_attr_count, axis=2) 71 | class_attr_max_label = np.argmax(class_attr_count, axis=2) 72 | equal_count = np.where( 73 | class_attr_min_label == class_attr_max_label) # check where 0 count = 1 count, set the corresponding class attribute label to be 1 74 | class_attr_max_label[equal_count] = 1 75 | 76 | min_class_count = 10 77 | attr_class_count = np.sum(class_attr_max_label, axis=0) 78 | mask = np.where(attr_class_count >= min_class_count)[ 79 | 0] # select attributes that are present (on a class level) in at least [min_class_count] classes 80 | class_attr_label_masked = class_attr_max_label[:, mask] 81 | return class_attr_label_masked 82 | 83 | 84 | if __name__ == '__main__': 85 | 86 | prefix = './materials/sun/SUNAttributeDB' 87 | 88 | label_to_imgid, label_to_path, imgid_to_label = read_imgid_label_pair(os.path.join(prefix, 'images.mat')) 89 | img_attr_labels = read_img_attr_label(os.path.join(prefix, 'attributeLabels_continuous.mat')) 90 | 91 | all_label_list = list(label_to_imgid.keys()) 92 | cl_num = len(all_label_list) 93 | all_label_list.sort() 94 | label_to_idx = {} 95 | idx_to_label = {} 96 | for i in range(len(all_label_list)): 97 | idx_to_label[i] = all_label_list[i] 98 | label_to_idx[all_label_list[i]] = i 99 | 100 | #cl_attr_probs = get_cl_attr_label(img_attr_labels, label_to_imgid, idx_to_label, cl_num) 101 | masked_cl_attr_probs = get_cl_attr_label(img_attr_labels, imgid_to_label, label_to_idx, cl_num) 102 | 103 | savedir = './materials/sun/' 104 | 105 | with open(savedir + 'masked_attr_dist.json', 'w') as outfile: 106 | json.dump({'attr_dist': masked_cl_attr_probs.tolist()}, outfile) 107 | print("%s -OK") 108 | -------------------------------------------------------------------------------- /io_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | import glob 4 | import argparse 5 | import backbone 6 | 7 | model_dict = dict( 8 | Conv4=backbone.Conv4, 9 | Conv4NP=backbone.Conv4NP, 10 | Conv4S=backbone.Conv4S, 11 | Conv6=backbone.Conv6, 12 | Conv6NP=backbone.Conv6NP, 13 | ResNet10=backbone.ResNet10, 14 | ResNet10NP=backbone.ResNet10NP, 15 | ResNet18=backbone.ResNet18, 16 | ResNet34=backbone.ResNet34, 17 | ResNet50=backbone.ResNet50, 18 | ResNet101=backbone.ResNet101) 19 | 20 | 21 | def parse_args(script): 22 | parser = argparse.ArgumentParser(description='few-shot script %s' % (script)) 23 | parser.add_argument('--dataset', default='CUB', help='CUB/miniImagenet/cross/omniglot/cross_char') 24 | parser.add_argument('--model', default='Conv4', 25 | help='model: Conv{4|6} / ResNet{10|18|34|50|101}') # 50 and 101 are not used in the paper 26 | parser.add_argument('--method', default='baseline', 27 | help='baseline/baseline++/protonet/matchingnet/relationnet{_softmax}/maml{_approx}') # relationnet_softmax replace L2 norm with softmax to expedite training, maml_approx use first-order approximation in the gradient for efficiency 28 | parser.add_argument('--train_n_way', default=5, type=int, 29 | help='class num to classify for training') # baseline and baseline++ would ignore this parameter 30 | parser.add_argument('--test_n_way', default=5, type=int, 31 | help='class num to classify for testing (validation) ') # baseline and baseline++ only use this parameter in finetuning 32 | parser.add_argument('--n_shot', default=5, type=int, 33 | help='number of labeled data in each class, same as n_support') # baseline and baseline++ only use this parameter in finetuning 34 | parser.add_argument('--train_aug', action='store_true', 35 | help='perform data augmentation or not during training ') # still required for save_features.py and test.py to find the model path correctly 36 | parser.add_argument('--exp_str', default='0', type=str, help='just to add some clarification for each exp') 37 | 38 | if script == 'train': 39 | parser.add_argument('--num_classes', default=200, type=int, 40 | help='total number of classes in softmax, only used in baseline') # make it larger than the maximum label value in base class 41 | parser.add_argument('--save_freq', default=50, type=int, help='Save frequency') 42 | parser.add_argument('--start_epoch', default=0, type=int, help='Starting epoch') 43 | parser.add_argument('--stop_epoch', default=-1, type=int, 44 | help='Stopping epoch') # for meta-learning methods, each epoch contains 100 episodes. The default epoch number is dataset dependent. See train.py 45 | parser.add_argument('--resume', action='store_true', 46 | help='continue from previous trained model with largest epoch') 47 | parser.add_argument('--warmup', action='store_true', 48 | help='continue from baseline, neglected if resume is true') # never used in the paper 49 | parser.add_argument('--beta', default=0.0, type=float, help='Coefficient for attribute loss') 50 | parser.add_argument('--attr_weight', default=15, type=float, help='attr_loss weight for COMP') 51 | parser.add_argument('--orth_weight', default=0.00035, type=float, help='orth_loss weight for COMP') 52 | elif script == 'save_features': 53 | parser.add_argument('--split', default='novel', 54 | help='base/val/novel') # default novel, but you can also test base/val class accuracy if you want 55 | parser.add_argument('--save_iter', default=-1, type=int, 56 | help='save feature from the model trained in x epoch, use the best model if x is -1') 57 | elif script == 'test': 58 | parser.add_argument('--split', default='novel', 59 | help='base/val/novel') # default novel, but you can also test base/val class accuracy if you want 60 | parser.add_argument('--save_iter', default=-1, type=int, 61 | help='saved feature from the model trained in x epoch, use the best model if x is -1') 62 | parser.add_argument('--adaptation', action='store_true', help='further adaptation in test time or not') 63 | parser.add_argument('--runs', default=None, type=str, help='runs for write the dis and acc files') 64 | else: 65 | raise ValueError('Unknown script') 66 | 67 | return parser.parse_args() 68 | 69 | 70 | def get_assigned_file(checkpoint_dir, num): 71 | assign_file = os.path.join(checkpoint_dir, '{:d}.tar'.format(num)) 72 | return assign_file 73 | 74 | 75 | def get_resume_file(checkpoint_dir): 76 | filelist = glob.glob(os.path.join(checkpoint_dir, '*.tar')) 77 | if len(filelist) == 0: 78 | return None 79 | 80 | filelist = [x for x in filelist if os.path.basename(x) != 'best_model.tar'] 81 | epochs = np.array([int(os.path.splitext(os.path.basename(x))[0]) for x in filelist]) 82 | max_epoch = np.max(epochs) 83 | resume_file = os.path.join(checkpoint_dir, '{:d}.tar'.format(max_epoch)) 84 | return resume_file 85 | 86 | 87 | def get_best_file(checkpoint_dir): 88 | best_file = os.path.join(checkpoint_dir, 'best_model.tar') 89 | if os.path.isfile(best_file): 90 | return best_file 91 | else: 92 | return get_resume_file(checkpoint_dir) 93 | -------------------------------------------------------------------------------- /save_features.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torch.autograd import Variable 4 | import os 5 | import glob 6 | import h5py 7 | 8 | import configs 9 | import backbone 10 | from data.datamgr import SimpleDataManager 11 | from methods.baselinetrain import BaselineTrain 12 | from methods.baselinefinetune import BaselineFinetune 13 | from methods.protonet import ProtoNet 14 | from methods.matchingnet import MatchingNet 15 | from methods.relationnet import RelationNet 16 | from methods.maml import MAML 17 | from io_utils import model_dict, parse_args, get_resume_file, get_best_file, get_assigned_file 18 | 19 | 20 | def save_features(model, data_loader, outfile): 21 | f = h5py.File(outfile, 'w') 22 | max_count = len(data_loader) * data_loader.batch_size 23 | all_labels = f.create_dataset('all_labels', (max_count,), dtype='i') 24 | all_feats = None 25 | count = 0 26 | for i, (x, y) in enumerate(data_loader): 27 | if i % 10 == 0: 28 | print('{:d}/{:d}'.format(i, len(data_loader))) 29 | x = x.cuda() 30 | x_var = Variable(x) 31 | feats = model(x_var) 32 | if all_feats is None: 33 | all_feats = f.create_dataset('all_feats', [max_count] + list(feats.size()[1:]), dtype='f') 34 | all_feats[count:count + feats.size(0)] = feats.data.cpu().numpy() 35 | all_labels[count:count + feats.size(0)] = y.cpu().numpy() 36 | count = count + feats.size(0) 37 | 38 | count_var = f.create_dataset('count', (1,), dtype='i') 39 | count_var[0] = count 40 | 41 | f.close() 42 | 43 | 44 | if __name__ == '__main__': 45 | params = parse_args('save_features') 46 | assert params.method != 'maml' and params.method != 'maml_approx', 'maml do not support save_feature and run' 47 | 48 | if 'Conv' in params.model: 49 | if params.dataset in ['omniglot', 'cross_char']: 50 | image_size = 28 51 | else: 52 | image_size = 84 53 | else: 54 | image_size = 224 55 | 56 | if params.dataset in ['omniglot', 'cross_char']: 57 | assert params.model == 'Conv4' and not params.train_aug, 'omniglot only support Conv4 without augmentation' 58 | params.model = 'Conv4S' 59 | 60 | split = params.split 61 | if params.dataset == 'cross': 62 | if split == 'base': 63 | loadfile = configs.data_dir['miniImagenet'] + 'all.json' 64 | else: 65 | loadfile = configs.data_dir['CUB'] + split + '.json' 66 | elif params.dataset == 'cross_char': 67 | if split == 'base': 68 | loadfile = configs.data_dir['omniglot'] + 'noLatin.json' 69 | else: 70 | loadfile = configs.data_dir['emnist'] + split + '.json' 71 | elif params.dataset == 'AWA2': 72 | # for AWA2, we use both validation and test classes as novel classes, because the number of test classes in AWA2 is too small (only 10) 73 | loadfile = configs.data_dir[params.dataset] + 'val_novel.json' 74 | else: 75 | loadfile = configs.data_dir[params.dataset] + split + '.json' 76 | 77 | checkpoint_dir = '%s/checkpoints/%s/%s_%s_%s' % ( 78 | configs.save_dir, params.dataset, params.model, params.method, params.exp_str) 79 | if params.train_aug: 80 | checkpoint_dir += '_aug' 81 | if not params.method in ['baseline', 'baseline++', 'comp']: 82 | checkpoint_dir += '_%dway_%dshot' % (params.train_n_way, params.n_shot) 83 | 84 | if params.save_iter != -1: 85 | modelfile = get_assigned_file(checkpoint_dir, params.save_iter) 86 | else: 87 | modelfile = get_best_file(checkpoint_dir) 88 | 89 | if params.save_iter != -1: 90 | outfile = os.path.join(checkpoint_dir.replace("checkpoints", "features"), 91 | split + "_" + str(params.save_iter) + ".hdf5") 92 | else: 93 | outfile = os.path.join(checkpoint_dir.replace("checkpoints", "features"), split + ".hdf5") 94 | 95 | datamgr = SimpleDataManager(image_size, batch_size=64) 96 | data_loader = datamgr.get_data_loader(loadfile, aug=False, is_train=False) 97 | 98 | if params.method in ['relationnet', 'relationnet_softmax']: 99 | if params.model == 'Conv4': 100 | model = backbone.Conv4NP() 101 | elif params.model == 'Conv6': 102 | model = backbone.Conv6NP() 103 | elif params.model == 'Conv4S': 104 | model = backbone.Conv4SNP() 105 | else: 106 | model = model_dict[params.model](flatten=False) 107 | elif params.method in ['maml', 'maml_approx']: 108 | raise ValueError('MAML do not support save feature') 109 | else: 110 | model = model_dict[params.model]() 111 | 112 | model = model.cuda() 113 | print(modelfile) 114 | tmp = torch.load(modelfile) 115 | state = tmp['state'] 116 | state_keys = list(state.keys()) 117 | for i, key in enumerate(state_keys): 118 | if "feature." in key: 119 | newkey = key.replace("feature.", 120 | "") # an architecture model has attribute 'feature', load architecture feature to backbone by casting name from 'feature.trunk.xx' to 'trunk.xx' 121 | state[newkey] = state.pop(key) 122 | else: 123 | state.pop(key) 124 | 125 | model.load_state_dict(state) 126 | model.eval() 127 | 128 | dirname = os.path.dirname(outfile) 129 | if not os.path.isdir(dirname): 130 | os.makedirs(dirname) 131 | save_features(model, data_loader, outfile) 132 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import torch 4 | import matplotlib.pyplot as plt 5 | 6 | def one_hot(y, num_class): 7 | return torch.zeros((len(y), num_class)).scatter_(1, y.unsqueeze(1), 1) 8 | 9 | 10 | def DBindex(cl_data_file): 11 | class_list = cl_data_file.keys() 12 | cl_num = len(class_list) 13 | cl_means = [] 14 | stds = [] 15 | DBs = [] 16 | for cl in class_list: 17 | cl_means.append(np.mean(cl_data_file[cl], axis=0)) 18 | stds.append(np.sqrt(np.mean(np.sum(np.square(cl_data_file[cl] - cl_means[-1]), axis=1)))) 19 | 20 | mu_i = np.tile(np.expand_dims(np.array(cl_means), axis=0), (len(class_list), 1, 1)) 21 | mu_j = np.transpose(mu_i, (1, 0, 2)) 22 | mdists = np.sqrt(np.sum(np.square(mu_i - mu_j), axis=2)) 23 | 24 | for i in range(cl_num): 25 | DBs.append(np.max([(stds[i] + stds[j]) / mdists[i, j] for j in range(cl_num) if j != i])) 26 | return np.mean(DBs) 27 | 28 | 29 | def sparsity(cl_data_file): 30 | class_list = cl_data_file.keys() 31 | cl_sparsity = [] 32 | for cl in class_list: 33 | cl_sparsity.append(np.mean([np.sum(x != 0) for x in cl_data_file[cl]])) 34 | 35 | return np.mean(cl_sparsity) 36 | 37 | """ 38 | Files for plot figs of adaptation difficulty 39 | """ 40 | 41 | def read_attr_dists(trainloader, dataset): 42 | if dataset == 'SUN': 43 | print("attribute distance for SUN!") 44 | attr_dists = trainloader.dataset.meta['attr_labels'] 45 | attr_dists_array = np.array(attr_dists).astype('float32') 46 | attr_dists = torch.from_numpy(attr_dists_array) 47 | 48 | base_labels = trainloader.dataset.cl_list 49 | base_ind = np.unique(base_labels).tolist() 50 | elif dataset == 'CUB': 51 | print("attribute distance for CUB!") 52 | filename = 'filelists/CUB/CUB_200_2011/masked_class_attribute_labels.txt' 53 | attr_dists = [] 54 | with open(filename, 'r') as f: 55 | for line in f.readlines(): 56 | line_split = line.strip().split(' ') 57 | float_line = [] 58 | for str_num in line_split: 59 | float_line.append(float(str_num)) 60 | attr_dists.append(float_line) 61 | 62 | attr_dists_array = np.array(attr_dists) 63 | attr_dists = torch.from_numpy(attr_dists_array) 64 | 65 | base_ind = [] 66 | for i in range(200): 67 | if i % 2 == 0: 68 | base_ind.append(i) 69 | elif dataset == 'AWA2': 70 | print("attribute distance for AWA2!") 71 | filename = 'filelists/AWA2/class_attribute_label.txt' 72 | attr_dists = [] 73 | with open(filename, 'r') as f: 74 | for line in f.readlines(): 75 | line_split = line.strip().split(' ') 76 | float_line = [] 77 | for str_num in line_split: 78 | float_line.append(float(str_num)) 79 | attr_dists.append(float_line) 80 | 81 | attr_dists_array = np.array(attr_dists) 82 | attr_dists = torch.from_numpy(attr_dists_array) 83 | 84 | base_labels = trainloader.dataset.cl_list 85 | base_ind = np.unique(base_labels).tolist() 86 | else: 87 | AssertionError("not implement!") 88 | return attr_dists, base_ind 89 | 90 | def get_attr_distance(trainloader, dataset): 91 | attr_dists, base_ind = read_attr_dists(trainloader, dataset) 92 | 93 | # class-agnostic or task-agnostic 94 | # part_dists = _dists_check(part_dists) 95 | # base_dists = part_dists[base_ind, :].mean(0) # (102,) 96 | # 97 | # all_cls_dists = part_dists 98 | # base_cls_dists = part_dists[base_ind, :] # (100, 102) 99 | 100 | # original 101 | import random 102 | base_cls_dists = [] 103 | sc_cls_lists = [random.sample(base_ind, 5) for _ in range(10000)] 104 | for sc_cls in sc_cls_lists: 105 | sc_dists = attr_dists[sc_cls, :] 106 | base_cls_dists.append(sc_dists) 107 | base_cls_dists = torch.stack(base_cls_dists, dim=0) 108 | all_cls_dists = attr_dists 109 | base_dists = base_cls_dists.mean(1) # (task_num, 102) 110 | 111 | return all_cls_dists, base_dists, base_cls_dists 112 | 113 | def interval_avg(acc_all, dist_all): 114 | min_d = np.min(dist_all) 115 | max_d = np.max(dist_all) 116 | inr = (max_d - min_d) / 9 117 | acc_inr = [0 for _ in range(9)] 118 | dis_inr = [0 for _ in range(9)] 119 | cout_inr = [0 for _ in range(9)] 120 | for dis, acc in zip(dist_all, acc_all): 121 | for i in range(9): 122 | min_i = min_d + i * inr 123 | max_i = min_d + (i + 1) * inr 124 | 125 | if dis >= min_i and dis <= max_i: 126 | acc_inr[i] += acc 127 | dis_inr[i] += dis 128 | cout_inr[i] += 1 129 | 130 | acc_avg_inr, dis_avg_inr = [], [] 131 | for acc, dis, num in zip(acc_inr, dis_inr, cout_inr): 132 | if num != 0: 133 | acc_avg_inr.append(1.0 * acc / num) 134 | dis_avg_inr.append(1.0 * dis / num) 135 | return acc_avg_inr, dis_avg_inr 136 | 137 | 138 | def plot_fig(acc_all, dist_all): 139 | acc_avg_inr, dis_avg_inr = interval_avg(acc_all, dist_all) 140 | print("acc_avg_inr:", acc_avg_inr) 141 | print("dis_avg_inr:", dis_avg_inr) 142 | 143 | plt.scatter(dist_all, acc_all) 144 | plt.scatter(dis_avg_inr, acc_avg_inr, s=40, marker='x', c='red') 145 | 146 | for x, y in zip(dis_avg_inr, acc_avg_inr): 147 | plt.annotate("%.1f" % (y), xy=(x, y), xytext=(x - 0.005, y + 1.5), color='r', weight='heavy') 148 | 149 | plt.plot(dis_avg_inr, acc_avg_inr, c='red') 150 | # plt.show() 151 | plt.savefig('dist_acc.pdf') 152 | plt.close() -------------------------------------------------------------------------------- /methods/meta_template.py: -------------------------------------------------------------------------------- 1 | import backbone 2 | import torch 3 | import torch.nn as nn 4 | from torch.autograd import Variable 5 | import numpy as np 6 | import torch.nn.functional as F 7 | import utils 8 | from abc import abstractmethod 9 | 10 | 11 | class MetaTemplate(nn.Module): 12 | def __init__(self, model_func, n_way, n_support, change_way=True): 13 | super(MetaTemplate, self).__init__() 14 | self.n_way = n_way 15 | self.n_support = n_support 16 | self.n_query = -1 # (change depends on input) 17 | self.feature = model_func() 18 | self.feat_dim = self.feature.final_feat_dim 19 | self.change_way = change_way # some methods allow different_way classification during training and test 20 | 21 | @abstractmethod 22 | def set_forward(self, x, is_feature): 23 | pass 24 | 25 | @abstractmethod 26 | def set_forward_loss(self, x): 27 | pass 28 | 29 | def forward(self, x): 30 | out = self.feature.forward(x) 31 | return out 32 | 33 | def parse_feature(self, x, is_feature): 34 | x = Variable(x.cuda()) 35 | if is_feature: 36 | z_all = x 37 | else: 38 | x = x.contiguous().view(self.n_way * (self.n_support + self.n_query), *x.size()[2:]) 39 | z_all = self.feature.forward(x) 40 | z_all = z_all.view(self.n_way, self.n_support + self.n_query, -1) 41 | z_support = z_all[:, :self.n_support] 42 | z_query = z_all[:, self.n_support:] 43 | 44 | return z_support, z_query 45 | 46 | def correct(self, x): 47 | scores = self.set_forward(x) 48 | y_query = np.repeat(range(self.n_way), self.n_query) 49 | 50 | topk_scores, topk_labels = scores.data.topk(1, 1, True, True) 51 | topk_ind = topk_labels.cpu().numpy() 52 | top1_correct = np.sum(topk_ind[:, 0] == y_query) 53 | return float(top1_correct), len(y_query) 54 | 55 | def train_loop(self, epoch, train_loader, optimizer, tf_writer): 56 | print_freq = 10 57 | 58 | avg_loss = 0 59 | for i, (x, _) in enumerate(train_loader): 60 | self.n_query = x.size(1) - self.n_support 61 | if self.change_way: 62 | self.n_way = x.size(0) 63 | optimizer.zero_grad() 64 | loss = self.set_forward_loss(x) 65 | loss.backward() 66 | optimizer.step() 67 | avg_loss = avg_loss + loss.item() 68 | 69 | if i % print_freq == 0: 70 | # print(optimizer.state_dict()['param_groups'][0]['lr']) 71 | print('Epoch {:d} | Batch {:d}/{:d} | Loss {:f}'.format(epoch, i, len(train_loader), 72 | avg_loss / float(i + 1))) 73 | tf_writer.add_scalar('loss/train', avg_loss / float(i + 1), epoch) 74 | 75 | def test_loop(self, test_loader, record=None): 76 | correct = 0 77 | count = 0 78 | acc_all = [] 79 | 80 | iter_num = len(test_loader) 81 | from tqdm import tqdm 82 | for i, (x, _) in enumerate(tqdm(test_loader)): 83 | self.n_query = x.size(1) - self.n_support 84 | if self.change_way: 85 | self.n_way = x.size(0) 86 | correct_this, count_this = self.correct(x) 87 | acc_all.append(correct_this / count_this * 100) 88 | 89 | acc_all = np.asarray(acc_all) 90 | acc_mean = np.mean(acc_all) 91 | acc_std = np.std(acc_all) 92 | print('%d Test Acc = %4.2f%% +- %4.2f%%' % (iter_num, acc_mean, 1.96 * acc_std / np.sqrt(iter_num))) 93 | 94 | return acc_mean 95 | 96 | def set_forward_adaptation(self, x, 97 | is_feature=True): # further adaptation, default is fixing feature and train a new softmax clasifier 98 | assert is_feature == True, 'Feature is fixed in further adaptation' 99 | z_support, z_query = self.parse_feature(x, is_feature) 100 | 101 | z_support = z_support.contiguous().view(self.n_way * self.n_support, -1) 102 | z_query = z_query.contiguous().view(self.n_way * self.n_query, -1) 103 | 104 | y_support = torch.from_numpy(np.repeat(range(self.n_way), self.n_support)) 105 | y_support = Variable(y_support.cuda()) 106 | 107 | linear_clf = nn.Linear(self.feat_dim, self.n_way) 108 | linear_clf = linear_clf.cuda() 109 | 110 | set_optimizer = torch.optim.SGD(linear_clf.parameters(), lr=0.01, momentum=0.9, dampening=0.9, 111 | weight_decay=0.001) 112 | 113 | loss_function = nn.CrossEntropyLoss() 114 | loss_function = loss_function.cuda() 115 | 116 | batch_size = 4 117 | support_size = self.n_way * self.n_support 118 | for epoch in range(100): 119 | rand_id = np.random.permutation(support_size) 120 | for i in range(0, support_size, batch_size): 121 | set_optimizer.zero_grad() 122 | selected_id = torch.from_numpy(rand_id[i: min(i + batch_size, support_size)]).cuda() 123 | z_batch = z_support[selected_id] 124 | y_batch = y_support[selected_id] 125 | scores = linear_clf(z_batch) 126 | loss = loss_function(scores, y_batch) 127 | loss.backward() 128 | set_optimizer.step() 129 | 130 | scores = linear_clf(z_query) 131 | return scores 132 | 133 | def test_loop_with_dists(self, test_loader, all_cls_dists, base_dists, base_cls_dists, attr_num): 134 | acc_all, dist_all = [], [] 135 | 136 | iter_num = len(test_loader) 137 | from tqdm import tqdm 138 | for i, (x, y) in enumerate(tqdm(test_loader)): 139 | x, y = x.cuda(), y.cuda() 140 | self.n_query = x.size(1) - self.n_support 141 | if self.change_way: 142 | self.n_way = x.size(0) 143 | correct_this, count_this = self.correct(x) 144 | acc_all.append(correct_this / count_this * 100) 145 | 146 | sc_cls = y.unique() 147 | # original mean-task (down trend) 148 | task_dists = all_cls_dists[sc_cls, :].mean(0).unsqueeze(0) 149 | dist_all.append(torch.abs(base_dists - task_dists).sum(-1).mean().item() / attr_num) 150 | 151 | acc_all = np.asarray(acc_all) 152 | acc_mean = np.mean(acc_all) 153 | acc_std = np.std(acc_all) 154 | print('%d Test Acc = %4.2f%% +- %4.2f%%' % (iter_num, acc_mean, 1.96 * acc_std / np.sqrt(iter_num))) 155 | 156 | return acc_all, dist_all -------------------------------------------------------------------------------- /methods/relationnet.py: -------------------------------------------------------------------------------- 1 | # This code is modified from https://github.com/floodsung/LearningToCompare_FSL 2 | 3 | import backbone 4 | import torch 5 | import torch.nn as nn 6 | from torch.autograd import Variable 7 | import numpy as np 8 | import torch.nn.functional as F 9 | from methods.meta_template import MetaTemplate 10 | import utils 11 | 12 | 13 | class RelationNet(MetaTemplate): 14 | def __init__(self, model_func, n_way, n_support, loss_type='mse'): 15 | super(RelationNet, self).__init__(model_func, n_way, n_support) 16 | 17 | self.loss_type = loss_type # 'softmax'# 'mse' 18 | self.relation_module = RelationModule(self.feat_dim, 8, 19 | self.loss_type) # relation net features are not pooled, so self.feat_dim is [dim, w, h] 20 | 21 | if self.loss_type == 'mse': 22 | self.loss_fn = nn.MSELoss() 23 | else: 24 | self.loss_fn = nn.CrossEntropyLoss() 25 | 26 | def set_forward(self, x, is_feature=False): 27 | z_support, z_query = self.parse_feature(x, is_feature) 28 | 29 | z_support = z_support.contiguous() 30 | z_proto = z_support.view(self.n_way, self.n_support, *self.feat_dim).mean(1) 31 | z_query = z_query.contiguous().view(self.n_way * self.n_query, *self.feat_dim) 32 | 33 | z_proto_ext = z_proto.unsqueeze(0).repeat(self.n_query * self.n_way, 1, 1, 1, 1) 34 | z_query_ext = z_query.unsqueeze(0).repeat(self.n_way, 1, 1, 1, 1) 35 | z_query_ext = torch.transpose(z_query_ext, 0, 1) 36 | extend_final_feat_dim = self.feat_dim.copy() 37 | extend_final_feat_dim[0] *= 2 38 | relation_pairs = torch.cat((z_proto_ext, z_query_ext), 2).view(-1, *extend_final_feat_dim) 39 | relations = self.relation_module(relation_pairs).view(-1, self.n_way) 40 | 41 | return relations 42 | 43 | def set_forward_adaptation(self, x, is_feature=True): # overwrite parent function 44 | assert is_feature == True, 'Finetune only support fixed feature' 45 | full_n_support = self.n_support 46 | full_n_query = self.n_query 47 | relation_module_clone = RelationModule(self.feat_dim, 8, self.loss_type) 48 | relation_module_clone.load_state_dict(self.relation_module.state_dict()) 49 | 50 | z_support, z_query = self.parse_feature(x, is_feature) 51 | z_support = z_support.contiguous() 52 | set_optimizer = torch.optim.SGD(self.relation_module.parameters(), lr=0.01, momentum=0.9, dampening=0.9, 53 | weight_decay=0.001) 54 | 55 | self.n_support = 3 56 | self.n_query = 2 57 | 58 | z_support_cpu = z_support.data.cpu().numpy() 59 | for epoch in range(100): 60 | perm_id = np.random.permutation(full_n_support).tolist() 61 | sub_x = np.array([z_support_cpu[i, perm_id, :, :, :] for i in range(z_support.size(0))]) 62 | sub_x = torch.Tensor(sub_x).cuda() 63 | if self.change_way: 64 | self.n_way = sub_x.size(0) 65 | set_optimizer.zero_grad() 66 | y = torch.from_numpy(np.repeat(range(self.n_way), self.n_query)) 67 | scores = self.set_forward(sub_x, is_feature=True) 68 | if self.loss_type == 'mse': 69 | y_oh = utils.one_hot(y, self.n_way) 70 | y_oh = Variable(y_oh.cuda()) 71 | 72 | loss = self.loss_fn(scores, y_oh) 73 | else: 74 | y = Variable(y.cuda()) 75 | loss = self.loss_fn(scores, y) 76 | loss.backward() 77 | set_optimizer.step() 78 | 79 | self.n_support = full_n_support 80 | self.n_query = full_n_query 81 | z_proto = z_support.view(self.n_way, self.n_support, *self.feat_dim).mean(1) 82 | z_query = z_query.contiguous().view(self.n_way * self.n_query, *self.feat_dim) 83 | 84 | z_proto_ext = z_proto.unsqueeze(0).repeat(self.n_query * self.n_way, 1, 1, 1, 1) 85 | z_query_ext = z_query.unsqueeze(0).repeat(self.n_way, 1, 1, 1, 1) 86 | z_query_ext = torch.transpose(z_query_ext, 0, 1) 87 | extend_final_feat_dim = self.feat_dim.copy() 88 | extend_final_feat_dim[0] *= 2 89 | relation_pairs = torch.cat((z_proto_ext, z_query_ext), 2).view(-1, *extend_final_feat_dim) 90 | relations = self.relation_module(relation_pairs).view(-1, self.n_way) 91 | 92 | self.relation_module.load_state_dict(relation_module_clone.state_dict()) 93 | return relations 94 | 95 | def set_forward_loss(self, x): 96 | y = torch.from_numpy(np.repeat(range(self.n_way), self.n_query)) 97 | 98 | scores = self.set_forward(x) 99 | if self.loss_type == 'mse': 100 | y_oh = utils.one_hot(y, self.n_way) 101 | y_oh = Variable(y_oh.cuda()) 102 | 103 | return self.loss_fn(scores, y_oh) 104 | else: 105 | y = Variable(y.cuda()) 106 | return self.loss_fn(scores, y) 107 | 108 | 109 | class RelationConvBlock(nn.Module): 110 | def __init__(self, indim, outdim, padding=0): 111 | super(RelationConvBlock, self).__init__() 112 | self.indim = indim 113 | self.outdim = outdim 114 | self.C = nn.Conv2d(indim, outdim, 3, padding=padding) 115 | self.BN = nn.BatchNorm2d(outdim, momentum=1, affine=True) 116 | self.relu = nn.ReLU() 117 | self.pool = nn.MaxPool2d(2) 118 | 119 | self.parametrized_layers = [self.C, self.BN, self.relu, self.pool] 120 | 121 | for layer in self.parametrized_layers: 122 | backbone.init_layer(layer) 123 | 124 | self.trunk = nn.Sequential(*self.parametrized_layers) 125 | 126 | def forward(self, x): 127 | out = self.trunk(x) 128 | return out 129 | 130 | 131 | class RelationModule(nn.Module): 132 | """docstring for RelationNetwork""" 133 | 134 | def __init__(self, input_size, hidden_size, loss_type='mse'): 135 | super(RelationModule, self).__init__() 136 | 137 | self.loss_type = loss_type 138 | padding = 1 if (input_size[1] < 10) and (input_size[ 139 | 2] < 10) else 0 # when using Resnet, conv map without avgpooling is 7x7, need padding in block to do pooling 140 | 141 | self.layer1 = RelationConvBlock(input_size[0] * 2, input_size[0], padding=padding) 142 | self.layer2 = RelationConvBlock(input_size[0], input_size[0], padding=padding) 143 | 144 | shrink_s = lambda s: int((int((s - 2 + 2 * padding) / 2) - 2 + 2 * padding) / 2) 145 | 146 | self.fc1 = nn.Linear(input_size[0] * shrink_s(input_size[1]) * shrink_s(input_size[2]), hidden_size) 147 | self.fc2 = nn.Linear(hidden_size, 1) 148 | 149 | def forward(self, x): 150 | out = self.layer1(x) 151 | out = self.layer2(out) 152 | out = out.view(out.size(0), -1) 153 | out = F.relu(self.fc1(out)) 154 | if self.loss_type == 'mse': 155 | out = F.sigmoid(self.fc2(out)) 156 | elif self.loss_type == 'softmax': 157 | out = self.fc2(out) 158 | 159 | return out 160 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from torch.autograd import Variable 4 | import torch.nn as nn 5 | import torch.optim 6 | import json 7 | import torch.utils.data.sampler 8 | import os 9 | import glob 10 | import random 11 | import time 12 | 13 | import configs 14 | import backbone 15 | import data.feature_loader as feat_loader 16 | from data.datamgr import SetDataManager 17 | from methods.baselinetrain import BaselineTrain 18 | from methods.baselinefinetune import BaselineFinetune 19 | from methods.protonet import ProtoNet 20 | from methods.matchingnet import MatchingNet 21 | from methods.relationnet import RelationNet 22 | from methods.maml import MAML 23 | from methods.apnet import APNet_w_attrLoc, APNet_wo_attrLoc 24 | from io_utils import model_dict, parse_args, get_resume_file, get_best_file, get_assigned_file 25 | 26 | seed = 1 27 | np.random.seed(seed) 28 | torch.random.manual_seed(seed) 29 | 30 | def feature_evaluation(cl_data_file, model, n_way=5, n_support=5, n_query=15, adaptation=False): 31 | class_list = cl_data_file.keys() 32 | 33 | select_class = random.sample(class_list, n_way) 34 | z_all = [] 35 | for cl in select_class: 36 | img_feat = cl_data_file[cl] 37 | perm_ids = np.random.permutation(len(img_feat)).tolist() 38 | z_all.append([np.squeeze(img_feat[perm_ids[i]]) for i in range(n_support + n_query)]) # stack each batch 39 | 40 | z_all = torch.from_numpy(np.array(z_all)) 41 | 42 | model.n_query = n_query 43 | if adaptation: 44 | scores = model.set_forward_adaptation(z_all, is_feature=True) 45 | else: 46 | scores = model.set_forward(z_all, is_feature=True) 47 | pred = scores.data.cpu().numpy().argmax(axis=1) 48 | y = np.repeat(range(n_way), n_query) 49 | acc = np.mean(pred == y) * 100 50 | return acc 51 | 52 | 53 | if __name__ == '__main__': 54 | params = parse_args('test') 55 | 56 | acc_all = [] 57 | iter_num = 600 58 | attr_loc = False 59 | few_shot_params = dict(n_way=params.test_n_way, n_support=params.n_shot) 60 | 61 | split = params.split 62 | if params.save_iter != -1: 63 | split_str = split + "_" + str(params.save_iter) 64 | else: 65 | split_str = split 66 | if 'Conv' in params.model: 67 | image_size = 84 68 | else: 69 | image_size = 224 70 | datamgr = SetDataManager(image_size, n_eposide=iter_num, n_query=15, **few_shot_params) 71 | loadfile = configs.data_dir[params.dataset] + split + '.json' 72 | 73 | if params.method == 'baseline': 74 | model = BaselineFinetune(model_dict[params.model], **few_shot_params) 75 | elif params.method == 'baseline++': 76 | model = BaselineFinetune(model_dict[params.model], loss_type='dist', **few_shot_params) 77 | elif params.method == 'protonet': 78 | model = ProtoNet(model_dict[params.model], **few_shot_params) 79 | elif params.method == 'comet': 80 | assert params.dataset == 'CUB' 81 | model = COMET(model_dict[params.model], **few_shot_params) 82 | elif params.method == 'matchingnet': 83 | model = MatchingNet(model_dict[params.model], **few_shot_params) 84 | elif params.method in ['relationnet', 'relationnet_softmax']: 85 | if params.model == 'Conv4': 86 | feature_model = backbone.Conv4NP 87 | elif params.model == 'Conv6': 88 | feature_model = backbone.Conv6NP 89 | elif params.model == 'Conv4S': 90 | feature_model = backbone.Conv4SNP 91 | else: 92 | feature_model = lambda: model_dict[params.model](flatten=False) 93 | loss_type = 'mse' if params.method == 'relationnet' else 'softmax' 94 | model = RelationNet(feature_model, loss_type=loss_type, **few_shot_params) 95 | elif params.method in ['maml', 'maml_approx']: 96 | backbone.ConvBlock.maml = True 97 | backbone.SimpleBlock.maml = True 98 | backbone.BottleneckBlock.maml = True 99 | backbone.ResNet.maml = True 100 | model = MAML(model_dict[params.model], approx=(params.method == 'maml_approx'), **few_shot_params) 101 | if params.dataset in ['omniglot', 'cross_char']: # maml use different parameter in omniglot 102 | model.n_task = 32 103 | model.task_update_num = 1 104 | model.train_lr = 0.1 105 | elif params.method == 'apnet': 106 | if params.dataset == 'CUB': 107 | attr_loc = True 108 | attr_num = 109 109 | elif params.dataset == 'SUN': 110 | attr_num = 102 111 | elif params.dataset == 'AWA2': 112 | attr_num = 85 113 | else: 114 | AssertionError("not implement!") 115 | 116 | few_shot_params = dict(n_way=params.test_n_way, n_support=params.n_shot, attr_num=attr_num, attr_loc=attr_loc, dataset=params.dataset) 117 | if attr_loc: 118 | model = APNet_w_attrLoc(model_dict[params.model], **few_shot_params) 119 | else: 120 | model = APNet_wo_attrLoc(model_dict[params.model], **few_shot_params) 121 | else: 122 | raise ValueError('Unknown method') 123 | 124 | model = model.cuda() 125 | 126 | checkpoint_dir = '%s/checkpoints/%s/%s_%s_%s' % ( 127 | configs.save_dir, params.dataset, params.model, params.method, params.exp_str) 128 | if params.train_aug: 129 | checkpoint_dir += '_aug' 130 | if not params.method in ['baseline', 'baseline++']: 131 | checkpoint_dir += '_%dway_%dshot' % (params.train_n_way, params.n_shot) 132 | 133 | if not params.method in ['baseline', 'baseline++']: 134 | if params.save_iter != -1: 135 | modelfile = get_assigned_file(checkpoint_dir, params.save_iter) 136 | else: 137 | modelfile = get_best_file(checkpoint_dir) 138 | if modelfile is not None: 139 | tmp = torch.load(modelfile) 140 | model.load_state_dict(tmp['state']) 141 | 142 | if params.method in ['maml', 'maml_approx']: # maml do not support testing with feature 143 | novel_loader = datamgr.get_data_loader(loadfile, aug=False, is_train=False) 144 | if params.dataset == 'SUN' and params.model == 'Conv4': 145 | model.train_lr = 0.1 146 | if params.adaptation: 147 | model.task_update_num = 100 # We perform adaptation on MAML simply by updating more times. 148 | model.eval() 149 | acc_mean, acc_std = model.test_loop(novel_loader, return_std=True) 150 | else: 151 | novel_file = os.path.join(checkpoint_dir.replace("checkpoints", "features"), 152 | split_str + ".hdf5") # defaut split = novel, but you can also test base or val classes 153 | cl_data_file = feat_loader.init_loader(novel_file) 154 | 155 | from tqdm import tqdm 156 | for i in tqdm(range(iter_num)): 157 | acc = feature_evaluation(cl_data_file, model, n_query=15, adaptation=params.adaptation, **few_shot_params) 158 | acc_all.append(acc) 159 | 160 | acc_all = np.asarray(acc_all) 161 | acc_mean = np.mean(acc_all) 162 | acc_std = np.std(acc_all) 163 | print('%d Test Acc = %4.2f%% +- %4.2f%%' % (iter_num, acc_mean, 1.96 * acc_std / np.sqrt(iter_num))) -------------------------------------------------------------------------------- /data/dataset.py: -------------------------------------------------------------------------------- 1 | # This code is modified from https://github.com/facebookresearch/low-shot-shrink-hallucinate 2 | 3 | import torch 4 | from PIL import Image 5 | import json 6 | import numpy as np 7 | import torchvision.transforms as transforms 8 | import os 9 | from torch.utils.data import Dataset 10 | import random 11 | import copy 12 | import cv2 13 | from .transforms import * 14 | import PIL 15 | 16 | identity = lambda x: x 17 | 18 | class SimpleDataset: 19 | def __init__(self, data_file, image_size, transform, target_transform=identity, is_train=True): 20 | with open(data_file, 'r') as f: 21 | self.meta = json.load(f) 22 | self.transform = transform 23 | self.target_transform = target_transform 24 | self.flip = is_train 25 | self.image_size = image_size 26 | self.is_train = is_train 27 | 28 | def __getitem__(self, i): 29 | image_path = os.path.join(self.meta['image_names'][i]) 30 | # data_numpy = cv2.imread(image_path, cv2.IMREAD_COLOR | cv2.IMREAD_IGNORE_ORIENTATION) 31 | 32 | # used for SUN, cv2.imread returns None type 33 | data_numpy = np.array(PIL.Image.open(image_path).convert('RGB'))[:, :, ::-1] # to BGR 34 | 35 | if data_numpy is None: 36 | raise ValueError('Fail to read {}'.format(image_path)) 37 | 38 | r = 0 39 | c = np.array([data_numpy.shape[1], data_numpy.shape[0]]) // 2 40 | s = np.array([data_numpy.shape[1], data_numpy.shape[0]]) // 160 41 | 42 | if self.is_train: 43 | sf = 0.25 44 | rf = 30 45 | s = s * np.clip(np.random.randn() * sf + 1, 1 - sf, 1 + sf) 46 | r = np.clip(np.random.randn() * rf, -rf * 2, rf * 2) \ 47 | if random.random() <= 0.6 else 0 48 | 49 | if self.flip and random.random() <= 0.5: 50 | data_numpy = data_numpy[:, ::-1, :] 51 | c[0] = data_numpy.shape[1] - c[0] - 1 52 | 53 | trans = get_affine_transform(c, s, r, [self.image_size, self.image_size]) 54 | input = cv2.warpAffine( 55 | data_numpy, 56 | trans, 57 | (int(self.image_size), int(self.image_size)), 58 | flags=cv2.INTER_LINEAR) 59 | input = cv2.cvtColor(input, cv2.COLOR_BGR2RGB) 60 | input = Image.fromarray(input.transpose((1, 0, 2))) 61 | 62 | if self.transform: 63 | input = self.transform(input) 64 | target = self.target_transform(self.meta['image_labels'][i]) 65 | return input, target 66 | 67 | def __len__(self): 68 | return len(self.meta['image_names']) 69 | 70 | 71 | class SetDataset: 72 | def __init__(self, data_file, batch_size, image_size, transform, is_train=True, attr_loc=False): 73 | with open(data_file, 'r') as f: 74 | self.meta = json.load(f) 75 | 76 | self.cl_list = np.unique(self.meta['image_labels']).tolist() 77 | 78 | self.sub_meta = {} 79 | for cl in self.cl_list: 80 | self.sub_meta[cl] = [] 81 | 82 | if 'part' in self.meta: 83 | for x, y, z in zip(self.meta['image_names'], self.meta['image_labels'], self.meta['part']): 84 | self.sub_meta[y].append({'path': x, 'part': z}) 85 | else: 86 | print("not use attribute location or attribute location is unavailable!") 87 | for x, y in zip(self.meta['image_names'], self.meta['image_labels']): 88 | self.sub_meta[y].append({'path': x}) 89 | 90 | self.sub_dataloader = [] 91 | sub_data_loader_params = dict(batch_size=batch_size, 92 | shuffle=True, 93 | num_workers=0, # use main thread only or may receive multiple batches 94 | pin_memory=False) 95 | 96 | for cl in self.cl_list: 97 | sub_dataset = SubDataset(self.sub_meta[cl], cl, image_size, attr_loc, transform=transform, is_train=is_train) 98 | self.sub_dataloader.append(torch.utils.data.DataLoader(sub_dataset, **sub_data_loader_params)) 99 | 100 | def __getitem__(self, i): 101 | return next(iter(self.sub_dataloader[i])) 102 | 103 | def __len__(self): 104 | return len(self.cl_list) 105 | 106 | 107 | class EpisodicBatchSampler(object): 108 | def __init__(self, n_classes, n_way, n_episodes): 109 | self.n_classes = n_classes 110 | self.n_way = n_way 111 | self.n_episodes = n_episodes 112 | 113 | def __len__(self): 114 | return self.n_episodes 115 | 116 | def __iter__(self): 117 | for i in range(self.n_episodes): 118 | yield torch.randperm(self.n_classes)[:self.n_way] 119 | 120 | 121 | class SubDataset(Dataset): 122 | def __init__(self, sub_meta, cl, image_size, attr_loc=False, transform=transforms.ToTensor(), target_transform=identity, 123 | is_train=True): 124 | self.num_joints = 15 125 | 126 | self.is_train = is_train 127 | self.sub_meta = sub_meta 128 | self.cl = cl 129 | self.transform = transform 130 | self.target_transform = target_transform 131 | 132 | self.flip = is_train 133 | self.attr_loc = attr_loc 134 | 135 | self.image_size = image_size 136 | 137 | self.transform = transform 138 | self.target_transform = target_transform 139 | 140 | def __len__(self, ): 141 | return len(self.sub_meta) 142 | 143 | def __getitem__(self, idx): 144 | image_file = os.path.join(self.sub_meta[idx]['path']) 145 | 146 | #data_numpy = cv2.imread(image_file, cv2.IMREAD_COLOR | cv2.IMREAD_IGNORE_ORIENTATION) 147 | 148 | # used for SUN, cv2.imread returns None type 149 | data_numpy = np.array(PIL.Image.open(image_file).convert('RGB'))[:, :, ::-1] # to BGR 150 | 151 | if data_numpy is None: 152 | raise ValueError('Fail to read {}'.format(image_file)) 153 | 154 | if self.attr_loc: 155 | joints_vis = self.sub_meta[idx]['part'] 156 | joints_vis = np.array(joints_vis) 157 | 158 | r = 0 159 | c = np.array([data_numpy.shape[1], data_numpy.shape[0]]) // 2 160 | s = np.array([data_numpy.shape[1], data_numpy.shape[0]]) // 160 161 | 162 | if self.is_train: 163 | sf = 0.25 164 | rf = 30 165 | s = s * np.clip(np.random.randn() * sf + 1, 1 - sf, 1 + sf) 166 | r = np.clip(np.random.randn() * rf, -rf * 2, rf * 2) \ 167 | if random.random() <= 0.6 else 0 168 | 169 | if self.flip and random.random() <= 0.5: 170 | data_numpy = data_numpy[:, ::-1, :] 171 | if self.attr_loc: 172 | for i in range(self.num_joints): 173 | if joints_vis[i, 2] > 0.0: 174 | joints_vis[i, 0] = data_numpy.shape[1] - joints_vis[i, 0] 175 | c[0] = data_numpy.shape[1] - c[0] - 1 176 | 177 | trans = get_affine_transform(c, s, r, [self.image_size, self.image_size]) 178 | input = cv2.warpAffine( 179 | data_numpy, 180 | trans, 181 | (int(self.image_size), int(self.image_size)), 182 | flags=cv2.INTER_LINEAR) 183 | input = cv2.cvtColor(input, cv2.COLOR_BGR2RGB) 184 | input = Image.fromarray(input.transpose((1, 0, 2))) 185 | 186 | if self.transform: 187 | input = self.transform(input) 188 | 189 | target = self.target_transform(self.cl) 190 | 191 | if self.attr_loc is False: 192 | return input, target 193 | else: 194 | for i in range(self.num_joints): 195 | if joints_vis[i, 2] > 0.0: 196 | joints_vis[i, 0:2] = affine_transform(joints_vis[i, 0:2], trans) 197 | 198 | joints_vis = self.target_transform(joints_vis) 199 | return input, target, joints_vis -------------------------------------------------------------------------------- /methods/maml.py: -------------------------------------------------------------------------------- 1 | # This code is modified from https://github.com/dragen1860/MAML-Pytorch and https://github.com/katerakelly/pytorch-maml 2 | 3 | import backbone 4 | import torch 5 | import torch.nn as nn 6 | from torch.autograd import Variable 7 | import numpy as np 8 | import torch.nn.functional as F 9 | from methods.meta_template import MetaTemplate 10 | 11 | 12 | class MAML(MetaTemplate): 13 | def __init__(self, model_func, n_way, n_support, approx=False): 14 | super(MAML, self).__init__(model_func, n_way, n_support, change_way=False) 15 | 16 | self.loss_fn = nn.CrossEntropyLoss() 17 | self.classifier = backbone.Linear_fw(self.feat_dim, n_way) 18 | self.classifier.bias.data.fill_(0) 19 | 20 | self.n_task = 4 21 | self.task_update_num = 5 22 | self.train_lr = 0.01 23 | self.approx = approx # first order approx. 24 | 25 | def forward(self, x): 26 | out = self.feature.forward(x) 27 | scores = self.classifier.forward(out) 28 | return scores 29 | 30 | def set_forward(self, x, is_feature=False): 31 | assert is_feature == False, 'MAML do not support fixed feature' 32 | x = x.cuda() 33 | x_var = Variable(x) 34 | x_a_i = x_var[:, :self.n_support, :, :, :].contiguous().view(self.n_way * self.n_support, 35 | *x.size()[2:]) # support data 36 | x_b_i = x_var[:, self.n_support:, :, :, :].contiguous().view(self.n_way * self.n_query, 37 | *x.size()[2:]) # query data 38 | y_a_i = Variable( 39 | torch.from_numpy(np.repeat(range(self.n_way), self.n_support))).cuda() # label for support data 40 | 41 | fast_parameters = list(self.parameters()) # the first gradient calcuated in line 45 is based on original weight 42 | for weight in self.parameters(): 43 | weight.fast = None 44 | self.zero_grad() 45 | 46 | for task_step in range(self.task_update_num): 47 | scores = self.forward(x_a_i) 48 | set_loss = self.loss_fn(scores, y_a_i) 49 | grad = torch.autograd.grad(set_loss, fast_parameters, 50 | create_graph=True) # build full graph support gradient of gradient 51 | if self.approx: 52 | grad = [g.detach() for g in 53 | grad] # do not calculate gradient of gradient if using first order approximation 54 | fast_parameters = [] 55 | for k, weight in enumerate(self.parameters()): 56 | # for usage of weight.fast, please see Linear_fw, Conv_fw in backbone.py 57 | if weight.fast is None: 58 | weight.fast = weight - self.train_lr * grad[k] # create weight.fast 59 | else: 60 | weight.fast = weight.fast - self.train_lr * grad[ 61 | k] # create an updated weight.fast, note the '-' is not merely minus value, but to create a new weight.fast 62 | fast_parameters.append( 63 | weight.fast) # gradients calculated in line 45 are based on newest fast weight, but the graph will retain the link to old weight.fasts 64 | 65 | scores = self.forward(x_b_i) 66 | return scores 67 | 68 | def set_forward_adaptation(self, x, is_feature=False): # overwrite parrent function 69 | raise ValueError('MAML performs further adapation simply by increasing task_upate_num') 70 | 71 | def set_forward_loss(self, x): 72 | scores = self.set_forward(x, is_feature=False) 73 | y_b_i = Variable(torch.from_numpy(np.repeat(range(self.n_way), self.n_query))).cuda() 74 | loss = self.loss_fn(scores, y_b_i) 75 | 76 | return loss 77 | 78 | def train_loop(self, epoch, train_loader, optimizer, tf_writer): # overwrite parrent function 79 | print_freq = 10 80 | avg_loss = 0 81 | task_count = 0 82 | loss_all = [] 83 | optimizer.zero_grad() 84 | 85 | # train 86 | for i, (x, _) in enumerate(train_loader): 87 | self.n_query = x.size(1) - self.n_support 88 | assert self.n_way == x.size(0), "MAML do not support way change" 89 | 90 | loss = self.set_forward_loss(x) 91 | avg_loss = avg_loss + loss.item() 92 | loss_all.append(loss) 93 | 94 | task_count += 1 95 | 96 | if task_count == self.n_task: # MAML update several tasks at one time 97 | loss_q = torch.stack(loss_all).sum(0) 98 | loss_q.backward() 99 | 100 | optimizer.step() 101 | task_count = 0 102 | loss_all = [] 103 | optimizer.zero_grad() 104 | if i % print_freq == 0: 105 | print('Epoch {:d} | Batch {:d}/{:d} | Loss {:f}'.format(epoch, i, len(train_loader), 106 | avg_loss / float(i + 1))) 107 | tf_writer.add_scalar('loss/train', avg_loss / float(i + 1), epoch) 108 | 109 | def test_loop(self, test_loader, return_std=False): # overwrite parrent function 110 | correct = 0 111 | count = 0 112 | acc_all = [] 113 | 114 | iter_num = len(test_loader) 115 | from tqdm import tqdm 116 | for i, (x, _) in enumerate(tqdm(test_loader)): 117 | self.n_query = x.size(1) - self.n_support 118 | assert self.n_way == x.size(0), "MAML do not support way change" 119 | correct_this, count_this = self.correct(x) 120 | acc_all.append(correct_this / count_this * 100) 121 | 122 | acc_all = np.asarray(acc_all) 123 | acc_mean = np.mean(acc_all) 124 | acc_std = np.std(acc_all) 125 | print('%d Test Acc = %4.2f%% +- %4.2f%%' % (iter_num, acc_mean, 1.96 * acc_std / np.sqrt(iter_num))) 126 | if return_std: 127 | return acc_mean, acc_std 128 | else: 129 | return acc_mean 130 | 131 | def test_loop_with_dists(self, test_loader, all_cls_dists, base_dists, 132 | base_cls_dists, attr_num): # overwrite parrent function 133 | correct = 0 134 | count = 0 135 | acc_all = [] 136 | dist_all = [] 137 | 138 | iter_num = len(test_loader) 139 | from tqdm import tqdm 140 | for i, (x, y) in enumerate(tqdm(test_loader)): 141 | # original mean-task (down trend) 142 | sc_cls = y.unique() 143 | task_dists = all_cls_dists[sc_cls, :].mean(0).unsqueeze(0) 144 | dist_all.append(torch.abs(base_dists - task_dists).sum(-1).mean().item() / attr_num) 145 | 146 | self.n_query = x.size(1) - self.n_support 147 | assert self.n_way == x.size(0), "MAML do not support way change" 148 | correct_this, count_this = self.correct(x) 149 | acc_all.append(correct_this / count_this * 100) 150 | 151 | # task-agnostic (up trend) 152 | # task_cls_dists = all_cls_dists[sc_cls, :].unsqueeze(1) 153 | # distance = torch.abs(base_cls_dists.unsqueeze(0) - task_cls_dists).sum(-1).mean() 154 | # dist_all.append(distance.item()) 155 | 156 | # class-agnostic (down trend) 157 | # task_dists = all_cls_dists[sc_cls, :].mean(0) 158 | # dist_all.append(torch.abs(base_dists - task_dists).sum().item() / len(attr_label_split)) 159 | 160 | # original (up trend) 161 | # task_cls_dists = all_cls_dists[sc_cls, :].unsqueeze(0) 162 | # distance = torch.abs(base_cls_dists - task_cls_dists).sum(-1).mean() 163 | # dist_all.append(distance.item()) 164 | 165 | # original mini-task (first up then down ?) 166 | # task_dists = all_cls_dists[sc_cls, :].mean(0).unsqueeze(0) 167 | # dist_all.append(torch.abs(base_dists - task_dists).sum(-1).min().item()) 168 | 169 | # ideal with task-agnostic (down trend) 170 | # task_cls_dists = all_cls_dists[sc_cls, :] 171 | # d = 0 172 | # for j in range(task_cls_dists.shape[0]): 173 | # task_dist = task_cls_dists[j] 174 | # 175 | # start = 0 176 | # for split in attr_label_split: 177 | # end = start + split 178 | # task_dist_split = task_dist[start:end].unsqueeze(0) 179 | # base_dist_split = base_cls_dists[:, start:end] 180 | # d_split = torch.abs(base_dist_split - task_dist_split).sum(-1).min() 181 | # d += d_split.item() 182 | # start = end 183 | # d /= (len(attr_label_split)*task_cls_dists.shape[0]) 184 | # dist_all.append(d) 185 | 186 | acc_all = np.asarray(acc_all) 187 | acc_mean = np.mean(acc_all) 188 | acc_std = np.std(acc_all) 189 | print('%d Test Acc = %4.2f%% +- %4.2f%%' % (iter_num, acc_mean, 1.96 * acc_std / np.sqrt(iter_num))) 190 | return acc_all, dist_all 191 | 192 | -------------------------------------------------------------------------------- /plot_distance_acc.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from torch.autograd import Variable 4 | import torch.nn as nn 5 | import torch.optim 6 | import json 7 | import torch.utils.data.sampler 8 | import os 9 | import glob 10 | import random 11 | import time 12 | 13 | import configs 14 | import backbone 15 | import data.feature_loader as feat_loader 16 | from data.datamgr import SetDataManager 17 | from methods.baselinetrain import BaselineTrain 18 | from methods.baselinefinetune import BaselineFinetune 19 | from methods.protonet import ProtoNet 20 | from methods.matchingnet import MatchingNet 21 | from methods.relationnet import RelationNet 22 | from methods.maml import MAML 23 | from methods.apnet import APNet_w_attrLoc, APNet_wo_attrLoc 24 | from io_utils import model_dict, parse_args, get_resume_file, get_best_file, get_assigned_file 25 | from utils import get_attr_distance, plot_fig 26 | 27 | # seed = 1 28 | # np.random.seed(seed) 29 | # torch.random.manual_seed(seed) 30 | 31 | def feature_evaluation_with_dists(cl_data_file, model, n_way=5, n_support=5, n_query=15, adaptation=False): 32 | class_list = cl_data_file.keys() 33 | 34 | select_class = random.sample(class_list, n_way) 35 | z_all = [] 36 | for cl in select_class: 37 | img_feat = cl_data_file[cl] 38 | perm_ids = np.random.permutation(len(img_feat)).tolist() 39 | z_all.append([np.squeeze(img_feat[perm_ids[i]]) for i in range(n_support + n_query)]) # stack each batch 40 | 41 | z_all = torch.from_numpy(np.array(z_all)) 42 | 43 | model.n_query = n_query 44 | if adaptation: 45 | scores = model.set_forward_adaptation(z_all, is_feature=True) 46 | else: 47 | scores = model.set_forward(z_all, is_feature=True) 48 | pred = scores.data.cpu().numpy().argmax(axis=1) 49 | y = np.repeat(range(n_way), n_query) 50 | acc = np.mean(pred == y) * 100 51 | return acc, select_class 52 | 53 | 54 | if __name__ == '__main__': 55 | params = parse_args('test') 56 | 57 | acc_all = [] 58 | iter_num, n_query = 600, 35 # to reduce the influence of random factors 59 | attr_loc = False 60 | few_shot_params = dict(n_way=params.test_n_way, n_support=params.n_shot) 61 | 62 | split = params.split 63 | if params.save_iter != -1: 64 | split_str = split + "_" + str(params.save_iter) 65 | else: 66 | split_str = split 67 | if 'Conv' in params.model: 68 | image_size = 84 69 | else: 70 | image_size = 224 71 | datamgr = SetDataManager(image_size, n_eposide=iter_num, n_query=n_query, **few_shot_params) 72 | if params.dataset == 'AWA2': 73 | # for AWA2, we use both validation and test classes as novel classes, because the number of test classes in AWA2 is too small (only 10) 74 | loadfile = configs.data_dir[params.dataset] + 'val_novel.json' 75 | else: 76 | loadfile = configs.data_dir[params.dataset] + split + '.json' 77 | novel_loader = datamgr.get_data_loader(loadfile, aug=False, is_train=False) 78 | 79 | base_datamgr = SetDataManager(image_size, n_eposide=iter_num, n_query=n_query, **few_shot_params) 80 | base_loadfile = configs.data_dir[params.dataset] + 'base.json' 81 | base_loader = base_datamgr.get_data_loader(base_loadfile, aug=False, is_train=False) 82 | all_cls_dists, base_dists, base_cls_dists = get_attr_distance(base_loader, params.dataset) 83 | 84 | if params.dataset == 'CUB': 85 | if params.method in ['comet', 'apnet']: 86 | attr_loc = True 87 | attr_num = 109 88 | elif params.dataset == 'SUN': 89 | attr_num = 102 90 | elif params.dataset == 'AWA2': 91 | attr_num = 85 92 | else: 93 | AssertionError("not implement!") 94 | 95 | if params.method == 'baseline': 96 | model = BaselineFinetune(model_dict[params.model], **few_shot_params) 97 | elif params.method == 'baseline++': 98 | model = BaselineFinetune(model_dict[params.model], loss_type='dist', **few_shot_params) 99 | elif params.method == 'protonet': 100 | model = ProtoNet(model_dict[params.model], **few_shot_params) 101 | elif params.method == 'comet': 102 | assert params.dataset == 'CUB' 103 | model = COMET(model_dict[params.model], **few_shot_params) 104 | elif params.method == 'matchingnet': 105 | model = MatchingNet(model_dict[params.model], **few_shot_params) 106 | elif params.method in ['relationnet', 'relationnet_softmax']: 107 | if params.model == 'Conv4': 108 | feature_model = backbone.Conv4NP 109 | elif params.model == 'Conv6': 110 | feature_model = backbone.Conv6NP 111 | elif params.model == 'Conv4S': 112 | feature_model = backbone.Conv4SNP 113 | else: 114 | feature_model = lambda: model_dict[params.model](flatten=False) 115 | loss_type = 'mse' if params.method == 'relationnet' else 'softmax' 116 | model = RelationNet(feature_model, loss_type=loss_type, **few_shot_params) 117 | elif params.method in ['maml', 'maml_approx']: 118 | backbone.ConvBlock.maml = True 119 | backbone.SimpleBlock.maml = True 120 | backbone.BottleneckBlock.maml = True 121 | backbone.ResNet.maml = True 122 | model = MAML(model_dict[params.model], approx=(params.method == 'maml_approx'), **few_shot_params) 123 | if params.dataset in ['omniglot', 'cross_char']: # maml use different parameter in omniglot 124 | model.n_task = 32 125 | model.task_update_num = 1 126 | model.train_lr = 0.1 127 | elif params.method == 'apnet': 128 | few_shot_params = dict(n_way=params.test_n_way, n_support=params.n_shot, attr_num=attr_num, attr_loc=attr_loc, dataset=params.dataset) 129 | if attr_loc: 130 | model = APNet_w_attrLoc(model_dict[params.model], **few_shot_params) 131 | else: 132 | model = APNet_wo_attrLoc(model_dict[params.model], **few_shot_params) 133 | else: 134 | raise ValueError('Unknown method') 135 | 136 | model = model.cuda() 137 | 138 | checkpoint_dir = '%s/checkpoints/%s/%s_%s_%s' % ( 139 | configs.save_dir, params.dataset, params.model, params.method, params.exp_str) 140 | if params.train_aug: 141 | checkpoint_dir += '_aug' 142 | if not params.method in ['baseline', 'baseline++']: 143 | checkpoint_dir += '_%dway_%dshot' % (params.train_n_way, params.n_shot) 144 | 145 | if not params.method in ['baseline', 'baseline++']: 146 | if params.save_iter != -1: 147 | modelfile = get_assigned_file(checkpoint_dir, params.save_iter) 148 | else: 149 | modelfile = get_best_file(checkpoint_dir) 150 | if modelfile is not None: 151 | tmp = torch.load(modelfile) 152 | model.load_state_dict(tmp['state']) 153 | 154 | if params.method in ['maml', 'maml_approx']: # maml do not support testing with feature 155 | if params.dataset == 'SUN' and params.model == 'Conv4': 156 | model.train_lr = 0.1 157 | if params.adaptation: 158 | model.task_update_num = 100 # We perform adaptation on MAML simply by updating more times. 159 | model.eval() 160 | acc_all, dist_all = model.test_loop_with_dists(novel_loader, all_cls_dists, base_dists, base_cls_dists, attr_num) 161 | else: 162 | novel_file = os.path.join(checkpoint_dir.replace("checkpoints", "features"), 163 | split_str + ".hdf5") # defaut split = novel, but you can also test base or val classes 164 | print("novel_file:", novel_file) 165 | cl_data_file = feat_loader.init_loader(novel_file) 166 | 167 | dist_all = [] 168 | from tqdm import tqdm 169 | for i in tqdm(range(iter_num)): 170 | acc, sc_cls = feature_evaluation_with_dists(cl_data_file, model, n_query=15, adaptation=params.adaptation, 171 | **few_shot_params) 172 | acc_all.append(acc) 173 | 174 | # original mean-task (down trend) 175 | task_dists = all_cls_dists[sc_cls, :].mean(0).unsqueeze(0) 176 | dist_all.append(torch.abs(base_dists - task_dists).sum(-1).mean().item() / attr_num) 177 | 178 | acc_all = np.asarray(acc_all) 179 | acc_mean = np.mean(acc_all) 180 | acc_std = np.std(acc_all) 181 | print('%d Test Acc = %4.2f%% +- %4.2f%%' % (iter_num, acc_mean, 1.96 * acc_std / np.sqrt(iter_num))) 182 | plot_fig(acc_all, dist_all) 183 | 184 | if params.runs is None: 185 | distance_file = './distance-{}.txt'.format(params.n_shot) 186 | accuracy_file = './accuracy-{}.txt'.format(params.n_shot) 187 | else: 188 | distance_file = './distance-{}-runs{}-{}.txt'.format(params.n_shot, params.runs, params.method) 189 | accuracy_file = './accuracy-{}-runs{}-{}.txt'.format(params.n_shot, params.runs, params.method) 190 | 191 | with open(distance_file, 'w') as f: 192 | for dis in dist_all: 193 | f.write(str(dis) + ' ') 194 | 195 | with open(accuracy_file, 'w') as f: 196 | for acc in acc_all: 197 | f.write(str(acc) + ' ') 198 | print("Write files successfully!") -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | from torch.autograd import Variable 5 | import torch.optim 6 | import torch.optim.lr_scheduler as lr_scheduler 7 | import time 8 | import os 9 | import glob 10 | 11 | import configs 12 | import backbone 13 | from data.datamgr import SimpleDataManager, SetDataManager 14 | from methods.baselinetrain import BaselineTrain 15 | from methods.baselinefinetune import BaselineFinetune 16 | from methods.protonet import ProtoNet 17 | from methods.matchingnet import MatchingNet 18 | from methods.relationnet import RelationNet 19 | from methods.maml import MAML 20 | from methods.apnet import APNet_w_attrLoc, APNet_wo_attrLoc 21 | from io_utils import model_dict, parse_args, get_resume_file 22 | from torch.utils.tensorboard import SummaryWriter 23 | 24 | def train(base_loader, val_loader, model, optimization, start_epoch, stop_epoch, params, tf_writer): 25 | if optimization == 'Adam': 26 | optimizer = torch.optim.Adam(model.parameters()) 27 | else: 28 | raise ValueError('Unknown optimization, please define by yourself') 29 | 30 | max_acc = 0 31 | 32 | for epoch in range(start_epoch, stop_epoch): 33 | model.train() 34 | # model.train_loop(epoch, base_loader, optimizer, tf_writer, params.beta) #model are called by reference, no need to return 35 | model.train_loop(epoch, base_loader, optimizer, tf_writer) 36 | model.eval() 37 | 38 | if not os.path.isdir(params.checkpoint_dir): 39 | os.makedirs(params.checkpoint_dir) 40 | 41 | acc = model.test_loop(val_loader) 42 | tf_writer.add_scalar('acc/test', acc, epoch) 43 | 44 | if acc > max_acc: # for baseline and baseline++, we don't use validation in default and we let acc = -1, but we allow options to validate with DB index 45 | print("best model! save...") 46 | max_acc = acc 47 | outfile = os.path.join(params.checkpoint_dir, 'best_model.tar') 48 | torch.save({'epoch': epoch, 'state': model.state_dict()}, outfile) 49 | 50 | if (epoch % params.save_freq == 0) or (epoch == stop_epoch - 1): 51 | outfile = os.path.join(params.checkpoint_dir, '{:d}.tar'.format(epoch)) 52 | torch.save({'epoch': epoch, 'state': model.state_dict()}, outfile) 53 | 54 | return model 55 | 56 | 57 | if __name__ == '__main__': 58 | np.random.seed(10) 59 | params = parse_args('train') 60 | 61 | base_file = configs.data_dir[params.dataset] + 'base.json' 62 | val_file = configs.data_dir[params.dataset] + 'val.json' 63 | 64 | if 'Conv' in params.model: 65 | image_size = 84 66 | else: 67 | image_size = 224 68 | 69 | optimization = 'Adam' 70 | 71 | if params.stop_epoch == -1: 72 | if params.method in ['baseline', 'baseline++']: 73 | if params.dataset in ['CUB']: 74 | params.stop_epoch = 200 # This is different as stated in the open-review paper. However, using 400 epoch in baseline actually lead to over-fitting 75 | else: 76 | params.stop_epoch = 400 # default 77 | else: # meta-learning methods 78 | if params.n_shot == 1: 79 | params.stop_epoch = 600 80 | elif params.n_shot == 5: 81 | params.stop_epoch = 400 82 | else: 83 | params.stop_epoch = 600 # default 84 | 85 | if params.method in ['baseline', 'baseline++'] : 86 | base_datamgr = SimpleDataManager(image_size, batch_size = 16) 87 | base_loader = base_datamgr.get_data_loader( base_file , aug = params.train_aug, is_train=True) 88 | val_datamgr = SimpleDataManager(image_size, batch_size = 64) 89 | val_loader = val_datamgr.get_data_loader( val_file, aug = False, is_train=False) 90 | 91 | if params.dataset == 'omniglot': 92 | assert params.num_classes >= 4112, 'class number need to be larger than max label id in base class' 93 | if params.dataset == 'cross_char': 94 | assert params.num_classes >= 1597, 'class number need to be larger than max label id in base class' 95 | 96 | if params.method == 'baseline': 97 | model = BaselineTrain( model_dict[params.model], params.num_classes) 98 | elif params.method == 'baseline++': 99 | model = BaselineTrain( model_dict[params.model], params.num_classes, loss_type = 'dist') 100 | 101 | elif params.method in ['protonet', 'comet', 'matchingnet','relationnet', 'relationnet_softmax', 'maml', 'maml_approx', 'apnet']: 102 | n_query = max(1, int(16* params.test_n_way/params.train_n_way)) #if test_n_way is smaller than train_n_way, reduce n_query to keep batch size small 103 | 104 | if params.dataset == 'CUB' and params.method in ['comet', 'apnet']: 105 | attr_loc = True 106 | else: 107 | attr_loc = False 108 | 109 | train_few_shot_params = dict(n_way = params.train_n_way, n_support = params.n_shot) 110 | base_datamgr = SetDataManager(image_size, n_query = n_query, **train_few_shot_params) 111 | base_loader = base_datamgr.get_data_loader( base_file , aug = params.train_aug, is_train=True, attr_loc=attr_loc) 112 | 113 | test_few_shot_params = dict(n_way = params.test_n_way, n_support = params.n_shot) 114 | val_datamgr = SetDataManager(image_size, n_query = n_query, **test_few_shot_params) 115 | val_loader = val_datamgr.get_data_loader( val_file, aug = False, is_train=False, attr_loc=attr_loc) 116 | #a batch for SetDataManager: a [n_way, n_support + n_query, dim, w, h] tensor 117 | 118 | if params.method == 'protonet': 119 | model = ProtoNet( model_dict[params.model], **train_few_shot_params ) 120 | elif params.method == 'comet': 121 | model = COMET( model_dict[params.model], **train_few_shot_params ) 122 | elif params.method == 'matchingnet': 123 | model = MatchingNet( model_dict[params.model], **train_few_shot_params ) 124 | elif params.method in ['relationnet', 'relationnet_softmax']: 125 | if params.model == 'Conv4': 126 | feature_model = backbone.Conv4NP 127 | elif params.model == 'Conv6': 128 | feature_model = backbone.Conv6NP 129 | elif params.model == 'Conv4S': 130 | feature_model = backbone.Conv4SNP 131 | else: 132 | feature_model = lambda: model_dict[params.model]( flatten = False ) 133 | loss_type = 'mse' if params.method == 'relationnet' else 'softmax' 134 | 135 | model = RelationNet( feature_model, loss_type = loss_type , **train_few_shot_params ) 136 | elif params.method in ['maml' , 'maml_approx']: 137 | backbone.ConvBlock.maml = True 138 | backbone.SimpleBlock.maml = True 139 | backbone.BottleneckBlock.maml = True 140 | backbone.ResNet.maml = True 141 | model = MAML( model_dict[params.model], approx = (params.method == 'maml_approx') , **train_few_shot_params ) 142 | if params.dataset in ['omniglot', 'cross_char']: #maml use different parameter in omniglot 143 | model.n_task = 32 144 | model.task_update_num = 1 145 | model.train_lr = 0.1 146 | elif params.method == 'apnet': 147 | if params.dataset == 'CUB': 148 | attr_num = 109 149 | elif params.dataset == 'SUN': 150 | attr_num = 102 151 | elif params.dataset == 'AWA2': 152 | attr_num = 85 153 | else: 154 | AssertionError('not implement!') 155 | train_few_shot_params = dict(n_way=params.train_n_way, n_support=params.n_shot, 156 | attr_num=attr_num, attr_loc=attr_loc, dataset=params.dataset) 157 | if attr_loc: 158 | model = APNet_w_attrLoc(model_dict[params.model], **train_few_shot_params) 159 | else: 160 | model = APNet_wo_attrLoc(model_dict[params.model], **train_few_shot_params) 161 | else: 162 | raise ValueError('Unknown method') 163 | 164 | model = model.cuda() 165 | 166 | params.checkpoint_dir = '%s/checkpoints/%s/%s_%s_%s' % ( 167 | configs.save_dir, params.dataset, params.model, params.method, params.exp_str) 168 | if params.train_aug: 169 | params.checkpoint_dir += '_aug' 170 | if not params.method in ['baseline', 'baseline++']: 171 | params.checkpoint_dir += '_%dway_%dshot' % (params.train_n_way, params.n_shot) 172 | 173 | if not os.path.isdir(params.checkpoint_dir): 174 | os.makedirs(params.checkpoint_dir) 175 | 176 | store_name = '_'.join([params.dataset, params.model, params.method, params.exp_str]) 177 | # wandb.init(project="fewshot_images", tensorboard=True, name=store_name) 178 | 179 | start_epoch = params.start_epoch 180 | stop_epoch = params.stop_epoch 181 | if params.method == 'maml' or params.method == 'maml_approx': 182 | stop_epoch = params.stop_epoch * model.n_task # maml use multiple tasks in one update 183 | 184 | if params.resume: 185 | resume_file = get_resume_file(params.checkpoint_dir) 186 | if resume_file is not None: 187 | tmp = torch.load(resume_file) 188 | start_epoch = tmp['epoch'] + 1 189 | model.load_state_dict(tmp['state']) 190 | elif params.warmup: # We also support warmup from pretrained baseline feature, but we never used in our paper 191 | baseline_checkpoint_dir = '%s/checkpoints/%s/%s_%s' % ( 192 | configs.save_dir, params.dataset, params.model, 'baseline') 193 | if params.train_aug: 194 | baseline_checkpoint_dir += '_aug' 195 | warmup_resume_file = get_resume_file(baseline_checkpoint_dir) 196 | tmp = torch.load(warmup_resume_file) 197 | if tmp is not None: 198 | state = tmp['state'] 199 | state_keys = list(state.keys()) 200 | for i, key in enumerate(state_keys): 201 | if "feature." in key: 202 | newkey = key.replace("feature.", 203 | "") # an architecture model has attribute 'feature', load architecture feature to backbone by casting name from 'feature.trunk.xx' to 'trunk.xx' 204 | state[newkey] = state.pop(key) 205 | else: 206 | state.pop(key) 207 | model.feature.load_state_dict(state) 208 | else: 209 | raise ValueError('No warm_up file') 210 | 211 | tf_writer = SummaryWriter(log_dir=params.checkpoint_dir) 212 | 213 | # print("beta:", params.beta) 214 | # # model2 = CBModel(model, params.method) 215 | # # model2 = MTModel(model, params.method) 216 | # # model2 = model2.cuda() 217 | # model2 = model 218 | # model2 = train(base_loader, val_loader, model2, optimization, start_epoch, stop_epoch, params, tf_writer) 219 | # print('*' * 10 + 'Test' + '*' * 10) 220 | # acc = model2.test_loop(test_loader) 221 | # tf_writer.add_scalar('acc/final_test', acc, stop_epoch) 222 | model = train(base_loader, val_loader, model, optimization, start_epoch, stop_epoch, params, tf_writer) -------------------------------------------------------------------------------- /methods/apnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | import random 6 | import time 7 | import json 8 | from sklearn.metrics import confusion_matrix 9 | from methods.meta_template import MetaTemplate 10 | from torch.autograd import Variable 11 | 12 | class APNetTemplate(MetaTemplate): 13 | def __init__(self, model_func, n_way, n_support, attr_num, attr_loc=False, dataset='CUB'): 14 | super(APNetTemplate, self).__init__(model_func, n_way, n_support) 15 | self.attr_loc = attr_loc 16 | 17 | final_feat_dim = self.feature.final_feat_dim 18 | if isinstance(final_feat_dim, list): 19 | if self.attr_loc: 20 | self.input_dim = final_feat_dim[0] * 16 21 | else: 22 | self.input_dim = 1 23 | for dim in final_feat_dim: 24 | self.input_dim *= dim 25 | else: 26 | self.input_dim = final_feat_dim 27 | if self.attr_loc: 28 | self.input_dim *= 16 29 | self.attr_num = attr_num 30 | self.beta = 0.6 31 | 32 | self.classifier = nn.Linear(self.input_dim, self.attr_num*2) # only consider binary attribute 33 | self.loss_fn = nn.CrossEntropyLoss() 34 | self.read_attr_labels(dataset) 35 | 36 | def read_attr_labels(self, dataset): 37 | # TODO: change the filenames with your own path, and add files to filter attribute labels in CUB 38 | if dataset == 'CUB': 39 | filename = '/home/huminyang/Code/comet/CUB/filelists/CUB/CUB_200_2011/masked_class_attribute_labels.txt' 40 | elif dataset == 'AWA2': 41 | filename = '/home/huminyang/Code/APNet/filelists/AWA2/class_attribute_label.txt' 42 | else: 43 | AssertionError('not implement!') 44 | 45 | attr_labels_binary = [] 46 | with open(filename, 'r') as f: 47 | for line in f.readlines(): 48 | line_split = line.strip().split(' ') 49 | float_line = [] 50 | for str_num in line_split: 51 | float_line.append(int(str_num)) 52 | attr_labels_binary.append(float_line) 53 | self.attr_labels_split = torch.from_numpy(np.array(attr_labels_binary)).cuda() 54 | 55 | def correct(self, x, joints): 56 | # correct for image classification 57 | scores = self.set_forward(x, joints) 58 | topk_scores, topk_labels = scores.data.topk(1, 1, True, True) 59 | topk_ind = topk_labels.cpu().numpy() 60 | 61 | y_query = np.repeat(range(self.n_way), self.n_query) 62 | top1_correct = np.sum(topk_ind[:, 0] == y_query) 63 | return float(top1_correct), len(y_query) 64 | 65 | def forward(self, x, joints): 66 | if joints is None: 67 | z_support, z_query = super().parse_feature(x, is_feature=False) 68 | else: 69 | z_support, z_query = self.parse_feature(x, joints, is_feature=False) 70 | feature = torch.cat([z_support, z_query], dim=1) # (n_way, n_support+n_query, dim) 71 | feature = feature.view(-1, self.input_dim) 72 | logits = self.classifier(feature) 73 | return torch.split(logits, 2, -1) 74 | 75 | def set_forward(self, x, joints): 76 | logits = self.forward(x, joints) 77 | logits = torch.cat(logits, dim=-1) 78 | logits = logits.view(self.n_way, self.n_support + self.n_query, -1) 79 | 80 | z_support = logits[:, :self.n_support] 81 | z_query = logits[:, self.n_support:] 82 | 83 | z_support = z_support.contiguous() 84 | z_proto = z_support.view(self.n_way, self.n_support, -1).mean(1) # the shape of z is [n_data, n_dim] 85 | z_query = z_query.contiguous().view(self.n_way * self.n_query, -1) 86 | 87 | scores = F.cosine_similarity(z_query.unsqueeze(1), z_proto, dim=-1) 88 | return scores / 0.2 89 | 90 | def set_forward_loss1(self, x, joints): 91 | y_query = torch.from_numpy(np.repeat(range(self.n_way), self.n_query)) 92 | y_query = y_query.cuda() 93 | 94 | scores = self.set_forward(x, joints) 95 | return self.loss_fn(scores, y_query) 96 | 97 | def set_forward_loss2(self, x, ys, joints): 98 | ys = ys.view(-1, self.attr_num) 99 | logits = self.forward(x, joints) 100 | logits = torch.cat(logits, dim=0) 101 | return self.loss_fn(logits, ys.transpose(1, 0).reshape(-1)) 102 | 103 | class APNet_w_attrLoc(APNetTemplate): 104 | def __init__(self, model_func, n_way, n_support, attr_num, attr_loc=False, dataset='CUB'): 105 | super(APNet_w_attrLoc, self).__init__(model_func, n_way, n_support, attr_num, attr_loc, dataset) 106 | self.globalpool = nn.AdaptiveAvgPool2d((1, 1)) 107 | 108 | # this function originates from comet 109 | def parse_feature(self, x, joints, is_feature): 110 | x = Variable(x.cuda()) 111 | if is_feature: 112 | z_all = x 113 | else: 114 | x = x.contiguous().view(self.n_way * (self.n_support + self.n_query), *x.size()[2:]) 115 | z_all = self.feature.forward(x) 116 | z_avg = self.globalpool(z_all).view(z_all.size(0), z_all.size(1)) 117 | 118 | joints = joints.contiguous().view(self.n_way * (self.n_support + self.n_query), *joints.size()[2:]) 119 | img_len = x.size()[-1] 120 | feat_len = z_all.size()[-1] 121 | joints[:, :, :2] = joints[:, :, :2] / img_len * feat_len 122 | joints = joints.round().int() 123 | joints_num = joints.size(1) 124 | 125 | avg_mask = (joints[:, :, 2] == 0) + (joints[:, :, 0] < 0) + (joints[:, :, 1] < 0) + ( 126 | joints[:, :, 0] >= feat_len) + (joints[:, :, 1] >= feat_len) 127 | avg_mask = (avg_mask > 0).long().unsqueeze(-1).cuda() # (85, 15, 1) 128 | mask_joints = joints.cuda() * (1 - avg_mask) 129 | mask_joints = (mask_joints[:, :, 0] * 7 + mask_joints[:, :, 1]).unsqueeze(1).repeat(1, 64, 1) 130 | z_all_2D = z_all.view(z_all.size(0), z_all.size(1), -1) 131 | mask_z = torch.gather(z_all_2D, dim=-1, index=mask_joints) 132 | mask_z = mask_z.permute(0, 2, 1) # (85, 15, 64) 133 | mask_z_avg = z_avg.unsqueeze(1).repeat(1, joints_num, 1) * avg_mask 134 | z_all_tmp = mask_z * (1 - avg_mask) + mask_z_avg 135 | z_all = torch.cat([z_all_tmp, z_avg.unsqueeze(1)], dim=1).view(self.n_way, self.n_support + self.n_query, -1) 136 | z_support = z_all[:, :self.n_support] 137 | z_query = z_all[:, self.n_support:] 138 | return z_support, z_query 139 | 140 | def train_loop(self, epoch, train_loader, optimizer, tf_writer): 141 | print_freq = 10 142 | 143 | avg_loss, avg_loss1, avg_loss2 = 0, 0, 0 144 | start_time = time.time() 145 | for i, (x, y, joints) in enumerate(train_loader): 146 | self.n_query = x.size(1) - self.n_support 147 | x, y = x.cuda(), y.cuda() 148 | attr_labels = self.attr_labels_split[y] 149 | optimizer.zero_grad() 150 | loss1 = self.set_forward_loss1(x, joints) 151 | loss2 = self.set_forward_loss2(x, attr_labels, joints) 152 | loss = loss1 + self.beta * loss2 153 | loss.backward() 154 | optimizer.step() 155 | avg_loss = avg_loss + loss.item() 156 | avg_loss1 = avg_loss1 + loss1.item() 157 | avg_loss2 = avg_loss2 + loss2.item() 158 | 159 | if i % print_freq == 0: 160 | # print(optimizer.state_dict()['param_groups'][0]['lr']) 161 | print('Epoch {:d} | Batch {:d}/{:d} | Loss {:f} | Loss1 {:f} | Loss2 {:f}'.format(epoch, i, len(train_loader), 162 | avg_loss / float(i + 1), avg_loss1 / float(i + 1), avg_loss2 / float(i + 1))) 163 | tf_writer.add_scalar('loss/train', avg_loss / float(i + 1), epoch) 164 | tf_writer.add_scalar('loss1/train', avg_loss1 / float(i + 1), epoch) 165 | tf_writer.add_scalar('loss2/train', avg_loss2 / float(i + 1), epoch) 166 | print("Epoch (train) uses %.2f s!" % (time.time() - start_time)) 167 | 168 | def test_loop(self, test_loader, return_std=False): 169 | acc_all = [] 170 | 171 | iter_num = len(test_loader) 172 | start_time = time.time() 173 | for i, (x, _, joints) in enumerate(test_loader): 174 | x = x.cuda() 175 | self.n_query = x.size(1) - self.n_support 176 | correct_this, count_this = self.correct(x, joints) 177 | 178 | acc_all.append(correct_this / count_this * 100) 179 | print("Epoch (test) uses %.2f s!" % (time.time() - start_time)) 180 | 181 | acc_all = np.asarray(acc_all) 182 | acc_mean = np.mean(acc_all) 183 | acc_std = np.std(acc_all) 184 | print('%d Test Acc = %4.2f%% +- %4.2f%%' % (iter_num, acc_mean, 1.96 * acc_std / np.sqrt(iter_num))) 185 | 186 | if return_std: 187 | return acc_mean, acc_std 188 | else: 189 | return acc_mean 190 | 191 | def test_loop_with_dists(self, test_loader, all_cls_dists, base_dists, base_cls_dists, attr_num): 192 | acc_all, dist_all = [], [] 193 | attr_num = all_cls_dists.shape[1] 194 | 195 | iter_num = len(test_loader) 196 | from tqdm import tqdm 197 | for i, (x, y, joints) in enumerate(tqdm(test_loader)): 198 | x, y = x.cuda(), y.cuda() 199 | self.n_query = x.size(1) - self.n_support 200 | if self.change_way: 201 | self.n_way = x.size(0) 202 | correct_this, count_this = self.correct(x, joints) 203 | acc_all.append(correct_this / count_this * 100) 204 | 205 | sc_cls = y.unique() 206 | # original mean-task (down trend) 207 | task_dists = all_cls_dists[sc_cls, :].mean(0).unsqueeze(0) 208 | dist_all.append(torch.abs(base_dists - task_dists).sum(-1).mean().item() / attr_num) 209 | 210 | acc_all = np.asarray(acc_all) 211 | acc_mean = np.mean(acc_all) 212 | acc_std = np.std(acc_all) 213 | print('%d Test Acc = %4.2f%% +- %4.2f%%' % (iter_num, acc_mean, 1.96 * acc_std / np.sqrt(iter_num))) 214 | 215 | return acc_all, dist_all 216 | 217 | class APNet_wo_attrLoc(APNetTemplate): 218 | def __init__(self, model_func, n_way, n_support, attr_num, attr_loc=False, dataset='CUB'): 219 | super(APNet_wo_attrLoc, self).__init__(model_func, n_way, n_support, attr_num, attr_loc, dataset) 220 | 221 | def train_loop(self, epoch, train_loader, optimizer, tf_writer): 222 | print_freq = 10 223 | 224 | avg_loss, avg_loss1, avg_loss2 = 0, 0, 0 225 | start_time = time.time() 226 | for i, (x, y) in enumerate(train_loader): 227 | self.n_query = x.size(1) - self.n_support 228 | x, y = x.cuda(), y.cuda() 229 | attr_labels = self.attr_labels_split[y] 230 | optimizer.zero_grad() 231 | loss1 = self.set_forward_loss1(x, None) 232 | loss2 = self.set_forward_loss2(x, attr_labels, None) 233 | loss = loss1 + self.beta * loss2 234 | loss.backward() 235 | optimizer.step() 236 | avg_loss = avg_loss + loss.item() 237 | avg_loss1 = avg_loss1 + loss1.item() 238 | avg_loss2 = avg_loss2 + loss2.item() 239 | 240 | if i % print_freq == 0: 241 | # print(optimizer.state_dict()['param_groups'][0]['lr']) 242 | print('Epoch {:d} | Batch {:d}/{:d} | Loss {:f} | Loss1 {:f} | Loss2 {:f}'.format(epoch, i, len(train_loader), 243 | avg_loss / float(i + 1), avg_loss1 / float(i + 1), avg_loss2 / float(i + 1))) 244 | tf_writer.add_scalar('loss/train', avg_loss / float(i + 1), epoch) 245 | tf_writer.add_scalar('loss1/train', avg_loss1 / float(i + 1), epoch) 246 | tf_writer.add_scalar('loss2/train', avg_loss2 / float(i + 1), epoch) 247 | print("Epoch (train) uses %.2f s!" % (time.time() - start_time)) 248 | 249 | def test_loop(self, test_loader, return_std=False): 250 | acc_all = [] 251 | 252 | iter_num = len(test_loader) 253 | start_time = time.time() 254 | for i, (x, _) in enumerate(test_loader): 255 | x = x.cuda() 256 | self.n_query = x.size(1) - self.n_support 257 | correct_this, count_this = self.correct(x, None) 258 | 259 | acc_all.append(correct_this / count_this * 100) 260 | print("Epoch (test) uses %.2f s!" % (time.time() - start_time)) 261 | 262 | acc_all = np.asarray(acc_all) 263 | acc_mean = np.mean(acc_all) 264 | acc_std = np.std(acc_all) 265 | print('%d Test Acc = %4.2f%% +- %4.2f%%' % (iter_num, acc_mean, 1.96 * acc_std / np.sqrt(iter_num))) 266 | 267 | if return_std: 268 | return acc_mean, acc_std 269 | else: 270 | return acc_mean 271 | 272 | def test_loop_with_dists(self, test_loader, all_cls_dists, base_dists, base_cls_dists, attr_num): 273 | acc_all, dist_all = [], [] 274 | 275 | iter_num = len(test_loader) 276 | from tqdm import tqdm 277 | for i, (x, y) in enumerate(tqdm(test_loader)): 278 | x, y = x.cuda(), y.cuda() 279 | self.n_query = x.size(1) - self.n_support 280 | if self.change_way: 281 | self.n_way = x.size(0) 282 | correct_this, count_this = self.correct(x, None) 283 | acc_all.append(correct_this / count_this * 100) 284 | 285 | sc_cls = y.unique() 286 | # original mean-task (down trend) 287 | task_dists = all_cls_dists[sc_cls, :].mean(0).unsqueeze(0) 288 | dist_all.append(torch.abs(base_dists - task_dists).sum(-1).mean().item() / attr_num) 289 | 290 | acc_all = np.asarray(acc_all) 291 | acc_mean = np.mean(acc_all) 292 | acc_std = np.std(acc_all) 293 | print('%d Test Acc = %4.2f%% +- %4.2f%%' % (iter_num, acc_mean, 1.96 * acc_std / np.sqrt(iter_num))) 294 | 295 | return acc_all, dist_all -------------------------------------------------------------------------------- /backbone.py: -------------------------------------------------------------------------------- 1 | # This code is modified from https://github.com/facebookresearch/low-shot-shrink-hallucinate 2 | 3 | import torch 4 | from torch.autograd import Variable 5 | import torch.nn as nn 6 | import math 7 | import numpy as np 8 | import torch.nn.functional as F 9 | from torch.nn.utils.weight_norm import WeightNorm 10 | 11 | 12 | # Basic ResNet model 13 | 14 | def init_layer(L): 15 | # Initialization using fan-in 16 | if isinstance(L, nn.Conv2d): 17 | n = L.kernel_size[0] * L.kernel_size[1] * L.out_channels 18 | L.weight.data.normal_(0, math.sqrt(2.0 / float(n))) 19 | elif isinstance(L, nn.BatchNorm2d): 20 | L.weight.data.fill_(1) 21 | L.bias.data.fill_(0) 22 | 23 | 24 | class distLinear(nn.Module): 25 | def __init__(self, indim, outdim): 26 | super(distLinear, self).__init__() 27 | self.L = nn.Linear(indim, outdim, bias=False) 28 | self.class_wise_learnable_norm = True # See the issue#4&8 in the github 29 | if self.class_wise_learnable_norm: 30 | WeightNorm.apply(self.L, 'weight', dim=0) # split the weight update component to direction and norm 31 | 32 | if outdim <= 200: 33 | self.scale_factor = 2; # a fixed scale factor to scale the output of cos value into a reasonably large input for softmax, for to reproduce the result of CUB with ResNet10, use 4. see the issue#31 in the github 34 | else: 35 | self.scale_factor = 10; # in omniglot, a larger scale factor is required to handle >1000 output classes. 36 | 37 | def forward(self, x): 38 | x_norm = torch.norm(x, p=2, dim=1).unsqueeze(1).expand_as(x) 39 | x_normalized = x.div(x_norm + 0.00001) 40 | if not self.class_wise_learnable_norm: 41 | L_norm = torch.norm(self.L.weight.data, p=2, dim=1).unsqueeze(1).expand_as(self.L.weight.data) 42 | self.L.weight.data = self.L.weight.data.div(L_norm + 0.00001) 43 | cos_dist = self.L( 44 | x_normalized) # matrix product by forward function, but when using WeightNorm, this also multiply the cosine distance by a class-wise learnable norm, see the issue#4&8 in the github 45 | scores = self.scale_factor * (cos_dist) 46 | 47 | return scores 48 | 49 | 50 | class Flatten(nn.Module): 51 | def __init__(self): 52 | super(Flatten, self).__init__() 53 | 54 | def forward(self, x): 55 | return x.view(x.size(0), -1) 56 | 57 | 58 | class Linear_fw(nn.Linear): # used in MAML to forward input with fast weight 59 | def __init__(self, in_features, out_features): 60 | super(Linear_fw, self).__init__(in_features, out_features) 61 | self.weight.fast = None # Lazy hack to add fast weight link 62 | self.bias.fast = None 63 | 64 | def forward(self, x): 65 | if self.weight.fast is not None and self.bias.fast is not None: 66 | out = F.linear(x, self.weight.fast, 67 | self.bias.fast) # weight.fast (fast weight) is the temporaily adapted weight 68 | else: 69 | out = super(Linear_fw, self).forward(x) 70 | return out 71 | 72 | 73 | class Conv2d_fw(nn.Conv2d): # used in MAML to forward input with fast weight 74 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, bias=True): 75 | super(Conv2d_fw, self).__init__(in_channels, out_channels, kernel_size, stride=stride, padding=padding, 76 | bias=bias) 77 | self.weight.fast = None 78 | if not self.bias is None: 79 | self.bias.fast = None 80 | 81 | def forward(self, x): 82 | if self.bias is None: 83 | if self.weight.fast is not None: 84 | out = F.conv2d(x, self.weight.fast, None, stride=self.stride, padding=self.padding) 85 | else: 86 | out = super(Conv2d_fw, self).forward(x) 87 | else: 88 | if self.weight.fast is not None and self.bias.fast is not None: 89 | out = F.conv2d(x, self.weight.fast, self.bias.fast, stride=self.stride, padding=self.padding) 90 | else: 91 | out = super(Conv2d_fw, self).forward(x) 92 | 93 | return out 94 | 95 | 96 | class BatchNorm2d_fw(nn.BatchNorm2d): # used in MAML to forward input with fast weight 97 | def __init__(self, num_features): 98 | super(BatchNorm2d_fw, self).__init__(num_features) 99 | self.weight.fast = None 100 | self.bias.fast = None 101 | 102 | def forward(self, x): 103 | running_mean = torch.zeros(x.data.size()[1]).cuda() 104 | running_var = torch.ones(x.data.size()[1]).cuda() 105 | if self.weight.fast is not None and self.bias.fast is not None: 106 | out = F.batch_norm(x, running_mean, running_var, self.weight.fast, self.bias.fast, training=True, 107 | momentum=1) 108 | # batch_norm momentum hack: follow hack of Kate Rakelly in pytorch-maml/src/layers.py 109 | else: 110 | out = F.batch_norm(x, running_mean, running_var, self.weight, self.bias, training=True, momentum=1) 111 | return out 112 | 113 | 114 | # Simple Conv Block 115 | class ConvBlock(nn.Module): 116 | maml = False # Default 117 | 118 | def __init__(self, indim, outdim, pool=True, padding=1): 119 | super(ConvBlock, self).__init__() 120 | self.indim = indim 121 | self.outdim = outdim 122 | if self.maml: 123 | self.C = Conv2d_fw(indim, outdim, 3, padding=padding) 124 | self.BN = BatchNorm2d_fw(outdim) 125 | else: 126 | self.C = nn.Conv2d(indim, outdim, 3, padding=padding) 127 | self.BN = nn.BatchNorm2d(outdim) 128 | self.relu = nn.ReLU(inplace=True) 129 | 130 | self.parametrized_layers = [self.C, self.BN, self.relu] 131 | if pool: 132 | self.pool = nn.MaxPool2d(2) 133 | self.parametrized_layers.append(self.pool) 134 | 135 | for layer in self.parametrized_layers: 136 | init_layer(layer) 137 | 138 | self.trunk = nn.Sequential(*self.parametrized_layers) 139 | 140 | def forward(self, x): 141 | out = self.trunk(x) 142 | return out 143 | 144 | 145 | # Simple ResNet Block 146 | class SimpleBlock(nn.Module): 147 | maml = False # Default 148 | 149 | def __init__(self, indim, outdim, half_res): 150 | super(SimpleBlock, self).__init__() 151 | self.indim = indim 152 | self.outdim = outdim 153 | if self.maml: 154 | self.C1 = Conv2d_fw(indim, outdim, kernel_size=3, stride=2 if half_res else 1, padding=1, bias=False) 155 | self.BN1 = BatchNorm2d_fw(outdim) 156 | self.C2 = Conv2d_fw(outdim, outdim, kernel_size=3, padding=1, bias=False) 157 | self.BN2 = BatchNorm2d_fw(outdim) 158 | else: 159 | self.C1 = nn.Conv2d(indim, outdim, kernel_size=3, stride=2 if half_res else 1, padding=1, bias=False) 160 | self.BN1 = nn.BatchNorm2d(outdim) 161 | self.C2 = nn.Conv2d(outdim, outdim, kernel_size=3, padding=1, bias=False) 162 | self.BN2 = nn.BatchNorm2d(outdim) 163 | self.relu1 = nn.ReLU(inplace=True) 164 | self.relu2 = nn.ReLU(inplace=True) 165 | 166 | self.parametrized_layers = [self.C1, self.C2, self.BN1, self.BN2] 167 | 168 | self.half_res = half_res 169 | 170 | # if the input number of channels is not equal to the output, then need a 1x1 convolution 171 | if indim != outdim: 172 | if self.maml: 173 | self.shortcut = Conv2d_fw(indim, outdim, 1, 2 if half_res else 1, bias=False) 174 | self.BNshortcut = BatchNorm2d_fw(outdim) 175 | else: 176 | self.shortcut = nn.Conv2d(indim, outdim, 1, 2 if half_res else 1, bias=False) 177 | self.BNshortcut = nn.BatchNorm2d(outdim) 178 | 179 | self.parametrized_layers.append(self.shortcut) 180 | self.parametrized_layers.append(self.BNshortcut) 181 | self.shortcut_type = '1x1' 182 | else: 183 | self.shortcut_type = 'identity' 184 | 185 | for layer in self.parametrized_layers: 186 | init_layer(layer) 187 | 188 | def forward(self, x): 189 | out = self.C1(x) 190 | out = self.BN1(out) 191 | out = self.relu1(out) 192 | out = self.C2(out) 193 | out = self.BN2(out) 194 | short_out = x if self.shortcut_type == 'identity' else self.BNshortcut(self.shortcut(x)) 195 | out = out + short_out 196 | out = self.relu2(out) 197 | return out 198 | 199 | 200 | # Bottleneck block 201 | class BottleneckBlock(nn.Module): 202 | maml = False # Default 203 | 204 | def __init__(self, indim, outdim, half_res): 205 | super(BottleneckBlock, self).__init__() 206 | bottleneckdim = int(outdim / 4) 207 | self.indim = indim 208 | self.outdim = outdim 209 | if self.maml: 210 | self.C1 = Conv2d_fw(indim, bottleneckdim, kernel_size=1, bias=False) 211 | self.BN1 = BatchNorm2d_fw(bottleneckdim) 212 | self.C2 = Conv2d_fw(bottleneckdim, bottleneckdim, kernel_size=3, stride=2 if half_res else 1, padding=1) 213 | self.BN2 = BatchNorm2d_fw(bottleneckdim) 214 | self.C3 = Conv2d_fw(bottleneckdim, outdim, kernel_size=1, bias=False) 215 | self.BN3 = BatchNorm2d_fw(outdim) 216 | else: 217 | self.C1 = nn.Conv2d(indim, bottleneckdim, kernel_size=1, bias=False) 218 | self.BN1 = nn.BatchNorm2d(bottleneckdim) 219 | self.C2 = nn.Conv2d(bottleneckdim, bottleneckdim, kernel_size=3, stride=2 if half_res else 1, padding=1) 220 | self.BN2 = nn.BatchNorm2d(bottleneckdim) 221 | self.C3 = nn.Conv2d(bottleneckdim, outdim, kernel_size=1, bias=False) 222 | self.BN3 = nn.BatchNorm2d(outdim) 223 | 224 | self.relu = nn.ReLU() 225 | self.parametrized_layers = [self.C1, self.BN1, self.C2, self.BN2, self.C3, self.BN3] 226 | self.half_res = half_res 227 | 228 | # if the input number of channels is not equal to the output, then need a 1x1 convolution 229 | if indim != outdim: 230 | if self.maml: 231 | self.shortcut = Conv2d_fw(indim, outdim, 1, stride=2 if half_res else 1, bias=False) 232 | else: 233 | self.shortcut = nn.Conv2d(indim, outdim, 1, stride=2 if half_res else 1, bias=False) 234 | 235 | self.parametrized_layers.append(self.shortcut) 236 | self.shortcut_type = '1x1' 237 | else: 238 | self.shortcut_type = 'identity' 239 | 240 | for layer in self.parametrized_layers: 241 | init_layer(layer) 242 | 243 | def forward(self, x): 244 | 245 | short_out = x if self.shortcut_type == 'identity' else self.shortcut(x) 246 | out = self.C1(x) 247 | out = self.BN1(out) 248 | out = self.relu(out) 249 | out = self.C2(out) 250 | out = self.BN2(out) 251 | out = self.relu(out) 252 | out = self.C3(out) 253 | out = self.BN3(out) 254 | out = out + short_out 255 | 256 | out = self.relu(out) 257 | return out 258 | 259 | 260 | class ConvNet(nn.Module): 261 | def __init__(self, depth, flatten=True): 262 | super(ConvNet, self).__init__() 263 | trunk = [] 264 | for i in range(depth): 265 | indim = 3 if i == 0 else 64 266 | outdim = 64 267 | B = ConvBlock(indim, outdim, pool=(i < 4)) # only pooling for fist 4 layers 268 | trunk.append(B) 269 | 270 | if flatten: 271 | trunk.append(Flatten()) 272 | 273 | self.trunk = nn.Sequential(*trunk) 274 | self.final_feat_dim = 1600 275 | 276 | def forward(self, x): 277 | out = self.trunk(x) 278 | return out 279 | 280 | 281 | class ConvNetNopool( 282 | nn.Module): # Relation net use a 4 layer conv with pooling in only first two layers, else no pooling 283 | def __init__(self, depth, flatten=True): 284 | super(ConvNetNopool, self).__init__() 285 | trunk = [] 286 | for i in range(depth): 287 | indim = 3 if i == 0 else 64 288 | outdim = 64 289 | B = ConvBlock(indim, outdim, pool=(i in [0, 1]), 290 | padding=0 if i in [0, 1] else 1) # only first two layer has pooling and no padding 291 | trunk.append(B) 292 | 293 | if flatten: 294 | lastpool = nn.AdaptiveAvgPool2d((1, 1)) 295 | trunk.append(lastpool) 296 | trunk.append(Flatten()) 297 | self.final_feat_dim = 64 298 | else: 299 | lastpool = nn.AdaptiveAvgPool2d((7, 7)) 300 | trunk.append(lastpool) 301 | self.final_feat_dim = [64, 7, 7] 302 | 303 | self.trunk = nn.Sequential(*trunk) 304 | 305 | def forward(self, x): 306 | out = self.trunk(x) 307 | return out 308 | 309 | 310 | class ConvNetS(nn.Module): # For omniglot, only 1 input channel, output dim is 64 311 | def __init__(self, depth, flatten=True): 312 | super(ConvNetS, self).__init__() 313 | trunk = [] 314 | for i in range(depth): 315 | indim = 1 if i == 0 else 64 316 | outdim = 64 317 | B = ConvBlock(indim, outdim, pool=(i < 4)) # only pooling for fist 4 layers 318 | trunk.append(B) 319 | 320 | if flatten: 321 | trunk.append(Flatten()) 322 | 323 | self.trunk = nn.Sequential(*trunk) 324 | self.final_feat_dim = 64 325 | 326 | def forward(self, x): 327 | out = x[:, 0:1, :, :] # only use the first dimension 328 | out = self.trunk(out) 329 | return out 330 | 331 | 332 | class ConvNetSNopool( 333 | nn.Module): # Relation net use a 4 layer conv with pooling in only first two layers, else no pooling. For omniglot, only 1 input channel, output dim is [64,5,5] 334 | def __init__(self, depth): 335 | super(ConvNetSNopool, self).__init__() 336 | trunk = [] 337 | for i in range(depth): 338 | indim = 1 if i == 0 else 64 339 | outdim = 64 340 | B = ConvBlock(indim, outdim, pool=(i in [0, 1]), 341 | padding=0 if i in [0, 1] else 1) # only first two layer has pooling and no padding 342 | trunk.append(B) 343 | 344 | self.trunk = nn.Sequential(*trunk) 345 | self.final_feat_dim = [64, 5, 5] 346 | 347 | def forward(self, x): 348 | out = x[:, 0:1, :, :] # only use the first dimension 349 | out = self.trunk(out) 350 | return out 351 | 352 | 353 | class ResNet(nn.Module): 354 | maml = False # Default 355 | 356 | def __init__(self, block, list_of_num_layers, list_of_out_dims, flatten=True): 357 | # list_of_num_layers specifies number of layers in each stage 358 | # list_of_out_dims specifies number of output channel for each stage 359 | super(ResNet, self).__init__() 360 | assert len(list_of_num_layers) == 4, 'Can have only four stages' 361 | if self.maml: 362 | conv1 = Conv2d_fw(3, 64, kernel_size=7, stride=2, padding=3, 363 | bias=False) 364 | bn1 = BatchNorm2d_fw(64) 365 | else: 366 | conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, 367 | bias=False) 368 | bn1 = nn.BatchNorm2d(64) 369 | 370 | relu = nn.ReLU() 371 | pool1 = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 372 | 373 | init_layer(conv1) 374 | init_layer(bn1) 375 | 376 | trunk = [conv1, bn1, relu, pool1] 377 | 378 | indim = 64 379 | for i in range(4): 380 | 381 | for j in range(list_of_num_layers[i]): 382 | half_res = (i >= 1) and (j == 0) 383 | B = block(indim, list_of_out_dims[i], half_res) 384 | trunk.append(B) 385 | indim = list_of_out_dims[i] 386 | 387 | if flatten: 388 | avgpool = nn.AvgPool2d(7) 389 | trunk.append(avgpool) 390 | trunk.append(Flatten()) 391 | self.final_feat_dim = indim 392 | else: 393 | self.final_feat_dim = [indim, 7, 7] 394 | 395 | self.trunk = nn.Sequential(*trunk) 396 | 397 | def forward(self, x): 398 | out = self.trunk(x) 399 | return out 400 | 401 | 402 | def Conv4(): 403 | return ConvNetNopool(4, flatten=True) 404 | 405 | 406 | def Conv6(): 407 | return ConvNet(6) 408 | 409 | 410 | def Conv4NP(flatten=False): 411 | return ConvNetNopool(4, flatten=False) 412 | 413 | 414 | def Conv6NP(flatten=False): 415 | return ConvNetNopool(6, flatten=False) 416 | 417 | 418 | def Conv4S(): 419 | return ConvNetS(4) 420 | 421 | 422 | def Conv4SNP(): 423 | return ConvNetSNopool(4) 424 | 425 | 426 | def ResNet10(flatten=True): 427 | return ResNet(SimpleBlock, [1, 1, 1, 1], [64, 128, 256, 512], flatten) 428 | 429 | 430 | def ResNet10NP(flatten=False): 431 | return ResNet(SimpleBlock, [1, 1, 1, 1], [64, 128, 256, 512], flatten) 432 | 433 | 434 | def ResNet18(flatten=True): 435 | return ResNet(SimpleBlock, [2, 2, 2, 2], [64, 128, 256, 512], flatten) 436 | 437 | 438 | def ResNet34(flatten=True): 439 | return ResNet(SimpleBlock, [3, 4, 6, 3], [64, 128, 256, 512], flatten) 440 | 441 | 442 | def ResNet50(flatten=True): 443 | return ResNet(BottleneckBlock, [3, 4, 6, 3], [256, 512, 1024, 2048], flatten) 444 | 445 | 446 | def ResNet101(flatten=True): 447 | return ResNet(BottleneckBlock, [3, 4, 23, 3], [256, 512, 1024, 2048], flatten) 448 | 449 | 450 | 451 | 452 | --------------------------------------------------------------------------------