├── dataset ├── ImagePandasDataset.py ├── NShotTaskSampler.py ├── car.py ├── cub.py ├── dog.py ├── loaders.py ├── miniimagenet.py └── nab.py ├── losses.py ├── models ├── __init__.py.py ├── contrast.py ├── encoder.py └── utils.py ├── pictures ├── 1.png ├── 2.png ├── 3.png └── 4.png ├── readme.md └── train_contrastive.py /dataset/ImagePandasDataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from torch.utils.data import Dataset 4 | from torchvision.datasets.folder import default_loader 5 | 6 | 7 | class ImagePandasDataset(Dataset): 8 | def __init__(self, df, 9 | img_key, label_key, img_root="", 10 | transform=None, 11 | target_transform=None, 12 | loader=default_loader): 13 | ''' 14 | df: pandas dataframe of this dataset 15 | img_key: column name of storing image path in the dataframe 16 | label_key: column name of storing labels in the daraframe 17 | transform: preprpcessing for img 18 | target_transform: preprocessing for labels 19 | ''' 20 | self.df = df.sort_values(by=[img_key]) 21 | self.img_key = img_key 22 | self.label_key = label_key 23 | self.img_root = img_root 24 | self.transform = transform 25 | self.target_transform = target_transform 26 | self.loader = default_loader 27 | self.num_classes = len(set([self.get_label_idx(i) for i in range(len(self))])) 28 | 29 | def __getitem__(self, i): 30 | ''' 31 | get (img,label_idx) pair of i-th data point 32 | img is already preprocessed 33 | label_idx start from 0 incrementally 34 | That is, they can be used for cnn input directly 35 | ''' 36 | return {"input": self.get_img(i), "label": self.get_label_idx(i)} 37 | 38 | def get_img_path(self, i): 39 | ''' 40 | get img_path of i-th data point 41 | ''' 42 | return os.path.join(self.img_root, str(self.df.iloc[i][self.img_key])) 43 | 44 | def get_img(self, i): 45 | ''' 46 | get img array of i-th data point 47 | self.transform is applied if exists 48 | ''' 49 | img = self.loader(self.get_img_path(i)) 50 | if self.transform is not None: 51 | img = self.transform(img) 52 | return img 53 | 54 | def get_label(self, i): 55 | ''' 56 | get label of i-th data point as it is. 57 | ''' 58 | return self.df.iloc[i][self.label_key] 59 | 60 | def get_label_idx(self, i): 61 | ''' 62 | get label idx, which start from 0 incrementally 63 | self.target_transform is applied if exists 64 | ''' 65 | label = self.get_label(i) 66 | if self.target_transform is not None: 67 | if isinstance(self.target_transform, dict): 68 | label_idx = self.target_transform[label] 69 | else: 70 | label_idx = self.target_transform(label) 71 | else: 72 | label_idx = int(label) 73 | return label_idx 74 | 75 | def __len__(self): 76 | return len(self.df) 77 | -------------------------------------------------------------------------------- /dataset/NShotTaskSampler.py: -------------------------------------------------------------------------------- 1 | # this is taken from here: https://github.com/oscarknagg/few-shot/blob/2830f29e757a0db35f998112f5c866ab365f6ef2/few_shot/core.py 2 | import torch 3 | import numpy as np 4 | from torch.utils.data import Sampler 5 | 6 | 7 | def make_label_idx2indicies(dataset): 8 | label_idx2indicies = {} 9 | for i in range(len(dataset)): 10 | label_idx = dataset.get_label_idx(i) 11 | if label_idx in label_idx2indicies: 12 | label_idx2indicies[label_idx].append(i) 13 | else: 14 | label_idx2indicies[label_idx] = [i] 15 | return label_idx2indicies 16 | 17 | 18 | class NShotTaskSampler(Sampler): 19 | def __init__(self, 20 | dataset: torch.utils.data.Dataset, 21 | episodes_per_epoch: int, 22 | n_shot: int, 23 | n_way: int, 24 | n_query: int, 25 | ): 26 | """ 27 | PyTorch Sampler subclass that generates batches of n-shot, k-way, q-query tasks. 28 | Each n-shot task contains a "support set" of `k` sets of `n` samples and a "query set" of `k` sets 29 | of `q` samples. The support set and the query set are all grouped into one Tensor such that the first n * k 30 | samples are from the support set while the remaining q * k samples are from the query set. 31 | The support and query sets are sampled such that they are disjoint i.e. do not contain overlapping samples. 32 | # Arguments 33 | dataset: Instance of torch.utils.data.Dataset from which to draw samples 34 | episodes_per_epoch: Arbitrary number of batches of n-shot tasks to generate in one epoch 35 | n_shot: int. Number of samples for each class in the n-shot classification tasks. 36 | n_way: int. Number of classes in the n-shot classification tasks. 37 | n_query: int. Number query samples for each class in the n-shot classification tasks. 38 | """ 39 | super(NShotTaskSampler, self).__init__(dataset) 40 | self.n_way = n_way 41 | self.n_shot = n_shot 42 | self.n_query = n_query 43 | self.episodes_per_epoch = episodes_per_epoch 44 | self.dataset = dataset 45 | self.label_idx2indicies = make_label_idx2indicies(dataset) 46 | 47 | def __len__(self): 48 | return self.episodes_per_epoch 49 | 50 | def __iter__(self): 51 | for _ in range(self.episodes_per_epoch): 52 | # choose classes to use in this episode 53 | episode_classes = np.random.choice( 54 | list(self.label_idx2indicies.keys()), 55 | self.n_way, replace=False) 56 | 57 | # choose indicies to use in this example 58 | # select support + query 59 | support_indicies = [] 60 | query_indicies = [] 61 | for cls_idx in episode_classes: 62 | try: 63 | sample_indicies = np.random.choice( 64 | self.label_idx2indicies[cls_idx], 65 | self.n_shot + self.n_query, replace=False) 66 | except: 67 | import warnings 68 | warnings.warn("Ops! proably not enough samples, so had to allow duplicates for this one.") 69 | sample_indicies = np.random.choice( 70 | self.label_idx2indicies[cls_idx], 71 | self.n_shot + self.n_query, replace=True) 72 | 73 | support_indicies += sample_indicies[0:self.n_shot].tolist() 74 | query_indicies += sample_indicies[self.n_shot:self.n_shot + self.n_query].tolist() 75 | 76 | batch = support_indicies + query_indicies 77 | yield np.stack(batch) 78 | -------------------------------------------------------------------------------- /dataset/car.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Every dataset should make df_dataset where 3 | 4 | df_dataset = { 5 | "train": 6 | "val": 7 | "test": 8 | } 9 | ''' 10 | import os 11 | import json 12 | import numpy as np 13 | import pandas as pd 14 | import glob 15 | 16 | def setup_df(dataset_root="./data/car/"): 17 | table = [] 18 | for path in glob.glob(os.path.join(dataset_root, "images/*/*.jpg")): 19 | print(path) 20 | label = path.split("/")[-2] 21 | table.append({"label": label, "path": path, "img_name": path.split("/")[-1]}) 22 | df_all = pd.DataFrame(table) 23 | df_all = df_all.sort_values(["path"]) 24 | labels = df_all["label"].unique().tolist() 25 | labels.sort() 26 | label2split = {} 27 | 28 | for i, label in enumerate(labels): 29 | if i % 2 == 0: 30 | label2split[label] = "train" 31 | elif i % 4 == 3: 32 | label2split[label] = "val" 33 | elif i % 4 == 1: 34 | label2split[label] = "test" 35 | df_all["split"] = df_all['label'].map(label2split) 36 | 37 | df_dataset = {} 38 | for split in ["train", "val", "test"]: 39 | df = df_all.query("split=='%s'" % split) 40 | df = df.sort_values(by=['path']) 41 | df_dataset[split] = df 42 | 43 | return df_dataset 44 | -------------------------------------------------------------------------------- /dataset/cub.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Every dataset should make df_dataset where 3 | 4 | df_dataset = { 5 | "train": 6 | "val": 7 | "test": 8 | } 9 | ''' 10 | import os 11 | import json 12 | import numpy as np 13 | import pandas as pd 14 | import glob 15 | 16 | # download the dataset from http://www.vision.caltech.edu/visipedia/CUB-200-2011.html 17 | # exract it and name as ./data/cub 18 | 19 | # $wget http://www.vision.caltech.edu/visipedia-data/CUB-200-2011/CUB_200_2011.tgz 20 | # $tar -xvz CUB_200_2011.tgz 21 | # mv CUB_200_2011 cub 22 | 23 | def setup_df(dataset_root="./data/cub/"): 24 | table = [] 25 | for path in glob.glob(os.path.join(dataset_root, "images/*/*.jpg")): 26 | print(path) 27 | label = path.split("/")[-2] 28 | table.append({"label": label, "path": path, "img_name": path.split("/")[-1]}) 29 | df_all = pd.DataFrame(table) 30 | df_all = df_all.sort_values(["path"]) 31 | labels = df_all["label"].unique().tolist() 32 | labels.sort() 33 | label2split = {} 34 | 35 | # data split coming from https://github.com/wyharveychen/CloserLookFewShot/blob/4ca19c75147b49a50fee7b2277971a2299c3d231/filelists/CUB/write_CUB_filelist.py#L32-L43 36 | for i, label in enumerate(labels): 37 | if i % 2 == 0: 38 | label2split[label] = "train" 39 | elif i % 4 == 1: 40 | label2split[label] = "val" 41 | elif i % 4 == 3: 42 | label2split[label] = "test" 43 | df_all["split"] = df_all['label'].map(label2split) 44 | 45 | df_dataset = {} 46 | for split in ["train", "val", "test"]: 47 | df = df_all.query("split=='%s'" % split) 48 | df = df.sort_values(by=['path']) 49 | df_dataset[split] = df 50 | 51 | return df_dataset 52 | -------------------------------------------------------------------------------- /dataset/dog.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Every dataset should make df_dataset where 3 | 4 | df_dataset = { 5 | "train": 6 | "val": 7 | "test": 8 | } 9 | ''' 10 | import os 11 | import json 12 | import numpy as np 13 | import pandas as pd 14 | import glob 15 | 16 | def setup_df(dataset_root="./data/dog/"): 17 | table = [] 18 | for path in glob.glob(os.path.join(dataset_root, "images/*/*.jpg")): 19 | print(path) 20 | label = path.split("/")[-2] 21 | table.append({"label": label, "path": path, "img_name": path.split("/")[-1]}) 22 | df_all = pd.DataFrame(table) 23 | df_all = df_all.sort_values(["path"]) 24 | labels = df_all["label"].unique().tolist() 25 | labels.sort() 26 | label2split = {} 27 | 28 | # for i, label in enumerate(labels): 29 | # if i % 2 == 0: 30 | # label2split[label] = "train" 31 | # elif i % 4 == 3: 32 | # label2split[label] = "val" 33 | # elif i % 4 == 1: 34 | # label2split[label] = "test" 35 | 36 | train_classes = 80 37 | val_classes = 10 38 | test_classes = 30 39 | 40 | for i, label in enumerate(labels): 41 | if i < train_classes: 42 | label2split[label] = "train" 43 | elif i < train_classes + val_classes: 44 | label2split[label] = "val" 45 | else: 46 | label2split[label] = "test" 47 | 48 | df_all["split"] = df_all['label'].map(label2split) 49 | 50 | df_dataset = {} 51 | for split in ["train", "val", "test"]: 52 | df = df_all.query("split=='%s'" % split) 53 | df = df.sort_values(by=['path']) 54 | df_dataset[split] = df 55 | 56 | return df_dataset 57 | -------------------------------------------------------------------------------- /dataset/loaders.py: -------------------------------------------------------------------------------- 1 | def setup_dataset(args): 2 | # setup dataset as pandas data frame 3 | dataset = getattr(__import__("datasets.%s" % args.dataset), args.dataset) 4 | dataset_root = "./data/%s" % args.dataset 5 | if args.dataset_root is not None: 6 | dataset_root = args.dataset_root 7 | df_dict = dataset.setup_df(dataset_root) 8 | 9 | dataset_dict = {} 10 | # key is train/val/test and the value is corresponding pytorch dataset 11 | for split, df in df_dict.items(): 12 | target_transform = {label: i for i, label in enumerate(sorted(df["label"].unique()))} 13 | # target_transform is mapping from category name to category idx start from 0 14 | if split == "train": 15 | transform = transforms.Compose([ 16 | transforms.Resize(256), 17 | transforms.CenterCrop((224, 224)), 18 | transforms.RandomHorizontalFlip(), 19 | transforms.ToTensor(), 20 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 21 | ]) 22 | else: 23 | transform = transforms.Compose([ 24 | transforms.Resize(256), 25 | transforms.CenterCrop((224, 224)), 26 | transforms.ToTensor(), 27 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 28 | ]) 29 | dataset_dict[split] = ImagePandasDataset(df=df, 30 | img_key="path", 31 | label_key="label", 32 | transform=transform, 33 | target_transform=target_transform, 34 | ) 35 | return dataset_dict 36 | 37 | 38 | def setup_dataloader(args, dataset_dic): 39 | dataloader_dict = {} 40 | episodes_dict = {"train": args.episodes_train, "val": args.episodes_val, "test": args.episodes_test} 41 | for split, dataset in dataset_dic.items(): 42 | episodes = episodes_dict[split] 43 | 44 | if split == "train": 45 | nway = args.nway 46 | else: 47 | nway = args.nway_eval 48 | dataloader_dict[split] = DataLoader( 49 | dataset, 50 | batch_sampler=NShotTaskSampler( 51 | dataset, 52 | episodes, 53 | args.nshot, 54 | nway, 55 | args.nquery, 56 | ), 57 | num_workers=args.workers, 58 | ) 59 | return dataloader_dict -------------------------------------------------------------------------------- /dataset/miniimagenet.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Every dataset should make df_dataset where 3 | 4 | df_dataset = { 5 | "train": 6 | "val": 7 | "test": 8 | } 9 | ''' 10 | import os 11 | import json 12 | import numpy as np 13 | import pandas as pd 14 | 15 | def setup_df(dataset_root = "./data/miniimagenet/"): 16 | df_dataset = {} 17 | img_dir_path = os.path.join(dataset_root,"images/") 18 | for split in ["train","val","test"]: 19 | df = pd.read_csv(os.path.join(dataset_root,"%s.csv"%split)) 20 | df["path"] = img_dir_path+df["filename"] 21 | del df["filename"] 22 | df_dataset[split] = df 23 | return df_dataset -------------------------------------------------------------------------------- /dataset/nab.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Every dataset should make df_dataset where 3 | 4 | df_dataset = { 5 | "train": 6 | "val": 7 | "test": 8 | } 9 | ''' 10 | import os 11 | import json 12 | import numpy as np 13 | import pandas as pd 14 | import glob 15 | 16 | def setup_df(dataset_root = "./data/nab/"): 17 | table = [] 18 | for path in glob.glob(dataset_root+"/images/*/*.jpg"): 19 | print(path) 20 | label = path.split("/")[-2] 21 | table.append({"label":label,"path":path,"img_name":path.split("/")[-1]}) 22 | df_all = pd.DataFrame(table) 23 | df_all = df_all.sort_values(["path"]) 24 | labels = df_all["label"].unique().tolist() 25 | labels.sort() 26 | label2split = {} 27 | 28 | #this split is our own way. not used in previous work before. 29 | for i,label in enumerate(labels): 30 | if i%2==0: 31 | label2split[label]="train" 32 | elif i%4==3: 33 | label2split[label]="val" 34 | elif i%4==1: 35 | label2split[label]="test" 36 | df_all["split"]= df_all['label'].map(label2split) 37 | 38 | 39 | df_dataset = {} 40 | for split in ["train","val","test"]: 41 | df = df_all.query("split=='%s'"%split) 42 | df = df.sort_values(by=['path']) 43 | df_dataset[split] = df 44 | 45 | return df_dataset -------------------------------------------------------------------------------- /losses.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | def contrast_distill(f1, f2): 7 | """ 8 | Contrastive Distillation 9 | """ 10 | f1 = F.normalize(f1, dim=1, p=2) 11 | f2 = F.normalize(f2, dim=1, p=2) 12 | loss = 2 - 2 * (f1 * f2).sum(dim=-1) 13 | return loss.mean() 14 | 15 | 16 | class DistillKL(nn.Module): 17 | """ 18 | KL divergence for distillation 19 | """ 20 | 21 | def __init__(self, T): 22 | super(DistillKL, self).__init__() 23 | self.T = T 24 | 25 | def forward(self, y_s, y_t): 26 | p_s = F.log_softmax(y_s / self.T, dim=1) 27 | p_t = F.softmax(y_t / self.T, dim=1) 28 | loss = F.kl_div(p_s, p_t, size_average=False) * (self.T ** 2) / y_s.shape[0] 29 | return loss 30 | 31 | 32 | class ContrastiveLoss(nn.Module): 33 | """ 34 | Contrastive Loss (based on https://github.com/HobbitLong/SupContrast) 35 | """ 36 | 37 | def __init__(self, temperature=None): 38 | super(ContrastiveLoss, self).__init__() 39 | self.temperature = temperature 40 | 41 | def _compute_logits(self, features_a, features_b, attention=None): 42 | # global similarity 43 | if features_a.dim() == 2: 44 | features_a = F.normalize(features_a, dim=1, p=2) 45 | features_b = F.normalize(features_b, dim=1, p=2) 46 | contrast = torch.matmul(features_a, features_b.T) 47 | 48 | # spatial similarity 49 | elif features_a.dim() == 4: 50 | contrast = attention(features_a, features_b) 51 | 52 | else: 53 | raise ValueError 54 | 55 | # note here we use inverse temp 56 | contrast = contrast * self.temperature 57 | return contrast 58 | 59 | def forward(self, features_a, features_b=None, labels=None, attention=None): 60 | device = (torch.device('cuda') if features_a.is_cuda else torch.device('cpu')) 61 | num_features, num_labels = features_a.shape[0], labels.shape[0] 62 | 63 | # using only the current features in a given batch 64 | if features_b is None: 65 | features_b = features_a 66 | # mask to remove self contrasting 67 | logits_mask = (1. - torch.eye(num_features)).to(device) 68 | else: 69 | # contrasting different features (a & b), no need to mask the diagonal 70 | logits_mask = torch.ones(num_features, num_features).to(device) 71 | 72 | # mask to only maintain positives 73 | if labels is None: 74 | # standard self supervised case 75 | mask = torch.eye(num_labels, dtype=torch.float32).to(device) 76 | else: 77 | labels = labels.contiguous().view(-1, 1) 78 | mask = torch.eq(labels, labels.T).float().to(device) 79 | 80 | # replicate the mask since the labels are just for N examples 81 | if num_features != num_labels: 82 | assert num_labels * 2 == num_features 83 | mask = mask.repeat(2, 2) 84 | 85 | # compute logits 86 | contrast = self._compute_logits(features_a, features_b, attention) 87 | 88 | # remove self contrasting 89 | mask = mask * logits_mask 90 | 91 | # normalization over number of positives 92 | normalization = mask.sum(1) 93 | normalization[normalization == 0] = 1. 94 | 95 | # for stability 96 | logits_max, _ = torch.max(contrast, dim=1, keepdim=True) 97 | logits = contrast - logits_max.detach() 98 | exp_logits = torch.exp(logits) 99 | 100 | exp_logits = exp_logits * logits_mask 101 | log_prob = logits - torch.log(exp_logits.sum(1, keepdim=True)) 102 | 103 | # compute mean of log-likelihood over positive 104 | mean_log_prob_pos = (mask * log_prob).sum(1) / normalization 105 | loss = -mean_log_prob_pos.mean() 106 | 107 | return loss 108 | -------------------------------------------------------------------------------- /models/__init__.py.py: -------------------------------------------------------------------------------- 1 | from .convnet import convnet4 2 | from .resnet import resnet12, resnet18, resnet24 3 | from .resnet import resnet24 4 | from .resnet import seresnet12 5 | from .wresnet import wrn_28_10 6 | from .resnet_standard import resnet50 7 | 8 | model_pool = [ 9 | 'convnet4', 10 | 'resnet12', 11 | 'resnet18', 12 | 'resnet24', 13 | 'seresnet12', 14 | 'wrn_28_10', 15 | 'resnet50', 16 | ] 17 | 18 | model_dict = { 19 | 'wrn_28_10': wrn_28_10, 20 | 'convnet4': convnet4, 21 | 'resnet12': resnet12, 22 | 'resnet18': resnet18, 23 | 'resnet24': resnet24, 24 | 'seresnet12': seresnet12, 25 | 'resnet50': resnet50, 26 | } -------------------------------------------------------------------------------- /models/contrast.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.optim as optim 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from models.util import create_model 6 | 7 | 8 | class Projection(nn.Module): 9 | """ 10 | projection head 11 | """ 12 | 13 | def __init__(self, dim, projection_size, hidden_size): 14 | super().__init__() 15 | self.net = nn.Sequential( 16 | nn.Linear(dim, hidden_size, bias=False), 17 | nn.ReLU(inplace=True), 18 | nn.Linear(hidden_size, projection_size, bias=False)) 19 | 20 | def forward(self, x): 21 | return self.net(x) 22 | 23 | 24 | class ContrastResNet(nn.Module): 25 | """ 26 | defining the backbone and projection head 27 | """ 28 | 29 | def __init__(self, opt, n_cls): 30 | super(ContrastResNet, self).__init__() 31 | 32 | self.encoder = create_model(opt.model, dataset=opt.dataset) 33 | dim_in = self.encoder.feat_dim 34 | projection_size = opt.feat_dim 35 | self.head = Projection(dim=dim_in, projection_size=projection_size, hidden_size=dim_in) 36 | 37 | self.global_cont_loss = opt.global_cont_loss 38 | self.spatial_cont_loss = opt.spatial_cont_loss 39 | 40 | def forward(self, x): 41 | # forward pass through the embedding model, feat is a list of features 42 | feat, outputs = self.encoder(x, is_feat=True) 43 | 44 | # spatial features before avg pool 45 | spatial_f = feat[-2] 46 | 47 | # global features after avg pool 48 | avg_pool_feat = feat[-1] 49 | 50 | # projected global features 51 | global_f = self.head(avg_pool_feat) 52 | return outputs, spatial_f, global_f, avg_pool_feat -------------------------------------------------------------------------------- /models/encoder.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | def setup_backbone(name, pretrained=True): 4 | if name == "resnet18": 5 | model = torchvision.models.resnet18(pretrained=False) 6 | model.load_state_dict(torch.load('/groups/public_cluster/home/samuel/projects/MyNet/pretrained_weights/resnet18-5c106cde.pth')) 7 | model.fc = Flatten() 8 | return model 9 | elif name == "ViT": 10 | model = create_model(num_classes=1000, has_logits=False) 11 | model.load_state_dict( 12 | torch.load('/groups/public_cluster/home/samuel/projects/MyNet/pretrained_weights/jx_vit_base_p16_224-80ecf9dd.pth')) 13 | return model 14 | elif name == "swin": 15 | model = swin(num_classes=1000, has_logits=False) 16 | model.load_state_dict( 17 | torch.load( '/groups/public_cluster/home/samuel/projects/MyNet/pretrained_weights/swin_base_patch4_window7_224.pth')["model"], strict=False) 18 | return model 19 | else: 20 | raise NotImplementedError("this option is not defined") -------------------------------------------------------------------------------- /models/utils.py: -------------------------------------------------------------------------------- 1 | from . import model_dict 2 | 3 | 4 | def create_model(name): 5 | """create model by name""" 6 | if name.endswith('v2') or name.endswith('v3'): 7 | model = model_dict[name](num_classes=n_cls) 8 | elif name.startswith('resnet50'): 9 | print('use imagenet-style resnet50') 10 | model = model_dict[name](num_classes=n_cls) 11 | elif name.startswith('resnet') or name.startswith('seresnet'): 12 | model = model_dict[name](avg_pool=True, drop_rate=0.1, dropblock_size=5, num_classes=n_cls) 13 | elif name.startswith('wrn'): 14 | model = model_dict[name](num_classes=n_cls) 15 | elif name.startswith('convnet'): 16 | model = model_dict[name](num_classes=n_cls) 17 | else: 18 | raise NotImplementedError('model {} not supported in dataset {}:'.format(name, dataset)) 19 | return model 20 | 21 | -------------------------------------------------------------------------------- /pictures/1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GZU-SAMLab/McQNet/0467100d7e0034bbc12e9cc14d72abccfc13c8f5/pictures/1.png -------------------------------------------------------------------------------- /pictures/2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GZU-SAMLab/McQNet/0467100d7e0034bbc12e9cc14d72abccfc13c8f5/pictures/2.png -------------------------------------------------------------------------------- /pictures/3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GZU-SAMLab/McQNet/0467100d7e0034bbc12e9cc14d72abccfc13c8f5/pictures/3.png -------------------------------------------------------------------------------- /pictures/4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GZU-SAMLab/McQNet/0467100d7e0034bbc12e9cc14d72abccfc13c8f5/pictures/4.png -------------------------------------------------------------------------------- /readme.md: -------------------------------------------------------------------------------- 1 | # Meta-contrastive Learning with Support-based Query Interaction for Few-shot Fine-grained Visual Classification 2 | 3 | > Few-shot fine-grained visual classification combines the challenge of scarcity of training data for few-shot tasks and intrinsic property with inter-intra variance for fine-grained tasks. To alleviate this problem, we first propose a support-based query interaction module and a meta-contrastive method. Extensive experiments demonstrate that our method performs well on both common fine-grained datasets and few-shot datasets in different experimental settings. 4 | 5 | # Our Major Contributions 6 | 7 | > 1. We propose a novel meta-contrastive framework for finegrained few-shot learning, further improving the performance of few-shot learning and fine-grained tasks. This may be the first to apply contrastive learning to finegrained few-shot tasks. 8 | > 2. We develop a support-based query interaction method that not only establishes the dependencies between support samples and query samples, but also enhances query features by leveraging support features. 9 | > 3. We conduct comprehensive experiments on three benchmark fine-grained datasets and a common few-shot benchmark dataset. Our experimental results achieve a SOTA performance on these datasets. 10 | 11 | # Architecture of our proposed Network 12 | 13 | ![image-20231016210225341](./pictures/1.png) 14 | 15 | > The network is composed of three main parts, including the feature extractor module, the support-based query interaction module (SQIM), and the meta-contrastive module (MCM). Among them, SQIM and MCM are the core of the proposed method. Indeed, SQIM mines latent knowledge by interacting with query features and support features. MCM enhances the model's learning ability to perform fine-grained few-shot tasks by incorporating the concept of contrastive learning. 16 | 17 | # Support-based Query Interaction Module 18 | 19 | image-20231016212731788 20 | 21 | > The detailed flow of support-based query interaction module. 22 | 23 | # Meta-Contrastive Module 24 | 25 | image-20231016212851060 26 | 27 | > The detailed flow of meta-contrastive module. 28 | 29 | # Environment Requirements 30 | 31 | > To run the project, you need the following environment configuration 32 | 33 | `````` 34 | numpy==1.21.5 35 | pandas==1.2.4 36 | resnet==0.1 37 | torch==1.10.2 38 | torchvision==0.11.3 39 | tqdm==4.66.1 40 | `````` 41 | # Dataset 42 | download the dataset from [http://www.vision.caltech.edu/visipedia/CUB-200-2011.html](http://www.vision.caltech.edu/datasets/cub_200_2011/)
43 | >Caltech-UCSD Birds-200-2011 (CUB-200-2011) is an extended version of the CUB-200 dataset, with roughly double the number of images per class and new part location annotations. 44 | >1. Number of categories: 200 45 | >2. Number of images: 11,788 46 | >3. Annotations per image: 15 Part Locations, 312 Binary Attributes, 1 Bounding Box
47 | 48 | Or use the Linux command line to download the dataset: 49 | ``` 50 | wget http://www.vision.caltech.edu/visipedia-data/CUB-200-2011/CUB_200_2011.tgz 51 | tar -xvz CUB_200_2011.tgz 52 | ``` 53 | # Code Structure 54 | `dataset/*`: Load the dataset and process the dataset 55 | 56 | `models/*`: Define all network models to be used 57 | 58 | `losses.py`: Calculation of losses 59 | 60 | `train_contrastive.py`: Training on the Comparison Module 61 | 62 | # Result 63 | ![image-20231016210844054](./pictures/2.png) 64 | 65 | > 5-way few-shot classification performance (%) on the CUB, NAB, and Stanford dogs datasets. The ± denotes that the results are reported with 95% confidence intervals over 1000 test episodes. The highest average accuracy of each column is marked in bold. 66 | 67 | -------------------------------------------------------------------------------- /train_contrastive.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import time 4 | from datetime import datetime 5 | import argparse 6 | from copy import deepcopy 7 | import glob 8 | import pandas as pd 9 | from tqdm import tqdm 10 | import torch 11 | import torchvision 12 | from torchvision import transforms 13 | from torch.utils.data import Dataset, DataLoader 14 | import numpy as np 15 | import random 16 | import json 17 | 18 | # imports from my own script 19 | import utils 20 | from models.contrast import ContrastResNet 21 | 22 | utils.make_deterministic(123) 23 | 24 | def setup_args(): 25 | parser = argparse.ArgumentParser('argument for training') 26 | 27 | # general 28 | parser.add_argument('--eval-freq', type=int, default=1, help="evaluate every this epochs") 29 | parser.add_argument('--print-freq', default=100, type=int, metavar='N', help='print frequency') 30 | parser.add_argument('--workers', type=int, default=4, help="number of processes to make batch worker.") 31 | parser.add_argument('--epochs', type=int, default=10, help="number of epochs. if 0, evaluation only mode") 32 | parser.add_argument('--gpu', type=int, default=0) 33 | 34 | # optimization 35 | parser.add_argument('--lr', type=float, default=1e-3, help="learning rate. default is 0.001") 36 | parser.add_argument('--steps', default=[5], nargs='+', type=int, help='decrease lr at this point') 37 | parser.add_argument('--step-facter', type=float, default=0.1, help="facter to decrease learning rate") 38 | 39 | # dataset 40 | parser.add_argument("--dataset", choices=["cub", "car", "dog", "nab","miniimagenet"], 41 | default="cub", 42 | help="Which dataset.") 43 | parser.add_argument('--backbone', type=str, default="ViT", choices=["ViT", "swin", "resnet18"], 44 | help="feature extraction newtork") 45 | parser.add_argument('--backbone-pretrained', type=int, default=1, 46 | help="use pretrained model or not for feature extraction network") 47 | 48 | # folder/route 49 | parser.add_argument('--dataset_root', type=str, default=None, 50 | help="Default is None, and ../data/ is used") 51 | 52 | ### meta setting 53 | parser.add_argument('--nway', default=5, type=int, 54 | help='class num to classify for training. this has to be more than 1 and maximum is the total number of classes') 55 | parser.add_argument('--nway-eval', default=5, type=int, 56 | help='class num to classify for evaluation. this has to be more than 1 and maximum is the total number of classes') 57 | parser.add_argument('--nshot', default=5, type=int, help='number of labeled data in each class, same as nsupport') 58 | parser.add_argument('--nquery', default=15, type=int, help='number of query point per class') 59 | 60 | parser.add_argument('--episodes-train', type=int, default=1000, help="number of episodes per epoch for train") 61 | parser.add_argument('--episodes-val', type=int, default=100, help="number of episodes for val") 62 | parser.add_argument('--episodes-test', type=int, default=1000, help="number of episodes for test") 63 | 64 | # loss weights 65 | parser.add_argument('--lambda_cls', default=0., type=float) 66 | 67 | # other setup 68 | parser.add_argument('--resume', type=str, default=None, help="metamodel checkpoint to resume") 69 | parser.add_argument('--resume-optimizer', type=str, default=None, help="optimizer checkpoint to resume") 70 | parser.add_argument('--saveroot', default="./experiments/", help='Root directory to make the output directory') 71 | parser.add_argument('--saveprefix', default="log", help='prefix to append to the name of log directory') 72 | parser.add_argument('--saveargs', 73 | default=["dataset", "nway", "nshot", "classifier", "backbone"] 74 | , nargs='+', help='args to append to the name of log directory') 75 | return parser.parse_args() 76 | 77 | def main(args): 78 | print(args) 79 | 80 | # setup dataset and dataloaders 81 | dataset_dict = setup_dataset(args) 82 | dataloader_dict = setup_dataloader(args, dataset_dict) 83 | 84 | # CE loss 85 | criterion_cls = torch.nn.CrossEntropyLoss() 86 | 87 | # create model 88 | model = ContrastResNet(args) 89 | 90 | # contrast loss 91 | criterion_contrast = ContrastiveLoss(temperature=args.temperature) 92 | 93 | # setup optimizer 94 | optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, amsgrad=True) 95 | 96 | lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=args.steps, gamma=args.step_facter) 97 | 98 | # main training 99 | for epoch in range(args.epochs): 100 | print("epoch: %d --start from 0 and end at %d" % (epoch, args.epochs - 1)) 101 | lr_scheduler.step() 102 | 103 | time1 = time.time() 104 | train_loss = train_one_epoch(args, dataloader_dict["train"], model, criterion_cls, criterion_contrast, optimizer) 105 | time2 = time.time() 106 | 107 | print('epoch: {}, total time: {:.2f}, train loss: {:.3f}'.format(epoch, time2 - time1, train_loss)) 108 | 109 | def train_one_epoch(args, dataloader, model, criterion_cls, criterion_contrast, optimizer): 110 | model.train() # Set model to training mode 111 | 112 | batch_time, data_time = AverageMeter(), AverageMeter() 113 | losses, loss_ce, loss_con = AverageMeter(), AverageMeter(), AverageMeter() 114 | 115 | nway = dataloader.batch_sampler.n_way 116 | nshot = dataloader.batch_sampler.n_shot 117 | nquery = dataloader.batch_sampler.n_query 118 | 119 | end = time.time() 120 | 121 | # training lab 122 | for i, data in enumerate(tqdm(dataloader)): 123 | data_time.update(time.time() - end) 124 | 125 | inputs = data["input"].to(device) 126 | labels = data["label"].to(device) 127 | 128 | # ===================forward===================== 129 | outputs, spatial_f, global_f, avg_pool_feat = model(inputs) 130 | 131 | # ===================Losses===================== 132 | # standard CE loss 133 | loss_cls = criterion_cls(outputs, labels) 134 | 135 | # compute contrastive loss 136 | loss_contrast = criterion_contrast(global_f, labels=labels) 137 | 138 | # compute the total loss 139 | loss = loss_contrast * opt.lambda_global + opt.lambda_cls * loss_cls 140 | 141 | # update the losses 142 | losses.update(loss.item()) 143 | loss_glo.update(loss_contrast_global.item()) 144 | loss_spa.update(loss_contrast_spatial.item()) 145 | loss_ce.update(loss_cls.item()) 146 | 147 | # ===================backward===================== 148 | optimizer.zero_grad() 149 | loss.backward() 150 | optimizer.step() 151 | 152 | return float(losses.avg) 153 | 154 | 155 | if __name__ == '__main__': 156 | args = setup_args() 157 | main(args) 158 | 159 | 160 | 161 | --------------------------------------------------------------------------------