├── .gitignore ├── LICENSE ├── README.md ├── args_mini.py ├── creat_inpaint_data.py ├── create_inpaint_data_mini.py ├── create_pureinpainting_data.py ├── fig └── framework.png ├── losses.py ├── test.py ├── torchFewShot ├── __init__.py ├── data_manager.py ├── data_manager_image_inpainting_data.py ├── data_manager_imageori.py ├── dataset_loader │ ├── __init__.py │ ├── test_image_ori_loader.py │ ├── test_inpainting_loader.py │ ├── test_loader.py │ ├── train_image_ori_loader.py │ ├── train_inpainting_loader.py │ └── train_loader.py ├── datasets │ ├── __init__.py │ └── miniImageNet.py ├── losses.py ├── models │ ├── __init__.py │ ├── cam.py │ ├── channel_wise_attention.py │ ├── net.py │ ├── net_related.py │ ├── related_net.py │ ├── related_net_spatial_attention.py │ └── resnet12.py ├── optimizers.py ├── transforms.py └── utils │ ├── __init__.py │ ├── avgmeter.py │ ├── iotools.py │ ├── logger.py │ └── torchtools.py ├── train.py ├── train.sh └── train_with_inpaint_read_from_data_fixed.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 ljjcoder 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # CSEI 2 | 3 | Learning Intact Features by Erasing-Inpainting for Few-shot Classification, [here](https://www.aaai.org/AAAI21Papers/AAAI-540.LiJ.pdf) 4 | ## Introduction 5 | 6 | In this paper, we propose to learn intact features by erasing-inpainting for few-shot 7 | classification. Specifically, we argue that extracting intact features of target objects is more transferable, and then 8 | propose a novel cross-set erasing-inpainting (CSEI) method. CSEI processes the images in the support set using erasing 9 | and inpainting, and then uses them to augment the query set of the same task. Consequently, the feature embedding produced 10 | by our proposed method can contain more complete information of target objects. In addition, we propose taskspecific feature modulation to make the features adaptive to 11 | the current task. 12 | 13 | ## Dataset 14 | 15 | [Mini-ImageNet](https://drive.google.com/file/d/1KfrNQgOLKLjaD0h1U6dfkmtaobVw2Qkt/view?usp=sharing) 16 | ## Our framework 17 | 18 | framework 19 | 20 | ## Acknowledgments 21 | 22 | This code is based on the implementations of [**Cross Attention Network for Few-shot Classification**](https://github.com/blue-blue272/fewshot-CAN). 23 | 24 | ### Training 25 | 26 | ``` 27 | sh train.sh 28 | ``` 29 | ### If you use this code/method or find it helpful, please cite: 30 | 31 | ``` 32 | @article{li2021learning, 33 | title={Learning Intact Features by Erasing-Inpainting for Few-shot Classification}, 34 | author={Li, Junjie and Wang, Zilei and Hu, Xiaoming}, 35 | booktitle = {AAAI} 36 | year={2021} 37 | } 38 | ``` 39 | -------------------------------------------------------------------------------- /args_mini.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torchFewShot 3 | 4 | def argument_parser(): 5 | 6 | parser = argparse.ArgumentParser(description='Train image model with cross entropy loss') 7 | # ************************************************************ 8 | # Datasets (general) 9 | # ************************************************************ 10 | parser.add_argument('-d', '--dataset', type=str, default='miniImageNet') 11 | parser.add_argument('--load', default=False) 12 | 13 | parser.add_argument('-j', '--workers', default=4, type=int, 14 | help="number of data loading workers (default: 4)") 15 | parser.add_argument('--height', type=int, default=84, 16 | help="height of an image (default: 84)") 17 | parser.add_argument('--width', type=int, default=84, 18 | help="width of an image (default: 84)") 19 | 20 | # ************************************************************ 21 | # Optimization options 22 | # ************************************************************ 23 | parser.add_argument('--optim', type=str, default='sgd', 24 | help="optimization algorithm (see optimizers.py)") 25 | parser.add_argument('--lr', '--learning-rate', default=0.1, type=float, 26 | help="initial learning rate") 27 | parser.add_argument('--weight-decay', default=5e-04, type=float, 28 | help="weight decay (default: 5e-04)") 29 | 30 | parser.add_argument('--max-epoch', default=95, type=int, 31 | help="maximum epochs to run") 32 | parser.add_argument('--start-epoch', default=0, type=int, 33 | help="manual epoch number (useful on restarts)") 34 | parser.add_argument('--stepsize', default=[60], nargs='+', type=int, 35 | help="stepsize to decay learning rate") 36 | parser.add_argument('--LUT_lr', default=[(60, 0.1), (70, 0.006), (85, 0.0012), (95, 0.00024)], 37 | help="multistep to decay learning rate") 38 | 39 | parser.add_argument('--train-batch', default=4, type=int, 40 | help="train batch size") 41 | parser.add_argument('--test-batch', default=4, type=int, 42 | help="test batch size") 43 | 44 | # ************************************************************ 45 | # Architecture settings 46 | # ************************************************************ 47 | parser.add_argument('--num_classes', type=int, default=64) 48 | parser.add_argument('--scale_cls', type=int, default=7) 49 | 50 | # ************************************************************ 51 | # Miscs 52 | # ************************************************************ 53 | parser.add_argument('--save-dir', type=str, default='./result/miniImageNet/CAM/1-shot-seed112/') 54 | parser.add_argument('--resume', type=str, default='', metavar='PATH') 55 | parser.add_argument('--gpu-devices', default='2', type=str) 56 | 57 | # ************************************************************ 58 | # FewShot settting 59 | # ************************************************************ 60 | parser.add_argument('--nKnovel', type=int, default=5, 61 | help='number of novel categories') 62 | parser.add_argument('--use_similarity', type=int, default=0, 63 | help='using similarity cam') 64 | parser.add_argument('--nExemplars', type=int, default=5, 65 | help='number of training examples per novel category.') 66 | parser.add_argument('--Classic', type=int, default=0, 67 | help='train classic classifer') 68 | 69 | parser.add_argument('--train_nTestNovel', type=int, default=6 * 5, 70 | help='number of test examples for all the novel category when training') 71 | parser.add_argument('--train_epoch_size', type=int, default=1200, 72 | help='number of batches per epoch when training') 73 | parser.add_argument('--nTestNovel', type=int, default=15 * 5, 74 | help='number of test examples for all the novel category') 75 | parser.add_argument('--epoch_size', type=int, default=600, 76 | help='number of batches per epoch') 77 | 78 | parser.add_argument('--phase', default='test', type=str, 79 | help='use test or val dataset to early stop') 80 | parser.add_argument('--seed', type=int, default=2333) 81 | 82 | return parser 83 | 84 | -------------------------------------------------------------------------------- /create_inpaint_data_mini.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | from __future__ import division 3 | 4 | import os 5 | import sys 6 | import time 7 | from PIL import Image 8 | import datetime 9 | import argparse 10 | import os.path as osp 11 | import numpy as np 12 | import random 13 | import cv2 14 | from scipy.misc import imread 15 | from skimage.feature import canny 16 | from skimage.color import rgb2gray, gray2rgb 17 | 18 | import torch 19 | import torch.nn as nn 20 | import torch.backends.cudnn as cudnn 21 | from torch.utils.data import DataLoader 22 | from torch.optim import lr_scheduler 23 | import torch.nn.functional as F 24 | import torchvision.transforms.functional as Funljj 25 | sys.path.append('./torchFewShot') 26 | 27 | #from args_tiered import argument_parser 28 | from args_xent import argument_parser 29 | #from torchFewShot.models.net import Model 30 | 31 | from torchFewShot.models.models_gnn import create_models 32 | from torchFewShot.data_manager_imageori import DataManager 33 | #from torchFewShot.data_manager import DataManager 34 | from torchFewShot.losses import CrossEntropyLoss 35 | from torchFewShot.optimizers import init_optimizer 36 | import transforms as T 37 | from torchFewShot.utils.iotools import save_checkpoint, check_isfile 38 | from torchFewShot.utils.avgmeter import AverageMeter 39 | from torchFewShot.utils.logger import Logger 40 | from torchFewShot.utils.torchtools import one_hot, adjust_learning_rate 41 | 42 | sys.path.append('/home/lijunjie/edge-connect-master') 43 | from shutil import copyfile 44 | from src.config import Config 45 | from src.edge_connect_few_shot import EdgeConnect 46 | 47 | #config = load_config(mode) 48 | config_path = os.path.join('/home/lijunjie/edge-connect-master/checkpoints/places2_authormodel', 'config.yml') 49 | config = Config(config_path) 50 | config.TEST_FLIST = '/home/lijunjie/edge-connect-master/examples/test_result/' 51 | config.TEST_MASK_FLIST = '/home/lijunjie/edge-connect-master/examples/places2/masks' 52 | config.RESULTS = './checkpoints/EC_test' 53 | config.MODE = 2 54 | if config.MODE == 2: 55 | config.MODEL = 3 56 | config.INPUT_SIZE = 0 57 | config.mask_id=2 58 | #if args.input is not None: 59 | #config.TEST_FLIST = args.input 60 | 61 | #if args.mask is not None: 62 | #config.TEST_MASK_FLIST = args.mask 63 | 64 | #if args.edge is not None: 65 | #config.TEST_EDGE_FLIST = args.edge 66 | 67 | #if args.output is not None: 68 | #config.RESULTS = args.output 69 | #exit(0) 70 | 71 | 72 | parser = argument_parser() 73 | args = parser.parse_args() 74 | #print(args.use_similarity) 75 | #exit(0) 76 | if args.use_similarity: 77 | from torchFewShot.models.net_similary import Model 78 | else: 79 | from torchFewShot.models.net import Model_mltizhixin , Model_tradi 80 | #print('enter ori net') 81 | #exit(0) 82 | 83 | only_test=False 84 | def returnCAM(feature_conv, weight_softmax, class_idx,output_cam ): 85 | # generate the class activation maps upsample to 256x256 86 | size_upsample = (84, 84) 87 | nc, h, w = feature_conv.shape 88 | #output_cam = [] 89 | #print(class_idx) 90 | #exit(0) 91 | #print(class_idx, nc, h, w,weight_softmax[class_idx[0]].shape) 92 | #print(feature_conv.shape) 93 | #print(class_idx) 94 | #exit(0) 95 | for idx in class_idx[0]: 96 | #idx=int(idx) 97 | #print(idx) 98 | #exit(0) 99 | #print( weight_softmax[idx].shape,feature_conv.reshape((nc, h*w)).shape) 100 | #exit(0) 101 | cam = weight_softmax[idx].dot(feature_conv.reshape((nc, h*w))) 102 | cam = cam.reshape(h, w) 103 | cam = cam - np.min(cam) 104 | cam_img = cam / np.max(cam) 105 | #cam_img = np.uint8((255 * cam_img)>200)*255 106 | cam_img = np.uint8(255 * cam_img) 107 | cam_img_resize=cv2.resize(cam_img, size_upsample) 108 | cam_img_resize = np.uint8((cam_img_resize)>200)*255 109 | #cv2.imwrite('./mask.jpg',cam_img*255) 110 | #exit(0) 111 | #print(cam_img.sum()) 112 | #exit(0) 113 | #cam_img = np.uint8(255 * cam_img) 114 | mask_tensor=Funljj.to_tensor(Image.fromarray(cam_img_resize)).float() 115 | #print(mask_tensor.sum()) 116 | #exit(0) 117 | output_cam.append(mask_tensor) 118 | return output_cam 119 | def main(): 120 | 121 | torch.manual_seed(args.seed) 122 | os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu_devices 123 | use_gpu = torch.cuda.is_available() 124 | config.DEVICE = torch.device("cuda") 125 | torch.backends.cudnn.benchmark = True 126 | #torch.manual_seed(config.SEED) 127 | 128 | #torch.cuda.manual_seed_all(config.SEED) 129 | np.random.seed(args.seed) 130 | random.seed(args.seed) 131 | sys.stdout = Logger(osp.join(args.save_dir, 'log_train.txt')) 132 | print("==========\nArgs:{}\n==========".format(args)) 133 | 134 | if use_gpu: 135 | print("Currently using GPU {}".format(args.gpu_devices)) 136 | cudnn.benchmark = True 137 | torch.cuda.manual_seed_all(args.seed) 138 | else: 139 | print("Currently using CPU (GPU is highly recommended)") 140 | 141 | print('Initializing image data manager') 142 | dm = DataManager(args, use_gpu) 143 | trainloader, testloader = dm.return_dataloaders() 144 | model_edge = EdgeConnect(config) 145 | model_edge.load() 146 | print('\nstart testing...\n') 147 | #model_edge.test() 148 | #print(args.scale_cls,args.num_classes) 149 | #exit(0) 150 | #GNN_model=create_models(args,512) 151 | #print(args.use_similarity) 152 | #exit(0) 153 | if args.use_similarity: 154 | GNN_model=create_models(args,512) 155 | model = Model(args,GNN_model,scale_cls=args.scale_cls, num_classes=args.num_classes) 156 | else: 157 | model = Model_mltizhixin(scale_cls=args.scale_cls, num_classes=args.num_classes) 158 | model_tradclass = Model_tradi(scale_cls=args.scale_cls, num_classes=args.num_classes) 159 | params_tradclass = torch.load('../fewshot-CAN-master/result/%s/CAM/1-shot-seed112_classic_classifier_avg_nouse_CAN/%s' % (args.dataset, 'best_model.pth.tar')) 160 | model_tradclass.load_state_dict(params_tradclass['state_dict']) 161 | #params = torch.load('result/%s/CAM/1-shot-seed112_inpaint_use_CAM/%s' % (args.dataset, 'checkpoint_inpaint67.pth.tar')) 162 | #model.load_state_dict(params['state_dict']) 163 | #print('enter model_tradclass') 164 | #exit(0) 165 | if False: 166 | params = torch.load('result/%s/CAM/1-shot-seed112/%s' % (args.dataset, 'best_model.pth.tar')) 167 | params_tradclass = torch.load('result/%s/CAM/1-shot-seed112_classic_classifier_global_avg/%s' % (args.dataset, 'checkpoint_inpaint67.pth.tar')) 168 | print(type(params)) 169 | #exit(0) 170 | #for key in params.keys(): 171 | #print(type(key)) 172 | #exit(0) 173 | #model.load_state_dict(params['state_dict']) 174 | model_tradclass.load_state_dict(params_tradclass['state_dict']) 175 | #exit(0) 176 | #for ind,i in model.state_dict().items(): 177 | #print (ind,i.shape) 178 | #exit(0) 179 | params = list(model_tradclass.parameters()) 180 | #fc_params=params[-2] 181 | weight_softmax = np.squeeze(params[-2].data.numpy()) 182 | #print(weight_softmax.shape,type(params[-2]),params[-2].shape,params[-2].data.shape) 183 | #exit(0) 184 | criterion = CrossEntropyLoss() 185 | optimizer = init_optimizer(args.optim, model.parameters(), args.lr, args.weight_decay) 186 | #optimizer_tradclass = init_optimizer(args.optim, model_tradclass.parameters(), args.lr, args.weight_decay) 187 | #model_tradclass 188 | 189 | if use_gpu: 190 | model = model.cuda() 191 | model_tradclass = model_tradclass.cuda() 192 | 193 | start_time = time.time() 194 | train_time = 0 195 | best_acc = -np.inf 196 | best_epoch = 0 197 | print("==> Start training") 198 | 199 | for epoch in range(args.max_epoch): 200 | if not args.Classic: 201 | learning_rate = adjust_learning_rate(optimizer, epoch, args.LUT_lr) 202 | else: 203 | optimizer_tradclass = init_optimizer(args.optim, model_tradclass.parameters(), args.lr, args.weight_decay) 204 | learning_rate = adjust_learning_rate(optimizer_tradclass, epoch, args.LUT_lr) 205 | #print('enter optimizer_tradclass') 206 | #exit(0) 207 | 208 | start_train_time = time.time() 209 | #exit(0) 210 | #print(not True) 211 | #exit(0) 212 | if not only_test: 213 | #print(';;;;;;;;;;;') 214 | #exit(0) 215 | if not args.Classic: 216 | print('enter train code') 217 | train(epoch,model_edge, model, model_tradclass,weight_softmax, criterion, optimizer, trainloader, learning_rate, use_gpu) 218 | #print('oooo') 219 | else: 220 | acc=train(epoch,model_edge, model_tradclass, criterion, optimizer_tradclass, trainloader, learning_rate, use_gpu) 221 | 222 | train_time += round(time.time() - start_train_time) 223 | 224 | if epoch == 0 or epoch > (args.stepsize[0]-1) or (epoch + 1) % 10 == 0: 225 | print('enter test code') 226 | #exit(0) 227 | if not args.Classic: 228 | #acc = test(model_edge, model, model_tradclass,weight_softmax, testloader, use_gpu) 229 | acc = test_ori(model, testloader, use_gpu) 230 | is_best = acc > best_acc 231 | #else: 232 | 233 | 234 | #print(acc) 235 | #exit(0) 236 | if is_best: 237 | best_acc = acc 238 | best_epoch = epoch + 1 239 | if not only_test: 240 | if not args.Classic: 241 | save_checkpoint({ 242 | 'state_dict': model.state_dict(), 243 | 'acc': acc, 244 | 'epoch': epoch, 245 | }, is_best, osp.join(args.save_dir, 'checkpoint_inpaint' + str(epoch + 1) + '.pth.tar')) 246 | if args.Classic: 247 | save_checkpoint({ 248 | 'state_dict': model_tradclass.state_dict(), 249 | 'acc': acc, 250 | 'epoch': epoch, 251 | }, is_best, osp.join(args.save_dir, 'checkpoint_classic' + str(epoch + 1) + '.pth.tar')) 252 | 253 | print("==> Test 5-way Best accuracy {:.2%}, achieved at epoch {}".format(best_acc, best_epoch)) 254 | 255 | elapsed = round(time.time() - start_time) 256 | elapsed = str(datetime.timedelta(seconds=elapsed)) 257 | train_time = str(datetime.timedelta(seconds=train_time)) 258 | print("Finished. Total elapsed time (h:m:s): {}. Training time (h:m:s): {}.".format(elapsed, train_time)) 259 | print("==========\nArgs:{}\n==========".format(args)) 260 | 261 | from skimage.feature import canny 262 | from skimage.color import rgb2gray, gray2rgb 263 | def load_edge( img, mask): 264 | sigma = 2 265 | index=1 266 | # in test mode images are masked (with masked regions), 267 | # using 'mask' parameter prevents canny to detect edges for the masked regions 268 | mask = None if False else (1 - mask / 255).astype(np.bool) 269 | #mask =(1 - mask / 255).astype(np.bool) 270 | # canny 271 | if True: 272 | # no edge 273 | if sigma == -1: 274 | return np.zeros(img.shape).astype(np.float) 275 | 276 | # random sigma 277 | if sigma == 0: 278 | sigma = random.randint(1, 4) 279 | 280 | return canny(img, sigma=sigma, mask=mask).astype(np.float) 281 | 282 | # external 283 | else: 284 | imgh, imgw = img.shape[0:2] 285 | edge = imread(self.edge_data[index]) 286 | edge = self.resize(edge, imgh, imgw) 287 | 288 | # non-max suppression 289 | if self.nms == 1: 290 | edge = edge * canny(img, sigma=sigma, mask=mask) 291 | 292 | return edge 293 | 294 | 295 | def read_image(img_path): 296 | """Keep reading image until succeed. 297 | This can avoid IOError incurred by heavy IO process.""" 298 | got_img = False 299 | if not osp.exists(img_path): 300 | raise IOError("{} does not exist".format(img_path)) 301 | while not got_img: 302 | try: 303 | img = Image.open(img_path).convert('RGB') 304 | got_img = True 305 | except IOError: 306 | print("IOError incurred when reading '{}'. Will redo. Don't worry. Just chill.".format(img_path)) 307 | pass 308 | return img 309 | transform_test = T.Compose([ 310 | T.Resize((args.height, args.width), interpolation=3), 311 | T.ToTensor(), 312 | T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), 313 | ]) 314 | def train(epoch,model_edge, model, model_tradclass,weight_softmax, criterion, optimizer, trainloader, learning_rate, use_gpu): 315 | 316 | if not os.path.isdir("/data4/lijunjie/mini-imagenet-tools/processed_images_84/train_1"): 317 | os.mkdir("/data4/lijunjie/mini-imagenet-tools/processed_images_84/train_1") 318 | os.mkdir("/data4/lijunjie/mini-imagenet-tools/processed_images_84/train_2") 319 | os.mkdir("/data4/lijunjie/mini-imagenet-tools/processed_images_84/train_3") 320 | os.mkdir("/data4/lijunjie/mini-imagenet-tools/processed_images_84/train_4") 321 | os.mkdir("/data4/lijunjie/mini-imagenet-tools/processed_images_84/train_5") 322 | os.mkdir("/data4/lijunjie/mini-imagenet-tools/processed_images_84/train_6") 323 | os.mkdir("/data4/lijunjie/mini-imagenet-tools/processed_images_84/train_7") 324 | os.mkdir("/data4/lijunjie/mini-imagenet-tools/processed_images_84/train_8") 325 | losses = AverageMeter() 326 | batch_time = AverageMeter() 327 | data_time = AverageMeter() 328 | std=np.expand_dims(np.array([0.229, 0.224, 0.225]),axis=1) 329 | std=np.expand_dims(std,axis=2) 330 | mean=np.expand_dims(np.array([0.485, 0.456, 0.406]),axis=1) 331 | mean=np.expand_dims(mean,axis=2) 332 | model.eval() 333 | #model_edge.eval() 334 | model_tradclass.eval() 335 | end = time.time() 336 | #print('llllllllllllll','located in train_with_inpaint_final.py at 264') 337 | #exit(0) 338 | for root, dirs, _ in os.walk('/data4/lijunjie/mini-imagenet-tools/processed_images_84/train'): 339 | #for f in files: 340 | #print(os.path.join(root, f)) 341 | 342 | for d in dirs: 343 | path=os.path.join(root, d) 344 | path_1=path.replace('train','train_1') 345 | path_2=path.replace('train','train_2') 346 | path_3=path.replace('train','train_3') 347 | path_4=path.replace('train','train_4') 348 | path_5=path.replace('train','train_5') 349 | path_6=path.replace('train','train_6') 350 | path_7=path.replace('train','train_7') 351 | path_8=path.replace('train','train_8') 352 | if not os.path.isdir(path_1): 353 | os.mkdir(path_1) 354 | os.mkdir(path_2) 355 | os.mkdir(path_3) 356 | os.mkdir(path_4) 357 | os.mkdir(path_5) 358 | os.mkdir(path_6) 359 | os.mkdir(path_7) 360 | os.mkdir(path_8) 361 | files = os.listdir(path) 362 | #images=[] 363 | #imgs_gray=[] 364 | #Xt_img_ori=[] 365 | Paths=[] 366 | Paths.append(path_1) 367 | Paths.append(path_2) 368 | Paths.append(path_3) 369 | Paths.append(path_4) 370 | Paths.append(path_5) 371 | Paths.append(path_6) 372 | Paths.append(path_7) 373 | Paths.append(path_8) 374 | for file in files: 375 | images=[] 376 | imgs_gray=[] 377 | Xt_img_ori=[] 378 | img_ori = read_image(os.path.join(path, file)) 379 | #print(file) 380 | #exit(0) 381 | masked_img=np.array(img_ori)#*(1-mask_3)+mask_3*255 382 | masked_img=Image.fromarray(masked_img) 383 | masked_img_tensor=Funljj.to_tensor(masked_img).float() 384 | Xt_img_ori.append(masked_img_tensor) 385 | img = transform_test(img_ori) 386 | img_gray = rgb2gray(np.array(img_ori)) 387 | img_gray=Image.fromarray(img_gray) 388 | img_gray_tensor=Funljj.to_tensor(img_gray).float() 389 | imgs_gray.append(img_gray_tensor) 390 | images.append(img) 391 | images = torch.stack(images, dim=0) 392 | imgs_gray = torch.stack(imgs_gray, dim=0) 393 | Xt_img_ori = torch.stack(Xt_img_ori, dim=0) 394 | if use_gpu: 395 | images_train = images.cuda() 396 | imgs_gray = imgs_gray.cuda() 397 | Xt_img_ori = Xt_img_ori.cuda() 398 | 399 | with torch.no_grad(): 400 | ytest,feature= model_tradclass(images_train.reshape(1,1,3,84,84), images_train.reshape(1,1,3,84,84),images_train.reshape(1,1,3,84,84), images_train.reshape(1,1,3,84,84)) 401 | feature_cpu=feature.detach().cpu().numpy() 402 | probs, idx = ytest.detach().sort(1, True) 403 | probs = probs.cpu().numpy() 404 | idx = idx.cpu().numpy() 405 | #print(pids) 406 | #print(idx[:,0,0,0]) 407 | #print(idx.shape) 408 | #exit(0) 409 | #print(feature.shape) 410 | #exit(0) 411 | masks=[] 412 | edges=[] 413 | #output_cam=[] 414 | for i in range(feature.shape[0]): 415 | CAMs=returnCAM(feature_cpu[i], weight_softmax, [idx[i,:8,0,0]],masks) 416 | #for j in range(4): 417 | #print(CAMs[j].shape,CAMs[j].max(),CAMs[j].min(),CAMs[j].sum()) 418 | #exit(0) 419 | masks=CAMs 420 | #print(len(masks),masks[0].shape) 421 | masks_tensor = torch.stack(masks, dim=0) 422 | Xt_masks = masks_tensor.reshape(1,1,8,1,84,84)#[:,:,0] 423 | Xt_img_ori_repeat=Xt_img_ori.reshape(1,1,1,3,84,84) 424 | 425 | Xt_img_ori_repeat = Xt_img_ori_repeat.repeat(1,1,8,1,1,1) 426 | Xt_img_gray_repeat=imgs_gray.reshape(1,1,1,1,84,84) 427 | 428 | Xt_img_gray_repeat = Xt_img_gray_repeat.repeat(1,1,7,1,1,1) 429 | #print(Xt_img_ori.shape,Xt_masks.shape) 430 | #exit(0) 431 | mask_numpy=np.uint8(Xt_masks.numpy()*255) 432 | print(mask_numpy.shape) 433 | #exit(0) 434 | Xt_img_gray_numpy=np.uint8(imgs_gray.cpu().numpy()*255).reshape(1,1,1,84,84) 435 | #print(Xt_img_gray_numpy.shape) 436 | for i in range(1): 437 | for j in range(1): 438 | for k in range(7): 439 | edge_PIL=Image.fromarray(load_edge(Xt_img_gray_numpy[i,j,0], mask_numpy[i,j,k,0])) 440 | print(mask_numpy[i,j,k,0].sum()/255,'llll') 441 | #exit(0) 442 | edges.append(Funljj.to_tensor(edge_PIL).float()) 443 | edges = torch.stack(edges, dim=0) 444 | edge_sh=edges#.reshape(4,5,1,84,84) 445 | #print(edge_sh.shape,Xt_img_gray_repeat.shape,masks_tensor.shape) 446 | #exit(0) 447 | #exit(0) 448 | #model_edge.test(Xt_img_ori,edge_sh,Xt_img_gray,Xt_masks) 449 | with torch.no_grad(): 450 | inpaint_img=model_edge.test(Xt_img_ori_repeat.reshape(8,3,84,84),edge_sh,Xt_img_gray_repeat.reshape(8,1,84,84),masks_tensor) 451 | inpaint_img_np=inpaint_img.detach().cpu().numpy() 452 | Xt_img_ori_np=Xt_img_ori_repeat.detach().cpu().numpy() 453 | #print(inpaint_img_np.shape) 454 | #exit(0) 455 | for id in range(8): 456 | images_temp_train1=inpaint_img_np[id,:,:] 457 | Xt_img_ori_repeat1=Xt_img_ori_np.reshape(8,3,84,84)[id,:,:] 458 | print(Xt_img_ori_repeat1.shape) 459 | #images_temp_train=images_temp_train1*std+mean 460 | images_ori_train=images_temp_train1.transpose((1,2,0))[:,:,::-1] 461 | Xt_img_ori_repeat1=Xt_img_ori_repeat1.transpose((1,2,0))[:,:,::-1] 462 | images_ori_train=np.uint8(images_ori_train*255) 463 | Xt_img_ori_repeat1=np.uint8(Xt_img_ori_repeat1*255) 464 | cv2.imwrite(Paths[id]+'/'+file, images_ori_train) 465 | #cv2.imwrite('./result/inpaint_img/'+str(i)+'_'+str(id)+'_ori.jpg', Xt_img_ori_repeat1) 466 | 467 | 468 | 469 | 470 | if __name__ == '__main__': 471 | main() 472 | -------------------------------------------------------------------------------- /fig/framework.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ljjcoder/CSEI/cfe671441a98ae057f9bfa5dc4f251683db469b1/fig/framework.png -------------------------------------------------------------------------------- /losses.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | 8 | import torch 9 | import torch.nn as nn 10 | from torch.autograd import Variable 11 | 12 | def l2_loss(feat): 13 | return feat.pow(2).sum()/(2.0*feat.size(0)) 14 | 15 | 16 | def get_one_hot(labels, num_classes): 17 | 18 | one_hot = Variable(torch.range(0, num_classes-1)).unsqueeze(0).expand(labels.size(0), num_classes) 19 | 20 | # if (type(labels.data) is torch.cuda.FloatTensor) or (type(labels.data) is torch.cuda.LongTensor): 21 | one_hot = one_hot.to(labels.device) 22 | 23 | #print(labels.unsqueeze(1).expand_as(one_hot).float(),(type(labels.data) is torch.cuda.FloatTensor) or (type(labels.data) is torch.cuda.LongTensor)) 24 | #exit(0) 25 | one_hot = one_hot.eq(labels.unsqueeze(1).expand_as(one_hot).float()).float() 26 | return one_hot 27 | 28 | class BatchSGMLoss(nn.Module): 29 | def __init__(self, num_classes): 30 | super(BatchSGMLoss, self).__init__() 31 | self.softmax = nn.Softmax() 32 | self.num_classes = num_classes 33 | def forward(self,feats, scores, classifier_weight, labels): 34 | one_hot = get_one_hot(labels, self.num_classes) 35 | p = self.softmax(scores) 36 | if type(scores.data) is torch.cuda.FloatTensor: 37 | p = p.cuda() 38 | 39 | 40 | G = (one_hot-p).transpose(0,1).mm(feats) 41 | G = G.div(feats.size(0)) 42 | return G.pow(2).sum() 43 | 44 | 45 | class SGMLoss(nn.Module): 46 | def __init__(self, num_classes): 47 | super(SGMLoss, self).__init__() 48 | self.softmax = nn.Softmax() 49 | self.num_classes = num_classes 50 | 51 | def forward(self,feats, scores, classifier_weight, labels): 52 | 53 | one_hot = get_one_hot(labels, self.num_classes) 54 | #print(labels[0],one_hot[0][labels[0]],one_hot.shape) 55 | #exit(0) 56 | p = self.softmax(scores) 57 | #print(p.shape,feats.size(0),'ooooooooo') 58 | #exit(0) 59 | if type(scores.data) is torch.cuda.FloatTensor: 60 | p = p.cuda() 61 | pereg_wt = (one_hot - p).pow(2).sum(1) 62 | #print(pereg_wt.shape) 63 | #exit(0) 64 | sqrXnorm = feats.pow(2).sum(1) 65 | loss = pereg_wt.mul(sqrXnorm).mean() 66 | return loss 67 | 68 | 69 | class GenericLoss: 70 | def __init__(self,aux_loss_type, aux_loss_wt, num_classes): 71 | aux_loss_fns = dict(l2=l2_loss, sgm=SGMLoss(num_classes), batchsgm=BatchSGMLoss(num_classes)) 72 | #print(aux_loss_fns,aux_loss_fns[aux_loss_type]) 73 | #print(aux_loss_type) 74 | #exit(0) 75 | self.aux_loss_fn = aux_loss_fns[aux_loss_type] 76 | self.aux_loss_type = aux_loss_type 77 | self.cross_entropy_loss = nn.CrossEntropyLoss() 78 | #print(aux_loss_wt) 79 | #exit(0) 80 | self.aux_loss_wt = aux_loss_wt 81 | 82 | def __call__(self, classifier_weight,scores,feats, y_var): 83 | #scores, feats = model(x_var) 84 | # print(scores.shape,feats.shape,scores.shape,'located in loss.py at 82') 85 | #exit(0) 86 | #if self.aux_loss_type in ['l2']: 87 | #aux_loss = self.aux_loss_fn(feats) 88 | #else: 89 | #classifier_weight = model.module.get_classifier_weight() 90 | #print(y_var.shape,classifier_weight.shape) 91 | #exit(0) 92 | #print(feats.shape, scores.shape, classifier_weight.shape, y_var.shape) 93 | batch,cnum,h,w=feats.shape 94 | num_class=scores.shape[1] 95 | y_var=y_var.reshape(batch,1).repeat(1,h*w).reshape(-1) 96 | classifier_weight=classifier_weight.reshape(classifier_weight.shape[0],classifier_weight.shape[1]) 97 | feats=feats.reshape(batch,cnum,-1).permute(0,2,1) 98 | feats=feats.reshape(-1,cnum) 99 | scores=scores.reshape(batch,num_class,-1).permute(0,2,1) 100 | scores=scores.reshape(-1,num_class) 101 | #print(feats.shape, scores.shape, classifier_weight.shape, y_var.shape) 102 | #exit(0) 103 | aux_loss = self.aux_loss_fn(feats, scores, classifier_weight, y_var) 104 | #orig_loss = self.cross_entropy_loss(scores, y_var) 105 | #print('love yym') 106 | 107 | #exit(0) 108 | return 0.002* aux_loss 109 | 110 | 111 | 112 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | from __future__ import division 3 | 4 | import os 5 | import sys 6 | import time 7 | import datetime 8 | import argparse 9 | import os.path as osp 10 | import numpy as np 11 | import random 12 | 13 | import torch 14 | import torch.nn as nn 15 | import torch.backends.cudnn as cudnn 16 | from torch.utils.data import DataLoader 17 | from torch.optim import lr_scheduler 18 | import torch.nn.functional as F 19 | sys.path.append('./torchFewShot') 20 | 21 | from torchFewShot.models.net_related import Model 22 | from torchFewShot.data_manager_image_inpainting_data import DataManager 23 | from torchFewShot.losses import CrossEntropyLoss 24 | from torchFewShot.optimizers import init_optimizer 25 | 26 | from torchFewShot.utils.iotools import save_checkpoint, check_isfile 27 | from torchFewShot.utils.avgmeter import AverageMeter 28 | from torchFewShot.utils.logger import Logger 29 | from torchFewShot.utils.torchtools import one_hot, adjust_learning_rate 30 | 31 | parser = argparse.ArgumentParser(description='Test image model with 5-way classification') 32 | # Datasets 33 | parser.add_argument('-d', '--dataset', type=str, default='miniImageNet') 34 | parser.add_argument('--load', default=False) 35 | parser.add_argument('-j', '--workers', default=4, type=int, 36 | help="number of data loading workers (default: 4)") 37 | parser.add_argument('--height', type=int, default=84, 38 | help="height of an image (default: 84)") 39 | parser.add_argument('--width', type=int, default=84, 40 | help="width of an image (default: 84)") 41 | # Optimization options 42 | parser.add_argument('--train-batch', default=4, type=int, 43 | help="train batch size") 44 | parser.add_argument('--test-batch', default=8, type=int, 45 | help="test batch size") 46 | # Architecture 47 | parser.add_argument('--num_classes', type=int, default=64) 48 | parser.add_argument('--scale_cls', type=int, default=7) 49 | parser.add_argument('--save-dir', type=str, default='') 50 | parser.add_argument('--resume', type=str, default='', metavar='PATH') 51 | # FewShot settting 52 | parser.add_argument('--nKnovel', type=int, default=5, 53 | help='number of novel categories') 54 | parser.add_argument('--nExemplars', type=int, default=1, 55 | help='number of training examples per novel category.') 56 | parser.add_argument('--train_nTestNovel', type=int, default=6 * 5, 57 | help='number of test examples for all the novel category when training') 58 | parser.add_argument('--train_epoch_size', type=int, default=1200, 59 | help='number of episodes per epoch when training') 60 | parser.add_argument('--nTestNovel', type=int, default=15 * 5, 61 | help='number of test examples for all the novel category') 62 | parser.add_argument('--epoch_size', type=int, default=600, 63 | help='number of batches per epoch') 64 | # Miscs 65 | parser.add_argument('--phase', default='test', type=str) 66 | parser.add_argument('--seed', type=int, default=1) 67 | parser.add_argument('--gpu-devices', default='1', type=str) 68 | 69 | args = parser.parse_args() 70 | 71 | 72 | def main(): 73 | torch.manual_seed(args.seed) 74 | os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu_devices 75 | use_gpu = torch.cuda.is_available() 76 | 77 | #sys.stdout = Logger(osp.join(args.save_dir, 'log_test.txt')) 78 | #print("==========\nArgs:{}\n==========".format(args)) 79 | 80 | if use_gpu: 81 | print("Currently using GPU {}".format(args.gpu_devices)) 82 | cudnn.benchmark = True 83 | torch.cuda.manual_seed_all(args.seed) 84 | else: 85 | print("Currently using CPU (GPU is highly recommended)") 86 | 87 | print('Initializing image data manager') 88 | dm = DataManager(args, use_gpu) 89 | trainloader, testloader = dm.return_dataloaders() 90 | 91 | model = Model(scale_cls=args.scale_cls, num_classes=args.num_classes) 92 | # load the model 93 | #best_path 94 | checkpoint = torch.load("/home/yfchen/ljj_code/spatial_test/result/miniImageNet/CAM/1-shot-seed112_inpaint_support_fuse_Cam_surport_from_64.99_test_fixed_GPU0_2333/best_model.pth.tar") 95 | model.load_state_dict(checkpoint['state_dict']) 96 | print("Loaded checkpoint from '{}'".format(args.resume)) 97 | 98 | if use_gpu: 99 | model = model.cuda() 100 | 101 | acc_5 = test_ori_5(model, testloader, use_gpu) 102 | 103 | 104 | def test_ori_5(model, testloader, use_gpu,topK=28): 105 | accs = AverageMeter() 106 | test_accuracies = [] 107 | final_accs = AverageMeter() 108 | final_test_accuracies = [] 109 | #params = torch.load(best_path) 110 | #model.load_state_dict(params['state_dict'], strict=True) 111 | model.eval() 112 | 113 | with torch.no_grad(): 114 | #for batch_idx , (images_train, labels_train,Xt_img_ori,Xt_img_gray, images_test, labels_test) in enumerate(testloader): 115 | for batch_idx , (images_train, images_train2,images_train3,images_train4,images_train5,labels_train, images_test, labels_test) in enumerate(testloader): 116 | shape_test=images_train.shape[0] 117 | images_train1=images_train.reshape(shape_test,-1,1,3,84,84) 118 | images_train2=images_train2.reshape(shape_test,-1,1,3,84,84) 119 | images_train3=images_train3.reshape(shape_test,-1,1,3,84,84) 120 | images_train4=images_train4.reshape(shape_test,-1,1,3,84,84) 121 | images_train5=images_train5.reshape(shape_test,-1,1,3,84,84) 122 | 123 | labels_train_5 = labels_train.reshape(shape_test,-1,1)#[:,:,0] 124 | 125 | labels_train_5 = labels_train_5.repeat(1,1,5) 126 | labels_train = labels_train_5.reshape(shape_test,-1) 127 | images_train_5=torch.cat((images_train1, images_train2,images_train3,images_train4,images_train5), 2) 128 | images_train=images_train_5.reshape(shape_test,-1,3,84,84) 129 | if use_gpu: 130 | images_train = images_train.cuda() 131 | #images_train_5 = images_train_5.cuda() 132 | images_test = images_test.cuda() 133 | 134 | end = time.time() 135 | #print(images_train.shape,labels_train.shape) 136 | #exit() 137 | batch_size, num_train_examples, channels, height, width = images_train.size() 138 | num_test_examples = images_test.size(1) 139 | 140 | labels_train_1hot = one_hot(labels_train).cuda() 141 | labels_test_1hot = one_hot(labels_test).cuda() 142 | #print(images_train.shape,images_test.shape) 143 | cls_scores ,cls_scores_final= model(images_train, images_test, labels_train_1hot, labels_test_1hot,topK) 144 | #print(cls_scores.shape,cls_scores_final.shape) 145 | #exit(0) 146 | cls_scores = cls_scores.view(batch_size * num_test_examples, -1) 147 | cls_scores_final = cls_scores_final.view(batch_size * num_test_examples, -1) 148 | labels_test = labels_test.view(batch_size * num_test_examples) 149 | 150 | _, preds = torch.max(cls_scores.detach().cpu(), 1) 151 | _, preds_final = torch.max(cls_scores_final.detach().cpu(), 1) 152 | acc = (torch.sum(preds == labels_test.detach().cpu()).float()) / labels_test.size(0) 153 | accs.update(acc.item(), labels_test.size(0)) 154 | 155 | acc_final = (torch.sum(preds_final == labels_test.detach().cpu()).float()) / labels_test.size(0) 156 | final_accs.update(acc_final.item(), labels_test.size(0)) 157 | 158 | gt = (preds == labels_test.detach().cpu()).float() 159 | gt = gt.view(batch_size, num_test_examples).numpy() #[b, n] 160 | 161 | gt_final = (preds_final == labels_test.detach().cpu()).float() 162 | gt_final = gt_final.view(batch_size, num_test_examples).numpy() #[b, n] 163 | 164 | acc = np.sum(gt, 1) / num_test_examples 165 | acc = np.reshape(acc, (batch_size)) 166 | test_accuracies.append(acc) 167 | 168 | acc_final = np.sum(gt_final, 1) / num_test_examples 169 | acc_final = np.reshape(acc_final, (batch_size)) 170 | final_test_accuracies.append(acc_final) 171 | 172 | accuracy = accs.avg 173 | test_accuracies = np.array(test_accuracies) 174 | test_accuracies = np.reshape(test_accuracies, -1) 175 | stds = np.std(test_accuracies, 0) 176 | ci95 = 1.96 * stds / np.sqrt(args.epoch_size) 177 | 178 | accuracy_final = final_accs.avg 179 | test_accuracies_final = np.array(final_test_accuracies) 180 | test_accuracies_final = np.reshape(test_accuracies_final, -1) 181 | stds_final = np.std(test_accuracies_final, 0) 182 | ci95_final = 1.96 * stds_final / np.sqrt(args.epoch_size) 183 | print('Accuracy: {:.2%}, std: :{:.2%}'.format(accuracy, ci95)) 184 | print('Accuracy: {:.2%}, std: :{:.2%}'.format(accuracy_final, ci95_final)) 185 | return accuracy 186 | 187 | 188 | if __name__ == '__main__': 189 | main() 190 | -------------------------------------------------------------------------------- /torchFewShot/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /torchFewShot/data_manager.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import print_function 3 | 4 | from torch.utils.data import DataLoader 5 | 6 | import transforms as T 7 | import datasets 8 | import dataset_loader 9 | 10 | class DataManager(object): 11 | """ 12 | Few shot data manager 13 | """ 14 | 15 | def __init__(self, args, use_gpu): 16 | super(DataManager, self).__init__() 17 | self.args = args 18 | self.use_gpu = use_gpu 19 | 20 | print("Initializing dataset {}".format(args.dataset)) 21 | dataset = datasets.init_imgfewshot_dataset(name=args.dataset) 22 | #print(args.load) 23 | #exit(0) 24 | if args.load: 25 | transform_train = T.Compose([ 26 | T.RandomCrop(84, padding=8), 27 | T.RandomHorizontalFlip(), 28 | T.ToTensor(), 29 | T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), 30 | T.RandomErasing(0.5) 31 | ]) 32 | 33 | transform_test = T.Compose([ 34 | T.ToTensor(), 35 | T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), 36 | ]) 37 | 38 | else: 39 | #print(args.height,args.width) 40 | #exit(0) 41 | transform_train = T.Compose([ 42 | T.Resize((args.height, args.width), interpolation=3), 43 | T.RandomCrop(args.height, padding=8), 44 | T.RandomHorizontalFlip(), 45 | T.ToTensor(), 46 | T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), 47 | T.RandomErasing(0.5) 48 | ]) 49 | 50 | transform_test = T.Compose([ 51 | T.Resize((args.height, args.width), interpolation=3), 52 | T.ToTensor(), 53 | T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), 54 | ]) 55 | 56 | pin_memory = True if use_gpu else False 57 | 58 | self.trainloader = DataLoader( 59 | dataset_loader.init_loader(name='train_loader', 60 | dataset=dataset.train, 61 | labels2inds=dataset.train_labels2inds, 62 | labelIds=dataset.train_labelIds, 63 | nKnovel=args.nKnovel, 64 | nExemplars=args.nExemplars, 65 | nTestNovel=args.train_nTestNovel, 66 | epoch_size=args.train_epoch_size, 67 | transform=transform_train, 68 | load=args.load, 69 | ), 70 | batch_size=args.train_batch, shuffle=True, num_workers=args.workers, 71 | pin_memory=pin_memory, drop_last=True, 72 | ) 73 | 74 | self.valloader = DataLoader( 75 | dataset_loader.init_loader(name='test_loader', 76 | dataset=dataset.val, 77 | labels2inds=dataset.val_labels2inds, 78 | labelIds=dataset.val_labelIds, 79 | nKnovel=args.nKnovel, 80 | nExemplars=args.nExemplars, 81 | nTestNovel=args.nTestNovel, 82 | epoch_size=args.epoch_size, 83 | transform=transform_test, 84 | load=args.load, 85 | ), 86 | batch_size=args.test_batch, shuffle=False, num_workers=args.workers, 87 | pin_memory=pin_memory, drop_last=False, 88 | ) 89 | self.testloader = DataLoader( 90 | dataset_loader.init_loader(name='test_loader', 91 | dataset=dataset.test, 92 | labels2inds=dataset.test_labels2inds, 93 | labelIds=dataset.test_labelIds, 94 | nKnovel=args.nKnovel, 95 | nExemplars=args.nExemplars, 96 | nTestNovel=args.nTestNovel, 97 | epoch_size=args.epoch_size, 98 | transform=transform_test, 99 | load=args.load, 100 | ), 101 | batch_size=args.test_batch, shuffle=False, num_workers=args.workers, 102 | pin_memory=pin_memory, drop_last=False, 103 | ) 104 | 105 | def return_dataloaders(self): 106 | if self.args.phase == 'test': 107 | return self.trainloader, self.testloader 108 | elif self.args.phase == 'val': 109 | return self.trainloader, self.valloader 110 | -------------------------------------------------------------------------------- /torchFewShot/data_manager_image_inpainting_data.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import print_function 3 | 4 | from torch.utils.data import DataLoader 5 | 6 | import transforms as T 7 | import datasets 8 | import dataset_loader 9 | 10 | class DataManager(object): 11 | """ 12 | Few shot data manager 13 | """ 14 | 15 | def __init__(self, args, use_gpu): 16 | super(DataManager, self).__init__() 17 | self.args = args 18 | self.use_gpu = use_gpu 19 | 20 | print("Initializing dataset {}".format(args.dataset)) 21 | dataset = datasets.init_imgfewshot_dataset(name=args.dataset) 22 | if args.nExemplars==1: 23 | self.seed=233 24 | else: 25 | self.seed=36 26 | if args.load: 27 | transform_train = T.Compose([ 28 | T.RandomCrop(84, padding=8), 29 | T.RandomHorizontalFlip(), 30 | T.ToTensor(), 31 | T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), 32 | T.RandomErasing(0.5) 33 | ]) 34 | 35 | transform_test = T.Compose([ 36 | T.ToTensor(), 37 | T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), 38 | ]) 39 | 40 | else: 41 | #print(args.height,args.width) 42 | #exit(0) 43 | transform_train = T.Compose([ 44 | T.Resize((args.height, args.width), interpolation=3), 45 | T.RandomCrop(args.height, padding=8), 46 | T.RandomHorizontalFlip(), 47 | T.ToTensor(), 48 | T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), 49 | T.RandomErasing(0.5) 50 | ]) 51 | 52 | transform_test = T.Compose([ 53 | T.Resize((args.height, args.width), interpolation=3), 54 | T.ToTensor(), 55 | T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), 56 | ]) 57 | 58 | pin_memory = True if use_gpu else False 59 | 60 | self.trainloader = DataLoader( 61 | dataset_loader.init_loader(name='train_inpainting_loader', 62 | dataset=dataset.train, 63 | labels2inds=dataset.train_labels2inds, 64 | labelIds=dataset.train_labelIds, 65 | nKnovel=args.nKnovel, 66 | nExemplars=args.nExemplars, 67 | nTestNovel=args.train_nTestNovel, 68 | epoch_size=args.train_epoch_size, 69 | transform=transform_train, 70 | load=args.load, 71 | ), 72 | batch_size=args.train_batch, shuffle=True, num_workers=args.workers, 73 | pin_memory=pin_memory, drop_last=True, 74 | ) 75 | 76 | self.valloader = DataLoader( 77 | dataset_loader.init_loader(name='test_inpainting_loader', 78 | dataset=dataset.val, 79 | labels2inds=dataset.val_labels2inds, 80 | labelIds=dataset.val_labelIds, 81 | nKnovel=args.nKnovel, 82 | nExemplars=args.nExemplars, 83 | nTestNovel=args.nTestNovel, 84 | epoch_size=args.epoch_size, 85 | transform=transform_test, 86 | load=args.load, 87 | ), 88 | batch_size=args.test_batch, shuffle=False, num_workers=args.workers, 89 | pin_memory=pin_memory, drop_last=False, 90 | ) 91 | self.testloader = DataLoader( 92 | dataset_loader.init_loader(name='test_inpainting_loader', 93 | dataset=dataset.test, 94 | labels2inds=dataset.test_labels2inds, 95 | labelIds=dataset.test_labelIds, 96 | nKnovel=args.nKnovel, 97 | nExemplars=args.nExemplars, 98 | nTestNovel=args.nTestNovel, 99 | epoch_size=args.epoch_size, 100 | transform=transform_test, 101 | load=args.load, 102 | seed=self.seed, 103 | ), 104 | batch_size=args.test_batch, shuffle=False, num_workers=args.workers, 105 | pin_memory=pin_memory, drop_last=False, 106 | ) 107 | 108 | def return_dataloaders(self): 109 | if self.args.phase == 'test': 110 | return self.trainloader, self.testloader 111 | elif self.args.phase == 'val': 112 | return self.trainloader, self.valloader 113 | -------------------------------------------------------------------------------- /torchFewShot/data_manager_imageori.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import print_function 3 | 4 | from torch.utils.data import DataLoader 5 | 6 | import transforms as T 7 | import datasets 8 | import dataset_loader 9 | 10 | class DataManager(object): 11 | """ 12 | Few shot data manager 13 | """ 14 | 15 | def __init__(self, args, use_gpu): 16 | super(DataManager, self).__init__() 17 | self.args = args 18 | self.use_gpu = use_gpu 19 | 20 | print("Initializing dataset {}".format(args.dataset)) 21 | dataset = datasets.init_imgfewshot_dataset(name=args.dataset) 22 | #print(args.load) 23 | #exit(0) 24 | if args.load: 25 | transform_train = T.Compose([ 26 | T.RandomCrop(84, padding=8), 27 | T.RandomHorizontalFlip(), 28 | T.ToTensor(), 29 | T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), 30 | T.RandomErasing(0.5) 31 | ]) 32 | 33 | transform_test = T.Compose([ 34 | T.ToTensor(), 35 | T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), 36 | ]) 37 | 38 | else: 39 | #print(args.height,args.width) 40 | #exit(0) 41 | transform_train = T.Compose([ 42 | T.Resize((args.height, args.width), interpolation=3), 43 | T.RandomCrop(args.height, padding=8), 44 | T.RandomHorizontalFlip(), 45 | T.ToTensor(), 46 | T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), 47 | T.RandomErasing(0.5) 48 | ]) 49 | 50 | transform_test = T.Compose([ 51 | T.Resize((args.height, args.width), interpolation=3), 52 | T.ToTensor(), 53 | T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), 54 | ]) 55 | 56 | pin_memory = True if use_gpu else False 57 | 58 | self.trainloader = DataLoader( 59 | dataset_loader.init_loader(name='train_imgori_loader', 60 | dataset=dataset.train, 61 | labels2inds=dataset.train_labels2inds, 62 | labelIds=dataset.train_labelIds, 63 | nKnovel=args.nKnovel, 64 | nExemplars=args.nExemplars, 65 | nTestNovel=args.train_nTestNovel, 66 | epoch_size=args.train_epoch_size, 67 | transform=transform_train, 68 | load=args.load, 69 | ), 70 | batch_size=args.train_batch, shuffle=True, num_workers=args.workers, 71 | pin_memory=pin_memory, drop_last=True, 72 | ) 73 | 74 | self.valloader = DataLoader( 75 | dataset_loader.init_loader(name='test__imgori_loader', 76 | dataset=dataset.val, 77 | labels2inds=dataset.val_labels2inds, 78 | labelIds=dataset.val_labelIds, 79 | nKnovel=args.nKnovel, 80 | nExemplars=args.nExemplars, 81 | nTestNovel=args.nTestNovel, 82 | epoch_size=args.epoch_size, 83 | transform=transform_test, 84 | load=args.load, 85 | ), 86 | batch_size=args.test_batch, shuffle=False, num_workers=args.workers, 87 | pin_memory=pin_memory, drop_last=False, 88 | ) 89 | self.testloader = DataLoader( 90 | dataset_loader.init_loader(name='test__imgori_loader', 91 | dataset=dataset.test, 92 | labels2inds=dataset.test_labels2inds, 93 | labelIds=dataset.test_labelIds, 94 | nKnovel=args.nKnovel, 95 | nExemplars=args.nExemplars, 96 | nTestNovel=args.nTestNovel, 97 | epoch_size=args.epoch_size, 98 | transform=transform_test, 99 | load=args.load, 100 | ), 101 | batch_size=args.test_batch, shuffle=False, num_workers=args.workers, 102 | pin_memory=pin_memory, drop_last=False, 103 | ) 104 | 105 | def return_dataloaders(self): 106 | if self.args.phase == 'test': 107 | return self.trainloader, self.testloader 108 | elif self.args.phase == 'val': 109 | return self.trainloader, self.valloader -------------------------------------------------------------------------------- /torchFewShot/dataset_loader/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | from .train_loader import FewShotDataset_train 4 | from .test_loader import FewShotDataset_test 5 | from .train_image_ori_loader import FewShotDataset_train_imgori 6 | from .test_image_ori_loader import FewShotDataset_test_imgori 7 | from .train_inpainting_loader import FewShotDataset_train_inpainting 8 | from .test_inpainting_loader import FewShotDataset_test_inpainting 9 | 10 | __loader_factory = { 11 | 'train_loader': FewShotDataset_train, 12 | 'test_loader': FewShotDataset_test, 13 | 'train_imgori_loader': FewShotDataset_train_imgori, 14 | 'test__imgori_loader': FewShotDataset_test_imgori, 15 | 'train_inpainting_loader': FewShotDataset_train_inpainting, 16 | 'test_inpainting_loader': FewShotDataset_test_inpainting, 17 | } 18 | 19 | 20 | 21 | def get_names(): 22 | return list(__loader_factory.keys()) 23 | 24 | 25 | def init_loader(name, *args, **kwargs): 26 | if name not in list(__loader_factory.keys()): 27 | raise KeyError("Unknown model: {}".format(name)) 28 | return __loader_factory[name](*args, **kwargs) 29 | 30 | -------------------------------------------------------------------------------- /torchFewShot/dataset_loader/test_image_ori_loader.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import print_function 3 | from __future__ import division 4 | 5 | import os 6 | from PIL import Image 7 | import numpy as np 8 | import os.path as osp 9 | import lmdb 10 | import io 11 | import random 12 | from skimage.feature import canny 13 | from skimage.color import rgb2gray, gray2rgb 14 | 15 | import torch 16 | from torch.utils.data import Dataset 17 | import torchvision.transforms.functional as F 18 | 19 | def read_image(img_path): 20 | """Keep reading image until succeed. 21 | This can avoid IOError incurred by heavy IO process.""" 22 | got_img = False 23 | if not osp.exists(img_path): 24 | raise IOError("{} does not exist".format(img_path)) 25 | while not got_img: 26 | try: 27 | img = Image.open(img_path).convert('RGB') 28 | got_img = True 29 | except IOError: 30 | print("IOError incurred when reading '{}'. Will redo. Don't worry. Just chill.".format(img_path)) 31 | pass 32 | return img 33 | 34 | 35 | class FewShotDataset_test_imgori(Dataset): 36 | """Few shot epoish Dataset 37 | 38 | Returns a task (Xtrain, Ytrain, Xtest, Ytest) to classify' 39 | Xtrain: [nKnovel*nExpemplars, c, h, w]. 40 | Ytrain: [nKnovel*nExpemplars]. 41 | Xtest: [nTestNovel, c, h, w]. 42 | Ytest: [nTestNovel]. 43 | """ 44 | 45 | def __init__(self, 46 | dataset, # dataset of [(img_path, cats), ...]. 47 | labels2inds, # labels of index {(cats: index1, index2, ...)}. 48 | labelIds, # train labels [0, 1, 2, 3, ...,]. 49 | nKnovel=5, # number of novel categories. 50 | nExemplars=1, # number of training examples per novel category. 51 | nTestNovel=2*5, # number of test examples for all the novel categories. 52 | epoch_size=2000, # number of tasks per eooch. 53 | transform=None, 54 | load=True, 55 | **kwargs 56 | ): 57 | 58 | self.dataset = dataset 59 | self.labels2inds = labels2inds 60 | self.labelIds = labelIds 61 | #print(labelIds) 62 | #print(len(labelIds)) 63 | #exit(0) 64 | self.nKnovel = nKnovel 65 | self.transform = transform 66 | 67 | self.nExemplars = nExemplars 68 | self.nTestNovel = nTestNovel 69 | self.epoch_size = epoch_size 70 | self.load = load 71 | self.edge=1 72 | self.sigma=2 73 | #print(self.nExemplars,self.nTestNovel,self.epoch_size) #5,75,2000 74 | #exit(0) 75 | seed = 112 76 | random.seed(seed) 77 | np.random.seed(seed) 78 | 79 | self.Epoch_Exemplar = [] 80 | self.Epoch_Tnovel = [] 81 | for i in range(epoch_size): 82 | Tnovel, Exemplar = self._sample_episode() 83 | self.Epoch_Exemplar.append(Exemplar) 84 | self.Epoch_Tnovel.append(Tnovel) 85 | 86 | def __len__(self): 87 | return self.epoch_size 88 | 89 | def _sample_episode(self): 90 | """sampels a training epoish indexs. 91 | Returns: 92 | Tnovel: a list of length 'nTestNovel' with 2-element tuples. (sample_index, label) 93 | Exemplars: a list of length 'nKnovel * nExemplars' with 2-element tuples. (sample_index, label) 94 | """ 95 | 96 | Knovel = random.sample(self.labelIds, self.nKnovel) 97 | #print(Knovel) 98 | #exit(0) 99 | nKnovel = len(Knovel) 100 | assert((self.nTestNovel % nKnovel) == 0) 101 | nEvalExamplesPerClass = int(self.nTestNovel / nKnovel) 102 | #print(nEvalExamplesPerClass) 103 | #exit(0) 104 | Tnovel = [] 105 | Exemplars = [] 106 | for Knovel_idx in range(len(Knovel)): 107 | ids = (nEvalExamplesPerClass + self.nExemplars) 108 | img_ids = random.sample(self.labels2inds[Knovel[Knovel_idx]], ids) 109 | 110 | imgs_tnovel = img_ids[:nEvalExamplesPerClass] 111 | imgs_emeplars = img_ids[nEvalExamplesPerClass:] 112 | #print(imgs_tnovel) 113 | #exit(0) 114 | Tnovel += [(img_id, Knovel_idx) for img_id in imgs_tnovel] 115 | Exemplars += [(img_id, Knovel_idx) for img_id in imgs_emeplars] 116 | assert(len(Tnovel) == self.nTestNovel) 117 | assert(len(Exemplars) == nKnovel * self.nExemplars) 118 | random.shuffle(Exemplars) 119 | random.shuffle(Tnovel) 120 | 121 | return Tnovel, Exemplars 122 | def load_edge(self, img, mask): 123 | sigma = self.sigma 124 | index=1 125 | # in test mode images are masked (with masked regions), 126 | # using 'mask' parameter prevents canny to detect edges for the masked regions 127 | #mask = None if self.training else (1 - mask / 255).astype(np.bool) 128 | 129 | # canny 130 | if self.edge == 1: 131 | # no edge 132 | if sigma == -1: 133 | return np.zeros(img.shape).astype(np.float) 134 | 135 | # random sigma 136 | if sigma == 0: 137 | sigma = random.randint(1, 4) 138 | 139 | return canny(img, sigma=sigma, mask=mask).astype(np.float) 140 | 141 | # external 142 | else: 143 | imgh, imgw = img.shape[0:2] 144 | edge = imread(self.edge_data[index]) 145 | edge = self.resize(edge, imgh, imgw) 146 | 147 | # non-max suppression 148 | if self.nms == 1: 149 | edge = edge * canny(img, sigma=sigma, mask=mask) 150 | 151 | return edge 152 | def _creatExamplesTensorData(self, examples): 153 | """ 154 | Creats the examples image label tensor data. 155 | 156 | Args: 157 | examples: a list of 2-element tuples. (sample_index, label). 158 | 159 | Returns: 160 | images: a tensor [nExemplars, c, h, w] 161 | labels: a tensor [nExemplars] 162 | """ 163 | 164 | images = [] 165 | labels = [] 166 | images_ori=[] 167 | edges=[] 168 | imgs_gray=[] 169 | masks=[] 170 | cls = [] 171 | for (img_idx, label) in examples: 172 | img_ori = self.dataset[img_idx][0] 173 | #print(img) 174 | #exit(0) 175 | if self.load: 176 | img_ori = Image.fromarray(img_ori) 177 | else: 178 | img_ori = read_image(img_ori) 179 | #print(img.size) 180 | #print(np.array(img).shape) 181 | #exit(0) 182 | if self.transform is not None: 183 | img = self.transform(img_ori) 184 | img_gray = rgb2gray(np.array(img_ori)) 185 | #edge = self.load_edge(img_gray, None) 186 | #edge_tensor=Image.fromarray( edge) 187 | #edge_tensor=F.to_tensor( edge).float() 188 | #print(img.shape,'located in test_loader.py at 146') 189 | #exit(0) 190 | img_gray=Image.fromarray(img_gray) 191 | img_gray_tensor=F.to_tensor(img_gray).float() 192 | imgs_gray.append(img_gray_tensor) 193 | images.append(img) 194 | labels.append(label) 195 | 196 | masked_img=np.array(img_ori)#*(1-mask_3)+mask_3*255 197 | masked_img=Image.fromarray(masked_img) 198 | masked_img_tensor=F.to_tensor(masked_img).float() 199 | images_ori.append(masked_img_tensor) 200 | 201 | #edges.append(edge) 202 | images = torch.stack(images, dim=0) 203 | labels = torch.LongTensor(labels) 204 | images_ori = torch.stack(images_ori, dim=0) 205 | #edges = torch.stack(edges, dim=0) 206 | imgs_gray = torch.stack(imgs_gray, dim=0) 207 | 208 | return images, labels,images_ori,imgs_gray 209 | 210 | def __getitem__(self, index): 211 | Tnovel = self.Epoch_Tnovel[index] 212 | Exemplars = self.Epoch_Exemplar[index] 213 | Xt, Yt,xtori,xt_imgs_gray = self._creatExamplesTensorData(Exemplars) 214 | Xe, Ye,xeori,xe_imgs_gray = self._creatExamplesTensorData(Tnovel) 215 | return Xt, Yt , xtori ,xt_imgs_gray, Xe, Ye#,xeori,xe_edges 216 | 217 | -------------------------------------------------------------------------------- /torchFewShot/dataset_loader/test_inpainting_loader.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import print_function 3 | from __future__ import division 4 | 5 | import os 6 | from PIL import Image 7 | import numpy as np 8 | import transforms as T 9 | import os.path as osp 10 | import sys 11 | sys.path.append("/ghome/lijj/python_package/") 12 | import lmdb 13 | import io 14 | import random 15 | 16 | import torch 17 | from torch.utils.data import Dataset 18 | 19 | 20 | def read_image(img_path): 21 | """Keep reading image until succeed. 22 | This can avoid IOError incurred by heavy IO process.""" 23 | got_img = False 24 | if not osp.exists(img_path): 25 | raise IOError("{} does not exist".format(img_path)) 26 | while not got_img: 27 | try: 28 | img = Image.open(img_path).convert('RGB') 29 | got_img = True 30 | except IOError: 31 | print("IOError incurred when reading '{}'. Will redo. Don't worry. Just chill.".format(img_path)) 32 | pass 33 | return img 34 | 35 | 36 | class FewShotDataset_test_inpainting(Dataset): 37 | """Few shot epoish Dataset 38 | 39 | Returns a task (Xtrain, Ytrain, Xtest, Ytest) to classify' 40 | Xtrain: [nKnovel*nExpemplars, c, h, w]. 41 | Ytrain: [nKnovel*nExpemplars]. 42 | Xtest: [nTestNovel, c, h, w]. 43 | Ytest: [nTestNovel]. 44 | """ 45 | 46 | def __init__(self, 47 | dataset, # dataset of [(img_path, cats), ...]. 48 | labels2inds, # labels of index {(cats: index1, index2, ...)}. 49 | labelIds, # train labels [0, 1, 2, 3, ...,]. 50 | nKnovel=5, # number of novel categories. 51 | nExemplars=1, # number of training examples per novel category. 52 | nTestNovel=2*5, # number of test examples for all the novel categories. 53 | epoch_size=2000, # number of tasks per eooch. 54 | transform=None, 55 | load=True, 56 | seed=223, 57 | **kwargs 58 | ): 59 | 60 | self.dataset = dataset 61 | self.labels2inds = labels2inds 62 | self.labelIds = labelIds 63 | #print(labelIds) 64 | #print(len(labelIds)) 65 | #exit(0) 66 | self.nKnovel = nKnovel 67 | self.transform = transform 68 | 69 | self.nExemplars = nExemplars 70 | self.nTestNovel = nTestNovel 71 | self.epoch_size = epoch_size 72 | self.load = load 73 | #print(self.nExemplars,self.nTestNovel,self.epoch_size) #5,75,2000 74 | #exit(0) 75 | seed = seed 76 | random.seed(seed) 77 | np.random.seed(seed) 78 | self.transform_test = T.Compose([ 79 | T.Resize((84, 84), interpolation=3), 80 | T.RandomCrop(84, padding=8), 81 | T.RandomHorizontalFlip(), 82 | T.ToTensor(), 83 | T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), 84 | ]) 85 | self.Epoch_Exemplar = [] 86 | self.Epoch_Tnovel = [] 87 | for i in range(epoch_size): 88 | Tnovel, Exemplar = self._sample_episode() 89 | self.Epoch_Exemplar.append(Exemplar) 90 | self.Epoch_Tnovel.append(Tnovel) 91 | 92 | def __len__(self): 93 | return self.epoch_size 94 | 95 | def _sample_episode(self): 96 | """sampels a training epoish indexs. 97 | Returns: 98 | Tnovel: a list of length 'nTestNovel' with 2-element tuples. (sample_index, label) 99 | Exemplars: a list of length 'nKnovel * nExemplars' with 2-element tuples. (sample_index, label) 100 | """ 101 | 102 | Knovel = random.sample(self.labelIds, self.nKnovel) 103 | #print(Knovel) 104 | #exit(0) 105 | nKnovel = len(Knovel) 106 | assert((self.nTestNovel % nKnovel) == 0) 107 | nEvalExamplesPerClass = int(self.nTestNovel / nKnovel) 108 | #print(nEvalExamplesPerClass) 109 | #exit(0) 110 | Tnovel = [] 111 | Exemplars = [] 112 | for Knovel_idx in range(len(Knovel)): 113 | ids = (nEvalExamplesPerClass + self.nExemplars) 114 | img_ids = random.sample(self.labels2inds[Knovel[Knovel_idx]], ids) 115 | 116 | imgs_tnovel = img_ids[:nEvalExamplesPerClass] 117 | imgs_emeplars = img_ids[nEvalExamplesPerClass:] 118 | #print(imgs_tnovel) 119 | #exit(0) 120 | Tnovel += [(img_id, Knovel_idx) for img_id in imgs_tnovel] 121 | Exemplars += [(img_id, Knovel_idx) for img_id in imgs_emeplars] 122 | assert(len(Tnovel) == self.nTestNovel) 123 | assert(len(Exemplars) == nKnovel * self.nExemplars) 124 | random.shuffle(Exemplars) 125 | random.shuffle(Tnovel) 126 | 127 | return Tnovel, Exemplars 128 | 129 | def _creatExamplesTensorData(self, examples): 130 | """ 131 | Creats the examples image label tensor data. 132 | 133 | Args: 134 | examples: a list of 2-element tuples. (sample_index, label). 135 | 136 | Returns: 137 | images: a tensor [nExemplars, c, h, w] 138 | labels: a tensor [nExemplars] 139 | """ 140 | 141 | images = [] 142 | 143 | images2 = [] 144 | images3 = [] 145 | images4 = [] 146 | images5 = [] 147 | labels = [] 148 | for (img_idx, label) in examples: 149 | img = self.dataset[img_idx][0] 150 | #print(img) 151 | ##exit(0) 152 | if self.load: 153 | img = Image.fromarray(img) 154 | else: 155 | img = read_image(img) 156 | #print(img.size) 157 | #print(np.array(img).shape) 158 | #exit(0) 159 | if self.transform is not None: 160 | img1 = self.transform(img) 161 | 162 | img2 = self.transform_test(img) 163 | img3 = self.transform_test(img) 164 | img4 = self.transform_test(img) 165 | img5 = self.transform_test(img) 166 | #print((img2-img1).abs().sum(),(img3-img1).abs().sum(),(img2-img3).abs().sum()) 167 | #print(img.shape,'located in test_loader.py at 146') 168 | #exit(0) 169 | images.append(img1) 170 | 171 | images2.append(img2) 172 | images3.append(img3) 173 | images4.append(img4) 174 | images5.append(img5) 175 | labels.append(label) 176 | images = torch.stack(images, dim=0) 177 | 178 | images2 = torch.stack(images2, dim=0) 179 | images3 = torch.stack(images3, dim=0) 180 | images4 = torch.stack(images4, dim=0) 181 | images5 = torch.stack(images5, dim=0) 182 | labels = torch.LongTensor(labels) 183 | return images, images2,images3,images4,images5,labels 184 | 185 | def __getitem__(self, index): 186 | Tnovel = self.Epoch_Tnovel[index] 187 | Exemplars = self.Epoch_Exemplar[index] 188 | Xt, Xt2,Xt3,Xt4,Xt5,Yt = self._creatExamplesTensorData(Exemplars) 189 | Xe,Xe2,Xe3,Xe4,Xe5, Ye = self._creatExamplesTensorData(Tnovel) 190 | return Xt, Xt2,Xt3,Xt4,Xt5,Yt, Xe, Ye 191 | 192 | -------------------------------------------------------------------------------- /torchFewShot/dataset_loader/test_loader.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import print_function 3 | from __future__ import division 4 | 5 | import os 6 | from PIL import Image 7 | import numpy as np 8 | import os.path as osp 9 | import lmdb 10 | import io 11 | import random 12 | 13 | import torch 14 | from torch.utils.data import Dataset 15 | 16 | 17 | def read_image(img_path): 18 | """Keep reading image until succeed. 19 | This can avoid IOError incurred by heavy IO process.""" 20 | got_img = False 21 | if not osp.exists(img_path): 22 | raise IOError("{} does not exist".format(img_path)) 23 | while not got_img: 24 | try: 25 | img = Image.open(img_path).convert('RGB') 26 | got_img = True 27 | except IOError: 28 | print("IOError incurred when reading '{}'. Will redo. Don't worry. Just chill.".format(img_path)) 29 | pass 30 | return img 31 | 32 | 33 | class FewShotDataset_test(Dataset): 34 | """Few shot epoish Dataset 35 | 36 | Returns a task (Xtrain, Ytrain, Xtest, Ytest) to classify' 37 | Xtrain: [nKnovel*nExpemplars, c, h, w]. 38 | Ytrain: [nKnovel*nExpemplars]. 39 | Xtest: [nTestNovel, c, h, w]. 40 | Ytest: [nTestNovel]. 41 | """ 42 | 43 | def __init__(self, 44 | dataset, # dataset of [(img_path, cats), ...]. 45 | labels2inds, # labels of index {(cats: index1, index2, ...)}. 46 | labelIds, # train labels [0, 1, 2, 3, ...,]. 47 | nKnovel=5, # number of novel categories. 48 | nExemplars=1, # number of training examples per novel category. 49 | nTestNovel=2*5, # number of test examples for all the novel categories. 50 | epoch_size=2000, # number of tasks per eooch. 51 | transform=None, 52 | load=True, 53 | **kwargs 54 | ): 55 | 56 | self.dataset = dataset 57 | self.labels2inds = labels2inds 58 | self.labelIds = labelIds 59 | #print(labelIds) 60 | #print(len(labelIds)) 61 | #exit(0) 62 | self.nKnovel = nKnovel 63 | self.transform = transform 64 | 65 | self.nExemplars = nExemplars 66 | self.nTestNovel = nTestNovel 67 | self.epoch_size = epoch_size 68 | self.load = load 69 | #print(self.nExemplars,self.nTestNovel,self.epoch_size) #5,75,2000 70 | #exit(0) 71 | seed = 112 72 | random.seed(seed) 73 | np.random.seed(seed) 74 | 75 | self.Epoch_Exemplar = [] 76 | self.Epoch_Tnovel = [] 77 | for i in range(epoch_size): 78 | Tnovel, Exemplar = self._sample_episode() 79 | self.Epoch_Exemplar.append(Exemplar) 80 | self.Epoch_Tnovel.append(Tnovel) 81 | 82 | def __len__(self): 83 | return self.epoch_size 84 | 85 | def _sample_episode(self): 86 | """sampels a training epoish indexs. 87 | Returns: 88 | Tnovel: a list of length 'nTestNovel' with 2-element tuples. (sample_index, label) 89 | Exemplars: a list of length 'nKnovel * nExemplars' with 2-element tuples. (sample_index, label) 90 | """ 91 | 92 | Knovel = random.sample(self.labelIds, self.nKnovel) 93 | #print(Knovel) 94 | #exit(0) 95 | nKnovel = len(Knovel) 96 | assert((self.nTestNovel % nKnovel) == 0) 97 | nEvalExamplesPerClass = int(self.nTestNovel / nKnovel) 98 | #print(nEvalExamplesPerClass) 99 | #exit(0) 100 | Tnovel = [] 101 | Exemplars = [] 102 | for Knovel_idx in range(len(Knovel)): 103 | ids = (nEvalExamplesPerClass + self.nExemplars) 104 | img_ids = random.sample(self.labels2inds[Knovel[Knovel_idx]], ids) 105 | 106 | imgs_tnovel = img_ids[:nEvalExamplesPerClass] 107 | imgs_emeplars = img_ids[nEvalExamplesPerClass:] 108 | #print(imgs_tnovel) 109 | #exit(0) 110 | Tnovel += [(img_id, Knovel_idx) for img_id in imgs_tnovel] 111 | Exemplars += [(img_id, Knovel_idx) for img_id in imgs_emeplars] 112 | assert(len(Tnovel) == self.nTestNovel) 113 | assert(len(Exemplars) == nKnovel * self.nExemplars) 114 | random.shuffle(Exemplars) 115 | random.shuffle(Tnovel) 116 | 117 | return Tnovel, Exemplars 118 | 119 | def _creatExamplesTensorData(self, examples): 120 | """ 121 | Creats the examples image label tensor data. 122 | 123 | Args: 124 | examples: a list of 2-element tuples. (sample_index, label). 125 | 126 | Returns: 127 | images: a tensor [nExemplars, c, h, w] 128 | labels: a tensor [nExemplars] 129 | """ 130 | 131 | images = [] 132 | labels = [] 133 | for (img_idx, label) in examples: 134 | img = self.dataset[img_idx][0] 135 | #print(img) 136 | ##exit(0) 137 | if self.load: 138 | img = Image.fromarray(img) 139 | else: 140 | img = read_image(img) 141 | #print(img.size) 142 | #print(np.array(img).shape) 143 | #exit(0) 144 | if self.transform is not None: 145 | img = self.transform(img) 146 | #print(img.shape,'located in test_loader.py at 146') 147 | #exit(0) 148 | images.append(img) 149 | labels.append(label) 150 | images = torch.stack(images, dim=0) 151 | labels = torch.LongTensor(labels) 152 | return images, labels 153 | 154 | def __getitem__(self, index): 155 | Tnovel = self.Epoch_Tnovel[index] 156 | Exemplars = self.Epoch_Exemplar[index] 157 | Xt, Yt = self._creatExamplesTensorData(Exemplars) 158 | Xe, Ye = self._creatExamplesTensorData(Tnovel) 159 | return Xt, Yt, Xe, Ye 160 | 161 | -------------------------------------------------------------------------------- /torchFewShot/dataset_loader/train_image_ori_loader.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import print_function 3 | from __future__ import division 4 | 5 | import cv2 6 | import os 7 | from PIL import Image 8 | import numpy as np 9 | import os.path as osp 10 | import lmdb 11 | import io 12 | import random 13 | from scipy.misc import imread 14 | from skimage.feature import canny 15 | from skimage.color import rgb2gray, gray2rgb 16 | 17 | 18 | import torch 19 | from torch.utils.data import Dataset 20 | import torchvision.transforms.functional as F 21 | 22 | 23 | def read_image(img_path): 24 | """Keep reading image until succeed. 25 | This can avoid IOError incurred by heavy IO process.""" 26 | got_img = False 27 | if not osp.exists(img_path): 28 | raise IOError("{} does not exist".format(img_path)) 29 | while not got_img: 30 | try: 31 | img = Image.open(img_path).convert('RGB') 32 | got_img = True 33 | except IOError: 34 | print("IOError incurred when reading '{}'. Will redo. Don't worry. Just chill.".format(img_path)) 35 | pass 36 | return img 37 | 38 | 39 | class FewShotDataset_train_imgori(Dataset): 40 | """Few shot epoish Dataset 41 | 42 | Returns a task (Xtrain, Ytrain, Xtest, Ytest, Ycls) to classify' 43 | Xtrain: [nKnovel*nExpemplars, c, h, w]. 44 | Ytrain: [nKnovel*nExpemplars]. 45 | Xtest: [nTestNovel, c, h, w]. 46 | Ytest: [nTestNovel]. 47 | Ycls: [nTestNovel]. 48 | """ 49 | 50 | def __init__(self, 51 | dataset, # dataset of [(img_path, cats), ...]. 52 | labels2inds, # labels of index {(cats: index1, index2, ...)}. 53 | labelIds, # train labels [0, 1, 2, 3, ...,]. 54 | nKnovel=5, # number of novel categories. 55 | nExemplars=1, # number of training examples per novel category. 56 | nTestNovel=6*5, # number of test examples for all the novel categories. 57 | epoch_size=2000, # number of tasks per eooch. 58 | transform=None, 59 | load=False, 60 | **kwargs 61 | ): 62 | 63 | self.dataset = dataset 64 | #print(self.dataset) 65 | #exit(0) 66 | self.labels2inds = labels2inds 67 | self.labelIds = labelIds 68 | self.nKnovel = nKnovel 69 | self.transform = transform 70 | #print(len(self.labels2inds) ,len(self.labelIds),self.nKnovel,nExemplars) 71 | #exit(0) 72 | self.edge=1 73 | self.sigma=2 74 | self.nExemplars = nExemplars#5 75 | self.nTestNovel = nTestNovel 76 | self.epoch_size = epoch_size 77 | self.load = load 78 | 79 | def __len__(self): 80 | return self.epoch_size 81 | 82 | def _sample_episode(self): 83 | """sampels a training epoish indexs. 84 | Returns: 85 | Tnovel: a list of length 'nTestNovel' with 2-element tuples. (sample_index, label) 86 | Exemplars: a list of length 'nKnovel * nExemplars' with 2-element tuples. (sample_index, label) 87 | """ 88 | 89 | Knovel = random.sample(self.labelIds, self.nKnovel) 90 | nKnovel = len(Knovel) 91 | assert((self.nTestNovel % nKnovel) == 0) 92 | nEvalExamplesPerClass = int(self.nTestNovel / nKnovel) # 6 93 | 94 | Tnovel = [] 95 | Exemplars = [] 96 | for Knovel_idx in range(len(Knovel)): 97 | ids = (nEvalExamplesPerClass + self.nExemplars) 98 | img_ids = random.sample(self.labels2inds[Knovel[Knovel_idx]], ids) 99 | #print(img_ids) 100 | #exit(0) 101 | 102 | imgs_tnovel = img_ids[:nEvalExamplesPerClass] 103 | imgs_emeplars = img_ids[nEvalExamplesPerClass:] 104 | 105 | Tnovel += [(img_id, Knovel_idx) for img_id in imgs_tnovel] 106 | Exemplars += [(img_id, Knovel_idx) for img_id in imgs_emeplars] 107 | assert(len(Tnovel) == self.nTestNovel) 108 | assert(len(Exemplars) == nKnovel * self.nExemplars) 109 | #print(Exemplars) 110 | #exit(0) 111 | random.shuffle(Exemplars) 112 | random.shuffle(Tnovel) 113 | 114 | return Tnovel, Exemplars 115 | def load_edge(self, img, mask): 116 | sigma = self.sigma 117 | index=1 118 | # in test mode images are masked (with masked regions), 119 | # using 'mask' parameter prevents canny to detect edges for the masked regions 120 | mask = None #if self.training else (1 - mask / 255).astype(np.bool) 121 | #mask =(1 - mask / 255).astype(np.bool) 122 | # canny 123 | if self.edge == 1: 124 | # no edge 125 | if sigma == -1: 126 | return np.zeros(img.shape).astype(np.float) 127 | 128 | # random sigma 129 | if sigma == 0: 130 | sigma = random.randint(1, 4) 131 | 132 | return canny(img, sigma=sigma, mask=mask).astype(np.float) 133 | 134 | # external 135 | else: 136 | imgh, imgw = img.shape[0:2] 137 | edge = imread(self.edge_data[index]) 138 | edge = self.resize(edge, imgh, imgw) 139 | 140 | # non-max suppression 141 | if self.nms == 1: 142 | edge = edge * canny(img, sigma=sigma, mask=mask) 143 | 144 | return edge 145 | def _creatExamplesTensorData(self, examples): 146 | """ 147 | Creats the examples image label tensor data. 148 | 149 | Args: 150 | examples: a list of 2-element tuples. (sample_index, label). 151 | 152 | Returns: 153 | images: a tensor [nExemplars, c, h, w] 154 | labels: a tensor [nExemplars] 155 | cls: a tensor [nExemplars] 156 | """ 157 | 158 | images = [] 159 | labels = [] 160 | images_ori=[] 161 | edges=[] 162 | imgs_gray=[] 163 | masks=[] 164 | cls = [] 165 | #self.mask_root="/home/lijunjie/edge-connect-master/examples/fuse.png" 166 | #self.mask_root="/home/lijunjie/Pconv/PConv-Keras_mask01/stroke_4/img00000266_mask.png" 167 | #self.mask_root="/home/lijunjie/Pconv/PConv-Keras_mask01/stroke_4/img00021416_mask.png" 168 | #self.mask_root="/home/lijunjie/edge-connect-master/examples/Places365_val_00006822_mask.png" 169 | #index_mask=self.data[index].rfind('/') 170 | #name_mask=self.data[index][index_mask+1:len(self.data[index])] 171 | #self.mask_root='/home/lijunjie/generative_inpainting-master_global_local/data/random_rec_places_10000/'+name_mask 172 | #mask_root="/home/lijunjie/Pconv/PConv-Keras_mask01/stroke_4/"+mask_name[self.mask_id]+'.png' 173 | #mask = imread(self.mask_data[index]) 174 | 175 | #mask_3 = (imread(self.mask_root)/255).astype(dtype=np.uint8) 176 | #if len(mask_3.shape) < 3: 177 | #mask_3 = gray2rgb(mask_3) 178 | #print(mask_3.shape) 179 | 180 | #cv2.imwrite('./test_masked.png',img) 181 | for (img_idx, label) in examples: 182 | img_ori, ids = self.dataset[img_idx] 183 | #img_ori='/home/lijunjie/edge-connect-master/examples/test_result/input_000.png' 184 | #exit(0) 185 | if self.load: 186 | img_ori = Image.fromarray(img_ori) 187 | else: 188 | img_ori = read_image(img_ori) 189 | if self.transform is not None: 190 | img = self.transform(img_ori) 191 | img_gray = rgb2gray(np.array(img_ori)) 192 | #print(img_gray.shape) 193 | #exit(0) 194 | masked_img=np.array(img_ori)#*(1-mask_3)+mask_3*255 195 | masked_img=Image.fromarray(masked_img) 196 | masked_img_tensor=F.to_tensor(masked_img).float() 197 | #print(masked_img_tensor.shape) 198 | #exit(0) 199 | images_ori.append(masked_img_tensor) 200 | #cv2.imwrite('./test_masked.png',np.array(img_ori)*(1-mask_3)+mask_3*255) 201 | #mask = rgb2gray(mask_3) 202 | #mask = (mask > 0).astype(np.uint8) * 255 203 | #print(mask.shape)#(84,84) 204 | #exit(0) 205 | #mask_tensor=F.to_tensor(Image.fromarray(mask)).float() 206 | #masks.append(mask_tensor) 207 | #edge = self.load_edge(img_gray, None) 208 | #edge = self.load_edge(img_gray, mask) 209 | #print(edge.dtype,'lllkkkk') 210 | #exit(0) 211 | 212 | #edge_tensor=Image.fromarray( edge) 213 | #edge_tensor=F.to_tensor( edge).float() 214 | img_gray=Image.fromarray(img_gray) 215 | img_gray_tensor=F.to_tensor(img_gray).float() 216 | imgs_gray.append(img_gray_tensor) 217 | images.append(img) 218 | labels.append(label) 219 | #edges.append(edge_tensor) 220 | cls.append(ids) 221 | #print(type(images[0])) 222 | images = torch.stack(images, dim=0) 223 | #masks = torch.stack(masks, dim=0) 224 | #print(masks.shape,'llll')#(5,1,84,84) 225 | 226 | #print(images.shape) 227 | #exit(0) 228 | labels = torch.LongTensor(labels) 229 | images_ori = torch.stack(images_ori, dim=0) 230 | #edges = torch.stack(edges, dim=0) 231 | cls = torch.LongTensor(cls) 232 | imgs_gray = torch.stack(imgs_gray, dim=0) 233 | #print(imgs_gray.shape,'ljj') 234 | #exit(0) 235 | return images, labels, cls,images_ori,imgs_gray#,masks 236 | 237 | 238 | def __getitem__(self, index): 239 | Tnovel, Exemplars = self._sample_episode() 240 | Xt, Yt, Ytc,Xt_img_ori,Xt_img_gray = self._creatExamplesTensorData(Exemplars) 241 | Xe, Ye, Yec,Xe_img_ori,Xe_img_gray= self._creatExamplesTensorData(Tnovel) 242 | return Xt, Yt,Ytc,Xt_img_ori,Xt_img_gray , Xe, Ye, Yec#,Xe_img_ori,Xe_edges,Xe_img_gray,Xe_masks 243 | -------------------------------------------------------------------------------- /torchFewShot/dataset_loader/train_inpainting_loader.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import print_function 3 | from __future__ import division 4 | 5 | import os 6 | import cv2 7 | from PIL import Image 8 | import numpy as np 9 | import os.path as osp 10 | import sys 11 | sys.path.append("/ghome/lijj/python_package/") 12 | import lmdb 13 | import io 14 | import random 15 | 16 | import torch 17 | from torch.utils.data import Dataset 18 | 19 | 20 | def read_image(img_path): 21 | """Keep reading image until succeed. 22 | This can avoid IOError incurred by heavy IO process.""" 23 | got_img = False 24 | if not osp.exists(img_path): 25 | raise IOError("{} does not exist".format(img_path)) 26 | while not got_img: 27 | try: 28 | img = Image.open(img_path).convert('RGB') 29 | #img=img.resize((100, 100),Image.ANTIALIAS) 30 | got_img = True 31 | except IOError: 32 | print("IOError incurred when reading '{}'. Will redo. Don't worry. Just chill.".format(img_path)) 33 | pass 34 | return img 35 | 36 | 37 | class FewShotDataset_train_inpainting(Dataset): 38 | """Few shot epoish Dataset 39 | 40 | Returns a task (Xtrain, Ytrain, Xtest, Ytest, Ycls) to classify' 41 | Xtrain: [nKnovel*nExpemplars, c, h, w]. 42 | Ytrain: [nKnovel*nExpemplars]. 43 | Xtest: [nTestNovel, c, h, w]. 44 | Ytest: [nTestNovel]. 45 | Ycls: [nTestNovel]. 46 | """ 47 | 48 | def __init__(self, 49 | dataset, # dataset of [(img_path, cats), ...]. 50 | labels2inds, # labels of index {(cats: index1, index2, ...)}. 51 | labelIds, # train labels [0, 1, 2, 3, ...,]. 52 | nKnovel=5, # number of novel categories. 53 | nExemplars=1, # number of training examples per novel category. 54 | nTestNovel=6*5, # number of test examples for all the novel categories. 55 | epoch_size=2000, # number of tasks per eooch. 56 | transform=None, 57 | load=False, 58 | **kwargs 59 | ): 60 | 61 | self.dataset = dataset 62 | #print(self.dataset) 63 | #exit(0) 64 | self.labels2inds = labels2inds 65 | self.labelIds = labelIds 66 | self.nKnovel = nKnovel 67 | self.transform = transform 68 | #print(len(self.labels2inds) ,len(self.labelIds),self.nKnovel,nExemplars) 69 | #exit(0) 70 | 71 | self.nExemplars = nExemplars#5 72 | self.nTestNovel = nTestNovel 73 | self.epoch_size = epoch_size 74 | self.load = load 75 | 76 | def __len__(self): 77 | return self.epoch_size 78 | 79 | def _sample_episode(self): 80 | """sampels a training epoish indexs. 81 | Returns: 82 | Tnovel: a list of length 'nTestNovel' with 2-element tuples. (sample_index, label) 83 | Exemplars: a list of length 'nKnovel * nExemplars' with 2-element tuples. (sample_index, label) 84 | """ 85 | 86 | Knovel = random.sample(self.labelIds, self.nKnovel) 87 | nKnovel = len(Knovel) 88 | assert((self.nTestNovel % nKnovel) == 0) 89 | nEvalExamplesPerClass = int(self.nTestNovel / nKnovel) # 6 90 | 91 | Tnovel = [] 92 | Exemplars = [] 93 | for Knovel_idx in range(len(Knovel)): 94 | ids = (nEvalExamplesPerClass + self.nExemplars) 95 | img_ids = random.sample(self.labels2inds[Knovel[Knovel_idx]], ids) 96 | #print(img_ids) 97 | #exit(0) 98 | 99 | imgs_tnovel = img_ids[:nEvalExamplesPerClass] 100 | imgs_emeplars = img_ids[nEvalExamplesPerClass:] 101 | 102 | Tnovel += [(img_id, Knovel_idx) for img_id in imgs_tnovel] 103 | Exemplars += [(img_id, Knovel_idx) for img_id in imgs_emeplars] 104 | assert(len(Tnovel) == self.nTestNovel) 105 | assert(len(Exemplars) == nKnovel * self.nExemplars) 106 | #print(Exemplars) 107 | #exit(0) 108 | random.shuffle(Exemplars) 109 | random.shuffle(Tnovel) 110 | 111 | return Tnovel, Exemplars 112 | 113 | def _creatExamplesTensorData(self, examples): 114 | """ 115 | Creats the examples image label tensor data. 116 | 117 | Args: 118 | examples: a list of 2-element tuples. (sample_index, label). 119 | 120 | Returns: 121 | images: a tensor [nExemplars, c, h, w] 122 | labels: a tensor [nExemplars] 123 | cls: a tensor [nExemplars] 124 | """ 125 | 126 | images = [] 127 | images1 = [] 128 | images2 = [] 129 | images3 = [] 130 | images4 = [] 131 | 132 | images5 = [] 133 | images6 = [] 134 | images7 = [] 135 | images8 = [] 136 | labels = [] 137 | cls = [] 138 | for (img_idx, label) in examples: 139 | img, ids = self.dataset[img_idx] 140 | #print(img) 141 | #exit(0) 142 | img1=img.replace('train','train_1') 143 | img2=img.replace('train','train_2') 144 | img3=img.replace('train','train_3') 145 | img4=img.replace('train','train_4') 146 | img5=img.replace('train','train_5/train_5') 147 | img6=img.replace('train','train_6/train_6') 148 | img7=img.replace('train','train_7/train_7') 149 | img8=img.replace('train','train_8/train_8') 150 | 151 | #img5=img.replace('train','train_5') 152 | #img6=img.replace('train','train_6') 153 | #img7=img.replace('train','train_7') 154 | #img8=img.replace('train','train_8') 155 | 156 | 157 | # print(img1,img2,img3,img4) 158 | #exit(0) 159 | #print(img8) 160 | 161 | #imtest=cv2.imread(img8) 162 | #print(imtest.shape) 163 | #exit(0) 164 | if self.load: 165 | img = Image.fromarray(img) 166 | else: 167 | img = read_image(img) 168 | img1 = read_image(img1) 169 | img2 = read_image(img2) 170 | img3 = read_image(img3) 171 | img4 = read_image(img4) 172 | 173 | img5 = read_image(img5) 174 | img6 = read_image(img6) 175 | img7 = read_image(img7) 176 | img8 = read_image(img8) 177 | #img=img.resize((100, 100),Image.ANTIALIAS) 178 | 179 | #img1=img1.resize((100, 100),Image.ANTIALIAS) 180 | #img2=img2.resize((100, 100),Image.ANTIALIAS) 181 | #img3=img3.resize((100, 100),Image.ANTIALIAS) 182 | #img4=img4.resize((100, 100),Image.ANTIALIAS) 183 | #img5=img5.resize((100, 100),Image.ANTIALIAS) 184 | #img6=img6.resize((100, 100),Image.ANTIALIAS) 185 | #img7=img7.resize((100, 100),Image.ANTIALIAS) 186 | #img8=img8.resize((100, 100),Image.ANTIALIAS) 187 | #print('pppppppppp') 188 | #exit(0) 189 | if self.transform is not None: 190 | img = self.transform(img) 191 | img1 = self.transform(img1) 192 | img2 = self.transform(img2) 193 | img3 = self.transform(img3) 194 | img4 = self.transform(img4) 195 | 196 | img5 = self.transform(img5) 197 | img6 = self.transform(img6) 198 | img7 = self.transform(img7) 199 | img8 = self.transform(img8) 200 | images.append(img) 201 | images1.append(img1) 202 | images2.append(img2) 203 | images3.append(img3) 204 | images4.append(img4) 205 | 206 | images5.append(img5) 207 | images6.append(img6) 208 | images7.append(img7) 209 | images8.append(img8) 210 | labels.append(label) 211 | cls.append(ids) 212 | images = torch.stack(images, dim=0) 213 | images1 = torch.stack(images1, dim=0) 214 | images2 = torch.stack(images2, dim=0) 215 | images3 = torch.stack(images3, dim=0) 216 | images4 = torch.stack(images4, dim=0) 217 | 218 | images5 = torch.stack(images5, dim=0) 219 | images6 = torch.stack(images6, dim=0) 220 | images7 = torch.stack(images7, dim=0) 221 | images8 = torch.stack(images8, dim=0) 222 | labels = torch.LongTensor(labels) 223 | cls = torch.LongTensor(cls) 224 | return images,images1,images2,images3,images4,images5,images6,images7,images8, labels, cls 225 | 226 | 227 | def __getitem__(self, index): 228 | Tnovel, Exemplars = self._sample_episode() 229 | Xt,Xt1,Xt2,Xt3,Xt4,Xt5,Xt6,Xt7,Xt8, Yt, Ytc = self._creatExamplesTensorData(Exemplars) 230 | Xe, Xe1, Xe2, Xe3, Xe4, Xe5, Xe6, Xe7, Xe8, Ye, Yec = self._creatExamplesTensorData(Tnovel) 231 | return Xt,Xt1,Xt2,Xt3,Xt4,Xt5,Xt6,Xt7,Xt8, Yt, Ytc, Xe, Xe1, Xe2, Xe3, Xe4, Ye, Yec 232 | -------------------------------------------------------------------------------- /torchFewShot/dataset_loader/train_loader.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import print_function 3 | from __future__ import division 4 | 5 | import os 6 | from PIL import Image 7 | import numpy as np 8 | import os.path as osp 9 | import lmdb 10 | import io 11 | import random 12 | 13 | import torch 14 | from torch.utils.data import Dataset 15 | 16 | 17 | def read_image(img_path): 18 | """Keep reading image until succeed. 19 | This can avoid IOError incurred by heavy IO process.""" 20 | got_img = False 21 | if not osp.exists(img_path): 22 | raise IOError("{} does not exist".format(img_path)) 23 | while not got_img: 24 | try: 25 | img = Image.open(img_path).convert('RGB') 26 | got_img = True 27 | except IOError: 28 | print("IOError incurred when reading '{}'. Will redo. Don't worry. Just chill.".format(img_path)) 29 | pass 30 | return img 31 | 32 | 33 | class FewShotDataset_train(Dataset): 34 | """Few shot epoish Dataset 35 | 36 | Returns a task (Xtrain, Ytrain, Xtest, Ytest, Ycls) to classify' 37 | Xtrain: [nKnovel*nExpemplars, c, h, w]. 38 | Ytrain: [nKnovel*nExpemplars]. 39 | Xtest: [nTestNovel, c, h, w]. 40 | Ytest: [nTestNovel]. 41 | Ycls: [nTestNovel]. 42 | """ 43 | 44 | def __init__(self, 45 | dataset, # dataset of [(img_path, cats), ...]. 46 | labels2inds, # labels of index {(cats: index1, index2, ...)}. 47 | labelIds, # train labels [0, 1, 2, 3, ...,]. 48 | nKnovel=5, # number of novel categories. 49 | nExemplars=1, # number of training examples per novel category. 50 | nTestNovel=6*5, # number of test examples for all the novel categories. 51 | epoch_size=2000, # number of tasks per eooch. 52 | transform=None, 53 | load=False, 54 | **kwargs 55 | ): 56 | 57 | self.dataset = dataset 58 | #print(self.dataset) 59 | #exit(0) 60 | self.labels2inds = labels2inds 61 | self.labelIds = labelIds 62 | self.nKnovel = nKnovel 63 | self.transform = transform 64 | #print(len(self.labels2inds) ,len(self.labelIds),self.nKnovel,nExemplars) 65 | #exit(0) 66 | 67 | self.nExemplars = nExemplars#5 68 | self.nTestNovel = nTestNovel 69 | self.epoch_size = epoch_size 70 | self.load = load 71 | 72 | def __len__(self): 73 | return self.epoch_size 74 | 75 | def _sample_episode(self): 76 | """sampels a training epoish indexs. 77 | Returns: 78 | Tnovel: a list of length 'nTestNovel' with 2-element tuples. (sample_index, label) 79 | Exemplars: a list of length 'nKnovel * nExemplars' with 2-element tuples. (sample_index, label) 80 | """ 81 | 82 | Knovel = random.sample(self.labelIds, self.nKnovel) 83 | nKnovel = len(Knovel) 84 | assert((self.nTestNovel % nKnovel) == 0) 85 | nEvalExamplesPerClass = int(self.nTestNovel / nKnovel) # 6 86 | 87 | Tnovel = [] 88 | Exemplars = [] 89 | for Knovel_idx in range(len(Knovel)): 90 | ids = (nEvalExamplesPerClass + self.nExemplars) 91 | img_ids = random.sample(self.labels2inds[Knovel[Knovel_idx]], ids) 92 | #print(img_ids) 93 | #exit(0) 94 | 95 | imgs_tnovel = img_ids[:nEvalExamplesPerClass] 96 | imgs_emeplars = img_ids[nEvalExamplesPerClass:] 97 | 98 | Tnovel += [(img_id, Knovel_idx) for img_id in imgs_tnovel] 99 | Exemplars += [(img_id, Knovel_idx) for img_id in imgs_emeplars] 100 | assert(len(Tnovel) == self.nTestNovel) 101 | assert(len(Exemplars) == nKnovel * self.nExemplars) 102 | #print(Exemplars) 103 | #exit(0) 104 | random.shuffle(Exemplars) 105 | random.shuffle(Tnovel) 106 | 107 | return Tnovel, Exemplars 108 | 109 | def _creatExamplesTensorData(self, examples): 110 | """ 111 | Creats the examples image label tensor data. 112 | 113 | Args: 114 | examples: a list of 2-element tuples. (sample_index, label). 115 | 116 | Returns: 117 | images: a tensor [nExemplars, c, h, w] 118 | labels: a tensor [nExemplars] 119 | cls: a tensor [nExemplars] 120 | """ 121 | 122 | images = [] 123 | labels = [] 124 | cls = [] 125 | for (img_idx, label) in examples: 126 | img, ids = self.dataset[img_idx] 127 | if self.load: 128 | img = Image.fromarray(img) 129 | else: 130 | img = read_image(img) 131 | if self.transform is not None: 132 | img = self.transform(img) 133 | images.append(img) 134 | labels.append(label) 135 | cls.append(ids) 136 | images = torch.stack(images, dim=0) 137 | labels = torch.LongTensor(labels) 138 | cls = torch.LongTensor(cls) 139 | return images, labels, cls 140 | 141 | 142 | def __getitem__(self, index): 143 | Tnovel, Exemplars = self._sample_episode() 144 | Xt, Yt, Ytc = self._creatExamplesTensorData(Exemplars) 145 | Xe, Ye, Yec = self._creatExamplesTensorData(Tnovel) 146 | return Xt, Yt, Xe, Ye, Yec 147 | -------------------------------------------------------------------------------- /torchFewShot/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | from .miniImageNet import miniImageNet 6 | 7 | 8 | 9 | __imgfewshot_factory = { 10 | 'miniImageNet': miniImageNet, 11 | } 12 | 13 | 14 | def get_names(): 15 | return list(__imgfewshot_factory.keys()) 16 | 17 | 18 | def init_imgfewshot_dataset(name, **kwargs): 19 | if name not in list(__imgfewshot_factory.keys()): 20 | raise KeyError("Invalid dataset, got '{}', but expected to be one of {}".format(name, list(__imgfewshot_factory.keys()))) 21 | #print('name is: ',name,'kwargs: ',**kwargs) 22 | #exit(0) 23 | return __imgfewshot_factory[name](**kwargs) 24 | 25 | -------------------------------------------------------------------------------- /torchFewShot/datasets/miniImageNet.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import os 6 | import torch 7 | 8 | class miniImageNet(object): 9 | """ 10 | Dataset statistics: 11 | # 64 * 600 (train) + 16 * 600 (val) + 20 * 600 (test) 12 | """ 13 | #dataset_dir = '/data4/lijunjie/mini-imagenet-tools/processed_images/' 14 | dataset_dir = "/home/yfchen/ljj_code/mini_imagenet/" 15 | def __init__(self): 16 | super(miniImageNet, self).__init__() 17 | self.train_dir = os.path.join(self.dataset_dir, 'train') 18 | self.val_dir = os.path.join(self.dataset_dir, 'val') 19 | self.test_dir = os.path.join(self.dataset_dir, 'test') 20 | 21 | self.train, self.train_labels2inds, self.train_labelIds = self._process_dir(self.train_dir) 22 | self.val, self.val_labels2inds, self.val_labelIds = self._process_dir(self.val_dir) 23 | self.test, self.test_labels2inds, self.test_labelIds = self._process_dir(self.test_dir) 24 | 25 | self.num_train_cats = len(self.train_labelIds) 26 | num_total_cats = len(self.train_labelIds) + len(self.val_labelIds) + len(self.test_labelIds) 27 | num_total_imgs = len(self.train + self.val + self.test) 28 | 29 | print("=> MiniImageNet loaded") 30 | print("Dataset statistics:") 31 | print(" ------------------------------") 32 | print(" subset | # cats | # images") 33 | print(" ------------------------------") 34 | print(" train | {:5d} | {:8d}".format(len(self.train_labelIds), len(self.train))) 35 | print(" val | {:5d} | {:8d}".format(len(self.val_labelIds), len(self.val))) 36 | print(" test | {:5d} | {:8d}".format(len(self.test_labelIds), len(self.test))) 37 | print(" ------------------------------") 38 | print(" total | {:5d} | {:8d}".format(num_total_cats, num_total_imgs)) 39 | print(" ------------------------------") 40 | 41 | def _check_before_run(self): 42 | """Check if all files are available before going deeper""" 43 | if not osp.exists(self.dataset_dir): 44 | raise RuntimeError("'{}' is not available".format(self.dataset_dir)) 45 | if not osp.exists(self.train_dir): 46 | raise RuntimeError("'{}' is not available".format(self.train_dir)) 47 | if not osp.exists(self.val_dir): 48 | raise RuntimeError("'{}' is not available".format(self.val_dir)) 49 | if not osp.exists(self.test_dir): 50 | raise RuntimeError("'{}' is not available".format(self.test_dir)) 51 | 52 | def _process_dir(self, dir_path): 53 | cat_container = sorted(os.listdir(dir_path)) 54 | cats2label = {cat:label for label, cat in enumerate(cat_container)} 55 | #print(cat_container,cats2label) 56 | #exit(0) 57 | dataset = [] 58 | labels = [] 59 | for cat in cat_container: 60 | for img_path in sorted(os.listdir(os.path.join(dir_path, cat))): 61 | if '.jpg' not in img_path: 62 | continue 63 | label = cats2label[cat] 64 | dataset.append((os.path.join(dir_path, cat, img_path), label)) 65 | labels.append(label) 66 | 67 | labels2inds = {} 68 | for idx, label in enumerate(labels): 69 | if label not in labels2inds: 70 | labels2inds[label] = [] 71 | labels2inds[label].append(idx) 72 | 73 | labelIds = sorted(labels2inds.keys()) 74 | return dataset, labels2inds, labelIds 75 | -------------------------------------------------------------------------------- /torchFewShot/losses.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | 4 | import torch 5 | import torch.nn as nn 6 | 7 | class CrossEntropyLoss(nn.Module): 8 | def __init__(self): 9 | super(CrossEntropyLoss, self).__init__() 10 | self.logsoftmax = nn.LogSoftmax(dim=1) 11 | 12 | def forward(self, inputs, targets): 13 | inputs = inputs.view(inputs.size(0), inputs.size(1), -1) 14 | 15 | log_probs = self.logsoftmax(inputs) 16 | targets = torch.zeros(inputs.size(0), inputs.size(1)).scatter_(1, targets.unsqueeze(1).data.cpu(), 1) 17 | targets = targets.unsqueeze(-1) 18 | targets = targets.cuda() 19 | loss = (- targets * log_probs).mean(0).sum() 20 | return loss / inputs.size(2) 21 | -------------------------------------------------------------------------------- /torchFewShot/models/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /torchFewShot/models/cam.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | 4 | import torch 5 | import math 6 | from torch import nn 7 | from torch.nn import functional as F 8 | 9 | 10 | class ConvBlock(nn.Module): 11 | """Basic convolutional block: 12 | convolution + batch normalization. 13 | 14 | Args (following http://pytorch.org/docs/master/nn.html#torch.nn.Conv2d): 15 | - in_c (int): number of input channels. 16 | - out_c (int): number of output channels. 17 | - k (int or tuple): kernel size. 18 | - s (int or tuple): stride. 19 | - p (int or tuple): padding. 20 | """ 21 | def __init__(self, in_c, out_c, k, s=1, p=0): 22 | super(ConvBlock, self).__init__() 23 | self.conv = nn.Conv2d(in_c, out_c, k, stride=s, padding=p) 24 | self.bn = nn.BatchNorm2d(out_c) 25 | 26 | def forward(self, x): 27 | return self.bn(self.conv(x)) 28 | 29 | 30 | class CAM(nn.Module): 31 | def __init__(self): 32 | super(CAM, self).__init__() 33 | self.conv1 = ConvBlock(36, 6, 1) 34 | self.conv2 = nn.Conv2d(6, 36, 1, stride=1, padding=0) 35 | 36 | for m in self.modules(): 37 | if isinstance(m, nn.Conv2d): 38 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 39 | m.weight.data.normal_(0, math.sqrt(2. / n)) 40 | 41 | def get_attention(self, a): 42 | input_a = a#[4,5,30,36,36] 43 | #print(torch.mean(input_a[:,:,:,0,0], -1).shape) 44 | a = a.mean(3) #[4,5,30,36] 45 | #print(a.shape) 46 | #exit(0) 47 | a = a.transpose(1, 3) #[4,36,30,5] 48 | a = F.relu(self.conv1(a))#[4,6,30,5] 49 | a = self.conv2(a) #[4,36,30,5] 50 | a = a.transpose(1, 3)#[4,5,30,36] 51 | a = a.unsqueeze(3) #[4,5,30,1,36] 52 | 53 | a = torch.mean(input_a * a, -1)#[4,5,30,36] 54 | a = F.softmax(a / 0.025, dim=-1) + 1 55 | return a 56 | 57 | def forward(self, f1, f2,test=False): 58 | b, n1, c, h, w = f1.size() 59 | n2 = f2.size(1) 60 | 61 | f1 = f1.view(b, n1, c, -1) #[4,5,512,36] 62 | f2 = f2.view(b, n2, c, -1) #[4,30,512,36] 63 | #print(f1.shape,f2.shape) 64 | #exit(0) 65 | f1_norm = F.normalize(f1, p=2, dim=2, eps=1e-12) 66 | f2_norm = F.normalize(f2, p=2, dim=2, eps=1e-12) 67 | 68 | f1_norm = f1_norm.transpose(2, 3).unsqueeze(2)#[4,5,1,36,512] 69 | 70 | f2_norm = f2_norm.unsqueeze(1)#[4,1,30,512,36] 71 | #print(f1_norm.shape,f2_norm.shape) 72 | #exit(0) 73 | a1 = torch.matmul(f1_norm, f2_norm) #[4,5,30,36,36] 74 | #print('The shape of a1 before get_attention: ', a1.shape,'located in cam.py at 72') 75 | #exit(0) 76 | a2 = a1.transpose(3, 4) #[4,5,30,36,36] 77 | 78 | a1 = self.get_attention(a1)#[4,5,30,36] 79 | a2 = self.get_attention(a2)#[4,5,30,36] 80 | #print('The shape of a1 after get_attention: ', a1.shape,'located in cam.py at 72') 81 | #print('The shape of a2 after get_attention: ', a2.shape,'located in cam.py at 72') 82 | #print(f1.unsqueeze(2).shape,a1.unsqueeze(3).shape)#[4,5,1,512,36],#[4,5,30,1,36] 83 | #exit(0) 84 | f1 = f1.unsqueeze(2) * a1.unsqueeze(3) 85 | f1 = f1.view(b, n1, n2, c, h, w)#[4,5,30,512,6,6] 86 | f2 = f2.unsqueeze(1) * a2.unsqueeze(3) 87 | f2 = f2.view(b, n1, n2, c, h, w)#[4,5,30,512,6,6] 88 | #print(f1.shape,f2.shape,'located in cam.py at 88') 89 | #exit(0) 90 | if test: 91 | return f1.transpose(1, 2), f2.transpose(1, 2),a1.view(b, n1, n2, h, w),a2.view(b, n1, n2, h, w) 92 | else: 93 | return f1.transpose(1, 2), f2.transpose(1, 2) 94 | class CAM_similarity(nn.Module): 95 | def __init__(self): 96 | super(CAM_similarity, self).__init__() 97 | self.conv1 = ConvBlock(18468, 6, 1) 98 | self.conv2 = nn.Conv2d(6, 36, 1, stride=1, padding=0) 99 | 100 | for m in self.modules(): 101 | if isinstance(m, nn.Conv2d): 102 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 103 | m.weight.data.normal_(0, math.sqrt(2. / n)) 104 | 105 | def get_attention(self, a,features): 106 | input_a = a#[4,5,30,36,36] 107 | #print(torch.mean(input_a[:,:,:,0,0], -1).shape) 108 | a = a.mean(3) #[4,5,30,36] 109 | #print(a.shape,'located in cam at 109 in CAM_similarity') 110 | #exit(0) 111 | b,train_n,test_n,c,sptial=features.size() 112 | a=a.unsqueeze(3) 113 | a=torch.cat([a,features],3) 114 | a=a.view(b,train_n,test_n,-1)#[4,5,30,512,6,6] 115 | #print(a.shape) 116 | #exit(0) 117 | a = a.transpose(1, 3) #[4,36,30,5] 118 | a = F.relu(self.conv1(a))#[4,6,30,5] 119 | a = self.conv2(a) #[4,36,30,5] 120 | a = a.transpose(1, 3)#[4,5,30,36] 121 | a = a.unsqueeze(3) #[4,5,30,1,36] 122 | 123 | a = torch.mean(input_a * a, -1)#[4,5,30,36] 124 | a = F.softmax(a / 0.025, dim=-1) + 1 125 | return a 126 | 127 | def forward(self, f1, f2,test=False): 128 | b, n1, c, h, w = f1.size() 129 | n2 = f2.size(1) 130 | 131 | f1 = f1.view(b, n1, c, -1) #[4,5,512,36] 132 | f2 = f2.view(b, n2, c, -1) 133 | #print(f1.shape,f2.shape) 134 | #exit(0) 135 | f1_norm = F.normalize(f1, p=2, dim=2, eps=1e-12)#[4,5,512,36] 136 | f2_norm = F.normalize(f2, p=2, dim=2, eps=1e-12)#[4,30,512,36] 137 | #print(f1_norm.shape,f2_norm.shape) 138 | #added by ljj 139 | f1_expand=f1_norm.unsqueeze(2).expand(f1_norm.shape[0],f1_norm.shape[1],f2_norm.shape[1],f1_norm.shape[2],f1_norm.shape[3]) 140 | f2_expand=f2_norm.unsqueeze(1).expand(f1_norm.shape[0],f1_norm.shape[1],f2_norm.shape[1],f1_norm.shape[2],f1_norm.shape[3]) 141 | #print(f1_norm.shape,f2_norm.shape) 142 | #for i in range(30): 143 | #print((f1_expand[:,:,i,:,:]-f1_norm).abs().sum()) 144 | #for i in range(5): 145 | # print((f2_expand[:,i,:,:,:]-f2_norm).abs().sum()) 146 | #exit(0) 147 | #print(f1_expand.shape,f2_expand.shape) 148 | f1_f2=f2_expand#-f2_expand 149 | f2_f1=f1_expand#-f1_expand 150 | #added end 151 | #print(f1_f2.shape) 152 | #exit(0) 153 | f1_norm = f1_norm.transpose(2, 3).unsqueeze(2) # 154 | 155 | f2_norm = f2_norm.unsqueeze(1) 156 | #print(f1_norm.shape,f2_norm.shape) 157 | #exit(0) 158 | a1 = torch.matmul(f1_norm, f2_norm) #[4,5,30,36,36] 159 | #print('The shape of a1 before get_attention: ', a1.shape,'located in cam.py at 72') 160 | #exit(0) 161 | a2 = a1.transpose(3, 4) #[4,5,30,36,36] 162 | 163 | a1 = self.get_attention(a1,f1_f2)#[4,5,30,36] 164 | a2 = self.get_attention(a2,f2_f1)#[4,5,30,36] 165 | #print('The shape of a1 after get_attention: ', a1.shape,'located in cam.py at 72') 166 | #print('The shape of a2 after get_attention: ', a2.shape,'located in cam.py at 72') 167 | #print(f1.unsqueeze(2).shape,a1.unsqueeze(3).shape)#[4,5,1,512,36],#[4,5,30,1,36] 168 | #exit(0) 169 | f1 = f1.unsqueeze(2) * a1.unsqueeze(3) 170 | f1 = f1.view(b, n1, n2, c, h, w)#[4,5,30,512,6,6] 171 | f2 = f2.unsqueeze(1) * a2.unsqueeze(3) 172 | f2 = f2.view(b, n1, n2, c, h, w)#[4,5,30,512,6,6] 173 | #print(f1.shape,f2.shape,'located in cam.py at 88') 174 | #exit(0) 175 | if test: 176 | return f1.transpose(1, 2), f2.transpose(1, 2),a1.view(b, n1, n2, h, w),a2.view(b, n1, n2, h, w) 177 | else: 178 | return f1.transpose(1, 2), f2.transpose(1, 2) -------------------------------------------------------------------------------- /torchFewShot/models/channel_wise_attention.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | 3 | 4 | class SELayer(nn.Module): 5 | def __init__(self, channel, reduction=16): 6 | super(SELayer, self).__init__() 7 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 8 | self.fc = nn.Sequential( 9 | nn.Linear(channel, channel // reduction, bias=False), 10 | nn.ReLU(inplace=True), 11 | nn.Linear(channel // reduction, 512, bias=False), 12 | nn.Sigmoid() 13 | ) 14 | 15 | def forward(self, x): 16 | b, c, _, _ = x.size() 17 | y = self.avg_pool(x).view(b, c) 18 | #print(y.shape,self.avg_pool(x).shape) 19 | #exit() 20 | y = self.fc(y).view(b, 512, 1, 1) 21 | #print(y.shape,'lkkk') 22 | #exit() 23 | return y -------------------------------------------------------------------------------- /torchFewShot/models/net.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import sys 6 | 7 | sys.path.append(r"./torchFewShot/models/") 8 | 9 | from resnet12 import resnet12 10 | from cam import CAM,CAM_similarity 11 | 12 | 13 | class Model(nn.Module): 14 | def __init__(self, scale_cls, num_classes=64): 15 | super(Model, self).__init__() 16 | self.scale_cls = scale_cls 17 | 18 | self.base = resnet12() 19 | self.cam = CAM() 20 | 21 | self.nFeat = self.base.nFeat 22 | self.clasifier = nn.Conv2d(self.nFeat, num_classes, kernel_size=1) 23 | 24 | def test(self, ftrain, ftest): 25 | ftest = ftest.mean(4) 26 | ftest = ftest.mean(4) 27 | ftest = F.normalize(ftest, p=2, dim=ftest.dim()-1, eps=1e-12) 28 | ftrain = F.normalize(ftrain, p=2, dim=ftrain.dim()-1, eps=1e-12) 29 | scores = self.scale_cls * torch.sum(ftest * ftrain, dim=-1) 30 | return scores 31 | 32 | def forward(self, xtrain, xtest, ytrain, ytest,test_fg=False): 33 | batch_size, num_train = xtrain.size(0), xtrain.size(1) 34 | #print(xtrain.shape,xtest.shape)#[4,25,3,84,84],[4,30,3,84,84] 35 | #exit(0) 36 | num_test = xtest.size(1) 37 | K = ytrain.size(2) 38 | ytrain = ytrain.transpose(1, 2)#[4,5,25] 39 | #print(batch_size) 40 | xtrain = xtrain.view(-1, xtrain.size(2), xtrain.size(3), xtrain.size(4))#[100,3,84,84] 41 | xtest = xtest.view(-1, xtest.size(2), xtest.size(3), xtest.size(4))#[120,3,84,84] 42 | x = torch.cat((xtrain, xtest), 0)#(220,3,84,84) 43 | #print(x.shape,xtrain.shape,ytrain.shape,xtest.shape,num_train,'llll',) 44 | #exit(0) 45 | f = self.base(x)#[220,512,6,6] 46 | #print(f.shape,'located in net.py at 42')#[220,512,6,6] 47 | #exit(0) 48 | #print(ytrain[0,:,1:10]) 49 | 50 | #exit(0) 51 | ftrain = f[:batch_size * num_train] 52 | ftrain = ftrain.view(batch_size, num_train, -1)#[4,25,18432] 53 | #print(ftrain.shape) 54 | #exit(0) 55 | ftrain = torch.bmm(ytrain, ftrain)#(4,5,18432),it is matrix multiply [4,5,25],[4,25,18432] 56 | #print(ytrain.sum(dim=2, keepdim=True).shape) 57 | #print(ftrain.div(ytrain.sum(dim=2, keepdim=True)) 58 | #print(ytrain.sum(dim=2, keepdim=True).expand_as(ftrain)[0,:3,:10]) 59 | #exit(0) 60 | ftrain = ftrain.div(ytrain.sum(dim=2, keepdim=True).expand_as(ftrain))# each row is the mean value of all support images of crospond class [4,5,18432] 61 | ftrain = ftrain.view(batch_size, -1, *f.size()[1:])#[4,5,512,6,6] 62 | #print(ftrain.shape) 63 | #exit(0) 64 | ftest = f[batch_size * num_train:] 65 | ftest = ftest.view(batch_size, num_test, *f.size()[1:])#[4,30,512,6,6] 66 | #print(ftest.shape,ftrain.shape,'lllllllllllllll') 67 | #exit(0) 68 | if not test_fg: 69 | ftrain, ftest = self.cam(ftrain, ftest,test_fg)##[4,30,5,512,6,6],[4,30,5,512,6,6] 70 | else: 71 | ftrain, ftest,a1,a2 = self.cam(ftrain, ftest,test_fg) 72 | ftrain = ftrain.mean(4) 73 | ftrain = ftrain.mean(4)#[4,30,5,512] 74 | 75 | if not self.training: 76 | if test_fg: 77 | return self.test(ftrain, ftest),a1,a2 78 | else: 79 | return self.test(ftrain, ftest) 80 | 81 | ftest_norm = F.normalize(ftest, p=2, dim=3, eps=1e-12)#[4,30,5,512,6,6] 82 | ftrain_norm = F.normalize(ftrain, p=2, dim=3, eps=1e-12)#[4,30,5,512] 83 | #print(ftest_norm.shape,ftrain_norm.shape,'located in net.py at 74') 84 | #exit(0) 85 | ftrain_norm = ftrain_norm.unsqueeze(4) 86 | ftrain_norm = ftrain_norm.unsqueeze(5)#[4,30,5,512,1,1] 87 | #print(ftest_norm.shape,self.scale_cls,K) 88 | #exit(0) 89 | cls_scores = self.scale_cls * torch.sum(ftest_norm * ftrain_norm, dim=3)#[4,30,5,6,6] 90 | #print(cls_scores.shape,'located in net.py at 79') 91 | #exit(0) 92 | cls_scores = cls_scores.view(batch_size * num_test, *cls_scores.size()[2:])#[120,5,6,6] 93 | #print(cls_scores.shape,'located in net.py at 79') 94 | #exit(0) 95 | ftest = ftest.view(batch_size, num_test, K, -1)#[4,30,5,18432] 96 | ftest = ftest.transpose(2, 3) #[4,30,18432,5] 97 | ytest = ytest.unsqueeze(3)#[4,30,5,1] 98 | #print(ytest.shape,ftest.shape) 99 | #exit(0) 100 | ftest = torch.matmul(ftest, ytest) 101 | ftest = ftest.view(batch_size * num_test, -1, 6, 6) 102 | #print(ftest.shape) 103 | #exit(0) 104 | ytest = self.clasifier(ftest) 105 | 106 | return ytest, cls_scores 107 | 108 | class Model_mltizhixin(nn.Module): 109 | def __init__(self, scale_cls, num_classes=64): 110 | super(Model_mltizhixin, self).__init__() 111 | self.scale_cls = scale_cls 112 | 113 | self.base = resnet12() 114 | self.cam = CAM() 115 | 116 | self.nFeat = self.base.nFeat 117 | self.clasifier = nn.Conv2d(self.nFeat, num_classes, kernel_size=1) 118 | 119 | def test(self, ftrain, ftest): 120 | ftest = ftest.mean(4) 121 | ftest = ftest.mean(4) 122 | ftest = F.normalize(ftest, p=2, dim=ftest.dim()-1, eps=1e-12) 123 | ftrain = F.normalize(ftrain, p=2, dim=ftrain.dim()-1, eps=1e-12) 124 | scores = self.scale_cls * torch.sum(ftest * ftrain, dim=-1) 125 | return scores 126 | 127 | def forward(self, xtrain, xtest, ytrain, ytest,test_fg=False): 128 | batch_size, num_train = xtrain.size(0), xtrain.size(1) 129 | #print(xtrain.shape,xtest.shape)#[4,25,3,84,84],[4,30,3,84,84] 130 | #exit(0) 131 | #cls_scores=0 132 | num_test = xtest.size(1) 133 | K = ytrain.size(2) 134 | ytest = ytest.unsqueeze(3) 135 | ytrain = ytrain.transpose(1, 2)#[4,5,25] 136 | #print(batch_size) 137 | xtrain = xtrain.view(-1, xtrain.size(2), xtrain.size(3), xtrain.size(4))#[100,3,84,84] 138 | xtest = xtest.view(-1, xtest.size(2), xtest.size(3), xtest.size(4))#[120,3,84,84] 139 | x = torch.cat((xtrain, xtest), 0)#(220,3,84,84) 140 | #print(x.shape,xtrain.shape,ytrain.shape,xtest.shape,num_train,'llll',) 141 | #exit(0) 142 | f = self.base(x)#[220,512,6,6] 143 | #print(f.shape,'located in net.py at 42')#[220,512,6,6] 144 | #exit(0) 145 | #print(ytrain[0,:,1:10]) 146 | 147 | #exit(0) 148 | ftrain_all = f[:batch_size * num_train].view(batch_size, 5, -1,512,6,6) 149 | ftest_all = f[batch_size * num_train:] 150 | ftest_all = ftest_all.view(batch_size, num_test, *f.size()[1:])#[4,30,512,6,6] 151 | #print(ftrain.shape) 152 | #exit(0) 153 | ytest_p=[] 154 | for i in range(5): 155 | ftrain = ftrain_all[:,:,i].view(-1,512,6,6) 156 | ftrain = ftrain.view(batch_size, 5, -1)#[4,25,18432] 157 | #print(ftrain.shape) 158 | #exit(0) 159 | ftrain = torch.bmm(ytrain, ftrain)#(4,5,18432),it is matrix multiply [4,5,25],[4,25,18432] 160 | #print(ytrain.sum(dim=2, keepdim=True).shape) 161 | #print(ftrain.div(ytrain.sum(dim=2, keepdim=True)) 162 | #print(ytrain.sum(dim=2, keepdim=True).expand_as(ftrain)[0,:3,:10]) 163 | #exit(0) 164 | ftrain = ftrain.div(ytrain.sum(dim=2, keepdim=True).expand_as(ftrain))# each row is the mean value of all support images of crospond class [4,5,18432] 165 | ftrain = ftrain.view(batch_size, -1, *f.size()[1:])#[4,5,512,6,6] 166 | #print(ftrain.shape) 167 | #exit(0) 168 | #ftest = f[batch_size * num_train:] 169 | #ftest = ftest.view(batch_size, num_test, *f.size()[1:])#[4,30,512,6,6] 170 | ftest =ftest_all 171 | #print(ftest.shape,ftrain.shape,'lllllllllllllll') 172 | #exit(0) 173 | if not test_fg: 174 | ftrain, ftest_att = self.cam(ftrain, ftest,test_fg)##[4,30,5,512,6,6],[4,30,5,512,6,6] 175 | else: 176 | ftrain, ftest_att ,a1,a2 = self.cam(ftrain, ftest,test_fg) 177 | ftrain = ftrain.mean(4) 178 | ftrain = ftrain.mean(4)#[4,30,5,512] 179 | 180 | if not self.training: 181 | if test_fg: 182 | return self.test(ftrain, ftest_att),a1,a2 183 | else: 184 | return self.test(ftrain, ftest_att) 185 | 186 | ftest_norm = F.normalize(ftest_att, p=2, dim=3, eps=1e-12)#[4,30,5,512,6,6] 187 | ftrain_norm = F.normalize(ftrain, p=2, dim=3, eps=1e-12)#[4,30,5,512] 188 | #print(ftest_norm.shape,ftrain_norm.shape,'located in net.py at 74') 189 | #exit(0) 190 | ftrain_norm = ftrain_norm.unsqueeze(4) 191 | ftrain_norm = ftrain_norm.unsqueeze(5)#[4,30,5,512,1,1] 192 | #print(ftest_norm.shape,self.scale_cls,K) 193 | #exit(0) 194 | if i==0: 195 | cls_scores =self.scale_cls * torch.sum(ftest_norm * ftrain_norm, dim=3)#[4,30,5,6,6] 196 | #print(cls_scores.shape,'located in net.py at 79') 197 | #exit(0) 198 | #cls_scores = cls_scores.view(batch_size * num_test, *cls_scores.size()[2:])#[120,5,6,6] 199 | #print(cls_scores.shape,'located in net.py at 79') 200 | #exit(0) 201 | ftest_att = ftest_att.view(batch_size, num_test, K, -1)#[4,30,5,18432] 202 | ftest_att = ftest_att.transpose(2, 3) #[4,30,18432,5] 203 | #ytest = ytest.unsqueeze(3)#[4,30,5,1] 204 | #print(ytest.shape,ftest.shape) 205 | #exit(0) 206 | 207 | ftest_att = torch.matmul(ftest_att, ytest) 208 | ftest_att = ftest_att.view(batch_size * num_test, -1, 6, 6) 209 | #print(ftest.shape) 210 | #exit(0) 211 | ytest_p.append( self.clasifier(ftest_att)) 212 | #print(ytest_p.shape,cls_scores.shape,'llll') 213 | 214 | #exit(0) 215 | cls_scores = cls_scores.view(batch_size * num_test, *cls_scores.size()[2:])#[120,5,6,6] 216 | y_p = torch.stack(ytest_p, dim=0).view(5*120,64,6,6) 217 | #print(y_p.shape) 218 | #exit(0) 219 | return y_p, cls_scores 220 | 221 | class Model_tradi(nn.Module): 222 | def __init__(self, scale_cls, num_classes=64): 223 | super(Model_tradi, self).__init__() 224 | 225 | self.scale_cls = scale_cls 226 | 227 | self.base = resnet12() 228 | self.cam = CAM() 229 | 230 | self.nFeat = self.base.nFeat 231 | self.clasifier = nn.Conv2d(self.nFeat, num_classes, kernel_size=1) 232 | #print(self.training) 233 | #exit(0) 234 | def test(self, ftrain, ftest): 235 | ftest = ftest.mean(4) 236 | ftest = ftest.mean(4) 237 | ftest = F.normalize(ftest, p=2, dim=ftest.dim()-1, eps=1e-12) 238 | ftrain = F.normalize(ftrain, p=2, dim=ftrain.dim()-1, eps=1e-12) 239 | scores = self.scale_cls * torch.sum(ftest * ftrain, dim=-1) 240 | return scores 241 | 242 | def forward(self, xtrain, xtest, ytrain, ytest,test_fg=False): 243 | batch_size, num_train = xtrain.size(0), xtrain.size(1) 244 | #print(xtrain.shape,xtest.shape)#[4,25,3,84,84],[4,30,3,84,84] 245 | #exit(0) 246 | num_test = xtest.size(1) 247 | K = ytrain.size(2) 248 | ytrain = ytrain.transpose(1, 2)#[4,5,25] 249 | #print(batch_size) 250 | xtrain = xtrain.view(-1, xtrain.size(2), xtrain.size(3), xtrain.size(4))#[100,3,84,84] 251 | 252 | xtest = xtest.view(-1, xtest.size(2), xtest.size(3), xtest.size(4))#[120,3,84,84] 253 | #x = torch.cat((xtrain, xtest), 0)#(220,3,84,84) 254 | #print(x.shape,xtrain.shape,xtest.shape,num_train) 255 | #exit(0) 256 | f = self.base(xtest)#[220,512,6,6] 257 | #print(f.shape,'located in net.py at 42')#[220,512,6,6] 258 | #exit(0) 259 | #print(ytrain[0,:,1:10]) 260 | 261 | #exit(0) 262 | #ftrain = f[:batch_size * num_train] 263 | #ftrain = ftrain.view(batch_size, num_train, -1)#[4,25,18432] 264 | #print(ftrain.shape) 265 | #exit(0) 266 | #ftrain = torch.bmm(ytrain, ftrain)#(4,5,18432),it is matrix multiply [4,5,25],[4,25,18432] 267 | #print(ytrain.sum(dim=2, keepdim=True).shape) 268 | #print(ftrain.div(ytrain.sum(dim=2, keepdim=True)) 269 | #print(ytrain.sum(dim=2, keepdim=True).expand_as(ftrain)[0,:3,:10]) 270 | #exit(0) 271 | #ftrain = ftrain.div(ytrain.sum(dim=2, keepdim=True).expand_as(ftrain))# each row is the mean value of all support images of crospond class [4,5,18432] 272 | #ftrain = ftrain.view(batch_size, -1, *f.size()[1:])#[4,5,512,6,6] 273 | #print(ftrain.shape) 274 | #exit(0) 275 | ftest = f#[batch_size * num_train:] 276 | ftest = ftest.view(batch_size, num_test, *f.size()[1:])#[4,30,512,6,6] 277 | #print(ftest.shape,ftrain.shape,'lllllllllllllll##########') 278 | #exit(0) 279 | #if not test_fg: 280 | #ftrain, ftest = self.cam(ftrain, ftest,test_fg)##[4,30,5,512,6,6],[4,30,5,512,6,6] 281 | #else: 282 | #ftrain, ftest,a1,a2 = self.cam(ftrain, ftest,test_fg) 283 | #ftrain = ftrain.mean(4) 284 | #ftrain = ftrain.mean(4)#[4,30,5,512] 285 | #print(ftest.shape,ftrain.shape,self.training,'lllllllllllllll##########') 286 | #if not self.training: 287 | #if test_fg: 288 | #return self.test(ftrain, ftest),a1,a2 289 | #else: 290 | #return self.test(ftrain, ftest) 291 | #print(ftest.shape,ftrain.shape,'lllllllllllllll##########') 292 | #exit(0) 293 | #ftest_norm = F.normalize(ftest, p=2, dim=3, eps=1e-12)#[4,30,5,512,6,6] 294 | #ftrain_norm = F.normalize(ftrain, p=2, dim=3, eps=1e-12)#[4,30,5,512] 295 | #print(ftest_norm.shape,ftrain_norm.shape,'located in net.py at 74') 296 | #exit(0) 297 | #ftrain_norm = ftrain_norm.unsqueeze(4) 298 | #ftrain_norm = ftrain_norm.unsqueeze(5)#[4,30,5,512,1,1] 299 | #print(ftest_norm.shape,self.scale_cls,K) 300 | #exit(0) 301 | #cls_scores = self.scale_cls * torch.sum(ftest_norm * ftrain_norm, dim=3)#[4,30,5,6,6] 302 | #print(cls_scores.shape,'located in net.py at 79') 303 | #exit(0) 304 | #cls_scores = cls_scores.view(batch_size * num_test, *cls_scores.size()[2:])#[120,5,6,6] 305 | #print(cls_scores.shape,'located in net.py at 79') 306 | #exit(0) 307 | #ftest = ftest.view(batch_size, num_test, K, -1)#[4,30,5,18432] 308 | #ftest = ftest.transpose(2, 3) #[4,30,18432,5] 309 | #ytest = ytest.unsqueeze(3)#[4,30,5,1] 310 | #print(ytest.shape,ftest.shape) 311 | #exit(0) 312 | #ftest = torch.matmul(ftest, ytest) 313 | ftest = ftest.view(batch_size * num_test, -1, 6, 6) 314 | feature=ftest 315 | ftest = F.avg_pool2d(ftest, kernel_size=6, stride=1) 316 | #print(ftest.shape,'using avg_pool') 317 | #print(ftest.shape,'no using avg_pool') 318 | #exit(0) 319 | ytest = self.clasifier(ftest) 320 | #print(ytest) 321 | #exit(0) 322 | 323 | return ytest,feature#, cls_scores -------------------------------------------------------------------------------- /torchFewShot/models/net_related.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import sys 6 | 7 | sys.path.append(r"./torchFewShot/models/") 8 | 9 | from resnet12 import resnet12 10 | #from related_net import fusenet 11 | from related_net_spatial_attention import fusenet 12 | from cam import CAM,CAM_similarity 13 | from torchFewShot.utils.torchtools import one_hot_36 14 | #from contrast_loss import TripletLoss 15 | 16 | 17 | class Model(nn.Module): 18 | def __init__(self, scale_cls,only_CSEI, num_classes=64): 19 | super(Model, self).__init__() 20 | self.scale_cls = scale_cls 21 | 22 | self.base = resnet12() 23 | self.only_CSEI=only_CSEI 24 | if not self.only_CSEI: 25 | self.fusenet=fusenet() 26 | self.cam = CAM() 27 | 28 | self.nFeat = self.base.nFeat 29 | self.clasifier = nn.Conv2d(self.nFeat, num_classes, kernel_size=1) 30 | #self.clasifier1 = nn.Conv2d(self.nFeat, num_classes, kernel_size=1) 31 | #self.contrastLoss=TripletLoss() 32 | 33 | def test_ori(self, ftrain, ftest): 34 | ftest = ftest.mean(4) 35 | ftest = ftest.mean(4) 36 | ftest = F.normalize(ftest, p=2, dim=ftest.dim()-1, eps=1e-12) 37 | ftrain = F.normalize(ftrain, p=2, dim=ftrain.dim()-1, eps=1e-12) 38 | scores = self.scale_cls * torch.sum(ftest * ftrain, dim=-1) 39 | print(ftest.shape,ftrain.shape) 40 | exit(0) 41 | return scores 42 | 43 | def test(self, ftrain, ftest): 44 | ftest_mean = ftest.mean(4) 45 | ftest_mean = ftest_mean.mean(4).unsqueeze(3) 46 | ftrain_mean = ftrain.mean(4) 47 | ftrain_mean = ftrain_mean.mean(4).unsqueeze(3) 48 | ftest=ftest.view(4, 75, 5, 512,-1) 49 | ftrain=ftrain.view(4, 75, 5, 512,-1) 50 | ftest=ftest.transpose(3, 4) 51 | ftrain=ftrain.transpose(3, 4) 52 | ftest = F.normalize(ftest, p=2, dim=ftest.dim()-1, eps=1e-12) 53 | ftrain = F.normalize(ftrain, p=2, dim=ftrain.dim()-1, eps=1e-12) 54 | ftest_mean = F.normalize(ftest_mean, p=2, dim=ftest_mean.dim()-1, eps=1e-12) 55 | ftrain_mean = F.normalize(ftrain_mean, p=2, dim=ftrain_mean.dim()-1, eps=1e-12) 56 | #print(ftest_mean.shape,ftrain_mean.shape)#[4,75,5,1,512] 57 | #print(ftest.shape,ftrain.shape)#[4, 75, 5, 36, 512] 58 | #exit(0) 59 | #print(ftest.shape,ftrain.shape,ftest_mean.shape,ftrain_mean.shape) 60 | scores = self.scale_cls * torch.sum((ftest_mean * ftrain_mean).squeeze(), dim=-1) 61 | ftrain_scores = self.scale_cls * torch.sum(ftest_mean * ftrain, dim=-1) 62 | ftest_scores = self.scale_cls * torch.sum(ftest * ftrain_mean, dim=-1) 63 | #print(ftest.shape,ftrain.shape) 64 | #print(ftrain_scores.shape,ftest_scores.shape) 65 | 66 | _,train_ind=torch.max(ftrain_scores,3) 67 | _,test_ind=torch.max(ftest_scores,3) 68 | #print(train_ind) 69 | #exit(0) 70 | train_ind=train_ind.view(-1) 71 | test_ind=test_ind.view(-1) 72 | train_one_hot=one_hot_36(train_ind).view(4,75,5,-1).unsqueeze(4) 73 | test_one_hot=one_hot_36(test_ind).view(4,75,5,-1).unsqueeze(4) 74 | scores_final = self.scale_cls * torch.sum(((ftest*test_one_hot.cuda()).sum(3) * ((ftrain* train_one_hot.cuda()).sum(3))), dim=-1) 75 | #print( test_one_hot.shape,train_one_hot.shape) #[4, 75, 5, 36] 76 | #exit(0) 77 | return scores,scores_final 78 | 79 | def test_topK(self, ftrain, ftest,K): 80 | shape_train=ftrain.shape 81 | #print(ftrain.shape,ftest.shape) 82 | #exit(0) 83 | ftest_mean = ftest.mean(4) 84 | ftest_mean = ftest_mean.mean(4).unsqueeze(3) 85 | ftrain_mean = ftrain.mean(4) 86 | ftrain_mean = ftrain_mean.mean(4).unsqueeze(3) 87 | #print(ftrain_mean.shape) 88 | ftest=ftest.view(shape_train[0], 75, 5, 512,-1) 89 | ftrain=ftrain.view(shape_train[0], 75, 5, 512,-1) 90 | ftest=ftest.transpose(3, 4) 91 | ftrain=ftrain.transpose(3, 4) 92 | #print(ftrain.shape) 93 | #exit(0) 94 | ftest = F.normalize(ftest, p=2, dim=ftest.dim()-1, eps=1e-12) 95 | ftrain = F.normalize(ftrain, p=2, dim=ftrain.dim()-1, eps=1e-12) 96 | ftest_mean = F.normalize(ftest_mean, p=2, dim=ftest_mean.dim()-1, eps=1e-12) 97 | ftrain_mean = F.normalize(ftrain_mean, p=2, dim=ftrain_mean.dim()-1, eps=1e-12) 98 | #print(ftest_mean.shape,ftrain_mean.shape)#[4,75,5,1,512] 99 | #print(ftest.shape,ftrain.shape)#[4, 75, 5, 36, 512] 100 | #exit(0) 101 | #print(ftest.shape,ftrain.shape,ftest_mean.shape,ftrain_mean.shape) 102 | scores = self.scale_cls * torch.sum((ftest_mean * ftrain_mean).squeeze(), dim=-1) 103 | ftrain_scores = self.scale_cls * torch.sum(ftest_mean * ftrain, dim=-1) 104 | ftest_scores = self.scale_cls * torch.sum(ftest * ftrain_mean, dim=-1) 105 | #print(ftest.shape,ftrain.shape) 106 | #print(ftrain_scores.shape,ftest_scores.shape) 107 | #print(self.scale_cls) 108 | #print(scores.shape,'scores') 109 | #K=1 110 | #_,train_ind=torch.max(ftrain_scores,3) 111 | #_,test_ind=torch.max(ftest_scores,3) 112 | _,train_ind=torch.topk(ftrain_scores, K, dim=3 ) 113 | _,test_ind=torch.topk(ftest_scores, K, dim=3 ) 114 | #print(train_ind) 115 | #print(train_ind.shape) 116 | #exit(0) 117 | train_ind=train_ind.view(-1) 118 | test_ind=test_ind.view(-1) 119 | 120 | train_one_hot=one_hot_36(train_ind).view(shape_train[0],75,5,K,-1).sum(3).unsqueeze(4) 121 | test_one_hot=one_hot_36(test_ind).view(shape_train[0],75,5,K,-1).sum(3).unsqueeze(4) 122 | #print(train_one_hot.shape) 123 | #print(train_one_hot[0,0,0,:]) 124 | #exit(0) 125 | ftest_fuse=((ftest*test_one_hot.cuda()).sum(3))/K 126 | ftrain_fuse=((ftrain*train_one_hot.cuda()).sum(3))/K 127 | ftest_fuse_mean = F.normalize(ftest_fuse, p=2, dim=ftest_fuse.dim()-1, eps=1e-12) 128 | ftrain_fuse_mean = F.normalize(ftrain_fuse, p=2, dim=ftrain_fuse.dim()-1, eps=1e-12) 129 | #print(ftest_fuse.shape) 130 | #print(ftrain_fuse.shape) 131 | #exit(0) 132 | #scores_final = self.scale_cls * torch.sum(((ftest*test_one_hot.cuda()).sum(3) * ((ftrain* train_one_hot.cuda()).sum(3))), dim=-1) 133 | scores_final = self.scale_cls * torch.sum((ftest_fuse_mean * ftrain_fuse_mean), dim=-1) 134 | #print( test_one_hot.shape,train_one_hot.shape) #[4, 75, 5, 36] 135 | #exit(0) 136 | return scores,scores_final 137 | 138 | def test_select_topK(self, ftrain, ftest,K): 139 | ftest_mean = ftest.mean(4) 140 | ftest_mean = ftest_mean.mean(4).unsqueeze(3) 141 | ftrain_mean = ftrain.mean(4) 142 | ftrain_mean = ftrain_mean.mean(4).unsqueeze(3) 143 | #print(ftrain_mean.shape) 144 | ftest=ftest.view(4, 75, 5, 512,-1) 145 | ftrain=ftrain.view(4, 75, 5, 512,-1) 146 | ftest=ftest.transpose(3, 4) 147 | ftrain=ftrain.transpose(3, 4) 148 | #print(ftrain.shape) 149 | #exit(0) 150 | ftest = F.normalize(ftest, p=2, dim=ftest.dim()-1, eps=1e-12) 151 | ftrain = F.normalize(ftrain, p=2, dim=ftrain.dim()-1, eps=1e-12) 152 | ftest_mean = F.normalize(ftest_mean, p=2, dim=ftest_mean.dim()-1, eps=1e-12) 153 | ftrain_mean = F.normalize(ftrain_mean, p=2, dim=ftrain_mean.dim()-1, eps=1e-12) 154 | #print(ftest_mean.shape,ftrain_mean.shape)#[4,75,5,1,512] 155 | #print(ftest.shape,ftrain.shape)#[4, 75, 5, 36, 512] 156 | #exit(0) 157 | #print(ftest.shape,ftrain.shape,ftest_mean.shape,ftrain_mean.shape) 158 | scores = self.scale_cls * torch.sum((ftest_mean * ftrain_mean).squeeze(), dim=-1) 159 | ftrain_scores = self.scale_cls * torch.sum(ftest_mean * ftrain, dim=-1) 160 | ftest_scores = self.scale_cls * torch.sum(ftest * ftrain_mean, dim=-1) 161 | #print(ftest.shape,ftrain.shape) 162 | #print(ftrain_scores.shape,ftest_scores.shape) 163 | #print(self.scale_cls) 164 | #print(scores.shape,'scores') 165 | #K=1 166 | #_,train_ind=torch.max(ftrain_scores,3) 167 | #_,test_ind=torch.max(ftest_scores,3) 168 | sorces_list=[] 169 | for i in range(35): 170 | print(i) 171 | _,train_ind=torch.topk(ftrain_scores, i+1, dim=3 ) 172 | for j in range(35): 173 | print(j) 174 | _,test_ind=torch.topk(ftest_scores, j+1, dim=3 ) 175 | #print(train_ind) 176 | #print(train_ind.shape) 177 | #exit(0) 178 | train_ind=train_ind.view(-1) 179 | test_ind=test_ind.view(-1) 180 | 181 | train_one_hot=one_hot_36(train_ind).view(4,75,5,i+1,-1).sum(3).unsqueeze(4) 182 | test_one_hot=one_hot_36(test_ind).view(4,75,5,j+1,-1).sum(3).unsqueeze(4) 183 | #print(train_one_hot.shape) 184 | #print(train_one_hot[0,0,0,:]) 185 | #exit(0) 186 | #print(ftest.shape,test_one_hot.shape) 187 | ftest_fuse=((ftest*test_one_hot.cuda()).sum(3))/(j+1) 188 | ftrain_fuse=((ftrain*train_one_hot.cuda()).sum(3))/(i+1) 189 | ftest_fuse_mean = F.normalize(ftest_fuse, p=2, dim=ftest_fuse.dim()-1, eps=1e-12) 190 | ftrain_fuse_mean = F.normalize(ftrain_fuse, p=2, dim=ftrain_fuse.dim()-1, eps=1e-12) 191 | #print(ftest_fuse.shape) 192 | #print(ftrain_fuse.shape) 193 | #exit(0) 194 | #scores_final = self.scale_cls * torch.sum(((ftest*test_one_hot.cuda()).sum(3) * ((ftrain* train_one_hot.cuda()).sum(3))), dim=-1) 195 | scores_final = self.scale_cls * torch.sum((ftest_fuse_mean * ftrain_fuse_mean), dim=-1) 196 | sorces_list.append(scores_final.view(4,75,5,1)) 197 | #print(scores_final.shape) 198 | #exit(0) 199 | #print( test_one_hot.shape,train_one_hot.shape) #[4, 75, 5, 36] 200 | scores_final,_=torch.cat(sorces_list,3).max(3) 201 | #print(scores_final.shape) 202 | #exit() 203 | #exit(0) 204 | return scores,scores_final 205 | 206 | def forward(self, xtrain, xtest, ytrain, ytest,topK=28,test_fg=False): 207 | batch_size, num_train = xtrain.size(0), xtrain.size(1) 208 | #print(xtrain.shape,xtest.shape)#[4,25,3,84,84],[4,30,3,84,84] 209 | #exit(0) 210 | num_test = xtest.size(1) 211 | K = ytrain.size(2) 212 | ytrain = ytrain.transpose(1, 2)#[4,5,25] 213 | #print(batch_size) 214 | xtrain = xtrain.view(-1, xtrain.size(2), xtrain.size(3), xtrain.size(4))#[100,3,84,84] 215 | xtest = xtest.view(-1, xtest.size(2), xtest.size(3), xtest.size(4))#[120,3,84,84] 216 | x = torch.cat((xtrain, xtest), 0)#(220,3,84,84) 217 | #print(x.shape,xtrain.shape,ytrain.shape,xtest.shape,num_train,'llll',) 218 | #exit(0) 219 | f = self.base(x)#[220,512,6,6] 220 | #print(f.shape,'located in net.py at 42')#[220,512,6,6] 221 | #exit(0) 222 | #print(ytrain[0,:,1:10]) 223 | 224 | #exit(0) 225 | ftrain = f[:batch_size * num_train] 226 | channel,wide,height=ftrain.shape[1],ftrain.shape[2],ftrain.shape[3] 227 | #print(channel,wide,height) 228 | #exit(0) 229 | #ftrain_temp=ftrain.view(batch_size, num_train,ftrain.shape[1],ftrain.shape[2],ftrain.shape[3]) 230 | #ftrain=self.fusenet(ftrain_temp,ytrain) 231 | #exit(0) 232 | ftrain = ftrain.view(batch_size, num_train, -1)#[4,25,18432] 233 | 234 | #print(ftrain.shape) 235 | #exit(0) 236 | ftrain = torch.bmm(ytrain, ftrain)#(4,5,18432),it is matrix multiply [4,5,25],[4,25,18432] 237 | #print(ftrain.shape,';;;;;;;') 238 | #print(ytrain.shape) 239 | N_class=ytrain.shape[1] 240 | #exit(0) 241 | 242 | #one-shot miniimagenet 65.29 use this code 243 | #ftrain_temp=ftrain.view(batch_size, N_class,channel,wide,height) 244 | #print(ftrain_temp 245 | #ftrain=self.fusenet(ftrain_temp,ytrain).view(batch_size, N_class, -1) 246 | 247 | 248 | #print(ftrain.shape) 249 | #exit(0) 250 | #print(ytrain.sum(dim=2, keepdim=True).shape) 251 | #print(ftrain.div(ytrain.sum(dim=2, keepdim=True)) 252 | #print(ytrain.sum(dim=2, keepdim=True).expand_as(ftrain)[0,:3,:10]) 253 | #exit(0) 254 | ftrain = ftrain.div(ytrain.sum(dim=2, keepdim=True).expand_as(ftrain))# each row is the mean value of all support images of crospond class [4,5,18432] 255 | #print(ftrain.shape) 256 | #exit(0) 257 | ftrain_temp=ftrain.view(batch_size, N_class,channel,wide,height) 258 | #print(ftrain_temp[0,0,0]) 259 | #print(ftrain_temp 260 | if not self.only_CSEI: 261 | ftrain,spatial,channel_attention=self.fusenet(ftrain_temp,ytrain) 262 | else: 263 | spatial=0 264 | channel_attention=0 265 | #ftrain_related=ftrain.view(-1, 512, 6,6) 266 | #print(ftrain[0,0,0]) 267 | #exit(0) 268 | ftrain = ftrain.view(batch_size, N_class, -1) 269 | ftrain = ftrain.view(batch_size, -1, *f.size()[1:])#[4,5,512,6,6] 270 | ftrain_class=ftrain.view( -1, *f.size()[1:]) 271 | #ftrain_class=ftrain_class.view(4,5,512,-1).transpose(2,3).contiguous() 272 | #print(ftrain_class.shape) 273 | #ftrain_class=ftrain_class.view(4,-1,512) 274 | #loss=self.contrastLoss(ftrain_class,ytrain) 275 | 276 | #print(ftrain_class.shape,loss) 277 | #exit(0) 278 | #print(ftrain_class.shape,'ftrain') 279 | 280 | 281 | #exit(0) 282 | ftest = f[batch_size * num_train:] 283 | ftest = ftest.view(batch_size, num_test, *f.size()[1:])#[4,30,512,6,6] 284 | #ftest=(spatial+1)*ftest 285 | ftest=(channel_attention+1)*ftest 286 | #print(ftest.shape,ftrain.shape,'use ftest_spatial') 287 | #exit(0) 288 | if not test_fg: 289 | ftrain, ftest = self.cam(ftrain, ftest,test_fg)##[4,30,5,512,6,6],[4,30,5,512,6,6] 290 | else: 291 | ftrain, ftest,a1,a2 = self.cam(ftrain, ftest,test_fg) 292 | ftrain_ori=ftrain 293 | ftrain = ftrain.mean(4) 294 | ftrain = ftrain.mean(4)#[4,30,5,512] 295 | #print(test_fg) 296 | #exit() 297 | if not self.training: 298 | if test_fg: 299 | return self.test(ftrain, ftest),a1,a2 300 | else: 301 | return self.test_topK(ftrain_ori, ftest,topK) 302 | ftest_norm = F.normalize(ftest, p=2, dim=3, eps=1e-12)#[4,30,5,512,6,6] 303 | ftrain_norm = F.normalize(ftrain, p=2, dim=3, eps=1e-12)#[4,30,5,512] 304 | #print(ftest_norm.shape,ftrain_norm.shape,'located in net.py at 74') 305 | #exit(0) 306 | ftrain_norm = ftrain_norm.unsqueeze(4) 307 | ftrain_norm = ftrain_norm.unsqueeze(5)#[4,30,5,512,1,1] 308 | #print(ftest_norm.shape,self.scale_cls,K) 309 | #exit(0) 310 | cls_scores = self.scale_cls * torch.sum(ftest_norm * ftrain_norm, dim=3)#[4,30,5,6,6] 311 | #print(cls_scores.shape,'located in net.py at 79') 312 | #exit(0) 313 | cls_scores = cls_scores.view(batch_size * num_test, *cls_scores.size()[2:])#[120,5,6,6] 314 | #print(cls_scores.shape,'located in net.py at 79') 315 | #exit(0) 316 | ftest = ftest.view(batch_size, num_test, K, -1)#[4,30,5,18432] 317 | ftest = ftest.transpose(2, 3) #[4,30,18432,5] 318 | ytest = ytest.unsqueeze(3)#[4,30,5,1] 319 | #print(ytest.shape,ftest.shape) 320 | #exit(0) 321 | ftest = torch.matmul(ftest, ytest) 322 | ftest = ftest.view(batch_size * num_test, -1, 6, 6) 323 | #print(ftest.shape) 324 | #exit(0) 325 | ytest = self.clasifier(ftest) 326 | #ytrain_class = self.clasifier1(ftrain_class) 327 | 328 | params = list(self.clasifier.parameters()) 329 | #print(ytest.shape,len(params),params[0].shape,ftest.shape) 330 | #exit(0) 331 | 332 | return ytest, cls_scores,ftest,params[0],spatial#,loss#,ytrain_class 333 | -------------------------------------------------------------------------------- /torchFewShot/models/related_net.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from cam_surport import CAM,CAM_similarity 6 | from channel_wise_attention import SELayer 7 | def conv3x3(in_planes, out_planes, stride=1): 8 | """3x3 convolution with padding""" 9 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 10 | padding=1, bias=False) 11 | 12 | 13 | class BasicBlock(nn.Module): 14 | expansion = 1 15 | 16 | def __init__(self, inplanes, planes, kernel=3, stride=1, downsample=None): 17 | super(BasicBlock, self).__init__() 18 | if kernel == 1: 19 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 20 | elif kernel == 3: 21 | self.conv1 = conv3x3(inplanes, planes) 22 | self.bn1 = nn.BatchNorm2d(planes) 23 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 24 | padding=1, bias=False) 25 | self.bn2 = nn.BatchNorm2d(planes) 26 | if kernel == 1: 27 | self.conv3 = nn.Conv2d(planes, planes, kernel_size=1, bias=False) 28 | elif kernel == 3: 29 | self.conv3 = conv3x3(planes, planes) 30 | self.bn3 = nn.BatchNorm2d(planes) 31 | self.relu = nn.ReLU(inplace=True) 32 | self.downsample = downsample 33 | self.stride = stride 34 | 35 | def forward(self, x): 36 | residual = x 37 | 38 | out = self.conv1(x) 39 | out = self.bn1(out) 40 | out = self.relu(out) 41 | 42 | out = self.conv2(out) 43 | out = self.bn2(out) 44 | out = self.relu(out) 45 | 46 | out = self.conv3(out) 47 | out = self.bn3(out) 48 | 49 | if self.downsample is not None: 50 | residual = self.downsample(x) 51 | 52 | out += residual 53 | out = self.relu(out) 54 | 55 | return out 56 | 57 | 58 | class Bottleneck(nn.Module): 59 | expansion = 4 60 | 61 | def __init__(self, inplanes, planes, kernel=1, stride=1, downsample=None): 62 | super(Bottleneck, self).__init__() 63 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 64 | self.bn1 = nn.BatchNorm2d(planes) 65 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 66 | padding=1, bias=False) 67 | self.bn2 = nn.BatchNorm2d(planes) 68 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) 69 | self.bn3 = nn.BatchNorm2d(planes * 4) 70 | self.relu = nn.ReLU(inplace=True) 71 | self.downsample = downsample 72 | self.stride = stride 73 | 74 | def forward(self, x): 75 | residual = x 76 | 77 | out = self.conv1(x) 78 | out = self.bn1(out) 79 | out = self.relu(out) 80 | 81 | out = self.conv2(out) 82 | out = self.bn2(out) 83 | out = self.relu(out) 84 | 85 | out = self.conv3(out) 86 | out = self.bn3(out) 87 | 88 | if self.downsample is not None: 89 | residual = self.downsample(x) 90 | 91 | out += residual 92 | out = self.relu(out) 93 | 94 | return out 95 | 96 | 97 | class RelaNet(nn.Module): 98 | 99 | def __init__(self,input_channel, layer_channels, kernel=3): 100 | self.inplanes = 64 101 | self.kernel = kernel 102 | self.input_channel=input_channel 103 | self.layer_channels=layer_channels 104 | super(RelaNet, self).__init__() 105 | self.convs=[] 106 | for i in range(len(layer_channels)): 107 | if i==0: 108 | self.conv1=nn.Conv2d(self.input_channel, self.layer_channels[i], kernel_size=kernel, stride=1, padding=0, bias=True) 109 | self.bn1 = nn.BatchNorm2d(self.layer_channels[i]) 110 | elif i==1: 111 | self.conv2=nn.Conv2d(self.layer_channels[i-1], self.layer_channels[i], kernel_size=kernel, stride=1, padding=0, bias=True) 112 | self.bn2 = nn.BatchNorm2d(self.layer_channels[i]) 113 | else: 114 | self.conv3=nn.Conv2d(self.layer_channels[i-1], 128, kernel_size=kernel, stride=1, padding=0, bias=True) 115 | self.bn3 = nn.BatchNorm2d(128) 116 | for j in range(len(layer_channels)): 117 | if j==0: 118 | self.conv1_attention=nn.Conv2d(640, self.layer_channels[j], kernel_size=kernel, stride=1, padding=0, bias=True) 119 | self.bn1_attention = nn.BatchNorm2d(self.layer_channels[j]) 120 | elif j==1: 121 | self.conv2_attention=nn.Conv2d(self.layer_channels[j-1], self.layer_channels[j], kernel_size=kernel, stride=1, padding=0, bias=True) 122 | self.bn2_attention = nn.BatchNorm2d(self.layer_channels[j]) 123 | else: 124 | self.conv3_attention=nn.Conv2d(self.layer_channels[j-1],512, kernel_size=kernel, stride=1, padding=0, bias=True) 125 | #self.bn3 = nn.BatchNorm2d(self.layer_channels[i]) 126 | #self.bn1 = nn.BatchNorm2d(64) 127 | self.relu = nn.ReLU(inplace=True) 128 | self.cam_support = CAM() 129 | self.se=SELayer(640,8) 130 | #self.sigmoid=torch.Sigmoid() 131 | def forward(self, x,label_1shot): 132 | shape=x.shape 133 | #print(x.shape) 134 | 135 | ftrain = self.cam_support(x, x) 136 | #print(ftrain.shape) 137 | #exit(0) 138 | #print(x.shape,(shape[0],shape[1],1,shape[2],shape[3],shape[4])) 139 | #print(x.shape,(shape[0],1,shape[1],shape[2],shape[3],shape[4])) 140 | #exit(0) 141 | shape=x.shape 142 | #print(x.shape) 143 | #exit(0) 144 | right=x.reshape(shape[0],shape[1],1,shape[2],shape[3],shape[4]).repeat(1,1,shape[1],1,1,1) 145 | left=x.reshape(shape[0],1,shape[1],shape[2],shape[3],shape[4]).repeat(1,shape[1],1,1,1,1) 146 | #concat_feature=torch.cat((left, right), 3).reshape( shape[0]* shape[1]*shape[1],2*shape[2],shape[3],shape[4] ) 147 | #print(x.shape) 148 | #exit(0) 149 | concat_feature=(left+right).reshape( shape[0]* shape[1]*shape[1],shape[2],shape[3],shape[4] ) 150 | #concat_feature1= concat_feature 151 | #print(concat_feature.shape) 152 | #exit(0) 153 | for i in range(3): 154 | if i==0: 155 | concat_feature = self.conv1(concat_feature) 156 | concat_feature= self.bn1(concat_feature) 157 | concat_feature = self.relu(concat_feature) 158 | #print(concat_feature.shape,'llllllllllllllllllll') 159 | elif i==1: 160 | #print(concat_feature.shape,'looooooooooooooo') 161 | concat_feature = self.conv2(concat_feature) 162 | concat_feature= self.bn2(concat_feature) 163 | concat_feature = self.relu(concat_feature) 164 | else: 165 | #print(concat_feature.shape,'pppppppppppppp') 166 | concat_feature = self.conv3(concat_feature) 167 | concat_feature= self.bn3(concat_feature) 168 | concat_feature = self.relu(concat_feature) 169 | #concat_feature=concat_feature.reshape( shape[0],shape[1],shape[1],shape[2],shape[3],shape[4] ) 170 | concat_feature=concat_feature.reshape( shape[0],shape[1],shape[1],128,shape[3],shape[4] ) 171 | concat_feature=concat_feature.mean(2).reshape( -1,640,shape[3],shape[4] ) 172 | #print(concat_feature.shape) 173 | #concat_feature=concat_feature.mean(2) 174 | #concat_feature=concat_feature.unsqueeze(2).unsqueeze(2) 175 | #print(concat_feature.shape) 176 | #concat_feature=concat_feature.reshape( 4,-1,1,1 ) 177 | #exit(0) 178 | #for i in range(0): 179 | #if i==0: 180 | #concat_feature = self.conv1_attention(concat_feature) 181 | #concat_feature= self.bn1_attention(concat_feature) 182 | #concat_feature = self.relu(concat_feature) 183 | #print(concat_feature.shape,'llllllllllllllllllll') 184 | #elif i==1: 185 | #print(concat_feature.shape,'looooooooooooooo') 186 | #concat_feature = self.conv2_attention(concat_feature) 187 | #concat_feature= self.bn2_attention(concat_feature) 188 | #concat_feature = self.relu(concat_feature) 189 | #else: 190 | #print(concat_feature.shape,'pppppppppppppp') 191 | #concat_feature = self.conv3_attention(concat_feature) 192 | #concat_feature= self.bn3_attention(concat_feature) 193 | #concat_feature = self.relu(concat_feature) 194 | #spatial=torch.sigmoid( concat_feature).reshape( shape[0],shape[1],512,1,1 ) 195 | #spatial=torch.sigmoid( concat_feature).reshape( shape[0],1,512,1,1 ) 196 | #SELayer 197 | #print(spatial.shape)#[4,5,512,1,1] 198 | #print(spatial[0].max(1)[0]) 199 | #print(spatial[0,1]) 200 | #exit(0) 201 | spatial=self.se(concat_feature).unsqueeze(1) 202 | concat_feature=(spatial+1)*x 203 | #spatial=self.se(concat_feature).unsqueeze(1) 204 | #print(spatial.shape) 205 | #print(x.shape) 206 | #exit(0) 207 | #concat_feature= ftrain 208 | #print(concat_feature.shape) 209 | #exit(0) 210 | return concat_feature,spatial 211 | 212 | 213 | def fusenet(): 214 | model = RelaNet(512,[128,128,512], kernel=1) 215 | return model 216 | -------------------------------------------------------------------------------- /torchFewShot/models/related_net_spatial_attention.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from cam import CAM 6 | from channel_wise_attention import SELayer 7 | #from fuse_net import resnet_fuse 8 | def conv3x3(in_planes, out_planes, stride=1): 9 | """3x3 convolution with padding""" 10 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 11 | padding=1, bias=False) 12 | 13 | 14 | class BasicBlock(nn.Module): 15 | expansion = 1 16 | 17 | def __init__(self, inplanes, planes, kernel=3, stride=1, downsample=None): 18 | super(BasicBlock, self).__init__() 19 | if kernel == 1: 20 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 21 | elif kernel == 3: 22 | self.conv1 = conv3x3(inplanes, planes) 23 | self.bn1 = nn.BatchNorm2d(planes) 24 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 25 | padding=1, bias=False) 26 | self.bn2 = nn.BatchNorm2d(planes) 27 | if kernel == 1: 28 | self.conv3 = nn.Conv2d(planes, planes, kernel_size=1, bias=False) 29 | elif kernel == 3: 30 | self.conv3 = conv3x3(planes, planes) 31 | self.bn3 = nn.BatchNorm2d(planes) 32 | self.relu = nn.ReLU(inplace=True) 33 | self.downsample = downsample 34 | self.stride = stride 35 | 36 | def forward(self, x): 37 | residual = x 38 | 39 | out = self.conv1(x) 40 | out = self.bn1(out) 41 | out = self.relu(out) 42 | 43 | out = self.conv2(out) 44 | out = self.bn2(out) 45 | out = self.relu(out) 46 | 47 | out = self.conv3(out) 48 | out = self.bn3(out) 49 | 50 | if self.downsample is not None: 51 | residual = self.downsample(x) 52 | 53 | out += residual 54 | out = self.relu(out) 55 | 56 | return out 57 | 58 | 59 | class Bottleneck(nn.Module): 60 | expansion = 4 61 | 62 | def __init__(self, inplanes, planes, kernel=1, stride=1, downsample=None): 63 | super(Bottleneck, self).__init__() 64 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 65 | self.bn1 = nn.BatchNorm2d(planes) 66 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 67 | padding=1, bias=False) 68 | self.bn2 = nn.BatchNorm2d(planes) 69 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) 70 | self.bn3 = nn.BatchNorm2d(planes * 4) 71 | self.relu = nn.ReLU(inplace=True) 72 | self.downsample = downsample 73 | self.stride = stride 74 | 75 | def forward(self, x): 76 | residual = x 77 | 78 | out = self.conv1(x) 79 | out = self.bn1(out) 80 | out = self.relu(out) 81 | 82 | out = self.conv2(out) 83 | out = self.bn2(out) 84 | out = self.relu(out) 85 | 86 | out = self.conv3(out) 87 | out = self.bn3(out) 88 | 89 | if self.downsample is not None: 90 | residual = self.downsample(x) 91 | 92 | out += residual 93 | out = self.relu(out) 94 | 95 | return out 96 | 97 | 98 | class RelaNet(nn.Module): 99 | 100 | def __init__(self,input_channel, layer_channels, kernel=3): 101 | self.inplanes = 64 102 | self.kernel = kernel 103 | self.input_channel=input_channel 104 | self.layer_channels=layer_channels 105 | super(RelaNet, self).__init__() 106 | self.convs=[] 107 | for i in range(len(layer_channels)): 108 | if i==0: 109 | self.conv1=nn.Conv2d(self.input_channel, self.layer_channels[i], kernel_size=kernel, stride=1, padding=0, bias=False) 110 | self.bn1 = nn.BatchNorm2d(self.layer_channels[i]) 111 | elif i==1: 112 | self.conv2=nn.Conv2d(self.layer_channels[i-1], self.layer_channels[i], kernel_size=kernel, stride=1, padding=0, bias=False) 113 | self.bn2 = nn.BatchNorm2d(self.layer_channels[i]) 114 | else: 115 | self.conv3=nn.Conv2d(self.layer_channels[i-1], 128, kernel_size=kernel, stride=1, padding=0, bias=False) 116 | self.bn3 = nn.BatchNorm2d(128) 117 | for j in range(len(layer_channels)): 118 | if j==0: 119 | self.conv1_attention=nn.Conv2d(128, self.layer_channels[j], kernel_size=kernel, stride=1, padding=0, bias=False) 120 | self.bn1_attention = nn.BatchNorm2d(self.layer_channels[j]) 121 | elif j==1: 122 | self.conv2_attention=nn.Conv2d(self.layer_channels[j-1], self.layer_channels[j], kernel_size=kernel, stride=1, padding=0, bias=False) 123 | self.bn2_attention = nn.BatchNorm2d(self.layer_channels[j]) 124 | else: 125 | self.conv3_attention=nn.Conv2d(self.layer_channels[j-1],1, kernel_size=kernel, stride=1, padding=0, bias=False) 126 | #self.bn3 = nn.BatchNorm2d(self.layer_channels[i]) 127 | self.conv3_channel_attention=nn.Conv2d(128*5,128, kernel_size=3, stride=1, padding=1, bias=False) 128 | self.bn3_channel_attention = nn.BatchNorm2d(128) 129 | #self.bn1 = nn.BatchNorm2d(64) 130 | self.relu = nn.ReLU(inplace=True) 131 | self.cam_support = CAM() 132 | self.channel_att=SELayer(128,4) 133 | #self.sigmoid=torch.Sigmoid() 134 | def forward(self, x,label_1shot): 135 | shape=x.shape 136 | #print(x.shape) 137 | #exit(0) 138 | #ftrain = self.cam_support(x, x) 139 | #print(ftrain.shape) 140 | #exit(0) 141 | #print(x.shape,(shape[0],shape[1],1,shape[2],shape[3],shape[4])) 142 | #print(x.shape,(shape[0],1,shape[1],shape[2],shape[3],shape[4])) 143 | #exit(0) 144 | shape=x.shape 145 | #print(x.shape) 146 | #exit(0) 147 | right=x.reshape(shape[0],shape[1],1,shape[2],shape[3],shape[4]).repeat(1,1,shape[1],1,1,1) 148 | left=x.reshape(shape[0],1,shape[1],shape[2],shape[3],shape[4]).repeat(1,shape[1],1,1,1,1) 149 | #concat_feature=torch.cat((left, right), 3).reshape( shape[0]* shape[1]*shape[1],2*shape[2],shape[3],shape[4] ) 150 | #print(x.shape) 151 | #exit(0) 152 | concat_feature=(left-right).reshape( shape[0]* shape[1]*shape[1],shape[2],shape[3],shape[4] ) 153 | #concat_feature1= concat_feature 154 | #print(concat_feature.shape) 155 | #exit(0) 156 | for i in range(3): 157 | if i==0: 158 | concat_feature = self.conv1(concat_feature) 159 | concat_feature= self.bn1(concat_feature) 160 | concat_feature = self.relu(concat_feature) 161 | #print(concat_feature.shape,'llllllllllllllllllll') 162 | elif i==1: 163 | #print(concat_feature.shape,'looooooooooooooo') 164 | concat_feature = self.conv2(concat_feature) 165 | concat_feature= self.bn2(concat_feature) 166 | concat_feature = self.relu(concat_feature) 167 | else: 168 | #print(concat_feature.shape,'pppppppppppppp') 169 | concat_feature = self.conv3(concat_feature) 170 | concat_feature= self.bn3(concat_feature) 171 | concat_feature = self.relu(concat_feature) 172 | #concat_feature=concat_feature.reshape( shape[0],shape[1],shape[1],shape[2],shape[3],shape[4] ) 173 | concat_feature=concat_feature.reshape( shape[0],shape[1],shape[1],128,shape[3],shape[4] ) 174 | concat_feature=concat_feature.mean(2).reshape( -1,128,shape[3],shape[4]) 175 | # channel-wise attention 176 | concat_feature_gm=concat_feature.reshape(shape[0],shape[1]*128,shape[3],shape[4] )#.mean(1) 177 | concat_feature_gm= self.conv3_channel_attention( concat_feature_gm) 178 | concat_feature_gm=self.bn3_channel_attention( concat_feature_gm) 179 | concat_feature_gm = self.relu(concat_feature_gm) 180 | #print(concat_feature_gm.shape) 181 | #exit(0) 182 | #print(concat_feature.shape) 183 | channel_attention=self.channel_att(concat_feature_gm).unsqueeze(1) 184 | #print(channel_attention.shape,x.shape) 185 | #exit(0) 186 | for i in range(3): 187 | if i==0: 188 | concat_feature = self.conv1_attention(concat_feature) 189 | concat_feature= self.bn1_attention(concat_feature) 190 | concat_feature = self.relu(concat_feature) 191 | #print(concat_feature.shape,'llllllllllllllllllll') 192 | elif i==1: 193 | #print(concat_feature.shape,'looooooooooooooo') 194 | concat_feature = self.conv2_attention(concat_feature) 195 | concat_feature= self.bn2_attention(concat_feature) 196 | concat_feature = self.relu(concat_feature) 197 | else: 198 | #print(concat_feature.shape,'pppppppppppppp') 199 | concat_feature = self.conv3_attention(concat_feature) 200 | #concat_feature= self.bn3_attention(concat_feature) 201 | #concat_feature = self.relu(concat_feature) 202 | #spatial=torch.sigmoid( concat_feature).reshape( shape[0],shape[1],1,shape[3],shape[4] ) 203 | spatial=torch.tanh( concat_feature).reshape( shape[0],shape[1],1,shape[3],shape[4] ) 204 | #print(spatial.shape,'pppp tanh') 205 | #print(spatial[:2]) 206 | #exit(0) 207 | concat_feature=(0.5*spatial+1.5)*x*(1+channel_attention ) 208 | #concat_feature= ftrain 209 | #print(concat_feature.shape) 210 | #exit(0) 211 | return concat_feature,spatial,channel_attention 212 | 213 | 214 | def fusenet(): 215 | model = RelaNet(512,[128,128,512], kernel=1) 216 | return model 217 | -------------------------------------------------------------------------------- /torchFewShot/models/resnet12.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | def conv3x3(in_planes, out_planes, stride=1): 7 | """3x3 convolution with padding""" 8 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 9 | padding=1, bias=False) 10 | 11 | 12 | class BasicBlock(nn.Module): 13 | expansion = 1 14 | 15 | def __init__(self, inplanes, planes, kernel=3, stride=1, downsample=None): 16 | super(BasicBlock, self).__init__() 17 | if kernel == 1: 18 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 19 | elif kernel == 3: 20 | self.conv1 = conv3x3(inplanes, planes) 21 | self.bn1 = nn.BatchNorm2d(planes) 22 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 23 | padding=1, bias=False) 24 | self.bn2 = nn.BatchNorm2d(planes) 25 | if kernel == 1: 26 | self.conv3 = nn.Conv2d(planes, planes, kernel_size=1, bias=False) 27 | elif kernel == 3: 28 | self.conv3 = conv3x3(planes, planes) 29 | self.bn3 = nn.BatchNorm2d(planes) 30 | self.relu = nn.ReLU(inplace=True) 31 | self.downsample = downsample 32 | self.stride = stride 33 | 34 | def forward(self, x): 35 | residual = x 36 | 37 | out = self.conv1(x) 38 | out = self.bn1(out) 39 | out = self.relu(out) 40 | 41 | out = self.conv2(out) 42 | out = self.bn2(out) 43 | out = self.relu(out) 44 | 45 | out = self.conv3(out) 46 | out = self.bn3(out) 47 | 48 | if self.downsample is not None: 49 | residual = self.downsample(x) 50 | 51 | out += residual 52 | out = self.relu(out) 53 | 54 | return out 55 | 56 | 57 | class Bottleneck(nn.Module): 58 | expansion = 4 59 | 60 | def __init__(self, inplanes, planes, kernel=1, stride=1, downsample=None): 61 | super(Bottleneck, self).__init__() 62 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 63 | self.bn1 = nn.BatchNorm2d(planes) 64 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 65 | padding=1, bias=False) 66 | self.bn2 = nn.BatchNorm2d(planes) 67 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) 68 | self.bn3 = nn.BatchNorm2d(planes * 4) 69 | self.relu = nn.ReLU(inplace=True) 70 | self.downsample = downsample 71 | self.stride = stride 72 | 73 | def forward(self, x): 74 | residual = x 75 | 76 | out = self.conv1(x) 77 | out = self.bn1(out) 78 | out = self.relu(out) 79 | 80 | out = self.conv2(out) 81 | out = self.bn2(out) 82 | out = self.relu(out) 83 | 84 | out = self.conv3(out) 85 | out = self.bn3(out) 86 | 87 | if self.downsample is not None: 88 | residual = self.downsample(x) 89 | 90 | out += residual 91 | out = self.relu(out) 92 | 93 | return out 94 | 95 | 96 | class ResNet(nn.Module): 97 | 98 | def __init__(self, block, layers, kernel=3): 99 | self.inplanes = 64 100 | self.kernel = kernel 101 | super(ResNet, self).__init__() 102 | self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False) 103 | self.bn1 = nn.BatchNorm2d(64) 104 | self.relu = nn.ReLU(inplace=True) 105 | 106 | self.layer1 = self._make_layer(block, 64, layers[0], stride=2) 107 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 108 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 109 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2) 110 | 111 | self.nFeat = 512 * block.expansion 112 | 113 | for m in self.modules(): 114 | if isinstance(m, nn.Conv2d): 115 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 116 | m.weight.data.normal_(0, math.sqrt(2. / n)) 117 | elif isinstance(m, nn.BatchNorm2d): 118 | m.weight.data.fill_(1) 119 | m.bias.data.zero_() 120 | 121 | def _make_layer(self, block, planes, blocks, stride=1): 122 | downsample = None 123 | if stride != 1 or self.inplanes != planes * block.expansion: 124 | downsample = nn.Sequential( 125 | nn.Conv2d(self.inplanes, planes * block.expansion, 126 | kernel_size=1, stride=stride, bias=False), 127 | nn.BatchNorm2d(planes * block.expansion), 128 | ) 129 | 130 | layers = [] 131 | layers.append(block(self.inplanes, planes, self.kernel, stride, downsample)) 132 | self.inplanes = planes * block.expansion 133 | for i in range(1, blocks): 134 | layers.append(block(self.inplanes, planes, self.kernel)) 135 | 136 | return nn.Sequential(*layers) 137 | 138 | def forward(self, x): 139 | x = self.conv1(x) 140 | x = self.bn1(x) 141 | x = self.relu(x) 142 | x = self.layer1(x) 143 | x = self.layer2(x) 144 | x = self.layer3(x) 145 | x = self.layer4(x) 146 | 147 | return x 148 | 149 | 150 | def resnet12(): 151 | model = ResNet(BasicBlock, [1,1,1,1], kernel=3) 152 | return model 153 | -------------------------------------------------------------------------------- /torchFewShot/optimizers.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | import torch 4 | 5 | 6 | def init_optimizer(optim, params, lr, weight_decay): 7 | if optim == 'adam': 8 | return torch.optim.Adam(params, lr=lr, weight_decay=weight_decay) 9 | elif optim == 'amsgrad': 10 | return torch.optim.Adam(params, lr=lr, weight_decay=weight_decay, amsgrad=True) 11 | elif optim == 'sgd': 12 | return torch.optim.SGD(params, lr=lr, momentum=0.9, weight_decay=weight_decay, nesterov=True) 13 | elif optim == 'rmsprop': 14 | return torch.optim.RMSprop(params, lr=lr, momentum=0.9, weight_decay=weight_decay) 15 | else: 16 | raise KeyError("Unsupported optimizer: {}".format(optim)) 17 | -------------------------------------------------------------------------------- /torchFewShot/transforms.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | 4 | from torchvision.transforms import * 5 | 6 | from PIL import Image 7 | import random 8 | import numpy as np 9 | import math 10 | import torch 11 | 12 | 13 | class Random2DTranslation(object): 14 | """ 15 | With a probability, first increase image size to (1 + 1/8), and then perform random crop. 16 | 17 | Args: 18 | - height (int): target height. 19 | - width (int): target width. 20 | - p (float): probability of performing this transformation. Default: 0.5. 21 | """ 22 | def __init__(self, height, width, p=0.5, interpolation=Image.BILINEAR): 23 | self.height = height 24 | self.width = width 25 | self.p = p 26 | self.interpolation = interpolation 27 | 28 | def __call__(self, img): 29 | """ 30 | Args: 31 | - img (PIL Image): Image to be cropped. 32 | """ 33 | if random.uniform(0, 1) > self.p: 34 | return img.resize((self.width, self.height), self.interpolation) 35 | 36 | new_width, new_height = int(round(self.width * 1.125)), int(round(self.height * 1.125)) 37 | resized_img = img.resize((new_width, new_height), self.interpolation) 38 | x_maxrange = new_width - self.width 39 | y_maxrange = new_height - self.height 40 | x1 = int(round(random.uniform(0, x_maxrange))) 41 | y1 = int(round(random.uniform(0, y_maxrange))) 42 | croped_img = resized_img.crop((x1, y1, x1 + self.width, y1 + self.height)) 43 | return croped_img 44 | 45 | 46 | class RandomErasing(object): 47 | """ Randomly selects a rectangle region in an image and erases its pixels. 48 | 'Random Erasing Data Augmentation' by Zhong et al. 49 | See https://arxiv.org/pdf/1708.04896.pdf 50 | Args: 51 | probability: The probability that the Random Erasing operation will be performed. 52 | sl: Minimum proportion of erased area against input image. 53 | sh: Maximum proportion of erased area against input image. 54 | r1: Minimum aspect ratio of erased area. 55 | mean: Erasing value. 56 | """ 57 | 58 | def __init__(self, probability = 0.5, sl = 0.02, sh = 0.4, r1 = 0.3, mean=[0.4914, 0.4822, 0.4465]): 59 | self.probability = probability 60 | self.mean = mean 61 | self.sl = sl 62 | self.sh = sh 63 | self.r1 = r1 64 | 65 | def __call__(self, img): 66 | 67 | if random.uniform(0, 1) > self.probability: 68 | return img 69 | 70 | for attempt in range(100): 71 | area = img.size()[1] * img.size()[2] 72 | 73 | target_area = random.uniform(self.sl, self.sh) * area 74 | aspect_ratio = random.uniform(self.r1, 1/self.r1) 75 | 76 | h = int(round(math.sqrt(target_area * aspect_ratio))) 77 | w = int(round(math.sqrt(target_area / aspect_ratio))) 78 | 79 | if w < img.size()[2] and h < img.size()[1]: 80 | x1 = random.randint(0, img.size()[1] - h) 81 | y1 = random.randint(0, img.size()[2] - w) 82 | if img.size()[0] == 3: 83 | img[0, x1:x1+h, y1:y1+w] = self.mean[0] 84 | img[1, x1:x1+h, y1:y1+w] = self.mean[1] 85 | img[2, x1:x1+h, y1:y1+w] = self.mean[2] 86 | else: 87 | img[0, x1:x1+h, y1:y1+w] = self.mean[0] 88 | return img 89 | 90 | return img 91 | -------------------------------------------------------------------------------- /torchFewShot/utils/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /torchFewShot/utils/avgmeter.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | 4 | 5 | class AverageMeter(object): 6 | """Computes and stores the average and current value. 7 | 8 | Code imported from https://github.com/pytorch/examples/blob/master/imagenet/main.py#L247-L262 9 | """ 10 | def __init__(self): 11 | self.reset() 12 | 13 | def reset(self): 14 | self.val = 0 15 | self.avg = 0 16 | self.sum = 0 17 | self.count = 0 18 | 19 | def update(self, val, n=1): 20 | self.val = val 21 | self.sum += val * n 22 | self.count += n 23 | self.avg = self.sum / self.count -------------------------------------------------------------------------------- /torchFewShot/utils/iotools.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | import os 4 | import os.path as osp 5 | import errno 6 | import json 7 | import shutil 8 | 9 | import torch 10 | 11 | 12 | def mkdir_if_missing(directory): 13 | if not osp.exists(directory): 14 | try: 15 | os.makedirs(directory) 16 | except OSError as e: 17 | if e.errno != errno.EEXIST: 18 | raise 19 | 20 | 21 | def check_isfile(path): 22 | isfile = osp.isfile(path) 23 | if not isfile: 24 | print("=> Warning: no file found at '{}' (ignored)".format(path)) 25 | return isfile 26 | 27 | 28 | def read_json(fpath): 29 | with open(fpath, 'r') as f: 30 | obj = json.load(f) 31 | return obj 32 | 33 | 34 | def write_json(obj, fpath): 35 | mkdir_if_missing(osp.dirname(fpath)) 36 | with open(fpath, 'w') as f: 37 | json.dump(obj, f, indent=4, separators=(',', ': ')) 38 | 39 | 40 | def save_checkpoint(state, is_best=False, fpath='checkpoint.pth.tar'): 41 | if len(osp.dirname(fpath)) != 0: 42 | mkdir_if_missing(osp.dirname(fpath)) 43 | torch.save(state, fpath) 44 | if is_best: 45 | shutil.copy(fpath, osp.join(osp.dirname(fpath), 'best_model.pth.tar')) -------------------------------------------------------------------------------- /torchFewShot/utils/logger.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | import sys 4 | import os 5 | import os.path as osp 6 | 7 | from .iotools import mkdir_if_missing 8 | 9 | 10 | class Logger(object): 11 | """ 12 | Write console output to external text file. 13 | Code imported from https://github.com/Cysu/open-reid/blob/master/reid/utils/logging.py. 14 | """ 15 | def __init__(self, fpath=None, mode='a'): 16 | self.console = sys.stdout 17 | self.file = None 18 | if fpath is not None: 19 | mkdir_if_missing(osp.dirname(fpath)) 20 | self.file = open(fpath, mode) 21 | 22 | def __del__(self): 23 | self.close() 24 | 25 | def __enter__(self): 26 | pass 27 | 28 | def __exit__(self, *args): 29 | self.close() 30 | 31 | def write(self, msg): 32 | self.console.write(msg) 33 | if self.file is not None: 34 | self.file.write(msg) 35 | 36 | def flush(self): 37 | self.console.flush() 38 | if self.file is not None: 39 | self.file.flush() 40 | os.fsync(self.file.fileno()) 41 | 42 | def close(self): 43 | self.console.close() 44 | if self.file is not None: 45 | self.file.close() 46 | -------------------------------------------------------------------------------- /torchFewShot/utils/torchtools.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | 4 | import torch 5 | import torch.nn as nn 6 | 7 | 8 | def open_all_layers(model): 9 | """ 10 | Open all layers in model for training. 11 | """ 12 | model.train() 13 | for p in model.parameters(): 14 | p.requires_grad = True 15 | 16 | 17 | def open_specified_layers(model, open_layers): 18 | """ 19 | Open specified layers in model for training while keeping 20 | other layers frozen. 21 | 22 | Args: 23 | - model (nn.Module): neural net model. 24 | - open_layers (list): list of layers names. 25 | """ 26 | if isinstance(model, nn.DataParallel): 27 | model = model.module 28 | 29 | for layer in open_layers: 30 | assert hasattr(model, layer), "'{}' is not an attribute of the model, please provide the correct name".format(layer) 31 | 32 | for name, module in model.named_children(): 33 | if name in open_layers: 34 | #print(module) 35 | module.train() 36 | for p in module.parameters(): 37 | p.requires_grad = True 38 | else: 39 | module.eval() 40 | for p in module.parameters(): 41 | p.requires_grad = False 42 | 43 | 44 | 45 | def adjust_learning_rate(optimizer, iters, LUT): 46 | # decay learning rate by 'gamma' for every 'stepsize' 47 | for (stepvalue, base_lr) in LUT: 48 | if iters < stepvalue: 49 | lr = base_lr 50 | break 51 | 52 | for param_group in optimizer.param_groups: 53 | param_group['lr'] = lr 54 | return lr 55 | 56 | 57 | def adjust_lambda(iters, LUT): 58 | for (stepvalue, base_lambda) in LUT: 59 | if iters < stepvalue: 60 | lambda_xent = base_lambda 61 | break 62 | return lambda_xent 63 | 64 | 65 | def set_bn_to_eval(m): 66 | # 1. no update for running mean and var 67 | # 2. scale and shift parameters are still trainable 68 | classname = m.__class__.__name__ 69 | if classname.find('BatchNorm') != -1: 70 | m.eval() 71 | 72 | 73 | def count_num_param(model): 74 | num_param = sum(p.numel() for p in model.parameters()) / 1e+06 75 | if hasattr(model, 'classifier') and isinstance(model.classifier, nn.Module): 76 | # we ignore the classifier because it is unused at test time 77 | num_param -= sum(p.numel() for p in model.classifier.parameters()) / 1e+06 78 | return num_param 79 | 80 | 81 | def one_hot(labels_train): 82 | """ 83 | Turn the labels_train to one-hot encoding. 84 | Args: 85 | labels_train: [batch_size, num_train_examples] 86 | Return: 87 | labels_train_1hot: [batch_size, num_train_examples, K] 88 | """ 89 | labels_train = labels_train.cpu() 90 | nKnovel = 1 + labels_train.max() 91 | labels_train_1hot_size = list(labels_train.size()) + [nKnovel,] 92 | labels_train_unsqueeze = labels_train.unsqueeze(dim=labels_train.dim()) 93 | labels_train_1hot = torch.zeros(labels_train_1hot_size).scatter_(len(labels_train_1hot_size) - 1, labels_train_unsqueeze, 1) 94 | return labels_train_1hot 95 | 96 | def one_hot_36(labels_train): 97 | """ 98 | Turn the labels_train to one-hot encoding. 99 | Args: 100 | labels_train: [batch_size, num_train_examples] 101 | Return: 102 | labels_train_1hot: [batch_size, num_train_examples, K] 103 | """ 104 | labels_train = labels_train.cpu() 105 | #nKnovel = 1 + labels_train.max() 106 | nKnovel = 36 107 | labels_train_1hot_size = list(labels_train.size()) + [nKnovel,] 108 | #print(list(labels_train.size()),labels_train_1hot_size) 109 | #exit(0) 110 | labels_train_unsqueeze = labels_train.unsqueeze(dim=labels_train.dim()) 111 | labels_train_1hot = torch.zeros(labels_train_1hot_size).scatter_(len(labels_train_1hot_size) - 1, labels_train_unsqueeze, 1) 112 | return labels_train_1hot 113 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | from __future__ import division 3 | 4 | import os 5 | import sys 6 | import time 7 | import datetime 8 | import argparse 9 | import os.path as osp 10 | import numpy as np 11 | import random 12 | import cv2 13 | 14 | import torch 15 | import torch.nn as nn 16 | import torch.backends.cudnn as cudnn 17 | from torch.utils.data import DataLoader 18 | from torch.optim import lr_scheduler 19 | import torch.nn.functional as F 20 | sys.path.append('./torchFewShot') 21 | 22 | #from args_tiered import argument_parser 23 | from args_xent import argument_parser 24 | #from torchFewShot.models.net import Model 25 | 26 | from torchFewShot.models.models_gnn import create_models 27 | from torchFewShot.data_manager import DataManager 28 | from torchFewShot.losses import CrossEntropyLoss 29 | from torchFewShot.optimizers import init_optimizer 30 | 31 | from torchFewShot.utils.iotools import save_checkpoint, check_isfile 32 | from torchFewShot.utils.avgmeter import AverageMeter 33 | from torchFewShot.utils.logger import Logger 34 | from torchFewShot.utils.torchtools import one_hot, adjust_learning_rate 35 | 36 | #from inpainting import EC 37 | 38 | 39 | parser = argument_parser() 40 | args = parser.parse_args() 41 | #print(args.use_similarity) 42 | #exit(0) 43 | if args.use_similarity: 44 | from torchFewShot.models.net_similary import Model 45 | else: 46 | from torchFewShot.models.net import Model 47 | only_test=False 48 | def main(): 49 | torch.manual_seed(args.seed) 50 | os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu_devices 51 | use_gpu = torch.cuda.is_available() 52 | 53 | sys.stdout = Logger(osp.join(args.save_dir, 'log_train.txt')) 54 | print("==========\nArgs:{}\n==========".format(args)) 55 | 56 | if use_gpu: 57 | print("Currently using GPU {}".format(args.gpu_devices)) 58 | cudnn.benchmark = True 59 | torch.cuda.manual_seed_all(args.seed) 60 | else: 61 | print("Currently using CPU (GPU is highly recommended)") 62 | 63 | print('Initializing image data manager') 64 | dm = DataManager(args, use_gpu) 65 | trainloader, testloader = dm.return_dataloaders() 66 | #print(args.scale_cls,args.num_classes) 67 | #exit(0) 68 | GNN_model=create_models(args,512) 69 | if args.use_similarity: 70 | model = Model(args,GNN_model,scale_cls=args.scale_cls, num_classes=args.num_classes) 71 | else: 72 | model = Model(scale_cls=args.scale_cls, num_classes=args.num_classes) 73 | if only_test: 74 | params = torch.load('result/%s/CAM/1-shot-seed112/%s' % (args.dataset, 'best_model.pth.tar')) 75 | print(type(params)) 76 | #exit(0) 77 | #for key in params.keys(): 78 | #print(type(key)) 79 | #exit(0) 80 | model.load_state_dict(params['state_dict']) 81 | #exit(0) 82 | criterion = CrossEntropyLoss() 83 | optimizer = init_optimizer(args.optim, model.parameters(), args.lr, args.weight_decay) 84 | 85 | if use_gpu: 86 | model = model.cuda() 87 | 88 | start_time = time.time() 89 | train_time = 0 90 | best_acc = -np.inf 91 | best_epoch = 0 92 | print("==> Start training") 93 | 94 | for epoch in range(args.max_epoch): 95 | learning_rate = adjust_learning_rate(optimizer, epoch, args.LUT_lr) 96 | 97 | start_train_time = time.time() 98 | #exit(0) 99 | #print(not True) 100 | #exit(0) 101 | if not only_test: 102 | #print(';;;;;;;;;;;') 103 | #exit(0) 104 | train(epoch, model, criterion, optimizer, trainloader, learning_rate, use_gpu) 105 | train_time += round(time.time() - start_train_time) 106 | 107 | if epoch == 0 or epoch > (args.stepsize[0]-1) or (epoch + 1) % 10 == 0: 108 | print('enter test code') 109 | #exit(0) 110 | acc = test(model, testloader, use_gpu) 111 | is_best = acc > best_acc 112 | #print(acc) 113 | #exit(0) 114 | if is_best: 115 | best_acc = acc 116 | best_epoch = epoch + 1 117 | if not only_test: 118 | save_checkpoint({ 119 | 'state_dict': model.state_dict(), 120 | 'acc': acc, 121 | 'epoch': epoch, 122 | }, is_best, osp.join(args.save_dir, 'checkpoint_ep' + str(epoch + 1) + '.pth.tar')) 123 | 124 | print("==> Test 5-way Best accuracy {:.2%}, achieved at epoch {}".format(best_acc, best_epoch)) 125 | 126 | elapsed = round(time.time() - start_time) 127 | elapsed = str(datetime.timedelta(seconds=elapsed)) 128 | train_time = str(datetime.timedelta(seconds=train_time)) 129 | print("Finished. Total elapsed time (h:m:s): {}. Training time (h:m:s): {}.".format(elapsed, train_time)) 130 | print("==========\nArgs:{}\n==========".format(args)) 131 | 132 | 133 | def train(epoch, model, criterion, optimizer, trainloader, learning_rate, use_gpu): 134 | losses = AverageMeter() 135 | batch_time = AverageMeter() 136 | data_time = AverageMeter() 137 | 138 | model.train() 139 | 140 | end = time.time() 141 | for batch_idx, (images_train, labels_train, images_test, labels_test, pids) in enumerate(trainloader): 142 | data_time.update(time.time() - end) 143 | #pids is the all class id 144 | #print(labels_train.shape) 145 | #exit(0) 146 | if use_gpu: 147 | images_train, labels_train = images_train.cuda(), labels_train.cuda() 148 | images_test, labels_test = images_test.cuda(), labels_test.cuda() 149 | pids = pids.cuda() 150 | 151 | batch_size, num_train_examples, channels, height, width = images_train.size() 152 | num_test_examples = images_test.size(1) 153 | 154 | labels_train_1hot = one_hot(labels_train).cuda() 155 | labels_test_1hot = one_hot(labels_test).cuda() 156 | 157 | ytest, cls_scores = model(images_train, images_test, labels_train_1hot, labels_test_1hot)#ytest is all class classification 158 | #cls_scores is N-way classifation 159 | 160 | loss1 = criterion(ytest, pids.view(-1))# 161 | loss2 = criterion(cls_scores, labels_test.view(-1)) 162 | loss = loss1 + 0.5 * loss2 163 | 164 | optimizer.zero_grad() 165 | loss.backward() 166 | optimizer.step() 167 | 168 | losses.update(loss.item(), pids.size(0)) 169 | batch_time.update(time.time() - end) 170 | end = time.time() 171 | 172 | print('Epoch{0} ' 173 | 'lr: {1} ' 174 | 'Time:{batch_time.sum:.1f}s ' 175 | 'Data:{data_time.sum:.1f}s ' 176 | 'Loss:{loss.avg:.4f} '.format( 177 | epoch+1, learning_rate, batch_time=batch_time, 178 | data_time=data_time, loss=losses)) 179 | 180 | 181 | def test(model, testloader, use_gpu): 182 | accs = AverageMeter() 183 | test_accuracies = [] 184 | model.eval() 185 | 186 | with torch.no_grad(): 187 | for batch_idx , (images_train, labels_train, images_test, labels_test) in enumerate(testloader): 188 | if use_gpu: 189 | images_train = images_train.cuda() 190 | images_test = images_test.cuda() 191 | #print(images_test.shape, 'located in train.py at 177' ) 192 | #print(images_test.shape[0],'located in train.py at 177') 193 | #exit(0) 194 | std=np.expand_dims(np.array([0.229, 0.224, 0.225]),axis=1) 195 | std=np.expand_dims(std,axis=2) 196 | mean=np.expand_dims(np.array([0.485, 0.456, 0.406]),axis=1) 197 | mean=np.expand_dims(mean,axis=2) 198 | #print(std.shape,mean.shape) 199 | #exit(0) 200 | #for i in range(images_test.shape[0]): 201 | #for j in range(images_test.shape[1]): 202 | #images_temp=images_test[i,j,:,:].cpu().numpy() 203 | #print(images_temp.shape) 204 | 205 | #images_temp=images_temp*std+mean 206 | #images_ori=images_temp.transpose((1,2,0)) 207 | #print(images_ori.shape) 208 | #print(images_ori.max(0).max(0).max(0),images_ori.min(0).min(0).min(0)) 209 | #exit(0) 210 | #images_ori=np.uint8(images_ori*255) 211 | #cv2.imwrite('./result/vis_images/images_ori.jpg',images_ori) 212 | #exit(0) 213 | end = time.time() 214 | 215 | batch_size, num_train_examples, channels, height, width = images_train.size() 216 | num_test_examples = images_test.size(1) 217 | 218 | labels_train_1hot = one_hot(labels_train).cuda() 219 | labels_test_1hot = one_hot(labels_test).cuda() 220 | 221 | cls_scores,a1,a2 = model(images_train, images_test, labels_train_1hot, labels_test_1hot,True) 222 | #a1. 223 | #print(a1.shape,a2.shape,'located in train.py at 209',(a1-1).max(),(a1-1).min())#[4,5,75,6,6] 224 | #print(type(a1.max(3))) 225 | #exit(0) 226 | max_a1=a1.max(3)[0].max(3)[0].unsqueeze(3).unsqueeze(3) 227 | min_a1=a1.min(3)[0].min(3)[0].unsqueeze(3).unsqueeze(3) 228 | max_a2=a2.max(3)[0].max(3)[0].unsqueeze(3).unsqueeze(3) 229 | min_a2=a2.min(3)[0].min(3)[0].unsqueeze(3).unsqueeze(3) 230 | #print(min_a1.shape,min_a1[0,0,0],max_a1[0,0,0]) 231 | #exit(0) 232 | #print(std.shape,mean.shape) 233 | #exit(0) 234 | scale_a1=torch.div((a1-min_a1),(max_a1-min_a1)) 235 | scale_a2=torch.div((a2-min_a2),(max_a2-min_a2)) 236 | #print(images_train.shape[1],images_test.shape[1],'located in train.py at 224') 237 | #exit(0) 238 | #print(scale_a1[0,0,1],scale_a2[0,0,1]) 239 | #exit(0) 240 | result_surpport_imgs=np.zeros((84*5+8*4,84*4+8*3,3)).astype(dtype=np.uint8) 241 | #print(labels_test[0]) 242 | #exit(0) 243 | #result_test_imgs=np.zeros((84+3)*20,(84+3)*75,3) 244 | 245 | cls_scores = cls_scores.view(batch_size * num_test_examples, -1) 246 | labels_test = labels_test.view(batch_size * num_test_examples) 247 | 248 | _, preds = torch.max(cls_scores.detach().cpu(), 1) 249 | #print(labels_test.numpy()[:75]) 250 | #print(preds.numpy()[:75]) 251 | #exit(0) 252 | acc = (torch.sum(preds == labels_test.detach().cpu()).float()) / labels_test.size(0) 253 | accs.update(acc.item(), labels_test.size(0)) 254 | #print(images_train.shape,images_test.shape) 255 | #print(scale_a1.shape) 256 | #exit(0) 257 | if only_test: 258 | for i in range(images_test.shape[0]): 259 | for k in range(images_test.shape[1]): 260 | for j in range(images_train.shape[1]): 261 | images_temp_test=images_test[i,k,:,:].cpu().numpy() 262 | images_temp_train=images_train[i,j,:,:].cpu().numpy() 263 | #print(images_temp.shape) 264 | index_support=labels_train[i,j] 265 | index_test= labels_test[i*num_test_examples+k] 266 | #print(label_gt,label_pred) 267 | #exit(0) 268 | images_temp_test=images_temp_test*std+mean 269 | images_ori_test=images_temp_test.transpose((1,2,0))[:,:,::-1] 270 | 271 | images_temp_train=images_temp_train*std+mean 272 | images_ori_train=images_temp_train.transpose((1,2,0))[:,:,::-1] 273 | #print(images_ori.shape) 274 | #print(images_ori.max(0).max(0).max(0),images_ori.min(0).min(0).min(0)) 275 | #exit(0) 276 | hot_a1=cv2.resize(np.uint8(scale_a1[i,index_support,k].cpu().numpy()*255),(84,84)) 277 | hot_a2=cv2.resize(np.uint8(scale_a2[i,index_support,k].cpu().numpy()*255),(84,84)) 278 | heatmap_a1 = cv2.applyColorMap(hot_a1, cv2.COLORMAP_JET) 279 | heatmap_a2 = cv2.applyColorMap(hot_a2, cv2.COLORMAP_JET) 280 | #print(heatmap_a1.shape) 281 | 282 | #exit(0) 283 | images_ori_test=np.uint8(images_ori_test*255) 284 | images_ori_train=np.uint8(images_ori_train*255) 285 | vis_test=images_ori_test*0.7+heatmap_a2*0.3 286 | #hot_a1=scale_a1[i,k,j] 287 | #hot_a2=scale_a2[i,k,j] 288 | vis_train=images_ori_train*0.7+heatmap_a1*0.3 289 | #cv2.imwrite('./result/vis_images/images_ori_test.jpg',images_ori_test) 290 | #cv2.imwrite('./result/vis_images/images_test.jpg',vis_test) 291 | #cv2.imwrite('./result/vis_images/images_ori_train.jpg',images_ori_train) 292 | #cv2.imwrite('./result/vis_images/images_train.jpg',vis_train) 293 | result_surpport_imgs[84*index_support+8*index_support:84*(index_support+1)+8*index_support,:84,:]=images_ori_test 294 | result_surpport_imgs[84*index_support+8*index_support:84*(index_support+1)+8*index_support,84+8:84+84+8,:]=images_ori_train 295 | result_surpport_imgs[84*index_support+8*index_support:84*(index_support+1)+8*index_support,84*2+8*2:84*3+8*2,:]=vis_test 296 | result_surpport_imgs[84*index_support+8*index_support:84*(index_support+1)+8*index_support,84*3+8*3:84*4+8*3,:]=vis_train 297 | label_gt=int(labels_test.numpy()[k]) 298 | label_pred=int(preds.numpy()[k]) 299 | cv2.imwrite('./result/vis_images/vis'+'_'+str(batch_idx)+'_'+str(i)+'_'+str(k)+'_'+str(label_gt)+'_'+str(label_pred)+'.jpg',result_surpport_imgs) 300 | #exit(0) 301 | if not True: 302 | if batch_idx>12: 303 | break 304 | gt = (preds == labels_test.detach().cpu()).float() 305 | gt = gt.view(batch_size, num_test_examples).numpy() #[b, n] 306 | acc = np.sum(gt, 1) / num_test_examples 307 | acc = np.reshape(acc, (batch_size)) 308 | test_accuracies.append(acc) 309 | #exit(0) 310 | accuracy = accs.avg 311 | test_accuracies = np.array(test_accuracies) 312 | test_accuracies = np.reshape(test_accuracies, -1) 313 | stds = np.std(test_accuracies, 0) 314 | ci95 = 1.96 * stds / np.sqrt(args.epoch_size) 315 | print('Accuracy: {:.2%}, std: :{:.2%}'.format(accuracy, ci95)) 316 | #exit(0) 317 | return accuracy 318 | 319 | 320 | if __name__ == '__main__': 321 | main() 322 | -------------------------------------------------------------------------------- /train.sh: -------------------------------------------------------------------------------- 1 | 2 | CUDA_VISIBLE_DEVICES=0 python ./train_with_inpaint_read_from_data_fixed.py --nExemplars 1 --epoch_size 600 --train_nTestNovel 30 --train-batch 4 --nKnovel 5 --Classic 0 --use_similarity 0 \ 3 | --save-dir ./result/miniImageNet/CAM/1-shot-seed112_inpaint_support_fuse_Cam_surport_from_65.96_test_fixed/ 4 | --------------------------------------------------------------------------------