├── README.md └── main.ipynb /README.md: -------------------------------------------------------------------------------- 1 | # Dynamic-Few-Shot-Visual-Learning-without-Forgetting 2 | 3 | pytorch simple implement for the CVPR 2018 paper [Dynamic Few Shot Visual Learning without Forgetting](https://arxiv.org/abs/1804.09458) in jupyter version. 4 | 5 | 6 | ### Set up the dataset 7 | 8 | download the MiniImagenet dataset from [here](https://mega.nz/#!rx0wGQyS!96sFlAr6yyv-9QQPCm5OBFbOm4XSD0t-HlmGaT5GaiE) and modify the data root. 9 | 10 | ### Train the model 11 | 12 | * Step 1 refers to pretrain Feature Extractor and cosine-based Classifier. In this process I set the epoch_num to 30. 13 | * Step 2 refers to continue train Classifier and attention-based Few-shot Weight Generator. 14 | 15 | ### Evaluate 16 | 17 | * I get similar result as published in paper. 18 | -------------------------------------------------------------------------------- /main.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import os\n", 10 | "import pickle\n", 11 | "import numpy as np\n", 12 | "import random\n", 13 | "import torch.utils.data as data\n", 14 | "import torchvision.transforms as transforms\n", 15 | "from PIL import Image\n", 16 | "\n", 17 | "import torch\n", 18 | "import torchnet as tnt\n", 19 | "\n", 20 | "#modify the data root\n", 21 | "_MINI_IMAGENET_DATASET_DIR = '../datasets/MiniImagenet'" 22 | ] 23 | }, 24 | { 25 | "cell_type": "code", 26 | "execution_count": null, 27 | "metadata": {}, 28 | "outputs": [], 29 | "source": [ 30 | "def load_data(file):\n", 31 | " with open(file,'rb') as f:\n", 32 | " data=pickle.load(f,encoding='iso-8859-1')\n", 33 | " return data\n", 34 | "\n", 35 | "def buildLabelIndex(labels):\n", 36 | " label2inds={}\n", 37 | " for idx,label in enumerate(labels):\n", 38 | " if label not in label2inds:\n", 39 | " label2inds[label]=[]\n", 40 | " label2inds[label].append(idx)\n", 41 | " return label2inds" 42 | ] 43 | }, 44 | { 45 | "cell_type": "code", 46 | "execution_count": null, 47 | "metadata": {}, 48 | "outputs": [], 49 | "source": [ 50 | "class MiniImageNet(data.Dataset):\n", 51 | " def __init__(self,phase='train',do_not_use_random_transf=False):\n", 52 | " self.base_folder='miniImagenet'\n", 53 | " assert(phase=='train' or phase=='val' or phase=='test')\n", 54 | " self.phase=phase\n", 55 | " self.name='MiniImageNet_'+phase\n", 56 | " \n", 57 | " print('Loading mini ImageNet dataser - phase {0}'.format(phase))\n", 58 | " file_train_categories_train_phase = os.path.join(_MINI_IMAGENET_DATASET_DIR,\n", 59 | " 'miniImageNet_category_split_train_phase_train.pickle')\n", 60 | " file_train_categories_val_phase = os.path.join(_MINI_IMAGENET_DATASET_DIR,\n", 61 | " 'miniImageNet_category_split_train_phase_val.pickle')\n", 62 | " file_train_categories_test_phase = os.path.join(_MINI_IMAGENET_DATASET_DIR,\n", 63 | " 'miniImageNet_category_split_train_phase_test.pickle')\n", 64 | " file_val_categories_val_phase = os.path.join(_MINI_IMAGENET_DATASET_DIR,\n", 65 | " 'miniImageNet_category_split_val.pickle')\n", 66 | " file_test_categories_test_phase = os.path.join(_MINI_IMAGENET_DATASET_DIR,\n", 67 | " 'miniImageNet_category_split_test.pickle')\n", 68 | " \n", 69 | " if self.phase=='train':\n", 70 | " #during training phase we only load the training phase images of the training category\n", 71 | " data_train=load_data(file_train_categories_train_phase)\n", 72 | " self.data=data_train['data'] #array (n,84,84,3)\n", 73 | " self.labels=data_train['labels'] #list[n]\n", 74 | " \n", 75 | " self.label2ind=buildLabelIndex(self.labels)\n", 76 | " self.labelIds=sorted(self.label2ind.keys())\n", 77 | " self.num_cats=len(self.labelIds)\n", 78 | " self.labelIds_base=self.labelIds\n", 79 | " self.num_cats_base=len(self.labelIds_base)\n", 80 | " elif self.phase=='val' or self.phase=='test':\n", 81 | " if self.phase=='test':\n", 82 | " data_base=load_data(file_train_categories_test_phase)\n", 83 | " data_novel=load_data(file_test_categories_test_phase)\n", 84 | " else:\n", 85 | " data_base=load_data(file_train_categories_val_phase)\n", 86 | " data_novel=load_data(file_val_categories_val_phase)\n", 87 | " \n", 88 | " self.data=np.concatenate(\n", 89 | " [data_base['data'],data_novel['data']],axis=0)\n", 90 | " self.labels=data_base['labels']+data_novel['labels']\n", 91 | " \n", 92 | " self.label2ind=buildLabelIndex(self.labels)\n", 93 | " self.labelIds=sorted(self.label2ind.keys())\n", 94 | " self.num_cats=len(self.labelIds)\n", 95 | " \n", 96 | " self.labelIds_base=buildLabelIndex(data_base['labels']).keys()\n", 97 | " self.labelIds_novel=buildLabelIndex(data_novel['labels']).keys()\n", 98 | " self.num_cats_base=len(self.labelIds_base)\n", 99 | " self.num_cats_novel=len(self.labelIds_novel)\n", 100 | " intersection=set(self.labelIds_base) & set(self.labelIds_novel)\n", 101 | " assert(len(intersection)==0)\n", 102 | " else:\n", 103 | " raise ValueError('Not valid phase {0}'.fotmat(self.phase))\n", 104 | " \n", 105 | " mean_pix = [x/255.0 for x in [120.39586422, 115.59361427, 104.54012653]]\n", 106 | " std_pix = [x/255.0 for x in [70.68188272, 68.27635443, 72.54505529]]\n", 107 | " normalize = transforms.Normalize(mean=mean_pix, std=std_pix)\n", 108 | " \n", 109 | " if (self.phase=='test' or self.phase=='val') or (do_not_use_random_transf==True):\n", 110 | " self.transform=transforms.Compose([\n", 111 | " lambda x:np.array(x),\n", 112 | " transforms.ToTensor(),\n", 113 | " normalize\n", 114 | " ])\n", 115 | " else:\n", 116 | " self.transform=transforms.Compose([\n", 117 | " transforms.RandomCrop(84,padding=8),\n", 118 | " transforms.RandomHorizontalFlip(),\n", 119 | " lambda x : np.array(x),\n", 120 | " transforms.ToTensor(),\n", 121 | " normalize\n", 122 | " ])\n", 123 | " \n", 124 | " def __getitem__(self,index):\n", 125 | " img,label=self.data[index],self.labels[index]\n", 126 | " #doing this so that it is consistent with all other datasets to return a PIL image\n", 127 | " img=Image.fromarray(img)\n", 128 | " if self.transform is not None:\n", 129 | " img=self.transform(img)\n", 130 | " return img,label\n", 131 | " \n", 132 | " def __len__(self):\n", 133 | " return len(self.data)" 134 | ] 135 | }, 136 | { 137 | "cell_type": "code", 138 | "execution_count": null, 139 | "metadata": {}, 140 | "outputs": [], 141 | "source": [ 142 | "class FewShotDataloader():\n", 143 | " def __init__(self,dataset,\n", 144 | " nKnovel=5,#number of novel categories\n", 145 | " nKbase=-1,#number of base categories\n", 146 | " nExemplars=1,#number of training examples per novel category\n", 147 | " nTestNovel=15*5,#number of test examples for all novel categories\n", 148 | " nTestBase=15*5,#number of test examples for all base categories\n", 149 | " batch_size=1,#number of training episodes per batch\n", 150 | " num_workers=4,\n", 151 | " epoch_size=2000):\n", 152 | " self.dataset=dataset\n", 153 | " self.phase=self.dataset.phase\n", 154 | " max_possible_nKnovel=(self.dataset.num_cats_base if self.phase=='train'\n", 155 | " else self.dataset.num_cats_novel)\n", 156 | " assert(nKnovel>=0 and nKnovel=0 else max_possible_nKbase\n", 161 | " \n", 162 | " if self.phase=='train' and nKbase>0:\n", 163 | " nKbase-=self.nKnovel\n", 164 | " max_possible_nKbase-=self.nKnovel\n", 165 | " \n", 166 | " assert(nKbase>=0 and nKbase <=max_possible_nKbase)\n", 167 | " self.nKbase=nKbase\n", 168 | " \n", 169 | " self.nExemplars=nExemplars\n", 170 | " self.nTestNovel=nTestNovel\n", 171 | " self.nTestBase=nTestBase\n", 172 | " self.batch_size=batch_size\n", 173 | " self.epoch_size=epoch_size\n", 174 | " self.num_workers=num_workers\n", 175 | " self.is_eval_mode=(self.phase=='test') or (self.phase=='val')\n", 176 | " \n", 177 | " def sampleImageIdsFrom(self, cat_id , sample_size=1):\n", 178 | " \"\"\"\n", 179 | " samples 'sample_size' number of unique image ids picked from the\n", 180 | " category 'cat_id'\n", 181 | " \"\"\"\n", 182 | " assert(cat_id in self.dataset.label2ind)\n", 183 | " assert(len(self.dataset.label2ind[cat_id])>=sample_size)\n", 184 | " #Note : random.sample samples elements without replacement.\n", 185 | " return random.sample(self.dataset.label2ind[cat_id],sample_size)\n", 186 | " \n", 187 | " def sampleCategories(self,cat_set,sample_size=1):\n", 188 | " \"\"\"\n", 189 | " Samples 'sample_size' number of unique categories picked from the\n", 190 | " 'cat_set' set of categories.'cat_set' can be either 'base' or 'novel'.\n", 191 | " \"\"\"\n", 192 | " if cat_set=='base':\n", 193 | " labelIds=self.dataset.labelIds_base\n", 194 | " elif cat_set=='novel':\n", 195 | " labelIds=self.dataset.labelIds_novel\n", 196 | " else:\n", 197 | " raise ValueError('Not recognize category set {}'.format(cat_set))\n", 198 | " \n", 199 | " assert(len(labelIds)>=sample_size)\n", 200 | " \n", 201 | " return random.sample(labelIds,sample_size)\n", 202 | " \n", 203 | " def sample_base_and_novel_categories(self,nKbase , nKnovel):\n", 204 | " \"\"\"\n", 205 | " Samples 'nKbase' number of base categories and 'nKnovel' number of novel categories.\n", 206 | " \"\"\"\n", 207 | " if self.is_eval_mode:\n", 208 | " assert(nKnovel<=self.dataset.num_cats_novel)\n", 209 | " \n", 210 | " Kbase = sorted(self.sampleCategories('base',nKbase))\n", 211 | " Knovel = sorted(self.sampleCategories('novel',nKnovel))\n", 212 | " else:\n", 213 | " cats_ids = self.sampleCategories('base',nKbase+nKnovel)\n", 214 | " assert(len(cats_ids)==(nKbase+nKnovel))\n", 215 | " \n", 216 | " random.shuffle(cats_ids)\n", 217 | " Knovel=sorted(cats_ids[:nKnovel])\n", 218 | " Kbase=sorted(cats_ids[nKnovel:])\n", 219 | " return Kbase,Knovel\n", 220 | " \n", 221 | " def sample_test_examples_for_base_categories(self,Kbase,nTestBase):\n", 222 | " \"\"\"\n", 223 | " Sample 'nTestBase' number of images from the 'Kbase' categories.\n", 224 | " \n", 225 | " \"\"\"\n", 226 | " Tbase=[]\n", 227 | " if len(Kbase)>0:\n", 228 | " KbaseIndices=np.random.choice(np.arange(len(Kbase)),size=nTestBase,replace=True)\n", 229 | " KbaseIndices,NumImagesPerCategory=np.unique(KbaseIndices,return_counts=True)\n", 230 | " \n", 231 | " for Kbase_idx,NumImages in zip(KbaseIndices,NumImagesPerCategory):\n", 232 | " imd_ids=self.sampleImageIdsFrom(Kbase[Kbase_idx],sample_size=NumImages)\n", 233 | " Tbase+=[(img_id,Kbase_idx) for img_id in imd_ids]\n", 234 | " assert(len(Tbase)==nTestBase)\n", 235 | " \n", 236 | " return Tbase\n", 237 | " \n", 238 | " def sample_train_and_test_examples_for_novel_categories(\n", 239 | " self,Knovel,nTestNovel,nExemplars,nKbase):\n", 240 | " \"\"\"\n", 241 | " Samples train and test examples of the novel categories.\n", 242 | " \n", 243 | " Args:\n", 244 | " Knovel:a list with the ids of the novel categories\n", 245 | " nTestNovel:the total number of test imgs that will be sampled from all novel categories\n", 246 | " nExemplars:the number of training examples per novel category that will be sampled\n", 247 | " nKbase:the number of base categories.it's used as offset of the category index of each sampled img\n", 248 | " \n", 249 | " Returns:\n", 250 | " Tnovel:a list of length 'nTestNovel' with 2-element tuple.\n", 251 | " (img_id , category_label)\n", 252 | " Exemplars: a list of length len(Knovel)*nExemplars of 2-element tuple\n", 253 | " (img_id , category_label range in [nKbase,nKbase+len(Knovel)-1])\n", 254 | " \"\"\"\n", 255 | " if len(Knovel)==0:\n", 256 | " return [],[]\n", 257 | " \n", 258 | " nKnovel=len(Knovel)\n", 259 | " Tnovel=[]\n", 260 | " Exemplars=[]\n", 261 | " assert((nTestNovel % nKnovel)==0)\n", 262 | " nEvalExamplesPerClass = int(nTestNovel/nKnovel)\n", 263 | " \n", 264 | " for Knovel_idx in range(len(Knovel)):\n", 265 | " imd_ids=self.sampleImageIdsFrom(Knovel[Knovel_idx],sample_size=(nEvalExamplesPerClass+nExemplars))\n", 266 | " imds_tnovel=imd_ids[:nEvalExamplesPerClass]\n", 267 | " imds_ememplars=imd_ids[nEvalExamplesPerClass:]\n", 268 | " \n", 269 | " Tnovel+=[(img_id , nKbase+Knovel_idx) for img_id in imds_tnovel]\n", 270 | " Exemplars+=[(img_id,nKbase+Knovel_idx) for img_id in imds_ememplars]\n", 271 | " \n", 272 | " assert(len(Tnovel)==nTestNovel)\n", 273 | " assert(len(Exemplars)==len(Knovel)*nExemplars)\n", 274 | " random.shuffle(Exemplars)\n", 275 | " \n", 276 | " return Tnovel,Exemplars\n", 277 | " \n", 278 | " def sample_episode(self):\n", 279 | " \"\"\"\n", 280 | " Sample a training episode\n", 281 | " \"\"\"\n", 282 | " nKnovel=self.nKnovel\n", 283 | " nKbase=self.nKbase\n", 284 | " nTestNovel=self.nTestNovel\n", 285 | " nTestBase=self.nTestBase\n", 286 | " nExemplars=self.nExemplars\n", 287 | " \n", 288 | " Kbase,Knovel = self.sample_base_and_novel_categories(nKbase,nKnovel)\n", 289 | " Tbase=self.sample_test_examples_for_base_categories(Kbase,nTestBase)\n", 290 | " Tnovel,Exemplars=self.sample_train_and_test_examples_for_novel_categories(Knovel,nTestNovel,nExemplars,nKbase)\n", 291 | " \n", 292 | " #concatenate the base and novel category examples\n", 293 | " Test=Tbase+Tnovel\n", 294 | " random.shuffle(Test)\n", 295 | " Kall=Kbase+Knovel\n", 296 | " \n", 297 | " return Exemplars , Test , Kall , nKbase\n", 298 | " \n", 299 | " def createExamplesTensorData(self,examples):\n", 300 | " \"\"\"\n", 301 | " Create the examples image and label tensor data\n", 302 | " \"\"\"\n", 303 | " images=torch.stack(\n", 304 | " [self.dataset[img_idx][0] for img_idx ,_ in examples],dim=0)\n", 305 | " labels=torch.LongTensor([label for _,label in examples])\n", 306 | " return images,labels\n", 307 | " \n", 308 | " def get_iterator(self,epoch=0):\n", 309 | " rand_seed=epoch\n", 310 | " random.seed(rand_seed)\n", 311 | " np.random.seed(rand_seed)\n", 312 | " \n", 313 | " def load_function(iter_idx):\n", 314 | " Exemplars,Test,Kall,nKbase = self.sample_episode()\n", 315 | " Xt,Yt=self.createExamplesTensorData(Test)\n", 316 | " Kall=torch.LongTensor(Kall)\n", 317 | " if len(Exemplars)>0:\n", 318 | " Xe,Ye=self.createExamplesTensorData(Exemplars)\n", 319 | " return Xe,Ye,Xt,Yt,Kall,nKbase\n", 320 | " else:\n", 321 | " return Xt,Yt,Kall,nKbase\n", 322 | " \n", 323 | " tnt_dataset=tnt.dataset.ListDataset(\n", 324 | " elem_list=range(self.epoch_size),load=load_function)\n", 325 | " data_loader=tnt_dataset.parallel(\n", 326 | " batch_size=self.batch_size,\n", 327 | " num_workers=(0 if self.is_eval_mode else self.num_workers),\n", 328 | " shuffle=(False if self.is_eval_mode else True))\n", 329 | " \n", 330 | " return data_loader\n", 331 | " def __call__(self,epoch=0):\n", 332 | " return self.get_iterator(epoch)\n", 333 | " \n", 334 | " def __len__(self):\n", 335 | " return int(self.epoch_size/self.batch_size)" 336 | ] 337 | }, 338 | { 339 | "cell_type": "code", 340 | "execution_count": null, 341 | "metadata": {}, 342 | "outputs": [], 343 | "source": [ 344 | "import torch\n", 345 | "import torch.nn as nn\n", 346 | "import torch.nn.functional as F\n", 347 | "from tqdm import tqdm" 348 | ] 349 | }, 350 | { 351 | "cell_type": "code", 352 | "execution_count": null, 353 | "metadata": {}, 354 | "outputs": [], 355 | "source": [ 356 | "class BasicModule(nn.Module):\n", 357 | " def __init__(self):\n", 358 | " super(BasicModule,self).__init__()\n", 359 | " self.model_name=str(type(self))\n", 360 | " \n", 361 | " def load(self,path):\n", 362 | " self.load_state_dict(torch.load(path))\n", 363 | " \n", 364 | " def save(self,path=None):\n", 365 | " if path is None:\n", 366 | " raise ValueError('Please specify the saving road!!!')\n", 367 | " torch.save(self.state_dict(),path)\n", 368 | " return path" 369 | ] 370 | }, 371 | { 372 | "cell_type": "code", 373 | "execution_count": null, 374 | "metadata": {}, 375 | "outputs": [], 376 | "source": [ 377 | "def conv_block(in_channels,out_channels,use_relu=True):\n", 378 | " if use_relu:\n", 379 | " return nn.Sequential(\n", 380 | " nn.Conv2d(in_channels,out_channels,3,padding=1),\n", 381 | " nn.BatchNorm2d(out_channels),\n", 382 | " nn.ReLU(),\n", 383 | " nn.MaxPool2d(2)\n", 384 | " )\n", 385 | " else:\n", 386 | " return nn.Sequential(\n", 387 | " nn.Conv2d(in_channels,out_channels,3,padding=1),\n", 388 | " nn.BatchNorm2d(out_channels),\n", 389 | " nn.MaxPool2d(2)\n", 390 | " )\n", 391 | " \n", 392 | "class AvgBlock(BasicModule):\n", 393 | " def __init__(self,nFeat):\n", 394 | " super(AvgBlock,self).__init__()\n", 395 | " \n", 396 | " def forward(self,features_train , labels_train):\n", 397 | " labels_train_transposed=labels_train.transpose(1,2)\n", 398 | " weight_novel=torch.bmm(labels_train_transposed,features_train)\n", 399 | " weight_novel=weight_novel.div(\n", 400 | " labels_train_transposed.sum(dim=2,keepdim=True).expand_as(weight_novel))\n", 401 | " return weight_novel" 402 | ] 403 | }, 404 | { 405 | "cell_type": "code", 406 | "execution_count": null, 407 | "metadata": {}, 408 | "outputs": [], 409 | "source": [ 410 | "class ConvNet(BasicModule):\n", 411 | " def __init__(self):\n", 412 | " super(ConvNet,self).__init__()\n", 413 | " self.encoder=nn.Sequential(\n", 414 | " conv_block(3,64),\n", 415 | " conv_block(64,64),\n", 416 | " conv_block(64,128),\n", 417 | " conv_block(128,128,use_relu=False),\n", 418 | " )\n", 419 | " def forward(self,x):\n", 420 | " out=self.encoder(x)\n", 421 | " out=out.view(out.size(0),-1)\n", 422 | " return out\n", 423 | " " 424 | ] 425 | }, 426 | { 427 | "cell_type": "code", 428 | "execution_count": null, 429 | "metadata": {}, 430 | "outputs": [], 431 | "source": [ 432 | "class AttentionBlock(BasicModule):\n", 433 | " def __init__(self,nFeat,nKall,scale_att=10.0):\n", 434 | " super(AttentionBlock,self).__init__()\n", 435 | " \n", 436 | " self.nFeat=nFeat\n", 437 | " self.queryLayer=nn.Linear(nFeat,nFeat)\n", 438 | " self.queryLayer.weight.data.copy_(\n", 439 | " torch.eye(nFeat,nFeat)+torch.randn(nFeat,nFeat)*0.001)\n", 440 | " self.queryLayer.bias.data.zero_()\n", 441 | " \n", 442 | " self.scale_att=nn.Parameter(torch.FloatTensor(1).fill_(scale_att),requires_grad=True)\n", 443 | " wkeys=torch.FloatTensor(nKall,nFeat).normal_(0.0,np.sqrt(2.0/nFeat))\n", 444 | " self.wkeys=nn.Parameter(wkeys,requires_grad=True)\n", 445 | " \n", 446 | " def forward(self,features_train,labels_train,weight_base,Kbase):\n", 447 | " \n", 448 | " batch_size,num_train_examples,num_features=features_train.size()\n", 449 | " nKbase=weight_base.size(1) #[batch_size,nKbase,num_features]\n", 450 | " labels_train_transposed=labels_train.transpose(1,2)\n", 451 | " nKnovel=labels_train_transposed.size(1) #[batch_size,nKnovel,num_train_examples]\n", 452 | " \n", 453 | " features_train=features_train.view(batch_size*num_train_examples,num_features)\n", 454 | " Qe=self.queryLayer(features_train)\n", 455 | " Qe=Qe.view(batch_size,num_train_examples,self.nFeat)\n", 456 | " Qe=F.normalize(Qe,p=2,dim=Qe.dim()-1,eps=1e-12)\n", 457 | " \n", 458 | " wkeys=self.wkeys[Kbase.view(-1)]\n", 459 | " wkeys=F.normalize(wkeys,p=2,dim=wkeys.dim()-1,eps=1e-12)\n", 460 | " #Transpose from[batch_size,nKbase,nFeat]->[batch_size,nFeat,nKbase]\n", 461 | " wkeys=wkeys.view(batch_size,nKbase,self.nFeat).transpose(1,2)\n", 462 | " \n", 463 | " #Compute the attention coefficients\n", 464 | " #AttenCoffiencients=Qe*wkeys -> \n", 465 | " #[batch_size x num_train_examples x nKbase] =[batch_size x num_train_examples x nFeat] * [batch_size x nFeat x nKbase]\n", 466 | " AttentionCoef=self.scale_att*torch.bmm(Qe,wkeys)\n", 467 | " AttentionCoef=F.softmax(AttentionCoef.view(batch_size*num_train_examples,nKbase))\n", 468 | " AttentionCoef=AttentionCoef.view(batch_size,num_train_examples,nKbase)\n", 469 | " \n", 470 | " #Compute the weight_novel\n", 471 | " #weight_novel=AttentionCoef * weight_base ->\n", 472 | " #[batch_size x num_train_examples x num_features] =[batch_size x num_train_examples x nKbase] * [batch_size x nKbase x num_features]\n", 473 | " weight_novel=torch.bmm(AttentionCoef,weight_base)\n", 474 | " #weight_novel=labels_train_transposed*weight_novel ->\n", 475 | " #[batch_size x nKnovel x num_features] = [batch_size x nKnovel x num_train_examples] * [batch_size x num_train_examples x num_features]\n", 476 | " weight_novel=torch.bmm(labels_train_transposed,weight_novel)\n", 477 | " #div K-shot ,get avg\n", 478 | " weight_novel=weight_novel.div(labels_train_transposed.sum(dim=2,keepdim=True).expand_as(weight_novel))\n", 479 | " return weight_novel" 480 | ] 481 | }, 482 | { 483 | "cell_type": "code", 484 | "execution_count": null, 485 | "metadata": {}, 486 | "outputs": [], 487 | "source": [ 488 | "class LinearDiag(BasicModule):\n", 489 | " def __init__(self,num_features,bias=False):\n", 490 | " super(LinearDiag,self).__init__()\n", 491 | " weight=torch.FloatTensor(num_features).fill_(1)#initialize to the identity transform\n", 492 | " self.weight=nn.Parameter(weight,requires_grad=True)\n", 493 | " \n", 494 | " if bias:\n", 495 | " bias=torch.FloatTensor(num_features).fill_(0)\n", 496 | " self.bias=nn.Parameter(bias,requires_grad=True)\n", 497 | " \n", 498 | " else:\n", 499 | " self.register_parameter('bias',None)\n", 500 | " \n", 501 | " def forward(self,X):\n", 502 | " assert(X.dim()==2 and X.size(1)==self.weight.size(0))\n", 503 | " out=X*self.weight.expand_as(X)\n", 504 | " if self.bias is not None:\n", 505 | " out=out+self.bias.expand_as(out)\n", 506 | " \n", 507 | " return out\n", 508 | " " 509 | ] 510 | }, 511 | { 512 | "cell_type": "code", 513 | "execution_count": null, 514 | "metadata": {}, 515 | "outputs": [], 516 | "source": [ 517 | "class Classifier(BasicModule):\n", 518 | " def __init__(self,nKall=64,nFeat=128*5*5,weight_generator_type='none'):\n", 519 | " super(Classifier,self).__init__()\n", 520 | " self.nKall=nKall\n", 521 | " self.nFeat=nFeat\n", 522 | " self.weight_generator_type=weight_generator_type\n", 523 | " \n", 524 | " weight_base=torch.FloatTensor(nKall,nFeat).normal_(\n", 525 | " 0.0,np.sqrt(2.0/nFeat))\n", 526 | " self.weight_base=nn.Parameter(weight_base,requires_grad=True)\n", 527 | " self.bias=nn.Parameter(torch.FloatTensor(1).fill_(0),requires_grad=True)\n", 528 | " scale_cls=10.0\n", 529 | " self.scale_cls=nn.Parameter(torch.FloatTensor(1).fill_(scale_cls),requires_grad=True)\n", 530 | " \n", 531 | " if self.weight_generator_type=='none':\n", 532 | " #if type is none , then feature averaging is being used.\n", 533 | " #However,in this case the generator doesn't involve any learnable params ,thus doesn't require training\n", 534 | " self.favgblock=AvgBlock(nFeat)\n", 535 | " elif self.weight_generator_type=='attention_based':\n", 536 | " scale_att=10.0\n", 537 | " self.favgblock=AvgBlock(nFeat)\n", 538 | " self.attentionBlock=AttentionBlock(nFeat,nKall,scale_att=scale_att)\n", 539 | " \n", 540 | " self.wnLayerFavg=LinearDiag(nFeat)\n", 541 | " self.wnLayerWatt=LinearDiag(nFeat)\n", 542 | " else:\n", 543 | " raise ValueError('weight_generator_type is not supported!')\n", 544 | " \n", 545 | " def get_classification_weights(\n", 546 | " self,Kbase_ids,features_train=None,labels_train=None):\n", 547 | " \"\"\"\n", 548 | " Args:\n", 549 | " Get the classification weights of the base and novel categories.\n", 550 | " Kbase_ids:[batch_size , nKbase],the indices of base categories that used\n", 551 | " features_train:[batch_size,num_train_examples(way*shot),nFeat]\n", 552 | " labels_train :[batch_size,num_train_examples,nKnovel(way)] one-hot of features_train\n", 553 | " \n", 554 | " return:\n", 555 | " cls_weights:[batch_size,nK,nFeat] \n", 556 | " \"\"\"\n", 557 | " #get the classification weights for the base categories\n", 558 | " batch_size,nKbase=Kbase_ids.size()\n", 559 | " weight_base=self.weight_base[Kbase_ids.view(-1)]\n", 560 | " weight_base=weight_base.view(batch_size,nKbase,-1)\n", 561 | " \n", 562 | " #if training data for novel categories are not provided,return only base_weight\n", 563 | " if features_train is None or labels_train is None:\n", 564 | " return weight_base\n", 565 | " \n", 566 | " #get classification weights for novel categories\n", 567 | " _,num_train_examples , num_channels=features_train.size()\n", 568 | " nKnovel=labels_train.size(2)\n", 569 | " \n", 570 | " #before do cosine similarity ,do L2 normalize\n", 571 | " features_train=F.normalize(features_train,p=2,dim=features_train.dim()-1,eps=1e-12)\n", 572 | " if self.weight_generator_type=='none':\n", 573 | " weight_novel=self.favgblock(features_train,labels_train)\n", 574 | " weight_novel=weight_novel.view(batch_size,nKnovel,num_channels)\n", 575 | " elif self.weight_generator_type=='attention_based':\n", 576 | " weight_novel_avg=self.favgblock(features_train,labels_train)\n", 577 | " weight_novel_avg=self.wnLayerFavg(weight_novel_avg.view(batch_size*nKnovel,num_channels))\n", 578 | " \n", 579 | " #do L2 for weighr_base\n", 580 | " weight_base_tmp=F.normalize(weight_base,p=2,dim=weight_base.dim()-1,eps=1e-12)\n", 581 | " \n", 582 | " weight_novel_att=self.attentionBlock(features_train,labels_train,weight_base_tmp,Kbase_ids)\n", 583 | " weight_novel_att=self.wnLayerWatt(weight_novel_att.view(batch_size*nKnovel,num_channels))\n", 584 | " \n", 585 | " weight_novel=weight_novel_avg+weight_novel_att\n", 586 | " weight_novel=weight_novel.view(batch_size,nKnovel,num_channels)\n", 587 | " else:\n", 588 | " raise ValueError('weight generator type is not supported!')\n", 589 | " \n", 590 | " #Concatenate the base and novel classification weights and return\n", 591 | " weight_both=torch.cat([weight_base,weight_novel],dim=1)#[batch_size ,nKbase+nKnovel , num_channel]\n", 592 | " \n", 593 | " return weight_both\n", 594 | " \n", 595 | " def apply_classification_weights(self,features,cls_weights):\n", 596 | " \"\"\"\n", 597 | " Apply the classification weight vectors to the feature vectors\n", 598 | " Args:\n", 599 | " features:[batch_size,num_test_examples,num_channels]\n", 600 | " cls_weights:[batch_size,nK,num_channels]\n", 601 | " Return:\n", 602 | " cls_scores:[batch_size,num_test_examples(query set),nK]\n", 603 | " \"\"\"\n", 604 | " #do L2 normalize\n", 605 | " features=F.normalize(features,p=2,dim=features.dim()-1,eps=1e-12)\n", 606 | " cls_weights=F.normalize(cls_weights,p=2,dim=cls_weights.dim()-1,eps=1e-12)\n", 607 | " cls_scores=self.scale_cls*torch.baddbmm(1.0,\n", 608 | " self.bias.view(1,1,1),1.0,features,cls_weights.transpose(1,2))\n", 609 | " return cls_scores\n", 610 | " \n", 611 | " def forward(self,features_test,Kbase_ids,features_train=None,labels_train=None):\n", 612 | " \"\"\"\n", 613 | " Recognize on the test examples both base and novel categories.\n", 614 | " Args:\n", 615 | " features_test:[batch_size,num_test_examples(query set),num_channels]\n", 616 | " Kbase_ids:[batch_size,nKbase] , the indices of base categories that are being used.\n", 617 | " features_train:[batch_size,num_train_examples,num_channels]\n", 618 | " labels_train:[batch_size,num_train_examples,nKnovel]\n", 619 | " \n", 620 | " Return:\n", 621 | " cls_score:[batch_size,num_test_examples,nKbase+nKnovel]\n", 622 | " \n", 623 | " \"\"\"\n", 624 | " cls_weights=self.get_classification_weights(\n", 625 | " Kbase_ids,features_train,labels_train)\n", 626 | " cls_scores=self.apply_classification_weights(features_test,cls_weights)\n", 627 | " return cls_scores " 628 | ] 629 | }, 630 | { 631 | "cell_type": "markdown", 632 | "metadata": {}, 633 | "source": [ 634 | "### Training procedure\n", 635 | "### training step1 : training FE and pretrain cosine-based classifier" 636 | ] 637 | }, 638 | { 639 | "cell_type": "code", 640 | "execution_count": null, 641 | "metadata": {}, 642 | "outputs": [], 643 | "source": [ 644 | "#step 1 =========training Feature Extractor and pretrain cosine-based classifier\n", 645 | "use_cuda=torch.cuda.is_available()\n", 646 | "torch.cuda.set_device(0)\n", 647 | "torch.manual_seed(1234)\n", 648 | "if use_cuda:\n", 649 | " torch.cuda.manual_seed(1234)\n" 650 | ] 651 | }, 652 | { 653 | "cell_type": "code", 654 | "execution_count": null, 655 | "metadata": {}, 656 | "outputs": [], 657 | "source": [ 658 | "epoch=31\n", 659 | "lr=0.1\n", 660 | "momentum=0.9\n", 661 | "weight_decay=5e-4\n", 662 | "\n", 663 | "dataset_train=MiniImageNet(phase='train')\n", 664 | "# dataset_test=MiniImageNet(phase='val')\n", 665 | "\n", 666 | "dloader_train=FewShotDataloader(dataset=dataset_train,\n", 667 | " nKnovel=0,\n", 668 | " nKbase=64,\n", 669 | " nExemplars=0,\n", 670 | " nTestNovel=0,\n", 671 | " nTestBase=32,\n", 672 | " batch_size=8,\n", 673 | " num_workers=1,\n", 674 | " epoch_size=8*1000)\n" 675 | ] 676 | }, 677 | { 678 | "cell_type": "code", 679 | "execution_count": null, 680 | "metadata": {}, 681 | "outputs": [], 682 | "source": [ 683 | "\n", 684 | "if not os.path.isdir('results/trace_file'):\n", 685 | " os.makedirs('results/trace_file')\n", 686 | " os.makedirs('results/pretrain_model')\n", 687 | " \n", 688 | "trace_file=os.path.join('results','trace_file','pre_train_trace.txt')\n", 689 | "if os.path.isfile(trace_file):\n", 690 | " os.remove(trace_file)\n", 691 | " \n", 692 | "#model\n", 693 | "fe_model=ConvNet()\n", 694 | "classifier=Classifier()\n", 695 | "if use_cuda:\n", 696 | " fe_model.cuda()\n", 697 | " classifier.cuda()\n", 698 | "\n", 699 | "#optimizer\n", 700 | "optimizer_fe=torch.optim.SGD(fe_model.parameters(),lr=lr,nesterov=True , momentum=momentum,weight_decay=weight_decay)\n", 701 | "optimizer_classifier=torch.optim.SGD(classifier.parameters(),lr=lr,nesterov=True , momentum=momentum,weight_decay=weight_decay)\n", 702 | "lr_schedule_fe=torch.optim.lr_scheduler.StepLR(optimizer=optimizer_fe,gamma=0.5,step_size=25)\n", 703 | "lr_schedule_classifier=torch.optim.lr_scheduler.StepLR(optimizer=optimizer_classifier,gamma=0.5,step_size=25)\n", 704 | "criterion=torch.nn.CrossEntropyLoss()\n", 705 | "\n", 706 | "print(\"----pre-train----\")\n", 707 | "for ep in range(epoch):\n", 708 | " train_loss=[]\n", 709 | " print(\"----epoch: %2d---- \"%ep)\n", 710 | " fe_model.train()\n", 711 | " classifier.train()\n", 712 | " \n", 713 | " for batch in tqdm(dloader_train(ep)):\n", 714 | " assert(len(batch)==4)\n", 715 | " \n", 716 | " optimizer_fe.zero_grad()\n", 717 | " optimizer_classifier.zero_grad()\n", 718 | " \n", 719 | " train_data=batch[0]\n", 720 | " train_label=batch[1]\n", 721 | " k_id=batch[2]\n", 722 | " \n", 723 | " if use_cuda:\n", 724 | " train_data=train_data.cuda()\n", 725 | " train_label=train_label.cuda()\n", 726 | " k_id=k_id.cuda()\n", 727 | " \n", 728 | " batch_size,nTestBase,channels,width,high=train_data.size()\n", 729 | " train_data=train_data.view(batch_size*nTestBase,channels,width,high)\n", 730 | " train_data_embedding=fe_model(train_data)\n", 731 | " pred_result=classifier(train_data_embedding.view(batch_size,nTestBase,-1),k_id)\n", 732 | "# print(\"pred_result.size\",pred_result.size())\n", 733 | " loss=criterion(pred_result.view(batch_size*nTestBase,-1),train_label.view(batch_size*nTestBase))\n", 734 | " loss.backward()\n", 735 | " optimizer_fe.step()\n", 736 | " optimizer_classifier.step()\n", 737 | " train_loss.append(float(loss))\n", 738 | " lr_schedule_fe.step()\n", 739 | " lr_schedule_classifier.step()\n", 740 | " \n", 741 | " avg_loss=np.mean(train_loss)\n", 742 | " print(\"epoch %2d training end : avg_loss = %.4f\"%(ep,avg_loss))\n", 743 | " with open(trace_file,'a') as f:\n", 744 | " f.write('epoch:{:2d} training end:avg_loss:{:.4f}'.format(ep,avg_loss))\n", 745 | " f.write('\\n')\n", 746 | " if ep==epoch-1:\n", 747 | " p1='results/pretrain_model/fe_%s.pth'%(str(ep))\n", 748 | " p2='results/pretrain_model/classifier_%s.pth'%(str(ep))\n", 749 | " m1=fe_model.save(path=p1)\n", 750 | " m2=classifier.save(path=p2)\n", 751 | " print(\"Epoch %2d model successfully saved!\"%(ep))" 752 | ] 753 | }, 754 | { 755 | "cell_type": "markdown", 756 | "metadata": {}, 757 | "source": [ 758 | "### training step2 : continue to train classifier and attention-based weight generator" 759 | ] 760 | }, 761 | { 762 | "cell_type": "code", 763 | "execution_count": null, 764 | "metadata": {}, 765 | "outputs": [], 766 | "source": [ 767 | "#step 2 \n", 768 | "path_fe='results/pretrain_model/fe_30.pth'\n", 769 | "path_classifier='results/pretrain_model/classifier_30.pth'\n", 770 | "\n", 771 | "#load pretrain model\n", 772 | "fe_model=ConvNet()\n", 773 | "classifier=Classifier(weight_generator_type='attention_based')\n", 774 | "pre_train_classifier=torch.load(path_classifier)\n", 775 | "\n", 776 | "fe_model.load(path_fe)\n", 777 | "\n", 778 | "for pname , param in classifier.named_parameters():\n", 779 | " if pname in pre_train_classifier:\n", 780 | " param.data.copy_(pre_train_classifier[pname])\n", 781 | " " 782 | ] 783 | }, 784 | { 785 | "cell_type": "code", 786 | "execution_count": null, 787 | "metadata": {}, 788 | "outputs": [], 789 | "source": [ 790 | "#load training data\n", 791 | "epoch=60\n", 792 | "lr=0.1\n", 793 | "momentum=0.9\n", 794 | "weight_decay=5e-4\n", 795 | "\n", 796 | "dataset_train=MiniImageNet(phase='train')\n", 797 | "dataset_test=MiniImageNet(phase='val')\n", 798 | "\n", 799 | "dloader_train=FewShotDataloader(dataset=dataset_train,\n", 800 | " nKnovel=5,\n", 801 | " nKbase=-1,\n", 802 | " nExemplars=1,\n", 803 | " nTestNovel=5*3,\n", 804 | " nTestBase=5*3,\n", 805 | " batch_size=8,\n", 806 | " num_workers=1,\n", 807 | " epoch_size=8*1000)#8*1000\n", 808 | "dloader_test = FewShotDataloader(\n", 809 | " dataset=dataset_test,\n", 810 | " nKnovel=5,\n", 811 | " nKbase=64,\n", 812 | " nExemplars=1, # num training examples per novel category\n", 813 | " nTestNovel=15*5, # num test examples for all the novel categories\n", 814 | " nTestBase=15*5, # num test examples for all the base categories\n", 815 | " batch_size=1,\n", 816 | " num_workers=0,\n", 817 | " epoch_size=2000, #2000 num of batches per epoch\n", 818 | ")" 819 | ] 820 | }, 821 | { 822 | "cell_type": "code", 823 | "execution_count": null, 824 | "metadata": {}, 825 | "outputs": [], 826 | "source": [ 827 | "def get_labels_train_one_hot(labels_train,num_classes):\n", 828 | " res=[]\n", 829 | " batch_size,num=labels_train.size()\n", 830 | " for i in range(batch_size):\n", 831 | " min_value=torch.min(labels_train[i])\n", 832 | " labels=labels_train[i]-min_value\n", 833 | " one_hot=torch.zeros((num,num_classes))\n", 834 | " for i in range(len(labels)):\n", 835 | " one_hot[i][labels[i]]=1\n", 836 | " res.append(one_hot)\n", 837 | " return torch.cat(res).view(batch_size,num,num_classes)\n", 838 | " \n", 839 | "def get_acc(pred,labels):\n", 840 | " _,pred_inds=pred.max(dim=1)\n", 841 | " pred_inds=pred_inds.view(-1)\n", 842 | " labels=labels.view(-1)\n", 843 | " acc=100*pred_inds.eq(labels).float().mean()\n", 844 | " return acc" 845 | ] 846 | }, 847 | { 848 | "cell_type": "code", 849 | "execution_count": null, 850 | "metadata": {}, 851 | "outputs": [], 852 | "source": [ 853 | "if not os.path.isdir('results/stage_2_model'):\n", 854 | " os.makedirs('results/stage_2_model')\n", 855 | "\n", 856 | "trace_file=os.path.join('results','trace_file','train_stage_2_trace.txt')\n", 857 | "if os.path.isfile(trace_file):\n", 858 | " os.remove(trace_file)\n", 859 | " \n", 860 | "if use_cuda:\n", 861 | " fe_model.cuda()\n", 862 | " classifier.cuda()\n", 863 | "\n", 864 | "#optimizer\n", 865 | "# optimizer_fe=torch.optim.SGD(fe_model.parameters(),lr=lr,nesterov=True , momentum=momentum,weight_decay=weight_decay)\n", 866 | "optimizer_classifier=torch.optim.SGD(classifier.parameters(),lr=lr,nesterov=True , momentum=momentum,weight_decay=weight_decay)\n", 867 | "lr_schedule_classifier=torch.optim.lr_scheduler.StepLR(optimizer=optimizer_classifier,gamma=0.5,step_size=25)\n", 868 | "criterion=torch.nn.CrossEntropyLoss()\n", 869 | "\n", 870 | "print(\"---- train-stage-2 ----\")\n", 871 | "best_acc_both=0.0\n", 872 | "best_acc_novel=0.0\n", 873 | "for ep in range(epoch):\n", 874 | " train_loss=[]\n", 875 | " acc_both=[]\n", 876 | " acc_base=[]\n", 877 | " acc_novel=[]\n", 878 | " print(\"----epoch: %2d---- \"%ep)\n", 879 | " fe_model.train()\n", 880 | " classifier.train()\n", 881 | " \n", 882 | " for batch in tqdm(dloader_train(ep)):\n", 883 | " assert(len(batch)==6) #images_train, labels_train, images_test, labels_test, K, nKbase\n", 884 | " \n", 885 | "# optimizer_fe.zero_grad()\n", 886 | " optimizer_classifier.zero_grad()\n", 887 | " \n", 888 | " train_data=batch[0]\n", 889 | " train_label=batch[1]\n", 890 | " test_data=batch[2]\n", 891 | " test_label=batch[3]\n", 892 | " k_id=batch[4]\n", 893 | " nKbase=batch[5]\n", 894 | " KbaseId=k_id[:,:nKbase[0]]\n", 895 | " labels_train_one_hot=get_labels_train_one_hot(train_label,dloader_train.nKnovel)\n", 896 | " \n", 897 | " if use_cuda:\n", 898 | " train_data=train_data.cuda()\n", 899 | " train_label=train_label.cuda()\n", 900 | " test_data=test_data.cuda()\n", 901 | " test_label=test_label.cuda()\n", 902 | " k_id=k_id.cuda()\n", 903 | " nKbase=nKbase.cuda()\n", 904 | " KbaseId=KbaseId.cuda()\n", 905 | " labels_train_one_hot=labels_train_one_hot.cuda()\n", 906 | " \n", 907 | " batch_size,nExamples,channels,width,high=train_data.size()\n", 908 | " nTest=test_data.size(1)\n", 909 | " \n", 910 | " train_data=train_data.view(batch_size*nExamples,channels,width,high)\n", 911 | " test_data=test_data.view(batch_size*nTest,channels,width,high)\n", 912 | " \n", 913 | " train_data_embedding=fe_model(train_data)\n", 914 | " test_data_embedding=fe_model(test_data)\n", 915 | " \n", 916 | " pred_result=classifier(features_test=test_data_embedding.view(batch_size,nTest,-1),Kbase_ids=KbaseId,\n", 917 | " features_train=train_data_embedding.view(batch_size,nExamples,-1),labels_train=labels_train_one_hot)\n", 918 | "# print(\"pred_result.size\",pred_result.size())\n", 919 | " pred_result = pred_result.view(batch_size*nTest,-1)\n", 920 | " test_label = test_label.view(batch_size*nTest)\n", 921 | " \n", 922 | " loss=criterion(pred_result,test_label)\n", 923 | " loss.backward()\n", 924 | "# optimizer_fe.step()\n", 925 | " optimizer_classifier.step()\n", 926 | " \n", 927 | " train_loss.append(float(loss))\n", 928 | " \n", 929 | " accuracy_both=get_acc(pred_result,test_label)\n", 930 | " acc_both.append(float(accuracy_both))\n", 931 | " \n", 932 | " base_ids=torch.nonzero(test_label < nKbase[0]).view(-1)\n", 933 | " novel_ids=torch.nonzero(test_label >= nKbase[0]).view(-1)\n", 934 | " \n", 935 | " pred_base = pred_result[base_ids,:]\n", 936 | " pred_novel =pred_result[novel_ids,:]\n", 937 | " \n", 938 | " accuracy_base=get_acc(pred_base[:,:nKbase[0]],test_label[base_ids])\n", 939 | " accuracy_novel=get_acc(pred_novel[:,nKbase[0]:],(test_label[novel_ids]-nKbase[0]))\n", 940 | " \n", 941 | " acc_base.append(float(accuracy_base))\n", 942 | " acc_novel.append(float(accuracy_novel))\n", 943 | " \n", 944 | " \n", 945 | " lr_schedule_classifier.step()\n", 946 | " #------------------------------------------------------\n", 947 | " #validation stage\n", 948 | " print(\"----begin validation----\")\n", 949 | " fe_model.eval()\n", 950 | " classifier.eval()\n", 951 | " \n", 952 | " val_loss=[]\n", 953 | " val_acc_both=[]\n", 954 | " val_acc_base=[]\n", 955 | " val_acc_novel=[]\n", 956 | " for batch in tqdm(dloader_test(ep)):\n", 957 | " assert(len(batch)==6)\n", 958 | " train_data=batch[0]\n", 959 | " train_label=batch[1]\n", 960 | " test_data=batch[2]\n", 961 | " test_label=batch[3]\n", 962 | " k_id=batch[4]\n", 963 | " nKbase=batch[5]\n", 964 | " KbaseId=k_id[:,:nKbase[0]]\n", 965 | " labels_train_one_hot=get_labels_train_one_hot(train_label,dloader_test.nKnovel)\n", 966 | " \n", 967 | " if use_cuda:\n", 968 | " train_data=train_data.cuda()\n", 969 | " train_label=train_label.cuda()\n", 970 | " test_data=test_data.cuda()\n", 971 | " test_label=test_label.cuda()\n", 972 | " k_id=k_id.cuda()\n", 973 | " nKbase=nKbase.cuda()\n", 974 | " KbaseId=KbaseId.cuda()\n", 975 | " labels_train_one_hot=labels_train_one_hot.cuda()\n", 976 | " \n", 977 | " batch_size,nExamples,channels,width,high=train_data.size()\n", 978 | " nTest=test_data.size(1)\n", 979 | " \n", 980 | " train_data=train_data.view(batch_size*nExamples,channels,width,high)\n", 981 | " test_data=test_data.view(batch_size*nTest,channels,width,high)\n", 982 | " \n", 983 | " train_data_embedding=fe_model(train_data)\n", 984 | " test_data_embedding=fe_model(test_data)\n", 985 | " \n", 986 | " pred_result=classifier(features_test=test_data_embedding.view(batch_size,nTest,-1),Kbase_ids=KbaseId,\n", 987 | " features_train=train_data_embedding.view(batch_size,nExamples,-1),labels_train=labels_train_one_hot)\n", 988 | "# print(\"pred_result.size\",pred_result.size())\n", 989 | " pred_result = pred_result.view(batch_size*nTest,-1)\n", 990 | " test_label = test_label.view(batch_size*nTest)\n", 991 | " \n", 992 | " loss=criterion(pred_result,test_label)\n", 993 | " val_loss.append(float(loss))\n", 994 | " \n", 995 | " accuracy_both=get_acc(pred_result,test_label)\n", 996 | " val_acc_both.append(float(accuracy_both))\n", 997 | " \n", 998 | " base_ids=torch.nonzero(test_label < nKbase[0]).view(-1)\n", 999 | " novel_ids=torch.nonzero(test_label >= nKbase[0]).view(-1)\n", 1000 | " \n", 1001 | " pred_base = pred_result[base_ids,:]\n", 1002 | " pred_novel =pred_result[novel_ids,:]\n", 1003 | " \n", 1004 | " accuracy_base=get_acc(pred_base[:,:nKbase[0]],test_label[base_ids])\n", 1005 | " accuracy_novel=get_acc(pred_novel[:,nKbase[0]:],(test_label[novel_ids]-nKbase[0]))\n", 1006 | " \n", 1007 | " val_acc_base.append(float(accuracy_base))\n", 1008 | " val_acc_novel.append(float(accuracy_novel))\n", 1009 | " avg_loss=np.mean(train_loss)\n", 1010 | " avg_acc_both=np.mean(acc_both)\n", 1011 | " avg_acc_base=np.mean(acc_base)\n", 1012 | " avg_acc_novel=np.mean(acc_novel)\n", 1013 | " \n", 1014 | " val_avg_loss=np.mean(val_loss)\n", 1015 | " val_avg_acc_both=np.mean(val_acc_both)\n", 1016 | " val_avg_acc_base=np.mean(val_acc_base)\n", 1017 | " val_avg_acc_novel=np.mean(val_acc_novel)\n", 1018 | " \n", 1019 | " print(\"epoch %2d training end : training ---- avg_loss = %.4f , avg_acc_both = %.2f , avg_acc_base = %.2f , avg_acc_novel = %.2f \"%(ep,avg_loss,avg_acc_both,avg_acc_base,avg_acc_novel))\n", 1020 | " print(\"epoch %2d training end : validation ---- avg_loss = %.4f , avg_acc_both = %.2f , avg_acc_base = %.2f , avg_acc_novel = %.2f \"%(ep,val_avg_loss,val_avg_acc_both,val_avg_acc_base,val_avg_acc_novel))\n", 1021 | " with open(trace_file,'a') as f:\n", 1022 | " f.write('epoch:{:2d} training ---- avg_loss:{:.4f} , avg_acc_both:{:.2f} , avg_acc_base:{:.2f} , avg_acc_novel:{:.2f}'.format(ep,avg_loss,avg_acc_both,avg_acc_base,avg_acc_novel))\n", 1023 | " f.write('\\n')\n", 1024 | " f.write('epoch:{:2d} validation ---- avg_loss:{:.4f} , avg_acc_both:{:.2f} , avg_acc_base:{:.2f} , avg_acc_novel:{:.2f}'.format(ep,val_avg_loss,val_avg_acc_both,val_avg_acc_base,val_avg_acc_novel))\n", 1025 | " f.write('\\n')\n", 1026 | " if best_acc_both= nKbase[0]).view(-1)\n", 1162 | "\n", 1163 | " pred_base = pred_result[base_ids,:]\n", 1164 | " pred_novel =pred_result[novel_ids,:]\n", 1165 | "\n", 1166 | " accuracy_base=get_acc(pred_base[:,:nKbase[0]],test_label[base_ids])\n", 1167 | " accuracy_novel=get_acc(pred_novel[:,nKbase[0]:],(test_label[novel_ids]-nKbase[0]))\n", 1168 | "\n", 1169 | " test_acc_base.append(float(accuracy_base))\n", 1170 | " test_acc_novel.append(float(accuracy_novel))\n", 1171 | "\n", 1172 | "test_avg_loss=np.mean(test_loss)\n", 1173 | "test_avg_acc_both=np.mean(test_acc_both)\n", 1174 | "test_avg_acc_base=np.mean(test_acc_base)\n", 1175 | "test_avg_acc_novel=np.mean(test_acc_novel)\n", 1176 | "\n", 1177 | "print(\"%2d batch test end : avg_loss = %.4f , avg_acc_both = %.2f , avg_acc_base = %.2f , avg_acc_novel = %.2f \"%(dloader_test.epoch_size,test_avg_loss,test_avg_acc_both,test_avg_acc_base,test_avg_acc_novel))\n", 1178 | "with open(trace_file,'a') as f:\n", 1179 | " f.write('batch_size:{:2d} test ---- avg_loss:{:.4f} , avg_acc_both:{:.2f} , avg_acc_base:{:.2f} , avg_acc_novel:{:.2f}'.format(dloader_test.epoch_size,test_avg_loss,test_avg_acc_both,test_avg_acc_base,test_avg_acc_novel))\n", 1180 | " f.write('\\n')" 1181 | ] 1182 | } 1183 | ], 1184 | "metadata": { 1185 | "kernelspec": { 1186 | "display_name": "Python 3", 1187 | "language": "python", 1188 | "name": "python3" 1189 | }, 1190 | "language_info": { 1191 | "codemirror_mode": { 1192 | "name": "ipython", 1193 | "version": 3 1194 | }, 1195 | "file_extension": ".py", 1196 | "mimetype": "text/x-python", 1197 | "name": "python", 1198 | "nbconvert_exporter": "python", 1199 | "pygments_lexer": "ipython3", 1200 | "version": "3.6.5" 1201 | } 1202 | }, 1203 | "nbformat": 4, 1204 | "nbformat_minor": 2 1205 | } 1206 | --------------------------------------------------------------------------------