├── data
├── __init__.py
├── additional_transforms.py
├── feature_loader.py
├── transforms.py
├── datamgr.py
└── dataset.py
├── requirements.txt
├── .idea
├── vcs.xml
├── inspectionProfiles
│ └── profiles_settings.xml
├── modules.xml
├── TAD.iml
├── remote-mappings.xml
├── misc.xml
├── deployment.xml
└── workspace.xml
├── scripts
├── train
│ ├── cub_protonet.sh
│ └── sun_protonet.sh
└── test
│ ├── sun_protonet.sh
│ ├── cub_protonet.sh
│ └── plot_distance_acc.sh
├── configs.py
├── methods
├── protonet.py
├── baselinefinetune.py
├── matchingnet.py
├── baselinetrain.py
├── meta_template.py
├── relationnet.py
├── maml.py
└── apnet.py
├── README.md
├── filelists
├── SUN
│ ├── write_SUN_filelist.py
│ └── write_SUN_attr.py
└── CUB
│ └── write_CUB_filelist.py
├── io_utils.py
├── save_features.py
├── utils.py
├── test.py
├── plot_distance_acc.py
├── train.py
└── backbone.py
/data/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | opencv-python
2 | matplotlib
3 | scikit-learn
4 | tensorboard
5 | h5py
6 | tqdm
--------------------------------------------------------------------------------
/.idea/vcs.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/.idea/inspectionProfiles/profiles_settings.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/scripts/train/cub_protonet.sh:
--------------------------------------------------------------------------------
1 | CUDA_VISIBLE_DEVICES=3 python train.py --dataset CUB --model Conv4NP --method protonet --train_aug --n_shot 1
2 | CUDA_VISIBLE_DEVICES=3 python train.py --dataset CUB --model Conv4NP --method protonet --train_aug --n_shot 5
--------------------------------------------------------------------------------
/scripts/train/sun_protonet.sh:
--------------------------------------------------------------------------------
1 | CUDA_VISIBLE_DEVICES=2 python train.py --dataset SUN --model Conv4 --method protonet --train_aug --n_shot 1
2 | CUDA_VISIBLE_DEVICES=2 python train.py --dataset SUN --model Conv4 --method protonet --train_aug --n_shot 5
3 |
--------------------------------------------------------------------------------
/.idea/modules.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/.idea/TAD.iml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/.idea/remote-mappings.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
--------------------------------------------------------------------------------
/configs.py:
--------------------------------------------------------------------------------
1 | save_dir = '.'
2 | data_dir = {}
3 | data_dir['CUB'] = './filelists/CUB/'
4 | data_dir['AWA2'] = './filelists/AWA2/'
5 | data_dir['SUN'] = './filelists/SUN/'
6 | data_dir['miniImagenet'] = './filelists/miniImagenet/'
7 | data_dir['omniglot'] = './filelists/omniglot/'
8 | data_dir['emnist'] = './filelists/emnist/'
9 |
--------------------------------------------------------------------------------
/.idea/misc.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
--------------------------------------------------------------------------------
/scripts/test/sun_protonet.sh:
--------------------------------------------------------------------------------
1 | python save_features.py --dataset SUN --model Conv4 --method protonet --train_aug --n_shot 1
2 | CUDA_VISIBLE_DEVICES=2 python test.py --dataset SUN --model Conv4 --method protonet --train_aug --n_shot 1
3 | python save_features.py --dataset SUN --model Conv4 --method protonet --train_aug --n_shot 5
4 | CUDA_VISIBLE_DEVICES=2 python test.py --dataset SUN --model Conv4 --method protonet --train_aug --n_shot 5
--------------------------------------------------------------------------------
/scripts/test/cub_protonet.sh:
--------------------------------------------------------------------------------
1 | python save_features.py --dataset CUB --model Conv4NP --method protonet --train_aug --n_shot 1
2 | CUDA_VISIBLE_DEVICES=3 python test.py --dataset CUB --model Conv4NP --method protonet --train_aug --n_shot 1
3 | python save_features.py --dataset CUB --model Conv4NP --method protonet --train_aug --n_shot 5
4 | CUDA_VISIBLE_DEVICES=3 python test.py --dataset CUB --model Conv4NP --method protonet --train_aug --n_shot 5
--------------------------------------------------------------------------------
/scripts/test/plot_distance_acc.sh:
--------------------------------------------------------------------------------
1 | python save_features.py --dataset CUB --model Conv4NP --method protonet --train_aug --n_shot 1
2 | CUDA_VISIBLE_DEVICES=2 python plot_distance_acc.py --dataset CUB --model Conv4NP --method protonet --train_aug --n_shot 1 --runs 1
3 | CUDA_VISIBLE_DEVICES=2 python plot_distance_acc.py --dataset CUB --model Conv4NP --method protonet --train_aug --n_shot 1 --runs 2
4 | CUDA_VISIBLE_DEVICES=2 python plot_distance_acc.py --dataset CUB --model Conv4NP --method protonet --train_aug --n_shot 1 --runs 3
5 | CUDA_VISIBLE_DEVICES=2 python plot_distance_acc.py --dataset CUB --model Conv4NP --method protonet --train_aug --n_shot 1 --runs 4
6 | CUDA_VISIBLE_DEVICES=2 python plot_distance_acc.py --dataset CUB --model Conv4NP --method protonet --train_aug --n_shot 1 --runs 5
7 |
--------------------------------------------------------------------------------
/data/additional_transforms.py:
--------------------------------------------------------------------------------
1 | # Copyright 2017-present, Facebook, Inc.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 |
8 | import torch
9 | from PIL import ImageEnhance
10 |
11 | transformtypedict=dict(Brightness=ImageEnhance.Brightness, Contrast=ImageEnhance.Contrast, Sharpness=ImageEnhance.Sharpness, Color=ImageEnhance.Color)
12 |
13 |
14 |
15 | class ImageJitter(object):
16 | def __init__(self, transformdict):
17 | self.transforms = [(transformtypedict[k], transformdict[k]) for k in transformdict]
18 |
19 |
20 | def __call__(self, img):
21 | out = img
22 | randtensor = torch.rand(len(self.transforms))
23 |
24 | for i, (transformer, alpha) in enumerate(self.transforms):
25 | r = alpha*(randtensor[i]*2.0 -1.0) + 1
26 | out = transformer(out).enhance(r).convert('RGB')
27 |
28 | return out
29 |
30 |
31 |
32 |
33 |
--------------------------------------------------------------------------------
/data/feature_loader.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | import h5py
4 |
5 |
6 | class SimpleHDF5Dataset:
7 | def __init__(self, file_handle=None):
8 | if file_handle == None:
9 | self.f = ''
10 | self.all_feats_dset = []
11 | self.all_labels = []
12 | self.total = 0
13 | else:
14 | self.f = file_handle
15 | self.all_feats_dset = self.f['all_feats'][...]
16 | self.all_labels = self.f['all_labels'][...]
17 | self.total = self.f['count'][0]
18 | # print('here')
19 |
20 | def __getitem__(self, i):
21 | return torch.Tensor(self.all_feats_dset[i, :]), int(self.all_labels[i])
22 |
23 | def __len__(self):
24 | return self.total
25 |
26 |
27 | def init_loader(filename):
28 | with h5py.File(filename, 'r') as f:
29 | fileset = SimpleHDF5Dataset(f)
30 |
31 | # labels = [ l for l in fileset.all_labels if l != 0]
32 | feats = fileset.all_feats_dset
33 | labels = fileset.all_labels
34 | while np.sum(feats[-1]) == 0:
35 | feats = np.delete(feats, -1, axis=0)
36 | labels = np.delete(labels, -1, axis=0)
37 |
38 | class_list = np.unique(np.array(labels)).tolist()
39 | inds = range(len(labels))
40 |
41 | cl_data_file = {}
42 | for cl in class_list:
43 | cl_data_file[cl] = []
44 | for ind in inds:
45 | cl_data_file[labels[ind]].append(feats[ind])
46 |
47 | return cl_data_file
48 |
--------------------------------------------------------------------------------
/.idea/deployment.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 |
43 |
--------------------------------------------------------------------------------
/methods/protonet.py:
--------------------------------------------------------------------------------
1 | # This code is modified from https://github.com/jakesnell/prototypical-networks
2 |
3 | import backbone
4 | import torch
5 | import torch.nn as nn
6 | from torch.autograd import Variable
7 | import numpy as np
8 | import torch.nn.functional as F
9 | from methods.meta_template import MetaTemplate
10 |
11 | class ProtoNet(MetaTemplate):
12 | def __init__(self, model_func, n_way, n_support):
13 | super(ProtoNet, self).__init__( model_func, n_way, n_support)
14 | self.loss_fn = nn.CrossEntropyLoss()
15 |
16 |
17 | def set_forward(self,x,is_feature = False):
18 | z_support, z_query = self.parse_feature(x,is_feature)
19 |
20 | z_support = z_support.contiguous()
21 | z_proto = z_support.view(self.n_way, self.n_support, -1 ).mean(1) #the shape of z is [n_data, n_dim]
22 | z_query = z_query.contiguous().view(self.n_way* self.n_query, -1 )
23 |
24 | dists = euclidean_dist(z_query, z_proto)
25 | scores = -dists
26 | return scores
27 |
28 |
29 | def set_forward_loss(self, x):
30 | y_query = torch.from_numpy(np.repeat(range( self.n_way ), self.n_query ))
31 | y_query = Variable(y_query.cuda())
32 |
33 | scores = self.set_forward(x)
34 |
35 | return self.loss_fn(scores, y_query )
36 |
37 |
38 | def euclidean_dist( x, y):
39 | # x: N x D
40 | # y: M x D
41 | n = x.size(0)
42 | m = y.size(0)
43 | d = x.size(1)
44 | assert d == y.size(1)
45 |
46 | x = x.unsqueeze(1).expand(n, m, d)
47 | y = y.unsqueeze(0).expand(n, m, d)
48 |
49 | return torch.pow(x - y, 2).sum(2)
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 | # Understanding Few-Shot Learning: Measuring Task Relatedness and Adaptation Difficulty via Attributes
3 | This repository is the official implementation of the paper "Understanding Few-Shot Learning: Measuring Task Relatedness and Adaptation Difficulty via Attributes" in Neural Information Processing Systems (NeurIPS 2023). In this project, we provide the Task Attribute Distance (TAD) metric to quantify the task relatedness and measure the adaptation difficulty of novel tasks.
4 |
5 | ## Dependenices
6 |
7 | The code is built with following libraries:
8 | - python 3.7
9 | - PyTorch 1.7.1
10 | - cv2
11 | - matplotlib
12 | - sklearn
13 | - tensorboard
14 | - h5py
15 | - tqdm
16 |
17 | #### Installation
18 | ```setup
19 | conda create -n TAD python=3.7
20 | source activate TAD
21 | conda install pytorch==1.7.1 torchvision==0.8.2 torchaudio==0.7.2 cudatoolkit=11.0 -c pytorch
22 | pip install -r requirements.txt
23 | ```
24 |
25 | #### Dataset prepare
26 | Please download the CUB and SUN datasets, then put them under the path of `filelists//`.
27 |
28 | Here we provide a [link](https://drive.google.com/file/d/1Je-BZaCVe9fSoUUpkBhBlm8thxalRxkI/view?usp=sharing) of CUB dataset and related files.
29 |
30 | ## Training
31 |
32 | To train the FSL models (such as ProtoNet) on CUB dataset, run this command:
33 |
34 | ```train
35 | bash scripts/train/cub_protonet.sh
36 | ```
37 |
38 | ## Evaluation
39 |
40 | To evaluate models on CUB, run:
41 |
42 | ```eval
43 | bash scripts/test/cub_protonet.sh
44 | ```
45 |
46 | ## Plot task distance and accuracy
47 |
48 | To estimate the average TAD between each novel task and training tasks, then plot a figure of average TAD and accuracy, run:
49 |
50 | ```eval
51 | bash scripts/test/plot_distance_acc.sh
52 | ```
53 |
54 | ## Fast start
55 |
56 | Here we provide some pretrained models for fast start.
57 |
58 | - [ProtoNet (Conv4NP)](https://drive.google.com/file/d/1AxXRP0QSmH0C5Y3i8GXEHThg6otK8leH/view?usp=sharing) trained on CUB in the 5-way 1-shot setting
59 |
60 | Download the pretrained model at file path `checkpoints/CUB/Conv4NP_protonet_0_aug_5way_1shot/`, and then run the command in `Plot task distance and accuracy` part.
61 |
62 | Our codebase is developed based on the [baseline++](https://github.com/wyharveychen/CloserLookFewShot) from the paper [A Closer Look at Few-shot Classification](https://arxiv.org/abs/1904.04232) and [COMET](https://github.com/snap-stanford/comet) from the paper [Concept Learners for Few-Shot Learning](https://arxiv.org/pdf/2007.07375.pdf).
--------------------------------------------------------------------------------
/methods/baselinefinetune.py:
--------------------------------------------------------------------------------
1 | import backbone
2 | import torch
3 | import torch.nn as nn
4 | from torch.autograd import Variable
5 | import numpy as np
6 | import torch.nn.functional as F
7 | from methods.meta_template import MetaTemplate
8 |
9 |
10 | class BaselineFinetune(MetaTemplate):
11 | def __init__(self, model_func, n_way, n_support, loss_type="softmax"):
12 | super(BaselineFinetune, self).__init__(model_func, n_way, n_support)
13 | self.loss_type = loss_type
14 |
15 | def set_forward(self, x, is_feature=True):
16 | return self.set_forward_adaptation(x, is_feature); # Baseline always do adaptation
17 |
18 | def set_forward_adaptation(self, x, is_feature=True):
19 | assert is_feature == True, 'Baseline only support testing with feature'
20 | z_support, z_query = self.parse_feature(x, is_feature)
21 |
22 | z_support = z_support.contiguous().view(self.n_way * self.n_support, -1)
23 | z_query = z_query.contiguous().view(self.n_way * self.n_query, -1)
24 |
25 | y_support = torch.from_numpy(np.repeat(range(self.n_way), self.n_support))
26 | y_support = Variable(y_support.cuda())
27 |
28 | if self.loss_type == 'softmax':
29 | linear_clf = nn.Linear(self.feat_dim, self.n_way)
30 | elif self.loss_type == 'dist':
31 | linear_clf = backbone.distLinear(self.feat_dim, self.n_way)
32 | linear_clf = linear_clf.cuda()
33 |
34 | set_optimizer = torch.optim.SGD(linear_clf.parameters(), lr=0.01, momentum=0.9, dampening=0.9,
35 | weight_decay=0.001)
36 |
37 | loss_function = nn.CrossEntropyLoss()
38 | loss_function = loss_function.cuda()
39 |
40 | batch_size = 4
41 | support_size = self.n_way * self.n_support
42 | for epoch in range(100):
43 | rand_id = np.random.permutation(support_size)
44 | for i in range(0, support_size, batch_size):
45 | set_optimizer.zero_grad()
46 | selected_id = torch.from_numpy(rand_id[i: min(i + batch_size, support_size)]).cuda()
47 | z_batch = z_support[selected_id]
48 | y_batch = y_support[selected_id]
49 | scores = linear_clf(z_batch)
50 | loss = loss_function(scores, y_batch)
51 | loss.backward()
52 | set_optimizer.step()
53 | scores = linear_clf(z_query)
54 | return scores
55 |
56 | def set_forward_loss(self, x):
57 | raise ValueError('Baseline predict on pretrained feature and do not support finetune backbone')
58 |
59 |
60 |
--------------------------------------------------------------------------------
/.idea/workspace.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 | 1684810029761
41 |
42 |
43 | 1684810029761
44 |
45 |
46 |
47 |
48 |
49 |
50 |
51 |
52 |
53 |
54 |
55 |
56 |
57 |
58 |
59 |
60 |
--------------------------------------------------------------------------------
/data/transforms.py:
--------------------------------------------------------------------------------
1 | # ------------------------------------------------------------------------------
2 | # Copyright (c) Microsoft
3 | # Licensed under the MIT License.
4 | # Written by Bin Xiao (Bin.Xiao@microsoft.com)
5 | # ------------------------------------------------------------------------------
6 |
7 | from __future__ import absolute_import
8 | from __future__ import division
9 | from __future__ import print_function
10 |
11 | import numpy as np
12 | import cv2
13 |
14 | def transform_preds(coords, center, scale, output_size):
15 | target_coords = np.zeros(coords.shape)
16 | trans = get_affine_transform(center, scale, 0, output_size, inv=1)
17 | for p in range(coords.shape[0]):
18 | target_coords[p, 0:2] = affine_transform(coords[p, 0:2], trans)
19 | return target_coords
20 |
21 |
22 | def get_affine_transform(center,
23 | scale,
24 | rot,
25 | output_size,
26 | shift=np.array([0, 0], dtype=np.float32),
27 | inv=0):
28 | if not isinstance(scale, np.ndarray) and not isinstance(scale, list):
29 | print(scale)
30 | scale = np.array([scale, scale])
31 |
32 | scale_tmp = scale * 200.0
33 | src_w = scale_tmp[0]
34 | dst_w = output_size[0]
35 | dst_h = output_size[1]
36 |
37 | rot_rad = np.pi * rot / 180
38 | src_dir = get_dir([0, src_w * -0.5], rot_rad)
39 | dst_dir = np.array([0, dst_w * -0.5], np.float32)
40 |
41 | src = np.zeros((3, 2), dtype=np.float32)
42 | dst = np.zeros((3, 2), dtype=np.float32)
43 | src[0, :] = center + scale_tmp * shift
44 | src[1, :] = center + src_dir + scale_tmp * shift
45 | dst[0, :] = [dst_w * 0.5, dst_h * 0.5]
46 | dst[1, :] = np.array([dst_w * 0.5, dst_h * 0.5]) + dst_dir
47 |
48 | src[2:, :] = get_3rd_point(src[0, :], src[1, :])
49 | dst[2:, :] = get_3rd_point(dst[0, :], dst[1, :])
50 |
51 | if inv:
52 | trans = cv2.getAffineTransform(np.float32(dst), np.float32(src))
53 | else:
54 | trans = cv2.getAffineTransform(np.float32(src), np.float32(dst))
55 |
56 | return trans
57 |
58 |
59 | def affine_transform(pt, t):
60 | new_pt = np.array([pt[0], pt[1], 1.]).T
61 | new_pt = np.dot(t, new_pt)
62 | return new_pt[:2]
63 |
64 |
65 | def get_3rd_point(a, b):
66 | direct = a - b
67 | return b + np.array([-direct[1], direct[0]], dtype=np.float32)
68 |
69 |
70 | def get_dir(src_point, rot_rad):
71 | sn, cs = np.sin(rot_rad), np.cos(rot_rad)
72 |
73 | src_result = [0, 0]
74 | src_result[0] = src_point[0] * cs - src_point[1] * sn
75 | src_result[1] = src_point[0] * sn + src_point[1] * cs
76 |
77 | return src_result
78 |
79 |
80 | def crop(img, center, scale, output_size, rot=0):
81 | trans = get_affine_transform(center, scale, rot, output_size)
82 |
83 | dst_img = cv2.warpAffine(img,
84 | trans,
85 | (int(output_size[0]), int(output_size[1])),
86 | flags=cv2.INTER_LINEAR)
87 |
88 | return dst_img
--------------------------------------------------------------------------------
/data/datamgr.py:
--------------------------------------------------------------------------------
1 | # This code is modified from https://github.com/facebookresearch/low-shot-shrink-hallucinate
2 |
3 | import torch
4 | from PIL import Image
5 | import numpy as np
6 | import torchvision.transforms as transforms
7 | import data.additional_transforms as add_transforms
8 | from data.dataset import SimpleDataset, SetDataset, EpisodicBatchSampler
9 | from abc import abstractmethod
10 |
11 | class TransformLoader:
12 | def __init__(self, image_size,
13 | normalize_param=dict(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
14 | jitter_param=dict(Brightness=0.4, Contrast=0.4, Color=0.4)):
15 | self.image_size = image_size
16 | self.normalize_param = normalize_param
17 | self.jitter_param = jitter_param
18 |
19 | def parse_transform(self, transform_type):
20 | if transform_type == 'ImageJitter':
21 | method = add_transforms.ImageJitter(self.jitter_param)
22 | return method
23 | method = getattr(transforms, transform_type)
24 | if transform_type == 'RandomSizedCrop':
25 | return method(self.image_size)
26 | elif transform_type == 'CenterCrop':
27 | return method(self.image_size)
28 | elif transform_type == 'Scale':
29 | return method([int(self.image_size * 1.15), int(self.image_size * 1.15)])
30 | elif transform_type == 'Normalize':
31 | return method(**self.normalize_param)
32 | else:
33 | return method()
34 |
35 | def get_composed_transform(self, aug=False):
36 | if aug:
37 | transform_list = ['ImageJitter', 'ToTensor', 'Normalize']
38 | else:
39 | transform_list = ['ToTensor', 'Normalize']
40 |
41 | transform_funcs = [self.parse_transform(x) for x in transform_list]
42 | transform = transforms.Compose(transform_funcs)
43 | return transform
44 |
45 |
46 | class DataManager:
47 | @abstractmethod
48 | def get_data_loader(self, data_file, aug):
49 | pass
50 |
51 |
52 | class SimpleDataManager(DataManager):
53 | def __init__(self, image_size, batch_size):
54 | super(SimpleDataManager, self).__init__()
55 | self.batch_size = batch_size
56 | self.trans_loader = TransformLoader(image_size)
57 | self.image_size = image_size
58 |
59 | def get_data_loader(self, data_file, aug, is_train=True): # parameters that would change on train/val set
60 | transform = self.trans_loader.get_composed_transform(aug)
61 | dataset = SimpleDataset(data_file, self.image_size, transform, is_train=is_train)
62 | data_loader_params = dict(batch_size=self.batch_size, shuffle=True, num_workers=12, pin_memory=True)
63 | data_loader = torch.utils.data.DataLoader(dataset, **data_loader_params)
64 |
65 | return data_loader
66 |
67 |
68 | class SetDataManager(DataManager):
69 | def __init__(self, image_size, n_way, n_support, n_query, n_eposide=100):
70 | super(SetDataManager, self).__init__()
71 | self.image_size = image_size
72 | self.n_way = n_way
73 | self.batch_size = n_support + n_query
74 | self.n_eposide = n_eposide
75 |
76 | self.trans_loader = TransformLoader(image_size)
77 |
78 | def get_data_loader(self, data_file, aug, is_train=True, attr_loc=False): # parameters that would change on train/val set
79 | transform = self.trans_loader.get_composed_transform(aug)
80 | dataset = SetDataset(data_file, self.batch_size, self.image_size, transform, is_train=is_train, attr_loc=attr_loc)
81 | sampler = EpisodicBatchSampler(len(dataset), self.n_way, self.n_eposide)
82 | data_loader_params = dict(batch_sampler=sampler, num_workers=12, pin_memory=True)
83 | data_loader = torch.utils.data.DataLoader(dataset, **data_loader_params)
84 | return data_loader
--------------------------------------------------------------------------------
/filelists/SUN/write_SUN_filelist.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from os import listdir
3 | from os.path import isfile, isdir, join
4 | import os
5 | import json
6 | import random
7 | import re
8 | import scipy.io as sio
9 |
10 | def read_imgid_label_pair(filename):
11 | cwd = os.getcwd()
12 | prefix = os.path.join(cwd, 'materials/sun/images')
13 |
14 | image_dict = sio.loadmat(filename)
15 | image_path = image_dict['images']
16 | label_to_imgid = {}
17 | label_to_path = {}
18 | for i in range(image_path.shape[0]):
19 | path = image_path[i]
20 | path_list = path[0][0].split('/')
21 | label = ''
22 | for j in range(len(path_list)):
23 | if j > 0 and j < len(path_list) - 1:
24 | label += path_list[j]
25 | if label not in label_to_imgid.keys():
26 | label_to_imgid[label] = []
27 | if label not in label_to_path.keys():
28 | label_to_path[label] = []
29 | label_to_imgid[label].append(i)
30 | label_to_path[label].append(prefix + '/' + path[0][0])
31 | return label_to_imgid, label_to_path
32 |
33 | def read_img_attr_label(filename):
34 | attr_dict = sio.loadmat(filename)
35 | img_attr_labels = attr_dict['labels_cv'] # (14340, 102)
36 | return img_attr_labels
37 |
38 | def get_cl_attr_label(img_attr_labels, label_to_imgid, idx_to_label, cl_num):
39 | cl_attr_probs = []
40 | for i in range(cl_num):
41 | label = idx_to_label[i]
42 | imgid_list = label_to_imgid[label]
43 | attr_probs = img_attr_labels[imgid_list].mean(0)
44 | cl_attr_probs.append(attr_probs)
45 | cl_attr_probs = np.array(cl_attr_probs)
46 | avg_prob = cl_attr_probs.mean()
47 | cl_attr_labels = (cl_attr_probs > avg_prob).astype('float32')
48 | return cl_attr_labels
49 |
50 | if __name__ == '__main__':
51 |
52 | prefix = './materials/sun/SUNAttributeDB'
53 |
54 | label_to_imgid, label_to_path = read_imgid_label_pair(os.path.join(prefix, 'images.mat'))
55 | img_attr_labels = read_img_attr_label(os.path.join(prefix, 'attributeLabels_continuous.mat'))
56 |
57 | all_label_list = list(label_to_imgid.keys())
58 | cl_num = len(all_label_list)
59 | all_label_list.sort()
60 | label_to_idx = {}
61 | idx_to_label = {}
62 | for i in range(len(all_label_list)):
63 | idx_to_label[i] = all_label_list[i]
64 | label_to_idx[all_label_list[i]] = i
65 |
66 | cl_attr_labels = get_cl_attr_label(img_attr_labels, label_to_imgid, idx_to_label, cl_num)
67 |
68 | savedir = './materials/sun/'
69 | dataset_list = ['base','val','novel']
70 |
71 | rs_label_list = list(label_to_imgid.keys())
72 | random.shuffle(rs_label_list)
73 | for dataset in dataset_list:
74 | file_list = []
75 | label_list = []
76 | for i, label in enumerate(rs_label_list):
77 | label_id = label_to_idx[label]
78 | if 'base' in dataset:
79 | if (i >= 0 and i < 430):
80 | file_list = file_list + label_to_path[label]
81 | label_list = label_list + np.repeat(label_id, len(label_to_path[label])).tolist()
82 | if 'val' in dataset:
83 | if (i >= 430 and i < 645):
84 | file_list = file_list + label_to_path[label]
85 | label_list = label_list + np.repeat(label_id, len(label_to_path[label])).tolist()
86 | if 'novel' in dataset:
87 | if (i >= 645):
88 | file_list = file_list + label_to_path[label]
89 | label_list = label_list + np.repeat(label_id, len(label_to_path[label])).tolist()
90 | with open(savedir + dataset + '.json', 'w') as outfile:
91 | json.dump({'label_names':all_label_list, 'image_names':file_list, 'image_labels':label_list,
92 | 'attr_labels': cl_attr_labels.tolist()}, outfile)
93 |
94 | print("%s -OK" %dataset)
95 |
--------------------------------------------------------------------------------
/filelists/CUB/write_CUB_filelist.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from os import listdir
3 | from os.path import isfile, isdir, join
4 | import os
5 | import json
6 | import random
7 | import re
8 |
9 | def read_img_id_pair(filename):
10 | img_to_id = dict()
11 | with open(filename, 'r') as fin:
12 | for line in fin.readlines():
13 | line_split = line.strip().split(' ')
14 | img_to_id[line_split[1]] = int(line_split[0])
15 | return img_to_id
16 |
17 | def read_parts(filename):
18 | id_to_parts = dict()
19 | with open(filename, 'r') as fin:
20 | for line in fin.readlines():
21 | line_split = line.strip().split(' ')
22 | img_id, part_id, x, y, visible = int(line_split[0]), int(line_split[1]), float(line_split[2]), float(line_split[3]), int(line_split[4])
23 | if part_id == 1:
24 | id_to_parts[img_id] = [[x, y, visible], ]
25 | else:
26 | id_to_parts[img_id].append([x, y, visible])
27 | return id_to_parts
28 |
29 | def read_img_attr_label(filename):
30 | imgid_to_attrlabel = dict()
31 | with open(filename, 'r') as fin:
32 | for line in fin.readlines():
33 | line_split = line.strip().split(' ')
34 | img_id, attr_label, visible = int(line_split[0]), int(line_split[1]), int(line_split[2])
35 | if visible:
36 | if img_id in imgid_to_attrlabel.keys():
37 | imgid_to_attrlabel[img_id].append(attr_label)
38 | else:
39 | imgid_to_attrlabel[img_id] = [attr_label]
40 | return imgid_to_attrlabel
41 |
42 | if __name__ == '__main__':
43 |
44 | img_to_idx = read_img_id_pair('./CUB_200_2011/images.txt')
45 | id_to_parts = read_parts('./CUB_200_2011/parts/part_locs.txt')
46 | id_to_attrlabel = read_img_attr_label('./CUB_200_2011/attributes/image_attribute_labels.txt')
47 |
48 | cwd = os.getcwd()
49 | data_path = join(cwd,'CUB_200_2011/images')
50 | savedir = './'
51 | dataset_list = ['base','val','novel']
52 |
53 | folder_list = [f for f in listdir(data_path) if isdir(join(data_path, f))]
54 | folder_list.sort()
55 | label_dict = dict(zip(folder_list,range(0,len(folder_list))))
56 |
57 | classfile_list_all = []
58 |
59 | for i, folder in enumerate(folder_list):
60 | folder_path = join(data_path, folder)
61 | classfile_list_all.append( [ join(folder_path, cf) for cf in listdir(folder_path) if (isfile(join(folder_path,cf)) and cf[0] != '.')])
62 | random.shuffle(classfile_list_all[i])
63 |
64 |
65 | for dataset in dataset_list:
66 | file_list = []
67 | label_list = []
68 | for i, classfile_list in enumerate(classfile_list_all):
69 | if 'base' in dataset:
70 | if (i%2 == 0):
71 | file_list = file_list + classfile_list
72 | label_list = label_list + np.repeat(i, len(classfile_list)).tolist()
73 | if 'val' in dataset:
74 | if (i%4 == 1):
75 | file_list = file_list + classfile_list
76 | label_list = label_list + np.repeat(i, len(classfile_list)).tolist()
77 | if 'novel' in dataset:
78 | if (i%4 == 3):
79 | file_list = file_list + classfile_list
80 | label_list = label_list + np.repeat(i, len(classfile_list)).tolist()
81 | attr_label_list = []
82 | for path in file_list:
83 | img = path.split('/')[-2] + '/' + path.split('/')[-1]
84 | attr_label_list.append(id_to_attrlabel[img_to_idx[img]])
85 | part_list = []
86 | for path in file_list:
87 | filename = re.search('/images/(.*)', path, flags=0).group(1)
88 | part_list.append(id_to_parts[img_to_idx[filename]])
89 | with open(savedir + dataset + '.json', 'w') as outfile:
90 | json.dump({'label_names':folder_list, 'image_names':file_list, 'image_labels':label_list, 'part': part_list,
91 | 'attr_labels': attr_label_list}, outfile)
92 |
93 | print("%s -OK" %dataset)
94 |
--------------------------------------------------------------------------------
/methods/matchingnet.py:
--------------------------------------------------------------------------------
1 | # This code is modified from https://github.com/facebookresearch/low-shot-shrink-hallucinate
2 |
3 | import backbone
4 | import torch
5 | import torch.nn as nn
6 | from torch.autograd import Variable
7 | import numpy as np
8 | import torch.nn.functional as F
9 | from methods.meta_template import MetaTemplate
10 | import utils
11 | import copy
12 |
13 | class MatchingNet(MetaTemplate):
14 | def __init__(self, model_func, n_way, n_support):
15 | super(MatchingNet, self).__init__( model_func, n_way, n_support)
16 |
17 | self.loss_fn = nn.NLLLoss()
18 |
19 | self.FCE = FullyContextualEmbedding(self.feat_dim)
20 | self.G_encoder = nn.LSTM(self.feat_dim, self.feat_dim, 1, batch_first=True, bidirectional=True)
21 |
22 | self.relu = nn.ReLU()
23 | self.softmax = nn.Softmax()
24 |
25 | def encode_training_set(self, S, G_encoder = None):
26 | if G_encoder is None:
27 | G_encoder = self.G_encoder
28 | out_G = G_encoder(S.unsqueeze(0))[0]
29 | out_G = out_G.squeeze(0)
30 | G = S + out_G[:,:S.size(1)] + out_G[:,S.size(1):]
31 | G_norm = torch.norm(G,p=2, dim =1).unsqueeze(1).expand_as(G)
32 | G_normalized = G.div(G_norm+ 0.00001)
33 | return G, G_normalized
34 |
35 | def get_logprobs(self, f, G, G_normalized, Y_S, FCE = None):
36 | if FCE is None:
37 | FCE = self.FCE
38 | F = FCE(f, G)
39 | F_norm = torch.norm(F,p=2, dim =1).unsqueeze(1).expand_as(F)
40 | F_normalized = F.div(F_norm+ 0.00001)
41 | #scores = F.mm(G_normalized.transpose(0,1)) #The implementation of Ross et al., but not consistent with origin paper and would cause large norm feature dominate
42 | scores = self.relu( F_normalized.mm(G_normalized.transpose(0,1)) ) *100 # The original paper use cosine simlarity, but here we scale it by 100 to strengthen highest probability after softmax
43 | softmax = self.softmax(scores)
44 | logprobs =(softmax.mm(Y_S)+1e-6).log()
45 | return logprobs
46 |
47 | def set_forward(self, x, is_feature = False):
48 | z_support, z_query = self.parse_feature(x,is_feature)
49 |
50 | z_support = z_support.contiguous().view( self.n_way* self.n_support, -1 )
51 | z_query = z_query.contiguous().view( self.n_way* self.n_query, -1 )
52 | G, G_normalized = self.encode_training_set( z_support)
53 |
54 | y_s = torch.from_numpy(np.repeat(range( self.n_way ), self.n_support ))
55 | Y_S = Variable( utils.one_hot(y_s, self.n_way ) ).cuda()
56 | f = z_query
57 | logprobs = self.get_logprobs(f, G, G_normalized, Y_S)
58 | return logprobs
59 |
60 | def set_forward_loss(self, x):
61 | y_query = torch.from_numpy(np.repeat(range( self.n_way ), self.n_query ))
62 | y_query = Variable(y_query.cuda())
63 |
64 | logprobs = self.set_forward(x)
65 |
66 | return self.loss_fn(logprobs, y_query )
67 |
68 | def cuda(self):
69 | super(MatchingNet, self).cuda()
70 | self.FCE = self.FCE.cuda()
71 | return self
72 |
73 | class FullyContextualEmbedding(nn.Module):
74 | def __init__(self, feat_dim):
75 | super(FullyContextualEmbedding, self).__init__()
76 | self.lstmcell = nn.LSTMCell(feat_dim*2, feat_dim)
77 | self.softmax = nn.Softmax()
78 | self.c_0 = Variable(torch.zeros(1,feat_dim))
79 | self.feat_dim = feat_dim
80 | #self.K = K
81 |
82 | def forward(self, f, G):
83 | h = f
84 | c = self.c_0.expand_as(f)
85 | G_T = G.transpose(0,1)
86 | K = G.size(0) #Tuna to be comfirmed
87 | for k in range(K):
88 | logit_a = h.mm(G_T)
89 | a = self.softmax(logit_a)
90 | r = a.mm(G)
91 | x = torch.cat((f, r),1)
92 |
93 | h, c = self.lstmcell(x, (h, c))
94 | h = h + f
95 |
96 | return h
97 | def cuda(self):
98 | super(FullyContextualEmbedding, self).cuda()
99 | self.c_0 = self.c_0.cuda()
100 | return self
101 |
102 |
--------------------------------------------------------------------------------
/methods/baselinetrain.py:
--------------------------------------------------------------------------------
1 | import backbone
2 | import utils
3 |
4 | import torch
5 | import torch.nn as nn
6 | from torch.autograd import Variable
7 | import numpy as np
8 | import torch.nn.functional as F
9 |
10 |
11 | class BaselineTrain(nn.Module):
12 | def __init__(self, model_func, num_class, loss_type='softmax'):
13 | super(BaselineTrain, self).__init__()
14 | self.feature = model_func()
15 | if loss_type == 'softmax':
16 | self.classifier = nn.Linear(self.feature.final_feat_dim, num_class)
17 | self.classifier.bias.data.fill_(0)
18 | elif loss_type == 'dist': # Baseline ++
19 | self.classifier = backbone.distLinear(self.feature.final_feat_dim, num_class)
20 | self.loss_type = loss_type # 'softmax' #'dist'
21 | self.num_class = num_class
22 | self.loss_fn = nn.CrossEntropyLoss()
23 | self.DBval = False; # only set True for CUB dataset, see issue #31
24 |
25 | def forward(self, x):
26 | x = Variable(x.cuda())
27 | out = self.feature.forward(x)
28 | scores = self.classifier.forward(out)
29 | return scores
30 |
31 | def forward_loss(self, x, y):
32 | scores = self.forward(x)
33 | y = Variable(y.cuda())
34 | return self.loss_fn(scores, y)
35 |
36 | def train_loop(self, epoch, train_loader, optimizer, tf_writer):
37 | print_freq = 10
38 | avg_loss = 0
39 |
40 | for i, (x, y) in enumerate(train_loader):
41 | optimizer.zero_grad()
42 | loss = self.forward_loss(x, y)
43 | loss.backward()
44 | optimizer.step()
45 |
46 | avg_loss = avg_loss + loss.item()
47 |
48 | if i % print_freq == 0:
49 | # print(optimizer.state_dict()['param_groups'][0]['lr'])
50 | print('Epoch {:d} | Batch {:d}/{:d} | Loss {:f}'.format(epoch, i, len(train_loader),
51 | avg_loss / float(i + 1)))
52 | tf_writer.add_scalar('loss/train', avg_loss / float(i + 1), epoch)
53 |
54 | def test_loop(self, val_loader):
55 | if self.DBval:
56 | return self.analysis_loop(val_loader)
57 | else:
58 | return -1 # no validation, just save model during iteration
59 |
60 | def analysis_loop(self, val_loader, record=None):
61 | class_file = {}
62 | for i, (x, y) in enumerate(val_loader):
63 | x = x.cuda()
64 | x_var = Variable(x)
65 | feats = self.feature.forward(x_var).data.cpu().numpy()
66 | labels = y.cpu().numpy()
67 | for f, l in zip(feats, labels):
68 | if l not in class_file.keys():
69 | class_file[l] = []
70 | class_file[l].append(f)
71 |
72 | for cl in class_file:
73 | class_file[cl] = np.array(class_file[cl])
74 |
75 | DB = DBindex(class_file)
76 | print('DB index = %4.2f' % (DB))
77 | return 1 / DB # DB index: the lower the better
78 |
79 |
80 | def DBindex(cl_data_file):
81 | # For the definition Davis Bouldin index (DBindex), see https://en.wikipedia.org/wiki/Davies%E2%80%93Bouldin_index
82 | # DB index present the intra-class variation of the data
83 | # As baseline/baseline++ do not train few-shot classifier in training, this is an alternative metric to evaluate the validation set
84 | # Emperically, this only works for CUB dataset but not for miniImagenet dataset
85 |
86 | class_list = cl_data_file.keys()
87 | cl_num = len(class_list)
88 | cl_means = []
89 | stds = []
90 | DBs = []
91 | for cl in class_list:
92 | cl_means.append(np.mean(cl_data_file[cl], axis=0))
93 | stds.append(np.sqrt(np.mean(np.sum(np.square(cl_data_file[cl] - cl_means[-1]), axis=1))))
94 |
95 | mu_i = np.tile(np.expand_dims(np.array(cl_means), axis=0), (len(class_list), 1, 1))
96 | mu_j = np.transpose(mu_i, (1, 0, 2))
97 | mdists = np.sqrt(np.sum(np.square(mu_i - mu_j), axis=2))
98 |
99 | for i in range(cl_num):
100 | DBs.append(np.max([(stds[i] + stds[j]) / mdists[i, j] for j in range(cl_num) if j != i]))
101 | return np.mean(DBs)
102 |
103 |
--------------------------------------------------------------------------------
/filelists/SUN/write_SUN_attr.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from os import listdir
3 | from os.path import isfile, isdir, join
4 | import os
5 | import json
6 | import random
7 | import re
8 | import scipy.io as sio
9 |
10 | def read_imgid_label_pair(filename):
11 | cwd = os.getcwd()
12 | prefix = os.path.join(cwd, 'materials/sun/images')
13 |
14 | image_dict = sio.loadmat(filename)
15 | image_path = image_dict['images']
16 | label_to_imgid = {}
17 | label_to_path = {}
18 |
19 | imgid_to_label = {}
20 | for i in range(image_path.shape[0]):
21 | path = image_path[i]
22 | path_list = path[0][0].split('/')
23 | label = ''
24 | for j in range(len(path_list)):
25 | if j > 0 and j < len(path_list) - 1:
26 | label += path_list[j]
27 | if label not in label_to_imgid.keys():
28 | label_to_imgid[label] = []
29 | if label not in label_to_path.keys():
30 | label_to_path[label] = []
31 | label_to_imgid[label].append(i)
32 | label_to_path[label].append(prefix + '/' + path[0][0])
33 |
34 | imgid_to_label[i] = label
35 | return label_to_imgid, label_to_path, imgid_to_label
36 |
37 | def read_img_attr_label(filename):
38 | attr_dict = sio.loadmat(filename)
39 | img_attr_labels = attr_dict['labels_cv'] # (14340, 102)
40 | return img_attr_labels
41 |
42 | # def get_cl_attr_label(img_attr_labels, label_to_imgid, idx_to_label, cl_num):
43 | # cl_attr_probs = []
44 | # for i in range(cl_num):
45 | # label = idx_to_label[i]
46 | # imgid_list = label_to_imgid[label]
47 | # attr_probs = img_attr_labels[imgid_list].mean(0)
48 | # cl_attr_probs.append(attr_probs)
49 | # cl_attr_probs = np.array(cl_attr_probs)
50 | # return cl_attr_probs
51 |
52 | def get_cl_attr_label(img_attr_labels, imgid_to_label, label_to_idx, cl_num):
53 | count = 0
54 | class_attr_count = np.zeros((cl_num, 102, 2))
55 |
56 | for i in range(img_attr_labels.shape[0]):
57 | count += 1
58 | class_label = imgid_to_label[i]
59 | class_label_id = label_to_idx[class_label]
60 |
61 | attr_labels = img_attr_labels[i, :]
62 | for j in range(102):
63 | attr_label_prob = attr_labels[j]
64 | if attr_label_prob >= 0.5:
65 | class_attr_count[class_label_id][j][1] += 1
66 | else:
67 | class_attr_count[class_label_id][j][0] += 1
68 | print("count:", count)
69 |
70 | class_attr_min_label = np.argmin(class_attr_count, axis=2)
71 | class_attr_max_label = np.argmax(class_attr_count, axis=2)
72 | equal_count = np.where(
73 | class_attr_min_label == class_attr_max_label) # check where 0 count = 1 count, set the corresponding class attribute label to be 1
74 | class_attr_max_label[equal_count] = 1
75 |
76 | min_class_count = 10
77 | attr_class_count = np.sum(class_attr_max_label, axis=0)
78 | mask = np.where(attr_class_count >= min_class_count)[
79 | 0] # select attributes that are present (on a class level) in at least [min_class_count] classes
80 | class_attr_label_masked = class_attr_max_label[:, mask]
81 | return class_attr_label_masked
82 |
83 |
84 | if __name__ == '__main__':
85 |
86 | prefix = './materials/sun/SUNAttributeDB'
87 |
88 | label_to_imgid, label_to_path, imgid_to_label = read_imgid_label_pair(os.path.join(prefix, 'images.mat'))
89 | img_attr_labels = read_img_attr_label(os.path.join(prefix, 'attributeLabels_continuous.mat'))
90 |
91 | all_label_list = list(label_to_imgid.keys())
92 | cl_num = len(all_label_list)
93 | all_label_list.sort()
94 | label_to_idx = {}
95 | idx_to_label = {}
96 | for i in range(len(all_label_list)):
97 | idx_to_label[i] = all_label_list[i]
98 | label_to_idx[all_label_list[i]] = i
99 |
100 | #cl_attr_probs = get_cl_attr_label(img_attr_labels, label_to_imgid, idx_to_label, cl_num)
101 | masked_cl_attr_probs = get_cl_attr_label(img_attr_labels, imgid_to_label, label_to_idx, cl_num)
102 |
103 | savedir = './materials/sun/'
104 |
105 | with open(savedir + 'masked_attr_dist.json', 'w') as outfile:
106 | json.dump({'attr_dist': masked_cl_attr_probs.tolist()}, outfile)
107 | print("%s -OK")
108 |
--------------------------------------------------------------------------------
/io_utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import os
3 | import glob
4 | import argparse
5 | import backbone
6 |
7 | model_dict = dict(
8 | Conv4=backbone.Conv4,
9 | Conv4NP=backbone.Conv4NP,
10 | Conv4S=backbone.Conv4S,
11 | Conv6=backbone.Conv6,
12 | Conv6NP=backbone.Conv6NP,
13 | ResNet10=backbone.ResNet10,
14 | ResNet10NP=backbone.ResNet10NP,
15 | ResNet18=backbone.ResNet18,
16 | ResNet34=backbone.ResNet34,
17 | ResNet50=backbone.ResNet50,
18 | ResNet101=backbone.ResNet101)
19 |
20 |
21 | def parse_args(script):
22 | parser = argparse.ArgumentParser(description='few-shot script %s' % (script))
23 | parser.add_argument('--dataset', default='CUB', help='CUB/miniImagenet/cross/omniglot/cross_char')
24 | parser.add_argument('--model', default='Conv4',
25 | help='model: Conv{4|6} / ResNet{10|18|34|50|101}') # 50 and 101 are not used in the paper
26 | parser.add_argument('--method', default='baseline',
27 | help='baseline/baseline++/protonet/matchingnet/relationnet{_softmax}/maml{_approx}') # relationnet_softmax replace L2 norm with softmax to expedite training, maml_approx use first-order approximation in the gradient for efficiency
28 | parser.add_argument('--train_n_way', default=5, type=int,
29 | help='class num to classify for training') # baseline and baseline++ would ignore this parameter
30 | parser.add_argument('--test_n_way', default=5, type=int,
31 | help='class num to classify for testing (validation) ') # baseline and baseline++ only use this parameter in finetuning
32 | parser.add_argument('--n_shot', default=5, type=int,
33 | help='number of labeled data in each class, same as n_support') # baseline and baseline++ only use this parameter in finetuning
34 | parser.add_argument('--train_aug', action='store_true',
35 | help='perform data augmentation or not during training ') # still required for save_features.py and test.py to find the model path correctly
36 | parser.add_argument('--exp_str', default='0', type=str, help='just to add some clarification for each exp')
37 |
38 | if script == 'train':
39 | parser.add_argument('--num_classes', default=200, type=int,
40 | help='total number of classes in softmax, only used in baseline') # make it larger than the maximum label value in base class
41 | parser.add_argument('--save_freq', default=50, type=int, help='Save frequency')
42 | parser.add_argument('--start_epoch', default=0, type=int, help='Starting epoch')
43 | parser.add_argument('--stop_epoch', default=-1, type=int,
44 | help='Stopping epoch') # for meta-learning methods, each epoch contains 100 episodes. The default epoch number is dataset dependent. See train.py
45 | parser.add_argument('--resume', action='store_true',
46 | help='continue from previous trained model with largest epoch')
47 | parser.add_argument('--warmup', action='store_true',
48 | help='continue from baseline, neglected if resume is true') # never used in the paper
49 | parser.add_argument('--beta', default=0.0, type=float, help='Coefficient for attribute loss')
50 | parser.add_argument('--attr_weight', default=15, type=float, help='attr_loss weight for COMP')
51 | parser.add_argument('--orth_weight', default=0.00035, type=float, help='orth_loss weight for COMP')
52 | elif script == 'save_features':
53 | parser.add_argument('--split', default='novel',
54 | help='base/val/novel') # default novel, but you can also test base/val class accuracy if you want
55 | parser.add_argument('--save_iter', default=-1, type=int,
56 | help='save feature from the model trained in x epoch, use the best model if x is -1')
57 | elif script == 'test':
58 | parser.add_argument('--split', default='novel',
59 | help='base/val/novel') # default novel, but you can also test base/val class accuracy if you want
60 | parser.add_argument('--save_iter', default=-1, type=int,
61 | help='saved feature from the model trained in x epoch, use the best model if x is -1')
62 | parser.add_argument('--adaptation', action='store_true', help='further adaptation in test time or not')
63 | parser.add_argument('--runs', default=None, type=str, help='runs for write the dis and acc files')
64 | else:
65 | raise ValueError('Unknown script')
66 |
67 | return parser.parse_args()
68 |
69 |
70 | def get_assigned_file(checkpoint_dir, num):
71 | assign_file = os.path.join(checkpoint_dir, '{:d}.tar'.format(num))
72 | return assign_file
73 |
74 |
75 | def get_resume_file(checkpoint_dir):
76 | filelist = glob.glob(os.path.join(checkpoint_dir, '*.tar'))
77 | if len(filelist) == 0:
78 | return None
79 |
80 | filelist = [x for x in filelist if os.path.basename(x) != 'best_model.tar']
81 | epochs = np.array([int(os.path.splitext(os.path.basename(x))[0]) for x in filelist])
82 | max_epoch = np.max(epochs)
83 | resume_file = os.path.join(checkpoint_dir, '{:d}.tar'.format(max_epoch))
84 | return resume_file
85 |
86 |
87 | def get_best_file(checkpoint_dir):
88 | best_file = os.path.join(checkpoint_dir, 'best_model.tar')
89 | if os.path.isfile(best_file):
90 | return best_file
91 | else:
92 | return get_resume_file(checkpoint_dir)
93 |
--------------------------------------------------------------------------------
/save_features.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | from torch.autograd import Variable
4 | import os
5 | import glob
6 | import h5py
7 |
8 | import configs
9 | import backbone
10 | from data.datamgr import SimpleDataManager
11 | from methods.baselinetrain import BaselineTrain
12 | from methods.baselinefinetune import BaselineFinetune
13 | from methods.protonet import ProtoNet
14 | from methods.matchingnet import MatchingNet
15 | from methods.relationnet import RelationNet
16 | from methods.maml import MAML
17 | from io_utils import model_dict, parse_args, get_resume_file, get_best_file, get_assigned_file
18 |
19 |
20 | def save_features(model, data_loader, outfile):
21 | f = h5py.File(outfile, 'w')
22 | max_count = len(data_loader) * data_loader.batch_size
23 | all_labels = f.create_dataset('all_labels', (max_count,), dtype='i')
24 | all_feats = None
25 | count = 0
26 | for i, (x, y) in enumerate(data_loader):
27 | if i % 10 == 0:
28 | print('{:d}/{:d}'.format(i, len(data_loader)))
29 | x = x.cuda()
30 | x_var = Variable(x)
31 | feats = model(x_var)
32 | if all_feats is None:
33 | all_feats = f.create_dataset('all_feats', [max_count] + list(feats.size()[1:]), dtype='f')
34 | all_feats[count:count + feats.size(0)] = feats.data.cpu().numpy()
35 | all_labels[count:count + feats.size(0)] = y.cpu().numpy()
36 | count = count + feats.size(0)
37 |
38 | count_var = f.create_dataset('count', (1,), dtype='i')
39 | count_var[0] = count
40 |
41 | f.close()
42 |
43 |
44 | if __name__ == '__main__':
45 | params = parse_args('save_features')
46 | assert params.method != 'maml' and params.method != 'maml_approx', 'maml do not support save_feature and run'
47 |
48 | if 'Conv' in params.model:
49 | if params.dataset in ['omniglot', 'cross_char']:
50 | image_size = 28
51 | else:
52 | image_size = 84
53 | else:
54 | image_size = 224
55 |
56 | if params.dataset in ['omniglot', 'cross_char']:
57 | assert params.model == 'Conv4' and not params.train_aug, 'omniglot only support Conv4 without augmentation'
58 | params.model = 'Conv4S'
59 |
60 | split = params.split
61 | if params.dataset == 'cross':
62 | if split == 'base':
63 | loadfile = configs.data_dir['miniImagenet'] + 'all.json'
64 | else:
65 | loadfile = configs.data_dir['CUB'] + split + '.json'
66 | elif params.dataset == 'cross_char':
67 | if split == 'base':
68 | loadfile = configs.data_dir['omniglot'] + 'noLatin.json'
69 | else:
70 | loadfile = configs.data_dir['emnist'] + split + '.json'
71 | elif params.dataset == 'AWA2':
72 | # for AWA2, we use both validation and test classes as novel classes, because the number of test classes in AWA2 is too small (only 10)
73 | loadfile = configs.data_dir[params.dataset] + 'val_novel.json'
74 | else:
75 | loadfile = configs.data_dir[params.dataset] + split + '.json'
76 |
77 | checkpoint_dir = '%s/checkpoints/%s/%s_%s_%s' % (
78 | configs.save_dir, params.dataset, params.model, params.method, params.exp_str)
79 | if params.train_aug:
80 | checkpoint_dir += '_aug'
81 | if not params.method in ['baseline', 'baseline++', 'comp']:
82 | checkpoint_dir += '_%dway_%dshot' % (params.train_n_way, params.n_shot)
83 |
84 | if params.save_iter != -1:
85 | modelfile = get_assigned_file(checkpoint_dir, params.save_iter)
86 | else:
87 | modelfile = get_best_file(checkpoint_dir)
88 |
89 | if params.save_iter != -1:
90 | outfile = os.path.join(checkpoint_dir.replace("checkpoints", "features"),
91 | split + "_" + str(params.save_iter) + ".hdf5")
92 | else:
93 | outfile = os.path.join(checkpoint_dir.replace("checkpoints", "features"), split + ".hdf5")
94 |
95 | datamgr = SimpleDataManager(image_size, batch_size=64)
96 | data_loader = datamgr.get_data_loader(loadfile, aug=False, is_train=False)
97 |
98 | if params.method in ['relationnet', 'relationnet_softmax']:
99 | if params.model == 'Conv4':
100 | model = backbone.Conv4NP()
101 | elif params.model == 'Conv6':
102 | model = backbone.Conv6NP()
103 | elif params.model == 'Conv4S':
104 | model = backbone.Conv4SNP()
105 | else:
106 | model = model_dict[params.model](flatten=False)
107 | elif params.method in ['maml', 'maml_approx']:
108 | raise ValueError('MAML do not support save feature')
109 | else:
110 | model = model_dict[params.model]()
111 |
112 | model = model.cuda()
113 | print(modelfile)
114 | tmp = torch.load(modelfile)
115 | state = tmp['state']
116 | state_keys = list(state.keys())
117 | for i, key in enumerate(state_keys):
118 | if "feature." in key:
119 | newkey = key.replace("feature.",
120 | "") # an architecture model has attribute 'feature', load architecture feature to backbone by casting name from 'feature.trunk.xx' to 'trunk.xx'
121 | state[newkey] = state.pop(key)
122 | else:
123 | state.pop(key)
124 |
125 | model.load_state_dict(state)
126 | model.eval()
127 |
128 | dirname = os.path.dirname(outfile)
129 | if not os.path.isdir(dirname):
130 | os.makedirs(dirname)
131 | save_features(model, data_loader, outfile)
132 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | import torch
4 | import matplotlib.pyplot as plt
5 |
6 | def one_hot(y, num_class):
7 | return torch.zeros((len(y), num_class)).scatter_(1, y.unsqueeze(1), 1)
8 |
9 |
10 | def DBindex(cl_data_file):
11 | class_list = cl_data_file.keys()
12 | cl_num = len(class_list)
13 | cl_means = []
14 | stds = []
15 | DBs = []
16 | for cl in class_list:
17 | cl_means.append(np.mean(cl_data_file[cl], axis=0))
18 | stds.append(np.sqrt(np.mean(np.sum(np.square(cl_data_file[cl] - cl_means[-1]), axis=1))))
19 |
20 | mu_i = np.tile(np.expand_dims(np.array(cl_means), axis=0), (len(class_list), 1, 1))
21 | mu_j = np.transpose(mu_i, (1, 0, 2))
22 | mdists = np.sqrt(np.sum(np.square(mu_i - mu_j), axis=2))
23 |
24 | for i in range(cl_num):
25 | DBs.append(np.max([(stds[i] + stds[j]) / mdists[i, j] for j in range(cl_num) if j != i]))
26 | return np.mean(DBs)
27 |
28 |
29 | def sparsity(cl_data_file):
30 | class_list = cl_data_file.keys()
31 | cl_sparsity = []
32 | for cl in class_list:
33 | cl_sparsity.append(np.mean([np.sum(x != 0) for x in cl_data_file[cl]]))
34 |
35 | return np.mean(cl_sparsity)
36 |
37 | """
38 | Files for plot figs of adaptation difficulty
39 | """
40 |
41 | def read_attr_dists(trainloader, dataset):
42 | if dataset == 'SUN':
43 | print("attribute distance for SUN!")
44 | attr_dists = trainloader.dataset.meta['attr_labels']
45 | attr_dists_array = np.array(attr_dists).astype('float32')
46 | attr_dists = torch.from_numpy(attr_dists_array)
47 |
48 | base_labels = trainloader.dataset.cl_list
49 | base_ind = np.unique(base_labels).tolist()
50 | elif dataset == 'CUB':
51 | print("attribute distance for CUB!")
52 | filename = 'filelists/CUB/CUB_200_2011/masked_class_attribute_labels.txt'
53 | attr_dists = []
54 | with open(filename, 'r') as f:
55 | for line in f.readlines():
56 | line_split = line.strip().split(' ')
57 | float_line = []
58 | for str_num in line_split:
59 | float_line.append(float(str_num))
60 | attr_dists.append(float_line)
61 |
62 | attr_dists_array = np.array(attr_dists)
63 | attr_dists = torch.from_numpy(attr_dists_array)
64 |
65 | base_ind = []
66 | for i in range(200):
67 | if i % 2 == 0:
68 | base_ind.append(i)
69 | elif dataset == 'AWA2':
70 | print("attribute distance for AWA2!")
71 | filename = 'filelists/AWA2/class_attribute_label.txt'
72 | attr_dists = []
73 | with open(filename, 'r') as f:
74 | for line in f.readlines():
75 | line_split = line.strip().split(' ')
76 | float_line = []
77 | for str_num in line_split:
78 | float_line.append(float(str_num))
79 | attr_dists.append(float_line)
80 |
81 | attr_dists_array = np.array(attr_dists)
82 | attr_dists = torch.from_numpy(attr_dists_array)
83 |
84 | base_labels = trainloader.dataset.cl_list
85 | base_ind = np.unique(base_labels).tolist()
86 | else:
87 | AssertionError("not implement!")
88 | return attr_dists, base_ind
89 |
90 | def get_attr_distance(trainloader, dataset):
91 | attr_dists, base_ind = read_attr_dists(trainloader, dataset)
92 |
93 | # class-agnostic or task-agnostic
94 | # part_dists = _dists_check(part_dists)
95 | # base_dists = part_dists[base_ind, :].mean(0) # (102,)
96 | #
97 | # all_cls_dists = part_dists
98 | # base_cls_dists = part_dists[base_ind, :] # (100, 102)
99 |
100 | # original
101 | import random
102 | base_cls_dists = []
103 | sc_cls_lists = [random.sample(base_ind, 5) for _ in range(10000)]
104 | for sc_cls in sc_cls_lists:
105 | sc_dists = attr_dists[sc_cls, :]
106 | base_cls_dists.append(sc_dists)
107 | base_cls_dists = torch.stack(base_cls_dists, dim=0)
108 | all_cls_dists = attr_dists
109 | base_dists = base_cls_dists.mean(1) # (task_num, 102)
110 |
111 | return all_cls_dists, base_dists, base_cls_dists
112 |
113 | def interval_avg(acc_all, dist_all):
114 | min_d = np.min(dist_all)
115 | max_d = np.max(dist_all)
116 | inr = (max_d - min_d) / 9
117 | acc_inr = [0 for _ in range(9)]
118 | dis_inr = [0 for _ in range(9)]
119 | cout_inr = [0 for _ in range(9)]
120 | for dis, acc in zip(dist_all, acc_all):
121 | for i in range(9):
122 | min_i = min_d + i * inr
123 | max_i = min_d + (i + 1) * inr
124 |
125 | if dis >= min_i and dis <= max_i:
126 | acc_inr[i] += acc
127 | dis_inr[i] += dis
128 | cout_inr[i] += 1
129 |
130 | acc_avg_inr, dis_avg_inr = [], []
131 | for acc, dis, num in zip(acc_inr, dis_inr, cout_inr):
132 | if num != 0:
133 | acc_avg_inr.append(1.0 * acc / num)
134 | dis_avg_inr.append(1.0 * dis / num)
135 | return acc_avg_inr, dis_avg_inr
136 |
137 |
138 | def plot_fig(acc_all, dist_all):
139 | acc_avg_inr, dis_avg_inr = interval_avg(acc_all, dist_all)
140 | print("acc_avg_inr:", acc_avg_inr)
141 | print("dis_avg_inr:", dis_avg_inr)
142 |
143 | plt.scatter(dist_all, acc_all)
144 | plt.scatter(dis_avg_inr, acc_avg_inr, s=40, marker='x', c='red')
145 |
146 | for x, y in zip(dis_avg_inr, acc_avg_inr):
147 | plt.annotate("%.1f" % (y), xy=(x, y), xytext=(x - 0.005, y + 1.5), color='r', weight='heavy')
148 |
149 | plt.plot(dis_avg_inr, acc_avg_inr, c='red')
150 | # plt.show()
151 | plt.savefig('dist_acc.pdf')
152 | plt.close()
--------------------------------------------------------------------------------
/methods/meta_template.py:
--------------------------------------------------------------------------------
1 | import backbone
2 | import torch
3 | import torch.nn as nn
4 | from torch.autograd import Variable
5 | import numpy as np
6 | import torch.nn.functional as F
7 | import utils
8 | from abc import abstractmethod
9 |
10 |
11 | class MetaTemplate(nn.Module):
12 | def __init__(self, model_func, n_way, n_support, change_way=True):
13 | super(MetaTemplate, self).__init__()
14 | self.n_way = n_way
15 | self.n_support = n_support
16 | self.n_query = -1 # (change depends on input)
17 | self.feature = model_func()
18 | self.feat_dim = self.feature.final_feat_dim
19 | self.change_way = change_way # some methods allow different_way classification during training and test
20 |
21 | @abstractmethod
22 | def set_forward(self, x, is_feature):
23 | pass
24 |
25 | @abstractmethod
26 | def set_forward_loss(self, x):
27 | pass
28 |
29 | def forward(self, x):
30 | out = self.feature.forward(x)
31 | return out
32 |
33 | def parse_feature(self, x, is_feature):
34 | x = Variable(x.cuda())
35 | if is_feature:
36 | z_all = x
37 | else:
38 | x = x.contiguous().view(self.n_way * (self.n_support + self.n_query), *x.size()[2:])
39 | z_all = self.feature.forward(x)
40 | z_all = z_all.view(self.n_way, self.n_support + self.n_query, -1)
41 | z_support = z_all[:, :self.n_support]
42 | z_query = z_all[:, self.n_support:]
43 |
44 | return z_support, z_query
45 |
46 | def correct(self, x):
47 | scores = self.set_forward(x)
48 | y_query = np.repeat(range(self.n_way), self.n_query)
49 |
50 | topk_scores, topk_labels = scores.data.topk(1, 1, True, True)
51 | topk_ind = topk_labels.cpu().numpy()
52 | top1_correct = np.sum(topk_ind[:, 0] == y_query)
53 | return float(top1_correct), len(y_query)
54 |
55 | def train_loop(self, epoch, train_loader, optimizer, tf_writer):
56 | print_freq = 10
57 |
58 | avg_loss = 0
59 | for i, (x, _) in enumerate(train_loader):
60 | self.n_query = x.size(1) - self.n_support
61 | if self.change_way:
62 | self.n_way = x.size(0)
63 | optimizer.zero_grad()
64 | loss = self.set_forward_loss(x)
65 | loss.backward()
66 | optimizer.step()
67 | avg_loss = avg_loss + loss.item()
68 |
69 | if i % print_freq == 0:
70 | # print(optimizer.state_dict()['param_groups'][0]['lr'])
71 | print('Epoch {:d} | Batch {:d}/{:d} | Loss {:f}'.format(epoch, i, len(train_loader),
72 | avg_loss / float(i + 1)))
73 | tf_writer.add_scalar('loss/train', avg_loss / float(i + 1), epoch)
74 |
75 | def test_loop(self, test_loader, record=None):
76 | correct = 0
77 | count = 0
78 | acc_all = []
79 |
80 | iter_num = len(test_loader)
81 | from tqdm import tqdm
82 | for i, (x, _) in enumerate(tqdm(test_loader)):
83 | self.n_query = x.size(1) - self.n_support
84 | if self.change_way:
85 | self.n_way = x.size(0)
86 | correct_this, count_this = self.correct(x)
87 | acc_all.append(correct_this / count_this * 100)
88 |
89 | acc_all = np.asarray(acc_all)
90 | acc_mean = np.mean(acc_all)
91 | acc_std = np.std(acc_all)
92 | print('%d Test Acc = %4.2f%% +- %4.2f%%' % (iter_num, acc_mean, 1.96 * acc_std / np.sqrt(iter_num)))
93 |
94 | return acc_mean
95 |
96 | def set_forward_adaptation(self, x,
97 | is_feature=True): # further adaptation, default is fixing feature and train a new softmax clasifier
98 | assert is_feature == True, 'Feature is fixed in further adaptation'
99 | z_support, z_query = self.parse_feature(x, is_feature)
100 |
101 | z_support = z_support.contiguous().view(self.n_way * self.n_support, -1)
102 | z_query = z_query.contiguous().view(self.n_way * self.n_query, -1)
103 |
104 | y_support = torch.from_numpy(np.repeat(range(self.n_way), self.n_support))
105 | y_support = Variable(y_support.cuda())
106 |
107 | linear_clf = nn.Linear(self.feat_dim, self.n_way)
108 | linear_clf = linear_clf.cuda()
109 |
110 | set_optimizer = torch.optim.SGD(linear_clf.parameters(), lr=0.01, momentum=0.9, dampening=0.9,
111 | weight_decay=0.001)
112 |
113 | loss_function = nn.CrossEntropyLoss()
114 | loss_function = loss_function.cuda()
115 |
116 | batch_size = 4
117 | support_size = self.n_way * self.n_support
118 | for epoch in range(100):
119 | rand_id = np.random.permutation(support_size)
120 | for i in range(0, support_size, batch_size):
121 | set_optimizer.zero_grad()
122 | selected_id = torch.from_numpy(rand_id[i: min(i + batch_size, support_size)]).cuda()
123 | z_batch = z_support[selected_id]
124 | y_batch = y_support[selected_id]
125 | scores = linear_clf(z_batch)
126 | loss = loss_function(scores, y_batch)
127 | loss.backward()
128 | set_optimizer.step()
129 |
130 | scores = linear_clf(z_query)
131 | return scores
132 |
133 | def test_loop_with_dists(self, test_loader, all_cls_dists, base_dists, base_cls_dists, attr_num):
134 | acc_all, dist_all = [], []
135 |
136 | iter_num = len(test_loader)
137 | from tqdm import tqdm
138 | for i, (x, y) in enumerate(tqdm(test_loader)):
139 | x, y = x.cuda(), y.cuda()
140 | self.n_query = x.size(1) - self.n_support
141 | if self.change_way:
142 | self.n_way = x.size(0)
143 | correct_this, count_this = self.correct(x)
144 | acc_all.append(correct_this / count_this * 100)
145 |
146 | sc_cls = y.unique()
147 | # original mean-task (down trend)
148 | task_dists = all_cls_dists[sc_cls, :].mean(0).unsqueeze(0)
149 | dist_all.append(torch.abs(base_dists - task_dists).sum(-1).mean().item() / attr_num)
150 |
151 | acc_all = np.asarray(acc_all)
152 | acc_mean = np.mean(acc_all)
153 | acc_std = np.std(acc_all)
154 | print('%d Test Acc = %4.2f%% +- %4.2f%%' % (iter_num, acc_mean, 1.96 * acc_std / np.sqrt(iter_num)))
155 |
156 | return acc_all, dist_all
--------------------------------------------------------------------------------
/methods/relationnet.py:
--------------------------------------------------------------------------------
1 | # This code is modified from https://github.com/floodsung/LearningToCompare_FSL
2 |
3 | import backbone
4 | import torch
5 | import torch.nn as nn
6 | from torch.autograd import Variable
7 | import numpy as np
8 | import torch.nn.functional as F
9 | from methods.meta_template import MetaTemplate
10 | import utils
11 |
12 |
13 | class RelationNet(MetaTemplate):
14 | def __init__(self, model_func, n_way, n_support, loss_type='mse'):
15 | super(RelationNet, self).__init__(model_func, n_way, n_support)
16 |
17 | self.loss_type = loss_type # 'softmax'# 'mse'
18 | self.relation_module = RelationModule(self.feat_dim, 8,
19 | self.loss_type) # relation net features are not pooled, so self.feat_dim is [dim, w, h]
20 |
21 | if self.loss_type == 'mse':
22 | self.loss_fn = nn.MSELoss()
23 | else:
24 | self.loss_fn = nn.CrossEntropyLoss()
25 |
26 | def set_forward(self, x, is_feature=False):
27 | z_support, z_query = self.parse_feature(x, is_feature)
28 |
29 | z_support = z_support.contiguous()
30 | z_proto = z_support.view(self.n_way, self.n_support, *self.feat_dim).mean(1)
31 | z_query = z_query.contiguous().view(self.n_way * self.n_query, *self.feat_dim)
32 |
33 | z_proto_ext = z_proto.unsqueeze(0).repeat(self.n_query * self.n_way, 1, 1, 1, 1)
34 | z_query_ext = z_query.unsqueeze(0).repeat(self.n_way, 1, 1, 1, 1)
35 | z_query_ext = torch.transpose(z_query_ext, 0, 1)
36 | extend_final_feat_dim = self.feat_dim.copy()
37 | extend_final_feat_dim[0] *= 2
38 | relation_pairs = torch.cat((z_proto_ext, z_query_ext), 2).view(-1, *extend_final_feat_dim)
39 | relations = self.relation_module(relation_pairs).view(-1, self.n_way)
40 |
41 | return relations
42 |
43 | def set_forward_adaptation(self, x, is_feature=True): # overwrite parent function
44 | assert is_feature == True, 'Finetune only support fixed feature'
45 | full_n_support = self.n_support
46 | full_n_query = self.n_query
47 | relation_module_clone = RelationModule(self.feat_dim, 8, self.loss_type)
48 | relation_module_clone.load_state_dict(self.relation_module.state_dict())
49 |
50 | z_support, z_query = self.parse_feature(x, is_feature)
51 | z_support = z_support.contiguous()
52 | set_optimizer = torch.optim.SGD(self.relation_module.parameters(), lr=0.01, momentum=0.9, dampening=0.9,
53 | weight_decay=0.001)
54 |
55 | self.n_support = 3
56 | self.n_query = 2
57 |
58 | z_support_cpu = z_support.data.cpu().numpy()
59 | for epoch in range(100):
60 | perm_id = np.random.permutation(full_n_support).tolist()
61 | sub_x = np.array([z_support_cpu[i, perm_id, :, :, :] for i in range(z_support.size(0))])
62 | sub_x = torch.Tensor(sub_x).cuda()
63 | if self.change_way:
64 | self.n_way = sub_x.size(0)
65 | set_optimizer.zero_grad()
66 | y = torch.from_numpy(np.repeat(range(self.n_way), self.n_query))
67 | scores = self.set_forward(sub_x, is_feature=True)
68 | if self.loss_type == 'mse':
69 | y_oh = utils.one_hot(y, self.n_way)
70 | y_oh = Variable(y_oh.cuda())
71 |
72 | loss = self.loss_fn(scores, y_oh)
73 | else:
74 | y = Variable(y.cuda())
75 | loss = self.loss_fn(scores, y)
76 | loss.backward()
77 | set_optimizer.step()
78 |
79 | self.n_support = full_n_support
80 | self.n_query = full_n_query
81 | z_proto = z_support.view(self.n_way, self.n_support, *self.feat_dim).mean(1)
82 | z_query = z_query.contiguous().view(self.n_way * self.n_query, *self.feat_dim)
83 |
84 | z_proto_ext = z_proto.unsqueeze(0).repeat(self.n_query * self.n_way, 1, 1, 1, 1)
85 | z_query_ext = z_query.unsqueeze(0).repeat(self.n_way, 1, 1, 1, 1)
86 | z_query_ext = torch.transpose(z_query_ext, 0, 1)
87 | extend_final_feat_dim = self.feat_dim.copy()
88 | extend_final_feat_dim[0] *= 2
89 | relation_pairs = torch.cat((z_proto_ext, z_query_ext), 2).view(-1, *extend_final_feat_dim)
90 | relations = self.relation_module(relation_pairs).view(-1, self.n_way)
91 |
92 | self.relation_module.load_state_dict(relation_module_clone.state_dict())
93 | return relations
94 |
95 | def set_forward_loss(self, x):
96 | y = torch.from_numpy(np.repeat(range(self.n_way), self.n_query))
97 |
98 | scores = self.set_forward(x)
99 | if self.loss_type == 'mse':
100 | y_oh = utils.one_hot(y, self.n_way)
101 | y_oh = Variable(y_oh.cuda())
102 |
103 | return self.loss_fn(scores, y_oh)
104 | else:
105 | y = Variable(y.cuda())
106 | return self.loss_fn(scores, y)
107 |
108 |
109 | class RelationConvBlock(nn.Module):
110 | def __init__(self, indim, outdim, padding=0):
111 | super(RelationConvBlock, self).__init__()
112 | self.indim = indim
113 | self.outdim = outdim
114 | self.C = nn.Conv2d(indim, outdim, 3, padding=padding)
115 | self.BN = nn.BatchNorm2d(outdim, momentum=1, affine=True)
116 | self.relu = nn.ReLU()
117 | self.pool = nn.MaxPool2d(2)
118 |
119 | self.parametrized_layers = [self.C, self.BN, self.relu, self.pool]
120 |
121 | for layer in self.parametrized_layers:
122 | backbone.init_layer(layer)
123 |
124 | self.trunk = nn.Sequential(*self.parametrized_layers)
125 |
126 | def forward(self, x):
127 | out = self.trunk(x)
128 | return out
129 |
130 |
131 | class RelationModule(nn.Module):
132 | """docstring for RelationNetwork"""
133 |
134 | def __init__(self, input_size, hidden_size, loss_type='mse'):
135 | super(RelationModule, self).__init__()
136 |
137 | self.loss_type = loss_type
138 | padding = 1 if (input_size[1] < 10) and (input_size[
139 | 2] < 10) else 0 # when using Resnet, conv map without avgpooling is 7x7, need padding in block to do pooling
140 |
141 | self.layer1 = RelationConvBlock(input_size[0] * 2, input_size[0], padding=padding)
142 | self.layer2 = RelationConvBlock(input_size[0], input_size[0], padding=padding)
143 |
144 | shrink_s = lambda s: int((int((s - 2 + 2 * padding) / 2) - 2 + 2 * padding) / 2)
145 |
146 | self.fc1 = nn.Linear(input_size[0] * shrink_s(input_size[1]) * shrink_s(input_size[2]), hidden_size)
147 | self.fc2 = nn.Linear(hidden_size, 1)
148 |
149 | def forward(self, x):
150 | out = self.layer1(x)
151 | out = self.layer2(out)
152 | out = out.view(out.size(0), -1)
153 | out = F.relu(self.fc1(out))
154 | if self.loss_type == 'mse':
155 | out = F.sigmoid(self.fc2(out))
156 | elif self.loss_type == 'softmax':
157 | out = self.fc2(out)
158 |
159 | return out
160 |
--------------------------------------------------------------------------------
/test.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | from torch.autograd import Variable
4 | import torch.nn as nn
5 | import torch.optim
6 | import json
7 | import torch.utils.data.sampler
8 | import os
9 | import glob
10 | import random
11 | import time
12 |
13 | import configs
14 | import backbone
15 | import data.feature_loader as feat_loader
16 | from data.datamgr import SetDataManager
17 | from methods.baselinetrain import BaselineTrain
18 | from methods.baselinefinetune import BaselineFinetune
19 | from methods.protonet import ProtoNet
20 | from methods.matchingnet import MatchingNet
21 | from methods.relationnet import RelationNet
22 | from methods.maml import MAML
23 | from methods.apnet import APNet_w_attrLoc, APNet_wo_attrLoc
24 | from io_utils import model_dict, parse_args, get_resume_file, get_best_file, get_assigned_file
25 |
26 | seed = 1
27 | np.random.seed(seed)
28 | torch.random.manual_seed(seed)
29 |
30 | def feature_evaluation(cl_data_file, model, n_way=5, n_support=5, n_query=15, adaptation=False):
31 | class_list = cl_data_file.keys()
32 |
33 | select_class = random.sample(class_list, n_way)
34 | z_all = []
35 | for cl in select_class:
36 | img_feat = cl_data_file[cl]
37 | perm_ids = np.random.permutation(len(img_feat)).tolist()
38 | z_all.append([np.squeeze(img_feat[perm_ids[i]]) for i in range(n_support + n_query)]) # stack each batch
39 |
40 | z_all = torch.from_numpy(np.array(z_all))
41 |
42 | model.n_query = n_query
43 | if adaptation:
44 | scores = model.set_forward_adaptation(z_all, is_feature=True)
45 | else:
46 | scores = model.set_forward(z_all, is_feature=True)
47 | pred = scores.data.cpu().numpy().argmax(axis=1)
48 | y = np.repeat(range(n_way), n_query)
49 | acc = np.mean(pred == y) * 100
50 | return acc
51 |
52 |
53 | if __name__ == '__main__':
54 | params = parse_args('test')
55 |
56 | acc_all = []
57 | iter_num = 600
58 | attr_loc = False
59 | few_shot_params = dict(n_way=params.test_n_way, n_support=params.n_shot)
60 |
61 | split = params.split
62 | if params.save_iter != -1:
63 | split_str = split + "_" + str(params.save_iter)
64 | else:
65 | split_str = split
66 | if 'Conv' in params.model:
67 | image_size = 84
68 | else:
69 | image_size = 224
70 | datamgr = SetDataManager(image_size, n_eposide=iter_num, n_query=15, **few_shot_params)
71 | loadfile = configs.data_dir[params.dataset] + split + '.json'
72 |
73 | if params.method == 'baseline':
74 | model = BaselineFinetune(model_dict[params.model], **few_shot_params)
75 | elif params.method == 'baseline++':
76 | model = BaselineFinetune(model_dict[params.model], loss_type='dist', **few_shot_params)
77 | elif params.method == 'protonet':
78 | model = ProtoNet(model_dict[params.model], **few_shot_params)
79 | elif params.method == 'comet':
80 | assert params.dataset == 'CUB'
81 | model = COMET(model_dict[params.model], **few_shot_params)
82 | elif params.method == 'matchingnet':
83 | model = MatchingNet(model_dict[params.model], **few_shot_params)
84 | elif params.method in ['relationnet', 'relationnet_softmax']:
85 | if params.model == 'Conv4':
86 | feature_model = backbone.Conv4NP
87 | elif params.model == 'Conv6':
88 | feature_model = backbone.Conv6NP
89 | elif params.model == 'Conv4S':
90 | feature_model = backbone.Conv4SNP
91 | else:
92 | feature_model = lambda: model_dict[params.model](flatten=False)
93 | loss_type = 'mse' if params.method == 'relationnet' else 'softmax'
94 | model = RelationNet(feature_model, loss_type=loss_type, **few_shot_params)
95 | elif params.method in ['maml', 'maml_approx']:
96 | backbone.ConvBlock.maml = True
97 | backbone.SimpleBlock.maml = True
98 | backbone.BottleneckBlock.maml = True
99 | backbone.ResNet.maml = True
100 | model = MAML(model_dict[params.model], approx=(params.method == 'maml_approx'), **few_shot_params)
101 | if params.dataset in ['omniglot', 'cross_char']: # maml use different parameter in omniglot
102 | model.n_task = 32
103 | model.task_update_num = 1
104 | model.train_lr = 0.1
105 | elif params.method == 'apnet':
106 | if params.dataset == 'CUB':
107 | attr_loc = True
108 | attr_num = 109
109 | elif params.dataset == 'SUN':
110 | attr_num = 102
111 | elif params.dataset == 'AWA2':
112 | attr_num = 85
113 | else:
114 | AssertionError("not implement!")
115 |
116 | few_shot_params = dict(n_way=params.test_n_way, n_support=params.n_shot, attr_num=attr_num, attr_loc=attr_loc, dataset=params.dataset)
117 | if attr_loc:
118 | model = APNet_w_attrLoc(model_dict[params.model], **few_shot_params)
119 | else:
120 | model = APNet_wo_attrLoc(model_dict[params.model], **few_shot_params)
121 | else:
122 | raise ValueError('Unknown method')
123 |
124 | model = model.cuda()
125 |
126 | checkpoint_dir = '%s/checkpoints/%s/%s_%s_%s' % (
127 | configs.save_dir, params.dataset, params.model, params.method, params.exp_str)
128 | if params.train_aug:
129 | checkpoint_dir += '_aug'
130 | if not params.method in ['baseline', 'baseline++']:
131 | checkpoint_dir += '_%dway_%dshot' % (params.train_n_way, params.n_shot)
132 |
133 | if not params.method in ['baseline', 'baseline++']:
134 | if params.save_iter != -1:
135 | modelfile = get_assigned_file(checkpoint_dir, params.save_iter)
136 | else:
137 | modelfile = get_best_file(checkpoint_dir)
138 | if modelfile is not None:
139 | tmp = torch.load(modelfile)
140 | model.load_state_dict(tmp['state'])
141 |
142 | if params.method in ['maml', 'maml_approx']: # maml do not support testing with feature
143 | novel_loader = datamgr.get_data_loader(loadfile, aug=False, is_train=False)
144 | if params.dataset == 'SUN' and params.model == 'Conv4':
145 | model.train_lr = 0.1
146 | if params.adaptation:
147 | model.task_update_num = 100 # We perform adaptation on MAML simply by updating more times.
148 | model.eval()
149 | acc_mean, acc_std = model.test_loop(novel_loader, return_std=True)
150 | else:
151 | novel_file = os.path.join(checkpoint_dir.replace("checkpoints", "features"),
152 | split_str + ".hdf5") # defaut split = novel, but you can also test base or val classes
153 | cl_data_file = feat_loader.init_loader(novel_file)
154 |
155 | from tqdm import tqdm
156 | for i in tqdm(range(iter_num)):
157 | acc = feature_evaluation(cl_data_file, model, n_query=15, adaptation=params.adaptation, **few_shot_params)
158 | acc_all.append(acc)
159 |
160 | acc_all = np.asarray(acc_all)
161 | acc_mean = np.mean(acc_all)
162 | acc_std = np.std(acc_all)
163 | print('%d Test Acc = %4.2f%% +- %4.2f%%' % (iter_num, acc_mean, 1.96 * acc_std / np.sqrt(iter_num)))
--------------------------------------------------------------------------------
/data/dataset.py:
--------------------------------------------------------------------------------
1 | # This code is modified from https://github.com/facebookresearch/low-shot-shrink-hallucinate
2 |
3 | import torch
4 | from PIL import Image
5 | import json
6 | import numpy as np
7 | import torchvision.transforms as transforms
8 | import os
9 | from torch.utils.data import Dataset
10 | import random
11 | import copy
12 | import cv2
13 | from .transforms import *
14 | import PIL
15 |
16 | identity = lambda x: x
17 |
18 | class SimpleDataset:
19 | def __init__(self, data_file, image_size, transform, target_transform=identity, is_train=True):
20 | with open(data_file, 'r') as f:
21 | self.meta = json.load(f)
22 | self.transform = transform
23 | self.target_transform = target_transform
24 | self.flip = is_train
25 | self.image_size = image_size
26 | self.is_train = is_train
27 |
28 | def __getitem__(self, i):
29 | image_path = os.path.join(self.meta['image_names'][i])
30 | # data_numpy = cv2.imread(image_path, cv2.IMREAD_COLOR | cv2.IMREAD_IGNORE_ORIENTATION)
31 |
32 | # used for SUN, cv2.imread returns None type
33 | data_numpy = np.array(PIL.Image.open(image_path).convert('RGB'))[:, :, ::-1] # to BGR
34 |
35 | if data_numpy is None:
36 | raise ValueError('Fail to read {}'.format(image_path))
37 |
38 | r = 0
39 | c = np.array([data_numpy.shape[1], data_numpy.shape[0]]) // 2
40 | s = np.array([data_numpy.shape[1], data_numpy.shape[0]]) // 160
41 |
42 | if self.is_train:
43 | sf = 0.25
44 | rf = 30
45 | s = s * np.clip(np.random.randn() * sf + 1, 1 - sf, 1 + sf)
46 | r = np.clip(np.random.randn() * rf, -rf * 2, rf * 2) \
47 | if random.random() <= 0.6 else 0
48 |
49 | if self.flip and random.random() <= 0.5:
50 | data_numpy = data_numpy[:, ::-1, :]
51 | c[0] = data_numpy.shape[1] - c[0] - 1
52 |
53 | trans = get_affine_transform(c, s, r, [self.image_size, self.image_size])
54 | input = cv2.warpAffine(
55 | data_numpy,
56 | trans,
57 | (int(self.image_size), int(self.image_size)),
58 | flags=cv2.INTER_LINEAR)
59 | input = cv2.cvtColor(input, cv2.COLOR_BGR2RGB)
60 | input = Image.fromarray(input.transpose((1, 0, 2)))
61 |
62 | if self.transform:
63 | input = self.transform(input)
64 | target = self.target_transform(self.meta['image_labels'][i])
65 | return input, target
66 |
67 | def __len__(self):
68 | return len(self.meta['image_names'])
69 |
70 |
71 | class SetDataset:
72 | def __init__(self, data_file, batch_size, image_size, transform, is_train=True, attr_loc=False):
73 | with open(data_file, 'r') as f:
74 | self.meta = json.load(f)
75 |
76 | self.cl_list = np.unique(self.meta['image_labels']).tolist()
77 |
78 | self.sub_meta = {}
79 | for cl in self.cl_list:
80 | self.sub_meta[cl] = []
81 |
82 | if 'part' in self.meta:
83 | for x, y, z in zip(self.meta['image_names'], self.meta['image_labels'], self.meta['part']):
84 | self.sub_meta[y].append({'path': x, 'part': z})
85 | else:
86 | print("not use attribute location or attribute location is unavailable!")
87 | for x, y in zip(self.meta['image_names'], self.meta['image_labels']):
88 | self.sub_meta[y].append({'path': x})
89 |
90 | self.sub_dataloader = []
91 | sub_data_loader_params = dict(batch_size=batch_size,
92 | shuffle=True,
93 | num_workers=0, # use main thread only or may receive multiple batches
94 | pin_memory=False)
95 |
96 | for cl in self.cl_list:
97 | sub_dataset = SubDataset(self.sub_meta[cl], cl, image_size, attr_loc, transform=transform, is_train=is_train)
98 | self.sub_dataloader.append(torch.utils.data.DataLoader(sub_dataset, **sub_data_loader_params))
99 |
100 | def __getitem__(self, i):
101 | return next(iter(self.sub_dataloader[i]))
102 |
103 | def __len__(self):
104 | return len(self.cl_list)
105 |
106 |
107 | class EpisodicBatchSampler(object):
108 | def __init__(self, n_classes, n_way, n_episodes):
109 | self.n_classes = n_classes
110 | self.n_way = n_way
111 | self.n_episodes = n_episodes
112 |
113 | def __len__(self):
114 | return self.n_episodes
115 |
116 | def __iter__(self):
117 | for i in range(self.n_episodes):
118 | yield torch.randperm(self.n_classes)[:self.n_way]
119 |
120 |
121 | class SubDataset(Dataset):
122 | def __init__(self, sub_meta, cl, image_size, attr_loc=False, transform=transforms.ToTensor(), target_transform=identity,
123 | is_train=True):
124 | self.num_joints = 15
125 |
126 | self.is_train = is_train
127 | self.sub_meta = sub_meta
128 | self.cl = cl
129 | self.transform = transform
130 | self.target_transform = target_transform
131 |
132 | self.flip = is_train
133 | self.attr_loc = attr_loc
134 |
135 | self.image_size = image_size
136 |
137 | self.transform = transform
138 | self.target_transform = target_transform
139 |
140 | def __len__(self, ):
141 | return len(self.sub_meta)
142 |
143 | def __getitem__(self, idx):
144 | image_file = os.path.join(self.sub_meta[idx]['path'])
145 |
146 | #data_numpy = cv2.imread(image_file, cv2.IMREAD_COLOR | cv2.IMREAD_IGNORE_ORIENTATION)
147 |
148 | # used for SUN, cv2.imread returns None type
149 | data_numpy = np.array(PIL.Image.open(image_file).convert('RGB'))[:, :, ::-1] # to BGR
150 |
151 | if data_numpy is None:
152 | raise ValueError('Fail to read {}'.format(image_file))
153 |
154 | if self.attr_loc:
155 | joints_vis = self.sub_meta[idx]['part']
156 | joints_vis = np.array(joints_vis)
157 |
158 | r = 0
159 | c = np.array([data_numpy.shape[1], data_numpy.shape[0]]) // 2
160 | s = np.array([data_numpy.shape[1], data_numpy.shape[0]]) // 160
161 |
162 | if self.is_train:
163 | sf = 0.25
164 | rf = 30
165 | s = s * np.clip(np.random.randn() * sf + 1, 1 - sf, 1 + sf)
166 | r = np.clip(np.random.randn() * rf, -rf * 2, rf * 2) \
167 | if random.random() <= 0.6 else 0
168 |
169 | if self.flip and random.random() <= 0.5:
170 | data_numpy = data_numpy[:, ::-1, :]
171 | if self.attr_loc:
172 | for i in range(self.num_joints):
173 | if joints_vis[i, 2] > 0.0:
174 | joints_vis[i, 0] = data_numpy.shape[1] - joints_vis[i, 0]
175 | c[0] = data_numpy.shape[1] - c[0] - 1
176 |
177 | trans = get_affine_transform(c, s, r, [self.image_size, self.image_size])
178 | input = cv2.warpAffine(
179 | data_numpy,
180 | trans,
181 | (int(self.image_size), int(self.image_size)),
182 | flags=cv2.INTER_LINEAR)
183 | input = cv2.cvtColor(input, cv2.COLOR_BGR2RGB)
184 | input = Image.fromarray(input.transpose((1, 0, 2)))
185 |
186 | if self.transform:
187 | input = self.transform(input)
188 |
189 | target = self.target_transform(self.cl)
190 |
191 | if self.attr_loc is False:
192 | return input, target
193 | else:
194 | for i in range(self.num_joints):
195 | if joints_vis[i, 2] > 0.0:
196 | joints_vis[i, 0:2] = affine_transform(joints_vis[i, 0:2], trans)
197 |
198 | joints_vis = self.target_transform(joints_vis)
199 | return input, target, joints_vis
--------------------------------------------------------------------------------
/methods/maml.py:
--------------------------------------------------------------------------------
1 | # This code is modified from https://github.com/dragen1860/MAML-Pytorch and https://github.com/katerakelly/pytorch-maml
2 |
3 | import backbone
4 | import torch
5 | import torch.nn as nn
6 | from torch.autograd import Variable
7 | import numpy as np
8 | import torch.nn.functional as F
9 | from methods.meta_template import MetaTemplate
10 |
11 |
12 | class MAML(MetaTemplate):
13 | def __init__(self, model_func, n_way, n_support, approx=False):
14 | super(MAML, self).__init__(model_func, n_way, n_support, change_way=False)
15 |
16 | self.loss_fn = nn.CrossEntropyLoss()
17 | self.classifier = backbone.Linear_fw(self.feat_dim, n_way)
18 | self.classifier.bias.data.fill_(0)
19 |
20 | self.n_task = 4
21 | self.task_update_num = 5
22 | self.train_lr = 0.01
23 | self.approx = approx # first order approx.
24 |
25 | def forward(self, x):
26 | out = self.feature.forward(x)
27 | scores = self.classifier.forward(out)
28 | return scores
29 |
30 | def set_forward(self, x, is_feature=False):
31 | assert is_feature == False, 'MAML do not support fixed feature'
32 | x = x.cuda()
33 | x_var = Variable(x)
34 | x_a_i = x_var[:, :self.n_support, :, :, :].contiguous().view(self.n_way * self.n_support,
35 | *x.size()[2:]) # support data
36 | x_b_i = x_var[:, self.n_support:, :, :, :].contiguous().view(self.n_way * self.n_query,
37 | *x.size()[2:]) # query data
38 | y_a_i = Variable(
39 | torch.from_numpy(np.repeat(range(self.n_way), self.n_support))).cuda() # label for support data
40 |
41 | fast_parameters = list(self.parameters()) # the first gradient calcuated in line 45 is based on original weight
42 | for weight in self.parameters():
43 | weight.fast = None
44 | self.zero_grad()
45 |
46 | for task_step in range(self.task_update_num):
47 | scores = self.forward(x_a_i)
48 | set_loss = self.loss_fn(scores, y_a_i)
49 | grad = torch.autograd.grad(set_loss, fast_parameters,
50 | create_graph=True) # build full graph support gradient of gradient
51 | if self.approx:
52 | grad = [g.detach() for g in
53 | grad] # do not calculate gradient of gradient if using first order approximation
54 | fast_parameters = []
55 | for k, weight in enumerate(self.parameters()):
56 | # for usage of weight.fast, please see Linear_fw, Conv_fw in backbone.py
57 | if weight.fast is None:
58 | weight.fast = weight - self.train_lr * grad[k] # create weight.fast
59 | else:
60 | weight.fast = weight.fast - self.train_lr * grad[
61 | k] # create an updated weight.fast, note the '-' is not merely minus value, but to create a new weight.fast
62 | fast_parameters.append(
63 | weight.fast) # gradients calculated in line 45 are based on newest fast weight, but the graph will retain the link to old weight.fasts
64 |
65 | scores = self.forward(x_b_i)
66 | return scores
67 |
68 | def set_forward_adaptation(self, x, is_feature=False): # overwrite parrent function
69 | raise ValueError('MAML performs further adapation simply by increasing task_upate_num')
70 |
71 | def set_forward_loss(self, x):
72 | scores = self.set_forward(x, is_feature=False)
73 | y_b_i = Variable(torch.from_numpy(np.repeat(range(self.n_way), self.n_query))).cuda()
74 | loss = self.loss_fn(scores, y_b_i)
75 |
76 | return loss
77 |
78 | def train_loop(self, epoch, train_loader, optimizer, tf_writer): # overwrite parrent function
79 | print_freq = 10
80 | avg_loss = 0
81 | task_count = 0
82 | loss_all = []
83 | optimizer.zero_grad()
84 |
85 | # train
86 | for i, (x, _) in enumerate(train_loader):
87 | self.n_query = x.size(1) - self.n_support
88 | assert self.n_way == x.size(0), "MAML do not support way change"
89 |
90 | loss = self.set_forward_loss(x)
91 | avg_loss = avg_loss + loss.item()
92 | loss_all.append(loss)
93 |
94 | task_count += 1
95 |
96 | if task_count == self.n_task: # MAML update several tasks at one time
97 | loss_q = torch.stack(loss_all).sum(0)
98 | loss_q.backward()
99 |
100 | optimizer.step()
101 | task_count = 0
102 | loss_all = []
103 | optimizer.zero_grad()
104 | if i % print_freq == 0:
105 | print('Epoch {:d} | Batch {:d}/{:d} | Loss {:f}'.format(epoch, i, len(train_loader),
106 | avg_loss / float(i + 1)))
107 | tf_writer.add_scalar('loss/train', avg_loss / float(i + 1), epoch)
108 |
109 | def test_loop(self, test_loader, return_std=False): # overwrite parrent function
110 | correct = 0
111 | count = 0
112 | acc_all = []
113 |
114 | iter_num = len(test_loader)
115 | from tqdm import tqdm
116 | for i, (x, _) in enumerate(tqdm(test_loader)):
117 | self.n_query = x.size(1) - self.n_support
118 | assert self.n_way == x.size(0), "MAML do not support way change"
119 | correct_this, count_this = self.correct(x)
120 | acc_all.append(correct_this / count_this * 100)
121 |
122 | acc_all = np.asarray(acc_all)
123 | acc_mean = np.mean(acc_all)
124 | acc_std = np.std(acc_all)
125 | print('%d Test Acc = %4.2f%% +- %4.2f%%' % (iter_num, acc_mean, 1.96 * acc_std / np.sqrt(iter_num)))
126 | if return_std:
127 | return acc_mean, acc_std
128 | else:
129 | return acc_mean
130 |
131 | def test_loop_with_dists(self, test_loader, all_cls_dists, base_dists,
132 | base_cls_dists, attr_num): # overwrite parrent function
133 | correct = 0
134 | count = 0
135 | acc_all = []
136 | dist_all = []
137 |
138 | iter_num = len(test_loader)
139 | from tqdm import tqdm
140 | for i, (x, y) in enumerate(tqdm(test_loader)):
141 | # original mean-task (down trend)
142 | sc_cls = y.unique()
143 | task_dists = all_cls_dists[sc_cls, :].mean(0).unsqueeze(0)
144 | dist_all.append(torch.abs(base_dists - task_dists).sum(-1).mean().item() / attr_num)
145 |
146 | self.n_query = x.size(1) - self.n_support
147 | assert self.n_way == x.size(0), "MAML do not support way change"
148 | correct_this, count_this = self.correct(x)
149 | acc_all.append(correct_this / count_this * 100)
150 |
151 | # task-agnostic (up trend)
152 | # task_cls_dists = all_cls_dists[sc_cls, :].unsqueeze(1)
153 | # distance = torch.abs(base_cls_dists.unsqueeze(0) - task_cls_dists).sum(-1).mean()
154 | # dist_all.append(distance.item())
155 |
156 | # class-agnostic (down trend)
157 | # task_dists = all_cls_dists[sc_cls, :].mean(0)
158 | # dist_all.append(torch.abs(base_dists - task_dists).sum().item() / len(attr_label_split))
159 |
160 | # original (up trend)
161 | # task_cls_dists = all_cls_dists[sc_cls, :].unsqueeze(0)
162 | # distance = torch.abs(base_cls_dists - task_cls_dists).sum(-1).mean()
163 | # dist_all.append(distance.item())
164 |
165 | # original mini-task (first up then down ?)
166 | # task_dists = all_cls_dists[sc_cls, :].mean(0).unsqueeze(0)
167 | # dist_all.append(torch.abs(base_dists - task_dists).sum(-1).min().item())
168 |
169 | # ideal with task-agnostic (down trend)
170 | # task_cls_dists = all_cls_dists[sc_cls, :]
171 | # d = 0
172 | # for j in range(task_cls_dists.shape[0]):
173 | # task_dist = task_cls_dists[j]
174 | #
175 | # start = 0
176 | # for split in attr_label_split:
177 | # end = start + split
178 | # task_dist_split = task_dist[start:end].unsqueeze(0)
179 | # base_dist_split = base_cls_dists[:, start:end]
180 | # d_split = torch.abs(base_dist_split - task_dist_split).sum(-1).min()
181 | # d += d_split.item()
182 | # start = end
183 | # d /= (len(attr_label_split)*task_cls_dists.shape[0])
184 | # dist_all.append(d)
185 |
186 | acc_all = np.asarray(acc_all)
187 | acc_mean = np.mean(acc_all)
188 | acc_std = np.std(acc_all)
189 | print('%d Test Acc = %4.2f%% +- %4.2f%%' % (iter_num, acc_mean, 1.96 * acc_std / np.sqrt(iter_num)))
190 | return acc_all, dist_all
191 |
192 |
--------------------------------------------------------------------------------
/plot_distance_acc.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | from torch.autograd import Variable
4 | import torch.nn as nn
5 | import torch.optim
6 | import json
7 | import torch.utils.data.sampler
8 | import os
9 | import glob
10 | import random
11 | import time
12 |
13 | import configs
14 | import backbone
15 | import data.feature_loader as feat_loader
16 | from data.datamgr import SetDataManager
17 | from methods.baselinetrain import BaselineTrain
18 | from methods.baselinefinetune import BaselineFinetune
19 | from methods.protonet import ProtoNet
20 | from methods.matchingnet import MatchingNet
21 | from methods.relationnet import RelationNet
22 | from methods.maml import MAML
23 | from methods.apnet import APNet_w_attrLoc, APNet_wo_attrLoc
24 | from io_utils import model_dict, parse_args, get_resume_file, get_best_file, get_assigned_file
25 | from utils import get_attr_distance, plot_fig
26 |
27 | # seed = 1
28 | # np.random.seed(seed)
29 | # torch.random.manual_seed(seed)
30 |
31 | def feature_evaluation_with_dists(cl_data_file, model, n_way=5, n_support=5, n_query=15, adaptation=False):
32 | class_list = cl_data_file.keys()
33 |
34 | select_class = random.sample(class_list, n_way)
35 | z_all = []
36 | for cl in select_class:
37 | img_feat = cl_data_file[cl]
38 | perm_ids = np.random.permutation(len(img_feat)).tolist()
39 | z_all.append([np.squeeze(img_feat[perm_ids[i]]) for i in range(n_support + n_query)]) # stack each batch
40 |
41 | z_all = torch.from_numpy(np.array(z_all))
42 |
43 | model.n_query = n_query
44 | if adaptation:
45 | scores = model.set_forward_adaptation(z_all, is_feature=True)
46 | else:
47 | scores = model.set_forward(z_all, is_feature=True)
48 | pred = scores.data.cpu().numpy().argmax(axis=1)
49 | y = np.repeat(range(n_way), n_query)
50 | acc = np.mean(pred == y) * 100
51 | return acc, select_class
52 |
53 |
54 | if __name__ == '__main__':
55 | params = parse_args('test')
56 |
57 | acc_all = []
58 | iter_num, n_query = 600, 35 # to reduce the influence of random factors
59 | attr_loc = False
60 | few_shot_params = dict(n_way=params.test_n_way, n_support=params.n_shot)
61 |
62 | split = params.split
63 | if params.save_iter != -1:
64 | split_str = split + "_" + str(params.save_iter)
65 | else:
66 | split_str = split
67 | if 'Conv' in params.model:
68 | image_size = 84
69 | else:
70 | image_size = 224
71 | datamgr = SetDataManager(image_size, n_eposide=iter_num, n_query=n_query, **few_shot_params)
72 | if params.dataset == 'AWA2':
73 | # for AWA2, we use both validation and test classes as novel classes, because the number of test classes in AWA2 is too small (only 10)
74 | loadfile = configs.data_dir[params.dataset] + 'val_novel.json'
75 | else:
76 | loadfile = configs.data_dir[params.dataset] + split + '.json'
77 | novel_loader = datamgr.get_data_loader(loadfile, aug=False, is_train=False)
78 |
79 | base_datamgr = SetDataManager(image_size, n_eposide=iter_num, n_query=n_query, **few_shot_params)
80 | base_loadfile = configs.data_dir[params.dataset] + 'base.json'
81 | base_loader = base_datamgr.get_data_loader(base_loadfile, aug=False, is_train=False)
82 | all_cls_dists, base_dists, base_cls_dists = get_attr_distance(base_loader, params.dataset)
83 |
84 | if params.dataset == 'CUB':
85 | if params.method in ['comet', 'apnet']:
86 | attr_loc = True
87 | attr_num = 109
88 | elif params.dataset == 'SUN':
89 | attr_num = 102
90 | elif params.dataset == 'AWA2':
91 | attr_num = 85
92 | else:
93 | AssertionError("not implement!")
94 |
95 | if params.method == 'baseline':
96 | model = BaselineFinetune(model_dict[params.model], **few_shot_params)
97 | elif params.method == 'baseline++':
98 | model = BaselineFinetune(model_dict[params.model], loss_type='dist', **few_shot_params)
99 | elif params.method == 'protonet':
100 | model = ProtoNet(model_dict[params.model], **few_shot_params)
101 | elif params.method == 'comet':
102 | assert params.dataset == 'CUB'
103 | model = COMET(model_dict[params.model], **few_shot_params)
104 | elif params.method == 'matchingnet':
105 | model = MatchingNet(model_dict[params.model], **few_shot_params)
106 | elif params.method in ['relationnet', 'relationnet_softmax']:
107 | if params.model == 'Conv4':
108 | feature_model = backbone.Conv4NP
109 | elif params.model == 'Conv6':
110 | feature_model = backbone.Conv6NP
111 | elif params.model == 'Conv4S':
112 | feature_model = backbone.Conv4SNP
113 | else:
114 | feature_model = lambda: model_dict[params.model](flatten=False)
115 | loss_type = 'mse' if params.method == 'relationnet' else 'softmax'
116 | model = RelationNet(feature_model, loss_type=loss_type, **few_shot_params)
117 | elif params.method in ['maml', 'maml_approx']:
118 | backbone.ConvBlock.maml = True
119 | backbone.SimpleBlock.maml = True
120 | backbone.BottleneckBlock.maml = True
121 | backbone.ResNet.maml = True
122 | model = MAML(model_dict[params.model], approx=(params.method == 'maml_approx'), **few_shot_params)
123 | if params.dataset in ['omniglot', 'cross_char']: # maml use different parameter in omniglot
124 | model.n_task = 32
125 | model.task_update_num = 1
126 | model.train_lr = 0.1
127 | elif params.method == 'apnet':
128 | few_shot_params = dict(n_way=params.test_n_way, n_support=params.n_shot, attr_num=attr_num, attr_loc=attr_loc, dataset=params.dataset)
129 | if attr_loc:
130 | model = APNet_w_attrLoc(model_dict[params.model], **few_shot_params)
131 | else:
132 | model = APNet_wo_attrLoc(model_dict[params.model], **few_shot_params)
133 | else:
134 | raise ValueError('Unknown method')
135 |
136 | model = model.cuda()
137 |
138 | checkpoint_dir = '%s/checkpoints/%s/%s_%s_%s' % (
139 | configs.save_dir, params.dataset, params.model, params.method, params.exp_str)
140 | if params.train_aug:
141 | checkpoint_dir += '_aug'
142 | if not params.method in ['baseline', 'baseline++']:
143 | checkpoint_dir += '_%dway_%dshot' % (params.train_n_way, params.n_shot)
144 |
145 | if not params.method in ['baseline', 'baseline++']:
146 | if params.save_iter != -1:
147 | modelfile = get_assigned_file(checkpoint_dir, params.save_iter)
148 | else:
149 | modelfile = get_best_file(checkpoint_dir)
150 | if modelfile is not None:
151 | tmp = torch.load(modelfile)
152 | model.load_state_dict(tmp['state'])
153 |
154 | if params.method in ['maml', 'maml_approx']: # maml do not support testing with feature
155 | if params.dataset == 'SUN' and params.model == 'Conv4':
156 | model.train_lr = 0.1
157 | if params.adaptation:
158 | model.task_update_num = 100 # We perform adaptation on MAML simply by updating more times.
159 | model.eval()
160 | acc_all, dist_all = model.test_loop_with_dists(novel_loader, all_cls_dists, base_dists, base_cls_dists, attr_num)
161 | else:
162 | novel_file = os.path.join(checkpoint_dir.replace("checkpoints", "features"),
163 | split_str + ".hdf5") # defaut split = novel, but you can also test base or val classes
164 | print("novel_file:", novel_file)
165 | cl_data_file = feat_loader.init_loader(novel_file)
166 |
167 | dist_all = []
168 | from tqdm import tqdm
169 | for i in tqdm(range(iter_num)):
170 | acc, sc_cls = feature_evaluation_with_dists(cl_data_file, model, n_query=15, adaptation=params.adaptation,
171 | **few_shot_params)
172 | acc_all.append(acc)
173 |
174 | # original mean-task (down trend)
175 | task_dists = all_cls_dists[sc_cls, :].mean(0).unsqueeze(0)
176 | dist_all.append(torch.abs(base_dists - task_dists).sum(-1).mean().item() / attr_num)
177 |
178 | acc_all = np.asarray(acc_all)
179 | acc_mean = np.mean(acc_all)
180 | acc_std = np.std(acc_all)
181 | print('%d Test Acc = %4.2f%% +- %4.2f%%' % (iter_num, acc_mean, 1.96 * acc_std / np.sqrt(iter_num)))
182 | plot_fig(acc_all, dist_all)
183 |
184 | if params.runs is None:
185 | distance_file = './distance-{}.txt'.format(params.n_shot)
186 | accuracy_file = './accuracy-{}.txt'.format(params.n_shot)
187 | else:
188 | distance_file = './distance-{}-runs{}-{}.txt'.format(params.n_shot, params.runs, params.method)
189 | accuracy_file = './accuracy-{}-runs{}-{}.txt'.format(params.n_shot, params.runs, params.method)
190 |
191 | with open(distance_file, 'w') as f:
192 | for dis in dist_all:
193 | f.write(str(dis) + ' ')
194 |
195 | with open(accuracy_file, 'w') as f:
196 | for acc in acc_all:
197 | f.write(str(acc) + ' ')
198 | print("Write files successfully!")
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torch.nn as nn
4 | from torch.autograd import Variable
5 | import torch.optim
6 | import torch.optim.lr_scheduler as lr_scheduler
7 | import time
8 | import os
9 | import glob
10 |
11 | import configs
12 | import backbone
13 | from data.datamgr import SimpleDataManager, SetDataManager
14 | from methods.baselinetrain import BaselineTrain
15 | from methods.baselinefinetune import BaselineFinetune
16 | from methods.protonet import ProtoNet
17 | from methods.matchingnet import MatchingNet
18 | from methods.relationnet import RelationNet
19 | from methods.maml import MAML
20 | from methods.apnet import APNet_w_attrLoc, APNet_wo_attrLoc
21 | from io_utils import model_dict, parse_args, get_resume_file
22 | from torch.utils.tensorboard import SummaryWriter
23 |
24 | def train(base_loader, val_loader, model, optimization, start_epoch, stop_epoch, params, tf_writer):
25 | if optimization == 'Adam':
26 | optimizer = torch.optim.Adam(model.parameters())
27 | else:
28 | raise ValueError('Unknown optimization, please define by yourself')
29 |
30 | max_acc = 0
31 |
32 | for epoch in range(start_epoch, stop_epoch):
33 | model.train()
34 | # model.train_loop(epoch, base_loader, optimizer, tf_writer, params.beta) #model are called by reference, no need to return
35 | model.train_loop(epoch, base_loader, optimizer, tf_writer)
36 | model.eval()
37 |
38 | if not os.path.isdir(params.checkpoint_dir):
39 | os.makedirs(params.checkpoint_dir)
40 |
41 | acc = model.test_loop(val_loader)
42 | tf_writer.add_scalar('acc/test', acc, epoch)
43 |
44 | if acc > max_acc: # for baseline and baseline++, we don't use validation in default and we let acc = -1, but we allow options to validate with DB index
45 | print("best model! save...")
46 | max_acc = acc
47 | outfile = os.path.join(params.checkpoint_dir, 'best_model.tar')
48 | torch.save({'epoch': epoch, 'state': model.state_dict()}, outfile)
49 |
50 | if (epoch % params.save_freq == 0) or (epoch == stop_epoch - 1):
51 | outfile = os.path.join(params.checkpoint_dir, '{:d}.tar'.format(epoch))
52 | torch.save({'epoch': epoch, 'state': model.state_dict()}, outfile)
53 |
54 | return model
55 |
56 |
57 | if __name__ == '__main__':
58 | np.random.seed(10)
59 | params = parse_args('train')
60 |
61 | base_file = configs.data_dir[params.dataset] + 'base.json'
62 | val_file = configs.data_dir[params.dataset] + 'val.json'
63 |
64 | if 'Conv' in params.model:
65 | image_size = 84
66 | else:
67 | image_size = 224
68 |
69 | optimization = 'Adam'
70 |
71 | if params.stop_epoch == -1:
72 | if params.method in ['baseline', 'baseline++']:
73 | if params.dataset in ['CUB']:
74 | params.stop_epoch = 200 # This is different as stated in the open-review paper. However, using 400 epoch in baseline actually lead to over-fitting
75 | else:
76 | params.stop_epoch = 400 # default
77 | else: # meta-learning methods
78 | if params.n_shot == 1:
79 | params.stop_epoch = 600
80 | elif params.n_shot == 5:
81 | params.stop_epoch = 400
82 | else:
83 | params.stop_epoch = 600 # default
84 |
85 | if params.method in ['baseline', 'baseline++'] :
86 | base_datamgr = SimpleDataManager(image_size, batch_size = 16)
87 | base_loader = base_datamgr.get_data_loader( base_file , aug = params.train_aug, is_train=True)
88 | val_datamgr = SimpleDataManager(image_size, batch_size = 64)
89 | val_loader = val_datamgr.get_data_loader( val_file, aug = False, is_train=False)
90 |
91 | if params.dataset == 'omniglot':
92 | assert params.num_classes >= 4112, 'class number need to be larger than max label id in base class'
93 | if params.dataset == 'cross_char':
94 | assert params.num_classes >= 1597, 'class number need to be larger than max label id in base class'
95 |
96 | if params.method == 'baseline':
97 | model = BaselineTrain( model_dict[params.model], params.num_classes)
98 | elif params.method == 'baseline++':
99 | model = BaselineTrain( model_dict[params.model], params.num_classes, loss_type = 'dist')
100 |
101 | elif params.method in ['protonet', 'comet', 'matchingnet','relationnet', 'relationnet_softmax', 'maml', 'maml_approx', 'apnet']:
102 | n_query = max(1, int(16* params.test_n_way/params.train_n_way)) #if test_n_way is smaller than train_n_way, reduce n_query to keep batch size small
103 |
104 | if params.dataset == 'CUB' and params.method in ['comet', 'apnet']:
105 | attr_loc = True
106 | else:
107 | attr_loc = False
108 |
109 | train_few_shot_params = dict(n_way = params.train_n_way, n_support = params.n_shot)
110 | base_datamgr = SetDataManager(image_size, n_query = n_query, **train_few_shot_params)
111 | base_loader = base_datamgr.get_data_loader( base_file , aug = params.train_aug, is_train=True, attr_loc=attr_loc)
112 |
113 | test_few_shot_params = dict(n_way = params.test_n_way, n_support = params.n_shot)
114 | val_datamgr = SetDataManager(image_size, n_query = n_query, **test_few_shot_params)
115 | val_loader = val_datamgr.get_data_loader( val_file, aug = False, is_train=False, attr_loc=attr_loc)
116 | #a batch for SetDataManager: a [n_way, n_support + n_query, dim, w, h] tensor
117 |
118 | if params.method == 'protonet':
119 | model = ProtoNet( model_dict[params.model], **train_few_shot_params )
120 | elif params.method == 'comet':
121 | model = COMET( model_dict[params.model], **train_few_shot_params )
122 | elif params.method == 'matchingnet':
123 | model = MatchingNet( model_dict[params.model], **train_few_shot_params )
124 | elif params.method in ['relationnet', 'relationnet_softmax']:
125 | if params.model == 'Conv4':
126 | feature_model = backbone.Conv4NP
127 | elif params.model == 'Conv6':
128 | feature_model = backbone.Conv6NP
129 | elif params.model == 'Conv4S':
130 | feature_model = backbone.Conv4SNP
131 | else:
132 | feature_model = lambda: model_dict[params.model]( flatten = False )
133 | loss_type = 'mse' if params.method == 'relationnet' else 'softmax'
134 |
135 | model = RelationNet( feature_model, loss_type = loss_type , **train_few_shot_params )
136 | elif params.method in ['maml' , 'maml_approx']:
137 | backbone.ConvBlock.maml = True
138 | backbone.SimpleBlock.maml = True
139 | backbone.BottleneckBlock.maml = True
140 | backbone.ResNet.maml = True
141 | model = MAML( model_dict[params.model], approx = (params.method == 'maml_approx') , **train_few_shot_params )
142 | if params.dataset in ['omniglot', 'cross_char']: #maml use different parameter in omniglot
143 | model.n_task = 32
144 | model.task_update_num = 1
145 | model.train_lr = 0.1
146 | elif params.method == 'apnet':
147 | if params.dataset == 'CUB':
148 | attr_num = 109
149 | elif params.dataset == 'SUN':
150 | attr_num = 102
151 | elif params.dataset == 'AWA2':
152 | attr_num = 85
153 | else:
154 | AssertionError('not implement!')
155 | train_few_shot_params = dict(n_way=params.train_n_way, n_support=params.n_shot,
156 | attr_num=attr_num, attr_loc=attr_loc, dataset=params.dataset)
157 | if attr_loc:
158 | model = APNet_w_attrLoc(model_dict[params.model], **train_few_shot_params)
159 | else:
160 | model = APNet_wo_attrLoc(model_dict[params.model], **train_few_shot_params)
161 | else:
162 | raise ValueError('Unknown method')
163 |
164 | model = model.cuda()
165 |
166 | params.checkpoint_dir = '%s/checkpoints/%s/%s_%s_%s' % (
167 | configs.save_dir, params.dataset, params.model, params.method, params.exp_str)
168 | if params.train_aug:
169 | params.checkpoint_dir += '_aug'
170 | if not params.method in ['baseline', 'baseline++']:
171 | params.checkpoint_dir += '_%dway_%dshot' % (params.train_n_way, params.n_shot)
172 |
173 | if not os.path.isdir(params.checkpoint_dir):
174 | os.makedirs(params.checkpoint_dir)
175 |
176 | store_name = '_'.join([params.dataset, params.model, params.method, params.exp_str])
177 | # wandb.init(project="fewshot_images", tensorboard=True, name=store_name)
178 |
179 | start_epoch = params.start_epoch
180 | stop_epoch = params.stop_epoch
181 | if params.method == 'maml' or params.method == 'maml_approx':
182 | stop_epoch = params.stop_epoch * model.n_task # maml use multiple tasks in one update
183 |
184 | if params.resume:
185 | resume_file = get_resume_file(params.checkpoint_dir)
186 | if resume_file is not None:
187 | tmp = torch.load(resume_file)
188 | start_epoch = tmp['epoch'] + 1
189 | model.load_state_dict(tmp['state'])
190 | elif params.warmup: # We also support warmup from pretrained baseline feature, but we never used in our paper
191 | baseline_checkpoint_dir = '%s/checkpoints/%s/%s_%s' % (
192 | configs.save_dir, params.dataset, params.model, 'baseline')
193 | if params.train_aug:
194 | baseline_checkpoint_dir += '_aug'
195 | warmup_resume_file = get_resume_file(baseline_checkpoint_dir)
196 | tmp = torch.load(warmup_resume_file)
197 | if tmp is not None:
198 | state = tmp['state']
199 | state_keys = list(state.keys())
200 | for i, key in enumerate(state_keys):
201 | if "feature." in key:
202 | newkey = key.replace("feature.",
203 | "") # an architecture model has attribute 'feature', load architecture feature to backbone by casting name from 'feature.trunk.xx' to 'trunk.xx'
204 | state[newkey] = state.pop(key)
205 | else:
206 | state.pop(key)
207 | model.feature.load_state_dict(state)
208 | else:
209 | raise ValueError('No warm_up file')
210 |
211 | tf_writer = SummaryWriter(log_dir=params.checkpoint_dir)
212 |
213 | # print("beta:", params.beta)
214 | # # model2 = CBModel(model, params.method)
215 | # # model2 = MTModel(model, params.method)
216 | # # model2 = model2.cuda()
217 | # model2 = model
218 | # model2 = train(base_loader, val_loader, model2, optimization, start_epoch, stop_epoch, params, tf_writer)
219 | # print('*' * 10 + 'Test' + '*' * 10)
220 | # acc = model2.test_loop(test_loader)
221 | # tf_writer.add_scalar('acc/final_test', acc, stop_epoch)
222 | model = train(base_loader, val_loader, model, optimization, start_epoch, stop_epoch, params, tf_writer)
--------------------------------------------------------------------------------
/methods/apnet.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | import numpy as np
5 | import random
6 | import time
7 | import json
8 | from sklearn.metrics import confusion_matrix
9 | from methods.meta_template import MetaTemplate
10 | from torch.autograd import Variable
11 |
12 | class APNetTemplate(MetaTemplate):
13 | def __init__(self, model_func, n_way, n_support, attr_num, attr_loc=False, dataset='CUB'):
14 | super(APNetTemplate, self).__init__(model_func, n_way, n_support)
15 | self.attr_loc = attr_loc
16 |
17 | final_feat_dim = self.feature.final_feat_dim
18 | if isinstance(final_feat_dim, list):
19 | if self.attr_loc:
20 | self.input_dim = final_feat_dim[0] * 16
21 | else:
22 | self.input_dim = 1
23 | for dim in final_feat_dim:
24 | self.input_dim *= dim
25 | else:
26 | self.input_dim = final_feat_dim
27 | if self.attr_loc:
28 | self.input_dim *= 16
29 | self.attr_num = attr_num
30 | self.beta = 0.6
31 |
32 | self.classifier = nn.Linear(self.input_dim, self.attr_num*2) # only consider binary attribute
33 | self.loss_fn = nn.CrossEntropyLoss()
34 | self.read_attr_labels(dataset)
35 |
36 | def read_attr_labels(self, dataset):
37 | # TODO: change the filenames with your own path, and add files to filter attribute labels in CUB
38 | if dataset == 'CUB':
39 | filename = '/home/huminyang/Code/comet/CUB/filelists/CUB/CUB_200_2011/masked_class_attribute_labels.txt'
40 | elif dataset == 'AWA2':
41 | filename = '/home/huminyang/Code/APNet/filelists/AWA2/class_attribute_label.txt'
42 | else:
43 | AssertionError('not implement!')
44 |
45 | attr_labels_binary = []
46 | with open(filename, 'r') as f:
47 | for line in f.readlines():
48 | line_split = line.strip().split(' ')
49 | float_line = []
50 | for str_num in line_split:
51 | float_line.append(int(str_num))
52 | attr_labels_binary.append(float_line)
53 | self.attr_labels_split = torch.from_numpy(np.array(attr_labels_binary)).cuda()
54 |
55 | def correct(self, x, joints):
56 | # correct for image classification
57 | scores = self.set_forward(x, joints)
58 | topk_scores, topk_labels = scores.data.topk(1, 1, True, True)
59 | topk_ind = topk_labels.cpu().numpy()
60 |
61 | y_query = np.repeat(range(self.n_way), self.n_query)
62 | top1_correct = np.sum(topk_ind[:, 0] == y_query)
63 | return float(top1_correct), len(y_query)
64 |
65 | def forward(self, x, joints):
66 | if joints is None:
67 | z_support, z_query = super().parse_feature(x, is_feature=False)
68 | else:
69 | z_support, z_query = self.parse_feature(x, joints, is_feature=False)
70 | feature = torch.cat([z_support, z_query], dim=1) # (n_way, n_support+n_query, dim)
71 | feature = feature.view(-1, self.input_dim)
72 | logits = self.classifier(feature)
73 | return torch.split(logits, 2, -1)
74 |
75 | def set_forward(self, x, joints):
76 | logits = self.forward(x, joints)
77 | logits = torch.cat(logits, dim=-1)
78 | logits = logits.view(self.n_way, self.n_support + self.n_query, -1)
79 |
80 | z_support = logits[:, :self.n_support]
81 | z_query = logits[:, self.n_support:]
82 |
83 | z_support = z_support.contiguous()
84 | z_proto = z_support.view(self.n_way, self.n_support, -1).mean(1) # the shape of z is [n_data, n_dim]
85 | z_query = z_query.contiguous().view(self.n_way * self.n_query, -1)
86 |
87 | scores = F.cosine_similarity(z_query.unsqueeze(1), z_proto, dim=-1)
88 | return scores / 0.2
89 |
90 | def set_forward_loss1(self, x, joints):
91 | y_query = torch.from_numpy(np.repeat(range(self.n_way), self.n_query))
92 | y_query = y_query.cuda()
93 |
94 | scores = self.set_forward(x, joints)
95 | return self.loss_fn(scores, y_query)
96 |
97 | def set_forward_loss2(self, x, ys, joints):
98 | ys = ys.view(-1, self.attr_num)
99 | logits = self.forward(x, joints)
100 | logits = torch.cat(logits, dim=0)
101 | return self.loss_fn(logits, ys.transpose(1, 0).reshape(-1))
102 |
103 | class APNet_w_attrLoc(APNetTemplate):
104 | def __init__(self, model_func, n_way, n_support, attr_num, attr_loc=False, dataset='CUB'):
105 | super(APNet_w_attrLoc, self).__init__(model_func, n_way, n_support, attr_num, attr_loc, dataset)
106 | self.globalpool = nn.AdaptiveAvgPool2d((1, 1))
107 |
108 | # this function originates from comet
109 | def parse_feature(self, x, joints, is_feature):
110 | x = Variable(x.cuda())
111 | if is_feature:
112 | z_all = x
113 | else:
114 | x = x.contiguous().view(self.n_way * (self.n_support + self.n_query), *x.size()[2:])
115 | z_all = self.feature.forward(x)
116 | z_avg = self.globalpool(z_all).view(z_all.size(0), z_all.size(1))
117 |
118 | joints = joints.contiguous().view(self.n_way * (self.n_support + self.n_query), *joints.size()[2:])
119 | img_len = x.size()[-1]
120 | feat_len = z_all.size()[-1]
121 | joints[:, :, :2] = joints[:, :, :2] / img_len * feat_len
122 | joints = joints.round().int()
123 | joints_num = joints.size(1)
124 |
125 | avg_mask = (joints[:, :, 2] == 0) + (joints[:, :, 0] < 0) + (joints[:, :, 1] < 0) + (
126 | joints[:, :, 0] >= feat_len) + (joints[:, :, 1] >= feat_len)
127 | avg_mask = (avg_mask > 0).long().unsqueeze(-1).cuda() # (85, 15, 1)
128 | mask_joints = joints.cuda() * (1 - avg_mask)
129 | mask_joints = (mask_joints[:, :, 0] * 7 + mask_joints[:, :, 1]).unsqueeze(1).repeat(1, 64, 1)
130 | z_all_2D = z_all.view(z_all.size(0), z_all.size(1), -1)
131 | mask_z = torch.gather(z_all_2D, dim=-1, index=mask_joints)
132 | mask_z = mask_z.permute(0, 2, 1) # (85, 15, 64)
133 | mask_z_avg = z_avg.unsqueeze(1).repeat(1, joints_num, 1) * avg_mask
134 | z_all_tmp = mask_z * (1 - avg_mask) + mask_z_avg
135 | z_all = torch.cat([z_all_tmp, z_avg.unsqueeze(1)], dim=1).view(self.n_way, self.n_support + self.n_query, -1)
136 | z_support = z_all[:, :self.n_support]
137 | z_query = z_all[:, self.n_support:]
138 | return z_support, z_query
139 |
140 | def train_loop(self, epoch, train_loader, optimizer, tf_writer):
141 | print_freq = 10
142 |
143 | avg_loss, avg_loss1, avg_loss2 = 0, 0, 0
144 | start_time = time.time()
145 | for i, (x, y, joints) in enumerate(train_loader):
146 | self.n_query = x.size(1) - self.n_support
147 | x, y = x.cuda(), y.cuda()
148 | attr_labels = self.attr_labels_split[y]
149 | optimizer.zero_grad()
150 | loss1 = self.set_forward_loss1(x, joints)
151 | loss2 = self.set_forward_loss2(x, attr_labels, joints)
152 | loss = loss1 + self.beta * loss2
153 | loss.backward()
154 | optimizer.step()
155 | avg_loss = avg_loss + loss.item()
156 | avg_loss1 = avg_loss1 + loss1.item()
157 | avg_loss2 = avg_loss2 + loss2.item()
158 |
159 | if i % print_freq == 0:
160 | # print(optimizer.state_dict()['param_groups'][0]['lr'])
161 | print('Epoch {:d} | Batch {:d}/{:d} | Loss {:f} | Loss1 {:f} | Loss2 {:f}'.format(epoch, i, len(train_loader),
162 | avg_loss / float(i + 1), avg_loss1 / float(i + 1), avg_loss2 / float(i + 1)))
163 | tf_writer.add_scalar('loss/train', avg_loss / float(i + 1), epoch)
164 | tf_writer.add_scalar('loss1/train', avg_loss1 / float(i + 1), epoch)
165 | tf_writer.add_scalar('loss2/train', avg_loss2 / float(i + 1), epoch)
166 | print("Epoch (train) uses %.2f s!" % (time.time() - start_time))
167 |
168 | def test_loop(self, test_loader, return_std=False):
169 | acc_all = []
170 |
171 | iter_num = len(test_loader)
172 | start_time = time.time()
173 | for i, (x, _, joints) in enumerate(test_loader):
174 | x = x.cuda()
175 | self.n_query = x.size(1) - self.n_support
176 | correct_this, count_this = self.correct(x, joints)
177 |
178 | acc_all.append(correct_this / count_this * 100)
179 | print("Epoch (test) uses %.2f s!" % (time.time() - start_time))
180 |
181 | acc_all = np.asarray(acc_all)
182 | acc_mean = np.mean(acc_all)
183 | acc_std = np.std(acc_all)
184 | print('%d Test Acc = %4.2f%% +- %4.2f%%' % (iter_num, acc_mean, 1.96 * acc_std / np.sqrt(iter_num)))
185 |
186 | if return_std:
187 | return acc_mean, acc_std
188 | else:
189 | return acc_mean
190 |
191 | def test_loop_with_dists(self, test_loader, all_cls_dists, base_dists, base_cls_dists, attr_num):
192 | acc_all, dist_all = [], []
193 | attr_num = all_cls_dists.shape[1]
194 |
195 | iter_num = len(test_loader)
196 | from tqdm import tqdm
197 | for i, (x, y, joints) in enumerate(tqdm(test_loader)):
198 | x, y = x.cuda(), y.cuda()
199 | self.n_query = x.size(1) - self.n_support
200 | if self.change_way:
201 | self.n_way = x.size(0)
202 | correct_this, count_this = self.correct(x, joints)
203 | acc_all.append(correct_this / count_this * 100)
204 |
205 | sc_cls = y.unique()
206 | # original mean-task (down trend)
207 | task_dists = all_cls_dists[sc_cls, :].mean(0).unsqueeze(0)
208 | dist_all.append(torch.abs(base_dists - task_dists).sum(-1).mean().item() / attr_num)
209 |
210 | acc_all = np.asarray(acc_all)
211 | acc_mean = np.mean(acc_all)
212 | acc_std = np.std(acc_all)
213 | print('%d Test Acc = %4.2f%% +- %4.2f%%' % (iter_num, acc_mean, 1.96 * acc_std / np.sqrt(iter_num)))
214 |
215 | return acc_all, dist_all
216 |
217 | class APNet_wo_attrLoc(APNetTemplate):
218 | def __init__(self, model_func, n_way, n_support, attr_num, attr_loc=False, dataset='CUB'):
219 | super(APNet_wo_attrLoc, self).__init__(model_func, n_way, n_support, attr_num, attr_loc, dataset)
220 |
221 | def train_loop(self, epoch, train_loader, optimizer, tf_writer):
222 | print_freq = 10
223 |
224 | avg_loss, avg_loss1, avg_loss2 = 0, 0, 0
225 | start_time = time.time()
226 | for i, (x, y) in enumerate(train_loader):
227 | self.n_query = x.size(1) - self.n_support
228 | x, y = x.cuda(), y.cuda()
229 | attr_labels = self.attr_labels_split[y]
230 | optimizer.zero_grad()
231 | loss1 = self.set_forward_loss1(x, None)
232 | loss2 = self.set_forward_loss2(x, attr_labels, None)
233 | loss = loss1 + self.beta * loss2
234 | loss.backward()
235 | optimizer.step()
236 | avg_loss = avg_loss + loss.item()
237 | avg_loss1 = avg_loss1 + loss1.item()
238 | avg_loss2 = avg_loss2 + loss2.item()
239 |
240 | if i % print_freq == 0:
241 | # print(optimizer.state_dict()['param_groups'][0]['lr'])
242 | print('Epoch {:d} | Batch {:d}/{:d} | Loss {:f} | Loss1 {:f} | Loss2 {:f}'.format(epoch, i, len(train_loader),
243 | avg_loss / float(i + 1), avg_loss1 / float(i + 1), avg_loss2 / float(i + 1)))
244 | tf_writer.add_scalar('loss/train', avg_loss / float(i + 1), epoch)
245 | tf_writer.add_scalar('loss1/train', avg_loss1 / float(i + 1), epoch)
246 | tf_writer.add_scalar('loss2/train', avg_loss2 / float(i + 1), epoch)
247 | print("Epoch (train) uses %.2f s!" % (time.time() - start_time))
248 |
249 | def test_loop(self, test_loader, return_std=False):
250 | acc_all = []
251 |
252 | iter_num = len(test_loader)
253 | start_time = time.time()
254 | for i, (x, _) in enumerate(test_loader):
255 | x = x.cuda()
256 | self.n_query = x.size(1) - self.n_support
257 | correct_this, count_this = self.correct(x, None)
258 |
259 | acc_all.append(correct_this / count_this * 100)
260 | print("Epoch (test) uses %.2f s!" % (time.time() - start_time))
261 |
262 | acc_all = np.asarray(acc_all)
263 | acc_mean = np.mean(acc_all)
264 | acc_std = np.std(acc_all)
265 | print('%d Test Acc = %4.2f%% +- %4.2f%%' % (iter_num, acc_mean, 1.96 * acc_std / np.sqrt(iter_num)))
266 |
267 | if return_std:
268 | return acc_mean, acc_std
269 | else:
270 | return acc_mean
271 |
272 | def test_loop_with_dists(self, test_loader, all_cls_dists, base_dists, base_cls_dists, attr_num):
273 | acc_all, dist_all = [], []
274 |
275 | iter_num = len(test_loader)
276 | from tqdm import tqdm
277 | for i, (x, y) in enumerate(tqdm(test_loader)):
278 | x, y = x.cuda(), y.cuda()
279 | self.n_query = x.size(1) - self.n_support
280 | if self.change_way:
281 | self.n_way = x.size(0)
282 | correct_this, count_this = self.correct(x, None)
283 | acc_all.append(correct_this / count_this * 100)
284 |
285 | sc_cls = y.unique()
286 | # original mean-task (down trend)
287 | task_dists = all_cls_dists[sc_cls, :].mean(0).unsqueeze(0)
288 | dist_all.append(torch.abs(base_dists - task_dists).sum(-1).mean().item() / attr_num)
289 |
290 | acc_all = np.asarray(acc_all)
291 | acc_mean = np.mean(acc_all)
292 | acc_std = np.std(acc_all)
293 | print('%d Test Acc = %4.2f%% +- %4.2f%%' % (iter_num, acc_mean, 1.96 * acc_std / np.sqrt(iter_num)))
294 |
295 | return acc_all, dist_all
--------------------------------------------------------------------------------
/backbone.py:
--------------------------------------------------------------------------------
1 | # This code is modified from https://github.com/facebookresearch/low-shot-shrink-hallucinate
2 |
3 | import torch
4 | from torch.autograd import Variable
5 | import torch.nn as nn
6 | import math
7 | import numpy as np
8 | import torch.nn.functional as F
9 | from torch.nn.utils.weight_norm import WeightNorm
10 |
11 |
12 | # Basic ResNet model
13 |
14 | def init_layer(L):
15 | # Initialization using fan-in
16 | if isinstance(L, nn.Conv2d):
17 | n = L.kernel_size[0] * L.kernel_size[1] * L.out_channels
18 | L.weight.data.normal_(0, math.sqrt(2.0 / float(n)))
19 | elif isinstance(L, nn.BatchNorm2d):
20 | L.weight.data.fill_(1)
21 | L.bias.data.fill_(0)
22 |
23 |
24 | class distLinear(nn.Module):
25 | def __init__(self, indim, outdim):
26 | super(distLinear, self).__init__()
27 | self.L = nn.Linear(indim, outdim, bias=False)
28 | self.class_wise_learnable_norm = True # See the issue#4&8 in the github
29 | if self.class_wise_learnable_norm:
30 | WeightNorm.apply(self.L, 'weight', dim=0) # split the weight update component to direction and norm
31 |
32 | if outdim <= 200:
33 | self.scale_factor = 2; # a fixed scale factor to scale the output of cos value into a reasonably large input for softmax, for to reproduce the result of CUB with ResNet10, use 4. see the issue#31 in the github
34 | else:
35 | self.scale_factor = 10; # in omniglot, a larger scale factor is required to handle >1000 output classes.
36 |
37 | def forward(self, x):
38 | x_norm = torch.norm(x, p=2, dim=1).unsqueeze(1).expand_as(x)
39 | x_normalized = x.div(x_norm + 0.00001)
40 | if not self.class_wise_learnable_norm:
41 | L_norm = torch.norm(self.L.weight.data, p=2, dim=1).unsqueeze(1).expand_as(self.L.weight.data)
42 | self.L.weight.data = self.L.weight.data.div(L_norm + 0.00001)
43 | cos_dist = self.L(
44 | x_normalized) # matrix product by forward function, but when using WeightNorm, this also multiply the cosine distance by a class-wise learnable norm, see the issue#4&8 in the github
45 | scores = self.scale_factor * (cos_dist)
46 |
47 | return scores
48 |
49 |
50 | class Flatten(nn.Module):
51 | def __init__(self):
52 | super(Flatten, self).__init__()
53 |
54 | def forward(self, x):
55 | return x.view(x.size(0), -1)
56 |
57 |
58 | class Linear_fw(nn.Linear): # used in MAML to forward input with fast weight
59 | def __init__(self, in_features, out_features):
60 | super(Linear_fw, self).__init__(in_features, out_features)
61 | self.weight.fast = None # Lazy hack to add fast weight link
62 | self.bias.fast = None
63 |
64 | def forward(self, x):
65 | if self.weight.fast is not None and self.bias.fast is not None:
66 | out = F.linear(x, self.weight.fast,
67 | self.bias.fast) # weight.fast (fast weight) is the temporaily adapted weight
68 | else:
69 | out = super(Linear_fw, self).forward(x)
70 | return out
71 |
72 |
73 | class Conv2d_fw(nn.Conv2d): # used in MAML to forward input with fast weight
74 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, bias=True):
75 | super(Conv2d_fw, self).__init__(in_channels, out_channels, kernel_size, stride=stride, padding=padding,
76 | bias=bias)
77 | self.weight.fast = None
78 | if not self.bias is None:
79 | self.bias.fast = None
80 |
81 | def forward(self, x):
82 | if self.bias is None:
83 | if self.weight.fast is not None:
84 | out = F.conv2d(x, self.weight.fast, None, stride=self.stride, padding=self.padding)
85 | else:
86 | out = super(Conv2d_fw, self).forward(x)
87 | else:
88 | if self.weight.fast is not None and self.bias.fast is not None:
89 | out = F.conv2d(x, self.weight.fast, self.bias.fast, stride=self.stride, padding=self.padding)
90 | else:
91 | out = super(Conv2d_fw, self).forward(x)
92 |
93 | return out
94 |
95 |
96 | class BatchNorm2d_fw(nn.BatchNorm2d): # used in MAML to forward input with fast weight
97 | def __init__(self, num_features):
98 | super(BatchNorm2d_fw, self).__init__(num_features)
99 | self.weight.fast = None
100 | self.bias.fast = None
101 |
102 | def forward(self, x):
103 | running_mean = torch.zeros(x.data.size()[1]).cuda()
104 | running_var = torch.ones(x.data.size()[1]).cuda()
105 | if self.weight.fast is not None and self.bias.fast is not None:
106 | out = F.batch_norm(x, running_mean, running_var, self.weight.fast, self.bias.fast, training=True,
107 | momentum=1)
108 | # batch_norm momentum hack: follow hack of Kate Rakelly in pytorch-maml/src/layers.py
109 | else:
110 | out = F.batch_norm(x, running_mean, running_var, self.weight, self.bias, training=True, momentum=1)
111 | return out
112 |
113 |
114 | # Simple Conv Block
115 | class ConvBlock(nn.Module):
116 | maml = False # Default
117 |
118 | def __init__(self, indim, outdim, pool=True, padding=1):
119 | super(ConvBlock, self).__init__()
120 | self.indim = indim
121 | self.outdim = outdim
122 | if self.maml:
123 | self.C = Conv2d_fw(indim, outdim, 3, padding=padding)
124 | self.BN = BatchNorm2d_fw(outdim)
125 | else:
126 | self.C = nn.Conv2d(indim, outdim, 3, padding=padding)
127 | self.BN = nn.BatchNorm2d(outdim)
128 | self.relu = nn.ReLU(inplace=True)
129 |
130 | self.parametrized_layers = [self.C, self.BN, self.relu]
131 | if pool:
132 | self.pool = nn.MaxPool2d(2)
133 | self.parametrized_layers.append(self.pool)
134 |
135 | for layer in self.parametrized_layers:
136 | init_layer(layer)
137 |
138 | self.trunk = nn.Sequential(*self.parametrized_layers)
139 |
140 | def forward(self, x):
141 | out = self.trunk(x)
142 | return out
143 |
144 |
145 | # Simple ResNet Block
146 | class SimpleBlock(nn.Module):
147 | maml = False # Default
148 |
149 | def __init__(self, indim, outdim, half_res):
150 | super(SimpleBlock, self).__init__()
151 | self.indim = indim
152 | self.outdim = outdim
153 | if self.maml:
154 | self.C1 = Conv2d_fw(indim, outdim, kernel_size=3, stride=2 if half_res else 1, padding=1, bias=False)
155 | self.BN1 = BatchNorm2d_fw(outdim)
156 | self.C2 = Conv2d_fw(outdim, outdim, kernel_size=3, padding=1, bias=False)
157 | self.BN2 = BatchNorm2d_fw(outdim)
158 | else:
159 | self.C1 = nn.Conv2d(indim, outdim, kernel_size=3, stride=2 if half_res else 1, padding=1, bias=False)
160 | self.BN1 = nn.BatchNorm2d(outdim)
161 | self.C2 = nn.Conv2d(outdim, outdim, kernel_size=3, padding=1, bias=False)
162 | self.BN2 = nn.BatchNorm2d(outdim)
163 | self.relu1 = nn.ReLU(inplace=True)
164 | self.relu2 = nn.ReLU(inplace=True)
165 |
166 | self.parametrized_layers = [self.C1, self.C2, self.BN1, self.BN2]
167 |
168 | self.half_res = half_res
169 |
170 | # if the input number of channels is not equal to the output, then need a 1x1 convolution
171 | if indim != outdim:
172 | if self.maml:
173 | self.shortcut = Conv2d_fw(indim, outdim, 1, 2 if half_res else 1, bias=False)
174 | self.BNshortcut = BatchNorm2d_fw(outdim)
175 | else:
176 | self.shortcut = nn.Conv2d(indim, outdim, 1, 2 if half_res else 1, bias=False)
177 | self.BNshortcut = nn.BatchNorm2d(outdim)
178 |
179 | self.parametrized_layers.append(self.shortcut)
180 | self.parametrized_layers.append(self.BNshortcut)
181 | self.shortcut_type = '1x1'
182 | else:
183 | self.shortcut_type = 'identity'
184 |
185 | for layer in self.parametrized_layers:
186 | init_layer(layer)
187 |
188 | def forward(self, x):
189 | out = self.C1(x)
190 | out = self.BN1(out)
191 | out = self.relu1(out)
192 | out = self.C2(out)
193 | out = self.BN2(out)
194 | short_out = x if self.shortcut_type == 'identity' else self.BNshortcut(self.shortcut(x))
195 | out = out + short_out
196 | out = self.relu2(out)
197 | return out
198 |
199 |
200 | # Bottleneck block
201 | class BottleneckBlock(nn.Module):
202 | maml = False # Default
203 |
204 | def __init__(self, indim, outdim, half_res):
205 | super(BottleneckBlock, self).__init__()
206 | bottleneckdim = int(outdim / 4)
207 | self.indim = indim
208 | self.outdim = outdim
209 | if self.maml:
210 | self.C1 = Conv2d_fw(indim, bottleneckdim, kernel_size=1, bias=False)
211 | self.BN1 = BatchNorm2d_fw(bottleneckdim)
212 | self.C2 = Conv2d_fw(bottleneckdim, bottleneckdim, kernel_size=3, stride=2 if half_res else 1, padding=1)
213 | self.BN2 = BatchNorm2d_fw(bottleneckdim)
214 | self.C3 = Conv2d_fw(bottleneckdim, outdim, kernel_size=1, bias=False)
215 | self.BN3 = BatchNorm2d_fw(outdim)
216 | else:
217 | self.C1 = nn.Conv2d(indim, bottleneckdim, kernel_size=1, bias=False)
218 | self.BN1 = nn.BatchNorm2d(bottleneckdim)
219 | self.C2 = nn.Conv2d(bottleneckdim, bottleneckdim, kernel_size=3, stride=2 if half_res else 1, padding=1)
220 | self.BN2 = nn.BatchNorm2d(bottleneckdim)
221 | self.C3 = nn.Conv2d(bottleneckdim, outdim, kernel_size=1, bias=False)
222 | self.BN3 = nn.BatchNorm2d(outdim)
223 |
224 | self.relu = nn.ReLU()
225 | self.parametrized_layers = [self.C1, self.BN1, self.C2, self.BN2, self.C3, self.BN3]
226 | self.half_res = half_res
227 |
228 | # if the input number of channels is not equal to the output, then need a 1x1 convolution
229 | if indim != outdim:
230 | if self.maml:
231 | self.shortcut = Conv2d_fw(indim, outdim, 1, stride=2 if half_res else 1, bias=False)
232 | else:
233 | self.shortcut = nn.Conv2d(indim, outdim, 1, stride=2 if half_res else 1, bias=False)
234 |
235 | self.parametrized_layers.append(self.shortcut)
236 | self.shortcut_type = '1x1'
237 | else:
238 | self.shortcut_type = 'identity'
239 |
240 | for layer in self.parametrized_layers:
241 | init_layer(layer)
242 |
243 | def forward(self, x):
244 |
245 | short_out = x if self.shortcut_type == 'identity' else self.shortcut(x)
246 | out = self.C1(x)
247 | out = self.BN1(out)
248 | out = self.relu(out)
249 | out = self.C2(out)
250 | out = self.BN2(out)
251 | out = self.relu(out)
252 | out = self.C3(out)
253 | out = self.BN3(out)
254 | out = out + short_out
255 |
256 | out = self.relu(out)
257 | return out
258 |
259 |
260 | class ConvNet(nn.Module):
261 | def __init__(self, depth, flatten=True):
262 | super(ConvNet, self).__init__()
263 | trunk = []
264 | for i in range(depth):
265 | indim = 3 if i == 0 else 64
266 | outdim = 64
267 | B = ConvBlock(indim, outdim, pool=(i < 4)) # only pooling for fist 4 layers
268 | trunk.append(B)
269 |
270 | if flatten:
271 | trunk.append(Flatten())
272 |
273 | self.trunk = nn.Sequential(*trunk)
274 | self.final_feat_dim = 1600
275 |
276 | def forward(self, x):
277 | out = self.trunk(x)
278 | return out
279 |
280 |
281 | class ConvNetNopool(
282 | nn.Module): # Relation net use a 4 layer conv with pooling in only first two layers, else no pooling
283 | def __init__(self, depth, flatten=True):
284 | super(ConvNetNopool, self).__init__()
285 | trunk = []
286 | for i in range(depth):
287 | indim = 3 if i == 0 else 64
288 | outdim = 64
289 | B = ConvBlock(indim, outdim, pool=(i in [0, 1]),
290 | padding=0 if i in [0, 1] else 1) # only first two layer has pooling and no padding
291 | trunk.append(B)
292 |
293 | if flatten:
294 | lastpool = nn.AdaptiveAvgPool2d((1, 1))
295 | trunk.append(lastpool)
296 | trunk.append(Flatten())
297 | self.final_feat_dim = 64
298 | else:
299 | lastpool = nn.AdaptiveAvgPool2d((7, 7))
300 | trunk.append(lastpool)
301 | self.final_feat_dim = [64, 7, 7]
302 |
303 | self.trunk = nn.Sequential(*trunk)
304 |
305 | def forward(self, x):
306 | out = self.trunk(x)
307 | return out
308 |
309 |
310 | class ConvNetS(nn.Module): # For omniglot, only 1 input channel, output dim is 64
311 | def __init__(self, depth, flatten=True):
312 | super(ConvNetS, self).__init__()
313 | trunk = []
314 | for i in range(depth):
315 | indim = 1 if i == 0 else 64
316 | outdim = 64
317 | B = ConvBlock(indim, outdim, pool=(i < 4)) # only pooling for fist 4 layers
318 | trunk.append(B)
319 |
320 | if flatten:
321 | trunk.append(Flatten())
322 |
323 | self.trunk = nn.Sequential(*trunk)
324 | self.final_feat_dim = 64
325 |
326 | def forward(self, x):
327 | out = x[:, 0:1, :, :] # only use the first dimension
328 | out = self.trunk(out)
329 | return out
330 |
331 |
332 | class ConvNetSNopool(
333 | nn.Module): # Relation net use a 4 layer conv with pooling in only first two layers, else no pooling. For omniglot, only 1 input channel, output dim is [64,5,5]
334 | def __init__(self, depth):
335 | super(ConvNetSNopool, self).__init__()
336 | trunk = []
337 | for i in range(depth):
338 | indim = 1 if i == 0 else 64
339 | outdim = 64
340 | B = ConvBlock(indim, outdim, pool=(i in [0, 1]),
341 | padding=0 if i in [0, 1] else 1) # only first two layer has pooling and no padding
342 | trunk.append(B)
343 |
344 | self.trunk = nn.Sequential(*trunk)
345 | self.final_feat_dim = [64, 5, 5]
346 |
347 | def forward(self, x):
348 | out = x[:, 0:1, :, :] # only use the first dimension
349 | out = self.trunk(out)
350 | return out
351 |
352 |
353 | class ResNet(nn.Module):
354 | maml = False # Default
355 |
356 | def __init__(self, block, list_of_num_layers, list_of_out_dims, flatten=True):
357 | # list_of_num_layers specifies number of layers in each stage
358 | # list_of_out_dims specifies number of output channel for each stage
359 | super(ResNet, self).__init__()
360 | assert len(list_of_num_layers) == 4, 'Can have only four stages'
361 | if self.maml:
362 | conv1 = Conv2d_fw(3, 64, kernel_size=7, stride=2, padding=3,
363 | bias=False)
364 | bn1 = BatchNorm2d_fw(64)
365 | else:
366 | conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,
367 | bias=False)
368 | bn1 = nn.BatchNorm2d(64)
369 |
370 | relu = nn.ReLU()
371 | pool1 = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
372 |
373 | init_layer(conv1)
374 | init_layer(bn1)
375 |
376 | trunk = [conv1, bn1, relu, pool1]
377 |
378 | indim = 64
379 | for i in range(4):
380 |
381 | for j in range(list_of_num_layers[i]):
382 | half_res = (i >= 1) and (j == 0)
383 | B = block(indim, list_of_out_dims[i], half_res)
384 | trunk.append(B)
385 | indim = list_of_out_dims[i]
386 |
387 | if flatten:
388 | avgpool = nn.AvgPool2d(7)
389 | trunk.append(avgpool)
390 | trunk.append(Flatten())
391 | self.final_feat_dim = indim
392 | else:
393 | self.final_feat_dim = [indim, 7, 7]
394 |
395 | self.trunk = nn.Sequential(*trunk)
396 |
397 | def forward(self, x):
398 | out = self.trunk(x)
399 | return out
400 |
401 |
402 | def Conv4():
403 | return ConvNetNopool(4, flatten=True)
404 |
405 |
406 | def Conv6():
407 | return ConvNet(6)
408 |
409 |
410 | def Conv4NP(flatten=False):
411 | return ConvNetNopool(4, flatten=False)
412 |
413 |
414 | def Conv6NP(flatten=False):
415 | return ConvNetNopool(6, flatten=False)
416 |
417 |
418 | def Conv4S():
419 | return ConvNetS(4)
420 |
421 |
422 | def Conv4SNP():
423 | return ConvNetSNopool(4)
424 |
425 |
426 | def ResNet10(flatten=True):
427 | return ResNet(SimpleBlock, [1, 1, 1, 1], [64, 128, 256, 512], flatten)
428 |
429 |
430 | def ResNet10NP(flatten=False):
431 | return ResNet(SimpleBlock, [1, 1, 1, 1], [64, 128, 256, 512], flatten)
432 |
433 |
434 | def ResNet18(flatten=True):
435 | return ResNet(SimpleBlock, [2, 2, 2, 2], [64, 128, 256, 512], flatten)
436 |
437 |
438 | def ResNet34(flatten=True):
439 | return ResNet(SimpleBlock, [3, 4, 6, 3], [64, 128, 256, 512], flatten)
440 |
441 |
442 | def ResNet50(flatten=True):
443 | return ResNet(BottleneckBlock, [3, 4, 6, 3], [256, 512, 1024, 2048], flatten)
444 |
445 |
446 | def ResNet101(flatten=True):
447 | return ResNet(BottleneckBlock, [3, 4, 23, 3], [256, 512, 1024, 2048], flatten)
448 |
449 |
450 |
451 |
452 |
--------------------------------------------------------------------------------