├── .gitignore ├── LICENSE ├── README.md ├── config ├── 1-shot-5-class.py ├── 5-shot-5-class.py ├── __init__.py ├── baselines │ ├── __init__.py │ ├── test-conv-nearest-neighbor.py │ ├── test-matching-net.py │ ├── train-conv-nearest-neighbor.py │ ├── train-matching-net.py │ ├── train-pixel-nearest-neighbor.py │ └── train-pre-trained-SGD.py ├── imagenet.py └── lstm │ ├── __init__.py │ ├── test-lstm.py │ ├── train-imagenet-1shot.py │ └── train-imagenet-5shot.py ├── datasets ├── __init__.py ├── data-loader.py └── miniImagenet.py ├── logger.py ├── main.py ├── model ├── __init__.py ├── baselines │ ├── __init__.py │ ├── fce-embedding.py │ ├── matching-net.py │ └── simple-embedding.py ├── lstm-classifier.py ├── lstm │ ├── __init__.py │ ├── bnlstm.py │ ├── learner.py │ ├── lstmhelper.py │ ├── metaLearner.py │ ├── metalstm.py │ ├── recurrentLSTMNetwork.py │ └── train-lstm.py └── matching-net-classifier.py ├── option.py ├── utils ├── __init__.py ├── create_miniImagenet.py └── util.py └── visualize ├── __init__.py └── visualize.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | env/ 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | 49 | # Translations 50 | *.mo 51 | *.pot 52 | 53 | # Django stuff: 54 | *.log 55 | local_settings.py 56 | 57 | # Flask stuff: 58 | instance/ 59 | .webassets-cache 60 | 61 | # Scrapy stuff: 62 | .scrapy 63 | 64 | # Sphinx documentation 65 | docs/_build/ 66 | 67 | # PyBuilder 68 | target/ 69 | 70 | # Jupyter Notebook 71 | .ipynb_checkpoints 72 | 73 | # pyenv 74 | .python-version 75 | 76 | # celery beat schedule file 77 | celerybeat-schedule 78 | 79 | # SageMath parsed files 80 | *.sage.py 81 | 82 | # dotenv 83 | .env 84 | 85 | # virtualenv 86 | .venv 87 | venv/ 88 | ENV/ 89 | 90 | # Spyder project settings 91 | .spyderproject 92 | .spyproject 93 | 94 | # Rope project settings 95 | .ropeproject 96 | 97 | # mkdocs documentation 98 | /site 99 | 100 | # mypy 101 | .mypy_cache/ 102 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2017 gitabcworld 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Optimization as a Model for Few-Shot Learning 2 | This repo provides a Pytorch implementation for the [Optimization as a Model for Few-Shot Learning](https://openreview.net/pdf?id=rJY0-Kcll) paper. 3 | 4 | ## Installation of pytorch 5 | The experiments needs installing [Pytorch](http://pytorch.org/) 6 | 7 | ## Data 8 | For the miniImageNet you need to download the ImageNet dataset and execute the script utils.create_miniImagenet.py changing the lines: 9 | ``` 10 | pathImageNet = '/ILSVRC2012_img_train' 11 | pathminiImageNet = '/miniImagenet/' 12 | ``` 13 | And also change the main file option.py line or pass it by command line arguments: 14 | ``` 15 | parser.add_argument('--dataroot', type=str, default='/miniImagenet/',help='path to dataset') 16 | ``` 17 | 18 | ## Installation 19 | 20 | $ pip install -r requirements.txt 21 | $ python main.py 22 | 23 | 24 | ## Acknowledgements 25 | Special thanks to @sachinravi14 for their Torch implementation. I intend to replicate their code using Pytorch. More details at https://github.com/twitter/meta-learning-lstm 26 | 27 | ## Cite 28 | ``` 29 | @inproceedings{Sachin2017, 30 | title={Optimization as a model for few-shot learning}, 31 | author={Ravi, Sachin and Larochelle, Hugo}, 32 | booktitle={In International Conference on Learning Representations (ICLR)}, 33 | year={2017} 34 | } 35 | ``` 36 | 37 | 38 | ## Authors 39 | 40 | * Albert Berenguel (@aberenguel) [Webpage](https://scholar.google.es/citations?user=HJx2fRsAAAAJ&hl=en) 41 | -------------------------------------------------------------------------------- /config/1-shot-5-class.py: -------------------------------------------------------------------------------- 1 | ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 2 | ## Created by: Albert Berenguel 3 | ## Computer Vision Center (CVC). Universitat Autonoma de Barcelona 4 | ## Email: aberenguel@cvc.uab.es 5 | ## Copyright (c) 2017 6 | ## 7 | ## This source code is licensed under the MIT-style license found in the 8 | ## LICENSE file in the root directory of this source tree 9 | ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 10 | def params(opt): 11 | opt['nClasses'] = {'train':5, 'val':5, 'test':5} 12 | opt['nTrainShot'] = 1 13 | opt['nEval'] = 15 14 | 15 | opt['nTest'] = [100, 250, 600] 16 | opt['nTestShot'] = [1, 5] 17 | return opt -------------------------------------------------------------------------------- /config/5-shot-5-class.py: -------------------------------------------------------------------------------- 1 | ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 2 | ## Created by: Albert Berenguel 3 | ## Computer Vision Center (CVC). Universitat Autonoma de Barcelona 4 | ## Email: aberenguel@cvc.uab.es 5 | ## Copyright (c) 2017 6 | ## 7 | ## This source code is licensed under the MIT-style license found in the 8 | ## LICENSE file in the root directory of this source tree 9 | ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 10 | def params(opt): 11 | opt['nClasses'] = {'train':5, 'val':5, 'test':5} 12 | opt['nTrainShot'] = 5 13 | opt['nEval'] = 15 14 | 15 | opt['nTest'] = [100, 250, 600] 16 | opt['nTestShot'] = [1, 5] 17 | return opt -------------------------------------------------------------------------------- /config/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gitabcworld/FewShotLearning/2eff1881f96212e0fb31737a48f82fab9c2fac83/config/__init__.py -------------------------------------------------------------------------------- /config/baselines/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gitabcworld/FewShotLearning/2eff1881f96212e0fb31737a48f82fab9c2fac83/config/baselines/__init__.py -------------------------------------------------------------------------------- /config/baselines/test-conv-nearest-neighbor.py: -------------------------------------------------------------------------------- 1 | ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 2 | ## Created by: Albert Berenguel 3 | ## Computer Vision Center (CVC). Universitat Autonoma de Barcelona 4 | ## Email: aberenguel@cvc.uab.es 5 | ## Copyright (c) 2017 6 | ## 7 | ## This source code is licensed under the MIT-style license found in the 8 | ## LICENSE file in the root directory of this source tree 9 | ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 10 | def params(opt): 11 | opt['nEpisode'] = 0 12 | opt['paramsFile'] = 'saved_params/matching-net-FCE/matching-net_params_snapshot.th' 13 | opt['networkFile'] = 'saved_params/matching-net-FCE/matching-net-models.th' 14 | return opt 15 | -------------------------------------------------------------------------------- /config/baselines/test-matching-net.py: -------------------------------------------------------------------------------- 1 | ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 2 | ## Created by: Albert Berenguel 3 | ## Computer Vision Center (CVC). Universitat Autonoma de Barcelona 4 | ## Email: aberenguel@cvc.uab.es 5 | ## Copyright (c) 2017 6 | ## 7 | ## This source code is licensed under the MIT-style license found in the 8 | ## LICENSE file in the root directory of this source tree 9 | ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 10 | def params(opt): 11 | opt['nEpochs'] = 0 12 | opt['paramsFile'] = 'saved_params/conv-nearest-neighbor/conv-nearest-neighbor-model.th' 13 | return opt -------------------------------------------------------------------------------- /config/baselines/train-conv-nearest-neighbor.py: -------------------------------------------------------------------------------- 1 | ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 2 | ## Created by: Albert Berenguel 3 | ## Computer Vision Center (CVC). Universitat Autonoma de Barcelona 4 | ## Email: aberenguel@cvc.uab.es 5 | ## Copyright (c) 2017 6 | ## 7 | ## This source code is licensed under the MIT-style license found in the 8 | ## LICENSE file in the root directory of this source tree 9 | ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 10 | def params(opt): 11 | opt['learner'] = 'model.matching-net-classifier' 12 | opt['metaLearner'] = 'model.baselines.conv-nearest-neighbor' 13 | 14 | opt['trainFull'] = True 15 | opt['nClasses.train'] = 64 16 | opt['learningRate'] = 0.001 17 | opt['trainBatchSize'] = 64 18 | opt['nEpochs'] = 30000 19 | opt['nValidationEpisode'] = 100 20 | opt['printPer'] = 1000 21 | opt['useCUDA'] = True 22 | return opt 23 | 24 | -------------------------------------------------------------------------------- /config/baselines/train-matching-net.py: -------------------------------------------------------------------------------- 1 | ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 2 | ## Created by: Albert Berenguel 3 | ## Computer Vision Center (CVC). Universitat Autonoma de Barcelona 4 | ## Email: aberenguel@cvc.uab.es 5 | ## Copyright (c) 2017 6 | ## 7 | ## This source code is licensed under the MIT-style license found in the 8 | ## LICENSE file in the root directory of this source tree 9 | ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 10 | def params(opt): 11 | opt['learner'] = 'model.matching-net-classifier' 12 | opt['metaLearner'] = 'model.baselines.matching-net' 13 | 14 | # simple or FCE - embedding model? 15 | # opt['embedModel'] = 'model.baselines.simple-embedding' 16 | opt['embedModel'] = 'model.baselines.fce-embedding' 17 | 18 | opt['steps'] = 3 19 | opt['classify'] = False 20 | opt['useDropout'] = True 21 | opt['optimMethod'] = 'adam' 22 | opt['lr'] = 1e-03 23 | opt['lr_decay'] = 1e-6 24 | opt['weight_decay'] = 1e-4 25 | opt['batchSize'] = opt['nClasses']['train'] * opt['nEval'] 26 | opt['nEpisode'] = 75000 27 | opt['nValidationEpisode'] = 100 28 | opt['printPer'] = 1000 29 | opt['useCUDA'] = True 30 | opt['ngpu'] = 2 31 | return opt 32 | -------------------------------------------------------------------------------- /config/baselines/train-pixel-nearest-neighbor.py: -------------------------------------------------------------------------------- 1 | ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 2 | ## Created by: Albert Berenguel 3 | ## Computer Vision Center (CVC). Universitat Autonoma de Barcelona 4 | ## Email: aberenguel@cvc.uab.es 5 | ## Copyright (c) 2017 6 | ## 7 | ## This source code is licensed under the MIT-style license found in the 8 | ## LICENSE file in the root directory of this source tree 9 | ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 10 | def params(opt): 11 | opt['learner'] = 'model.matching-net-classifier' 12 | opt['metaLearner'] = 'model.baselines.pixel-nearest-neighbor' 13 | 14 | opt['trainFull'] = True 15 | opt['nClasses.train'] = 64 - (-20) - (-16) 16 | opt['nAllClasses'] = 64 - (-4112) 17 | opt['useCUDA'] = False 18 | return opt 19 | 20 | -------------------------------------------------------------------------------- /config/baselines/train-pre-trained-SGD.py: -------------------------------------------------------------------------------- 1 | ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 2 | ## Created by: Albert Berenguel 3 | ## Computer Vision Center (CVC). Universitat Autonoma de Barcelona 4 | ## Email: aberenguel@cvc.uab.es 5 | ## Copyright (c) 2017 6 | ## 7 | ## This source code is licensed under the MIT-style license found in the 8 | ## LICENSE file in the root directory of this source tree 9 | ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 10 | def params(opt): 11 | opt['learner'] = 'model.matching-net-classifier' 12 | opt['metaLearner'] = 'model.baselines.pre-trained-SGD' 13 | 14 | 15 | opt['trainFull'] = True 16 | opt['nClasses.train'] = 64 17 | 18 | opt['learningRate'] = 0.001 19 | opt['trainBatchSize'] = 64 20 | 21 | opt['learningRates'] = [0.5, 0.1, 0.01, 0.001, 0.0001, 0.00001] 22 | opt['learningRateDecays'] = [1, 1e-1, 1e-2, 1e-3, 1e-4, 1e-5, 0] 23 | opt['nUpdates'] = [15] 24 | 25 | opt['nEpochs'] = 30000 26 | opt['nValidationEpisode'] = 100 27 | opt['printPer'] = 1000 28 | opt['useCUDA'] = True 29 | return opt 30 | 31 | -------------------------------------------------------------------------------- /config/imagenet.py: -------------------------------------------------------------------------------- 1 | ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 2 | ## Created by: Albert Berenguel 3 | ## Computer Vision Center (CVC). Universitat Autonoma de Barcelona 4 | ## Email: aberenguel@cvc.uab.es 5 | ## Copyright (c) 2017 6 | ## 7 | ## This source code is licensed under the MIT-style license found in the 8 | ## LICENSE file in the root directory of this source tree 9 | ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 10 | def params(opt): 11 | opt['nExamples'] = 20 12 | opt['nDepth'] = 3 13 | opt['nIn'] = 84 14 | 15 | opt['rawDataDir'] = '/home/aberenguel/Dataset/miniImagenet' 16 | opt['dataName'] = 'datasets.miniImagenet' 17 | opt['dataLoader'] = 'datasets.data-loader' 18 | opt['episodeSamplerKind'] = 'permutation' 19 | 20 | return opt 21 | 22 | -------------------------------------------------------------------------------- /config/lstm/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gitabcworld/FewShotLearning/2eff1881f96212e0fb31737a48f82fab9c2fac83/config/lstm/__init__.py -------------------------------------------------------------------------------- /config/lstm/test-lstm.py: -------------------------------------------------------------------------------- 1 | ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 2 | ## Created by: Albert Berenguel 3 | ## Computer Vision Center (CVC). Universitat Autonoma de Barcelona 4 | ## Email: aberenguel@cvc.uab.es 5 | ## Copyright (c) 2017 6 | ## 7 | ## This source code is licensed under the MIT-style license found in the 8 | ## LICENSE file in the root directory of this source tree 9 | ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 10 | def params(opt): 11 | opt['nEpisode'] = 0 12 | opt['paramsFile'] = 'saved_params/miniImagenet/meta-learner-5shot-momentum/metaLearner_params_snapshot.th' 13 | return opt 14 | 15 | -------------------------------------------------------------------------------- /config/lstm/train-imagenet-1shot.py: -------------------------------------------------------------------------------- 1 | ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 2 | ## Created by: Albert Berenguel 3 | ## Computer Vision Center (CVC). Universitat Autonoma de Barcelona 4 | ## Email: aberenguel@cvc.uab.es 5 | ## Copyright (c) 2017 6 | ## 7 | ## This source code is licensed under the MIT-style license found in the 8 | ## LICENSE file in the root directory of this source tree 9 | ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 10 | def params(opt): 11 | opt['learner'] = 'model.lstm-classifier' 12 | opt['metaLearner'] = 'model.lstm.train-lstm' 13 | 14 | 15 | opt['BN_momentum'] = 0.9 16 | opt['optimMethod'] = 'adam' 17 | opt['lr'] = 1e-03 18 | opt['lr_decay'] = 1e-6 19 | opt['weight_decay'] = 1e-4 20 | opt['maxGradNorm'] = 0.25 21 | 22 | opt['batchSize'] = {1: 5, 5: 5} 23 | opt['nEpochs'] = {1: 12, 5: 5} 24 | 25 | opt['nEpisode'] = 7500 26 | opt['nValidationEpisode'] = 100 27 | opt['printPer'] = 1000 28 | opt['useCUDA'] = True 29 | return opt 30 | 31 | -------------------------------------------------------------------------------- /config/lstm/train-imagenet-5shot.py: -------------------------------------------------------------------------------- 1 | ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 2 | ## Created by: Albert Berenguel 3 | ## Computer Vision Center (CVC). Universitat Autonoma de Barcelona 4 | ## Email: aberenguel@cvc.uab.es 5 | ## Copyright (c) 2017 6 | ## 7 | ## This source code is licensed under the MIT-style license found in the 8 | ## LICENSE file in the root directory of this source tree 9 | ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 10 | def params(opt): 11 | opt['learner'] = 'model.lstm-classifier' 12 | opt['metaLearner'] = 'model.lstm.train-lstm' 13 | 14 | opt['BN_momentum'] = 0.95 15 | opt['classify'] = True 16 | opt['useDropout'] = False 17 | 18 | opt['optimMethod'] = 'adam' 19 | opt['lr'] = 1e-03 20 | opt['lr_decay'] = 1e-6 21 | opt['weight_decay'] = 1e-4 22 | opt['maxGradNorm'] = 0.25 23 | 24 | opt['batchSize'] = {1:5, 5:25} 25 | opt['nEpochs'] = {1: 5, 5: 8} 26 | 27 | opt['nEpisode'] = 500 28 | opt['nValidationEpisode'] = 100 29 | opt['printPer'] = 100 30 | opt['useCUDA'] = True 31 | return opt -------------------------------------------------------------------------------- /datasets/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gitabcworld/FewShotLearning/2eff1881f96212e0fb31737a48f82fab9c2fac83/datasets/__init__.py -------------------------------------------------------------------------------- /datasets/data-loader.py: -------------------------------------------------------------------------------- 1 | ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 2 | ## Created by: Albert Berenguel 3 | ## Computer Vision Center (CVC). Universitat Autonoma de Barcelona 4 | ## Email: aberenguel@cvc.uab.es 5 | ## Copyright (c) 2017 6 | ## 7 | ## This source code is licensed under the MIT-style license found in the 8 | ## LICENSE file in the root directory of this source tree 9 | ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 10 | 11 | import importlib 12 | import numpy as np 13 | 14 | def getData(opt): 15 | # set up meta-train, meta-validation & meta-test datasets 16 | dataTrain = importlib.import_module(opt['dataName']).DatasetLoader(dataroot=opt['rawDataDir'], 17 | type='train', 18 | nEpisodes=opt['nEpisode'], 19 | classes_per_set=opt['nClasses']['train'], 20 | samples_per_class=opt['nTrainShot']) 21 | 22 | dataVal = importlib.import_module(opt['dataName']).DatasetLoader(dataroot=opt['rawDataDir'], 23 | type='val', 24 | nEpisodes=opt['nValidationEpisode'], 25 | classes_per_set=opt['nClasses']['val'], 26 | samples_per_class=opt['nEval']) 27 | dataTest = [] 28 | for nTest in opt['nTest']: 29 | dataTest.append(importlib.import_module(opt['dataName']).DatasetLoader(dataroot=opt['rawDataDir'], 30 | type='test', 31 | nEpisodes=np.sum(opt['nTest']), 32 | classes_per_set=opt['nClasses']['test'], 33 | samples_per_class=np.max( 34 | opt['nTestShot']))) 35 | data = {'train': dataTrain, 'validation': dataVal, 'test': dataTest} 36 | return data -------------------------------------------------------------------------------- /datasets/miniImagenet.py: -------------------------------------------------------------------------------- 1 | ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 2 | ## Created by: Albert Berenguel 3 | ## Computer Vision Center (CVC). Universitat Autonoma de Barcelona 4 | ## Email: aberenguel@cvc.uab.es 5 | ## Copyright (c) 2017 6 | ## 7 | ## This source code is licensed under the MIT-style license found in the 8 | ## LICENSE file in the root directory of this source tree 9 | ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 10 | import torch 11 | import torch.utils.data as data 12 | import torchvision.transforms as transforms 13 | from PIL import Image 14 | import os.path 15 | import csv 16 | import math 17 | import collections 18 | from tqdm import tqdm 19 | 20 | import numpy as np 21 | np.random.seed(2191) # for reproducibility 22 | 23 | # LAMBDA FUNCTIONS 24 | filenameToPILImage = lambda x: Image.open(x) 25 | PiLImageResize = lambda x: x.resize((84,84)) 26 | 27 | class DatasetLoader(data.Dataset): 28 | def __init__(self, dataroot = './data/miniImagenet', type = 'train', 29 | nEpisodes = 1000, classes_per_set=10, samples_per_class=1): 30 | 31 | self.nEpisodes = nEpisodes 32 | self.classes_per_set = classes_per_set 33 | self.samples_per_class = samples_per_class 34 | self.samples_per_class_eval = 15 35 | self.n_samples = self.samples_per_class * self.classes_per_set 36 | self.n_samples_eval = self.samples_per_class_eval * self.classes_per_set 37 | # Transformations to the image 38 | self.transform = transforms.Compose([filenameToPILImage, 39 | PiLImageResize, 40 | transforms.ToTensor() 41 | ]) 42 | 43 | def loadSplit(splitFile): 44 | dictLabels = {} 45 | with open(splitFile) as csvfile: 46 | csvreader = csv.reader(csvfile, delimiter=',') 47 | next(csvreader, None) 48 | for i,row in enumerate(csvreader): 49 | filename = row[0] 50 | label = row[1] 51 | if label in dictLabels.keys(): 52 | dictLabels[label].append(filename) 53 | else: 54 | dictLabels[label] = [filename] 55 | return dictLabels 56 | 57 | #requiredFiles = ['train','val','test'] 58 | self.miniImagenetImagesDir = os.path.join(dataroot,'images') 59 | self.data = loadSplit(splitFile = os.path.join(dataroot,type + '.csv')) 60 | self.data = collections.OrderedDict(sorted(self.data.items())) 61 | self.classes_dict = {self.data.keys()[i]:i for i in range(len(self.data.keys()))} 62 | self.create_episodes(self.nEpisodes) 63 | 64 | def create_episodes(self,episodes): 65 | 66 | nClasses = len(self.data.keys()) 67 | 68 | self.support_set_x_batch = [] 69 | self.target_x_batch = [] 70 | for b in np.arange(episodes): 71 | # select n classes_per_set randomly 72 | selected_classes = np.random.choice(nClasses, self.classes_per_set, False) 73 | support_set_x = [] 74 | target_x = [] 75 | for c in selected_classes: 76 | selected_samples = np.random.choice(len(self.data[self.data.keys()[c]]), 77 | self.samples_per_class + self.samples_per_class_eval, False) 78 | indexDtrain = np.array(selected_samples[:self.samples_per_class]) 79 | indexDtest = np.array(selected_samples[self.samples_per_class:]) 80 | support_set_x.append(np.array(self.data[self.data.keys()[c]])[indexDtrain].tolist()) 81 | target_x.append(np.array(self.data[self.data.keys()[c]])[indexDtest].tolist()) 82 | self.support_set_x_batch.append(support_set_x) 83 | self.target_x_batch.append(target_x) 84 | 85 | def __getitem__(self, index): 86 | 87 | support_set_x = torch.FloatTensor(self.n_samples, 3, 84, 84) 88 | support_set_y = np.zeros((self.n_samples), dtype=np.int) 89 | target_x = torch.FloatTensor(self.n_samples_eval, 3, 84, 84) 90 | target_y = np.zeros((self.n_samples_eval), dtype=np.int) 91 | 92 | flatten_support_set_x_batch = [os.path.join(self.miniImagenetImagesDir,item) 93 | for sublist in self.support_set_x_batch[index] for item in sublist] 94 | support_set_y = np.array([self.classes_dict[item[:9]] 95 | for sublist in self.support_set_x_batch[index] for item in sublist]) 96 | flatten_target_x = [os.path.join(self.miniImagenetImagesDir,item) 97 | for sublist in self.target_x_batch[index] for item in sublist] 98 | target_y = np.array([self.classes_dict[item[:9]] 99 | for sublist in self.target_x_batch[index] for item in sublist]) 100 | 101 | for i,path in enumerate(flatten_support_set_x_batch): 102 | if self.transform is not None: 103 | support_set_x[i] = self.transform(path) 104 | 105 | for i,path in enumerate(flatten_target_x): 106 | if self.transform is not None: 107 | target_x[i] = self.transform(path) 108 | return support_set_x, torch.IntTensor(support_set_y), target_x, torch.IntTensor(target_y) 109 | 110 | def __len__(self): 111 | return self.nEpisodes 112 | 113 | 114 | -------------------------------------------------------------------------------- /logger.py: -------------------------------------------------------------------------------- 1 | import os 2 | from tensorboard_logger import configure, log_value 3 | 4 | class Logger(object): 5 | def __init__(self, log_dir): 6 | # clean previous logged data under the same directory name 7 | self._remove(log_dir) 8 | 9 | # configure the project 10 | configure(log_dir) 11 | 12 | self.global_step = 0 13 | 14 | def log_value(self, name, value): 15 | log_value(name, value, self.global_step) 16 | return self 17 | 18 | def step(self): 19 | self.global_step += 1 20 | 21 | @staticmethod 22 | def _remove(path): 23 | """ param could either be relative or absolute. """ 24 | if os.path.isfile(path): 25 | os.remove(path) # remove the file 26 | elif os.path.isdir(path): 27 | import shutil 28 | shutil.rmtree(path) # remove dir and all contains -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 2 | ## Created by: Albert Berenguel 3 | ## Computer Vision Center (CVC). Universitat Autonoma de Barcelona 4 | ## Email: aberenguel@cvc.uab.es 5 | ## Copyright (c) 2017 6 | ## 7 | ## This source code is licensed under the MIT-style license found in the 8 | ## LICENSE file in the root directory of this source tree 9 | ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 10 | import importlib 11 | from option import Options 12 | from logger import Logger 13 | import numpy as np 14 | 15 | # Import params from Config 16 | # Parse other options 17 | args = Options().parse() 18 | 19 | # load config info for task, data, and model 20 | opt = {} 21 | opt = importlib.import_module(args.task).params(opt) 22 | opt = importlib.import_module(args.data).params(opt) 23 | opt = importlib.import_module(args.model).params(opt) 24 | if not args.test == '-': 25 | opt = importlib.import_module(args.test).params(opt) 26 | LOG_DIR = args.log_dir + '/task_{}_data_{}_model_{}' \ 27 | .format(args.task,args.data,args.model) 28 | # create logger 29 | logger = Logger(LOG_DIR) 30 | 31 | # Print options 32 | print('Training with options:') 33 | for key in sorted(opt.iterkeys()): 34 | print "%s: %s" % (key, opt[key]) 35 | 36 | # set up meta-train, meta-validation and meta-test datasets 37 | data = importlib.import_module(opt['dataLoader']).getData(opt) 38 | # Run the training, validation and test. 39 | results = importlib.import_module(opt['metaLearner']).run(opt,data) 40 | print('Task: %s. Data: %s. Model: %s' % (args.task,args.data,args.model) ) 41 | 42 | 43 | -------------------------------------------------------------------------------- /model/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gitabcworld/FewShotLearning/2eff1881f96212e0fb31737a48f82fab9c2fac83/model/__init__.py -------------------------------------------------------------------------------- /model/baselines/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gitabcworld/FewShotLearning/2eff1881f96212e0fb31737a48f82fab9c2fac83/model/baselines/__init__.py -------------------------------------------------------------------------------- /model/baselines/fce-embedding.py: -------------------------------------------------------------------------------- 1 | ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 2 | ## Created by: Albert Berenguel 3 | ## Computer Vision Center (CVC). Universitat Autonoma de Barcelona 4 | ## Email: aberenguel@cvc.uab.es 5 | ## Copyright (c) 2017 6 | ## 7 | ## This source code is licensed under the MIT-style license found in the 8 | ## LICENSE file in the root directory of this source tree 9 | ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 10 | 11 | import os 12 | import torch 13 | import torch.nn as nn 14 | import importlib 15 | import pickle 16 | import numpy as np 17 | #from model.lstm.bnlstm import RecurrentLSTMNetwork 18 | 19 | class FceEmbedding(): 20 | def __init__(self, opt): 21 | self.opt = opt # Store the parameters 22 | self.maxGradNorm = opt['maxGradNorm'] if ['maxGradNorm'] in opt.keys() else 0.25 23 | self.numLayersAttLstm = opt['numLayersAttLstm'] if ['numLayersAttLstm'] in opt.keys() else 1 24 | self.numLayersBiLstm = opt['numLayersBiLstm'] if ['numLayersBiLstm'] in opt.keys() else 1 25 | self.buildModels(self.opt) 26 | self.setCuda() 27 | 28 | # Build F and G models 29 | def buildModels(self,opt): 30 | # F function 31 | modelF = importlib.import_module(opt['learner']).build(opt) 32 | self.embedNetF = modelF.net 33 | # G function 34 | modelG = importlib.import_module(opt['learner']).build(opt) 35 | self.embedNetG = modelG.net 36 | 37 | ''' 38 | # Build LSTM for attention model. 39 | self.attLSTM = RecurrentLSTMNetwork({ 40 | 'inputFeatures': self.embedNetF.outSize + self.embedNetG.outSize, 41 | 'hiddenFeatures': self.embedNetF.outSize, 42 | 'outputType': 'all' 43 | }) 44 | 45 | self.biLSTMForward = RecurrentLSTMNetwork({ 46 | 'inputFeatures': self.embedNetG.outSize, 47 | 'hiddenFeatures': self.embedNetG.outSize, 48 | 'outputType': 'all' 49 | }) 50 | 51 | self.biLSTMBackward = RecurrentLSTMNetwork({ 52 | 'inputFeatures': self.embedNetG.outSize, 53 | 'hiddenFeatures': self.embedNetG.outSize, 54 | 'outputType': 'all' 55 | }) 56 | ''' 57 | 58 | self.attLSTM = nn.LSTM(input_size=self.embedNetF.outSize + self.embedNetG.outSize, 59 | hidden_size=self.embedNetF.outSize, 60 | num_layers = self.numLayersAttLstm) 61 | # Build bidirectional LSTM 62 | self.biLSTM = nn.LSTM(input_size=self.embedNetG.outSize, 63 | hidden_size=self.embedNetG.outSize, 64 | num_layers=self.numLayersBiLstm, 65 | bidirectional=True) 66 | 67 | self.softmax = nn.Softmax() 68 | 69 | # Build list of parameters for optim 70 | def parameters(self): 71 | # TODO: why in the original code creates a dictionary with the same 72 | # parameters. model.params = {f=paramsG, g=paramsG, attLST, biLSTM} 73 | return list(self.embedNetG.parameters()) + \ 74 | list(self.embedNetG.parameters()) + \ 75 | list(self.attLSTM.parameters()) + \ 76 | list(self.biLSTM.parameters()) 77 | 78 | # Set training or evaluation mode 79 | def set(self,mode): 80 | if mode == 'training': 81 | self.embedNetF.train() 82 | self.embedNetG.train() 83 | elif mode == 'evaluate': 84 | self.embedNetF.eval() 85 | self.embedNetG.eval() 86 | else: 87 | print('model.set: undefined mode - %s' % (mode)) 88 | 89 | def isTraining(self): 90 | return self.embedNetF.training 91 | 92 | def attLSTM_forward(self,gS,fX, K): 93 | 94 | r = gS.mean(0).expand_as(fX) 95 | for i in np.arange(K): 96 | x = torch.cat((fX, r), 1) 97 | x = x.unsqueeze(0) 98 | if i == 0: 99 | #dim: [sequence = 1, batch_size, num_features * 2] 100 | output, (h, c) = self.attLSTM(x) 101 | else: 102 | output, (h, c) = self.attLSTM(x,(h,c)) 103 | h = fX.squeeze(0) + output 104 | 105 | embed = None 106 | # Iterate over batch size 107 | for j in np.arange(h.size(1)): 108 | hInd = h[0,i, :].expand_as(gS) 109 | weight = (gS*hInd).sum(1).unsqueeze(1) 110 | embed_tmp = (self.softmax(weight).expand_as(gS) * gS).sum(0).unsqueeze(0) 111 | if embed is None: 112 | embed = embed_tmp 113 | else: 114 | embed = torch.cat([embed,embed_tmp],0) 115 | # output dim: [batch, num_features] 116 | return h.squeeze(0) 117 | 118 | def biLSTM_forward(self, input): 119 | gX = input 120 | # Expected input dimension of the form [sequence_length, batch_size, num_features] 121 | gX = gX.unsqueeze(1) 122 | output, (hn, cn) = self.biLSTM(gX) 123 | # output dim: [sequence, batch_size, num_features * 2] 124 | output = output[:, :, :self.embedNetG.outSize] + output[:, :, self.embedNetG.outSize:] 125 | output = output.squeeze(1) 126 | # output dim: [sequence, num_features] 127 | return output 128 | 129 | def embedG(self, input): 130 | g = self.embedNetG(input) 131 | return self.biLSTM_forward(g) 132 | 133 | def embedF(self, input, g, K): 134 | f = self.embedNetF(input) 135 | return self.attLSTM_forward(g,f,K) 136 | 137 | def save(self, path = './data'): 138 | # Save the opt parameters 139 | optParametersFile = open(os.path.join(path,'SimpleEmbedding_opt.pkl'), 'wb') 140 | pickle.dump(self.opt, optParametersFile) 141 | optParametersFile.close() 142 | # Clean not needed data of the models 143 | self.embedNetF.clearState() 144 | self.embedNetG.clearState() 145 | torch.save(self.embedNetF.state_dict(), os.path.join(path,'embedNetF.pth.tar')) 146 | torch.save(self.embedNetG.state_dict(), os.path.join(path, 'embedNetG.pth.tar')) 147 | 148 | def load(self, pathParams, pathModelF, pathModelG): 149 | # Load opt parameters 'SimpleEmbedding_opt.pkl' 150 | optParametersFile = open(pathParams, 'rb') 151 | self.opt = pickle.load(optParametersFile) 152 | optParametersFile.close() 153 | # build the models 154 | self.buildModels(self.opt) 155 | # Load the weights and biases of F and G 156 | checkpoint = torch.load(pathModelF) 157 | self.embedNetF.load_state_dict(checkpoint['state_dict']) 158 | checkpoint = torch.load(pathModelG) 159 | self.embedNetG.load_state_dict(checkpoint['state_dict']) 160 | # Set cuda 161 | self.setCuda() 162 | 163 | def setCuda(self, value = 'default'): 164 | # If value is a string then use self.opt 165 | # If it is not a string then it should be True or False 166 | if type(value) == str: 167 | value = self.opt['useCUDA'] 168 | else: 169 | assert(type(value)==bool) 170 | 171 | if value == True: 172 | print('Check CUDA') 173 | self.embedNetF.cuda() 174 | self.embedNetG.cuda() 175 | self.attLSTM.cuda() 176 | self.biLSTM.cuda() 177 | else: 178 | self.embedNetF.cpu() 179 | self.embedNetG.cpu() 180 | self.attLSTM.cpu() 181 | self.biLSTM.cpu() 182 | 183 | def build(opt): 184 | model = FceEmbedding(opt) 185 | return model 186 | 187 | 188 | -------------------------------------------------------------------------------- /model/baselines/matching-net.py: -------------------------------------------------------------------------------- 1 | ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 2 | ## Created by: Albert Berenguel 3 | ## Computer Vision Center (CVC). Universitat Autonoma de Barcelona 4 | ## Email: aberenguel@cvc.uab.es 5 | ## Copyright (c) 2017 6 | ## 7 | ## This source code is licensed under the MIT-style license found in the 8 | ## LICENSE file in the root directory of this source tree 9 | ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 10 | 11 | import os 12 | import torch 13 | import torch.nn as nn 14 | from torch.autograd import Variable 15 | import importlib 16 | import numpy as np 17 | import time 18 | import math 19 | from sklearn.metrics import classification_report 20 | from sklearn.metrics import accuracy_score 21 | from utils import util 22 | 23 | class MatchingNet(nn.Module): 24 | def __init__(self, opt): 25 | super(MatchingNet, self).__init__() 26 | 27 | # function cosine-similarity layer 28 | self.cosineSim = nn.CosineSimilarity() 29 | 30 | # local embedding model (simple or FCE) 31 | self.embedModel = importlib.import_module(opt['embedModel']).build(opt) 32 | # set Cuda 33 | self.embedModel.setCuda(opt['useCUDA']) 34 | 35 | # load loss. Why does not load with the model 36 | self.lossF = nn.CrossEntropyLoss() 37 | 38 | # Set training or evaluation mode 39 | def set(self, mode): 40 | self.embedModel.set(mode) 41 | 42 | def forward(self, opt, input ): 43 | 44 | trainInput = input['trainInput'] 45 | trainTarget = input['trainTarget'] 46 | testInput = input['testInput'] 47 | testTarget = input['testTarget'] 48 | 49 | # Create one-hot vector 50 | trainTarget = trainTarget.view(-1,1) 51 | y_one_hot = trainTarget.clone() 52 | y_one_hot = y_one_hot.expand( 53 | trainTarget.size()[0], opt['nClasses']['train']) 54 | y_one_hot.data.zero_() 55 | y_one_hot = y_one_hot.float().scatter_(1, trainTarget, 1) 56 | 57 | # embed support set & test items using g and f respectively 58 | gS = self.embedModel.embedG(trainInput) 59 | fX = self.embedModel.embedF(testInput, gS, opt['steps']) 60 | 61 | # repeat tensors so that can get cosine sims in one call 62 | repeatgS = gS.repeat(fX.size(0),1) 63 | repeatfX = fX.repeat(1, gS.size(0)).view(fX.size(0)*gS.size(0),fX.size(1)) 64 | 65 | # weights are num_test x num_train (weigths per test item) 66 | weights = self.cosineSim(repeatgS, repeatfX).view(fX.size(0), gS.size(0),1) 67 | 68 | # one-hot matrix of train labels is expanded to num_train x num_test x num_labels 69 | expandOneHot = y_one_hot.view(1,y_one_hot.size(0),y_one_hot.size(1)).expand( 70 | fX.size(0),y_one_hot.size(0),y_one_hot.size(1)) 71 | 72 | # weights are expanded to match one-hot matrix 73 | expandWeights = weights.expand_as(expandOneHot) 74 | 75 | # cmul one-hot matrix by weights and sum along rows to get weight per label 76 | # final size: num_train x num_labels 77 | out = expandOneHot.mul(expandWeights).sum(1) 78 | 79 | # calculate NLL 80 | if self.embedModel.isTraining(): 81 | loss = self.lossF(out,testTarget) 82 | return out, loss 83 | else: 84 | return out 85 | 86 | def create_optimizer(opt, model): 87 | if opt['optimMethod'] == 'sgd': 88 | optimizer = torch.optim.SGD(model.parameters(), lr=opt['lr'], 89 | momentum=0.9, dampening=0.9, 90 | weight_decay=opt['weight_decay']) 91 | elif opt['optimMethod']: 92 | optimizer = torch.optim.Adam(model.parameters(), lr=opt['lr'], 93 | weight_decay=opt['weight_decay']) 94 | else: 95 | raise Exception('Not supported optimizer: {0}'.format(opt['optimMethod'])) 96 | return optimizer 97 | 98 | def adjust_learning_rate(opt,optimizer): 99 | """Updates the learning rate given the learning rate decay. 100 | The routine has been implemented according to the original Lua SGD optimizer 101 | """ 102 | for group in optimizer.param_groups: 103 | if 'step' not in group: 104 | group['step'] = 0 105 | group['step'] += 1 106 | 107 | group['lr'] = opt['lr'] / (1 + group['step'] * opt['lr_decay']) 108 | 109 | return optimizer 110 | 111 | def run(opt,data): 112 | 113 | # Set the model 114 | network = MatchingNet(opt) 115 | 116 | # Keep track of errors 117 | trainConf_pred = [] 118 | trainConf_gt = [] 119 | valConf_pred = {} 120 | valConf_gt = {} 121 | testConf_pred = {} 122 | testConf_gt = {} 123 | for i in opt['nTestShot']: 124 | valConf_pred[i] = [] 125 | valConf_gt[i] = [] 126 | testConf_pred[i] = [] 127 | testConf_gt[i] = [] 128 | 129 | # load params from file 130 | # paramsFile format: /SimpleEmbedding_opt.pkl 131 | # pathModelF format: /embedNet_F.pth.tar 132 | # pathModelF format: /embedNet_G.pth.tar 133 | if np.all([key in opt.keys() for key in ['paramsFile','pathModelF','pathModelG']]): 134 | 135 | if (os.path.isfile(opt['paramsFile']) and \ 136 | os.path.isfile(opt['pathModelF']) and \ 137 | os.path.isfile(opt['pathModelG'])): 138 | print('loading from params: %s' % (opt['paramsFile'])) 139 | print('loading model F: %s' % (opt['pathModelF'])) 140 | print('loading model G: %s' % (opt['pathModelG'])) 141 | network.embedModel.load(opt['paramsFile'], 142 | opt['pathModelF'], 143 | opt['pathModelG']) 144 | 145 | cost = 0 146 | timer = time.time() 147 | 148 | ################################################################# 149 | ############ Meta-training 150 | ################################################################# 151 | 152 | # Init optimizer 153 | optimizer = create_optimizer(opt, network.embedModel) 154 | 155 | # set net for training 156 | network.set('training') 157 | 158 | # train episode loop 159 | for episodeTrain,(x_support_set, y_support_set, x_target, target_y) in enumerate(data['train']): 160 | 161 | # Re-arange the Target vectors between [0..nClasses_train] 162 | dictLabels, dictLabelsInverse = util.createDictLabels(y_support_set) 163 | y_support_set = util.fitLabelsToRange(dictLabels, y_support_set) 164 | target_y = util.fitLabelsToRange(dictLabels, target_y) 165 | 166 | # Convert them in Variables 167 | input = {} 168 | input['trainInput'] = Variable(x_support_set).float() 169 | input['trainTarget'] = Variable(y_support_set,requires_grad=False).long() 170 | input['testInput'] = Variable(x_target).float() 171 | input['testTarget'] = Variable(target_y,requires_grad=False).long() 172 | 173 | # Convert to GPU if needed 174 | if opt['useCUDA']: 175 | input['trainInput'] = input['trainInput'].cuda() 176 | input['trainTarget'] = input['trainTarget'].cuda() 177 | input['testInput'] = input['testInput'].cuda() 178 | input['testTarget'] = input['testTarget'].cuda() 179 | 180 | output, loss = network(opt,input) 181 | optimizer.zero_grad() 182 | loss.backward() 183 | optimizer.step() 184 | 185 | # Adjust learning rate 186 | optimizer = adjust_learning_rate(opt, optimizer) 187 | 188 | cost = cost + loss 189 | 190 | # update stats 191 | values_pred, indices_pred = torch.max(output, 1) 192 | target_y = util.fitLabelsToRange(dictLabelsInverse, target_y) 193 | indices_pred = util.fitLabelsToRange(dictLabelsInverse, indices_pred.cpu().data) 194 | trainConf_pred.append(indices_pred.numpy()) 195 | trainConf_gt.append(target_y.numpy()) 196 | 197 | if episodeTrain % opt['printPer'] == 0: 198 | trainConf_pred = np.concatenate(trainConf_pred, axis=0) 199 | trainConf_gt = np.concatenate(trainConf_gt, axis=0) 200 | target_names = [str(i) for i in np.unique(trainConf_gt)] 201 | print( 202 | 'Training Episode: [{}/{} ({:.0f}%)]\tLoss: {:.3f}. Elapsed: {:.4f} s'.format( 203 | episodeTrain, len(data['train']), 100. * episodeTrain / len(data['train']), 204 | (cost.cpu().data.numpy() / opt['printPer'])[0],time.time() - timer)) 205 | print(classification_report(trainConf_gt, trainConf_pred, 206 | target_names=target_names)) 207 | # Set to 0 208 | trainConf_pred = [] 209 | trainConf_gt = [] 210 | 211 | ################################################################# 212 | ############ Meta-evaluation 213 | ################################################################# 214 | 215 | timerEval = time.time() 216 | 217 | # evaluate validation set 218 | network.set('evaluate') 219 | # validation episode loop 220 | for episodeValidation, (x_support_set, y_support_set, x_target, target_y) in enumerate(data['validation']): 221 | 222 | # Re-arange the Target vectors between [0..nClasses_train] 223 | dictLabels, dictLabelsInverse = util.createDictLabels(y_support_set) 224 | y_support_set = util.fitLabelsToRange(dictLabels, y_support_set) 225 | target_y = util.fitLabelsToRange(dictLabels, target_y) 226 | unique_labels = dictLabels.keys() 227 | 228 | # k-shot loop 229 | for k in opt['nTestShot']: 230 | 231 | # Select k samples from each class from x_support_set and 232 | indexes_selected = [] 233 | for k_selected in unique_labels: 234 | selected = np.random.choice(np.squeeze(np.where(y_support_set.numpy() == dictLabels[k_selected])) 235 | ,k, False) 236 | indexes_selected.append(selected) 237 | 238 | # Select the k-shot examples from the Tensors 239 | x_support_set_k = x_support_set[torch.from_numpy(np.squeeze(indexes_selected).flatten())] 240 | y_support_set_k = y_support_set[torch.from_numpy(np.squeeze(indexes_selected).flatten())] 241 | 242 | # Convert them in Variables 243 | input = {} 244 | input['trainInput'] = Variable(x_support_set_k).float() 245 | input['trainTarget'] = Variable(y_support_set_k, requires_grad=False).long() 246 | input['testInput'] = Variable(x_target).float() 247 | input['testTarget'] = Variable(target_y, requires_grad=False).long() 248 | 249 | # Convert to GPU if needed 250 | if opt['useCUDA']: 251 | input['trainInput'] = input['trainInput'].cuda() 252 | input['trainTarget'] = input['trainTarget'].cuda() 253 | input['testInput'] = input['testInput'].cuda() 254 | input['testTarget'] = input['testTarget'].cuda() 255 | 256 | output = network(opt, input) 257 | 258 | # update stats validation 259 | values_pred, indices_pred = torch.max(output, 1) 260 | target_y = util.fitLabelsToRange(dictLabelsInverse, target_y) 261 | indices_pred = util.fitLabelsToRange(dictLabelsInverse, indices_pred.cpu().data) 262 | valConf_pred[k].append(indices_pred.numpy()) 263 | valConf_gt[k].append(target_y.numpy()) 264 | 265 | for k in opt['nTestShot']: 266 | valConf_pred[k] = np.concatenate(valConf_pred[k], axis=0) 267 | valConf_gt[k] = np.concatenate(valConf_gt[k], axis=0) 268 | print('Validation: {}-shot Acc: {:.3f}. Elapsed: {:.4f} s.'.format( 269 | k,accuracy_score(valConf_gt[k],valConf_pred[k]),time.time() - timerEval)) 270 | target_names = [str(i) for i in np.unique(valConf_gt[k])] 271 | print(classification_report(valConf_gt[k], valConf_pred[k], 272 | target_names=target_names)) 273 | valConf_pred[k] = [] 274 | valConf_gt[k] = [] 275 | 276 | cost = 0 277 | timer = time.time() 278 | network.set('training') 279 | 280 | ################################################################# 281 | ############ Meta-testing 282 | ################################################################# 283 | # set net for testing 284 | network.set('evaluate') 285 | 286 | results = [] 287 | for n in np.arange(len(opt['nTest'])): 288 | # validation episode loop 289 | for episodeTest, (x_support_set, y_support_set, x_target, target_y) in enumerate(data['test'][n]): 290 | 291 | # Re-arange the Target vectors between [0..nClasses_train] 292 | dictLabels, dictLabelsInverse = util.createDictLabels(y_support_set) 293 | y_support_set = util.fitLabelsToRange(dictLabels, y_support_set) 294 | target_y = util.fitLabelsToRange(dictLabels, target_y) 295 | unique_labels = dictLabels.keys() 296 | 297 | # k-shot loop 298 | for k in opt['nTestShot']: 299 | 300 | # Select k samples from each class from x_support_set and 301 | indexes_selected = [] 302 | for k_selected in unique_labels: 303 | selected = np.random.choice(np.squeeze(np.where((y_support_set.numpy() == dictLabels[k_selected]))) 304 | , k, False) 305 | indexes_selected.append(selected) 306 | 307 | # Select the k-shot examples from the Tensors 308 | x_support_set_k = x_support_set[torch.from_numpy(np.squeeze(indexes_selected).flatten())] 309 | y_support_set_k = y_support_set[torch.from_numpy(np.squeeze(indexes_selected).flatten())] 310 | 311 | # Convert them in Variables 312 | input = {} 313 | input['trainInput'] = Variable(x_support_set_k).float() 314 | input['trainTarget'] = Variable(y_support_set_k, requires_grad=False).long() 315 | input['testInput'] = Variable(x_target).float() 316 | input['testTarget'] = Variable(target_y, requires_grad=False).long() 317 | 318 | # Convert to GPU if needed 319 | if opt['useCUDA']: 320 | input['trainInput'] = input['trainInput'].cuda() 321 | input['trainTarget'] = input['trainTarget'].cuda() 322 | input['testInput'] = input['testInput'].cuda() 323 | input['testTarget'] = input['testTarget'].cuda() 324 | 325 | output = network(opt, input) 326 | 327 | # update stats test 328 | values_pred, indices_pred = torch.max(output, 1) 329 | target_y = util.fitLabelsToRange(dictLabelsInverse, target_y) 330 | indices_pred = util.fitLabelsToRange(dictLabelsInverse, indices_pred.cpu().data) 331 | testConf_pred[k].append(indices_pred.numpy()) 332 | testConf_gt[k].append(target_y.numpy()) 333 | 334 | for k in opt['nTestShot']: 335 | acc = [] 336 | for i in np.arange(len(testConf_gt[k])): 337 | acc.append(accuracy_score(testConf_gt[k][i],testConf_pred[k][i])) 338 | low = np.mean(acc) - 1.96*(np.std(acc)/math.sqrt(len(acc))) 339 | high = np.mean(acc) + 1.96 * (np.std(acc) / math.sqrt(len(acc))) 340 | print('Test: nTest: {}. {}-shot. mAcc: {:.3f}. low mAcc: {:.3f}. high mAcc: {:.3f}.'.format( 341 | opt['nTest'][n],k,np.mean(acc),low, high)) 342 | testConf_pred[k] = [] 343 | testConf_gt[k] = [] 344 | results.append((opt['nTest'][n],k,np.mean(acc),low, high)) 345 | 346 | return results -------------------------------------------------------------------------------- /model/baselines/simple-embedding.py: -------------------------------------------------------------------------------- 1 | ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 2 | ## Created by: Albert Berenguel 3 | ## Computer Vision Center (CVC). Universitat Autonoma de Barcelona 4 | ## Email: aberenguel@cvc.uab.es 5 | ## Copyright (c) 2017 6 | ## 7 | ## This source code is licensed under the MIT-style license found in the 8 | ## LICENSE file in the root directory of this source tree 9 | ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 10 | 11 | import os 12 | import torch 13 | import torch.nn as nn 14 | import importlib 15 | import pickle 16 | 17 | class SimpleEmbedding(): 18 | def __init__(self, opt): 19 | self.opt = opt # Store the parameters 20 | self.buildModels(self.opt) 21 | self.setCuda() 22 | 23 | # Build F and G models 24 | def buildModels(self,opt): 25 | modelF = importlib.import_module(opt['learner']).build(opt) 26 | self.embedNetF = modelF.net # F function 27 | modelG = importlib.import_module(opt['learner']).build(opt) 28 | self.embedNetG = modelG.net # G function 29 | 30 | # Build list of parameters for optim 31 | def parameters(self): 32 | # TODO: why in the original code creates a dictionary with the same 33 | # parameters. model.params = {f=paramsG, g=paramsG} 34 | return list(self.embedNetG.parameters()) + list(self.embedNetG.parameters()) 35 | 36 | # Set training or evaluation mode 37 | def set(self,mode): 38 | if mode == 'training': 39 | self.embedNetF.train() 40 | self.embedNetG.train() 41 | elif mode == 'evaluate': 42 | self.embedNetF.eval() 43 | self.embedNetG.eval() 44 | else: 45 | print('model.set: undefined mode - %s' % (mode)) 46 | 47 | def isTraining(self): 48 | return self.embedNetF.training 49 | 50 | def default(self, dfDefault): 51 | self.df = dfDefault 52 | 53 | def embedF(self, input, g = [], K = []): 54 | return self.embedNetF(input) 55 | 56 | def embedG(self, input): 57 | return self.embedNetG(input) 58 | 59 | def save(self, path = './data'): 60 | # Save the opt parameters 61 | optParametersFile = open(os.path.join(path,'SimpleEmbedding_opt.pkl'), 'wb') 62 | pickle.dump(self.opt, optParametersFile) 63 | optParametersFile.close() 64 | # Clean not needed data of the models 65 | self.embedNetF.clearState() 66 | self.embedNetG.clearState() 67 | torch.save(self.embedNetF.state_dict(), os.path.join(path,'embedNetF.pth.tar')) 68 | torch.save(self.embedNetG.state_dict(), os.path.join(path, 'embedNetG.pth.tar')) 69 | 70 | def load(self, pathParams, pathModelF, pathModelG): 71 | # Load opt parameters 'SimpleEmbedding_opt.pkl' 72 | optParametersFile = open(pathParams, 'rb') 73 | self.opt = pickle.load(optParametersFile) 74 | optParametersFile.close() 75 | # build the models 76 | self.buildModels(self.opt) 77 | # Load the weights and biases of F and G 78 | checkpoint = torch.load(pathModelF) 79 | self.embedNetF.load_state_dict(checkpoint['state_dict']) 80 | checkpoint = torch.load(pathModelG) 81 | self.embedNetG.load_state_dict(checkpoint['state_dict']) 82 | # Set cuda 83 | self.setCuda() 84 | 85 | def setCuda(self, value = 'default'): 86 | # If value is a string then use self.opt 87 | # If it is not a string then it should be True or False 88 | if type(value) == str: 89 | value = self.opt['useCUDA'] 90 | else: 91 | assert(type(value)==bool) 92 | 93 | if value == True: 94 | print('Check CUDA') 95 | self.embedNetF.cuda() 96 | self.embedNetG.cuda() 97 | else: 98 | self.embedNetF.cpu() 99 | self.embedNetG.cpu() 100 | 101 | def build(opt): 102 | model = SimpleEmbedding(opt) 103 | return model 104 | 105 | 106 | -------------------------------------------------------------------------------- /model/lstm-classifier.py: -------------------------------------------------------------------------------- 1 | ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 2 | ## Created by: Albert Berenguel 3 | ## Computer Vision Center (CVC). Universitat Autonoma de Barcelona 4 | ## Email: aberenguel@cvc.uab.es 5 | ## Copyright (c) 2017 6 | ## 7 | ## This source code is licensed under the MIT-style license found in the 8 | ## LICENSE file in the root directory of this source tree 9 | ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 10 | 11 | import torch.nn as nn 12 | import torch.nn.init as init 13 | import numpy as np 14 | import math 15 | 16 | def convLayer(opt, layer_pos, nInput, nOutput, k ): 17 | "3x3 convolution with padding" 18 | #if 'BN_momentum' in opt.keys(): 19 | # batchNorm = nn.BatchNorm2d(nOutput,momentum=opt['BN_momentum']) 20 | #else: 21 | # batchNorm = nn.BatchNorm2d(nOutput) 22 | 23 | seq = nn.Sequential( 24 | nn.Conv2d(nInput, nOutput, kernel_size=k, 25 | stride=1, padding=1, bias=True), 26 | #batchNorm, 27 | opt['bnorm2d'][layer_pos], 28 | nn.ReLU(True), 29 | nn.MaxPool2d(kernel_size=2, stride=2) 30 | ) 31 | if opt['useDropout']: # Add dropout module 32 | list_seq = list(seq.modules())[1:] 33 | list_seq.append(nn.Dropout(0.1)) 34 | seq = nn.Sequential(*list_seq) 35 | return seq 36 | 37 | class Classifier(nn.Module): 38 | def __init__(self, opt): 39 | super(Classifier, self).__init__() 40 | 41 | finalSize = int(math.floor(opt['nIn'] / (2 * 2 * 2 * 2))) 42 | 43 | self.layer1 = convLayer(opt, 0, opt['nDepth'], opt['nFilters'], 3) 44 | self.layer2 = convLayer(opt, 1, opt['nFilters'], opt['nFilters'], 3) 45 | self.layer3 = convLayer(opt, 2, opt['nFilters'], opt['nFilters'], 3) 46 | self.layer4 = convLayer(opt, 3, opt['nFilters'], opt['nFilters'], 3) 47 | 48 | self.outSize = opt['nFilters']*finalSize*finalSize 49 | self.classify = opt['classify'] 50 | if self.classify: 51 | self.layer5 = nn.Linear(opt['nFilters']*finalSize*finalSize, opt['nClasses']['train']) 52 | self.outSize = opt['nClasses']['train'] 53 | 54 | # Initialize layers 55 | self.reset() 56 | 57 | def weights_init(self,module): 58 | for m in module.modules(): 59 | if isinstance(m, nn.Conv2d): 60 | init.xavier_uniform(m.weight, gain=np.sqrt(2)) 61 | init.constant(m.bias, 0) 62 | elif isinstance(m, nn.BatchNorm2d): 63 | m.weight.data.fill_(1) 64 | m.bias.data.zero_() 65 | 66 | def reset(self): 67 | self.weights_init(self.layer1) 68 | self.weights_init(self.layer2) 69 | self.weights_init(self.layer3) 70 | self.weights_init(self.layer4) 71 | 72 | def forward(self, x): 73 | """ 74 | Runs the CNN producing the embeddings and the gradients. 75 | :param image_input: Image input to produce embeddings for. [batch_size, 28, 28, 1] 76 | :return: Embeddings of size [batch_size, 64] 77 | """ 78 | x = self.layer1(x) 79 | x = self.layer2(x) 80 | x = self.layer3(x) 81 | x = self.layer4(x) 82 | x = x.view(x.size(0), -1) 83 | if self.classify: 84 | x = self.layer5(x) 85 | return x 86 | 87 | 88 | class MatchingNetClassifier(): 89 | def __init__(self, opt): 90 | 91 | self.net = Classifier(opt) 92 | if opt['classify']: 93 | self.criterion = nn.CrossEntropyLoss() 94 | else: 95 | self.criterion = [] 96 | self.nParams = sum([i.view(-1).size()[0] for i in self.net.parameters()]) 97 | self.outSize = self.net.outSize 98 | 99 | def build(opt): 100 | 101 | model = MatchingNetClassifier(opt) 102 | print('created net:') 103 | print(model.net) 104 | return model 105 | -------------------------------------------------------------------------------- /model/lstm/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gitabcworld/FewShotLearning/2eff1881f96212e0fb31737a48f82fab9c2fac83/model/lstm/__init__.py -------------------------------------------------------------------------------- /model/lstm/bnlstm.py: -------------------------------------------------------------------------------- 1 | #PyTorch implementation of Recurrent Batch Normalization 2 | # proposed by Cooijmans et al. (2017). https://arxiv.org/abs/1603.09025 3 | # Source code from: https://github.com/jihunchoi/recurrent-batch-normalization-pytorch 4 | 5 | import torch 6 | from torch import nn 7 | from torch.autograd import Variable 8 | from torch.nn import functional, init 9 | 10 | class SeparatedBatchNorm1d(nn.Module): 11 | 12 | """ 13 | A batch normalization module which keeps its running mean 14 | and variance separately per timestep. 15 | """ 16 | 17 | def __init__(self, num_features, max_length, eps=1e-5, momentum=0.1, 18 | affine=True): 19 | """ 20 | Most parts are copied from 21 | torch.nn.modules.batchnorm._BatchNorm. 22 | """ 23 | 24 | super(SeparatedBatchNorm1d, self).__init__() 25 | self.num_features = num_features 26 | self.max_length = max_length 27 | self.affine = affine 28 | self.eps = eps 29 | self.momentum = momentum 30 | if self.affine: 31 | self.weight = nn.Parameter(torch.FloatTensor(num_features)) 32 | self.bias = nn.Parameter(torch.FloatTensor(num_features)) 33 | else: 34 | self.register_parameter('weight', None) 35 | self.register_parameter('bias', None) 36 | for i in range(max_length): 37 | self.register_buffer( 38 | 'running_mean_{}'.format(i), torch.zeros(num_features)) 39 | self.register_buffer( 40 | 'running_var_{}'.format(i), torch.ones(num_features)) 41 | self.reset_parameters() 42 | 43 | def reset_parameters(self): 44 | for i in range(self.max_length): 45 | running_mean_i = getattr(self, 'running_mean_{}'.format(i)) 46 | running_var_i = getattr(self, 'running_var_{}'.format(i)) 47 | running_mean_i.zero_() 48 | running_var_i.fill_(1) 49 | if self.affine: 50 | self.weight.data.uniform_() 51 | self.bias.data.zero_() 52 | 53 | def _check_input_dim(self, input_): 54 | if input_.size(1) != self.running_mean_0.nelement(): 55 | raise ValueError('got {}-feature tensor, expected {}' 56 | .format(input_.size(1), self.num_features)) 57 | 58 | def forward(self, input_, time): 59 | self._check_input_dim(input_) 60 | if time >= self.max_length: 61 | time = self.max_length - 1 62 | running_mean = getattr(self, 'running_mean_{}'.format(time)) 63 | running_var = getattr(self, 'running_var_{}'.format(time)) 64 | return functional.batch_norm( 65 | input=input_, running_mean=running_mean, running_var=running_var, 66 | weight=self.weight, bias=self.bias, training=self.training, 67 | momentum=self.momentum, eps=self.eps) 68 | 69 | def __repr__(self): 70 | return ('{name}({num_features}, eps={eps}, momentum={momentum},' 71 | ' max_length={max_length}, affine={affine})' 72 | .format(name=self.__class__.__name__, **self.__dict__)) 73 | 74 | 75 | class LSTMCell(nn.Module): 76 | 77 | """A basic LSTM cell.""" 78 | 79 | def __init__(self, input_size, hidden_size, use_bias=True): 80 | """ 81 | Most parts are copied from torch.nn.LSTMCell. 82 | """ 83 | 84 | super(LSTMCell, self).__init__() 85 | self.input_size = input_size 86 | self.hidden_size = hidden_size 87 | self.use_bias = use_bias 88 | self.weight_ih = nn.Parameter( 89 | torch.FloatTensor(input_size, 4 * hidden_size)) 90 | self.weight_hh = nn.Parameter( 91 | torch.FloatTensor(hidden_size, 4 * hidden_size)) 92 | if use_bias: 93 | self.bias = nn.Parameter(torch.FloatTensor(4 * hidden_size)) 94 | else: 95 | self.register_parameter('bias', None) 96 | self.reset_parameters() 97 | 98 | def reset_parameters(self): 99 | """ 100 | Initialize parameters following the way proposed in the paper. 101 | """ 102 | 103 | init.orthogonal(self.weight_ih.data) 104 | weight_hh_data = torch.eye(self.hidden_size) 105 | weight_hh_data = weight_hh_data.repeat(1, 4) 106 | self.weight_hh.data.set_(weight_hh_data) 107 | # The bias is just set to zero vectors. 108 | if self.use_bias: 109 | init.constant(self.bias.data, val=0) 110 | 111 | def forward(self, input_, hx): 112 | """ 113 | Args: 114 | input_: A (batch, input_size) tensor containing input 115 | features. 116 | hx: A tuple (h_0, c_0), which contains the initial hidden 117 | and cell state, where the size of both states is 118 | (batch, hidden_size). 119 | Returns: 120 | h_1, c_1: Tensors containing the next hidden and cell state. 121 | """ 122 | 123 | h_0, c_0 = hx 124 | batch_size = h_0.size(0) 125 | bias_batch = (self.bias.unsqueeze(0) 126 | .expand(batch_size, *self.bias.size())) 127 | wh_b = torch.addmm(bias_batch, h_0, self.weight_hh) 128 | wi = torch.mm(input_, self.weight_ih) 129 | f, i, o, g = torch.split(wh_b + wi, 130 | split_size=self.hidden_size, dim=1) 131 | c_1 = torch.sigmoid(f)*c_0 + torch.sigmoid(i)*torch.tanh(g) 132 | h_1 = torch.sigmoid(o) * torch.tanh(c_1) 133 | return h_1, c_1 134 | 135 | def __repr__(self): 136 | s = '{name}({input_size}, {hidden_size})' 137 | return s.format(name=self.__class__.__name__, **self.__dict__) 138 | 139 | 140 | class BNLSTMCell(nn.Module): 141 | 142 | """A BN-LSTM cell.""" 143 | 144 | def __init__(self, input_size, hidden_size, max_length, use_bias=True): 145 | 146 | super(BNLSTMCell, self).__init__() 147 | self.input_size = input_size 148 | self.hidden_size = hidden_size 149 | self.max_length = max_length 150 | self.use_bias = use_bias 151 | self.weight_ih = nn.Parameter( 152 | torch.FloatTensor(input_size, 4 * hidden_size)) 153 | self.weight_hh = nn.Parameter( 154 | torch.FloatTensor(hidden_size, 4 * hidden_size)) 155 | if use_bias: 156 | self.bias = nn.Parameter(torch.FloatTensor(4 * hidden_size)) 157 | else: 158 | self.register_parameter('bias', None) 159 | # BN parameters 160 | self.bn_ih = SeparatedBatchNorm1d( 161 | num_features=4 * hidden_size, max_length=max_length) 162 | self.bn_hh = SeparatedBatchNorm1d( 163 | num_features=4 * hidden_size, max_length=max_length) 164 | self.bn_c = SeparatedBatchNorm1d( 165 | num_features=hidden_size, max_length=max_length) 166 | self.reset_parameters() 167 | 168 | def reset_parameters(self): 169 | """ 170 | Initialize parameters following the way proposed in the paper. 171 | """ 172 | 173 | # The input-to-hidden weight matrix is initialized orthogonally. 174 | init.orthogonal(self.weight_ih.data) 175 | # The hidden-to-hidden weight matrix is initialized as an identity 176 | # matrix. 177 | weight_hh_data = torch.eye(self.hidden_size) 178 | weight_hh_data = weight_hh_data.repeat(1, 4) 179 | self.weight_hh.data.set_(weight_hh_data) 180 | # The bias is just set to zero vectors. 181 | init.constant(self.bias.data, val=0) 182 | # Initialization of BN parameters. 183 | self.bn_ih.reset_parameters() 184 | self.bn_hh.reset_parameters() 185 | self.bn_c.reset_parameters() 186 | self.bn_ih.bias.data.fill_(0) 187 | self.bn_hh.bias.data.fill_(0) 188 | self.bn_ih.weight.data.fill_(0.1) 189 | self.bn_hh.weight.data.fill_(0.1) 190 | self.bn_c.weight.data.fill_(0.1) 191 | 192 | def forward(self, input_, hx, time): 193 | """ 194 | Args: 195 | input_: A (batch, input_size) tensor containing input 196 | features. 197 | hx: A tuple (h_0, c_0), which contains the initial hidden 198 | and cell state, where the size of both states is 199 | (batch, hidden_size). 200 | time: The current timestep value, which is used to 201 | get appropriate running statistics. 202 | Returns: 203 | h_1, c_1: Tensors containing the next hidden and cell state. 204 | """ 205 | 206 | h_0, c_0 = hx 207 | batch_size = h_0.size(0) 208 | bias_batch = (self.bias.unsqueeze(0) 209 | .expand(batch_size, *self.bias.size())) 210 | wh = torch.mm(h_0, self.weight_hh) 211 | wi = torch.mm(input_, self.weight_ih) 212 | bn_wh = self.bn_hh(wh, time=time) 213 | bn_wi = self.bn_ih(wi, time=time) 214 | f, i, o, g = torch.split(bn_wh + bn_wi + bias_batch, 215 | split_size=self.hidden_size, dim=1) 216 | c_1 = torch.sigmoid(f)*c_0 + torch.sigmoid(i)*torch.tanh(g) 217 | h_1 = torch.sigmoid(o) * torch.tanh(self.bn_c(c_1, time=time)) 218 | return h_1, c_1 219 | 220 | 221 | class LSTM(nn.Module): 222 | 223 | """A module that runs multiple steps of LSTM.""" 224 | 225 | def __init__(self, cell_class, input_size, hidden_size, num_layers=1, 226 | use_bias=True, batch_first=False, dropout=0, **kwargs): 227 | super(LSTM, self).__init__() 228 | self.cell_class = cell_class 229 | self.input_size = input_size 230 | self.hidden_size = hidden_size 231 | self.num_layers = num_layers 232 | self.use_bias = use_bias 233 | self.batch_first = batch_first 234 | self.dropout = dropout 235 | 236 | self.cells = [] 237 | for layer in range(num_layers): 238 | layer_input_size = input_size if layer == 0 else hidden_size 239 | cell = cell_class(input_size=layer_input_size, 240 | hidden_size=hidden_size, 241 | **kwargs) 242 | self.cells.append(cell) 243 | setattr(self, 'cell_{}'.format(layer), cell) 244 | self.dropout_layer = nn.Dropout(dropout) 245 | self.reset_parameters() 246 | 247 | def reset_parameters(self): 248 | for cell in self.cells: 249 | cell.reset_parameters() 250 | 251 | @staticmethod 252 | def _forward_rnn(cell, input_, length, hx): 253 | max_time = input_.size(0) 254 | output = [] 255 | for time in range(max_time): 256 | if isinstance(cell, BNLSTMCell): 257 | h_next, c_next = cell(input_=input_[time], hx=hx, time=time) 258 | else: 259 | h_next, c_next = cell(input_=input_[time], hx=hx) 260 | mask = (time < length).float().unsqueeze(1).expand_as(h_next) 261 | h_next = h_next*mask + hx[0]*(1 - mask) 262 | c_next = c_next*mask + hx[1]*(1 - mask) 263 | hx_next = (h_next, c_next) 264 | output.append(h_next) 265 | hx = hx_next 266 | output = torch.stack(output, 0) 267 | return output, hx 268 | 269 | def forward(self, input_, length=None, hx=None): 270 | if self.batch_first: 271 | input_ = input_.transpose(0, 1) 272 | max_time, batch_size, _ = input_.size() 273 | if length is None: 274 | length = Variable(torch.LongTensor([max_time] * batch_size)) 275 | if input_.is_cuda: 276 | length = length.cuda() 277 | if hx is None: 278 | hx = Variable(input_.data.new(batch_size, self.hidden_size).zero_()) 279 | hx = (hx, hx) 280 | h_n = [] 281 | c_n = [] 282 | layer_output = None 283 | for layer in range(self.num_layers): 284 | layer_output, (layer_h_n, layer_c_n) = LSTM._forward_rnn( 285 | cell=self.cells[layer], input_=input_, length=length, hx=hx) 286 | input_ = self.dropout_layer(layer_output) 287 | h_n.append(layer_h_n) 288 | c_n.append(layer_c_n) 289 | output = layer_output 290 | h_n = torch.stack(h_n, 0) 291 | c_n = torch.stack(c_n, 0) 292 | return output, (h_n, c_n) 293 | 294 | ''' 295 | class RecurrentLSTMNetwork(nn.Module): 296 | def __init__(self, opt): 297 | super(RecurrentLSTMNetwork, self).__init__() 298 | 299 | self.inputFeatures = opt['inputFeatures'] if 'inputFeatures' in opt.keys() else 10 300 | self.hiddenFeatures = opt['hiddenFeatures'] if 'hiddenFeatures' in opt.keys() else 100 301 | self.outputType = opt['outputType'] if 'outputType' in opt.keys() else 'last' # 'last' or 'all' 302 | self.batchNormalization = opt['batchNormalization'] if 'batchNormalization' in opt.keys() else False 303 | self.maxBatchNormalizationLayers = opt['maxBatchNormalizationLayers'] if 'batchNormalization' in opt.keys() else 10 304 | 305 | # containers 306 | self.layers = {} 307 | 308 | # parameters 309 | self.p = {} 310 | self.p['W'] = torch.zeros(self.inputFeatures+self.hiddenFeatures,4 * self.hiddenFeatures) 311 | 312 | #TODO: delete this line. only for debugging 313 | self.batchNormalization = True 314 | 315 | if self.batchNormalization: 316 | # TODO: check if nn.BatchNorm1d or torch.legacy.nn.BatchNormalization 317 | # translation and scaling parameters are shared across time. 318 | lstm_bn = nn.BatchNorm1d(4*self.hiddenFeatures) 319 | cell_bn = nn.BatchNorm1d(self.hiddenFeatures) 320 | self.layers = {'lstm_bn':[lstm_bn],'cell_bn':[cell_bn]} 321 | 322 | for i in range(2,self.maxBatchNormalizationLayers): 323 | lstm_bn = nn.BatchNorm1d(4*self.hiddenFeatures) 324 | cell_bn = nn.BatchNorm1d(self.hiddenFeatures) 325 | self.layers['lstm_bn'].append(lstm_bn) 326 | self.layers['cell_bn'].append(cell_bn) 327 | 328 | # Initializing scaling to <1 is recommended for LSTM batch norm 329 | self.layers['lstm_bn'][0].weight.data.fill_(0.1) 330 | self.layers['lstm_bn'][0].bias.data.zero_() 331 | self.layers['cell_bn'][0].weight.data.fill_(0.1) 332 | self.layers['cell_bn'][0].bias.data.zero_() 333 | else: 334 | self.p['b'] = torch.zeros(1, 4*self.hiddenFeatures) 335 | 336 | def forward(self, x, prevState = None ): 337 | 338 | # dimensions 339 | if len(x.size()) == 2: x = x.unsqueeze(0) 340 | batch = x.size(0) 341 | steps = x.size(1) 342 | 343 | if prevState == None: prevState = {} 344 | hs = {} 345 | cs = {} 346 | for t in range(steps): 347 | # xt 348 | xt = x[:,t,:] 349 | # prev h and pre c 350 | hp = hs[t-1] or prevState.h or torch.zeros() 351 | a = 0 352 | ''' -------------------------------------------------------------------------------- /model/lstm/learner.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.autograd import Variable 4 | import importlib 5 | import numpy as np 6 | 7 | class Learner(nn.Module): 8 | def __init__(self, opt): 9 | super(Learner, self).__init__() 10 | 11 | # Note: we are using two networks to simulate learner where one network 12 | # is used for backward pass on test set and the other is used simply to get 13 | # gradients which serve as input to meta-learner. 14 | # This is a simple way to make computation graph work 15 | # so that it doesn't include gradients of learner 16 | 17 | # Create another network with only shared 'running_mean' and 'running_var' 18 | # this weights can be found in BatchNormalization layers (or InstanceNormalization) 19 | # In torch with the instruction: model.net:clone('running_mean', 'running_var') 20 | # it is already done but with Pytorch we need to copy those parameters with 21 | # state_dict and load_state_dict every time we want to use one of the shared 22 | # networks. 23 | 24 | # Add dimension filters for the cnn 25 | opt['nFilters'] = 32 26 | # Create 4 layers with batch norm. Share layers between self.model and self.modelF 27 | self.bn_layers = [] 28 | for i in range(4): 29 | if 'BN_momentum' in opt.keys(): 30 | self.bn_layers.append(nn.BatchNorm2d(opt['nFilters'], 31 | momentum=opt['BN_momentum'])) 32 | else: 33 | self.bn_layers.append(nn.BatchNorm2d(opt['nFilters'])) 34 | opt['bnorm2d'] = self.bn_layers 35 | 36 | # local embedding model 37 | self.model = importlib.import_module(opt['learner']).build(opt) 38 | self.modelF = importlib.import_module(opt['learner']).build(opt) 39 | self.nParams = self.modelF.nParams 40 | self.params = {param[0]: param[1] for param in self.modelF.net.named_parameters()} 41 | 42 | def unflattenParams_net(self,flatParams): 43 | flatParams = flatParams.squeeze() 44 | indx = 0 45 | for param in self.model.net.parameters(): 46 | lengthParam = param.view(-1).size()[0] 47 | param = flatParams[indx:lengthParam].view_as(param).clone() 48 | 49 | def forward(self, inputs, targets ): 50 | 51 | output = self.modelF.net(inputs) 52 | loss = self.modelF.criterion(output, targets) 53 | return output, loss 54 | 55 | def feval(self, inputs, targets): 56 | # reset gradients 57 | self.model.net.zero_grad() 58 | # evaluate function for complete mini batch 59 | outputs = self.model.net(inputs) 60 | loss = self.model.criterion(outputs, targets) 61 | loss.backward() 62 | grads = torch.cat([param.grad.view(-1) for param in self.model.net.parameters()], 0) 63 | return grads,loss 64 | 65 | def reset(self): 66 | self.model.net.reset() 67 | self.modelF.net.reset() 68 | 69 | # Set training or evaluation mode 70 | def set(self,mode): 71 | if mode == 'training': 72 | self.model.net.train() 73 | self.modelF.net.train() 74 | elif mode == 'evaluate': 75 | self.model.net.eval() 76 | self.modelF.net.eval() 77 | else: 78 | print('model.set: undefined mode - %s' % (mode)) 79 | 80 | def setCuda(self, value = True): 81 | # If value is a string then use self.opt 82 | # If it is not a string then it should be True or False 83 | if value == True: 84 | self.model.net.cuda() 85 | self.modelF.net.cuda() 86 | else: 87 | self.model.net.cpu() 88 | self.modelF.net.cpu() -------------------------------------------------------------------------------- /model/lstm/lstmhelper.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.autograd import Variable 4 | 5 | P = Variable(torch.FloatTensor(1).fill_(10)) 6 | expP = Variable(torch.exp(P.data)) 7 | negExpP = Variable(torch.exp(-P.data)) 8 | 9 | def preProc1(x): 10 | # Access the global variables 11 | global P,expP,negExpP 12 | P = P.type_as(x) 13 | expP = expP.type_as(x) 14 | negExpP = negExpP.type_as(x) 15 | 16 | # Create a variable filled with -1. Second part of the condition 17 | z = Variable(torch.zeros(x.size()).fill_(-1)).type_as(x) 18 | absX = torch.abs(x) 19 | cond1 = torch.gt(absX, negExpP) 20 | if (torch.sum(cond1) > 0).data.all(): 21 | x1 = torch.log(torch.abs(x[cond1]))/P 22 | z[cond1] = x1 23 | return z 24 | 25 | def preProc2(x): 26 | # Access the global variables 27 | global P, expP, negExpP 28 | P = P.type_as(x) 29 | expP = expP.type_as(x) 30 | negExpP = negExpP.type_as(x) 31 | 32 | # Create a variable filled with -1. Second part of the condition 33 | z = Variable(torch.zeros(x.size())).type_as(x) 34 | absX = torch.abs(x) 35 | cond1 = torch.gt(absX, negExpP) 36 | cond2 = torch.le(absX, negExpP) 37 | if (torch.sum(cond1) > 0).data.all(): 38 | x1 = torch.sign(x[cond1]) 39 | z[cond1] = x1 40 | if (torch.sum(cond2) > 0).data.all(): 41 | x2 = x[cond2]*expP 42 | z[cond2] = x2 43 | return z 44 | 45 | def preprocess(grad,loss): 46 | 47 | #preGrad = Variable(grad.data.new(grad.data.size()[0], 1, 2).zero_()) 48 | #preGrad = grad.expand(grad.data.size()[0], 1, 2) 49 | preGrad = grad.clone().expand(grad.data.size()[0], 1, 2) 50 | preGrad[:, :, 0] = preProc1(grad) 51 | preGrad[:, :, 1] = preProc2(grad) 52 | 53 | #lossT = Variable(loss.data.new(1,1,1).zero_()) 54 | #lossT[0] = loss 55 | #preLoss = Variable(loss.data.new(1,1,2).zero_()) 56 | #preLoss = loss.expand(1, 1, 2) 57 | preLoss = loss.clone().expand(1, 1, 2) 58 | preLoss[:, :, 0] = preProc1(loss) 59 | preLoss[:, :, 1] = preProc2(loss) 60 | return preGrad,preLoss 61 | -------------------------------------------------------------------------------- /model/lstm/metaLearner.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.autograd import Variable 4 | import importlib 5 | import model.lstm.bnlstm as bnlstm 6 | import model.lstm.metalstm as metalstm 7 | from model.lstm.recurrentLSTMNetwork import RecurrentLSTMNetwork 8 | from model.lstm.lstmhelper import preprocess 9 | from utils import util 10 | from visualize.visualize import make_dot 11 | 12 | class MetaLearner(nn.Module): 13 | def __init__(self, opt): 14 | super(MetaLearner, self).__init__() 15 | 16 | self.nHidden = opt['nHidden'] if 'nHidden' in opt.keys() else 20 17 | self.maxGradNorm = opt['maxGradNorm'] if 'maxGradNorm' in opt.keys() else 0.25 18 | 19 | #inputFeatures = 4 #loss(2) + preGrad(2) = 4 20 | inputFeatures = 2 # loss(2) + preGrad(2) = 4 21 | batchNormalization1 = opt['BN1'] if 'BN1' in opt.keys() else False 22 | maxBatchNormalizationLayers = opt['steps'] if 'steps' in opt.keys() else 1 23 | batchNormalization1 = False 24 | if batchNormalization1: 25 | self.lstm = bnlstm.LSTM(cell_class=bnlstm.BNLSTMCell, input_size=inputFeatures, 26 | hidden_size=self.nHidden, batch_first=True, 27 | max_length=maxBatchNormalizationLayers) 28 | else: 29 | self.lstm = nn.LSTM(input_size=inputFeatures, 30 | hidden_size=self.nHidden, 31 | batch_first=True, 32 | num_layers=maxBatchNormalizationLayers) 33 | 34 | # set initial hidden layer and cell state 35 | # num_layers * num_directions, batch, hidden_size 36 | batch_size = 1 37 | self.lstm_h0_c0 = None 38 | 39 | #self.lstm_c0 = Variable(torch.rand(self.lstm.num_layers, batch_size, self.lstm.hidden_size), 40 | # requires_grad=False).cuda() 41 | #self.lstm_h0 = Variable(torch.rand(self.lstm.num_layers, batch_size, self.lstm.hidden_size), 42 | # requires_grad=False).cuda() 43 | 44 | # Meta-learner LSTM 45 | # TODO: BatchNormalization in MetaLSTM 46 | batchNormalization2 = opt['BN2'] if 'BN2' in opt.keys() else False 47 | self.lstm2 = metalstm.MetaLSTM(input_size = opt['nParams'], 48 | hidden_size = self.nHidden, 49 | batch_first=True, 50 | num_layers=maxBatchNormalizationLayers) 51 | 52 | # set initial c0 and h0 states for lstm2 53 | batch_size = 1 54 | self.lstm2_fS_iS_cS_deltaS = None 55 | 56 | # Join parameters as input for optimizer 57 | self.params = lambda: list(self.lstm.named_parameters()) + list(self.lstm2.named_parameters()) 58 | self.params = { param[0]:param[1] for param in self.params()} 59 | 60 | # initialize weights learner 61 | for names in self.lstm._all_weights: 62 | for name in filter(lambda n: "weight" in n, names): 63 | weight = getattr(self.lstm, name) 64 | weight.data.uniform_(-0.01, 0.01) 65 | 66 | # initialize weights meta-learner for all layers. 67 | for params in self.lstm2.named_parameters(): 68 | if 'WF' in names[0] or names[0] in names[0] or 'cI' in params[0]: 69 | params[1].data.uniform_(-0.01, 0.01) 70 | 71 | # want initial forget value to be high and input value 72 | # to be low so that model starts with gradient descent 73 | for params in self.lstm2.named_parameters(): 74 | if "cell_0.bF" in names[0]: 75 | params[0].data.uniform_(4, 5) 76 | if "cell_0.bI" in names[0]: 77 | params[0].data.uniform_(-4, -5) 78 | 79 | # Set initial cell state = learner's initial parameters 80 | initialParams = torch.cat([value.view(-1) for key,value in opt['learnerParams'].items()], 0) 81 | #self.lstm2.cells[0].cI = initialParams.unsqueeze(1).clone() 82 | 83 | #torch.nn.Parameter(initial_param['weight']) 84 | for params in self.lstm2.named_parameters(): 85 | if "cell_0.cI" in params[0]: 86 | params[1].data = initialParams.data.clone() 87 | # self.lstm2.cells[0].cI.data = initialParams.view_as(self.lstm2.cells[0].cI).clone() 88 | a = 0 89 | 90 | #for params in self.lstm2.parameters(): 91 | # params.retain_grad() 92 | 93 | #for params in self.lstm.parameters(): 94 | # params.retain_grad() 95 | 96 | 97 | def forward(self, learner, trainInput, trainTarget, testInput, testTarget 98 | , steps, batchSize, evaluate = False ): 99 | 100 | trainSize = trainInput.size(0) 101 | 102 | # reset parameters for each dataset 103 | # Modules with learnable parameters have a reset(). This function 104 | # allows to re-initialize parameters. It's also used for weight 105 | # initialization. 106 | learner.reset() 107 | learner.set('training') 108 | 109 | # Set learner's initial parameters = initial cell state 110 | util.unflattenParams(learner.model, self.lstm2.cells[0].cI) 111 | 112 | #for params in self.lstm2.named_parameters(): 113 | # if "cell_0.cI" in params[0]: 114 | # util.unflattenParams(learner.model, params[1]) 115 | #util.unflattenParams(learner.model, self.lstm2.cells[0].cI) 116 | 117 | idx = 0 118 | for s in range(steps): 119 | for i in range(0,trainSize,batchSize): 120 | # get image input & label 121 | x = trainInput[i:batchSize,:] 122 | y = trainTarget[i:batchSize] 123 | 124 | #if idx > 0: 125 | # # break computational graph 126 | # learnerParams = output.detach() 127 | # # Unflatten params and copy parameters to learner network 128 | # util.unflattenParams(learner.model,learnerParams) 129 | 130 | # get gradient and loss w/r/t learnerParams for input+label 131 | grad_model, loss_model = learner.feval(x,y) 132 | grad_model = grad_model.view(grad_model.size()[0], 1, 1) 133 | 134 | # Delete albert 135 | inputs = torch.cat((grad_model, loss_model.expand_as(grad_model)), 2) 136 | ''' 137 | # preprocess grad & loss by DeepMind "Learning to learn" 138 | preGrad, preLoss = preprocess(grad_model,loss_model) 139 | # use meta-learner to get learner's next parameters 140 | lossExpand = preLoss.expand_as(preGrad) 141 | inputs = torch.cat((lossExpand,preGrad),2) 142 | ''' 143 | output, self.lstm_h0_c0 = self.lstm(inputs, self.lstm_h0_c0) 144 | self.lstm2_fS_iS_cS_deltaS = self.lstm2((output,grad_model), 145 | self.lstm2_fS_iS_cS_deltaS) 146 | 147 | ## Delete 148 | util.unflattenParams(learner.modelF, self.lstm2_fS_iS_cS_deltaS[2]) 149 | output, loss = learner(testInput, testTarget) 150 | #g = make_dot(loss) 151 | #g.render('/home/aberenguel/tmp/g.gv', view=True) 152 | #torch.autograd.grad(loss, list(self.lstm2.parameters()) + list(self.lstm.parameters())) 153 | ###### 154 | 155 | # get the internal cell state 156 | output = self.lstm2_fS_iS_cS_deltaS[2] 157 | 158 | # Unflatten params and copy parameters to learner network 159 | util.unflattenParams(learner.model, output) 160 | 161 | idx = idx + 1 162 | 163 | # Unflatten params and copy parameters to learner network 164 | util.unflattenParams(learner.modelF, output) 165 | 166 | ## get loss + predictions from learner. 167 | ## use batch-stats when meta-training; otherwise, use running-stats 168 | if evaluate: 169 | learner.set('evaluate') 170 | # Do a dummy forward / backward pass to get the correct grads into learner 171 | output, loss = learner(testInput, testTarget) 172 | 173 | # replace the gradients with the lstm2 174 | torch.autograd.grad(loss, self.lstm2.parameters()) 175 | 176 | return output, loss 177 | 178 | 179 | def gradNorm(self, loss): 180 | 181 | print('Grads lstm + lstm2:') 182 | for params in self.lstm.parameters(): 183 | print(params.grad) 184 | 185 | for params in self.lstm2.parameters(): 186 | print(params.grad) 187 | a = 0 188 | 189 | def setCuda(self, value = True): 190 | if value: 191 | self.lstm.cuda() 192 | self.lstm2.cuda() 193 | else: 194 | self.lstm.cpu() 195 | self.lstm2.cpu() -------------------------------------------------------------------------------- /model/lstm/metalstm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.autograd import Variable 4 | from torch.nn import functional, init 5 | 6 | class MetaLSTMCell(nn.Module): 7 | 8 | """A basic LSTM cell.""" 9 | 10 | def __init__(self, input_size, hidden_size): 11 | """ 12 | Most parts are copied from torch.nn.LSTMCell. 13 | """ 14 | super(MetaLSTMCell, self).__init__() 15 | self.input_size = input_size 16 | self.hidden_size = hidden_size 17 | 18 | self.WF = nn.Parameter(torch.FloatTensor(hidden_size + 2, 1)) 19 | self.WI = nn.Parameter(torch.FloatTensor(hidden_size + 2, 1)) 20 | # initial cell state is a param 21 | self.cI = nn.Parameter(torch.FloatTensor(input_size, 1)) 22 | self.bI = nn.Parameter(torch.FloatTensor(1, 1)) 23 | self.bF = nn.Parameter(torch.FloatTensor(1, 1)) 24 | self.m = nn.Parameter(torch.FloatTensor(1)) 25 | 26 | ''' 27 | self.WF = Variable(torch.FloatTensor(hidden_size + 2, 1), requires_grad=True) 28 | self.WI = Variable(torch.FloatTensor(hidden_size + 2, 1), requires_grad=True) 29 | # initial cell state is a param 30 | self.cI = Variable(torch.FloatTensor(input_size, 1), requires_grad=True) 31 | self.bI = Variable(torch.FloatTensor(1, 1), requires_grad=True) 32 | self.bF = Variable(torch.FloatTensor(1, 1), requires_grad=True) 33 | self.m = Variable(torch.FloatTensor(1), requires_grad=True) 34 | ''' 35 | 36 | self.reset_parameters() 37 | 38 | def reset_parameters(self): 39 | """ 40 | Initialize parameters 41 | """ 42 | self.WF.data.uniform_(-0.01, 0.01) 43 | self.WI.data.uniform_(-0.01, 0.01) 44 | self.cI.data.uniform_(-0.01, 0.01) 45 | self.bI.data.zero_() 46 | self.bF.data.zero_() 47 | self.m.data.zero_() 48 | 49 | def forward(self, input_, grads_, hx): 50 | """ 51 | Args: 52 | input_: A (batch, input_size) tensor containing input 53 | features. 54 | hx: A tuple (h_0, c_0), which contains the initial hidden 55 | and cell state, where the size of both states is 56 | (batch, hidden_size). 57 | Returns: 58 | h_1, c_1: Tensors containing the next hidden and cell state. 59 | """ 60 | 61 | # next forget, input gate 62 | (fS, iS, cS, deltaS) = hx 63 | fS = torch.cat((cS, fS), 1) 64 | iS = torch.cat((cS, iS), 1) 65 | 66 | fS = torch.mm(torch.cat((input_,fS), 1),self.WF) 67 | fS += self.bF.expand_as(fS) 68 | 69 | iS = torch.mm(torch.cat((input_,iS), 1),self.WI) 70 | iS += self.bI.expand_as(iS) 71 | 72 | # next delta 73 | deltaS = self.m * deltaS - nn.Sigmoid()(iS).mul(grads_) 74 | 75 | # next cell/params 76 | cS = nn.Sigmoid()(fS).mul(cS) + deltaS 77 | 78 | return fS, iS, cS, deltaS 79 | 80 | def __repr__(self): 81 | s = '{name}({input_size}, {hidden_size})' 82 | return s.format(name=self.__class__.__name__, **self.__dict__) 83 | 84 | 85 | class MetaLSTM(nn.Module): 86 | 87 | """A module that runs multiple steps of LSTM.""" 88 | 89 | def __init__(self, input_size, hidden_size, 90 | batch_first = False, num_layers=1): 91 | super(MetaLSTM, self).__init__() 92 | self.input_size = input_size 93 | self.hidden_size = hidden_size 94 | self.num_layers = num_layers 95 | self.batch_first = batch_first 96 | 97 | self.cells = [] 98 | for layer in range(num_layers): 99 | layer_input_size = input_size if layer == 0 else hidden_size 100 | cell = MetaLSTMCell(input_size=layer_input_size, 101 | hidden_size=hidden_size) 102 | self.cells.append(cell) 103 | setattr(self, 'cell_{}'.format(layer), cell) 104 | self.reset_parameters() 105 | 106 | def reset_parameters(self): 107 | for cell in self.cells: 108 | cell.reset_parameters() 109 | 110 | @staticmethod 111 | def _forward_rnn(cell, input_, grads_, length, hx): 112 | max_time = input_.size(0) 113 | output = [] 114 | for time in range(max_time): 115 | hx = cell(input_=input_[time],grads_=grads_[time], hx=hx) 116 | #mask = (time < length).float().unsqueeze(1).expand_as(h_next[0]) 117 | #fS_next = h_next[0] * mask + hx[0] * (1 - mask) 118 | #iS_next = h_next[1] * mask + hx[1] * (1 - mask) 119 | #cS_next = h_next[2] * mask + hx[2] * (1 - mask) 120 | #deltaS_next = h_next[3] * mask + hx[3] * (1 - mask) 121 | #hx_next = (fS_next, iS_next, cS_next, deltaS_next) 122 | #output.append(h_next) 123 | #hx = hx_next 124 | #output = torch.stack(output, 0) 125 | #return output,hx 126 | #return hx[2],hx 127 | return hx 128 | 129 | def forward(self, input_, length=None, hx=None): 130 | 131 | x_input = input_[0] # output from lstm 132 | grad_input = input_[1] # gradients from learner 133 | if self.batch_first: 134 | x_input = x_input.transpose(0, 1) 135 | grad_input = grad_input.transpose(0, 1) 136 | max_time, batch_size, _ = x_input.data.size() 137 | if length is None: 138 | length = Variable(torch.LongTensor([max_time] * batch_size)) 139 | if x_input.is_cuda: 140 | length = length.cuda() 141 | # hidden variables. Here we have fS, iS and cS. 142 | if hx is None: 143 | fS = Variable(grad_input.data.new(batch_size, 1).zero_()) 144 | iS = Variable(grad_input.data.new(batch_size, 1).zero_()) 145 | cS = (self.cells[0].cI).unsqueeze(1) 146 | deltaS = Variable(grad_input.data.new(batch_size, 1).zero_()) 147 | hx = (fS, iS, cS, deltaS) 148 | 149 | fS_n = [] 150 | iS_n = [] 151 | cS_n = [] 152 | deltaS_n = [] 153 | for layer in range(self.num_layers): 154 | hx_new = MetaLSTM._forward_rnn( 155 | cell=self.cells[layer], input_=x_input, 156 | grads_= grad_input, length=length, hx=hx) 157 | fS_n.append(hx_new[0]) 158 | iS_n.append(hx_new[1]) 159 | cS_n.append(hx_new[2]) 160 | deltaS_n.append(hx_new[3]) 161 | fS_n = torch.stack(fS_n, 0) 162 | iS_n = torch.stack(iS_n, 0) 163 | cS_n = torch.stack(cS_n, 0) 164 | fS_n = torch.stack(fS_n, 0) 165 | deltaS_n = torch.stack(deltaS_n, 0) 166 | # return cS and the actual state 167 | return (fS_n, iS_n, cS_n, deltaS_n) 168 | 169 | -------------------------------------------------------------------------------- /model/lstm/recurrentLSTMNetwork.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.autograd import Variable 4 | 5 | class RecurrentLSTMNetwork(nn.Module): 6 | def __init__(self, opt): 7 | super(RecurrentLSTMNetwork, self).__init__() 8 | 9 | self.inputFeatures = opt['inputFeatures'] if 'inputFeatures' in opt.keys() else 10 10 | self.hiddenFeatures = opt['hiddenFeatures'] if 'hiddenFeatures' in opt.keys() else 100 11 | self.outputType = opt['outputType'] if 'outputType' in opt.keys() else 'last' # 'last' or 'all' 12 | self.batchNormalization = opt['batchNormalization'] if 'batchNormalization' in opt.keys() else False 13 | self.maxBatchNormalizationLayers = opt['maxBatchNormalizationLayers'] if 'batchNormalization' in opt.keys() else 10 14 | 15 | # parameters 16 | self.p = {} 17 | self.p['W'] = Variable(torch.zeros(self.inputFeatures+self.hiddenFeatures,4 * self.hiddenFeatures), 18 | requires_grad = True) 19 | self.params = [self.p['W']] 20 | 21 | #TODO: delete this line. only for debugging 22 | self.batchNormalization = True 23 | 24 | if self.batchNormalization: 25 | # TODO: check if nn.BatchNorm1d or torch.legacy.nn.BatchNormalization 26 | # translation and scaling parameters are shared across time. 27 | lstm_bn = nn.BatchNorm1d(4*self.hiddenFeatures) 28 | cell_bn = nn.BatchNorm1d(self.hiddenFeatures) 29 | self.layers = {'lstm_bn':[lstm_bn],'cell_bn':[cell_bn]} 30 | 31 | for i in range(2,self.maxBatchNormalizationLayers): 32 | lstm_bn = nn.BatchNorm1d(4*self.hiddenFeatures) 33 | cell_bn = nn.BatchNorm1d(self.hiddenFeatures) 34 | self.layers['lstm_bn'].append(lstm_bn) 35 | self.layers['cell_bn'].append(cell_bn) 36 | 37 | # Initializing scaling to <1 is recommended for LSTM batch norm 38 | # TODO: why only the first are initialized?? 39 | self.layers['lstm_bn'][0].weight.data.fill_(0.1) 40 | self.layers['lstm_bn'][0].bias.data.zero_() 41 | self.layers['cell_bn'][0].weight.data.fill_(0.1) 42 | self.layers['cell_bn'][0].bias.data.zero_() 43 | 44 | self.params = self.params + \ 45 | list(self.layers['lstm_bn'][0].parameters()) + \ 46 | list(self.layers['lstm_bn'][0].parameters()) 47 | else: 48 | self.p['b'] = Variable(torch.zeros(1, 4*self.hiddenFeatures), 49 | require_grad = True) 50 | self.params = self.params + [self.p['b']] 51 | self.layers = {} 52 | 53 | def setCuda(self, value = True): 54 | # If value is a string then use self.opt 55 | # If it is not a string then it should be True or False 56 | if value == True: 57 | for key in self.p.keys(): 58 | self.p[key].cuda() 59 | for key in self.layers.keys(): 60 | for i in range(len(self.layers[key])): 61 | self.layers[key][i].cuda() 62 | else: 63 | for key in self.p.keys(): 64 | self.p[key].cpu() 65 | for key in self.layers.keys(): 66 | for i in range(len(self.layers[key])): 67 | self.layers[key][i].cpu() 68 | 69 | def forward(self, x, prevState = None ): 70 | 71 | # dimensions 72 | if len(x.size()) == 2: x = x.unsqueeze(0) 73 | batch = x.size(0) 74 | steps = x.size(1) 75 | 76 | if prevState == None: prevState = {} 77 | hs = {} 78 | cs = {} 79 | for t in range(steps): 80 | # xt 81 | xt = x[:,t,:] 82 | # prev h and pre c 83 | hp = hs[t-1] or prevState.h or torch.zeros() 84 | a = 0 85 | -------------------------------------------------------------------------------- /model/lstm/train-lstm.py: -------------------------------------------------------------------------------- 1 | ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 2 | ## Created by: Albert Berenguel 3 | ## Computer Vision Center (CVC). Universitat Autonoma de Barcelona 4 | ## Email: aberenguel@cvc.uab.es 5 | ## Copyright (c) 2017 6 | ## 7 | ## This source code is licensed under the MIT-style license found in the 8 | ## LICENSE file in the root directory of this source tree 9 | ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 10 | 11 | import os 12 | import torch 13 | import torch.nn as nn 14 | from torch.autograd import Variable 15 | import importlib 16 | import numpy as np 17 | import time 18 | import math 19 | from sklearn.metrics import classification_report 20 | from sklearn.metrics import accuracy_score 21 | from utils import util 22 | from learner import Learner 23 | from metaLearner import MetaLearner 24 | 25 | def create_optimizer(opt, params): 26 | if opt['optimMethod'] == 'sgd': 27 | optimizer = torch.optim.SGD(params, lr=opt['lr'], 28 | momentum=0.9, dampening=0.9, 29 | weight_decay=opt['weight_decay']) 30 | elif opt['optimMethod']: 31 | optimizer = torch.optim.Adam(params, lr=opt['lr'], 32 | weight_decay=opt['weight_decay']) 33 | else: 34 | raise Exception('Not supported optimizer: {0}'.format(opt['optimMethod'])) 35 | return optimizer 36 | 37 | def adjust_learning_rate(opt,optimizer): 38 | """Updates the learning rate given the learning rate decay. 39 | The routine has been implemented according to the original Lua SGD optimizer 40 | """ 41 | for group in optimizer.param_groups: 42 | if 'step' not in group: 43 | group['step'] = 0 44 | group['step'] += 1 45 | 46 | group['lr'] = opt['lr'] / (1 + group['step'] * opt['lr_decay']) 47 | 48 | return optimizer 49 | 50 | def run(opt,data): 51 | 52 | # learner 53 | learner = Learner(opt) 54 | print('Learner nParams: %d' % (learner.nParams)) 55 | 56 | # meta-learner 57 | params_dict = {'learnerParams': learner.params, 58 | 'nParams': learner.nParams} 59 | for param in ['debug','homePath','nHidden','BN1','BN2']: 60 | if param in opt.keys(): 61 | params_dict[param] = opt[param] 62 | metaLearner = MetaLearner(params_dict) 63 | # set cuda 64 | metaLearner.setCuda(opt['useCUDA']) 65 | learner.setCuda(opt['useCUDA']) 66 | 67 | # Keep track of errors 68 | trainConf_pred = [] 69 | trainConf_gt = [] 70 | valConf_pred = {} 71 | valConf_gt = {} 72 | testConf_pred = {} 73 | testConf_gt = {} 74 | for i in opt['nTestShot']: 75 | valConf_pred[i] = [] 76 | valConf_gt[i] = [] 77 | testConf_pred[i] = [] 78 | testConf_gt[i] = [] 79 | 80 | cost = 0 81 | timer = time.time() 82 | 83 | ################################################################# 84 | ############ Meta-training 85 | ################################################################# 86 | 87 | # Init optimizer 88 | #optimizer = create_optimizer(opt, metaLearner.params.values()) 89 | #optimizer = create_optimizer(opt, learner.modelF.net.parameters()) 90 | optimizer = create_optimizer(opt, list(metaLearner.lstm.parameters()) + list(metaLearner.lstm2.parameters())) 91 | 92 | # train episode loop 93 | for episodeTrain,(x_support_set, y_support_set, x_target, target_y) in enumerate(data['train']): 94 | 95 | # Re-arange the Target vectors between [0..nClasses_train] 96 | dictLabels, dictLabelsInverse = util.createDictLabels(y_support_set) 97 | y_support_set = util.fitLabelsToRange(dictLabels, y_support_set) 98 | target_y = util.fitLabelsToRange(dictLabels, target_y) 99 | 100 | # Convert them in Variables 101 | input = {} 102 | trainInput = Variable(x_support_set).float() 103 | trainTarget = Variable(y_support_set,requires_grad=False).long() 104 | testInput = Variable(x_target).float() 105 | testTarget = Variable(target_y,requires_grad=False).long() 106 | 107 | # Convert to GPU if needed 108 | if opt['useCUDA']: 109 | trainInput = trainInput.cuda() 110 | trainTarget = trainTarget.cuda() 111 | testInput = testInput.cuda() 112 | testTarget = testTarget.cuda() 113 | 114 | # learner-optimizer with learner.model.net 115 | 116 | 117 | # forward metalearner 118 | output, loss = metaLearner(learner, trainInput, trainTarget, 119 | testInput, testTarget, 120 | opt['nEpochs'][opt['nTrainShot']], 121 | opt['batchSize'][opt['nTrainShot']]) 122 | optimizer.zero_grad() 123 | loss.backward() 124 | metaLearner.gradNorm(loss) 125 | optimizer.step() 126 | 127 | # Adjust learning rate 128 | optimizer = adjust_learning_rate(opt, optimizer) 129 | 130 | cost = cost + loss 131 | 132 | # update stats 133 | values_pred, indices_pred = torch.max(output, 1) 134 | target_y = util.fitLabelsToRange(dictLabelsInverse, target_y) 135 | indices_pred = util.fitLabelsToRange(dictLabelsInverse, indices_pred.cpu().data) 136 | trainConf_pred.append(indices_pred.numpy()) 137 | trainConf_gt.append(target_y.numpy()) 138 | 139 | print( 140 | 'Training Episode: [{}/{} ({:.0f}%)]\tLoss: {:.3f}. Elapsed: {:.4f} s'.format( 141 | episodeTrain, len(data['train']), 100. * episodeTrain / len(data['train']), 142 | loss.data[0], time.time() - timer)) 143 | 144 | 145 | if episodeTrain % opt['printPer'] == 0: 146 | trainConf_pred = np.concatenate(trainConf_pred, axis=0) 147 | trainConf_gt = np.concatenate(trainConf_gt, axis=0) 148 | target_names = [str(i) for i in np.unique(trainConf_gt)] 149 | 150 | print( 151 | 'Training Episode: [{}/{} ({:.0f}%)]\tCost: {:.3f}. Elapsed: {:.4f} s'.format( 152 | episodeTrain, len(data['train']), 100. * episodeTrain / len(data['train']), 153 | (cost.cpu().data.numpy() / opt['printPer'])[0],time.time() - timer)) 154 | print(classification_report(trainConf_gt, trainConf_pred, 155 | target_names=target_names)) 156 | # Set to 0 157 | trainConf_pred = [] 158 | trainConf_gt = [] 159 | 160 | cost = 0 161 | timer = time.time() 162 | 163 | -------------------------------------------------------------------------------- /model/matching-net-classifier.py: -------------------------------------------------------------------------------- 1 | ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 2 | ## Created by: Albert Berenguel 3 | ## Computer Vision Center (CVC). Universitat Autonoma de Barcelona 4 | ## Email: aberenguel@cvc.uab.es 5 | ## Copyright (c) 2017 6 | ## 7 | ## This source code is licensed under the MIT-style license found in the 8 | ## LICENSE file in the root directory of this source tree 9 | ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 10 | 11 | import torch.nn as nn 12 | import torch.nn.init as init 13 | import numpy as np 14 | import math 15 | 16 | def convLayer(opt, nInput, nOutput, k): 17 | "3x3 convolution with padding" 18 | seq = nn.Sequential( 19 | nn.Conv2d(nInput, nOutput, kernel_size=k, 20 | stride=1, padding=1, bias=True), 21 | nn.BatchNorm2d(nOutput), 22 | nn.ReLU(True), 23 | nn.MaxPool2d(kernel_size=2, stride=2) 24 | ) 25 | if opt['useDropout']: # Add dropout module 26 | list_seq = list(seq.modules())[1:] 27 | list_seq.append(nn.Dropout(0.1)) 28 | seq = nn.Sequential(*list_seq) 29 | return seq 30 | 31 | class Classifier(nn.Module): 32 | def __init__(self, opt): 33 | super(Classifier, self).__init__() 34 | 35 | nFilters = 64 36 | finalSize = int(math.floor(opt['nIn'] / (2 * 2 * 2 * 2))) 37 | 38 | self.layer1 = convLayer(opt, opt['nDepth'], nFilters, 3) 39 | self.layer2 = convLayer(opt, nFilters, nFilters, 3) 40 | self.layer3 = convLayer(opt, nFilters, nFilters, 3) 41 | self.layer4 = convLayer(opt, nFilters, nFilters, 3) 42 | 43 | self.outSize = nFilters*finalSize*finalSize 44 | self.classify = opt['classify'] 45 | if self.classify: 46 | self.layer5 = nn.Linear(nFilters*finalSize*finalSize, opt['nClasses']['train']) 47 | self.outSize = opt['nClasses']['train'] 48 | 49 | # Initialize layers 50 | self.weights_init(self.layer1) 51 | self.weights_init(self.layer2) 52 | self.weights_init(self.layer3) 53 | self.weights_init(self.layer4) 54 | 55 | def weights_init(self,module): 56 | for m in module.modules(): 57 | if isinstance(m, nn.Conv2d): 58 | init.xavier_uniform(m.weight, gain=np.sqrt(2)) 59 | init.constant(m.bias, 0) 60 | elif isinstance(m, nn.BatchNorm2d): 61 | m.weight.data.fill_(1) 62 | m.bias.data.zero_() 63 | 64 | def forward(self, x): 65 | """ 66 | Runs the CNN producing the embeddings and the gradients. 67 | :param image_input: Image input to produce embeddings for. [batch_size, 28, 28, 1] 68 | :return: Embeddings of size [batch_size, 64] 69 | """ 70 | x = self.layer1(x) 71 | x = self.layer2(x) 72 | x = self.layer3(x) 73 | x = self.layer4(x) 74 | x = x.view(x.size(0), -1) 75 | if self.classify: 76 | x = self.layer5(x) 77 | return x 78 | 79 | 80 | class MatchingNetClassifier(): 81 | def __init__(self, opt): 82 | 83 | self.net = Classifier(opt) 84 | if opt['classify']: 85 | self.criterion = nn.CrossEntropyLoss() 86 | else: 87 | self.criterion = [] 88 | self.nParams = np.sum([1 for i in self.net.parameters()]) 89 | self.outSize = self.net.outSize 90 | 91 | def build(opt): 92 | 93 | model = MatchingNetClassifier(opt) 94 | print('created net:') 95 | print(model.net) 96 | return model 97 | -------------------------------------------------------------------------------- /option.py: -------------------------------------------------------------------------------- 1 | ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 2 | ## Created by: Albert Berenguel 3 | ## Computer Vision Center (CVC). Universitat Autonoma de Barcelona 4 | ## Email: aberenguel@cvc.uab.es 5 | ## Copyright (c) 2017 6 | ## 7 | ## This source code is licensed under the MIT-style license found in the 8 | ## LICENSE file in the root directory of this source tree 9 | ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 10 | 11 | import argparse 12 | 13 | class Options(): 14 | def __init__(self): 15 | # Training settings 16 | parser = argparse.ArgumentParser(description='Few-Shot Learning') 17 | parser.add_argument('--task', type=str, default='config.5-shot-5-class', 18 | help='path to config file for task') 19 | parser.add_argument('--data', type=str, default='config.imagenet', 20 | help='path to config file for data') 21 | parser.add_argument('--model', type=str, default='config.lstm.train-imagenet-5shot', 22 | #parser.add_argument('--model', type=str, default='config.baselines.train-matching-net', 23 | help='path to config file for model') 24 | parser.add_argument('--test', type=str, default='-', 25 | help='path to config file for test details') 26 | parser.add_argument('--log-dir', default='./logs', 27 | help='folder to output model checkpoints') 28 | self.parser = parser 29 | 30 | def parse(self): 31 | return self.parser.parse_args() 32 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gitabcworld/FewShotLearning/2eff1881f96212e0fb31737a48f82fab9c2fac83/utils/__init__.py -------------------------------------------------------------------------------- /utils/create_miniImagenet.py: -------------------------------------------------------------------------------- 1 | ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 2 | ## Created by: Albert Berenguel 3 | ## Computer Vision Center (CVC). Universitat Autonoma de Barcelona 4 | ## Email: aberenguel@cvc.uab.es 5 | ## Copyright (c) 2017 6 | ## 7 | ## This source code is licensed under the MIT-style license found in the 8 | ## LICENSE file in the root directory of this source tree 9 | ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 10 | 11 | ''' 12 | This code creates the MiniImagenet dataset. Following the partitions given 13 | by Sachin Ravi and Hugo Larochelle in 14 | https://github.com/twitter/meta-learning-lstm/tree/master/data/miniImagenet 15 | ''' 16 | 17 | import numpy as np 18 | import csv 19 | import glob, os 20 | from shutil import copyfile 21 | import cv2 22 | from tqdm import tqdm 23 | 24 | pathImageNet = '/home/aberenguel/Dataset/Imagenet/ILSVRC2012_img_train' 25 | pathminiImageNet = '/home/aberenguel/TensorFlow/maml/data/miniImagenet/' 26 | pathImages = os.path.join(pathminiImageNet,'images/') 27 | filesCSVSachinRavi = [os.path.join(pathminiImageNet,'train.csv'), 28 | os.path.join(pathminiImageNet,'val.csv'), 29 | os.path.join(pathminiImageNet,'test.csv')] 30 | 31 | # Check if the folder of images exist. If not create it. 32 | if not os.path.exists(pathImages): 33 | os.makedirs(pathImages) 34 | 35 | for filename in filesCSVSachinRavi: 36 | with open(filename) as csvfile: 37 | csv_reader = csv.reader(csvfile, delimiter=',') 38 | next(csv_reader, None) 39 | images = {} 40 | print('Reading IDs....') 41 | for row in tqdm(csv_reader): 42 | if row[1] in images.keys(): 43 | images[row[1]].append(row[0]) 44 | else: 45 | images[row[1]] = [row[0]] 46 | 47 | print('Writing photos....') 48 | for c in tqdm(images.keys()): # Iterate over all the classes 49 | lst_files = [] 50 | for file in glob.glob(pathImageNet + "/*"+c+"*"): 51 | lst_files.append(file) 52 | # TODO: Sort by name of by index number of the image??? 53 | # I sort by the number of the image 54 | lst_index = [int(i[i.rfind('_')+1:i.rfind('.')]) for i in lst_files] 55 | index_sorted = sorted(range(len(lst_index)), key=lst_index.__getitem__) 56 | 57 | # Now iterate 58 | index_selected = [int(i[i.index('.') - 4:i.index('.')]) for i in images[c]] 59 | selected_images = np.array(index_sorted)[np.array(index_selected) - 1] 60 | for i in np.arange(len(selected_images)): 61 | # read file and resize to 84x84x3 62 | #im = cv2.imread(os.path.join(pathImageNet,lst_files[selected_images[i]])) 63 | #im_resized = cv2.resize(im, (84, 84), interpolation=cv2.INTER_AREA) 64 | #cv2.imwrite(os.path.join(pathImages, images[c][i]),im_resized) 65 | copyfile(os.path.join(pathImageNet,lst_files[selected_images[i]]),os.path.join(pathImages, images[c][i])) 66 | 67 | 68 | 69 | 70 | -------------------------------------------------------------------------------- /utils/util.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | def createDictLabels(labels): 5 | 6 | """ 7 | Creates dictionaries that fits data with non-sequential labels into a 8 | sequential order label from [0...nClasses]. 9 | :param labels: all the non-sequential labels 10 | :return: dict that converts from non-sequential to sequential, 11 | dict that converts from sequential to non-sequential 12 | """ 13 | 14 | # Re-arange the Target vectors between [0..nClasses_train] 15 | labels = labels.numpy() 16 | unique_labels = np.unique(labels) 17 | dictLabels = {val: i for i, val in enumerate(unique_labels)} 18 | dictLabelsInverse = {i: val for i, val in enumerate(unique_labels)} 19 | return dictLabels,dictLabelsInverse 20 | 21 | 22 | def fitLabelsToRange(dictLabels,labels): 23 | 24 | """ 25 | Converts Tensor values to the values contained in the dictionary 26 | :param dictLabels: dictionary with the conversion values 27 | :param labels: Tensor to convert 28 | :return: Tensor with the converted labels. 29 | """ 30 | labels = labels.numpy() 31 | unique_labels = np.unique(labels) 32 | labels_temp = np.array(labels) 33 | for i in dictLabels.keys(): 34 | labels_temp[labels == i] = dictLabels[i] 35 | labels = labels_temp 36 | return torch.from_numpy(labels) 37 | 38 | def unflattenParams(model,flatParams): 39 | flatParams = flatParams.squeeze() 40 | indx = 0 41 | for param in model.net.parameters(): 42 | lengthParam = param.view(-1).size()[0] 43 | param.data = flatParams[indx:indx+lengthParam].view_as(param).data 44 | indx = indx + lengthParam 45 | a = 0 46 | 47 | 48 | -------------------------------------------------------------------------------- /visualize/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gitabcworld/FewShotLearning/2eff1881f96212e0fb31737a48f82fab9c2fac83/visualize/__init__.py -------------------------------------------------------------------------------- /visualize/visualize.py: -------------------------------------------------------------------------------- 1 | # Code from https://github.com/szagoruyko/functional-zoo/blob/master/visualize.py 2 | from graphviz import Digraph 3 | import torch 4 | from torch.autograd import Variable 5 | 6 | 7 | def make_dot(var, params=None): 8 | """ Produces Graphviz representation of PyTorch autograd graph 9 | Blue nodes are the Variables that require grad, orange are Tensors 10 | saved for backward in torch.autograd.Function 11 | Args: 12 | var: output Variable 13 | params: dict of (name, Variable) to add names to node that 14 | require grad (TODO: make optional) 15 | """ 16 | if params is not None: 17 | assert isinstance(params.values()[0], Variable) 18 | param_map = {id(v): k for k, v in params.items()} 19 | 20 | node_attr = dict(style='filled', 21 | shape='box', 22 | align='left', 23 | fontsize='12', 24 | ranksep='0.1', 25 | height='0.2') 26 | dot = Digraph(node_attr=node_attr, graph_attr=dict(size="12,12")) 27 | seen = set() 28 | 29 | def size_to_str(size): 30 | return '('+(', ').join(['%d' % v for v in size])+')' 31 | 32 | def add_nodes(var): 33 | if var not in seen: 34 | if torch.is_tensor(var): 35 | dot.node(str(id(var)), size_to_str(var.size()), fillcolor='orange') 36 | elif hasattr(var, 'variable'): 37 | u = var.variable 38 | name = param_map[id(u)] if params is not None else '' 39 | node_name = '%s\n %s' % (name, size_to_str(u.size())) 40 | dot.node(str(id(var)), node_name, fillcolor='lightblue') 41 | else: 42 | dot.node(str(id(var)), str(type(var).__name__)) 43 | seen.add(var) 44 | if hasattr(var, 'next_functions'): 45 | for u in var.next_functions: 46 | if u[0] is not None: 47 | dot.edge(str(id(u[0])), str(id(var))) 48 | add_nodes(u[0]) 49 | if hasattr(var, 'saved_tensors'): 50 | for t in var.saved_tensors: 51 | dot.edge(str(id(t)), str(id(var))) 52 | add_nodes(t) 53 | add_nodes(var.grad_fn) 54 | return dot --------------------------------------------------------------------------------