├── methods ├── pt_map │ ├── __init__.py │ ├── data │ │ ├── __init__.py │ │ ├── additional_transforms.py │ │ ├── datamgr.py │ │ └── dataset.py │ ├── io_utils.py │ ├── save_plk.py │ ├── evaluation │ │ ├── save_plk.py │ │ └── test_standard.py │ ├── FSLTask.py │ ├── pt_map_loss.py │ └── test_standard.py ├── prototypical │ ├── __init__.py │ └── proto_loss.py └── __init__.py ├── .idea ├── .gitignore ├── vcs.xml ├── inspectionProfiles │ ├── profiles_settings.xml │ └── Project_Default.xml ├── encodings.xml ├── modules.xml ├── SOT.iml └── misc.xml ├── bpa_workflow.png ├── models ├── __init__.py ├── dropblock.py ├── resnet12.py ├── res_mixup_model.py └── wrn_mixup_model.py ├── bpa ├── __init__.py ├── ot.py └── balanced_pairwise_affinities.py ├── datasets ├── __init__.py ├── README.md ├── samplers.py ├── cifar.py ├── mini_imagenet.py ├── cub.py └── get_cifar_fs.py ├── .gitignore ├── README.md ├── train.py └── utils.py /methods/pt_map/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /.idea/.gitignore: -------------------------------------------------------------------------------- 1 | # Default ignored files 2 | /shelf/ 3 | /workspace.xml 4 | -------------------------------------------------------------------------------- /bpa_workflow.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DanielShalam/BPA/HEAD/bpa_workflow.png -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | from . import res_mixup_model 2 | from . import wrn_mixup_model 3 | -------------------------------------------------------------------------------- /methods/prototypical/__init__.py: -------------------------------------------------------------------------------- 1 | from methods.prototypical.proto_loss import ProtoLoss 2 | -------------------------------------------------------------------------------- /methods/__init__.py: -------------------------------------------------------------------------------- 1 | from methods.pt_map.pt_map_loss import PTMAPLoss 2 | from methods.prototypical import ProtoLoss 3 | -------------------------------------------------------------------------------- /methods/pt_map/data/__init__.py: -------------------------------------------------------------------------------- 1 | from . import datamgr 2 | from . import dataset 3 | from . import additional_transforms 4 | -------------------------------------------------------------------------------- /bpa/__init__.py: -------------------------------------------------------------------------------- 1 | from bpa.balanced_pairwise_affinities import BPA 2 | from bpa.ot import log_sinkhorn, batched_log_sinkhorn 3 | -------------------------------------------------------------------------------- /datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from datasets.cifar import CIFAR 2 | from datasets.cub import CUB 3 | from datasets.mini_imagenet import MiniImageNet -------------------------------------------------------------------------------- /.idea/vcs.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | -------------------------------------------------------------------------------- /.idea/inspectionProfiles/profiles_settings.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 6 | -------------------------------------------------------------------------------- /.idea/encodings.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | -------------------------------------------------------------------------------- /.idea/modules.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /.idea/SOT.iml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /.idea/misc.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 7 | -------------------------------------------------------------------------------- /methods/pt_map/data/additional_transforms.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | 8 | import torch 9 | from PIL import ImageEnhance 10 | 11 | transformtypedict = dict(Brightness=ImageEnhance.Brightness, Contrast=ImageEnhance.Contrast, 12 | Sharpness=ImageEnhance.Sharpness, Color=ImageEnhance.Color) 13 | 14 | 15 | class ImageJitter(object): 16 | def __init__(self, transformdict): 17 | self.transforms = [(transformtypedict[k], transformdict[k]) for k in transformdict] 18 | 19 | def __call__(self, img): 20 | out = img 21 | randtensor = torch.rand(len(self.transforms)) 22 | 23 | for i, (transformer, alpha) in enumerate(self.transforms): 24 | r = alpha * (randtensor[i] * 2.0 - 1.0) + 1 25 | out = transformer(out).enhance(r).convert('RGB') 26 | 27 | return out 28 | -------------------------------------------------------------------------------- /datasets/README.md: -------------------------------------------------------------------------------- 1 | # Datasets 2 | 3 | Follow the instructions to prepare the datasets. 4 | 5 | ## Miniimagenet: 6 | 7 | ``` 8 | ./miniimagenet/ 9 | └── ./miniimagenet/split/ # split files 10 | ├── ./miniimagenet/split/train.csv 11 | ├── ./miniimagenet/split/val.csv 12 | └── ./miniimagenet/split/test.csv 13 | └── /miniimagenet/images/ # all images 14 | └── ./miniimagenet/images/.jpg 15 | 16 | ``` 17 | 18 | Download the dataset from [here](https://cseweb.ucsd.edu/~weijian/static/datasets/mini-ImageNet/MiniImagenet.tar.gz). 19 | 20 | Download the train/val/test split from this [repo](https://github.com/twitter-research/meta-learning-lstm/). 21 | 22 | ## CIFAR-FS: 23 | 24 | ``` 25 | ./cifar_fs/ 26 | └── ./cifar_fs/train/ # train images 27 | └── ./cifar_fs/train//.jpg 28 | └── ./cifar_fs/val/ # val images 29 | └── ./cifar_fs/val//.jpg 30 | └── ./cifar_fs/test/ # test images 31 | └── ./cifar_fs/test//.jpg 32 | ``` 33 | 34 | Run the following command to automatically download the dataset: 35 | 36 | ``` 37 | python get_cifar_fs.py 38 | ``` 39 | 40 | Modify "datapath" inside the script to choose different location. -------------------------------------------------------------------------------- /.idea/inspectionProfiles/Project_Default.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 32 | -------------------------------------------------------------------------------- /methods/prototypical/proto_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | from bpa import BPA 5 | 6 | 7 | class ProtoLoss(nn.Module): 8 | """ 9 | Prototypical loss function 10 | """ 11 | def __init__(self, args: dict, bpa: BPA = None): 12 | super().__init__() 13 | self.num_shot = args['num_shot'] 14 | self.num_query = args['num_query'] 15 | self.way_dict = dict(train=args['train_way'], val=args['val_way']) 16 | self.temperature = args['temperature'] 17 | self.BPA = bpa 18 | self.num_labeled = None 19 | 20 | @staticmethod 21 | def get_accuracy(probas: torch.Tensor, labels: torch.Tensor): 22 | y_hat = probas.argmin(dim=-1) 23 | matches = labels.eq(y_hat).float() 24 | m = matches.mean().item() 25 | # pm = matches.std(unbiased=False).item() * 1.96 26 | return m 27 | 28 | def forward(self, X: torch.Tensor, labels: torch.Tensor, mode: str): 29 | num_way = self.way_dict[mode] 30 | self.num_labeled = num_way * self.num_shot 31 | 32 | # apply the BPA transform 33 | if self.BPA is not None: 34 | X = self.BPA(X, y=labels[:self.num_labeled]) 35 | 36 | # split to support and queries 37 | X_support, X_query = X.split((self.num_labeled, X.size(0)-self.num_labeled), dim=0) 38 | 39 | # compute centroids 40 | # -> assuming input data sorted as [0, 1, 2, 3, 4, 0, 1, 2, 3, 4, ...] 41 | X_centroid = X_support.reshape(self.num_shot, num_way, -1).transpose(0, 1).mean(dim=1) 42 | 43 | # compute distances between queries and the centroids 44 | D = (X_query.unsqueeze(1) - X_centroid.unsqueeze(0)).norm(dim=2).pow(2) 45 | D = D / self.temperature 46 | 47 | return -D, ProtoLoss.get_accuracy(D, labels) 48 | 49 | 50 | 51 | 52 | 53 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | 131 | secret/ 132 | checkpoints/ 133 | wandb/ 134 | experiment_figs/ 135 | commands.txt 136 | main_ablation.py -------------------------------------------------------------------------------- /models/dropblock.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torch import nn 4 | from torch.distributions import Bernoulli 5 | 6 | 7 | class DropBlock(nn.Module): 8 | def __init__(self, block_size): 9 | super(DropBlock, self).__init__() 10 | 11 | self.block_size = block_size 12 | 13 | def forward(self, x, gamma): 14 | # shape: (bsize, channels, height, width) 15 | 16 | if self.training: 17 | batch_size, channels, height, width = x.shape 18 | bernoulli = Bernoulli(gamma) 19 | mask = bernoulli.sample((batch_size, channels, height - (self.block_size - 1), width - (self.block_size - 1))) 20 | if torch.cuda.is_available(): 21 | mask = mask.cuda() 22 | block_mask = self._compute_block_mask(mask) 23 | countM = block_mask.size()[0] * block_mask.size()[1] * block_mask.size()[2] * block_mask.size()[3] 24 | count_ones = block_mask.sum() 25 | 26 | return block_mask * x * (countM / count_ones) 27 | else: 28 | return x 29 | 30 | def _compute_block_mask(self, mask): 31 | left_padding = int((self.block_size-1) / 2) 32 | right_padding = int(self.block_size / 2) 33 | 34 | batch_size, channels, height, width = mask.shape 35 | non_zero_idxs = mask.nonzero() 36 | nr_blocks = non_zero_idxs.shape[0] 37 | 38 | offsets = torch.stack( 39 | [ 40 | torch.arange(self.block_size).view(-1, 1).expand(self.block_size, self.block_size).reshape(-1), # - left_padding, 41 | torch.arange(self.block_size).repeat(self.block_size), #- left_padding 42 | ] 43 | ).t() 44 | offsets = torch.cat((torch.zeros(self.block_size**2, 2).long(), offsets.long()), 1) 45 | if torch.cuda.is_available(): 46 | offsets = offsets.cuda() 47 | 48 | if nr_blocks > 0: 49 | non_zero_idxs = non_zero_idxs.repeat(self.block_size ** 2, 1) 50 | offsets = offsets.repeat(nr_blocks, 1).view(-1, 4) 51 | offsets = offsets.long() 52 | 53 | block_idxs = non_zero_idxs + offsets 54 | #block_idxs += left_padding 55 | padded_mask = F.pad(mask, (left_padding, right_padding, left_padding, right_padding)) 56 | padded_mask[block_idxs[:, 0], block_idxs[:, 1], block_idxs[:, 2], block_idxs[:, 3]] = 1. 57 | else: 58 | padded_mask = F.pad(mask, (left_padding, right_padding, left_padding, right_padding)) 59 | 60 | block_mask = 1 - padded_mask#[:height, :width] 61 | return block_mask -------------------------------------------------------------------------------- /bpa/ot.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def log_sum_exp(u: torch.Tensor, dim: int): 5 | # Reduce log sum exp along axis 6 | u_max, __ = u.max(dim=dim, keepdim=True) 7 | log_sum_exp_u = torch.log(torch.exp(u - u_max).sum(dim)) + u_max.sum(dim) 8 | return log_sum_exp_u 9 | 10 | 11 | def log_sinkhorn(M: torch.Tensor, reg: float, num_iters: int): 12 | """ 13 | Log-space-sinkhorn algorithm for better stability. 14 | """ 15 | if M.dim() > 2: 16 | return batched_log_sinkhorn(M=M, reg=reg, num_iters=num_iters) 17 | 18 | # Initialize dual variable v (u is implicitly defined in the loop) 19 | log_v = torch.zeros(M.size()[1]).to(M.device) # ==torch.log(torch.ones(m.size()[1])) 20 | 21 | # Exponentiate the pairwise distance matrix 22 | log_K = -M / reg 23 | 24 | # Main loop 25 | for i in range(num_iters): 26 | # Match r marginals 27 | log_u = - log_sum_exp(log_K + log_v[None, :], dim=1) 28 | 29 | # Match c marginals 30 | log_v = - log_sum_exp(log_u[:, None] + log_K, dim=0) 31 | 32 | # Compute optimal plan, cost, return everything 33 | log_P = log_u[:, None] + log_K + log_v[None, :] 34 | return log_P 35 | 36 | 37 | def batched_log_sinkhorn(M, reg: float, num_iters: int): 38 | """ 39 | Batched version of log-space-sinkhorn. 40 | """ 41 | batch_size, x_points, _ = M.shape 42 | # both marginals are fixed with equal weights 43 | mu = torch.empty(batch_size, x_points, dtype=torch.float, 44 | requires_grad=False).fill_(1.0 / x_points).squeeze().to(M.device) 45 | nu = torch.empty(batch_size, x_points, dtype=torch.float, 46 | requires_grad=False).fill_(1.0 / x_points).squeeze().to(M.device) 47 | 48 | u = torch.zeros_like(mu) 49 | v = torch.zeros_like(nu) 50 | # To check if algorithm terminates because of threshold 51 | # or max iterations reached 52 | actual_nits = 0 53 | # Stopping criterion 54 | thresh = 1e-1 55 | 56 | def C(M, u, v, reg): 57 | """Modified cost for logarithmic updates""" 58 | return (-M + u.unsqueeze(-1) + v.unsqueeze(-2)) / reg 59 | 60 | # Sinkhorn iterations 61 | for i in range(num_iters): 62 | u1 = u # useful to check the update 63 | u = reg * (torch.log(mu + 1e-8) - torch.logsumexp(C(M, u, v, reg), dim=-1)) + u 64 | v = reg * (torch.log(nu + 1e-8) - torch.logsumexp(C(M, u, v, reg).transpose(-2, -1), dim=-1)) + v 65 | err = (u - u1).abs().sum(-1).mean() 66 | 67 | actual_nits += 1 68 | if err.item() < thresh: 69 | break 70 | 71 | U, V = u, v 72 | # Transport plan pi = diag(a)*K*diag(b) 73 | log_p = C(M, U, V, reg) 74 | return log_p 75 | -------------------------------------------------------------------------------- /methods/pt_map/data/datamgr.py: -------------------------------------------------------------------------------- 1 | # This code is modified from https://github.com/facebookresearch/low-shot-shrink-hallucinate 2 | import torch 3 | import torchvision.transforms as transforms 4 | from . import additional_transforms as add_transforms 5 | from .dataset import SimpleDataset 6 | from abc import abstractmethod 7 | 8 | 9 | class TransformLoader: 10 | def __init__(self, image_size, 11 | normalize_param=dict(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), 12 | jitter_param=dict(Brightness=0.4, Contrast=0.4, Color=0.4)): 13 | self.image_size = image_size 14 | self.normalize_param = normalize_param 15 | self.jitter_param = jitter_param 16 | 17 | def parse_transform(self, transform_type): 18 | if transform_type == 'ImageJitter': 19 | method = add_transforms.ImageJitter(self.jitter_param) 20 | return method 21 | method = getattr(transforms, transform_type) 22 | if transform_type == 'RandomSizedCrop': 23 | return method(self.image_size) 24 | elif transform_type == 'CenterCrop': 25 | return method(self.image_size) 26 | elif transform_type == 'Resize': 27 | return method([int(self.image_size * 1.15), int(self.image_size * 1.15)]) 28 | elif transform_type == 'Normalize': 29 | return method(**self.normalize_param) 30 | else: 31 | return method() 32 | 33 | def get_composed_transform(self, aug=False): 34 | if aug: 35 | transform_list = ['RandomSizedCrop', 'ImageJitter', 'RandomHorizontalFlip', 'ToTensor', 'Normalize'] 36 | else: 37 | transform_list = ['Resize', 'CenterCrop', 'ToTensor', 'Normalize'] 38 | 39 | transform_funcs = [self.parse_transform(x) for x in transform_list] 40 | transform = transforms.Compose(transform_funcs) 41 | return transform 42 | 43 | 44 | class DataManager: 45 | @abstractmethod 46 | def get_data_loader(self, data_file, aug): 47 | pass 48 | 49 | 50 | class SimpleDataManager(DataManager): 51 | def __init__(self, image_size, batch_size): 52 | super(SimpleDataManager, self).__init__() 53 | self.batch_size = batch_size 54 | self.trans_loader = TransformLoader(image_size) 55 | 56 | def get_data_loader(self, data_file, aug): # parameters that would change on train/val set 57 | transform = self.trans_loader.get_composed_transform(aug) 58 | dataset = SimpleDataset(data_file, transform) 59 | data_loader_params = dict(batch_size=self.batch_size, shuffle=True, num_workers=12, pin_memory=True) 60 | data_loader = torch.utils.data.DataLoader(dataset, **data_loader_params) 61 | 62 | return data_loader 63 | 64 | -------------------------------------------------------------------------------- /datasets/samplers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | 5 | class CategoriesSampler: 6 | 7 | def __init__(self, set_name, labels, num_episodes, 8 | num_way, num_shot, num_query, const_loader, 9 | replace=True): 10 | 11 | self.set_name = set_name 12 | self.num_way = num_way 13 | self.num_shot = num_shot 14 | self.num_query = num_query 15 | self.num_episodes = num_episodes 16 | self.const_loader = const_loader # same tasks in different epochs. good for validation 17 | self.replace = replace # sample few-shot tasks with replacement (same class can appear twice or more 18 | 19 | self.m_ind = [] 20 | self.batches = [] 21 | 22 | labels = np.array(labels) 23 | for i in range(max(labels) + 1): 24 | ind = np.argwhere(labels == i).reshape(-1) 25 | ind = torch.from_numpy(ind) 26 | self.m_ind.append(ind) 27 | 28 | self.classes = np.arange(len(self.m_ind)) 29 | 30 | if self.const_loader: 31 | for i_batch in range(self.num_episodes): 32 | batch = [] 33 | # -- faster loading with np.choice -- # 34 | # classes = torch.randperm(len(self.m_ind))[:self.num_way] 35 | classes = np.random.choice(self.classes, size=self.num_way, replace=self.replace) 36 | for c in classes: 37 | l = self.m_ind[c] 38 | pos = np.random.choice(np.arange(l.shape[0]), 39 | size=self.num_shot + self.num_query, 40 | replace=False) 41 | batch.append(l[pos]) 42 | 43 | batch = torch.from_numpy(np.stack(batch)).t().reshape(-1) 44 | self.batches.append(batch) 45 | 46 | def __len__(self): 47 | return self.num_episodes 48 | 49 | def __iter__(self): 50 | if not self.const_loader: 51 | for batch_idx in range(self.num_episodes): 52 | batch = [] 53 | # classes = torch.randperm(len(self.m_ind))[:self.num_way] 54 | classes = np.random.choice(self.classes, size=self.num_way, replace=self.replace) 55 | for c in classes: 56 | l = self.m_ind[c] 57 | pos = np.random.choice(np.arange(l.shape[0]), 58 | size=self.num_shot + self.num_query, 59 | replace=False) 60 | batch.append(l[pos]) 61 | 62 | batch = torch.from_numpy(np.stack(batch)).t().reshape(-1) 63 | yield batch 64 | else: 65 | for batch_idx in range(self.num_episodes): 66 | yield self.batches[batch_idx] 67 | -------------------------------------------------------------------------------- /datasets/cifar.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | from PIL import Image 3 | import os 4 | from torch.utils.data import Dataset 5 | from torchvision import transforms 6 | 7 | 8 | class CIFAR(Dataset): 9 | 10 | def __init__(self, data_path: str, setname: str, backbone: str, augment: bool): 11 | d = osp.join(data_path, setname) 12 | dirs = [os.path.join(d, o) for o in os.listdir(d) if os.path.isdir(os.path.join(d, o))] 13 | 14 | data = [] 15 | label = [] 16 | lb = -1 17 | 18 | for d in dirs: 19 | lb += 1 20 | for image_name in os.listdir(d): 21 | path = osp.join(d, image_name) 22 | data.append(path) 23 | label.append(lb) 24 | 25 | self.data = data 26 | self.label = label 27 | 28 | mean = [x / 255.0 for x in [129.37731888, 124.10583864, 112.47758569]] 29 | std = [x / 255.0 for x in [68.20947949, 65.43124043, 70.45866994]] 30 | normalize = transforms.Normalize(mean=mean, std=std) 31 | 32 | self.image_size = 32 33 | if augment and setname == 'train': 34 | transforms_list = [ 35 | transforms.RandomResizedCrop(self.image_size), 36 | transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4), 37 | transforms.RandomHorizontalFlip(), 38 | transforms.ToTensor(), 39 | ] 40 | else: 41 | transforms_list = [ 42 | transforms.Resize((self.image_size, self.image_size)), 43 | transforms.ToTensor(), 44 | ] 45 | 46 | self.transform = transforms.Compose( 47 | transforms_list + [normalize] 48 | ) 49 | 50 | def __len__(self): 51 | return len(self.data) 52 | 53 | def __getitem__(self, i): 54 | path, label = self.data[i], self.label[i] 55 | image = self.transform(Image.open(path).convert('RGB')) 56 | return image, label 57 | 58 | 59 | def get_transform(img_size: int, split_name: str): 60 | mean = [x / 255.0 for x in [129.37731888, 124.10583864, 112.47758569]] 61 | std = [x / 255.0 for x in [68.20947949, 65.43124043, 70.45866994]] 62 | normalize = transforms.Normalize(mean=mean, std=std) 63 | 64 | if split_name == 'train': 65 | return transforms.Compose([ 66 | # transforms.RandomResizedCrop((img_size, img_size), scale=(0.05, 1.0)), 67 | transforms.RandomCrop(32, padding=4), 68 | transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4), 69 | transforms.RandomHorizontalFlip(), 70 | transforms.ToTensor(), 71 | normalize 72 | ]) 73 | 74 | else: 75 | return transforms.Compose([ 76 | transforms.Resize((img_size, img_size)), 77 | transforms.ToTensor(), 78 | normalize 79 | ]) 80 | -------------------------------------------------------------------------------- /methods/pt_map/io_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | import glob 4 | import argparse 5 | 6 | 7 | def parse_args(script): 8 | parser = argparse.ArgumentParser(description='few-shot script %s' % (script)) 9 | parser.add_argument('--dataset', default='miniImagenet', help='CUB/miniImagenet') 10 | parser.add_argument('--model', default='WideResNet28_10', help='model: WideResNet28_10/ResNet{18}') 11 | parser.add_argument('--method', default='S2M2_R', help='rotation/S2M2_R') 12 | parser.add_argument('--train_aug', action='store_true', 13 | help='perform data augmentation or not during training ') # still required for 14 | # save_features.py and test.py to find the model path correctly 15 | if script == 'train': 16 | parser.add_argument('--num_classes', default=200, type=int, 17 | help='total number of classes') # make it larger than the maximum label value in base class 18 | parser.add_argument('--save_freq', default=10, type=int, help='Save frequency') 19 | parser.add_argument('--start_epoch', default=0, type=int, help='Starting epoch') 20 | parser.add_argument('--stop_epoch', default=400, type=int, 21 | help='Stopping epoch') # for meta-learning methods, each epoch contains 100 episodes. 22 | # The default epoch number is dataset dependent. See train.py 23 | parser.add_argument('--resume', action='store_true', 24 | help='continue from previous trained model with largest epoch') 25 | parser.add_argument('--lr', default=0.001, type=int, help='learning rate') 26 | parser.add_argument('--batch_size', default=16, type=int, help='batch size ') 27 | parser.add_argument('--test_batch_size', default=2, type=int, help='batch size ') 28 | parser.add_argument('--alpha', default=2.0, type=int, help='for S2M2 training ') 29 | elif script == 'test': 30 | parser.add_argument('--num_classes', default=200, type=int, help='total number of classes') 31 | parser.add_argument('--model_dir', type=str, help='the pretrained model path ') 32 | parser.add_argument('--file_name', type=str, help='where the features will be saved ') 33 | parser.add_argument('--json_dir', type=str, default='./', help='') 34 | 35 | return parser.parse_args() 36 | 37 | 38 | def get_assigned_file(checkpoint_dir, num): 39 | assign_file = os.path.join(checkpoint_dir, '{:d}.tar'.format(num)) 40 | return assign_file 41 | 42 | 43 | def get_resume_file(checkpoint_dir): 44 | filelist = glob.glob(os.path.join(checkpoint_dir, '*.tar')) 45 | if len(filelist) == 0: 46 | return None 47 | 48 | filelist = [x for x in filelist if os.path.basename(x) != 'best.tar'] 49 | epochs = np.array([int(os.path.splitext(os.path.basename(x))[0]) for x in filelist]) 50 | max_epoch = np.max(epochs) 51 | resume_file = os.path.join(checkpoint_dir, '{:d}.tar'.format(max_epoch)) 52 | return resume_file 53 | 54 | 55 | def get_best_file(checkpoint_dir): 56 | best_file = os.path.join(checkpoint_dir, 'best.tar') 57 | if os.path.isfile(best_file): 58 | return best_file 59 | else: 60 | return get_resume_file(checkpoint_dir) 61 | -------------------------------------------------------------------------------- /datasets/mini_imagenet.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | from PIL import Image, ImageFilter 3 | from torch.utils.data import Dataset 4 | from torchvision import transforms 5 | import numpy as np 6 | 7 | 8 | class MiniImageNet(Dataset): 9 | 10 | def __init__(self, data_path: str, setname: str, backbone: str, augment: bool): 11 | try: 12 | csv_path = osp.join(data_path, setname + '.csv') 13 | lines = [x.strip() for x in open(csv_path, 'r').readlines()][1:] 14 | except FileNotFoundError: 15 | csv_path = osp.join(data_path, 'split', setname + '.csv') 16 | lines = [x.strip() for x in open(csv_path, 'r').readlines()][1:] 17 | 18 | data = [] 19 | label = [] 20 | lb = -1 21 | 22 | self.wnids = [] 23 | for l in lines: 24 | name, wnid = l.split(',') 25 | path = osp.join(data_path, 'images', name) 26 | if wnid not in self.wnids: 27 | self.wnids.append(wnid) 28 | lb += 1 29 | data.append(path) 30 | label.append(lb) 31 | 32 | self.data = data 33 | self.label = label 34 | 35 | self.image_size = 84 36 | if augment: 37 | # augment only if training and args.augment set to true 38 | transforms_list = [ 39 | transforms.RandomResizedCrop(self.image_size), 40 | transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4), 41 | transforms.RandomHorizontalFlip(), 42 | transforms.ToTensor(), 43 | ] 44 | else: 45 | transforms_list = [ 46 | transforms.Resize(92), 47 | transforms.CenterCrop(self.image_size), 48 | transforms.ToTensor(), 49 | ] 50 | 51 | # Transformation 52 | backbone = backbone.lower() 53 | if backbone == 'convnet': 54 | self.transform = transforms.Compose( 55 | transforms_list + [ 56 | transforms.Normalize(np.array([0.485, 0.456, 0.406]), 57 | np.array([0.229, 0.224, 0.225])) 58 | ]) 59 | elif backbone == 'resnet12': 60 | self.transform = transforms.Compose( 61 | transforms_list + [ 62 | transforms.Normalize(np.array([x / 255.0 for x in [120.39586422, 115.59361427, 104.54012653]]), 63 | np.array([x / 255.0 for x in [70.68188272, 68.27635443, 72.54505529]])) 64 | ]) 65 | elif backbone in 'resnet18' or 'wrn' in backbone or 'vit' in backbone: 66 | self.transform = transforms.Compose( 67 | transforms_list + [ 68 | transforms.Normalize(mean=[0.485, 0.456, 0.406], 69 | std=[0.229, 0.224, 0.225]) 70 | ]) 71 | else: 72 | raise ValueError('Non-supported Network Types. Please Revise Data Pre-Processing Scripts.') 73 | 74 | def __len__(self): 75 | return len(self.data) 76 | 77 | def __getitem__(self, i): 78 | path, label = self.data[i], self.label[i] 79 | orig = self.transform(Image.open(path).convert('RGB')) 80 | return orig, label 81 | -------------------------------------------------------------------------------- /methods/pt_map/data/dataset.py: -------------------------------------------------------------------------------- 1 | # This code is modified from https://github.com/facebookresearch/low-shot-shrink-hallucinate 2 | import torch 3 | from PIL import Image 4 | import json 5 | import numpy as np 6 | import torchvision.transforms as transforms 7 | import os 8 | 9 | 10 | def identity(x): 11 | return x 12 | 13 | 14 | class SimpleDataset: 15 | def __init__(self, data_file, transform, target_transform=identity): 16 | with open(data_file, 'r') as f: 17 | self.meta = json.load(f) 18 | self.dir = '/'.join(data_file.split('/')[:-1]) 19 | print(self.dir) 20 | self.transform = transform 21 | self.target_transform = target_transform 22 | 23 | def __getitem__(self, i): 24 | image_path = os.path.join(self.dir, 'images', self.meta['image_names'][i]) 25 | img = Image.open(image_path).convert('RGB') 26 | img = self.transform(img) 27 | target = self.target_transform(self.meta['image_labels'][i]) 28 | return img, target 29 | 30 | def __len__(self): 31 | return len(self.meta['image_names']) 32 | 33 | 34 | class SetDataset: 35 | def __init__(self, data_file, batch_size, transform): 36 | with open(data_file, 'r') as f: 37 | self.meta = json.load(f) 38 | 39 | self.cl_list = np.unique(self.meta['image_labels']).tolist() 40 | 41 | self.sub_meta = {} 42 | for cl in self.cl_list: 43 | self.sub_meta[cl] = [] 44 | 45 | for x, y in zip(self.meta['image_names'], self.meta['image_labels']): 46 | self.sub_meta[y].append(x) 47 | 48 | self.sub_dataloader = [] 49 | sub_data_loader_params = dict(batch_size=batch_size, 50 | shuffle=True, 51 | num_workers=0, # use main thread only or may receive multiple batches 52 | pin_memory=False) 53 | for cl in self.cl_list: 54 | sub_dataset = SubDataset(self.sub_meta[cl], cl, transform=transform) 55 | self.sub_dataloader.append(torch.utils.data.DataLoader(sub_dataset, **sub_data_loader_params)) 56 | 57 | def __getitem__(self, i): 58 | return next(iter(self.sub_dataloader[i])) 59 | 60 | def __len__(self): 61 | return len(self.cl_list) 62 | 63 | 64 | class SubDataset: 65 | def __init__(self, sub_meta, cl, transform=transforms.ToTensor(), target_transform=identity): 66 | self.sub_meta = sub_meta 67 | self.cl = cl 68 | self.transform = transform 69 | self.target_transform = target_transform 70 | 71 | def __getitem__(self, i): 72 | # print( '%d -%d' %(self.cl,i)) 73 | image_path = os.path.join(self.sub_meta[i]) 74 | img = Image.open(image_path).convert('RGB') 75 | img = self.transform(img) 76 | target = self.target_transform(self.cl) 77 | return img, target 78 | 79 | def __len__(self): 80 | return len(self.sub_meta) 81 | 82 | 83 | class EpisodicBatchSampler(object): 84 | def __init__(self, n_classes, n_way, n_episodes): 85 | self.n_classes = n_classes 86 | self.n_way = n_way 87 | self.n_episodes = n_episodes 88 | 89 | def __len__(self): 90 | return self.n_episodes 91 | 92 | def __iter__(self): 93 | for i in range(self.n_episodes): 94 | yield torch.randperm(self.n_classes)[:self.n_way] 95 | -------------------------------------------------------------------------------- /methods/pt_map/save_plk.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import os 3 | import collections 4 | import pickle 5 | import torch 6 | import torch.backends.cudnn as cudnn 7 | import torch.nn as nn 8 | from data.datamgr import SimpleDataManager 9 | from io_utils import parse_args 10 | import sys 11 | 12 | from models import wrn_mixup_model, res_mixup_model 13 | 14 | use_gpu = torch.cuda.is_available() 15 | 16 | 17 | class WrappedModel(nn.Module): 18 | def __init__(self, module): 19 | super(WrappedModel, self).__init__() 20 | self.module = module 21 | 22 | def forward(self, x): 23 | return self.module(x) 24 | 25 | 26 | def save_pickle(file, data): 27 | with open(file, 'wb') as f: 28 | pickle.dump(data, f) 29 | 30 | 31 | def load_pickle(file): 32 | with open(file, 'rb') as f: 33 | return pickle.load(f) 34 | 35 | 36 | def extract_feature(val_loader, model, checkpoint_dir, tag='last'): 37 | save_dir = '{}/{}'.format(checkpoint_dir, tag) 38 | if os.path.isfile(save_dir + '/output.plk'): 39 | data = load_pickle(save_dir + '/output.plk') 40 | return data 41 | else: 42 | if not os.path.isdir(save_dir): 43 | os.makedirs(save_dir) 44 | 45 | model.eval() 46 | with torch.no_grad(): 47 | 48 | output_dict = collections.defaultdict(list) 49 | 50 | for i, (inputs, labels) in enumerate(val_loader): 51 | print(f"{i}/{len(val_loader)}") 52 | # compute output 53 | inputs = inputs.cuda() 54 | labels = labels.cuda() 55 | outputs = model(inputs, return_logits=False).cpu().data.numpy() 56 | for out, label in zip(outputs, labels): 57 | output_dict[label.item()].append(out) 58 | 59 | all_info = output_dict 60 | save_pickle(save_dir + '/output.plk', all_info) 61 | return 62 | 63 | 64 | def main(): 65 | args = parse_args('test') 66 | loadfile = args.json_dir + f'{args.dataset.lower()}_novel.json' 67 | 68 | if args.dataset.lower() == 'miniimagenet' or args.dataset.lower() == 'cub': 69 | datamgr = SimpleDataManager(84, batch_size=64) 70 | novel_loader = datamgr.get_data_loader(loadfile, aug=False) 71 | else: 72 | raise ValueError 73 | 74 | model_file = os.path.join(args.model_dir, args.file_name) 75 | if args.model == 'WideResNet28_10': 76 | model = wrn_mixup_model.wrn28_10(num_classes=args.num_classes) 77 | elif args.model == 'ResNet18': 78 | model = res_mixup_model.resnet18(num_classes=args.num_classes) 79 | else: 80 | raise ValueError 81 | 82 | model = model.cuda() 83 | cudnn.benchmark = True 84 | 85 | if model_file.endswith('.tar'): 86 | checkpoint = torch.load(model_file) 87 | state = checkpoint['state'] 88 | state_keys = list(state.keys()) 89 | callwrap = False 90 | if 'module' in state_keys[0]: 91 | callwrap = True 92 | if callwrap: 93 | model = WrappedModel(model) 94 | model_dict_load = model.state_dict() 95 | model_dict_load.update(state) 96 | model.load_state_dict(model_dict_load) 97 | else: 98 | model.load_state_dict(torch.load(model_file)) 99 | 100 | model.eval() 101 | extract_feature(novel_loader, model, args.model_dir, tag='last') 102 | print("features saved!") 103 | 104 | 105 | if __name__ == '__main__': 106 | main() 107 | -------------------------------------------------------------------------------- /datasets/cub.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | import PIL 3 | from PIL import Image 4 | 5 | import numpy as np 6 | from torch.utils.data import Dataset 7 | from torchvision import transforms 8 | 9 | ROOT_PATH = 'C:/temp/datasets/cub' 10 | IMAGE_PATH = osp.join(ROOT_PATH, 'images') 11 | 12 | 13 | # This is for the CUB dataset 14 | # It is notable, we assume the cub images are cropped based on the given bounding boxes 15 | # The concept labels are based on the attribute value, which are for further use (and not used in this work) 16 | 17 | class CUB(Dataset): 18 | 19 | def __init__(self, setname, args, augment=False): 20 | txt_path = osp.join(ROOT_PATH, setname + '.csv') 21 | 22 | self.data, self.label = self.parse_csv(txt_path) 23 | self.num_class = np.unique(np.array(self.label)).shape[0] 24 | 25 | self.image_size = 84 26 | if augment and setname == 'train': 27 | transforms_list = [ 28 | transforms.RandomResizedCrop(self.image_size), 29 | transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4), 30 | transforms.RandomHorizontalFlip(), 31 | transforms.ToTensor(), 32 | ] 33 | else: 34 | transforms_list = [ 35 | transforms.Resize(92), 36 | transforms.CenterCrop(self.image_size), 37 | transforms.ToTensor(), 38 | ] 39 | 40 | # Transformation 41 | if args.backbone_class == 'ConvNet': 42 | self.transform = transforms.Compose( 43 | transforms_list + [ 44 | transforms.Normalize(np.array([0.485, 0.456, 0.406]), 45 | np.array([0.229, 0.224, 0.225])) 46 | ]) 47 | elif args.backbone_class == 'Res12': 48 | self.transform = transforms.Compose( 49 | transforms_list + [ 50 | transforms.Normalize(np.array([x / 255.0 for x in [120.39586422, 115.59361427, 104.54012653]]), 51 | np.array([x / 255.0 for x in [70.68188272, 68.27635443, 72.54505529]])) 52 | ]) 53 | elif args.backbone_class == 'Res18': 54 | self.transform = transforms.Compose( 55 | transforms_list + [ 56 | transforms.Normalize(mean=[0.485, 0.456, 0.406], 57 | std=[0.229, 0.224, 0.225]) 58 | ]) 59 | elif args.backbone_class == 'WRN': 60 | self.transform = transforms.Compose( 61 | transforms_list + [ 62 | transforms.Normalize(mean=[0.485, 0.456, 0.406], 63 | std=[0.229, 0.224, 0.225]) 64 | ]) 65 | else: 66 | raise ValueError('Non-supported Network Types. Please Revise Data Pre-Processing Scripts.') 67 | 68 | def parse_csv(self, txt_path): 69 | data = [] 70 | label = [] 71 | lb = -1 72 | self.wnids = [] 73 | lines = [x.strip() for x in open(txt_path, 'r').readlines()][1:] 74 | 75 | for l in lines: 76 | context = l.split(',') 77 | name = context[0] 78 | wnid = context[1] 79 | path = osp.join(IMAGE_PATH, name) 80 | if wnid not in self.wnids: 81 | self.wnids.append(wnid) 82 | lb += 1 83 | 84 | data.append(path) 85 | label.append(lb) 86 | 87 | return data, label 88 | 89 | def __len__(self): 90 | return len(self.data) 91 | 92 | def __getitem__(self, i): 93 | data, label = self.data[i], self.label[i] 94 | if self.use_im_cache: 95 | image = self.transform(data) 96 | else: 97 | image = self.transform(Image.open(data).convert('RGB')) 98 | return image, label 99 | -------------------------------------------------------------------------------- /methods/pt_map/evaluation/save_plk.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import argparse 4 | import csv 5 | import os 6 | import collections 7 | import pickle 8 | import random 9 | 10 | import numpy as np 11 | import torch 12 | from torch.autograd import Variable 13 | import torch.backends.cudnn as cudnn 14 | import torch.nn as nn 15 | import torch.optim as optim 16 | import torchvision.transforms as transforms 17 | import torchvision.datasets as datasets 18 | from io_utils import parse_args 19 | from data.datamgr import SimpleDataManager , SetDataManager 20 | import configs 21 | 22 | import wrn_mixup_model 23 | import res_mixup_model 24 | 25 | import torch.nn.functional as F 26 | 27 | from io_utils import parse_args, get_resume_file ,get_assigned_file 28 | from os import path 29 | 30 | use_gpu = torch.cuda.is_available() 31 | 32 | class WrappedModel(nn.Module): 33 | def __init__(self, module): 34 | super(WrappedModel, self).__init__() 35 | self.module = module 36 | def forward(self, x): 37 | return self.module(x) 38 | 39 | def save_pickle(file, data): 40 | with open(file, 'wb') as f: 41 | pickle.dump(data, f) 42 | 43 | def load_pickle(file): 44 | with open(file, 'rb') as f: 45 | return pickle.load(f) 46 | 47 | def extract_feature(val_loader, model, checkpoint_dir, tag='last'): 48 | save_dir = '{}/{}'.format(checkpoint_dir, tag) 49 | if os.path.isfile(save_dir + '/output.plk'): 50 | data = load_pickle(save_dir + '/output.plk') 51 | return data 52 | else: 53 | if not os.path.isdir(save_dir): 54 | os.makedirs(save_dir) 55 | 56 | #model.eval() 57 | with torch.no_grad(): 58 | 59 | output_dict = collections.defaultdict(list) 60 | 61 | for i, (inputs, labels) in enumerate(val_loader): 62 | # compute output 63 | inputs = inputs.cuda() 64 | labels = labels.cuda() 65 | outputs,_ = model(inputs) 66 | outputs = outputs.cpu().data.numpy() 67 | 68 | for out, label in zip(outputs, labels): 69 | output_dict[label.item()].append(out) 70 | 71 | all_info = output_dict 72 | save_pickle(save_dir + '/output.plk', all_info) 73 | return all_info 74 | 75 | 76 | if __name__ == '__main__': 77 | params = parse_args('test') 78 | 79 | loadfile = configs.data_dir[params.dataset] + 'novel.json' 80 | 81 | if params.dataset == 'miniImagenet' or params.dataset == 'CUB': 82 | datamgr = SimpleDataManager(84, batch_size = 256) 83 | else: 84 | raise ValueError(params.dataset) 85 | 86 | novel_loader = datamgr.get_data_loader(loadfile, aug = False) 87 | 88 | #checkpoint_dir = '%s/checkpoints/%s/%s_%s' %(configs.save_dir, params.dataset, params.model, params.method) 89 | #modelfile = get_resume_file(checkpoint_dir) 90 | 91 | # --- BPA Loading --- 92 | checkpoint_dir = "./checkpoints/miniImagenet/" 93 | modelfile = os.path.join(checkpoint_dir, 'ptmap_bpa.pth') 94 | # --- BPA Loading --- 95 | 96 | if params.model == 'WideResNet28_10': 97 | model = wrn_mixup_model.wrn28_10(num_classes=params.num_classes) 98 | elif params.model == 'ResNet18': 99 | model = res_mixup_model.resnet18(num_classes=params.num_classes) 100 | else: 101 | raise ValueError(params.model) 102 | 103 | model.cuda() 104 | cudnn.benchmark = True 105 | 106 | if modelfile.endswith('.tar'): 107 | checkpoint = torch.load(modelfile) 108 | state = checkpoint['state'] 109 | state_keys = list(state.keys()) 110 | 111 | callwrap = False 112 | if 'module' in state_keys[0]: 113 | callwrap = True 114 | if callwrap: 115 | model = WrappedModel(model) 116 | model_dict_load = model.state_dict() 117 | model_dict_load.update(state) 118 | model.load_state_dict(model_dict_load) 119 | else: 120 | model.load_state_dict(torch.load(modelfile), strict = False) 121 | 122 | model.eval() 123 | output_dict = extract_feature(novel_loader, model, checkpoint_dir, tag='last') 124 | print("features saved!") 125 | -------------------------------------------------------------------------------- /datasets/get_cifar_fs.py: -------------------------------------------------------------------------------- 1 | """ 2 | @author: Arnout Devos 3 | 2018/12/06 4 | MIT License 5 | Script for downloading, and reorganizing CIFAR few shot from CIFAR-100 according 6 | to split specifications in Luca et al. '18. 7 | Run this file as follows: 8 | python get_cifarfs.py 9 | """ 10 | 11 | import pickle 12 | import os 13 | import numpy as np 14 | from tqdm import tqdm 15 | import requests 16 | import math 17 | import tarfile,sys 18 | from PIL import Image 19 | import glob 20 | import shutil 21 | 22 | def download_file(url, filename): 23 | """ 24 | Helper method handling downloading large files from `url` to `filename`. Returns a pointer to `filename`. 25 | """ 26 | chunkSize = 1024 27 | r = requests.get(url, stream=True) 28 | with open(filename, 'wb') as f: 29 | pbar = tqdm( unit="B", total=int( r.headers['Content-Length'] ) ) 30 | for chunk in r.iter_content(chunk_size=chunkSize): 31 | if chunk: # filter out keep-alive new chunks 32 | pbar.update (len(chunk)) 33 | f.write(chunk) 34 | return filename 35 | 36 | 37 | if not os.path.exists("cifar-100-python.tar.gz"): 38 | print("Downloading cifar-100-python.tar.gz\n") 39 | download_file('http://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz','cifar-100-python.tar.gz') 40 | print("Downloading done.\n") 41 | else: 42 | print("Dataset already downloaded. Did not download twice.\n") 43 | 44 | 45 | tarname = "cifar-100-python.tar.gz" 46 | print("Untarring: {}".format(tarname)) 47 | tar = tarfile.open(tarname) 48 | tar.extractall() 49 | tar.close() 50 | 51 | datapath = "cifar-100-python" 52 | 53 | print("Extracting jpg images and classes from pickle files") 54 | 55 | # in CIFAR 100, the files are given in a train and test format 56 | for batch in ['test', 'train']: 57 | 58 | print("Handling pickle file: {}".format(batch)) 59 | 60 | # Create variable which is the exact path to the file 61 | fpath = os.path.join(datapath, batch) 62 | 63 | # Unpickle the file, and its metadata (classnames) 64 | f = open(fpath, 'rb') 65 | labels = pickle.load(open(os.path.join(datapath, 'meta'), 'rb'), encoding="ASCII") 66 | d = pickle.load(f, encoding='bytes') 67 | 68 | # decode utf8 encoded keys, and copy files into new dictionary d_decoded 69 | d_decoded = {} 70 | for k, v in d.items(): 71 | d_decoded[k.decode('utf8')] = v 72 | 73 | d = d_decoded 74 | f.close() 75 | 76 | #for i, filename in enumerate(d['filenames']): 77 | i = 0 78 | for filename in tqdm(d['filenames']): 79 | folder = os.path.join('images', labels['fine_label_names'][d['fine_labels'][i]]) 80 | 81 | png_path = os.path.join(folder, filename.decode()) 82 | jpg_path = os.path.splitext(png_path)[0]+".jpg" 83 | 84 | if os.path.exists(jpg_path): 85 | continue 86 | else: 87 | os.makedirs(folder, exist_ok=True) 88 | q = d['data'][i] 89 | with open(jpg_path, 'wb') as outfile: 90 | #png.from_array(q.reshape((32, 32, 3), order='F').swapaxes(0,1), mode='RGB').save(outfile) 91 | img = Image.fromarray(q.reshape((32, 32, 3), order='F').swapaxes(0,1), 'RGB') 92 | img.save(outfile) 93 | 94 | i+=1 95 | 96 | print("Removing pickle files") 97 | shutil.rmtree('cifar-100-python', ignore_errors=True) 98 | 99 | print("Depending on the split files, organize train, val and test sets") 100 | for datatype in ['train', 'val', 'test']: 101 | os.makedirs(os.path.join('cifar-fs', datatype), exist_ok=True) 102 | with open(os.path.join('cifar-fs-splits', datatype + '.txt'), 'r') as f: 103 | content = f.readlines() 104 | # Remove whitespace characters like `\n` at the end of each line 105 | classes = [x.strip() for x in content] 106 | 107 | for img_class in classes: 108 | if os.path.exists(os.path.join('cifar-fs', datatype, img_class)): 109 | continue 110 | else: 111 | cur_dir = os.path.join('cifar-fs', datatype) 112 | os.makedirs(cur_dir, exist_ok=True) 113 | os.system('mv images/' + img_class + ' ' + cur_dir) 114 | 115 | print("Removing original CIFAR 100 images") 116 | shutil.rmtree('images', ignore_errors=True) 117 | 118 | print("Removing tar file") 119 | os.remove('cifar-100-python.tar.gz') 120 | -------------------------------------------------------------------------------- /bpa/balanced_pairwise_affinities.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch import Tensor 5 | 6 | from bpa import ot 7 | 8 | 9 | class BPA(nn.Module): 10 | supported_distances = ['cosine', 'euclidean'] 11 | 12 | def __init__(self, 13 | distance_metric: str = 'cosine', 14 | ot_reg: float = 0.1, 15 | sinkhorn_iterations: int = 10, 16 | sigmoid: bool = False, 17 | mask_diag: bool = True, 18 | max_scale: bool = True): 19 | """ 20 | :param distance_metric - Compute the cost matrix. 21 | :param ot_reg - Sinkhorn entropy regularization (lambda). For few-shot classification, 0.1-0.2 works best. 22 | :param sinkhorn_iterations - Maximum number of sinkhorn iterations. 23 | :param sigmoid - If to apply sigmoid(log_p) instead of the usual exp(log_p). 24 | :param mask_diag - Set to true to apply diagonal masking before and after the OT. 25 | :param max_scale - Re-scale the BPA values to range [0,1]. 26 | """ 27 | super().__init__() 28 | 29 | assert distance_metric.lower() in BPA.supported_distances and sinkhorn_iterations > 0 30 | 31 | self.sinkhorn_iterations = sinkhorn_iterations 32 | self.distance_metric = distance_metric.lower() 33 | self.mask_diag = mask_diag 34 | self.sigmoid = sigmoid 35 | self.ot_reg = ot_reg 36 | self.max_scale = max_scale 37 | self.diagonal_val = 1e5 # value to mask self-values with 38 | 39 | def mask_diagonal(self, M: Tensor, value: float): 40 | """ 41 | Fill the diagonal of a given matrix (or a batch of them) with given value 42 | """ 43 | if self.mask_diag: 44 | if M.dim() > 2: 45 | M[torch.eye(M.shape[1]).repeat(M.shape[0], 1, 1).bool()] = value 46 | else: 47 | M.fill_diagonal_(value) 48 | return M 49 | 50 | def adjust_labeled(self, x: Tensor, y: Tensor): 51 | """ 52 | Adjust BPA scores using additional labels (e.g. support set in few shot classification) 53 | We do so by filling the final values of labeled pairs by 0 and 1, according to if they share the same class 54 | """ 55 | task_dim = x.ndim - 2 56 | labels_one_hot = F.one_hot(y, num_classes=y.max().item() + 1).float() 57 | mask = (labels_one_hot @ labels_one_hot.transpose(-2, -1)).bool() # mask[i,j] 1 if y[i] == y[j] 58 | # pad mask 59 | pad_size = x.size(task_dim) - mask.size(task_dim) 60 | pad = (0, pad_size, 0, pad_size) 61 | # (padding_left, padding_right, padding_top, padding_bottom) 62 | x.masked_fill_(F.pad(mask, pad, "constant", False), value=1) # mask known positives 63 | x.masked_fill_(F.pad(~mask, pad, "constant", False), value=0) # mask known negatives 64 | return x 65 | 66 | def compute_cost_matrix(self, x: Tensor) -> Tensor: 67 | """ 68 | Compute the cost matrix under euclidean or cosine distances 69 | """ 70 | # Euclidean distances 71 | if self.distance_metric == 'euclidean': 72 | # dim_offset = 0 if x.dim() <= 2 else 1 73 | # pairwise_dist = (x.unsqueeze(1+dim_offset) - x.unsqueeze(0+dim_offset)).norm(dim=-1).pow(2) 74 | pairwise_dist = torch.cdist(x, x, p=2) 75 | pairwise_dist = pairwise_dist / pairwise_dist.max() # scale distances to [0, 1] 76 | 77 | # Cosine distances 78 | elif self.distance_metric == 'cosine': 79 | x_norm = F.normalize(x, dim=-1, p=2) 80 | pairwise_dist = 1 - (x_norm @ x_norm.transpose(-2, -1)) 81 | return pairwise_dist 82 | 83 | def forward(self, x: Tensor, y: Tensor = None) -> Tensor: 84 | """ 85 | Compute the BPA feature transform 86 | """ 87 | # get masked cost matrix 88 | C = self.compute_cost_matrix(x) 89 | C = self.mask_diagonal(C, value=self.diagonal_val) 90 | 91 | # compute self-OT 92 | x_bpa = ot.log_sinkhorn(C, reg=self.ot_reg, num_iters=self.sinkhorn_iterations) 93 | if self.sigmoid: 94 | x_bpa = torch.sigmoid(x_bpa) 95 | else: 96 | x_bpa = torch.exp(x_bpa) 97 | 98 | # divide the BPA matrix by its maximum value to scale its range into [0, 1] 99 | if self.max_scale: 100 | z_max = x_bpa.max().item() if x_bpa.dim() <= 2 else x_bpa.amax(dim=(1, 2), keepdim=True) 101 | x_bpa = x_bpa / z_max 102 | 103 | # adjust labeled samples (e.g. support set) if given labels 104 | if y is not None: 105 | x_bpa = self.adjust_labeled(x_bpa, y) 106 | 107 | # set self-values to 1 108 | return self.mask_diagonal(x_bpa, value=1) 109 | 110 | -------------------------------------------------------------------------------- /models/resnet12.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | import torch.nn.functional as F 4 | from .dropblock import DropBlock 5 | 6 | # This ResNet network was designed following the practice of the following papers: 7 | # TADAM: Task dependent adaptive metric for improved few-shot learning (Oreshkin et al., in NIPS 2018) and 8 | # A Simple Neural Attentive Meta-Learner (Mishra et al., in ICLR 2018). 9 | 10 | 11 | def conv3x3(in_planes, out_planes, stride=1): 12 | """3x3 convolution with padding""" 13 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 14 | padding=1, bias=False) 15 | 16 | 17 | class BasicBlock(nn.Module): 18 | expansion = 1 19 | 20 | def __init__(self, inplanes, planes, stride=1, downsample=None, drop_rate=0.0, drop_block=False, block_size=1): 21 | super(BasicBlock, self).__init__() 22 | self.conv1 = conv3x3(inplanes, planes) 23 | self.bn1 = nn.BatchNorm2d(planes) 24 | self.relu = nn.LeakyReLU(0.1) 25 | self.conv2 = conv3x3(planes, planes) 26 | self.bn2 = nn.BatchNorm2d(planes) 27 | self.conv3 = conv3x3(planes, planes) 28 | self.bn3 = nn.BatchNorm2d(planes) 29 | self.maxpool = nn.MaxPool2d(stride) 30 | self.downsample = downsample 31 | self.stride = stride 32 | self.drop_rate = drop_rate 33 | self.num_batches_tracked = 0 34 | self.drop_block = drop_block 35 | self.block_size = block_size 36 | self.DropBlock = DropBlock(block_size=self.block_size) 37 | 38 | def forward(self, x): 39 | self.num_batches_tracked += 1 40 | 41 | residual = x 42 | 43 | out = self.conv1(x) 44 | out = self.bn1(out) 45 | out = self.relu(out) 46 | 47 | out = self.conv2(out) 48 | out = self.bn2(out) 49 | out = self.relu(out) 50 | 51 | out = self.conv3(out) 52 | out = self.bn3(out) 53 | 54 | if self.downsample is not None: 55 | residual = self.downsample(x) 56 | out += residual 57 | out = self.relu(out) 58 | out = self.maxpool(out) 59 | 60 | if self.drop_rate > 0: 61 | if self.drop_block == True: 62 | feat_size = out.size()[2] 63 | keep_rate = max(1.0 - self.drop_rate / (20*2000) * (self.num_batches_tracked), 1.0 - self.drop_rate) 64 | gamma = (1 - keep_rate) / self.block_size**2 * feat_size**2 / (feat_size - self.block_size + 1)**2 65 | out = self.DropBlock(out, gamma=gamma) 66 | else: 67 | out = F.dropout(out, p=self.drop_rate, training=self.training, inplace=True) 68 | 69 | return out 70 | 71 | 72 | class ResNet(nn.Module): 73 | 74 | def __init__(self, block=BasicBlock, keep_prob=1.0, avg_pool=True, dropout=0.1, dropblock_size=5): 75 | self.inplanes = 3 76 | drop_rate = dropout 77 | super(ResNet, self).__init__() 78 | 79 | self.layer1 = self._make_layer(block, 64, stride=2, drop_rate=drop_rate) 80 | self.layer2 = self._make_layer(block, 160, stride=2, drop_rate=drop_rate) 81 | self.layer3 = self._make_layer(block, 320, stride=2, drop_rate=drop_rate, drop_block=True, block_size=dropblock_size) 82 | self.layer4 = self._make_layer(block, 640, stride=2, drop_rate=drop_rate, drop_block=True, block_size=dropblock_size) 83 | if avg_pool: 84 | self.avgpool = nn.AvgPool2d(5, stride=1) 85 | self.keep_prob = keep_prob 86 | self.keep_avg_pool = avg_pool 87 | self.dropout = nn.Dropout(p=1 - self.keep_prob, inplace=False) 88 | self.drop_rate = drop_rate 89 | 90 | for m in self.modules(): 91 | if isinstance(m, nn.Conv2d): 92 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='leaky_relu') 93 | elif isinstance(m, nn.BatchNorm2d): 94 | nn.init.constant_(m.weight, 1) 95 | nn.init.constant_(m.bias, 0) 96 | 97 | def _make_layer(self, block, planes, stride=1, drop_rate=0.0, drop_block=False, block_size=1): 98 | downsample = None 99 | if stride != 1 or self.inplanes != planes * block.expansion: 100 | downsample = nn.Sequential( 101 | nn.Conv2d(self.inplanes, planes * block.expansion, 102 | kernel_size=1, stride=1, bias=False), 103 | nn.BatchNorm2d(planes * block.expansion), 104 | ) 105 | 106 | layers = [] 107 | layers.append(block(self.inplanes, planes, stride, downsample, drop_rate, drop_block, block_size)) 108 | self.inplanes = planes * block.expansion 109 | 110 | return nn.Sequential(*layers) 111 | 112 | def forward(self, x): 113 | x = self.layer1(x) 114 | x = self.layer2(x) 115 | x = self.layer3(x) 116 | x = self.layer4(x) 117 | if self.keep_avg_pool: 118 | x = self.avgpool(x) 119 | x = x.view(x.size(0), -1) 120 | return x 121 | 122 | 123 | def Res12(keep_prob=1.0, avg_pool=False, **kwargs): 124 | """Constructs a ResNet-12 model. 125 | """ 126 | model = ResNet(BasicBlock, keep_prob=keep_prob, avg_pool=avg_pool, **kwargs) 127 | return model -------------------------------------------------------------------------------- /methods/pt_map/FSLTask.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | import numpy as np 4 | import torch 5 | 6 | # ======================================================== 7 | # Usefull paths 8 | data_features_path = {"miniimagenet": "./checkpoints/miniimagenet/WideResNet28_10_S2M2_R/last/output.plk", 9 | "miniimagenet_dct": "./checkpoints/miniimagenet/WideResNet28_10_S2M2_R_5way_1shot_aug/last/output_both.plk", 10 | "cub": "./checkpoints/CUB/sill_features/5way_1shot_0.94734.plk", 11 | "cifar": "./checkpoints/cifar/sill_features/5way_5shot_0.91092.plk", 12 | "cross": "./checkpoints/cross/WideResNet28_10_S2M2_R/last/output.plk"} 13 | _cacheDir = None 14 | _maxRuns = 10000 15 | _min_examples = -1 16 | 17 | # ======================================================== 18 | # Module internal functions and variables 19 | 20 | _randStates = None 21 | _rsCfg = None 22 | 23 | 24 | def _load_pickle(file): 25 | with open(file, 'rb') as f: 26 | data = pickle.load(f) 27 | labels = [np.full(shape=len(data[key]), fill_value=key) for key in data] 28 | data = [features for key in data for features in data[key]] 29 | dataset = dict() 30 | dataset['data'] = torch.FloatTensor(np.stack(data, axis=0)) 31 | dataset['labels'] = torch.LongTensor(np.concatenate(labels)) 32 | return dataset 33 | 34 | 35 | # ========================================================= 36 | # Callable variables and functions from outside the module 37 | 38 | data = None 39 | labels = None 40 | dsName = None 41 | 42 | 43 | def loadDataSet(dsname, root: str = None, features_path: str = None): 44 | global dsName, data, labels, _randStates, _rsCfg, _min_examples, _cacheDir 45 | dsName = dsname 46 | _randStates = None 47 | _rsCfg = None 48 | _cacheDir = root + '/methods/pt_map/cache' 49 | 50 | # Loading data from files on computer 51 | if features_path is None or features_path == '': 52 | if dsname not in data_features_path: 53 | raise NameError('Unknwown dataset: {}'.format(dsname)) 54 | features_path = data_features_path[dsname] 55 | 56 | dataset = _load_pickle(features_path) 57 | 58 | # Computing the number of items per class in the dataset 59 | _min_examples = dataset["labels"].shape[0] 60 | for i in range(dataset["labels"].shape[0]): 61 | if torch.where(dataset["labels"] == dataset["labels"][i])[0].shape[0] > 0: 62 | _min_examples = min(_min_examples, torch.where( 63 | dataset["labels"] == dataset["labels"][i])[0].shape[0]) 64 | print("Guaranteed number of items per class: {:d}\n".format(_min_examples)) 65 | 66 | # Generating data tensors 67 | data = torch.zeros((0, _min_examples, dataset["data"].shape[1])) 68 | labels = dataset["labels"].clone() 69 | while labels.shape[0] > 0: 70 | indices = torch.where(dataset["labels"] == labels[0])[0] 71 | data = torch.cat([data, dataset["data"][indices, :][:_min_examples].view(1, _min_examples, -1)], dim=0) 72 | indices = torch.where(labels != labels[0])[0] 73 | labels = labels[indices] 74 | print("Total of {:d} classes, {:d} elements each, with dimension {:d}\n".format(data.shape[0], data.shape[1], data.shape[2])) 75 | 76 | 77 | def GenerateRun(iRun, cfg, regenRState=False, generate=True): 78 | global _randStates, data, _min_examples 79 | if not regenRState: 80 | np.random.set_state(_randStates[iRun]) 81 | 82 | classes = np.random.permutation(np.arange(data.shape[0]))[:cfg["ways"]] 83 | shuffle_indices = np.arange(_min_examples) 84 | dataset = None 85 | if generate: 86 | dataset = torch.zeros((cfg['ways'], cfg['shot'] + cfg['queries'], data.shape[2])) 87 | for i in range(cfg['ways']): 88 | shuffle_indices = np.random.permutation(shuffle_indices) 89 | if generate: 90 | dataset[i] = data[classes[i], shuffle_indices, :][:cfg['shot'] + cfg['queries']] 91 | 92 | return dataset 93 | 94 | 95 | def setRandomStates(cfg): 96 | global _randStates, _maxRuns, _rsCfg 97 | if _rsCfg == cfg: 98 | return 99 | 100 | rsFile = os.path.join(_cacheDir, "RandStates_{}_s{}_q{}_w{}".format(dsName, cfg['shot'], cfg['queries'], cfg['ways'])) 101 | if not os.path.exists(rsFile): 102 | print("{} does not exist, regenerating it...".format(rsFile)) 103 | np.random.seed(0) 104 | _randStates = [] 105 | for iRun in range(_maxRuns): 106 | _randStates.append(np.random.get_state()) 107 | GenerateRun(iRun, cfg, regenRState=True, generate=False) 108 | torch.save(_randStates, rsFile) 109 | else: 110 | print("reloading random states from file....") 111 | _randStates = torch.load(rsFile) 112 | _rsCfg = cfg 113 | 114 | 115 | def GenerateRunSet(start=None, end=None, cfg=None): 116 | global dataset, _maxRuns 117 | if start is None: 118 | start = 0 119 | if end is None: 120 | end = _maxRuns 121 | if cfg is None: 122 | cfg = {"shot": 1, "ways": 5, "queries": 15} 123 | 124 | setRandomStates(cfg) 125 | print("generating task from {} to {}".format(start, end)) 126 | 127 | dataset = [] 128 | for iRun in range(end - start): 129 | dataset.append(GenerateRun(start + iRun, cfg)) 130 | 131 | return torch.stack(dataset) 132 | -------------------------------------------------------------------------------- /methods/pt_map/pt_map_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.nn.functional as F 4 | 5 | """ 6 | Implementation of PT-MAP as a differential module. 7 | Original code in https://github.com/yhu01/PT-MAP 8 | """ 9 | 10 | 11 | def centerDatas(X: torch.Tensor, n_lsamples: int): 12 | """ 13 | Center labeled and un-labeled data separately. 14 | """ 15 | X[:n_lsamples, :] = X[:n_lsamples, :] - X[:n_lsamples, :].mean(0, keepdim=True) 16 | X[n_lsamples:, :] = X[n_lsamples:, :] - X[n_lsamples:, :].mean(0, keepdim=True) 17 | return X 18 | 19 | # --------- GaussianModel 20 | 21 | class GaussianModel: 22 | def __init__(self, num_way: int, num_shot: int, num_query: int, lam: float): 23 | self.num_way = num_way 24 | self.num_shot = num_shot 25 | self.num_query = num_query 26 | self.n_lsamples = num_way * num_shot 27 | self.n_usamples = num_way * num_query 28 | self.lam = lam 29 | self.mus = None # shape [n_ways][feat_dim] 30 | 31 | def cuda(self): 32 | self.mus = self.mus.cuda() 33 | 34 | def init_from_labelled(self, X: torch.Tensor): 35 | self.mus = X.reshape(self.num_shot + self.num_query, self.num_way, -1)[:self.num_shot, ].mean(0) 36 | 37 | def update_from_estimate(self, estimate, alpha): 38 | Dmus = estimate - self.mus 39 | self.mus = self.mus + alpha * Dmus 40 | 41 | def compute_optimal_transport(self, M: torch.Tensor, r: torch.Tensor, c: torch.Tensor, epsilon: float = 1e-6): 42 | n_runs, n, m = M.shape 43 | P = torch.exp(-self.lam * M) 44 | P = P / P.view((n_runs, -1)).sum(1).unsqueeze(1).unsqueeze(1) 45 | u = torch.zeros((n_runs, n), device='cuda') 46 | maxiters = 1000 47 | iters = 1 48 | # normalize this matrix 49 | while torch.max(torch.abs(u - P.sum(-1))) > epsilon: 50 | u = P.sum(dim=-1) 51 | P *= (r / u).view((n_runs, -1, 1)) 52 | P *= (c / P.sum(1)).view((n_runs, 1, -1)) 53 | if iters == maxiters: 54 | break 55 | iters += 1 56 | 57 | if n_runs == 1: 58 | return P.squeeze(0) 59 | return P 60 | 61 | def get_probas(self, X: torch.Tensor, labels: torch.Tensor): 62 | # compute squared dist to centroids [n_samples][n_ways] 63 | dist = (X.unsqueeze(1) - self.mus.unsqueeze(0)).norm(dim=2).pow(2) 64 | p_xj = torch.zeros_like(dist) 65 | r = torch.ones(1, self.num_way * self.num_query, device='cuda') 66 | c = torch.ones(1, self.num_way, device='cuda') * self.num_query 67 | p_xj_test = self.compute_optimal_transport(dist.unsqueeze(0)[:, self.n_lsamples:], r, c, epsilon=1e-6) 68 | p_xj[self.n_lsamples:] = p_xj_test 69 | 70 | p_xj[:self.n_lsamples].fill_(0) 71 | p_xj[:self.n_lsamples].scatter_(1, labels[:self.n_lsamples].unsqueeze(1), 1) 72 | return p_xj 73 | 74 | def estimate_from_mask(self, X: torch.Tensor, mask: torch.Tensor): 75 | emus = mask.T.matmul(X).div(mask.sum(dim=0).unsqueeze(1)) 76 | return emus 77 | 78 | 79 | # ========================================= 80 | # MAP 81 | # ========================================= 82 | 83 | class MAP: 84 | def __init__(self, labels, alpha: float, num_labeled: int, n_runs=1): 85 | self.alpha = alpha 86 | self.num_labeled = num_labeled 87 | self.s_labels = labels[:self.num_labeled] 88 | self.q_labels = labels[self.num_labeled:] 89 | self.n_runs = n_runs 90 | 91 | def get_accuracy(self, probas: torch.Tensor): 92 | y_hat = probas[self.num_labeled:].argmax(dim=-1) 93 | matches = self.q_labels.eq(y_hat).float() 94 | m = matches.mean().item() 95 | pm = matches.std(unbiased=False).item() * 1.96 96 | return m, pm 97 | 98 | def perform_epoch(self, model: GaussianModel, X: torch.Tensor): 99 | p_xj = model.get_probas(X=X, labels=self.s_labels) 100 | m_estimates = model.estimate_from_mask(X=X, mask=p_xj) 101 | # update centroids 102 | model.update_from_estimate(m_estimates, self.alpha) 103 | 104 | def loop(self, X: torch.Tensor, model: GaussianModel, n_epochs: int = 20): 105 | for epoch in range(1, n_epochs + 1): 106 | self.perform_epoch(model=model, X=X) 107 | # get final accuracy and return it 108 | P = model.get_probas(X=X, labels=self.s_labels) 109 | return P 110 | 111 | 112 | class PTMAPLoss(nn.Module): 113 | def __init__(self, args: dict, lam: float = 10, alpha: float = 0.2, n_epochs: int = 10, bpa=None): 114 | super().__init__() 115 | self.way_dict = dict(train=args['train_way'], val=args['val_way']) 116 | self.num_shot = args['num_shot'] 117 | self.num_query = args['num_query'] 118 | self.lam = lam 119 | self.alpha = alpha 120 | self.n_epochs = n_epochs 121 | self.num_labeled = None 122 | self.BPA = bpa 123 | 124 | def scale(self, X: torch.Tensor, mode: str): 125 | # normalize, center and normalize again 126 | if mode != 'train': 127 | X = F.normalize(X, p=2, dim=-1) 128 | X = centerDatas(X, self.num_labeled) 129 | X = F.normalize(X, p=2, dim=-1) 130 | return X 131 | 132 | def forward(self, X: torch.Tensor, labels: torch.Tensor, mode: str): 133 | num_way = self.way_dict[mode] 134 | self.num_labeled = num_way * self.num_shot 135 | 136 | assert X.min() >= 0, \ 137 | f"Error: PT-MAP require non-negative features while X.min()={X.min()}. " \ 138 | "You may add ReLU like activation or use switch to WRN backbone." 139 | 140 | # power-transform (the PT part) and scaling 141 | X = torch.pow(X + 1e-6, 0.5) 142 | Z = self.scale(X, mode=mode) 143 | 144 | # applying BPA transform 145 | if self.BPA is not None: 146 | Z = self.BPA(Z, labels[:self.num_labeled] if mode != "train" else None) 147 | # Z = self.scale(Z, mode=mode) 148 | 149 | # MAP 150 | gaussian_model = GaussianModel( 151 | num_way=num_way, 152 | num_shot=self.num_shot, 153 | num_query=self.num_query, 154 | lam=self.lam) 155 | gaussian_model.init_from_labelled(Z) 156 | 157 | optim = MAP(labels=labels, alpha=self.alpha, num_labeled=self.num_labeled) 158 | probs = optim.loop( 159 | Z, model=gaussian_model, n_epochs=self.n_epochs 160 | ) 161 | accuracy, std = optim.get_accuracy(probs) 162 | 163 | return torch.log(probs[self.num_labeled:] + 1e-6), accuracy 164 | -------------------------------------------------------------------------------- /models/res_mixup_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | import random 6 | from torch.nn.utils.weight_norm import WeightNorm 7 | from torchvision.models.resnet import Bottleneck 8 | 9 | 10 | def conv3x3(in_planes, out_planes, stride=1): 11 | """3x3 convolution with padding""" 12 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 13 | padding=1, bias=False) 14 | 15 | 16 | def conv1x1(in_planes, out_planes, stride=1): 17 | """1x1 convolution""" 18 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) 19 | 20 | 21 | def mixup_data(x, y, lam): 22 | '''Compute the mixup data. Return mixed inputs, pairs of targets, and lambda''' 23 | 24 | batch_size = x.size()[0] 25 | index = torch.randperm(batch_size) 26 | if torch.cuda.is_available(): 27 | index = index.cuda() 28 | mixed_x = lam * x + (1 - lam) * x[index, :] 29 | y_a, y_b = y, y[index] 30 | 31 | return mixed_x, y_a, y_b, lam 32 | 33 | 34 | class distLinear(nn.Module): 35 | def __init__(self, indim, outdim): 36 | super(distLinear, self).__init__() 37 | self.L = nn.Linear(indim, outdim, bias=False) 38 | self.class_wise_learnable_norm = True # See the issue#4&8 in the github 39 | if self.class_wise_learnable_norm: 40 | WeightNorm.apply(self.L, 'weight', dim=0) # split the weight update component to direction and norm 41 | 42 | if outdim <= 200: 43 | self.scale_factor = 2 # a fixed scale factor to scale the output of cos value into a reasonably large input for softmax 44 | else: 45 | self.scale_factor = 10 # in omniglot, a larger scale factor is required to handle >1000 output classes. 46 | 47 | def forward(self, x): 48 | x_norm = torch.norm(x, p=2, dim=1).unsqueeze(1).expand_as(x) 49 | x_normalized = x.div(x_norm + 0.00001) 50 | if not self.class_wise_learnable_norm: 51 | L_norm = torch.norm(self.L.weight.data, p=2, dim=1).unsqueeze(1).expand_as(self.L.weight.data) 52 | self.L.weight.data = self.L.weight.data.div(L_norm + 0.00001) 53 | cos_dist = self.L( 54 | x_normalized) # matrix product by forward function, but when using WeightNorm, this also multiply the cosine distance by a class-wise learnable norm, see the issue#4&8 in the github 55 | scores = self.scale_factor * (cos_dist) 56 | 57 | return scores 58 | 59 | 60 | class BasicBlock(nn.Module): 61 | expansion = 1 62 | 63 | def __init__(self, inplanes, planes, stride=1, downsample=None): 64 | super(BasicBlock, self).__init__() 65 | self.conv1 = conv3x3(inplanes, planes, stride) 66 | self.bn1 = nn.BatchNorm2d(planes) 67 | self.relu = nn.ReLU(inplace=True) 68 | self.conv2 = conv3x3(planes, planes) 69 | self.bn2 = nn.BatchNorm2d(planes) 70 | self.downsample = downsample 71 | self.stride = stride 72 | 73 | def forward(self, x): 74 | identity = x 75 | 76 | out = self.conv1(x) 77 | out = self.bn1(out) 78 | out = self.relu(out) 79 | 80 | out = self.conv2(out) 81 | out = self.bn2(out) 82 | 83 | if self.downsample is not None: 84 | identity = self.downsample(x) 85 | 86 | out += identity 87 | out = self.relu(out) 88 | 89 | return out 90 | 91 | 92 | class ResNet(nn.Module): 93 | 94 | def __init__(self, block, layers, num_classes=200, zero_init_residual=False): 95 | super(ResNet, self).__init__() 96 | self.inplanes = 64 97 | self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, 98 | bias=False) 99 | self.bn1 = nn.BatchNorm2d(64) 100 | self.relu = nn.ReLU(inplace=True) 101 | self.layer1 = self._make_layer(block, 64, layers[0]) 102 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 103 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 104 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2) 105 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 106 | 107 | self.fc = distLinear(512 * block.expansion, num_classes) 108 | 109 | for m in self.modules(): 110 | if isinstance(m, nn.Conv2d): 111 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 112 | elif isinstance(m, nn.BatchNorm2d): 113 | nn.init.constant_(m.weight, 1) 114 | nn.init.constant_(m.bias, 0) 115 | 116 | # Zero-initialize the last BN in each residual branch, 117 | # so that the residual branch starts with zeros, and each residual block behaves like an identity. 118 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 119 | if zero_init_residual: 120 | for m in self.modules(): 121 | if isinstance(m, Bottleneck): 122 | nn.init.constant_(m.bn3.weight, 0) 123 | elif isinstance(m, BasicBlock): 124 | nn.init.constant_(m.bn2.weight, 0) 125 | 126 | def _make_layer(self, block, planes, blocks, stride=1): 127 | downsample = None 128 | if stride != 1 or self.inplanes != planes * block.expansion: 129 | downsample = nn.Sequential( 130 | conv1x1(self.inplanes, planes * block.expansion, stride), 131 | nn.BatchNorm2d(planes * block.expansion), 132 | ) 133 | 134 | layers = [] 135 | layers.append(block(self.inplanes, planes, stride, downsample)) 136 | self.inplanes = planes * block.expansion 137 | for _ in range(1, blocks): 138 | layers.append(block(self.inplanes, planes)) 139 | 140 | return nn.Sequential(*layers) 141 | 142 | def forward(self, x, target=None, mixup=False, mixup_hidden=True, mixup_alpha=None, lam=0.4): 143 | if target is not None: 144 | if mixup_hidden: 145 | layer_mix = random.randint(0, 5) 146 | elif mixup: 147 | layer_mix = 0 148 | else: 149 | layer_mix = None 150 | 151 | out = x 152 | 153 | if layer_mix == 0: 154 | out, target_a, target_b, lam = mixup_data(out, target, lam=lam) 155 | 156 | out = self.relu(self.bn1(self.conv1(x))) 157 | out = self.layer1(out) 158 | 159 | if layer_mix == 1: 160 | out, target_a, target_b, lam = mixup_data(out, target, lam=lam) 161 | 162 | out = self.layer2(out) 163 | 164 | if layer_mix == 2: 165 | out, target_a, target_b, lam = mixup_data(out, target, lam=lam) 166 | 167 | out = self.layer3(out) 168 | 169 | if layer_mix == 3: 170 | out, target_a, target_b, lam = mixup_data(out, target, lam=lam) 171 | 172 | out = self.layer4(out) 173 | 174 | if layer_mix == 4: 175 | out, target_a, target_b, lam = mixup_data(out, target, lam=lam) 176 | 177 | out = self.avgpool(out) 178 | out = out.view(out.size(0), -1) 179 | out1 = self.fc.forward(out) 180 | 181 | if layer_mix == 5: 182 | out, target_a, target_b, lam = mixup_data(out, target, lam=lam) 183 | 184 | return out, out1, target_a, target_b 185 | else: 186 | out = self.conv1(x) 187 | out = self.bn1(out) 188 | out = self.relu(out) 189 | 190 | out = self.layer1(out) 191 | out = self.layer2(out) 192 | out = self.layer3(out) 193 | out = self.layer4(out) 194 | out = self.avgpool(out) 195 | out = out.view(out.size(0), -1) 196 | out1 = self.fc.forward(out) 197 | return out, out1 198 | 199 | 200 | def resnet18(**kwargs): 201 | """Constructs a ResNet-18 model. 202 | """ 203 | model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs) 204 | return model 205 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # BPA: The balanced-pairwise-affinities Feature Transform (ICML 2024) 2 | 3 | This repository contains the official PyTorch implementation of **BPA** (formerly SOT) — the **B**alanced-**P**airwise-**A**ffinities feature transform — as described in our paper [*The Balanced-Pairwise-Affinities Feature Transform*](https://arxiv.org/abs/2407.01467), presented at ICML 2024. 4 | 5 | ![BPA](bpa_workflow.png?raw=true) 6 | 7 | BPA enhances the representation of a set of input features to support downstream tasks such as matching or grouping. 8 | 9 | The transformed features capture both **direct** pairwise similarity and **third-party agreement** — how other instances in the set influence similarity between a given pair. This enables BPA to model higher-order relations effectively. 10 | 11 | --- 12 | 13 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/the-self-optimal-transport-feature-transform/few-shot-image-classification-on-cifar-fs-5)](https://paperswithcode.com/sota/few-shot-image-classification-on-cifar-fs-5?p=the-self-optimal-transport-feature-transform) 14 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/the-self-optimal-transport-feature-transform/few-shot-image-classification-on-cifar-fs-5-1)](https://paperswithcode.com/sota/few-shot-image-classification-on-cifar-fs-5-1?p=the-self-optimal-transport-feature-transform) 15 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/the-self-optimal-transport-feature-transform/few-shot-image-classification-on-cub-200-5-1)](https://paperswithcode.com/sota/few-shot-image-classification-on-cub-200-5-1?p=the-self-optimal-transport-feature-transform) 16 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/the-self-optimal-transport-feature-transform/few-shot-image-classification-on-cub-200-5)](https://paperswithcode.com/sota/few-shot-image-classification-on-cub-200-5?p=the-self-optimal-transport-feature-transform) 17 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/the-self-optimal-transport-feature-transform/few-shot-image-classification-on-mini-2)](https://paperswithcode.com/sota/few-shot-image-classification-on-mini-2?p=the-self-optimal-transport-feature-transform) 18 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/the-self-optimal-transport-feature-transform/few-shot-image-classification-on-mini-3)](https://paperswithcode.com/sota/few-shot-image-classification-on-mini-3?p=the-self-optimal-transport-feature-transform) 19 | 20 | --- 21 | 22 | ## Few-Shot Classification Results 23 | 24 | | Dataset | Method | 5-Way 1-Shot | 5-Way 5-Shot | 25 | |--------------|------------------------|--------------|--------------| 26 | | MiniImagenet | PTMAP-BPAp | 83.19 | 89.56 | 27 | | | PTMAP-BPAt | 84.18 | 90.51 | 28 | | | PTMAP-SF-BPA | 85.59 | 91.34 | 29 | | CIFAR-FS | PTMAP-BPAp | 87.37 | 91.12 | 30 | | | PTMAP-SF-BPA | 89.94 | 92.83 | 31 | | CUB | PTMAP-BPAp | 91.90 | 94.63 | 32 | | | PTMAP-SF-BPA | 95.80 | 97.12 | 33 | 34 | --- 35 | 36 | ## Setup 37 | 38 | ### Datasets 39 | Instructions for downloading and preparing the few-shot classification datasets are available in the `datasets` directory. 40 | 41 | ### Pretrained Models 42 | Most results in the paper use fine-tuned models: 43 | 44 | - **WideResNet-28**: [PT-MAP checkpoint](https://drive.google.com/file/d/1wVJlDnU00Gurs0pw54ZMqf4XsWhJWHIh/view) 45 | - **ResNet-12**: [FEAT checkpoint](https://github.com/Sha-Lab/FEAT) 46 | 47 | ### BPA Checkpoint 48 | We provide a checkpoint for [PTMAP-BPAt](https://drive.google.com/file/d/1wjh_EBQPYYHFjqoqlKCitcG9mjgkWFRw/view?usp=sharing), which yields: 49 | 50 | - **84.69%** accuracy for 5-way 1-shot (vs. 84.18% in the paper) 51 | - **90.30%** accuracy for 5-way 5-shot (vs. 90.51% in the paper) 52 | 53 | --- 54 | 55 | ## Usage 56 | 57 | ### Quick Start 58 | 59 | BPA can be applied in just two lines of code: 60 | 61 | ```python 62 | import torch 63 | from bpa import BPA 64 | 65 | x = torch.randn(100, 128) # [n_samples, dim] 66 | x = BPA()(x) # Output shape: [n_samples, n_samples] 67 | ``` 68 | 69 | --- 70 | 71 | ### Training PT-MAP-BPAt on MiniImagenet 72 | 73 | 1. [Download](https://drive.google.com/file/d/1wVJlDnU00Gurs0pw54ZMqf4XsWhJWHIh/view) the pretrained WRN feature extractor. 74 | 2. Create an empty `checkpoints` directory. 75 | 3. Extract the downloaded file into the `checkpoints` folder. 76 | 77 | Run the following command to train BPA with PT-MAP: 78 | 79 | ```bash 80 | python train.py \ 81 | --sink_iters 5 \ 82 | --distance_metric cosine \ 83 | --ot_reg 0.2 \ 84 | --method pt_map_bpa \ 85 | --backbone wrn \ 86 | --augment false \ 87 | --lr 5e-5 \ 88 | --weight_decay 0. \ 89 | --max_epochs 50 \ 90 | --train_way 10 \ 91 | --val_way 5 \ 92 | --num_shot 5 \ 93 | --num_query 15 \ 94 | --train_episodes 200 \ 95 | --eval_episodes 400 \ 96 | --checkpoint_dir ./checkpoints/pt_map_bpa \ 97 | --data_path \ 98 | --pretrained_path ./checkpoints/miniImagenet/WideResNet28_10_S2M2_R/470.tar 99 | ``` 100 | 101 | Alternatively, use the [trained checkpoint](https://drive.google.com/file/d/1wjh_EBQPYYHFjqoqlKCitcG9mjgkWFRw/view?usp=sharing). 102 | 103 | --- 104 | 105 | ### Evaluation 106 | 107 | You can evaluate using the [original PT-MAP repository](https://github.com/yhu01/PT-MAP): 108 | 109 | 1. Clone the PT-MAP repository. 110 | 2. Navigate to `method/pt_map/evaluation/` and replace the files with the ones from this repository. 111 | 3. Edit the following: 112 | - Set `checkpoint_dir` in `save_plk.py` to the location of your BPA checkpoint. 113 | - Update `_datasetFeaturesFiles` in `FSLTask.py` to point to your feature files. 114 | 4. Create feature files by running: 115 | ```bash 116 | python save_plk.py --dataset miniImagenet --method S2M2_R --model WideResNet28_10 117 | ``` 118 | 5. Then run the evaluation: 119 | ```bash 120 | python test_standard.py 121 | ``` 122 | 123 | --- 124 | 125 | #### Alternatively, evaluate directly within this repository 126 | 127 | You can run evaluation directly using the same script as training, with a few extra flags: 128 | 129 | ```bash 130 | python train.py \ 131 | --eval true \ 132 | --pretrained_path \ 133 | --backbone \ 134 | --test_episodes 2000 ``` 135 | ``` 136 | 137 | Make sure to include the same arguments you used during training (e.g., `--method`, `--data_path`, etc.) so the evaluation runs consistently. 138 | 139 | --- 140 | 141 | ### Logging 142 | We support logging and visualization via Weights & Biases (wandb). This helps track your training and evaluation metrics in real-time. 143 | 144 | **To enable logging:** 145 | 146 | 1. Install the `wandb` package: 147 | ```bash 148 | pip install wandb 149 | ``` 150 | 151 | 2. Add the following flags to your command: 152 | ```bash 153 | --wandb true --project --entity 154 | ``` 155 | 156 | Replace `` with your wandb project name, and `` with your wandb username or team name. 157 | 158 | --- 159 | 160 | ## Citation 161 | 162 |

163 | 164 | #### If you find this repository useful in your research, please cite: 165 | @article{shalam2024balanced, 166 | title={The Balanced-Pairwise-Affinities Feature Transform}, 167 | author={Shalam, Daniel and Korman, Simon}, 168 | journal={arXiv preprint arXiv:2407.01467}, 169 | year={2024} 170 | } 171 | 172 |

173 | 174 | --- 175 | 176 | ## Acknowledgment 177 | [Leveraging the Feature Distribution in Transfer-based Few-Shot Learning](https://github.com/yhu01/PT-MAP) 178 | 179 | [S2M2 Charting the Right Manifold: Manifold Mixup for Few-shot Learning](https://arxiv.org/pdf/1907.12087.pdf) 180 | 181 | [Few-Shot Learning via Embedding Adaptation with Set-to-Set Functions](https://arxiv.org/pdf/1812.03664.pdf) 182 | -------------------------------------------------------------------------------- /methods/pt_map/evaluation/test_standard.py: -------------------------------------------------------------------------------- 1 | import collections 2 | import pickle 3 | import random 4 | import numpy as np 5 | import matplotlib.pyplot as plt 6 | import torch 7 | from torch.autograd import Variable 8 | import torch.backends.cudnn as cudnn 9 | import torch.nn as nn 10 | import torch.optim as optim 11 | import math 12 | import torch.nn.functional as F 13 | import torch.optim as optim 14 | from numpy import linalg as LA 15 | from tqdm.notebook import tqdm 16 | 17 | from bpa import BPA 18 | 19 | 20 | use_gpu = torch.cuda.is_available() 21 | 22 | 23 | # ======================================== 24 | # loading datas 25 | 26 | 27 | def centerDatas(datas): 28 | datas[:, :n_lsamples] = datas[:, :n_lsamples, :] - datas[:, :n_lsamples].mean(1, keepdim=True) 29 | datas[:, :n_lsamples] = datas[:, :n_lsamples, :] / torch.norm(datas[:, :n_lsamples, :], 2, 2)[:, :, None] 30 | datas[:, n_lsamples:] = datas[:, n_lsamples:, :] - datas[:, n_lsamples:].mean(1, keepdim=True) 31 | datas[:, n_lsamples:] = datas[:, n_lsamples:, :] / torch.norm(datas[:, n_lsamples:, :], 2, 2)[:, :, None] 32 | 33 | return datas 34 | 35 | def scaleEachUnitaryDatas(datas): 36 | 37 | norms = datas.norm(dim=2, keepdim=True) 38 | return datas/norms 39 | 40 | 41 | def QRreduction(datas): 42 | 43 | ndatas = torch.qr(datas.permute(0,2,1)).R 44 | ndatas = ndatas.permute(0,2,1) 45 | return ndatas 46 | 47 | 48 | class Model: 49 | def __init__(self, n_ways): 50 | self.n_ways = n_ways 51 | 52 | # --------- GaussianModel 53 | class GaussianModel(Model): 54 | def __init__(self, n_ways, lam): 55 | super(GaussianModel, self).__init__(n_ways) 56 | self.mus = None # shape [n_runs][n_ways][n_nfeat] 57 | self.lam = lam 58 | 59 | def clone(self): 60 | other = GaussianModel(self.n_ways) 61 | other.mus = self.mus.clone() 62 | return self 63 | 64 | def cuda(self): 65 | self.mus = self.mus.cuda() 66 | 67 | def initFromLabelledDatas(self): 68 | self.mus = ndatas.reshape(n_runs, n_shot+n_queries,n_ways, n_nfeat)[:,:n_shot,].mean(1) 69 | 70 | def updateFromEstimate(self, estimate, alpha): 71 | 72 | Dmus = estimate - self.mus 73 | self.mus = self.mus + alpha * (Dmus) 74 | 75 | def compute_optimal_transport(self, M, r, c, epsilon=1e-6): 76 | 77 | r = r.cuda() 78 | c = c.cuda() 79 | n_runs, n, m = M.shape 80 | P = torch.exp(- self.lam * M) 81 | P /= P.view((n_runs, -1)).sum(1).unsqueeze(1).unsqueeze(1) 82 | 83 | u = torch.zeros(n_runs, n).cuda() 84 | maxiters = 1000 85 | iters = 1 86 | # normalize this matrix 87 | while torch.max(torch.abs(u - P.sum(2))) > epsilon: 88 | u = P.sum(2) 89 | P *= (r / u).view((n_runs, -1, 1)) 90 | P *= (c / P.sum(1)).view((n_runs, 1, -1)) 91 | if iters == maxiters: 92 | break 93 | iters = iters + 1 94 | return P, torch.sum(P * M) 95 | 96 | def getProbas(self): 97 | # compute squared dist to centroids [n_runs][n_samples][n_ways] 98 | dist = (ndatas.unsqueeze(2)-self.mus.unsqueeze(1)).norm(dim=3).pow(2) 99 | 100 | p_xj = torch.zeros_like(dist) 101 | r = torch.ones(n_runs, n_usamples) 102 | c = torch.ones(n_runs, n_ways) * n_queries 103 | 104 | p_xj_test, _ = self.compute_optimal_transport(dist[:, n_lsamples:], r, c, epsilon=1e-6) 105 | p_xj[:, n_lsamples:] = p_xj_test 106 | 107 | p_xj[:,:n_lsamples].fill_(0) 108 | p_xj[:,:n_lsamples].scatter_(2,labels[:,:n_lsamples].unsqueeze(2), 1) 109 | 110 | return p_xj 111 | 112 | def estimateFromMask(self, mask): 113 | 114 | emus = mask.permute(0,2,1).matmul(ndatas).div(mask.sum(dim=1).unsqueeze(2)) 115 | 116 | return emus 117 | 118 | 119 | # ========================================= 120 | # MAP 121 | # ========================================= 122 | 123 | class MAP: 124 | def __init__(self, alpha=None): 125 | 126 | self.verbose = False 127 | self.progressBar = False 128 | self.alpha = alpha 129 | 130 | def getAccuracy(self, probas): 131 | olabels = probas.argmax(dim=2) 132 | matches = labels.eq(olabels).float() 133 | acc_test = matches[:,n_lsamples:].mean(1) 134 | 135 | m = acc_test.mean().item() 136 | pm = acc_test.std().item() *1.96 / math.sqrt(n_runs) 137 | return m, pm 138 | 139 | def performEpoch(self, model, epochInfo=None): 140 | 141 | p_xj = model.getProbas() 142 | self.probas = p_xj 143 | 144 | if self.verbose: 145 | print("accuracy from filtered probas", self.getAccuracy(self.probas)) 146 | 147 | m_estimates = model.estimateFromMask(self.probas) 148 | 149 | # update centroids 150 | model.updateFromEstimate(m_estimates, self.alpha) 151 | 152 | if self.verbose: 153 | op_xj = model.getProbas() 154 | acc = self.getAccuracy(op_xj) 155 | print("output model accuracy", acc) 156 | 157 | def loop(self, model, n_epochs=20): 158 | 159 | self.probas = model.getProbas() 160 | if self.verbose: 161 | print("initialisation model accuracy", self.getAccuracy(self.probas)) 162 | 163 | if self.progressBar: 164 | if type(self.progressBar) == bool: 165 | pb = tqdm(total = n_epochs) 166 | else: 167 | pb = self.progressBar 168 | 169 | for epoch in range(1, n_epochs+1): 170 | if self.verbose: 171 | print("----- epoch[{:3d}] lr_p: {:0.3f} lr_m: {:0.3f}".format(epoch, self.alpha)) 172 | self.performEpoch(model, epochInfo=(epoch, n_epochs)) 173 | if (self.progressBar): pb.update() 174 | 175 | # get final accuracy and return it 176 | op_xj = model.getProbas() 177 | acc = self.getAccuracy(op_xj) 178 | return acc 179 | 180 | 181 | if __name__ == '__main__': 182 | # ---- data loading 183 | n_shot = 5 184 | n_ways = 5 185 | n_queries = 15 186 | n_runs=10000 187 | n_lsamples = n_ways * n_shot 188 | n_usamples = n_ways * n_queries 189 | n_samples = n_lsamples + n_usamples 190 | 191 | import FSLTask 192 | cfg = {'shot':n_shot, 'ways':n_ways, 'queries':n_queries} 193 | FSLTask.loadDataSet("miniimagenet") 194 | FSLTask.setRandomStates(cfg) 195 | ndatas = FSLTask.GenerateRunSet(cfg=cfg) 196 | ndatas = ndatas.permute(0,2,1,3).reshape(n_runs, n_samples, -1) 197 | labels = torch.arange(n_ways).view(1,1,n_ways).expand(n_runs,n_shot+n_queries,n_ways).clone().view(n_runs, n_samples) 198 | 199 | # Power transform 200 | beta = 0.5 201 | ndatas[:,] = torch.pow(ndatas[:,]+1e-6, beta) 202 | 203 | ndatas = QRreduction(ndatas) 204 | n_nfeat = ndatas.size(2) 205 | 206 | ndatas = scaleEachUnitaryDatas(ndatas) 207 | # trans-mean-sub 208 | ndatas = centerDatas(ndatas) 209 | 210 | USE_BPA = True # comment it to run vanilla PT-MAP 211 | if USE_BPA: 212 | ndatas = BPA(ot_reg = 0.2)(ndatas) # BPA insertion 213 | 214 | # for PT-MAP, we need to scale and normalize again 215 | ndatas = scaleEachUnitaryDatas(ndatas) 216 | ndatas = centerDatas(ndatas) 217 | 218 | # expected results for 1 and 5 shots | W/o BPA: 1=82.11, 5=88.57 | W/ BPA_p: 1=82.62 , 5=89.14 | W/ BPA_t: 1=84.69 , 5=90.30 | 219 | 220 | n_nfeat = ndatas.size(2) 221 | print("size of the datas...", ndatas.size()) 222 | 223 | # switch to cuda 224 | ndatas = ndatas.cuda() 225 | labels = labels.cuda() 226 | 227 | #MAP 228 | lam = 10 229 | model = GaussianModel(n_ways, lam) 230 | model.initFromLabelledDatas() 231 | 232 | alpha = 0.2 233 | optim = MAP(alpha) 234 | 235 | optim.verbose=False 236 | optim.progressBar=True 237 | 238 | acc_test = optim.loop(model, n_epochs=20) 239 | 240 | print("final accuracy found {:0.2f} +- {:0.2f}".format(*(100*x for x in acc_test))) 241 | 242 | 243 | 244 | -------------------------------------------------------------------------------- /models/wrn_mixup_model.py: -------------------------------------------------------------------------------- 1 | ### dropout has been removed in this code. original code had dropout##### 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.init as init 5 | import torch.nn.functional as F 6 | from torch.autograd import Variable 7 | 8 | import sys, os 9 | import numpy as np 10 | import random 11 | 12 | act = torch.nn.ReLU() 13 | 14 | import math 15 | from torch.nn.utils.weight_norm import WeightNorm 16 | 17 | 18 | class BasicBlock(nn.Module): 19 | def __init__(self, in_planes, out_planes, stride, dropRate=0.0): 20 | super(BasicBlock, self).__init__() 21 | self.bn1 = nn.BatchNorm2d(in_planes) 22 | self.relu1 = nn.ReLU(inplace=True) 23 | self.conv1 = nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 24 | padding=1, bias=False) 25 | self.bn2 = nn.BatchNorm2d(out_planes) 26 | self.relu2 = nn.ReLU(inplace=True) 27 | self.conv2 = nn.Conv2d(out_planes, out_planes, kernel_size=3, stride=1, 28 | padding=1, bias=False) 29 | self.droprate = dropRate 30 | self.equalInOut = (in_planes == out_planes) 31 | self.convShortcut = (not self.equalInOut) and nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, 32 | padding=0, bias=False) or None 33 | 34 | def forward(self, x): 35 | if not self.equalInOut: 36 | x = self.relu1(self.bn1(x)) 37 | else: 38 | out = self.relu1(self.bn1(x)) 39 | out = self.relu2(self.bn2(self.conv1(out if self.equalInOut else x))) 40 | if self.droprate > 0: 41 | out = F.dropout(out, p=self.droprate, training=self.training) 42 | out = self.conv2(out) 43 | return torch.add(x if self.equalInOut else self.convShortcut(x), out) 44 | 45 | 46 | class distLinear(nn.Module): 47 | def __init__(self, indim, outdim): 48 | super(distLinear, self).__init__() 49 | self.L = nn.Linear(indim, outdim, bias=False) 50 | self.class_wise_learnable_norm = True # See the issue#4&8 in the github 51 | if self.class_wise_learnable_norm: 52 | WeightNorm.apply(self.L, 'weight', dim=0) # split the weight update component to direction and norm 53 | 54 | if outdim <= 200: 55 | self.scale_factor = 2 # a fixed scale factor to scale the output of cos value into a reasonably large input for softmax 56 | else: 57 | self.scale_factor = 10 # in omniglot, a larger scale factor is required to handle >1000 output classes. 58 | 59 | def forward(self, x): 60 | x_norm = torch.norm(x, p=2, dim=1).unsqueeze(1).expand_as(x) 61 | x_normalized = x.div(x_norm + 0.00001) 62 | if not self.class_wise_learnable_norm: 63 | L_norm = torch.norm(self.L.weight.data, p=2, dim=1).unsqueeze(1).expand_as(self.L.weight.data) 64 | self.L.weight.data = self.L.weight.data.div(L_norm + 0.00001) 65 | cos_dist = self.L( 66 | x_normalized) # matrix product by forward function, but when using WeightNorm, this also multiply the cosine distance by a class-wise learnable norm, see the issue#4&8 in the github 67 | scores = self.scale_factor * (cos_dist) 68 | 69 | return scores 70 | 71 | 72 | class NetworkBlock(nn.Module): 73 | def __init__(self, nb_layers, in_planes, out_planes, block, stride, dropRate=0.0): 74 | super(NetworkBlock, self).__init__() 75 | self.layer = self._make_layer(block, in_planes, out_planes, nb_layers, stride, dropRate) 76 | 77 | def _make_layer(self, block, in_planes, out_planes, nb_layers, stride, dropRate): 78 | layers = [] 79 | for i in range(int(nb_layers)): 80 | layers.append(block(i == 0 and in_planes or out_planes, out_planes, i == 0 and stride or 1, dropRate)) 81 | return nn.Sequential(*layers) 82 | 83 | def forward(self, x): 84 | return self.layer(x) 85 | 86 | 87 | def to_one_hot(inp, num_classes): 88 | y_onehot = torch.FloatTensor(inp.size(0), num_classes) 89 | if torch.cuda.is_available(): 90 | y_onehot = y_onehot.cuda() 91 | 92 | y_onehot.zero_() 93 | x = inp.type(torch.LongTensor) 94 | if torch.cuda.is_available(): 95 | x = x.cuda() 96 | 97 | x = torch.unsqueeze(x, 1) 98 | y_onehot.scatter_(1, x, 1) 99 | 100 | return Variable(y_onehot, requires_grad=False) 101 | # return y_onehot 102 | 103 | 104 | def mixup_data(x, y, lam): 105 | '''Compute the mixup data. Return mixed inputs, pairs of targets, and lambda''' 106 | 107 | batch_size = x.size()[0] 108 | index = torch.randperm(batch_size) 109 | if torch.cuda.is_available(): 110 | index = index.cuda() 111 | mixed_x = lam * x + (1 - lam) * x[index, :] 112 | y_a, y_b = y, y[index] 113 | 114 | return mixed_x, y_a, y_b, lam 115 | 116 | 117 | class WideResNet(nn.Module): 118 | def __init__(self, depth=28, widen_factor=10, num_classes=200, loss_type='dist', per_img_std=False, stride=1, 119 | dropRate=0.5): 120 | flatten = True 121 | super(WideResNet, self).__init__() 122 | nChannels = [16, 16 * widen_factor, 32 * widen_factor, 64 * widen_factor] 123 | assert ((depth - 4) % 6 == 0) 124 | n = (depth - 4) / 6 125 | block = BasicBlock 126 | # 1st conv before any network block 127 | self.conv1 = nn.Conv2d(3, nChannels[0], kernel_size=3, stride=1, 128 | padding=1, bias=False) 129 | # 1st block 130 | self.block1 = NetworkBlock(n, nChannels[0], nChannels[1], block, stride, dropRate) 131 | # 2nd block 132 | self.block2 = NetworkBlock(n, nChannels[1], nChannels[2], block, 2, dropRate) 133 | # 3rd block 134 | self.block3 = NetworkBlock(n, nChannels[2], nChannels[3], block, 2, dropRate) 135 | # global average pooling and linear 136 | self.bn1 = nn.BatchNorm2d(nChannels[3]) 137 | self.relu = nn.ReLU(inplace=True) 138 | self.nChannels = nChannels[3] 139 | 140 | self.num_classes = num_classes 141 | if flatten: 142 | self.final_feat_dim = 640 143 | for m in self.modules(): 144 | if isinstance(m, nn.Conv2d): 145 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 146 | m.weight.data.normal_(0, math.sqrt(2. / n)) 147 | elif isinstance(m, nn.BatchNorm2d): 148 | m.weight.data.fill_(1) 149 | m.bias.data.zero_() 150 | 151 | def forward(self, x, target=None, mixup=False, mixup_hidden=True, mixup_alpha=None, lam=0.4, return_logits=True): 152 | if target is not None: 153 | if mixup_hidden: 154 | layer_mix = random.randint(0, 3) 155 | elif mixup: 156 | layer_mix = 0 157 | else: 158 | layer_mix = None 159 | 160 | out = x 161 | 162 | target_a = target_b = target 163 | 164 | if layer_mix == 0: 165 | out, target_a, target_b, lam = mixup_data(out, target, lam=lam) 166 | 167 | out = self.conv1(out) 168 | out = self.block1(out) 169 | 170 | if layer_mix == 1: 171 | out, target_a, target_b, lam = mixup_data(out, target, lam=lam) 172 | 173 | out = self.block2(out) 174 | 175 | if layer_mix == 2: 176 | out, target_a, target_b, lam = mixup_data(out, target, lam=lam) 177 | 178 | out = self.block3(out) 179 | if layer_mix == 3: 180 | out, target_a, target_b, lam = mixup_data(out, target, lam=lam) 181 | 182 | out = self.relu(self.bn1(out)) 183 | out = F.avg_pool2d(out, out.size()[2:]) 184 | out = out.view(out.size(0), -1) 185 | if not return_logits: 186 | return out, target_a, target_b 187 | 188 | out1 = self.linear(out) 189 | return out, out1, target_a, target_b 190 | else: 191 | out = x 192 | out = self.conv1(out) 193 | out = self.block1(out) 194 | out = self.block2(out) 195 | out = self.block3(out) 196 | out = self.relu(self.bn1(out)) 197 | out = F.avg_pool2d(out, out.size()[2:]) 198 | out = out.view(out.size(0), -1) 199 | # if not return_logits: 200 | return out 201 | 202 | # out1 = self.linear(out) 203 | # return out, out1 204 | 205 | 206 | def wrn28_10(num_classes=200, loss_type='dist', dropout=0): 207 | model = WideResNet(depth=28, widen_factor=10, num_classes=num_classes, loss_type=loss_type, per_img_std=False, 208 | stride=1, dropRate=dropout) 209 | return model 210 | -------------------------------------------------------------------------------- /methods/pt_map/test_standard.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import math 3 | 4 | from tqdm import tqdm 5 | import sys 6 | 7 | import torch 8 | import torch.nn.functional as F 9 | 10 | from methods.pt_map import FSLTask 11 | from bpa import BPA 12 | 13 | 14 | def bool_flag(s): 15 | """ 16 | Parse boolean arguments from the command line. 17 | """ 18 | FALSY_STRINGS = {"off", "false", "0"} 19 | TRUTHY_STRINGS = {"on", "true", "1"} 20 | if s.lower() in FALSY_STRINGS: 21 | return False 22 | elif s.lower() in TRUTHY_STRINGS: 23 | return True 24 | else: 25 | raise argparse.ArgumentTypeError("invalid value for a boolean flag") 26 | 27 | 28 | def centerDatas(datas): 29 | datas[:, :n_lsamples] = datas[:, :n_lsamples, :] - datas[:, :n_lsamples].mean(1, keepdim=True) 30 | datas[:, n_lsamples:] = datas[:, n_lsamples:, :] - datas[:, n_lsamples:].mean(1, keepdim=True) 31 | return datas 32 | 33 | 34 | def scaleEachUnitaryDatas(datas): 35 | norms = datas.norm(dim=-1, keepdim=True) 36 | return datas / norms 37 | 38 | 39 | def QRreduction(datas): 40 | ndatas = torch.linalg.qr(datas.permute(0, 2, 1)).R 41 | ndatas = ndatas.permute(0, 2, 1) 42 | return ndatas 43 | 44 | 45 | class Model: 46 | def __init__(self, n_ways): 47 | self.n_ways = n_ways 48 | 49 | 50 | # --------- GaussianModel 51 | class GaussianModel(Model): 52 | def __init__(self, n_ways, lam, distance_metric: str = 'euclidean'): 53 | super(GaussianModel, self).__init__(n_ways) 54 | self.mus = None # shape [n_runs][n_ways][n_nfeat] 55 | self.lam = lam 56 | self.distance_metric = distance_metric 57 | 58 | def clone(self): 59 | other = GaussianModel(self.n_ways) 60 | other.mus = self.mus.clone() 61 | return self 62 | 63 | def cuda(self): 64 | self.mus = self.mus.cuda() 65 | 66 | def initFromLabelledDatas(self): 67 | self.mus = ndatas.reshape(n_runs, n_shot + n_queries, n_ways, n_nfeat)[:, :n_shot, ].mean(1) 68 | 69 | def updateFromEstimate(self, estimate, alpha): 70 | 71 | Dmus = estimate - self.mus 72 | self.mus = self.mus + alpha * Dmus 73 | 74 | def compute_optimal_transport(self, M, r, c, epsilon=1e-6): 75 | r = r.cuda() 76 | c = c.cuda() 77 | n_runs, n, m = M.shape 78 | P = torch.exp(- self.lam * M) 79 | P /= P.view((n_runs, -1)).sum(1).unsqueeze(1).unsqueeze(1) 80 | u = torch.zeros(n_runs, n).cuda() 81 | maxiters = 1000 82 | iters = 1 83 | # normalize this matrix 84 | while torch.max(torch.abs(u - P.sum(2))) > epsilon: 85 | u = P.sum(2) 86 | P *= (r / u).view((n_runs, -1, 1)) 87 | P *= (c / P.sum(1)).view((n_runs, 1, -1)) 88 | if iters == maxiters: 89 | break 90 | iters = iters + 1 91 | return P 92 | 93 | @staticmethod 94 | def _pairwise_dist(a, b): 95 | return (a.unsqueeze(2) - b.unsqueeze(1)).norm(dim=3).pow(2) 96 | 97 | def getProbas(self): 98 | global ndatas, n_nfeat 99 | # compute squared dist to centroids [n_runs][n_samples][n_ways] 100 | if self.distance_metric == 'cosine': 101 | dist = 1-torch.bmm(F.normalize(ndatas), F.normalize(self.mus.transpose(1, 2))) 102 | elif self.distance_metric == 'ce': 103 | dist = -torch.bmm(torch.log(ndatas + 1e-5), self.mus.transpose(1, 2)) 104 | else: 105 | dist = self._pairwise_dist(ndatas, self.mus) 106 | 107 | p_xj = torch.zeros_like(dist) 108 | r = torch.ones(n_runs, n_usamples, device='cuda') 109 | c = torch.ones(n_runs, n_ways, device='cuda') * n_queries 110 | p_xj_test = self.compute_optimal_transport(dist[:, n_lsamples:], r, c, epsilon=1e-4) 111 | p_xj[:, n_lsamples:] = p_xj_test 112 | 113 | p_xj[:, :n_lsamples].fill_(0) 114 | p_xj[:, :n_lsamples].scatter_(2, labels[:, :n_lsamples].unsqueeze(2), 1) 115 | 116 | return p_xj 117 | 118 | def estimateFromMask(self, mask): 119 | emus = mask.permute(0, 2, 1).matmul(ndatas).div(mask.sum(dim=1).unsqueeze(2)) 120 | return emus 121 | 122 | 123 | # ========================================= 124 | # MAP 125 | # ========================================= 126 | 127 | class MAP: 128 | def __init__(self, alpha=None, verbose: bool = False, progressBar: bool = False): 129 | self.verbose = verbose 130 | self.progressBar = progressBar 131 | self.alpha = alpha 132 | 133 | def getAccuracy(self, probas): 134 | olabels = probas.argmax(dim=2) 135 | matches = labels.eq(olabels).float() 136 | acc_test = matches[:, n_lsamples:].mean(1) 137 | 138 | m = acc_test.mean().item() 139 | pm = acc_test.std().item() * 1.96 / math.sqrt(n_runs) 140 | return m, pm 141 | 142 | def performEpoch(self, model, epochInfo=None): 143 | 144 | p_xj = model.getProbas() 145 | self.probas = p_xj 146 | 147 | m_estimates = model.estimateFromMask(self.probas) 148 | # update centroids 149 | model.updateFromEstimate(m_estimates, self.alpha) 150 | 151 | if self.verbose: 152 | op_xj = model.getProbas() 153 | acc = self.getAccuracy(op_xj) 154 | print("output model accuracy", acc) 155 | 156 | def loop(self, model, n_epochs=20): 157 | self.probas = model.getProbas() 158 | if self.verbose: 159 | print("initialisation model accuracy", self.getAccuracy(self.probas)) 160 | 161 | if self.progressBar: 162 | if type(self.progressBar) == bool: 163 | pb = tqdm(total=n_epochs) 164 | else: 165 | pb = self.progressBar 166 | 167 | for epoch in range(1, n_epochs + 1): 168 | self.performEpoch(model, epochInfo=(epoch, n_epochs)) 169 | if self.progressBar: pb.update() 170 | 171 | # get final accuracy and return it 172 | op_xj = model.getProbas() 173 | acc = self.getAccuracy(op_xj) 174 | return acc 175 | 176 | 177 | def get_args(): 178 | """ Description: Parses arguments at command line. """ 179 | parser = argparse.ArgumentParser() 180 | parser.add_argument('--root', type=str, default='C:/Users/dani3/Documents/GitHub/SOT/') 181 | parser.add_argument('--features_path', type=str, 182 | default='/checkpoints/wrn/miniImagenet/WideResNet28_10_S2M2_R/last/output.plk') 183 | parser.add_argument('--dataset', type=str, default='miniimagenet', choices=['miniimagenet']) 184 | parser.add_argument('--num_way', type=int, default=5) 185 | parser.add_argument('--num_shot', type=int, default=5) 186 | parser.add_argument('--num_query', type=int, default=15) 187 | parser.add_argument('--num_runs', type=int, default=10000) 188 | parser.add_argument('--num_repeat', type=int, default=1, 189 | help='repeat the evaluation n times for averaging purposes.') 190 | parser.add_argument('--verbose', type=bool_flag, default=False) 191 | 192 | # BPA args 193 | parser.add_argument('--ot_reg', type=float, default=0.1) 194 | parser.add_argument('--sink_iters', type=int, default=10) 195 | parser.add_argument('--distance_metric', type=str, default='cosine') 196 | parser.add_argument('--norm_type', type=str, default='sinkhorn') 197 | parser.add_argument('--mask_diag', type=bool_flag, default=True) 198 | return parser.parse_args() 199 | 200 | 201 | if __name__ == '__main__': 202 | # ---- data loading 203 | args = get_args() 204 | n_shot = args.num_shot 205 | n_ways = args.num_way 206 | n_queries = args.num_query 207 | n_runs = args.num_runs 208 | n_lsamples = n_ways * n_shot 209 | n_usamples = n_ways * n_queries 210 | n_samples = n_lsamples + n_usamples 211 | 212 | cfg = {'shot': n_shot, 'ways': n_ways, 'queries': n_queries} 213 | FSLTask.loadDataSet(args.dataset, root=args.root, features_path=args.root + args.features_path) 214 | FSLTask.setRandomStates(cfg) 215 | ndatas = FSLTask.GenerateRunSet(cfg=cfg, end=n_runs) 216 | ndatas = ndatas.permute(0, 2, 1, 3).reshape(n_runs, n_samples, -1) 217 | labels = torch.arange(n_ways).view(1, 1, n_ways).expand(n_runs, n_shot + n_queries, n_ways).clone().view(n_runs, 218 | n_samples) 219 | labels = labels.cuda() 220 | ndatas = ndatas.cuda() 221 | 222 | # Power transform + QR + Normalize 223 | ndatas[:, ] = torch.pow(ndatas[:, ] + 1e-6, 0.5) 224 | ndatas = QRreduction(ndatas) 225 | ndatas = scaleEachUnitaryDatas(ndatas) 226 | # trans-mean-sub 227 | ndatas = centerDatas(ndatas) 228 | _ndatas = scaleEachUnitaryDatas(ndatas) 229 | # # transform data 230 | bpa = BPA( 231 | args.distance_metric, 232 | ot_reg=args.ot_reg, 233 | sinkhorn_iterations=args.sink_iters, 234 | mask_diag=args.mask_diag, 235 | ) 236 | 237 | for dm in ['euclidean']: 238 | print(f"DM {dm}") 239 | for mask_diag in [False, True]: 240 | bpa.mask_diag = mask_diag 241 | print(f"sot mask_diag {bpa.mask_diag }") 242 | # for max_temp in [False, True]: 243 | # print(f"sot max_temp {max_temp}") 244 | for reg in [0.1, 0.2, 0.3, 0.4, 0.5]: 245 | bpa.ot_reg = reg 246 | print(f"sot lambda {bpa.ot_reg}") 247 | 248 | ndatas = bpa(_ndatas) 249 | n_nfeat = ndatas.size(2) 250 | print("size of the datas...", ndatas.size()) 251 | 252 | # MAP 253 | model = GaussianModel(n_ways=n_ways, lam=10, distance_metric=dm) 254 | model.initFromLabelledDatas() 255 | 256 | optim = MAP(alpha=0.2, verbose=args.verbose) 257 | acc_test = optim.loop(model, n_epochs=20) 258 | 259 | print("final accuracy found {:0.2f} +- {:0.2f}".format(*(100 * x for x in acc_test))) 260 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import math 3 | import os 4 | from time import time 5 | 6 | import torch 7 | 8 | import utils 9 | from bpa import BPA 10 | 11 | 12 | device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') 13 | 14 | 15 | def get_args(): 16 | parser = argparse.ArgumentParser() 17 | 18 | parser.add_argument('--seed', type=int, default=1, 19 | help="""Random seed.""") 20 | 21 | parser.add_argument('--root_path', type=str, default='./', 22 | help=""" Path to project root directory. """) 23 | parser.add_argument('--checkpoint_dir', type=str, default=None, 24 | help=""" Where to save model checkpoints. If None, it will automatically created. """) 25 | parser.add_argument('--dataset', type=str, default='miniimagenet', 26 | choices=['miniimagenet', 'cifar']) 27 | parser.add_argument('--data_path', type=str, default='./datasets/few_shot/miniimagenet', 28 | help="""Path to dataset root directory.""") 29 | 30 | parser.add_argument('--backbone', type=str, default='wrn', 31 | help="""Define which backbone network to use. """) 32 | parser.add_argument('--pretrained_path', type=str, default=False, 33 | help=""" Path to pretrained model, used for testing/fine-tuning. """) 34 | 35 | parser.add_argument('--eval', type=utils.bool_flag, default=False, 36 | help=""" If true, make evaluation on the *test set*. 37 | The amount of test episodes controlled by --test_episodes=<>""") 38 | parser.add_argument('--eval_freq', type=int, default=1, 39 | help=""" Evaluate training every n epochs. """) 40 | parser.add_argument('--eval_first', type=utils.bool_flag, default=False, 41 | help=""" Set to true to evaluate the model before training. Useful for fine-tuning. """) 42 | parser.add_argument('--num_workers', type=int, default=8) 43 | 44 | # wandb specific arguments 45 | parser.add_argument('--wandb', type=utils.bool_flag, default=False, 46 | help=""" Log data into wandb. """) 47 | parser.add_argument('--project', type=str, default='BPA', 48 | help=""" Project name in wandb. """) 49 | parser.add_argument('--entity', type=str, default='', 50 | help=""" Your wandb entity name. """) 51 | 52 | # few-shot specific arguments 53 | parser.add_argument('--method', type=str, default='pt_map_bpa', 54 | choices=['proto', 'proto_bpa', 'pt_map', 'pt_map_bpa'], 55 | help="""Specify which few-shot classifier to use.""") 56 | parser.add_argument('--train_way', type=int, default=5, 57 | help=""" Number of classes for each training task. """) 58 | parser.add_argument('--val_way', type=int, default=5, 59 | help=""" Number of classes for each validation/test task. """) 60 | parser.add_argument('--num_shot', type=int, default=5, 61 | help=""" Number of (labeled) support samples for each class. """) 62 | parser.add_argument('--num_query', type=int, default=15, 63 | help=""" Number of (un-labeled) query samples for each class. """) 64 | parser.add_argument('--train_episodes', type=int, default=200, 65 | help=""" Number of few-shot tasks for each epoch. """) 66 | parser.add_argument('--eval_episodes', type=int, default=400, 67 | help=""" Number of tasks to evaluate. """) 68 | parser.add_argument('--test_episodes', type=int, default=10000, 69 | help=""" Number of tasks to evaluate. """) 70 | parser.add_argument('--temperature', type=float, default=1., 71 | help=""" Temperature for ProtoNet. """) 72 | 73 | # training specific arguments 74 | parser.add_argument('--max_epochs', type=int, default=25, 75 | help="""Number of training/finetuning epochs. """) 76 | parser.add_argument('--optimizer', type=str, default='adam', 77 | help="""Optimizer""", choices=['adam', 'adamw', 'sgd']) 78 | parser.add_argument('--lr', type=float, default=5e-5, 79 | help="""Learning rate. """) 80 | parser.add_argument('--weight_decay', type=float, default=1e-4, 81 | help="""Weight decay. """) 82 | parser.add_argument('--dropout', type=float, default=0., 83 | help=""" Dropout probability. """) 84 | parser.add_argument('--momentum', type=float, default=0.9, 85 | help="""Momentum of SGD optimizer. """) 86 | parser.add_argument('--scheduler', type=str, default='step', 87 | help="""Learning rate scheduler. To disable the scheduler, use scheduler=''. """) 88 | parser.add_argument('--step_size', type=int, default=5, 89 | help="""Step size (in epochs) of StepLR scheduler. """) 90 | parser.add_argument('--gamma', type=float, default=0.5, 91 | help="""Gamma of StepLR scheduler. """) 92 | parser.add_argument('--augment', type=utils.bool_flag, default=False, 93 | help=""" Apply data augmentation. """) 94 | 95 | # BPA specific arguments 96 | parser.add_argument('--ot_reg', type=float, default=0.1, 97 | help=""" Sinkhorn entropy regularization. 98 | For few-shot methods, 0.1-0.2 seems to work best. 99 | For larger tasks (~10,000) samples, try to increase this value. """) 100 | parser.add_argument('--sink_iters', type=int, default=20, 101 | help=""" Number of Sinkhorn iterations. 102 | Usually small number (~ 5-10) is sufficient. """) 103 | parser.add_argument('--distance_metric', type=str, default='cosine', 104 | help=""" Distance metric for the OT cost matrix. """, 105 | choices=['cosine', 'euclidean']) 106 | parser.add_argument('--mask_diag', type=utils.bool_flag, default=True, 107 | help=""" If true, mask diagonal (self) values before and after the OT. """) 108 | parser.add_argument('--max_scale', type=utils.bool_flag, default=True, 109 | help=""" Scaling range of the BPA values to [0,1]. 110 | This should always be True. """) 111 | 112 | return parser.parse_args() 113 | 114 | 115 | def main(): 116 | args = get_args() 117 | utils.set_seed(seed=args.seed) 118 | print("\n".join("%s: %s" % (k, str(v)) for k, v in sorted(dict(vars(args)).items()))) 119 | output_dir = utils.get_output_dir(args=args) 120 | 121 | # define datasets and loaders 122 | args.set_episodes = dict(train=args.train_episodes, val=args.eval_episodes, test=args.test_episodes) 123 | if not args.eval: 124 | train_dataloader = utils.get_dataloader(set_name='train', args=args, constant=False) 125 | val_dataloader = utils.get_dataloader(set_name='val', args=args, constant=True) 126 | else: 127 | val_dataloader = utils.get_dataloader(set_name='test', args=args, constant=False) 128 | train_dataloader = None 129 | 130 | # define model and load pretrained weights if available 131 | model = utils.get_model(args.backbone, args) 132 | model = model.to(device) 133 | utils.load_weights(model, args.pretrained_path) 134 | 135 | # BPA and few-shot classification method (e.g. proto, pt-map...) 136 | bpa = None 137 | if 'bpa' in args.method.lower(): 138 | bpa = BPA( 139 | distance_metric=args.distance_metric, 140 | ot_reg=args.ot_reg, 141 | mask_diag=args.mask_diag, 142 | sinkhorn_iterations=args.sink_iters, 143 | max_scale=args.max_scale 144 | ) 145 | fewshot_method = utils.get_method(args=args, bpa=bpa) 146 | 147 | # few-shot labels 148 | train_labels = utils.get_fs_labels(args.method, args.train_way, args.num_query, args.num_shot) 149 | val_labels = utils.get_fs_labels(args.method, args.val_way, args.num_query, args.num_shot) 150 | 151 | # initialized wandb 152 | if args.wandb: 153 | utils.init_wandb(exp_name=output_dir.split('/')[-1] if output_dir[-1] != '/' else output_dir.split('/')[-2], 154 | args=args) 155 | 156 | # define loss 157 | criterion = utils.get_criterion_by_method(method=args.method) 158 | 159 | # Test-set evaluation 160 | if args.eval: 161 | print(f"Evaluate model for {args.test_episodes} episodes... ") 162 | loss, acc = eval_one_epoch(model, val_dataloader, fewshot_method, criterion, val_labels, 0, args, set_name='test') 163 | print("Final evaluation results:\nAccuracy={:.4f}, Loss={:.4f}".format(acc, loss)) 164 | exit(1) 165 | 166 | # define optimizer and scheduler 167 | optimizer, lr_scheduler = utils.get_optimizer_and_lr_scheduler(args=args, params=model.parameters()) 168 | 169 | # evaluate model before training 170 | if args.eval_first: 171 | print("Evaluate model before training... ") 172 | eval_one_epoch(model, val_dataloader, fewshot_method, criterion, val_labels, -1, args, set_name='val') 173 | 174 | # train 175 | print("Start training...") 176 | best_acc = 0. 177 | best_loss = math.inf 178 | for epoch in range(args.max_epochs): 179 | print("[Epoch {}/{}]...".format(epoch, args.max_epochs)) 180 | 181 | # train 182 | train_one_epoch(model, train_dataloader, optimizer, fewshot_method, criterion, train_labels, epoch, args) 183 | if lr_scheduler is not None: 184 | lr_scheduler.step() 185 | 186 | # evaluate 187 | if epoch % args.eval_freq == 0: 188 | eval_loss, eval_acc = eval_one_epoch(model, val_dataloader, fewshot_method, criterion, val_labels, 189 | epoch, args, set_name='val') 190 | # save best model 191 | if eval_loss < best_loss: 192 | best_loss = eval_loss 193 | torch.save(model.state_dict(), os.path.join(output_dir, 'min_loss.pth')) 194 | elif eval_acc > best_acc: 195 | best_acc = eval_acc 196 | torch.save(model.state_dict(), os.path.join(output_dir, 'max_acc.pth')) 197 | 198 | # save last checkpoint 199 | torch.save(model.state_dict(), os.path.join(output_dir, 'last.pth')) 200 | 201 | 202 | def train_one_epoch(model, dataloader, optimizer, fewshot_method, criterion, labels, epoch, args): 203 | metric_logger = utils.MetricLogger(delimiter=" ") 204 | header = 'Train Epoch: [{}/{}]'.format(epoch, args.max_epochs) 205 | log_freq = 50 206 | n_batches = len(dataloader) 207 | 208 | model.train() 209 | for batch_idx, (images, _) in enumerate(metric_logger.log_every(dataloader, log_freq, header=header)): 210 | images = images.to(device) 211 | # extract features 212 | features = model(images) 213 | # few-shot classifier 214 | probas, accuracy = fewshot_method(features, labels=labels, mode='train') 215 | q_labels = labels if len(labels) == len(probas) else labels[-len(probas):] 216 | # loss 217 | loss = criterion(probas, q_labels) 218 | 219 | optimizer.zero_grad() 220 | loss.backward() 221 | optimizer.step() 222 | 223 | metric_logger.update(loss=loss.detach().item(), accuracy=accuracy) 224 | 225 | if batch_idx % log_freq == 0: 226 | utils.wandb_log( 227 | { 228 | 'train/step': batch_idx + (epoch * n_batches), 229 | 'train/loss_step': loss.item(), 230 | 'train/accuracy_step': accuracy, 231 | } 232 | ) 233 | 234 | print("Averaged stats:", metric_logger) 235 | utils.wandb_log( 236 | { 237 | 'lr': optimizer.param_groups[0]['lr'], 238 | 'train/epoch': epoch, 239 | 'train/loss': metric_logger.loss.global_avg, 240 | 'train/accuracy': metric_logger.accuracy.global_avg, 241 | } 242 | ) 243 | return metric_logger 244 | 245 | 246 | @torch.no_grad() 247 | def eval_one_epoch(model, dataloader, fewshot_method, criterion, labels, epoch, args, set_name): 248 | metric_logger = utils.MetricLogger(delimiter=" ") 249 | header = 'Validation:' if set_name == "val" else 'Test:' 250 | log_freq = 50 251 | 252 | n_batches = len(dataloader) 253 | model.eval() 254 | for batch_idx, (images, _) in enumerate(metric_logger.log_every(dataloader, log_freq, header=header)): 255 | images = images.to(device) 256 | # extract features 257 | features = model(images) 258 | # few-shot classifier 259 | probas, accuracy = fewshot_method(X=features, labels=labels, mode='val') 260 | q_labels = labels if len(labels) == len(probas) else labels[-len(probas):] 261 | # loss 262 | loss = criterion(probas, q_labels) 263 | metric_logger.update(loss=loss.detach().item(), accuracy=accuracy) 264 | 265 | print("Averaged stats:", metric_logger) 266 | utils.wandb_log( 267 | { 268 | '{}/epoch'.format(set_name): epoch, 269 | '{}/loss'.format(set_name): metric_logger.loss.global_avg, 270 | '{}/accuracy'.format(set_name): metric_logger.accuracy.global_avg, 271 | } 272 | ) 273 | return metric_logger.loss.global_avg, metric_logger.accuracy.global_avg 274 | 275 | 276 | if __name__ == '__main__': 277 | main() 278 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import os 3 | import argparse 4 | import random 5 | import time 6 | from collections import defaultdict, deque 7 | 8 | import numpy as np 9 | import torch 10 | from torch import optim 11 | from torch.utils.data import DataLoader 12 | 13 | from models.wrn_mixup_model import wrn28_10 14 | from models.resnet12 import Res12 15 | from datasets import MiniImageNet, CIFAR, CUB 16 | from datasets.samplers import CategoriesSampler 17 | from methods import PTMAPLoss, ProtoLoss 18 | 19 | try: 20 | import wandb 21 | HAS_WANDB = True 22 | except Exception as e: 23 | HAS_WANDB = False 24 | 25 | 26 | MODELS = dict( 27 | wrn=wrn28_10, resnet12=Res12 28 | ) 29 | DATASETS = dict( 30 | miniimagenet=MiniImageNet, cifar=CIFAR 31 | ) 32 | METHODS = dict( 33 | pt_map=PTMAPLoss, pt_map_bpa=PTMAPLoss, proto=ProtoLoss, proto_bpa=ProtoLoss, 34 | ) 35 | 36 | 37 | def get_model(model_name: str, args): 38 | """ 39 | Get the backbone model. 40 | """ 41 | arch = model_name.lower() 42 | if arch in MODELS.keys(): 43 | model = MODELS[arch](dropout=args.dropout) 44 | if torch.cuda.is_available(): 45 | torch.backends.cudnn.benchmark = True 46 | return model 47 | else: 48 | raise ValueError(f'Model {model_name} not implemented. available models are: {list(MODELS.keys())}') 49 | 50 | 51 | def get_dataloader(set_name: str, args: argparse, constant: bool = False): 52 | """ 53 | Get dataloader with categorical sampler for few-shot classification. 54 | """ 55 | num_episodes = args.set_episodes[set_name] 56 | num_way = args.train_way if set_name == 'train' else args.val_way 57 | 58 | # define dataset sampler and data loader 59 | data_set = DATASETS[args.dataset.lower()]( 60 | args.data_path, set_name, args.backbone, 61 | augment=set_name == 'train' and args.augment 62 | ) 63 | args.img_size = data_set.image_size 64 | 65 | data_sampler = CategoriesSampler( 66 | set_name, data_set.label, num_episodes, const_loader=constant, 67 | num_way=num_way, num_shot=args.num_shot, num_query=args.num_query, 68 | replace=set_name == 'train', 69 | ) 70 | return DataLoader( 71 | data_set, batch_sampler=data_sampler, num_workers=args.num_workers, pin_memory=not constant 72 | ) 73 | 74 | 75 | def get_optimizer_and_lr_scheduler(args, params): 76 | optimizer = get_optimizer(args, params) 77 | lr_scheduler = get_scheduler(args, optimizer) 78 | return optimizer, lr_scheduler 79 | 80 | 81 | def get_optimizer(args, params): 82 | """ 83 | Get optimizer. 84 | """ 85 | if args.optimizer == 'adam': 86 | return optim.Adam(params, lr=args.lr, weight_decay=args.weight_decay) 87 | elif args.optimizer == 'adamw': 88 | return optim.AdamW(params, lr=args.lr, weight_decay=args.weight_decay) 89 | elif args.optimizer == 'sgd': 90 | return optim.SGD(params, lr=args.lr, momentum=args.momentum, nesterov=True, weight_decay=args.weight_decay) 91 | else: 92 | raise ValueError(f'Optimizer {args.optimizer} not available.') 93 | 94 | 95 | def get_scheduler(args, optimizer: torch.optim): 96 | """ 97 | Get optimizer. 98 | """ 99 | if not args.scheduler or args.scheduler == '': 100 | return None 101 | elif args.scheduler == 'step': 102 | return optim.lr_scheduler.StepLR(optimizer=optimizer, step_size=args.step_size, gamma=args.gamma) 103 | else: 104 | raise ValueError(f'Error: LR-scheduler {args.scheduler} is not available.') 105 | 106 | 107 | def get_method(args, bpa=None): 108 | """ 109 | Get the few-shot classification method (e.g. pt_map). 110 | """ 111 | 112 | if args.method.lower() in METHODS.keys(): 113 | return METHODS[args.method.lower()](args=vars(args), bpa=bpa) 114 | else: 115 | raise ValueError(f'Not implemented method. available methods are: {METHODS.keys()}') 116 | 117 | 118 | def get_criterion_by_method(method: str): 119 | """ 120 | Get loss function based on the method. 121 | """ 122 | 123 | if 'pt_map' in method: 124 | return torch.nn.NLLLoss() 125 | elif 'proto' in method: 126 | return torch.nn.CrossEntropyLoss() 127 | else: 128 | raise ValueError(f'Not implemented criterion for this method. available methods are: {list(METHODS.keys())}') 129 | 130 | 131 | def init_wandb(exp_name: str, args): 132 | """ 133 | Initialize and returns wandb logger if args.wandb is True. 134 | """ 135 | if not args.wandb: 136 | return None 137 | assert HAS_WANDB, "Install wandb via - 'pip install wandb' in order to use wandb logging. " 138 | logger = wandb.init(project=args.project, entity=args.entity, name=exp_name, config=vars(args)) 139 | # define which metrics will be plotted against it 140 | logger.define_metric("train_loss", step_metric="epoch") 141 | logger.define_metric("train_accuracy", step_metric="epoch") 142 | logger.define_metric("val_loss", step_metric="epoch") 143 | logger.define_metric("val_accuracy", step_metric="epoch") 144 | return logger 145 | 146 | 147 | def wandb_log(results: dict): 148 | """ 149 | Log step to the logger without print. 150 | """ 151 | if HAS_WANDB and wandb.run is not None: 152 | wandb.log(results) 153 | 154 | 155 | def get_output_dir(args: argparse): 156 | """ 157 | Initialize the output dir. 158 | """ 159 | 160 | if args.checkpoint_dir is None: 161 | checkpoint_dir = os.path.join(args.root_path, 'checkpoints', args.dataset.lower(), args.backbone.lower(), args.method.lower()) 162 | 163 | name_str = f'-n_way={args.train_way}' \ 164 | f'-n_shot={args.num_shot}' \ 165 | f'-lr={args.lr}' \ 166 | f'-scheduler={args.scheduler}' \ 167 | f'-dropout={args.dropout}' 168 | 169 | checkpoint_dir = os.path.join(checkpoint_dir, name_str) 170 | else: 171 | checkpoint_dir = args.checkpoint_dir 172 | 173 | if args.eval: 174 | return checkpoint_dir 175 | 176 | while os.path.exists(checkpoint_dir): 177 | checkpoint_dir += f'-{np.random.randint(100)}' 178 | 179 | os.makedirs(checkpoint_dir, exist_ok=True) 180 | 181 | # write args to a file 182 | with open(os.path.join(checkpoint_dir, "args.txt"), 'w') as f: 183 | for key, value in vars(args).items(): 184 | f.write('%s:%s\n' % (key, value)) 185 | 186 | print("=> Checkpoints will be saved at:\n", checkpoint_dir) 187 | 188 | return checkpoint_dir 189 | 190 | 191 | def load_weights(model: torch.nn.Module, pretrained_path: str): 192 | """ 193 | Load pretrained weights from given path. 194 | """ 195 | if not pretrained_path: 196 | return model 197 | 198 | print(f'Loading weights from {pretrained_path}') 199 | state_dict = torch.load(pretrained_path) 200 | sd_keys = list(state_dict.keys()) 201 | if 'state' in sd_keys: 202 | state_dict = state_dict['state'] 203 | for k in list(state_dict.keys()): 204 | if k.startswith('module.'): 205 | state_dict["{}".format(k[len('module.'):])] = state_dict[k] 206 | del state_dict[k] 207 | 208 | model.load_state_dict(state_dict, strict=False) 209 | 210 | elif 'params' in sd_keys: 211 | state_dict = state_dict['params'] 212 | for k in list(state_dict.keys()): 213 | if k.startswith('encoder.'): 214 | state_dict["{}".format(k[len('encoder.'):])] = state_dict[k] 215 | 216 | del state_dict[k] 217 | 218 | model.load_state_dict(state_dict, strict=True) 219 | else: 220 | model.load_state_dict(state_dict) 221 | 222 | print("Weights loaded successfully ") 223 | return model 224 | 225 | 226 | def get_fs_labels(method: str, num_way: int, num_query: int, num_shot: int): 227 | """ 228 | Prepare few-shot labels. For example for 5-way, 1-shot, 2-query: [0, 1, 2, 3, 4, 0, 1, 2, 3, 4, ...] 229 | """ 230 | n_samples = num_shot + num_query if 'map' in method else num_query 231 | labels = torch.arange(num_way, dtype=torch.int16).repeat(n_samples).type(torch.LongTensor) 232 | 233 | if torch.cuda.is_available(): 234 | return labels.cuda() 235 | else: 236 | return labels 237 | 238 | 239 | def bool_flag(s): 240 | """ 241 | Parse boolean arguments from the command line. 242 | """ 243 | FALSY_STRINGS = {"off", "false", "0"} 244 | TRUTHY_STRINGS = {"on", "true", "1"} 245 | if s.lower() in FALSY_STRINGS: 246 | return False 247 | elif s.lower() in TRUTHY_STRINGS: 248 | return True 249 | else: 250 | raise argparse.ArgumentTypeError("invalid value for a boolean flag") 251 | 252 | 253 | def print_and_log(results: dict, n: int = 0): 254 | """ 255 | Print and log current results. 256 | """ 257 | for key in results.keys(): 258 | # average by n if needed (n > 0) 259 | if n > 0 and 'time' not in key and '/epoch' not in key: 260 | results[key] = results[key] / n 261 | 262 | # print and log 263 | print(f'{key}: {results[key]:.4f}') 264 | 265 | if wandb.run is not None: 266 | wandb.log(results) 267 | 268 | 269 | def set_seed(seed: int): 270 | """ 271 | seed. 272 | """ 273 | random.seed(seed) 274 | np.random.seed(seed) 275 | torch.random.manual_seed(seed) 276 | torch.cuda.manual_seed(seed) 277 | 278 | 279 | class bcolors: 280 | HEADER = '\033[95m' 281 | OKBLUE = '\033[94m' 282 | OKCYAN = '\033[96m' 283 | OKGREEN = '\033[92m' 284 | WARNING = '\033[93m' 285 | FAIL = '\033[91m' 286 | ENDC = '\033[0m' 287 | BOLD = '\033[1m' 288 | UNDERLINE = '\033[4m' 289 | 290 | 291 | class SmoothedValue(object): 292 | """Track a series of values and provide access to smoothed values over a 293 | window or the global series average. 294 | """ 295 | 296 | def __init__(self, window_size=20, fmt=None): 297 | if fmt is None: 298 | fmt = "{median:.6f} ({global_avg:.6f})" 299 | self.deque = deque(maxlen=window_size) 300 | self.total = 0.0 301 | self.count = 0 302 | self.fmt = fmt 303 | 304 | def update(self, value, n=1): 305 | self.deque.append(value) 306 | self.count += n 307 | self.total += value * n 308 | 309 | @property 310 | def median(self): 311 | d = torch.tensor(list(self.deque)) 312 | return d.median().item() 313 | 314 | @property 315 | def avg(self): 316 | d = torch.tensor(list(self.deque), dtype=torch.float32) 317 | return d.mean().item() 318 | 319 | @property 320 | def global_avg(self): 321 | return self.total / self.count 322 | 323 | @property 324 | def max(self): 325 | return max(self.deque) 326 | 327 | @property 328 | def value(self): 329 | return self.deque[-1] 330 | 331 | def __str__(self): 332 | return self.fmt.format( 333 | median=self.median, 334 | avg=self.avg, 335 | global_avg=self.global_avg, 336 | max=self.max, 337 | value=self.value) 338 | 339 | 340 | class MetricLogger(object): 341 | def __init__(self, delimiter="\t"): 342 | self.meters = defaultdict(SmoothedValue) 343 | self.delimiter = delimiter 344 | 345 | def update(self, **kwargs): 346 | for k, v in kwargs.items(): 347 | if isinstance(v, torch.Tensor): 348 | v = v.item() 349 | assert isinstance(v, (float, int)) 350 | self.meters[k].update(v) 351 | 352 | def __getattr__(self, attr): 353 | if attr in self.meters: 354 | return self.meters[attr] 355 | if attr in self.__dict__: 356 | return self.__dict__[attr] 357 | raise AttributeError("'{}' object has no attribute '{}'".format( 358 | type(self).__name__, attr)) 359 | 360 | def __str__(self): 361 | loss_str = [] 362 | for name, meter in self.meters.items(): 363 | loss_str.append( 364 | "{}: {}".format(name, str(meter)) 365 | ) 366 | return self.delimiter.join(loss_str) 367 | 368 | def add_meter(self, name, meter): 369 | self.meters[name] = meter 370 | 371 | def log_every(self, iterable, print_freq, header=None): 372 | i = 0 373 | if not header: 374 | header = '' 375 | 376 | start_time = time.time() 377 | end = time.time() 378 | iter_time = SmoothedValue(fmt='{avg:.6f}') 379 | data_time = SmoothedValue(fmt='{avg:.6f}') 380 | space_fmt = ':' + str(len(str(len(iterable)))) + 'd' 381 | if torch.cuda.is_available(): 382 | log_msg = self.delimiter.join([ 383 | header, 384 | '[{0' + space_fmt + '}/{1}]', 385 | 'eta: {eta}', 386 | '{meters}', 387 | 'time: {time}', 388 | 'data: {data}', 389 | 'mem: {memory:.0f} ' 390 | 'mem reserved: {memory_res:.0f} ' 391 | ]) 392 | else: 393 | log_msg = self.delimiter.join([ 394 | header, 395 | '[{0' + space_fmt + '}/{1}]', 396 | 'eta: {eta}', 397 | '{meters}', 398 | 'time: {time}', 399 | 'data: {data}' 400 | ]) 401 | MB = 1024.0 * 1024.0 402 | for obj in iterable: 403 | data_time.update(time.time() - end) 404 | yield obj 405 | iter_time.update(time.time() - end) 406 | len_iterable = len(iterable) 407 | if i % print_freq == 0 or i == len_iterable - 1: 408 | eta_seconds = iter_time.global_avg * (len_iterable - i) 409 | eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) 410 | if torch.cuda.is_available(): 411 | print(log_msg.format( 412 | i, len_iterable, eta=eta_string, 413 | meters=str(self), 414 | time=str(iter_time), data=str(data_time), 415 | memory=torch.cuda.memory_allocated() / MB, 416 | memory_res=torch.cuda.memory_reserved() / MB)) 417 | else: 418 | print(log_msg.format( 419 | i, len_iterable, eta=eta_string, 420 | meters=str(self), 421 | time=str(iter_time), data=str(data_time))) 422 | i += 1 423 | end = time.time() 424 | total_time = time.time() - start_time 425 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 426 | print('{} Total time: {} ({:.6f} s / it)'.format( 427 | header, total_time_str, total_time / len(iterable))) 428 | --------------------------------------------------------------------------------