├── scratch.py ├── tests ├── __init__.py ├── test_proto.py ├── test_few_shot.py └── test_utils.py ├── few_shot ├── __init__.py ├── metrics.py ├── eval.py ├── proto.py ├── utils.py ├── train.py ├── matching.py ├── extmodel_proto_net_clf.py ├── maml.py ├── datasets.py ├── core.py ├── models.py └── callbacks.py ├── scripts ├── __init__.py ├── prepare_mini_imagenet.py └── prepare_omniglot.py ├── assets ├── maml_diagram.png ├── proto_nets_diagram.png └── matching_nets_diagram.png ├── config.py ├── requirements.txt ├── app ├── whale │ ├── k_tta.py │ ├── train.py │ ├── train_exhaustive.py │ ├── k_train.py │ ├── whale_plus_utils.py │ └── whale_utils.py └── app_utils_clf.py ├── .gitignore ├── experiments ├── proto_nets.py ├── experiments.txt └── Test_ImageNet_ResNet18_as_ProtoNet.ipynb └── README.md /scratch.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /few_shot/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /scripts/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /assets/maml_diagram.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/daisukelab/protonet-fine-grained-clf/HEAD/assets/maml_diagram.png -------------------------------------------------------------------------------- /assets/proto_nets_diagram.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/daisukelab/protonet-fine-grained-clf/HEAD/assets/proto_nets_diagram.png -------------------------------------------------------------------------------- /assets/matching_nets_diagram.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/daisukelab/protonet-fine-grained-clf/HEAD/assets/matching_nets_diagram.png -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | 4 | PATH = os.path.dirname(os.path.realpath(__file__)) 5 | 6 | DATA_PATH = None 7 | 8 | EPSILON = 1e-8 9 | 10 | if DATA_PATH is None: 11 | raise Exception('Configure your data folder location in config.py before continuing!') 12 | -------------------------------------------------------------------------------- /few_shot/metrics.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def categorical_accuracy(y, y_pred): 5 | """Calculates categorical accuracy. 6 | 7 | # Arguments: 8 | y_pred: Prediction probabilities or logits of shape [batch_size, num_categories] 9 | y: Ground truth categories. Must have shape [batch_size,] 10 | """ 11 | return torch.eq(y_pred.argmax(dim=-1), y).sum().item() / y_pred.shape[0] 12 | 13 | 14 | NAMED_METRICS = { 15 | 'categorical_accuracy': categorical_accuracy 16 | } 17 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | atomicwrites==1.2.1 2 | attrs==18.2.0 3 | cloudpickle==0.6.1 4 | cycler==0.10.0 5 | dask==0.20.0 6 | decorator==4.3.0 7 | kiwisolver==1.0.1 8 | matplotlib==3.0.1 9 | more-itertools==4.3.0 10 | networkx==2.2 11 | numpy==1.15.3 12 | pandas==0.23.4 13 | Pillow==5.3.0 14 | pkg-resources==0.0.0 15 | pluggy==0.8.0 16 | py==1.7.0 17 | pyparsing==2.3.0 18 | pytest==3.9.3 19 | python-dateutil==2.7.5 20 | pytz==2018.7 21 | PyWavelets==1.0.1 22 | scikit-image==0.14.1 23 | scipy==1.1.0 24 | six==1.11.0 25 | toolz==0.9.0 26 | torch==0.4.1 27 | torchvision==0.2.1 28 | tqdm==4.28.1 29 | -------------------------------------------------------------------------------- /app/whale/k_tta.py: -------------------------------------------------------------------------------- 1 | from dlcliche.image import * 2 | sys.path.append('..') # app 3 | sys.path.append('../..') # root 4 | from easydict import EasyDict 5 | from app_utils_clf import * 6 | from whale_plus_utils import * 7 | from config import DATA_PATH 8 | 9 | calculate_results_plus(weight_file='k/k_k60_epoch600.pth', 10 | output_path='results', 11 | SZ=224, 12 | get_model_fn=get_resnet50, 13 | device=device, 14 | train_csv=DATA_PATH+'/train.csv', 15 | data_train='images/train-448-AC-CR', 16 | data_test='images/test-448-AC-CR', 17 | data_type='normal', 18 | normalize='imagenet', 19 | N_TTA=4) 20 | 21 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | 106 | # Pycharm 107 | .idea/ 108 | -------------------------------------------------------------------------------- /scripts/prepare_mini_imagenet.py: -------------------------------------------------------------------------------- 1 | """ 2 | Run this script to prepare the miniImageNet dataset. 3 | 4 | This script uses the 100 classes of 600 images each used in the Matching Networks paper. The exact images used are 5 | given in data/mini_imagenet.txt which is downloaded from the link provided in the paper (https://goo.gl/e3orz6). 6 | 7 | 1. Download files from https://drive.google.com/file/d/0B3Irx3uQNoBMQ1FlNXJsZUdYWEE/view and place in 8 | data/miniImageNet/images 9 | 2. Run the script 10 | """ 11 | from tqdm import tqdm as tqdm 12 | import numpy as np 13 | import shutil 14 | import os 15 | 16 | from config import DATA_PATH 17 | from few_shot.utils import mkdir, rmdir 18 | 19 | 20 | # Clean up folders 21 | rmdir(DATA_PATH + '/miniImageNet/images_background') 22 | rmdir(DATA_PATH + '/miniImageNet/images_evaluation') 23 | mkdir(DATA_PATH + '/miniImageNet/images_background') 24 | mkdir(DATA_PATH + '/miniImageNet/images_evaluation') 25 | 26 | # Find class identities 27 | classes = [] 28 | for root, _, files in os.walk(DATA_PATH + '/miniImageNet/images/'): 29 | for f in files: 30 | if f.endswith('.jpg'): 31 | classes.append(f[:-12]) 32 | 33 | classes = list(set(classes)) 34 | 35 | # Train/test split 36 | np.random.seed(0) 37 | np.random.shuffle(classes) 38 | background_classes, evaluation_classes = classes[:80], classes[80:] 39 | 40 | # Create class folders 41 | for c in background_classes: 42 | mkdir(DATA_PATH + f'/miniImageNet/images_background/{c}/') 43 | 44 | for c in evaluation_classes: 45 | mkdir(DATA_PATH + f'/miniImageNet/images_evaluation/{c}/') 46 | 47 | # Move images to correct location 48 | for root, _, files in os.walk(DATA_PATH + '/miniImageNet/images'): 49 | for f in tqdm(files, total=600*100): 50 | if f.endswith('.jpg'): 51 | class_name = f[:-12] 52 | image_name = f[-12:] 53 | # Send to correct folder 54 | subset_folder = 'images_evaluation' if class_name in evaluation_classes else 'images_background' 55 | src = f'{root}/{f}' 56 | dst = DATA_PATH + f'/miniImageNet/{subset_folder}/{class_name}/{image_name}' 57 | shutil.copy(src, dst) 58 | -------------------------------------------------------------------------------- /few_shot/eval.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.nn import Module 3 | from torch.utils.data import DataLoader 4 | from typing import Callable, List, Union 5 | 6 | from few_shot.metrics import NAMED_METRICS 7 | 8 | 9 | def evaluate(model: Module, dataloader: DataLoader, prepare_batch: Callable, metrics: List[Union[str, Callable]], 10 | loss_fn: Callable = None, prefix: str = 'val_', suffix: str = ''): 11 | """Evaluate a model on one or more metrics on a particular dataset 12 | 13 | # Arguments 14 | model: Model to evaluate 15 | dataloader: Instance of torch.utils.data.DataLoader representing the dataset 16 | prepare_batch: Callable to perform any desired preprocessing 17 | metrics: List of metrics to evaluate the model with. Metrics must either be a named metric (see `metrics.py`) or 18 | a Callable that takes predictions and ground truth labels and returns a scalar value 19 | loss_fn: Loss function to calculate over the dataset 20 | prefix: Prefix to prepend to the name of each metric - used to identify the dataset. Defaults to 'val_' as 21 | it is typical to evaluate on a held-out validation dataset 22 | suffix: Suffix to append to the name of each metric. 23 | """ 24 | logs = {} 25 | seen = 0 26 | totals = {m: 0 for m in metrics} 27 | if loss_fn is not None: 28 | totals['loss'] = 0 29 | model.eval() 30 | with torch.no_grad(): 31 | for batch in dataloader: 32 | x, y = prepare_batch(batch) 33 | y_pred = model(x) 34 | 35 | seen += x.shape[0] 36 | 37 | if loss_fn is not None: 38 | totals['loss'] += loss_fn(y_pred, y).item() * x.shape[0] 39 | 40 | for m in metrics: 41 | if isinstance(m, str): 42 | v = NAMED_METRICS[m](y, y_pred) 43 | else: 44 | # Assume metric is a callable function 45 | v = m(y, y_pred) 46 | 47 | totals[m] += v * x.shape[0] 48 | 49 | for m in ['loss'] + metrics: 50 | logs[prefix + m + suffix] = totals[m] / seen 51 | 52 | return logs 53 | -------------------------------------------------------------------------------- /app/whale/train.py: -------------------------------------------------------------------------------- 1 | from dlcliche.image import * 2 | sys.path.append('..') # app 3 | sys.path.append('../..') # root 4 | from easydict import EasyDict 5 | from app_utils_clf import * 6 | from whale_utils import * 7 | from config import DATA_PATH 8 | 9 | # Basic training parameters 10 | args = EasyDict() 11 | args.distance = 'l2' 12 | args.n_train = 1 13 | args.n_test = 1 14 | args.q_train = 1 15 | args.q_test = 1 16 | 17 | args.k_train = 50 18 | args.k_test = 10 19 | SZ = 224 20 | RE_SZ = 256 21 | 22 | args.n_epochs = 100 23 | args.drop_lr_every = 50 24 | args.lr = 3e-3 25 | args.init_weight = None 26 | 27 | data_train = DATA_PATH+'/train' 28 | data_test = DATA_PATH+'/test' 29 | 30 | args.param_str = f'app_whale_n{args.n_train}_k{args.k_train}_q{args.q_train}' 31 | args.checkpoint_monitor = 'categorical_accuracy' 32 | args.checkpoint_period = 50 33 | 34 | print(f'Training {args.param_str}.') 35 | 36 | # Data - 'more_than_two' or 'exhaustive' 37 | trn_images, trn_labels, val_images, val_labels = get_training_data_lists('more_than_two') 38 | 39 | args.episodes_per_epoch = len(trn_images) // args.k_train + 1 40 | args.evaluation_episodes = 100 # setting small value, anyway validation set is almost useless here 41 | 42 | print(f'Samples = {len(trn_images)}, {len(val_images)}') 43 | 44 | # Model 45 | feature_model = get_resnet18(device=device, weight_file=args.init_weight) 46 | 47 | # Dataloader 48 | background = WhaleImages(data_train, trn_images, trn_labels, re_size=RE_SZ, to_size=SZ) 49 | background_taskloader = DataLoader( 50 | background, 51 | batch_sampler=NShotTaskSampler(background, args.episodes_per_epoch, args.n_train, args.k_train, args.q_train), 52 | num_workers=8 53 | ) 54 | evaluation = WhaleImages(data_train, val_images, val_labels, re_size=RE_SZ, to_size=SZ, train=False) 55 | evaluation_taskloader = DataLoader( 56 | evaluation, 57 | batch_sampler=NShotTaskSampler(evaluation, args.episodes_per_epoch, args.n_test, args.k_test, args.q_test), 58 | num_workers=8 59 | ) 60 | 61 | # Train 62 | train_proto_net(args, 63 | model=feature_model, 64 | device=device, 65 | path='.', 66 | n_epochs=args.n_epochs, 67 | background_taskloader=background_taskloader, 68 | evaluation_taskloader=evaluation_taskloader, 69 | drop_lr_every=args.drop_lr_every, 70 | evaluation_episodes=args.evaluation_episodes, 71 | episodes_per_epoch=args.episodes_per_epoch, 72 | lr=args.lr, 73 | ) 74 | torch.save(feature_model.state_dict(), f'{args.param_str}_epoch{args.n_epochs}.pth') 75 | -------------------------------------------------------------------------------- /app/whale/train_exhaustive.py: -------------------------------------------------------------------------------- 1 | from dlcliche.image import * 2 | sys.path.append('..') # app 3 | sys.path.append('../..') # root 4 | from easydict import EasyDict 5 | from app_utils_clf import * 6 | from whale_utils import * 7 | from config import DATA_PATH 8 | 9 | # Basic training parameters 10 | args = EasyDict() 11 | args.distance = 'l2' 12 | args.n_train = 1 13 | args.n_test = 1 14 | args.q_train = 1 15 | args.q_test = 1 16 | 17 | args.k_train = 50 18 | args.k_test = 10 19 | SZ = 224 20 | RE_SZ = 256 21 | 22 | args.n_epochs = 100 23 | args.drop_lr_every = 50 24 | args.lr = 3e-3 25 | args.init_weight = None 26 | 27 | data_train = DATA_PATH+'/train' 28 | data_test = DATA_PATH+'/test' 29 | 30 | args.param_str = f'app_whale_n{args.n_train}_k{args.k_train}_q{args.q_train}' 31 | args.checkpoint_monitor = 'categorical_accuracy' 32 | args.checkpoint_period = 50 33 | 34 | print(f'Training {args.param_str}.') 35 | 36 | # Data - 'more_than_two' or 'exhaustive' 37 | trn_images, trn_labels, val_images, val_labels = get_training_data_lists('exhaustive') 38 | 39 | args.episodes_per_epoch = len(trn_images) // args.k_train + 1 40 | args.evaluation_episodes = 100 # setting small value, anyway validation set is almost useless here 41 | 42 | print(f'Samples = {len(trn_images)}, {len(val_images)}') 43 | 44 | # Model 45 | feature_model = get_resnet18(device=device, weight_file=args.init_weight) 46 | 47 | # Dataloader 48 | background = WhaleImages(data_train, trn_images, trn_labels, re_size=RE_SZ, to_size=SZ) 49 | background_taskloader = DataLoader( 50 | background, 51 | batch_sampler=NShotTaskSampler(background, args.episodes_per_epoch, args.n_train, args.k_train, args.q_train), 52 | num_workers=8 53 | ) 54 | evaluation = WhaleImages(data_train, val_images, val_labels, re_size=RE_SZ, to_size=SZ, train=False) 55 | evaluation_taskloader = DataLoader( 56 | evaluation, 57 | batch_sampler=NShotTaskSampler(evaluation, args.episodes_per_epoch, args.n_test, args.k_test, args.q_test), 58 | num_workers=8 59 | ) 60 | 61 | # Train 62 | train_proto_net(args, 63 | model=feature_model, 64 | device=device, 65 | path='.', 66 | n_epochs=args.n_epochs, 67 | background_taskloader=background_taskloader, 68 | evaluation_taskloader=evaluation_taskloader, 69 | drop_lr_every=args.drop_lr_every, 70 | evaluation_episodes=args.evaluation_episodes, 71 | episodes_per_epoch=args.episodes_per_epoch, 72 | lr=args.lr, 73 | ) 74 | torch.save(feature_model.state_dict(), f'{args.param_str}_epoch{args.n_epochs}.pth') -------------------------------------------------------------------------------- /tests/test_proto.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | from torch.utils.data import DataLoader 4 | 5 | from few_shot.core import NShotTaskSampler 6 | from few_shot.datasets import DummyDataset, OmniglotDataset, MiniImageNet 7 | from few_shot.models import get_few_shot_encoder 8 | from few_shot.proto import compute_prototypes 9 | 10 | 11 | class TestProtoNets(unittest.TestCase): 12 | @classmethod 13 | def setUpClass(cls): 14 | cls.dataset = DummyDataset(samples_per_class=1000, n_classes=20) 15 | 16 | def _test_n_k_q_combination(self, n, k, q): 17 | n_shot_taskloader = DataLoader(self.dataset, 18 | batch_sampler=NShotTaskSampler(self.dataset, 100, n, k, q)) 19 | 20 | # Load a single n-shot, k-way task 21 | for batch in n_shot_taskloader: 22 | x, y = batch 23 | break 24 | 25 | support = x[:n * k] 26 | support_labels = y[:n * k] 27 | prototypes = compute_prototypes(support, k, n) 28 | 29 | # By construction the second feature of samples from the 30 | # DummyDataset is equal to the label. 31 | # As class prototypes are constructed from the means of the support 32 | # set items of a particular class the value of the second feature 33 | # of the class prototypes should be equal to the label of that class. 34 | for i in range(k): 35 | self.assertEqual( 36 | support_labels[i * n], 37 | prototypes[i, 1], 38 | 'Prototypes computed incorrectly!' 39 | ) 40 | 41 | def test_compute_prototypes(self): 42 | test_combinations = [ 43 | (1, 5, 5), 44 | (5, 5, 5), 45 | (1, 20, 5), 46 | (5, 20, 5) 47 | ] 48 | 49 | for n, k, q in test_combinations: 50 | self._test_n_k_q_combination(n, k, q) 51 | 52 | def test_create_model(self): 53 | # Check output of encoder has shape specified in paper 54 | encoder = get_few_shot_encoder(num_input_channels=1).float() 55 | omniglot = OmniglotDataset('background') 56 | self.assertEqual( 57 | encoder(omniglot[0][0].unsqueeze(0).float()).shape[1], 58 | 64, 59 | 'Encoder network should produce 64 dimensional embeddings on Omniglot dataset.' 60 | ) 61 | 62 | encoder = get_few_shot_encoder(num_input_channels=3).float() 63 | omniglot = MiniImageNet('background') 64 | self.assertEqual( 65 | encoder(omniglot[0][0].unsqueeze(0).float()).shape[1], 66 | 1600, 67 | 'Encoder network should produce 1600 dimensional embeddings on miniImageNet dataset.' 68 | ) -------------------------------------------------------------------------------- /tests/test_few_shot.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | import torch 4 | from torch.utils.data import DataLoader 5 | 6 | from few_shot.core import create_nshot_task_label, NShotTaskSampler 7 | from few_shot.datasets import DummyDataset 8 | 9 | 10 | class TestNShotLabel(unittest.TestCase): 11 | def test_label(self): 12 | n = 1 13 | k = 5 14 | q = 1 15 | 16 | y = create_nshot_task_label(k, q) 17 | 18 | 19 | class TestNShotSampler(unittest.TestCase): 20 | @classmethod 21 | def setUpClass(cls): 22 | cls.dataset = DummyDataset(samples_per_class=1000, n_classes=20) 23 | 24 | def test_n_shot_sampler(self): 25 | n, k, q = 2, 4, 3 26 | n_shot_taskloader = DataLoader(self.dataset, 27 | batch_sampler=NShotTaskSampler(self.dataset, 100, n, k, q)) 28 | 29 | # Load a single n-shot task and check it's properties 30 | for x, y in n_shot_taskloader: 31 | support = x[:n*k] 32 | queries = x[n*k:] 33 | support_labels = y[:n*k] 34 | query_labels = y[n*k:] 35 | 36 | # Check ordering of support labels is correct 37 | for i in range(0, n * k, n): 38 | support_set_labels_correct = torch.all(support_labels[i:i + n] == support_labels[i]) 39 | self.assertTrue( 40 | support_set_labels_correct, 41 | 'Classes of support set samples should be arranged like: ' 42 | '[class_1]*n + [class_2]*n + ... + [class_k]*n' 43 | ) 44 | 45 | # Check ordering of query labels is correct 46 | for i in range(0, q * k, q): 47 | support_set_labels_correct = torch.all(query_labels[i:i + q] == query_labels[i]) 48 | self.assertTrue( 49 | support_set_labels_correct, 50 | 'Classes of query set samples should be arranged like: ' 51 | '[class_1]*q + [class_2]*q + ... + [class_k]*q' 52 | ) 53 | 54 | # Check labels are consistent across query and support 55 | for i in range(k): 56 | self.assertEqual( 57 | support_labels[i*n], 58 | query_labels[i*q], 59 | 'Classes of query and support set should be consistent.' 60 | ) 61 | 62 | # Check no overlap of IDs between support and query. 63 | # By construction the first feature in the DummyDataset is the 64 | # id of the sample in the dataset so we can use this to test 65 | # for overlap betwen query and suppport samples 66 | self.assertEqual( 67 | len(set(support[:, 0].numpy()).intersection(set(queries[:, 0].numpy()))), 68 | 0, 69 | 'There should be no overlap between support and query set samples.' 70 | ) 71 | 72 | break -------------------------------------------------------------------------------- /scripts/prepare_omniglot.py: -------------------------------------------------------------------------------- 1 | """ 2 | Run this script to prepare the Omniglot dataset from the raw Omniglot dataset that is found at 3 | https://github.com/brendenlake/omniglot/tree/master/python. 4 | 5 | This script prepares an enriched version of Omniglot the same as is used in the Matching Networks and Prototypical 6 | Networks papers. 7 | 8 | 1. Augment classes with rotations in multiples of 90 degrees. 9 | 2. Downsize images to 28x28 10 | 3. Uses background and evaluation sets present in the raw dataset 11 | """ 12 | from skimage import io 13 | from skimage import transform 14 | import zipfile 15 | import shutil 16 | import os 17 | 18 | from config import DATA_PATH 19 | from few_shot.utils import mkdir, rmdir 20 | 21 | 22 | # Parameters 23 | dataset_zip_files = ['images_background.zip', 'images_evaluation.zip'] 24 | raw_omniglot_location = DATA_PATH + '/Omniglot_Raw/' 25 | prepared_omniglot_location = DATA_PATH + '/Omniglot/' 26 | output_shape = (28, 28) 27 | 28 | 29 | def handle_characters(alphabet_folder, character_folder, rotate): 30 | for root, _, character_images in os.walk(character_folder): 31 | character_name = root.split('/')[-1] 32 | mkdir(f'{alphabet_folder}.{rotate}/{character_name}') 33 | for img_path in character_images: 34 | # print(root+'/'+img_path) 35 | img = io.imread(root+'/'+img_path) 36 | img = transform.rotate(img, angle=rotate) 37 | img = transform.resize(img, output_shape, anti_aliasing=True) 38 | img = (img - img.min()) / (img.max() - img.min()) 39 | # print(img.min(), img.max()) 40 | # print(f'{alphabet_folder}.{rotate}/{character_name}/{img_path}') 41 | io.imsave(f'{alphabet_folder}.{rotate}/{character_name}/{img_path}', img) 42 | # return 43 | 44 | 45 | def handle_alphabet(folder): 46 | print('{}...'.format(folder.split('/')[-1])) 47 | for rotate in [0, 90, 180, 270]: 48 | # Create new folders for each augmented alphabet 49 | mkdir(f'{folder}.{rotate}') 50 | for root, character_folders, _ in os.walk(folder): 51 | for character_folder in character_folders: 52 | # For each character folder in an alphabet rotate and resize all of the images and save 53 | # to the new folder 54 | handle_characters(folder, root + '/' + character_folder, rotate) 55 | # return 56 | 57 | # Delete original alphabet 58 | rmdir(folder) 59 | 60 | 61 | # Clean up previous extraction 62 | rmdir(prepared_omniglot_location) 63 | mkdir(prepared_omniglot_location) 64 | 65 | # Unzip dataset 66 | for root, _, files in os.walk(raw_omniglot_location): 67 | for f in files: 68 | if f in dataset_zip_files: 69 | print('Unzipping {}...'.format(f)) 70 | zip_ref = zipfile.ZipFile(root + f, 'r') 71 | zip_ref.extractall(prepared_omniglot_location) 72 | zip_ref.close() 73 | 74 | print('Processing background set...') 75 | for root, alphabets, _ in os.walk(prepared_omniglot_location + 'images_background/'): 76 | for alphabet in sorted(alphabets): 77 | handle_alphabet(root + alphabet) 78 | 79 | print('Processing evaluation set...') 80 | for root, alphabets, _ in os.walk(prepared_omniglot_location + 'images_evaluation/'): 81 | for alphabet in sorted(alphabets): 82 | handle_alphabet(root + alphabet) 83 | -------------------------------------------------------------------------------- /app/whale/k_train.py: -------------------------------------------------------------------------------- 1 | from dlcliche.image import * 2 | sys.path.append('..') # app 3 | sys.path.append('../..') # root 4 | from easydict import EasyDict 5 | from app_utils_clf import * 6 | from whale_plus_utils import * 7 | from config import DATA_PATH 8 | args = EasyDict() 9 | 10 | #### LOOK HERE #### 11 | name = 'k' 12 | args.normalize = 'imagenet' 13 | #### LOOK HERE #### 14 | 15 | # Basic training parameters 16 | args.distance = 'l2' 17 | args.n_train = 1 18 | args.n_test = 1 19 | args.q_train = 1 20 | args.q_test = 1 21 | 22 | args.k_train = 60 23 | args.k_test = 10 24 | SZ = 224 25 | RE_SZ = 256 26 | 27 | args.n_epochs = 600 28 | args.drop_lr_every = 30 29 | args.lr = 3e-3 30 | args.init_weight = None 31 | args.part = 0 32 | args.n_part = 1 33 | args.augment_train = 'train' 34 | 35 | data_train = DATA_PATH+'/train' 36 | data_test = DATA_PATH+'/test' 37 | TRN_N_IMAGES = 2 38 | 39 | args.param_str = f'{name}_k{args.k_train}' 40 | args.checkpoint_monitor = 'categorical_accuracy' 41 | args.checkpoint_period = 50 42 | 43 | print(f'Training {args.param_str}.') 44 | 45 | # Data 46 | df = pd.read_csv(DATA_PATH+'/train.csv') 47 | df = df[df.Id != 'new_whale'] 48 | ids = df.Id.values 49 | classes = sorted(list(set(ids))) 50 | images = df.Image.values 51 | all_cls2imgs = {cls:images[ids == cls] for cls in classes} 52 | 53 | trn_images = [image for image, _id in zip(images, ids) if len(all_cls2imgs[_id]) >= TRN_N_IMAGES] 54 | trn_labels = [_id for image, _id in zip(images, ids) if len(all_cls2imgs[_id]) >= TRN_N_IMAGES] 55 | val_images = [image for image, _id in zip(images, ids) if len(all_cls2imgs[_id]) == 2] 56 | val_labels = [_id for image, _id in zip(images, ids) if len(all_cls2imgs[_id]) == 2] 57 | 58 | args.episodes_per_epoch = len(trn_images) // args.k_train + 1 59 | args.evaluation_episodes = 50 # setting small value, anyway validation set is almost useless here 60 | 61 | print(f'Samples = {len(trn_images)}, {len(val_images)}') 62 | 63 | # Model 64 | feature_model = get_resnet50(device=device, weight_file=args.init_weight) 65 | 66 | # Dataloader 67 | background = WhaleImagesPlus(data_train, trn_images, trn_labels, re_size=RE_SZ, to_size=SZ, augment=args.augment_train, 68 | part=args.part, n_part=args.n_part, normalize=args.normalize) 69 | background_taskloader = DataLoader( 70 | background, 71 | batch_sampler=NShotTaskSampler(background, args.episodes_per_epoch, args.n_train, args.k_train, args.q_train), 72 | num_workers=8 73 | ) 74 | evaluation = WhaleImagesPlus(data_train, val_images, val_labels, re_size=RE_SZ, to_size=SZ, augment='test', 75 | part=args.part, n_part=args.n_part, normalize=args.normalize) 76 | evaluation_taskloader = DataLoader( 77 | evaluation, 78 | batch_sampler=NShotTaskSampler(evaluation, args.episodes_per_epoch, args.n_test, args.k_test, args.q_test), 79 | num_workers=8 80 | ) 81 | 82 | # Train 83 | train_proto_net(args, 84 | model=feature_model, 85 | device=device, 86 | path=name, 87 | n_epochs=args.n_epochs, 88 | background_taskloader=background_taskloader, 89 | evaluation_taskloader=evaluation_taskloader, 90 | drop_lr_every=args.drop_lr_every, 91 | evaluation_episodes=args.evaluation_episodes, 92 | episodes_per_epoch=args.episodes_per_epoch, 93 | lr=args.lr, 94 | ) 95 | torch.save(feature_model.state_dict(), f'{name}/{args.param_str}_epoch{args.n_epochs}.pth') 96 | -------------------------------------------------------------------------------- /few_shot/proto.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.optim import Optimizer 3 | from torch.nn import Module 4 | from typing import Callable 5 | 6 | from few_shot.utils import pairwise_distances 7 | 8 | 9 | def proto_net_episode(model: Module, 10 | optimiser: Optimizer, 11 | loss_fn: Callable, 12 | x: torch.Tensor, 13 | y: torch.Tensor, 14 | n_shot: int, 15 | k_way: int, 16 | q_queries: int, 17 | distance: str, 18 | train: bool): 19 | """Performs a single training episode for a Prototypical Network. 20 | 21 | # Arguments 22 | model: Prototypical Network to be trained. 23 | optimiser: Optimiser to calculate gradient step 24 | loss_fn: Loss function to calculate between predictions and outputs. Should be cross-entropy 25 | x: Input samples of few shot classification task 26 | y: Input labels of few shot classification task 27 | n_shot: Number of examples per class in the support set 28 | k_way: Number of classes in the few shot classification task 29 | q_queries: Number of examples per class in the query set 30 | distance: Distance metric to use when calculating distance between class prototypes and queries 31 | train: Whether (True) or not (False) to perform a parameter update 32 | 33 | # Returns 34 | loss: Loss of the Prototypical Network on this task 35 | y_pred: Predicted class probabilities for the query set on this task 36 | """ 37 | if train: 38 | # Zero gradients 39 | model.train() 40 | optimiser.zero_grad() 41 | else: 42 | model.eval() 43 | 44 | # Embed all samples 45 | embeddings = model(x) 46 | 47 | # Samples are ordered by the NShotWrapper class as follows: 48 | # k lots of n support samples from a particular class 49 | # k lots of q query samples from those classes 50 | support = embeddings[:n_shot*k_way] 51 | queries = embeddings[n_shot*k_way:] 52 | prototypes = compute_prototypes(support, k_way, n_shot) 53 | 54 | # Calculate squared distances between all queries and all prototypes 55 | # Output should have shape (q_queries * k_way, k_way) = (num_queries, k_way) 56 | distances = pairwise_distances(queries, prototypes, distance) 57 | 58 | # Calculate log p_{phi} (y = k | x) 59 | log_p_y = (-distances).log_softmax(dim=1) 60 | loss = loss_fn(log_p_y, y) 61 | 62 | # Prediction probabilities are softmax over distances 63 | y_pred = (-distances).softmax(dim=1) 64 | 65 | if train: 66 | # Take gradient step 67 | loss.backward() 68 | optimiser.step() 69 | else: 70 | pass 71 | 72 | return loss, y_pred 73 | 74 | 75 | def compute_prototypes(support: torch.Tensor, k: int, n: int) -> torch.Tensor: 76 | """Compute class prototypes from support samples. 77 | 78 | # Arguments 79 | support: torch.Tensor. Tensor of shape (n * k, d) where d is the embedding 80 | dimension. 81 | k: int. "k-way" i.e. number of classes in the classification task 82 | n: int. "n-shot" of the classification task 83 | 84 | # Returns 85 | class_prototypes: Prototypes aka mean embeddings for each class 86 | """ 87 | # Reshape so the first dimension indexes by class then take the mean 88 | # along that dimension to generate the "prototypes" for each class 89 | class_prototypes = support.reshape(k, n, -1).mean(dim=1) 90 | return class_prototypes -------------------------------------------------------------------------------- /experiments/proto_nets.py: -------------------------------------------------------------------------------- 1 | """ 2 | Reproduce Omniglot results of Snell et al Prototypical networks. 3 | """ 4 | from torch.optim import Adam 5 | from torch.utils.data import DataLoader 6 | import argparse 7 | 8 | from few_shot.datasets import OmniglotDataset, MiniImageNet 9 | from few_shot.models import get_few_shot_encoder 10 | from few_shot.core import NShotTaskSampler, EvaluateFewShot, prepare_nshot_task 11 | from few_shot.proto import proto_net_episode 12 | from few_shot.train import fit 13 | from few_shot.callbacks import * 14 | from few_shot.utils import setup_dirs 15 | from config import PATH 16 | 17 | 18 | setup_dirs() 19 | assert torch.cuda.is_available() 20 | device = torch.device('cuda') 21 | torch.backends.cudnn.benchmark = True 22 | 23 | 24 | ############## 25 | # Parameters # 26 | ############## 27 | parser = argparse.ArgumentParser() 28 | parser.add_argument('--dataset') 29 | parser.add_argument('--distance', default='l2') 30 | parser.add_argument('--n-train', default=1, type=int) 31 | parser.add_argument('--n-test', default=1, type=int) 32 | parser.add_argument('--k-train', default=60, type=int) 33 | parser.add_argument('--k-test', default=5, type=int) 34 | parser.add_argument('--q-train', default=5, type=int) 35 | parser.add_argument('--q-test', default=1, type=int) 36 | args = parser.parse_args() 37 | 38 | evaluation_episodes = 1000 39 | episodes_per_epoch = 100 40 | 41 | if args.dataset == 'omniglot': 42 | n_epochs = 40 43 | dataset_class = OmniglotDataset 44 | num_input_channels = 1 45 | drop_lr_every = 20 46 | elif args.dataset == 'miniImageNet': 47 | n_epochs = 80 48 | dataset_class = MiniImageNet 49 | num_input_channels = 3 50 | drop_lr_every = 40 51 | else: 52 | raise(ValueError, 'Unsupported dataset') 53 | 54 | param_str = f'{args.dataset}_nt={args.n_train}_kt={args.k_train}_qt={args.q_train}_' \ 55 | f'nv={args.n_test}_kv={args.k_test}_qv={args.q_test}' 56 | 57 | print(param_str) 58 | 59 | ################### 60 | # Create datasets # 61 | ################### 62 | background = dataset_class('background') 63 | background_taskloader = DataLoader( 64 | background, 65 | batch_sampler=NShotTaskSampler(background, episodes_per_epoch, args.n_train, args.k_train, args.q_train), 66 | num_workers=4 67 | ) 68 | evaluation = dataset_class('evaluation') 69 | evaluation_taskloader = DataLoader( 70 | evaluation, 71 | batch_sampler=NShotTaskSampler(evaluation, episodes_per_epoch, args.n_test, args.k_test, args.q_test), 72 | num_workers=4 73 | ) 74 | 75 | 76 | ######### 77 | # Model # 78 | ######### 79 | model = get_few_shot_encoder(num_input_channels) 80 | model.to(device, dtype=torch.float) 81 | 82 | 83 | ############ 84 | # Training # 85 | ############ 86 | print(f'Training Prototypical network on {args.dataset}...') 87 | optimiser = Adam(model.parameters(), lr=1e-3) 88 | loss_fn = torch.nn.NLLLoss().cuda() 89 | 90 | 91 | def lr_schedule(epoch, lr): 92 | # Drop lr every 2000 episodes 93 | if epoch % drop_lr_every == 0: 94 | return lr / 2 95 | else: 96 | return lr 97 | 98 | 99 | callbacks = [ 100 | EvaluateFewShot( 101 | eval_fn=proto_net_episode, 102 | num_tasks=evaluation_episodes, 103 | n_shot=args.n_test, 104 | k_way=args.k_test, 105 | q_queries=args.q_test, 106 | taskloader=evaluation_taskloader, 107 | prepare_batch=prepare_nshot_task(args.n_test, args.k_test, args.q_test), 108 | distance=args.distance 109 | ), 110 | ModelCheckpoint( 111 | filepath=PATH + f'/models/proto_nets/{param_str}.pth', 112 | monitor=f'val_{args.n_test}-shot_{args.k_test}-way_acc' 113 | ), 114 | LearningRateScheduler(schedule=lr_schedule), 115 | CSVLogger(PATH + f'/logs/proto_nets/{param_str}.csv'), 116 | ] 117 | 118 | fit( 119 | model, 120 | optimiser, 121 | loss_fn, 122 | epochs=n_epochs, 123 | dataloader=background_taskloader, 124 | prepare_batch=prepare_nshot_task(args.n_train, args.k_train, args.q_train), 125 | callbacks=callbacks, 126 | metrics=['categorical_accuracy'], 127 | fit_function=proto_net_episode, 128 | fit_function_kwargs={'n_shot': args.n_train, 'k_way': args.k_train, 'q_queries': args.q_train, 'train': True, 129 | 'distance': args.distance}, 130 | ) 131 | -------------------------------------------------------------------------------- /experiments/experiments.txt: -------------------------------------------------------------------------------- 1 | # Proto Net experiments 2 | python -m experiments.proto_nets --dataset omniglot --k-test 5 --n-test 1 3 | python -m experiments.proto_nets --dataset omniglot --k-test 5 --n-test 5 4 | python -m experiments.proto_nets --dataset omniglot --k-test 20 --n-test 1 5 | python -m experiments.proto_nets --dataset omniglot --k-test 20 --n-test 5 --n-train 5 6 | 7 | python -m experiments.proto_nets --dataset miniImageNet --k-test 5 --n-test 1 --k-train 20 --n-train 1 --q-train 15 8 | python -m experiments.proto_nets --dataset miniImageNet --k-test 5 --n-test 5 --k-train 20 --n-train 5 --q-train 15 9 | 10 | # Matching Network experiments 11 | python -m experiments.matching_nets --dataset omniglot --fce False --k-test 5 --n-test 1 --distance l2 12 | python -m experiments.matching_nets --dataset omniglot --fce False --k-test 5 --n-test 5 --distance l2 13 | python -m experiments.matching_nets --dataset omniglot --fce False --k-test 20 --n-test 1 --distance l2 14 | python -m experiments.matching_nets --dataset omniglot --fce False --k-test 20 --n-test 5 --distance l2 15 | 16 | python -m experiments.matching_nets --dataset omniglot --fce False --k-test 5 --n-test 1 --distance cosine 17 | python -m experiments.matching_nets --dataset omniglot --fce False --k-test 5 --n-test 5 --distance cosine 18 | python -m experiments.matching_nets --dataset omniglot --fce False --k-test 20 --n-test 1 --distance cosine 19 | python -m experiments.matching_nets --dataset omniglot --fce False --k-test 20 --n-test 5 --distance cosine 20 | 21 | python -m experiments.matching_nets --dataset miniImageNet --fce False --k-test 5 --n-test 1 --distance l2 22 | python -m experiments.matching_nets --dataset miniImageNet --fce False --k-test 5 --n-test 5 --distance l2 23 | python -m experiments.matching_nets --dataset miniImageNet --fce True --k-test 5 --n-test 1 --distance l2 24 | python -m experiments.matching_nets --dataset miniImageNet --fce True --k-test 5 --n-test 5 --n-train 5 --distance l2 25 | 26 | python -m experiments.matching_nets --dataset miniImageNet --fce False --k-test 5 --n-test 1 --distance cosine 27 | python -m experiments.matching_nets --dataset miniImageNet --fce False --k-test 5 --n-test 5 --distance cosine 28 | python -m experiments.matching_nets --dataset miniImageNet --fce True --k-test 5 --n-test 1 --distance cosine 29 | python -m experiments.matching_nets --dataset miniImageNet --fce True --k-test 5 --n-test 5 --n-train 5 --distance cosine 30 | 31 | # 1st order MAML 32 | python -m experiments.maml --dataset omniglot --order 1 --n 1 --k 5 --eval-batches 10 --epoch-len 50 33 | python -m experiments.maml --dataset omniglot --order 1 --n 5 --k 5 --eval-batches 10 --epoch-len 50 34 | python -m experiments.maml --dataset omniglot --order 1 --n 1 --k 20 --meta-batch-size 16 \ 35 | --inner-train-steps 5 --inner-val-steps 5 --inner-lr 0.1 --eval-batches 20 --epoch-len 100 36 | python -m experiments.maml --dataset omniglot --order 1 --n 5 --k 20 --meta-batch-size 16 \ 37 | --inner-train-steps 5 --inner-val-steps 5 --inner-lr 0.1 --eval-batches 20 --epoch-len 100 38 | 39 | python -m experiments.maml --dataset miniImageNet --order 1 --n 1 --k 5 --q 5 --meta-batch-size 4 \ 40 | --inner-train-steps 5 --inner-val-steps 10 --inner-lr 0.01 --eval-batches 40 --epoch-len 400 41 | python -m experiments.maml --dataset miniImageNet --order 1 --n 5 --k 5 --q 5 --meta-batch-size 4 \ 42 | --inner-train-steps 5 --inner-val-steps 10 --inner-lr 0.01 --eval-batches 40 --epoch-len 400 43 | 44 | # 2nd order MAML 45 | python -m experiments.maml --dataset omniglot --order 2 --n 1 --k 5 --eval-batches 10 --epoch-len 50 46 | python -m experiments.maml --dataset omniglot --order 2 --n 5 --k 5 --eval-batches 20 --epoch-len 100 \ 47 | --meta-batch-size 16 --eval-batches 20 48 | python -m experiments.maml --dataset omniglot --order 2 --n 1 --k 20 --meta-batch-size 16 \ 49 | --inner-train-steps 5 --inner-val-steps 5 --inner-lr 0.1 --eval-batches 40 --epoch-len 200 50 | python -m experiments.maml --dataset omniglot --order 2 --n 5 --k 20 --meta-batch-size 4 \ 51 | --inner-train-steps 5 --inner-val-steps 5 --inner-lr 0.1 --eval-batches 80 --epoch-len 400 52 | 53 | python -m experiments.maml --dataset miniImageNet --order 2 --n 1 --k 5 --q 5 --meta-batch-size 4 \ 54 | --inner-train-steps 5 --inner-val-steps 10 --inner-lr 0.01 --eval-batches 80 --epoch-len 400 55 | python -m experiments.maml --dataset miniImageNet --order 2 --n 5 --k 5 --q 5 --meta-batch-size 2 \ 56 | --inner-train-steps 5 --inner-val-steps 10 --inner-lr 0.01 --eval-batches 80 --epoch-len 800 57 | -------------------------------------------------------------------------------- /few_shot/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | import shutil 4 | from typing import Tuple, List 5 | 6 | from config import EPSILON, PATH 7 | 8 | 9 | def mkdir(dir): 10 | """Create a directory, ignoring exceptions 11 | 12 | # Arguments: 13 | dir: Path of directory to create 14 | """ 15 | try: 16 | os.mkdir(dir) 17 | except: 18 | pass 19 | 20 | 21 | def rmdir(dir): 22 | """Recursively remove a directory and contents, ignoring exceptions 23 | 24 | # Arguments: 25 | dir: Path of directory to recursively remove 26 | """ 27 | try: 28 | shutil.rmtree(dir) 29 | except: 30 | pass 31 | 32 | 33 | def setup_dirs(): 34 | """Creates directories for this project.""" 35 | mkdir(PATH + '/logs/') 36 | mkdir(PATH + '/logs/proto_nets') 37 | mkdir(PATH + '/logs/matching_nets') 38 | mkdir(PATH + '/models/') 39 | mkdir(PATH + '/models/proto_nets') 40 | mkdir(PATH + '/models/matching_nets') 41 | 42 | 43 | def pairwise_distances(x: torch.Tensor, 44 | y: torch.Tensor, 45 | matching_fn: str) -> torch.Tensor: 46 | """Efficiently calculate pairwise distances (or other similarity scores) between 47 | two sets of samples. 48 | 49 | # Arguments 50 | x: Query samples. A tensor of shape (n_x, d) where d is the embedding dimension 51 | y: Class prototypes. A tensor of shape (n_y, d) where d is the embedding dimension 52 | matching_fn: Distance metric/similarity score to compute between samples 53 | """ 54 | n_x = x.shape[0] 55 | n_y = y.shape[0] 56 | 57 | if matching_fn == 'l2': 58 | distances = ( 59 | x.unsqueeze(1).expand(n_x, n_y, -1) - 60 | y.unsqueeze(0).expand(n_x, n_y, -1) 61 | ).pow(2).sum(dim=2) 62 | return distances 63 | elif matching_fn == 'cosine': 64 | normalised_x = x / (x.pow(2).sum(dim=1, keepdim=True).sqrt() + EPSILON) 65 | normalised_y = y / (y.pow(2).sum(dim=1, keepdim=True).sqrt() + EPSILON) 66 | 67 | expanded_x = normalised_x.unsqueeze(1).expand(n_x, n_y, -1) 68 | expanded_y = normalised_y.unsqueeze(0).expand(n_x, n_y, -1) 69 | 70 | cosine_similarities = (expanded_x * expanded_y).sum(dim=2) 71 | return 1 - cosine_similarities 72 | elif matching_fn == 'dot': 73 | expanded_x = x.unsqueeze(1).expand(n_x, n_y, -1) 74 | expanded_y = y.unsqueeze(0).expand(n_x, n_y, -1) 75 | 76 | return -(expanded_x * expanded_y).sum(dim=2) 77 | else: 78 | raise(ValueError('Unsupported similarity function')) 79 | 80 | 81 | def copy_weights(from_model: torch.nn.Module, to_model: torch.nn.Module): 82 | """Copies the weights from one model to another model. 83 | 84 | # Arguments: 85 | from_model: Model from which to source weights 86 | to_model: Model which will receive weights 87 | """ 88 | if not from_model.__class__ == to_model.__class__: 89 | raise(ValueError("Models don't have the same architecture!")) 90 | 91 | for m_from, m_to in zip(from_model.modules(), to_model.modules()): 92 | is_linear = isinstance(m_to, torch.nn.Linear) 93 | is_conv = isinstance(m_to, torch.nn.Conv2d) 94 | is_bn = isinstance(m_to, torch.nn.BatchNorm2d) 95 | if is_linear or is_conv or is_bn: 96 | m_to.weight.data = m_from.weight.data.clone() 97 | if m_to.bias is not None: 98 | m_to.bias.data = m_from.bias.data.clone() 99 | 100 | 101 | def autograd_graph(tensor: torch.Tensor) -> Tuple[ 102 | List[torch.autograd.Function], 103 | List[Tuple[torch.autograd.Function, torch.autograd.Function]] 104 | ]: 105 | """Recursively retrieves the autograd graph for a particular tensor. 106 | 107 | # Arguments 108 | tensor: The Tensor to retrieve the autograd graph for 109 | 110 | # Returns 111 | nodes: List of torch.autograd.Functions that are the nodes of the autograd graph 112 | edges: List of (Function, Function) tuples that are the edges between the nodes of the autograd graph 113 | """ 114 | nodes, edges = list(), list() 115 | 116 | def _add_nodes(tensor): 117 | if tensor not in nodes: 118 | nodes.append(tensor) 119 | 120 | if hasattr(tensor, 'next_functions'): 121 | for f in tensor.next_functions: 122 | if f[0] is not None: 123 | edges.append((f[0], tensor)) 124 | _add_nodes(f[0]) 125 | 126 | if hasattr(tensor, 'saved_tensors'): 127 | for t in tensor.saved_tensors: 128 | edges.append((t, tensor)) 129 | _add_nodes(t) 130 | 131 | _add_nodes(tensor.grad_fn) 132 | 133 | return nodes, edges 134 | -------------------------------------------------------------------------------- /app/app_utils_clf.py: -------------------------------------------------------------------------------- 1 | from torch.optim import Adam 2 | from torch.utils.data import DataLoader 3 | from torchvision import models 4 | from torch import nn 5 | 6 | from few_shot.models import get_few_shot_encoder, Flatten 7 | from few_shot.core import NShotTaskSampler, EvaluateFewShot, prepare_nshot_task 8 | from few_shot.proto import proto_net_episode 9 | from few_shot.train import fit 10 | from few_shot.callbacks import * 11 | 12 | from dlcliche.image import * 13 | 14 | assert torch.cuda.is_available() 15 | device = torch.device('cuda') 16 | 17 | 18 | def show_normalized_image(img, ax=None, mono=False): 19 | if mono: 20 | img.numpy()[..., np.newaxis] 21 | np_img = img.numpy().transpose(1, 2, 0) 22 | lifted = np_img - np.min(np_img) 23 | ranged = lifted / np.max(lifted) 24 | show_np_image(ranged, ax=ax) 25 | 26 | 27 | class MonoTo3ChLayer(nn.Module): 28 | def __init__(self): 29 | super(MonoTo3ChLayer, self).__init__() 30 | def forward(self, x): 31 | x.unsqueeze_(1) 32 | return x.repeat(1, 3, 1, 1) 33 | 34 | 35 | def _get_model(weight_file, device, model_fn, mono): 36 | base_model = model_fn(pretrained=True) 37 | feature_model = nn.Sequential(*list(base_model.children())[:-1], 38 | nn.AdaptiveAvgPool2d(1), 39 | Flatten()) 40 | # Load initial weights 41 | if weight_file is not None: 42 | feature_model.load_state_dict(torch.load(weight_file)) 43 | # Add mono image input layer at the bottom of feature model 44 | if mono: 45 | feature_model = nn.Sequential(MonoTo3ChLayer(), feature_model) 46 | if device is not None: 47 | feature_model.to(device) 48 | 49 | feature_model.eval() 50 | return feature_model 51 | 52 | 53 | def get_resnet101(weight_file=None, device=None, mono=False): 54 | return _get_model(weight_file, device, models.resnet101, mono=mono) 55 | 56 | 57 | def get_resnet50(weight_file=None, device=None, mono=False): 58 | return _get_model(weight_file, device, models.resnet50, mono=mono) 59 | 60 | 61 | def get_resnet34(weight_file=None, device=None, mono=False): 62 | return _get_model(weight_file, device, models.resnet34, mono=mono) 63 | 64 | 65 | def get_resnet18(weight_file=None, device=None, mono=False): 66 | return _get_model(weight_file, device, models.resnet18, mono=mono) 67 | 68 | 69 | def get_densenet121(weight_file=None, device=None, mono=False): 70 | return _get_model(weight_file, device, models.densenet121, mono=mono) 71 | 72 | 73 | def train_proto_net(args, model, device, n_epochs, 74 | background_taskloader, 75 | evaluation_taskloader, 76 | path='.', 77 | lr=3e-3, 78 | drop_lr_every=100, 79 | evaluation_episodes=100, 80 | episodes_per_epoch=100, 81 | ): 82 | # Prepare model 83 | model.to(device, dtype=torch.float) 84 | model.train(True) 85 | 86 | # Prepare training etc. 87 | optimizer = Adam(model.parameters(), lr=lr) 88 | loss_fn = torch.nn.NLLLoss().cuda() 89 | ensure_folder(path + '/models') 90 | ensure_folder(path + '/logs') 91 | 92 | def lr_schedule(epoch, lr): 93 | if epoch % drop_lr_every == 0: 94 | return lr / 2 95 | else: 96 | return lr 97 | 98 | callbacks = [ 99 | EvaluateFewShot( 100 | eval_fn=proto_net_episode, 101 | num_tasks=evaluation_episodes, 102 | n_shot=args.n_test, 103 | k_way=args.k_test, 104 | q_queries=args.q_test, 105 | taskloader=evaluation_taskloader, 106 | prepare_batch=prepare_nshot_task(args.n_test, args.k_test, args.q_test), 107 | distance=args.distance 108 | ), 109 | ModelCheckpoint( 110 | filepath=path + '/models/'+args.param_str+'_e{epoch:02d}.pth', 111 | monitor=args.checkpoint_monitor or f'val_{args.n_test}-shot_{args.k_test}-way_acc', 112 | period=args.checkpoint_period or 100, 113 | ), 114 | LearningRateScheduler(schedule=lr_schedule), 115 | CSVLogger(path + f'/logs/{args.param_str}.csv'), 116 | ] 117 | 118 | fit( 119 | model, 120 | optimizer, 121 | loss_fn, 122 | epochs=n_epochs, 123 | dataloader=background_taskloader, 124 | prepare_batch=prepare_nshot_task(args.n_train, args.k_train, args.q_train), 125 | callbacks=callbacks, 126 | metrics=['categorical_accuracy'], 127 | epoch_metrics=[f'val_{args.n_test}-shot_{args.k_test}-way_acc'], 128 | fit_function=proto_net_episode, 129 | fit_function_kwargs={'n_shot': args.n_train, 'k_way': args.k_train, 'q_queries': args.q_train, 'train': True, 130 | 'distance': args.distance}, 131 | ) 132 | -------------------------------------------------------------------------------- /few_shot/train.py: -------------------------------------------------------------------------------- 1 | """ 2 | The `fit` function in this file implements a slightly modified version 3 | of the Keras `model.fit()` API. 4 | """ 5 | import torch 6 | from torch.optim import Optimizer 7 | from torch.nn import Module 8 | from torch.utils.data import DataLoader 9 | from typing import Callable, List, Union 10 | 11 | from few_shot.callbacks import DefaultCallback, ProgressBarLogger, CallbackList, Callback 12 | from few_shot.metrics import NAMED_METRICS 13 | 14 | 15 | def gradient_step(model: Module, optimiser: Optimizer, loss_fn: Callable, x: torch.Tensor, y: torch.Tensor, **kwargs): 16 | """Takes a single gradient step. 17 | 18 | # Arguments 19 | model: Model to be fitted 20 | optimiser: Optimiser to calculate gradient step from loss 21 | loss_fn: Loss function to calculate between predictions and outputs 22 | x: Input samples 23 | y: Input targets 24 | """ 25 | model.train() 26 | optimiser.zero_grad() 27 | y_pred = model(x) 28 | loss = loss_fn(y_pred, y) 29 | loss.backward() 30 | optimiser.step() 31 | 32 | return loss, y_pred 33 | 34 | 35 | def batch_metrics(model: Module, y_pred: torch.Tensor, y: torch.Tensor, metrics: List[Union[str, Callable]], 36 | batch_logs: dict): 37 | """Calculates metrics for the current training batch 38 | 39 | # Arguments 40 | model: Model being fit 41 | y_pred: predictions for a particular batch 42 | y: labels for a particular batch 43 | batch_logs: Dictionary of logs for the current batch 44 | """ 45 | model.eval() 46 | for m in metrics: 47 | if isinstance(m, str): 48 | batch_logs[m] = NAMED_METRICS[m](y, y_pred) 49 | else: 50 | # Assume metric is a callable function 51 | batch_logs = m(y, y_pred) 52 | 53 | return batch_logs 54 | 55 | 56 | def fit(model: Module, optimiser: Optimizer, loss_fn: Callable, epochs: int, dataloader: DataLoader, 57 | prepare_batch: Callable, metrics: List[Union[str, Callable]] = None, 58 | epoch_metrics: List[str] = None, callbacks: List[Callback] = None, 59 | verbose: bool =True, fit_function: Callable = gradient_step, fit_function_kwargs: dict = {}): 60 | """Function to abstract away training loop. 61 | 62 | The benefit of this function is that allows training scripts to be much more readable and allows for easy re-use of 63 | common training functionality provided they are written as a subclass of voicemap.Callback (following the 64 | Keras API). 65 | 66 | # Arguments 67 | model: Model to be fitted. 68 | optimiser: Optimiser to calculate gradient step from loss 69 | loss_fn: Loss function to calculate between predictions and outputs 70 | epochs: Number of epochs of fitting to be performed 71 | dataloader: `torch.DataLoader` instance to fit the model to 72 | prepare_batch: Callable to perform any desired preprocessing 73 | metrics: Optional list of metrics to evaluate the model with 74 | epoch_metrics: Optional list of metrics on top of metrics at the end of epoch 75 | callbacks: Additional functionality to incorporate into training such as logging metrics to csv, model 76 | checkpointing, learning rate scheduling etc... See voicemap.callbacks for more. 77 | verbose: All print output is muted if this argument is `False` 78 | fit_function: Function for calculating gradients. Leave as default for simple supervised training on labelled 79 | batches. For more complex training procedures (meta-learning etc...) you will need to write your own 80 | fit_function 81 | fit_function_kwargs: Keyword arguments to pass to `fit_function` 82 | """ 83 | # Determine number of samples: 84 | num_batches = len(dataloader) 85 | batch_size = dataloader.batch_size 86 | 87 | callbacks = CallbackList([DefaultCallback(), ] + (callbacks or []) + [ProgressBarLogger(), ]) 88 | callbacks.set_model(model) 89 | callbacks.set_params({ 90 | 'num_batches': num_batches, 91 | 'batch_size': batch_size, 92 | 'verbose': verbose, 93 | 'metrics': (metrics or []), 94 | 'epoch_metrics': (epoch_metrics or []), 95 | 'prepare_batch': prepare_batch, 96 | 'loss_fn': loss_fn, 97 | 'optimiser': optimiser 98 | }) 99 | 100 | if verbose: 101 | print('Begin training...') 102 | 103 | callbacks.on_train_begin() 104 | 105 | for epoch in range(1, epochs+1): 106 | callbacks.on_epoch_begin(epoch) 107 | 108 | epoch_logs = {} 109 | for batch_index, batch in enumerate(dataloader): 110 | batch_logs = dict(batch=batch_index, size=(batch_size or 1)) 111 | 112 | callbacks.on_batch_begin(batch_index, batch_logs) 113 | 114 | x, y = prepare_batch(batch) 115 | 116 | loss, y_pred = fit_function(model, optimiser, loss_fn, x, y, **fit_function_kwargs) 117 | batch_logs['loss'] = loss.item() 118 | 119 | # Loops through all metrics 120 | batch_logs = batch_metrics(model, y_pred, y, metrics, batch_logs) 121 | 122 | callbacks.on_batch_end(batch_index, batch_logs) 123 | 124 | # Run on epoch end 125 | callbacks.on_epoch_end(epoch, epoch_logs) 126 | 127 | # Run on train end 128 | if verbose: 129 | print('Finished.') 130 | 131 | callbacks.on_train_end() 132 | -------------------------------------------------------------------------------- /few_shot/matching.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.nn.utils import clip_grad_norm_ 3 | from torch.optim import Optimizer 4 | from torch.nn import Module 5 | from torch.nn.modules.loss import _Loss as Loss 6 | 7 | from config import EPSILON 8 | from few_shot.core import create_nshot_task_label 9 | from few_shot.utils import pairwise_distances 10 | 11 | 12 | def matching_net_episode(model: Module, 13 | optimiser: Optimizer, 14 | loss_fn: Loss, 15 | x: torch.Tensor, 16 | y: torch.Tensor, 17 | n_shot: int, 18 | k_way: int, 19 | q_queries: int, 20 | distance: str, 21 | fce: bool, 22 | train: bool): 23 | """Performs a single training episode for a Matching Network. 24 | 25 | # Arguments 26 | model: Matching Network to be trained. 27 | optimiser: Optimiser to calculate gradient step from loss 28 | loss_fn: Loss function to calculate between predictions and outputs 29 | x: Input samples of few shot classification task 30 | y: Input labels of few shot classification task 31 | n_shot: Number of examples per class in the support set 32 | k_way: Number of classes in the few shot classification task 33 | q_queries: Number of examples per class in the query set 34 | distance: Distance metric to use when calculating distance between support and query set samples 35 | fce: Whether or not to us fully conditional embeddings 36 | train: Whether (True) or not (False) to perform a parameter update 37 | 38 | # Returns 39 | loss: Loss of the Matching Network on this task 40 | y_pred: Predicted class probabilities for the query set on this task 41 | """ 42 | if train: 43 | # Zero gradients 44 | model.train() 45 | optimiser.zero_grad() 46 | else: 47 | model.eval() 48 | 49 | # Embed all samples 50 | embeddings = model.encoder(x) 51 | 52 | # Samples are ordered by the NShotWrapper class as follows: 53 | # k lots of n support samples from a particular class 54 | # k lots of q query samples from those classes 55 | support = embeddings[:n_shot * k_way] 56 | queries = embeddings[n_shot * k_way:] 57 | 58 | # Optionally apply full context embeddings 59 | if fce: 60 | # LSTM requires input of shape (seq_len, batch, input_size). `support` is of 61 | # shape (k_way * n_shot, embedding_dim) and we want the LSTM to treat the 62 | # support set as a sequence so add a single dimension to transform support set 63 | # to the shape (k_way * n_shot, 1, embedding_dim) and then remove the batch dimension 64 | # afterwards 65 | 66 | # Calculate the fully conditional embedding, g, for support set samples as described 67 | # in appendix A.2 of the paper. g takes the form of a bidirectional LSTM with a 68 | # skip connection from inputs to outputs 69 | support, _, _ = model.g(support.unsqueeze(1)) 70 | support = support.squeeze(1) 71 | 72 | # Calculate the fully conditional embedding, f, for the query set samples as described 73 | # in appendix A.1 of the paper. 74 | queries = model.f(support, queries) 75 | 76 | # Efficiently calculate distance between all queries and all prototypes 77 | # Output should have shape (q_queries * k_way, k_way) = (num_queries, k_way) 78 | distances = pairwise_distances(queries, support, distance) 79 | 80 | # Calculate "attention" as softmax over support-query distances 81 | attention = (-distances).softmax(dim=1) 82 | 83 | # Calculate predictions as in equation (1) from Matching Networks 84 | # y_hat = \sum_{i=1}^{k} a(x_hat, x_i) y_i 85 | y_pred = matching_net_predictions(attention, n_shot, k_way, q_queries) 86 | 87 | # Calculated loss with negative log likelihood 88 | # Clip predictions for numerical stability 89 | clipped_y_pred = y_pred.clamp(EPSILON, 1 - EPSILON) 90 | loss = loss_fn(clipped_y_pred.log(), y) 91 | 92 | if train: 93 | # Backpropagate gradients 94 | loss.backward() 95 | # I found training to be quite unstable so I clip the norm 96 | # of the gradient to be at most 1 97 | clip_grad_norm_(model.parameters(), 1) 98 | # Take gradient step 99 | optimiser.step() 100 | 101 | return loss, y_pred 102 | 103 | 104 | def matching_net_predictions(attention: torch.Tensor, n: int, k: int, q: int) -> torch.Tensor: 105 | """Calculates Matching Network predictions based on equation (1) of the paper. 106 | 107 | The predictions are the weighted sum of the labels of the support set where the 108 | weights are the "attentions" (i.e. softmax over query-support distances) pointing 109 | from the query set samples to the support set samples. 110 | 111 | # Arguments 112 | attention: torch.Tensor containing softmax over query-support distances. 113 | Should be of shape (q * k, k * n) 114 | n: Number of support set samples per class, n-shot 115 | k: Number of classes in the episode, k-way 116 | q: Number of query samples per-class 117 | 118 | # Returns 119 | y_pred: Predicted class probabilities 120 | """ 121 | if attention.shape != (q * k, k * n): 122 | raise(ValueError(f'Expecting attention Tensor to have shape (q * k, k * n) = ({q * k, k * n})')) 123 | 124 | # Create one hot label vector for the support set 125 | y_onehot = torch.zeros(k * n, k) 126 | 127 | # Unsqueeze to force y to be of shape (K*n, 1) as this 128 | # is needed for .scatter() 129 | y = create_nshot_task_label(k, n).unsqueeze(-1) 130 | y_onehot = y_onehot.scatter(1, y, 1) 131 | 132 | y_pred = torch.mm(attention, y_onehot.cuda().float()) 133 | 134 | return y_pred -------------------------------------------------------------------------------- /few_shot/extmodel_proto_net_clf.py: -------------------------------------------------------------------------------- 1 | """ 2 | For testing what if we use ImageNet pretrained model as ProtoNet?? 3 | """ 4 | from dlcliche.utils import * 5 | from dlcliche.math import * 6 | from dlcliche.image import show_np_image, subplot_matrix 7 | 8 | from torchvision import models 9 | from torch import nn 10 | import torch 11 | from tqdm import tqdm 12 | 13 | # TODO: Support cpu environment 14 | 15 | class BasePretrainedModel(nn.Module): 16 | def __init__(self, base_model=models.resnet18, n_embs=512, print_shape=False): 17 | super(BasePretrainedModel, self).__init__() 18 | resnet = base_model(pretrained=True) 19 | self.body = nn.Sequential(*list(resnet.children())[:-1]) 20 | self.n_embs = n_embs 21 | self.print_shape = print_shape 22 | 23 | def forward(self, x): 24 | x = self.body(x) 25 | if self.print_shape: 26 | print(x.shape) 27 | return x.view(-1, self.n_embs) 28 | 29 | 30 | class ExtModelProtoNetClf(object): 31 | """ProtoNet as conventional classifier using external model. 32 | Created for testing what if we use ImageNet pretrained model for getting embeddings. 33 | 34 | TODO Fix bad design for member-call-order dependency... 35 | """ 36 | 37 | def __init__(self, model, classes, device, n_embeddings=None): 38 | model.to(device) 39 | model.eval() 40 | self.model = model 41 | self.classes = classes 42 | self.device = device 43 | self.n_embeddings = n_embeddings # First get_embeddings() will set this, if it is None 44 | self.n_classes = len(classes) 45 | self._make_null_prototypes() 46 | self.log = get_logger() 47 | 48 | def _make_null_prototypes(self): 49 | if self.n_embeddings: 50 | self.prototypes = [OnlineStats(self.n_embeddings) \ 51 | for _ in range(self.n_classes)] 52 | else: 53 | self.prototypes = None 54 | 55 | def get_embeddings(self, dl, visualize=False): 56 | """Get embeddings for all samples available in dataloader.""" 57 | gts, cur = [], 0 58 | with torch.no_grad(): 59 | for batch_index, (X, y_gt) in tqdm(enumerate(dl), total=len(dl)): 60 | dev_X, y_gt = X.to(self.device), list(y_gt) 61 | this_embs = self.model(dev_X).cpu().detach().numpy() 62 | if cur == 0: 63 | self.n_embeddings = this_embs.shape[-1] 64 | embs = np.zeros((len(dl.dataset), self.n_embeddings)) 65 | 66 | if visualize: 67 | for i, ax in enumerate(subplot_matrix(columns=4, rows=2, figsize=(16, 8))): 68 | if len(dl) <= batch_index * 8 + i: break 69 | show_np_image(np.transpose(X[i].cpu().detach().numpy(), [1, 2, 0]), ax=ax) 70 | plt.show() 71 | 72 | for i in range(len(this_embs)): 73 | embs[cur] = this_embs[i] 74 | gts.append(y_gt[i]) 75 | cur += 1 76 | return np.array(embs), gts 77 | 78 | def make_prototypes(self, support_set_dl, repeat=1, update=False, visualize=False): 79 | """Calculate prototypes by accumulating embeddings of all samples in given support set. 80 | Args: 81 | support_set_dl: support set dataloader. 82 | repeat: test parameter for what if we get prototype with augmented samples. 83 | update: set True if you don't want to update prototypes with new samples from dataloader. 84 | """ 85 | # Get embeddings of support set samples 86 | embs, gts = self.get_embeddings(support_set_dl, visualize=visualize) 87 | # Make prototypes if not there 88 | if update: 89 | self.log.info('Using current prototypes.') 90 | else: 91 | self.log.info('Making new prototypes.') 92 | self._make_null_prototypes() 93 | # Update prototypes (just by feeding to online stat class) 94 | for i in range(repeat): 95 | for emb, cls in zip(embs, gts): 96 | if not isinstance(cls, int): 97 | cls = self.classes.index(cls) 98 | self.prototypes[cls].put(emb) 99 | if i < repeat - 1: 100 | embs, gts = self.get_embeddings(support_set_dl) # no visualization 101 | 102 | def get_prototypes(self): 103 | return np.array([p.mean() for p in self.prototypes]) 104 | 105 | def save_prototypes(self, filename): 106 | np.save(filename, self.get_prototypes()) 107 | 108 | def load_prototypes(self, filename): 109 | prototypes = np.load(filename) 110 | for i, prototype in enumerate(prototypes): 111 | self.prototypes[i].set_mean(prototype) 112 | 113 | def predict_embeddings(self, X_embs, softmax=True, normalized_softmax=True): 114 | preds = np.zeros((len(X_embs), self.n_classes)) 115 | proto_embs = [p.mean() for p in self.prototypes] 116 | for idx_sample, x in tqdm(enumerate(X_embs), total=len(X_embs)): 117 | for idx_class, proto in enumerate(proto_embs): 118 | preds[idx_sample, idx_class] = -(np.linalg.norm(x - proto)**2) 119 | if softmax: 120 | if normalized_softmax: 121 | preds /= np.max([1.0, np.abs(preds.mean())]) 122 | preds = np_softmax(preds) 123 | return preds 124 | 125 | def predict(self, data_loader): 126 | embs, y_gts = self.get_embeddings(data_loader) 127 | return self.predict_embeddings(embs), y_gts 128 | 129 | def evaluate(self, data_loader): 130 | y_hat, y_gts = self.predict(data_loader) 131 | return calculate_clf_metrics(y_gts, y_hat) 132 | 133 | @staticmethod 134 | def get_uncertainty(dists): 135 | _max = np.max(dists, axis=-1) 136 | _mean = np.mean(dists, axis=-1) 137 | uncertainty = _max/_mean 138 | return uncertainty 139 | -------------------------------------------------------------------------------- /few_shot/maml.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from collections import OrderedDict 3 | from torch.optim import Optimizer 4 | from torch.nn import Module 5 | from typing import Dict, List, Callable, Union 6 | 7 | from few_shot.core import create_nshot_task_label 8 | 9 | 10 | def replace_grad(parameter_gradients, parameter_name): 11 | def replace_grad_(module): 12 | return parameter_gradients[parameter_name] 13 | 14 | return replace_grad_ 15 | 16 | 17 | def meta_gradient_step(model: Module, 18 | optimiser: Optimizer, 19 | loss_fn: Callable, 20 | x: torch.Tensor, 21 | y: torch.Tensor, 22 | n_shot: int, 23 | k_way: int, 24 | q_queries: int, 25 | order: int, 26 | inner_train_steps: int, 27 | inner_lr: float, 28 | train: bool, 29 | device: Union[str, torch.device]): 30 | """ 31 | Perform a gradient step on a meta-learner. 32 | 33 | # Arguments 34 | model: Base model of the meta-learner being trained 35 | optimiser: Optimiser to calculate gradient step from loss 36 | loss_fn: Loss function to calculate between predictions and outputs 37 | x: Input samples for all few shot tasks 38 | y: Input labels of all few shot tasks 39 | n_shot: Number of examples per class in the support set of each task 40 | k_way: Number of classes in the few shot classification task of each task 41 | q_queries: Number of examples per class in the query set of each task. The query set is used to calculate 42 | meta-gradients after applying the update to 43 | order: Whether to use 1st order MAML (update meta-learner weights with gradients of the updated weights on the 44 | query set) or 2nd order MAML (use 2nd order updates by differentiating through the gradients of the updated 45 | weights on the query with respect to the original weights). 46 | inner_train_steps: Number of gradient steps to fit the fast weights during each inner update 47 | inner_lr: Learning rate used to update the fast weights on the inner update 48 | train: Whether to update the meta-learner weights at the end of the episode. 49 | device: Device on which to run computation 50 | """ 51 | data_shape = x.shape[2:] 52 | create_graph = (True if order == 2 else False) and train 53 | 54 | task_gradients = [] 55 | task_losses = [] 56 | task_predictions = [] 57 | for meta_batch in x: 58 | # By construction x is a 5D tensor of shape: (meta_batch_size, n*k + q*k, channels, width, height) 59 | # Hence when we iterate over the first dimension we are iterating through the meta batches 60 | x_task_train = meta_batch[:n_shot * k_way] 61 | x_task_val = meta_batch[n_shot * k_way:] 62 | 63 | # Create a fast model using the current meta model weights 64 | fast_weights = OrderedDict(model.named_parameters()) 65 | 66 | # Train the model for `inner_train_steps` iterations 67 | for inner_batch in range(inner_train_steps): 68 | # Perform update of model weights 69 | y = create_nshot_task_label(k_way, n_shot).to(device) 70 | logits = model.functional_forward(x_task_train, fast_weights) 71 | loss = loss_fn(logits, y) 72 | gradients = torch.autograd.grad(loss, fast_weights.values(), create_graph=create_graph) 73 | 74 | # Update weights manually 75 | fast_weights = OrderedDict( 76 | (name, param - inner_lr * grad) 77 | for ((name, param), grad) in zip(fast_weights.items(), gradients) 78 | ) 79 | 80 | # Do a pass of the model on the validation data from the current task 81 | y = create_nshot_task_label(k_way, q_queries).to(device) 82 | logits = model.functional_forward(x_task_val, fast_weights) 83 | loss = loss_fn(logits, y) 84 | loss.backward(retain_graph=True) 85 | 86 | # Get post-update accuracies 87 | y_pred = logits.softmax(dim=1) 88 | task_predictions.append(y_pred) 89 | 90 | # Accumulate losses and gradients 91 | task_losses.append(loss) 92 | gradients = torch.autograd.grad(loss, fast_weights.values(), create_graph=create_graph) 93 | named_grads = {name: g for ((name, _), g) in zip(fast_weights.items(), gradients)} 94 | task_gradients.append(named_grads) 95 | 96 | if order == 1: 97 | if train: 98 | sum_task_gradients = {k: torch.stack([grad[k] for grad in task_gradients]).mean(dim=0) 99 | for k in task_gradients[0].keys()} 100 | hooks = [] 101 | for name, param in model.named_parameters(): 102 | hooks.append( 103 | param.register_hook(replace_grad(sum_task_gradients, name)) 104 | ) 105 | 106 | model.train() 107 | optimiser.zero_grad() 108 | # Dummy pass in order to create `loss` variable 109 | # Replace dummy gradients with mean task gradients using hooks 110 | logits = model(torch.zeros((k_way, ) + data_shape).to(device, dtype=torch.float)) 111 | loss = loss_fn(logits, create_nshot_task_label(k_way, 1).to(device)) 112 | loss.backward() 113 | optimiser.step() 114 | 115 | for h in hooks: 116 | h.remove() 117 | 118 | return torch.stack(task_losses).mean(), torch.cat(task_predictions) 119 | 120 | elif order == 2: 121 | model.train() 122 | optimiser.zero_grad() 123 | meta_batch_loss = torch.stack(task_losses).mean() 124 | 125 | if train: 126 | meta_batch_loss.backward() 127 | optimiser.step() 128 | 129 | return meta_batch_loss, torch.cat(task_predictions) 130 | else: 131 | raise ValueError('Order must be either 1 or 2.') 132 | -------------------------------------------------------------------------------- /tests/test_utils.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | import numpy as np 3 | import torch 4 | from torch.nn.modules.distance import CosineSimilarity, PairwiseDistance 5 | 6 | from few_shot.utils import * 7 | from config import PATH 8 | 9 | 10 | class TestDistance(unittest.TestCase): 11 | def test_query_support_distances(self): 12 | # Create some dummy data with easily verifiable distances 13 | q = 1 # 1 query per class 14 | k = 3 # 3 way classification 15 | d = 2 # embedding dimension of two 16 | 17 | query = torch.zeros([q * k, d], dtype=torch.float) 18 | query[0] = torch.Tensor([0, 0]) 19 | query[1] = torch.Tensor([0, 1]) 20 | query[2] = torch.Tensor([1, 0]) 21 | support = torch.zeros([k, d], dtype=torch.float) 22 | support[0] = torch.Tensor([1, 1]) 23 | support[1] = torch.Tensor([1, 2]) 24 | support[2] = torch.Tensor([2, 2]) 25 | 26 | distances = pairwise_distances(query, support, 'l2') 27 | self.assertEqual( 28 | distances.shape, (q * k, k), 29 | 'Output should have shape (q * k, k).' 30 | ) 31 | 32 | # Calculate squared distances by iterating through all query-support pairs 33 | for i, q_ in enumerate(query): 34 | for j, s_ in enumerate(support): 35 | self.assertEqual( 36 | (q_ - s_).pow(2).sum(), 37 | distances[i, j].item(), 38 | 'The jth column of the ith row should be the squared distance between the ' 39 | 'ith query sample and the kth query sample' 40 | ) 41 | 42 | # Create some dummy data with easily verifiable distances 43 | q = 1 # 1 query per class 44 | k = 3 # 3 way classification 45 | d = 2 # embedding dimension of two 46 | query = torch.zeros([q * k, d], dtype=torch.float) 47 | query[0] = torch.Tensor([1, 0]) 48 | query[1] = torch.Tensor([0, 1]) 49 | query[2] = torch.Tensor([1, 1]) 50 | support = torch.zeros([k, d], dtype=torch.float) 51 | support[0] = torch.Tensor([1, 1]) 52 | support[1] = torch.Tensor([-1, -1]) 53 | support[2] = torch.Tensor([0, 2]) 54 | 55 | distances = pairwise_distances(query, support, 'cosine') 56 | 57 | # Calculate distances by iterating through all query-support pairs 58 | for i, q_ in enumerate(query): 59 | for j, s_ in enumerate(support): 60 | self.assertTrue( 61 | torch.isclose(1-CosineSimilarity(dim=0)(q_, s_), distances[i, j], atol=2e-8), 62 | 'The jth column of the ith row should be the squared distance between the ' 63 | 'ith query sample and the kth query sample' 64 | ) 65 | 66 | def test_no_nans_on_zero_vectors(self): 67 | """Cosine distance calculation involves a divide-through by vector magnitude which 68 | can divide by zeros to occur. 69 | """ 70 | # Create some dummy data with easily verifiable distances 71 | q = 1 # 1 query per class 72 | k = 3 # 3 way classification 73 | d = 2 # embedding dimension of two 74 | query = torch.zeros([q * k, d], dtype=torch.float) 75 | query[0] = torch.Tensor([0, 0]) # First query sample is all zeros 76 | query[1] = torch.Tensor([0, 1]) 77 | query[2] = torch.Tensor([1, 1]) 78 | support = torch.zeros([k, d], dtype=torch.float) 79 | support[0] = torch.Tensor([1, 1]) 80 | support[1] = torch.Tensor([-1, -1]) 81 | support[2] = torch.Tensor([0, 0]) # Third support sample is all zeros 82 | 83 | distances = pairwise_distances(query, support, 'cosine') 84 | 85 | self.assertTrue(torch.isnan(distances).sum() == 0, 'Cosine distances between 0-vectors should not be nan') 86 | 87 | 88 | class TestAutogradGraphRetrieval(unittest.TestCase): 89 | def test_retrieval(self): 90 | """Create a simple autograd graph and check that the output is what is expected""" 91 | x = torch.ones(2, 2, requires_grad=True) 92 | y = x + 2 93 | # The operation on the next line will create two edges because the y variable is 94 | # y variable is used twice. 95 | z = y * y 96 | out = z.mean() 97 | 98 | nodes, edges = autograd_graph(out) 99 | 100 | # This is quite a brittle test as it will break if the names of the autograd Functions change 101 | # TODO: Less brittle test 102 | 103 | expected_nodes = [ 104 | 'MeanBackward1', 105 | 'ThMulBackward', 106 | 'AddBackward', 107 | 'AccumulateGrad', 108 | ] 109 | 110 | self.assertEqual( 111 | set(expected_nodes), 112 | set(n.__class__.__name__ for n in nodes), 113 | 'autograd_graph() must return all nodes in the autograd graph.' 114 | ) 115 | 116 | # Check for the existence of the expected edges 117 | expected_edges = [ 118 | ('ThMulBackward', 'MeanBackward1'), # z = y * y, out = z.mean() 119 | ('AddBackward', 'ThMulBackward'), # y = x + 2, z = y * y 120 | ('AccumulateGrad', 'AddBackward'), # x = torch.ones(2, 2, requires_grad=True), y = x + 2 121 | ] 122 | for e in edges: 123 | self.assertIn( 124 | (e[0].__class__.__name__, e[1].__class__.__name__), 125 | expected_edges, 126 | 'autograd_graph() must return all edges in the autograd graph.' 127 | ) 128 | 129 | # Check for two edges between the AddBackward node and the ThMulBackward node 130 | num_y_squared_edges = 0 131 | for e in edges: 132 | if e[0].__class__.__name__ == 'AddBackward' and e[1].__class__.__name__ == 'ThMulBackward': 133 | num_y_squared_edges += 1 134 | 135 | self.assertEqual( 136 | num_y_squared_edges, 137 | 2, 138 | 'autograd_graph() must return multiple edges between nodes if they exist.' 139 | ) 140 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # (Using) Prototypical Networks as a Fine Grained Classifier 2 | 3 | This repository is heavily based on [Oscar Knagg](https://towardsdatascience.com/@oknagg)'s few-shot learning 4 | implementation [github.com/oscarknagg/few-shot](https://github.com/oscarknagg/few-shot), focused on applying simple but strong [Prototpyical Networks](https://arxiv.org/pdf/1703.05175.pdf) to fine grained classification task. 5 | 6 | Main contributions this repository provides: 7 | 8 | - Practical application of few-shot machine learning system ready to real world fine-grained classification problems. 9 | - Transfer learning ready to make quick training possible. Using ImageNet pre-trained models by default, or any networks even non-CNN are available. 10 | - Proved in a fairly difficult Kaggle competition that ImageNet pretrained model works fine as core model of Prototypical Networks. 11 | 12 | Unlike very clean original implementation, this repository contains 13 | some dirty code to quickly present sample solution to a Kaggle competition 14 | "[Humpback Whale Identification](https://www.kaggle.com/c/humpback-whale-identification/)". 15 | 16 | Some of submission code borrows functions from [Radek Osmulski](https://medium.com/@radekosmulski)'s github repository. 17 | 18 | I'd like to express sincere appreciation to both [Oscar Knagg](https://towardsdatascience.com/@oknagg) and [Radek Osmulski](https://medium.com/@radekosmulski). Thank you. 19 | 20 | ## Prototypical Networks 21 | 22 | Prototypical Networks was proposed in the paper [Prototpyical Networks for Few-shot Learning](https://arxiv.org/pdf/1703.05175.pdf) 23 | (Snell et al), which calculates _prototype_ as a central point of class in Euclidean space, then test samples can be simply classified by measuring distances to the class prototypes. 24 | 25 | In Prototypical Networks, model learns all the non-linearity. It encapsulates everything in between non-linear inputs and linear outputs, system design and training algorithm make it all possible. 26 | 27 | ![fig](assets/proto_nets_diagram.png) 28 | 29 | Figure from original paper. Color circles: training samples, $c_i$: prototypes, $x$: test sample. 30 | 31 | What Prototypical Networks scheme trains model is metrics in Euclidean space, this makes it quite handy tool for real world engineering. 32 | 33 | Here's summary of nice traits for machine learning practitioners: 34 | 35 | - Explainable: It discriminates classes in multi-dimensional Euclidean space, which many old fashioned engineers are familiar with. This is important so that we can explain to non-ML project stakeholders and finally bring the model to the real world projects. It’s not even cosine distance, just a conventional distance. 36 | - Customizable: Any model can be used, so it is applicable to any problem; model is simply trained to map input data points to output data points in Euclidean space so that all classes can be distinguished by old fashioned distance. 37 | - Few-shot ready: It works with long tail problems where very small number of samples are available with some classes, as well as imbalance of samples between classes. It is (almost as of now) proven in a Kaggle competition 38 | "[Humpback Whale Identification](https://www.kaggle.com/c/humpback-whale-identification/)". 39 | - Easy to train: (I think) this is almost free from difficult and computationally intensive hard mining that selects training samples to make it difficulter as training goes. 40 | 41 | ## Quick start 42 | 43 | This project derives prerequisite below: 44 | 45 | This project is written in python 3.6 and Pytorch and assumes you have 46 | a GPU. 47 | 48 | 1. Install [dl-cliche from github](https://github.com/daisukelab/dl-cliche), excuse me this is my almost-private library to repeat cliche code. 49 | 50 | pip install git+https://github.com/daisukelab/dl-cliche.git@master --upgrade 51 | 52 | 2. Install [albumentations](https://github.com/albu/albumentations/). 53 | 3. Edit the `DATA_PATH` variable in `config.py` to the location where 54 | you downloaded dataset copy from Kaggle. 55 | 4. Open and run `app/whale/Example_Humpback_Whale_Identification.ipynb` to reproduce whale identification solution. 56 | 57 | ## Benefits and drawbacks summary 58 | 59 | - Very simple design for both networks and training algorithm. 60 | - All non-linearity can be learned by the model. 61 | - Independent from model design, we can choose arbitrary networks best fit to our problem. 62 | - Embeddings produced by the learnt model are simple data points in multi-dimensional Euclidean space where distances between data points are quite simply calculated. 63 | - Training is easier comparing to Siamese networks for example. 64 | - Less sensitive to class imbalance, training algorithm always picks equal number of samples from k-classes. 65 | - Test time augmentation can be naturally applied for both getting prototypes and test samples' embeddings. 66 | 67 | But 68 | 69 | - Number of classes ProtoNets can train is mainly limited by memory size. Single GTX1080Ti can handle up to 20 classes for 1 shot with 384x384 images for example. 70 | - As far as I have tried, more k-way (k-classes) results in better performance, and it is limited by memory as written above. 71 | 72 | ## Towards better performance 73 | 74 | - Augmentation matters. 75 | - Image size also matters. 76 | - TTA pushes score. 77 | - and more... 78 | 79 | ## Resources 80 | 81 | - Original paper: [Prototpyical Networks for Few-shot Learning](https://arxiv.org/pdf/1703.05175.pdf) 82 | (Snell et al). 83 | - [Oscar Knagg](https://towardsdatascience.com/@oknagg)'s article: [Theory and concepts](https://towardsdatascience.com/advances-in-few-shot-learning-a-guided-tour-36bc10a68b77) 84 | - [Oscar Knagg](https://towardsdatascience.com/@oknagg)'s article: [Discussion of implementation details](https://towardsdatascience.com/advances-in-few-shot-learning-reproducing-results-in-pytorch-aba70dee541d) 85 | - [Radek Osmulski](https://www.kaggle.com/radek1)'s post on Kaggle discussion: [[LB 0.760] Fastai Starter Pack](https://www.kaggle.com/c/humpback-whale-identification/discussion/74647) 86 | - [Radek Osmulski](https://medium.com/@radekosmulski)'s github repository: [Humpback Whale Identification Competition Starter Pack](https://github.com/radekosmulski/whale) 87 | -------------------------------------------------------------------------------- /app/whale/whale_plus_utils.py: -------------------------------------------------------------------------------- 1 | from whale_utils import * 2 | from PIL import Image #import PIL 3 | 4 | 5 | def partition_np_image(image, part, n_part=2): 6 | ih, iw = image.shape[:2] 7 | assert part < n_part 8 | M = 2 9 | dw = iw // (n_part * M + 1) # 1/5 if n_part=2 10 | x0 = dw * part * M 11 | x1 = min(x0 + (M + 1) * dw, iw) 12 | image = image[:, x0:x1, :] 13 | return image 14 | 15 | 16 | def get_aug_plus(re_size=224, to_size=224, augment='train', normalize='imagenet'): 17 | augs = [A.Resize(height=re_size, width=re_size)] 18 | if augment == 'train': 19 | augs.extend([ 20 | A.RandomCrop(height=to_size, width=to_size), 21 | A.ShiftScaleRotate(shift_limit=0.1, scale_limit=0.3, rotate_limit=30, p=0.75), 22 | A.RandomBrightnessContrast(brightness_limit=0.5, contrast_limit=0.5, p=0.75), 23 | A.Blur(p=0.5), 24 | A.Cutout(max_h_size=to_size//12, max_w_size=to_size//12, p=0.5), 25 | ]) 26 | elif augment == 'train_hard': 27 | augs.extend([ 28 | A.RandomCrop(height=to_size, width=to_size), 29 | A.IAAAffine(scale=1.3, translate_percent=0.2, translate_px=None, 30 | rotate=40, shear=20), 31 | A.RandomBrightnessContrast(brightness_limit=0.5, contrast_limit=0.5, p=0.75), 32 | A.IAAPerspective(p=1), 33 | A.IAAAdditiveGaussianNoise(p=0.2), 34 | A.Blur(p=0.5), 35 | A.Cutout(max_h_size=to_size//12, max_w_size=to_size//12, p=0.5), 36 | ]) 37 | elif augment == 'test': 38 | augs.extend([ 39 | A.CenterCrop(height=to_size, width=to_size) 40 | ]) 41 | elif augment == 'tta': 42 | augs.extend([ 43 | A.IAAAffine(scale=1.05, translate_percent=0.1, translate_px=None, 44 | rotate=20, shear=10), 45 | A.CenterCrop(height=to_size, width=to_size), 46 | ]) 47 | else: 48 | raise Exception(f'aug level not supported: {augment}') 49 | return A.Compose(augs + [A.Normalize(samplewise=(normalize=='samplewise'))]) 50 | 51 | 52 | class WhaleImagesPlus(WhaleImages): 53 | def __init__(self, path, images, labels, re_size=256, to_size=224, part=0, n_part=1, 54 | augment='normal', normalize='samplewise'): 55 | super().__init__(path, images, labels, re_size=re_size, to_size=to_size, train=(augment=='train')) 56 | self.transform = get_aug_plus(re_size=re_size, to_size=to_size, augment=augment, normalize=normalize) 57 | self.part, self.n_part = part, n_part 58 | 59 | def __getitem__(self, item): 60 | instance = self.loader(self.datasetid_to_filepath[item]) 61 | if self.n_part > 1: 62 | instance = partition_np_image(instance, self.part, n_part=self.n_part) 63 | instance = self.transform(image=instance)['image'] 64 | instance = self.to_tensor(instance) 65 | label = self.datasetid_to_class_id[item] 66 | return instance, label 67 | 68 | 69 | def calculate_results_plus(weight_file, output_path, SZ, get_model_fn, device, train_csv='data/data.csv', 70 | data_train='data/train', data_test='data/test', 71 | data_type='normal', normalize='samplewise', part=0, n_part=1, N_TTA=4): 72 | weight_file = Path(weight_file) 73 | output_path = Path(output_path) 74 | ensure_folder(output_path) 75 | submission_file_stem = ('NS_' if normalize=='samplewise' else '') + weight_file.stem 76 | 77 | # Training samples 78 | df = pd.read_csv(train_csv) 79 | df = df[df.Id != 'new_whale'] 80 | images = df.Image.values 81 | labels = df.Id.values 82 | 83 | # Test samples 84 | test_images = get_test_images(data_test) 85 | dummy_test_gts = list(range(len(test_images))) 86 | 87 | print(f'Training samples: {len(images)}, # of labels: {len(list(set(labels)))}.') 88 | print(f'Test samples: {len(test_images)}.') 89 | print(f'Work in progress for {submission_file_stem}...') 90 | 91 | # Making dataloaders 92 | def get_dl(images, labels, folder, SZ=SZ, batch_size=64, augment='test', normalize='samplewise'): 93 | if data_type == 'normal': 94 | ds = WhaleImagesPlus(folder, images, labels, re_size=SZ, to_size=SZ, 95 | augment=augment, normalize=normalize, part=part, n_part=n_part) 96 | else: 97 | raise ValueError('invalid data type') 98 | dl = DataLoader(ds, batch_size=batch_size) 99 | return dl 100 | 101 | # 1. NORMAL RESULT 102 | # Make prototypes 103 | trn_dl = get_dl(images, labels, data_train) 104 | model = get_model_fn(device=device, weight_file=weight_file) 105 | proto_net = ExtModelProtoNetClf(model, trn_dl.dataset.classes, device) 106 | 107 | proto_net.make_prototypes(trn_dl) 108 | 109 | # Calculate distances 110 | test_dl = get_dl(test_images, dummy_test_gts, data_test) 111 | test_embs, gts = proto_net.get_embeddings(test_dl) 112 | test_dists = proto_net.predict_embeddings(test_embs, softmax=False) 113 | 114 | np.save(output_path/f'test_dists_{submission_file_stem}.npy', test_dists) 115 | np.save(output_path/f'prototypes_{submission_file_stem}.npy', np.array([x.mean() for x in proto_net.prototypes])) 116 | 117 | # 2. PTA RESULT 118 | print(f'Work in progress for PTA_{submission_file_stem}...') 119 | trn_dl = get_dl(images, labels, data_train, augment='tta', normalize=normalize) 120 | proto_net.make_prototypes(trn_dl, repeat=N_TTA, update=True) 121 | 122 | test_dists = proto_net.predict_embeddings(test_embs, softmax=False) 123 | 124 | np.save(output_path/f'test_dists_PTA_{submission_file_stem}.npy', test_dists) 125 | np.save(output_path/f'prototypes_PTA_{submission_file_stem}.npy', np.array([x.mean() for x in proto_net.prototypes])) 126 | 127 | # 3. PTTA RESULT 128 | print(f'Work in progress for PTTA_{submission_file_stem}...') 129 | test_dl = get_dl(test_images, dummy_test_gts, data_test, augment='tta', normalize=normalize) 130 | tta_embs = [] 131 | for i in range(N_TTA): 132 | embs, gts = proto_net.get_embeddings(test_dl) 133 | tta_embs.append(embs) 134 | all_test_embs = np.array([test_embs] + tta_embs) 135 | mean_test_embs = all_test_embs.mean(axis=0) 136 | 137 | test_dists = proto_net.predict_embeddings(mean_test_embs, softmax=False) 138 | 139 | np.save(output_path/f'test_dists_PTTA_{submission_file_stem}.npy', test_dists) 140 | -------------------------------------------------------------------------------- /few_shot/datasets.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import Dataset 2 | import torch 3 | from PIL import Image 4 | from torchvision import transforms 5 | from skimage import io 6 | from tqdm import tqdm 7 | import pandas as pd 8 | import numpy as np 9 | import os 10 | 11 | from config import DATA_PATH 12 | 13 | 14 | class OmniglotDataset(Dataset): 15 | def __init__(self, subset): 16 | """Dataset class representing Omniglot dataset 17 | 18 | # Arguments: 19 | subset: Whether the dataset represents the background or evaluation set 20 | """ 21 | if subset not in ('background', 'evaluation'): 22 | raise(ValueError, 'subset must be one of (background, evaluation)') 23 | self.subset = subset 24 | 25 | self.df = pd.DataFrame(self.index_subset(self.subset)) 26 | 27 | # Index of dataframe has direct correspondence to item in dataset 28 | self.df = self.df.assign(id=self.df.index.values) 29 | 30 | # Convert arbitrary class names of dataset to ordered 0-(num_speakers - 1) integers 31 | self.unique_characters = sorted(self.df['class_name'].unique()) 32 | self.class_name_to_id = {self.unique_characters[i]: i for i in range(self.num_classes())} 33 | self.df = self.df.assign(class_id=self.df['class_name'].apply(lambda c: self.class_name_to_id[c])) 34 | 35 | # Create dicts 36 | self.datasetid_to_filepath = self.df.to_dict()['filepath'] 37 | self.datasetid_to_class_id = self.df.to_dict()['class_id'] 38 | 39 | def __getitem__(self, item): 40 | instance = io.imread(self.datasetid_to_filepath[item]) 41 | # Reindex to channels first format as supported by pytorch 42 | instance = instance[np.newaxis, :, :] 43 | 44 | # Normalise to 0-1 45 | instance = (instance - instance.min()) / (instance.max() - instance.min()) 46 | 47 | label = self.datasetid_to_class_id[item] 48 | 49 | return torch.from_numpy(instance), label 50 | 51 | def __len__(self): 52 | return len(self.df) 53 | 54 | def num_classes(self): 55 | return len(self.df['class_name'].unique()) 56 | 57 | @staticmethod 58 | def index_subset(subset): 59 | """Index a subset by looping through all of its files and recording relevant information. 60 | 61 | # Arguments 62 | subset: Name of the subset 63 | 64 | # Returns 65 | A list of dicts containing information about all the image files in a particular subset of the 66 | Omniglot dataset dataset 67 | """ 68 | images = [] 69 | print('Indexing {}...'.format(subset)) 70 | # Quick first pass to find total for tqdm bar 71 | subset_len = 0 72 | for root, folders, files in os.walk(DATA_PATH + '/Omniglot/images_{}/'.format(subset)): 73 | subset_len += len([f for f in files if f.endswith('.png')]) 74 | 75 | progress_bar = tqdm(total=subset_len) 76 | for root, folders, files in os.walk(DATA_PATH + '/Omniglot/images_{}/'.format(subset)): 77 | if len(files) == 0: 78 | continue 79 | 80 | alphabet = root.split('/')[-2] 81 | class_name = '{}.{}'.format(alphabet, root.split('/')[-1]) 82 | 83 | for f in files: 84 | progress_bar.update(1) 85 | images.append({ 86 | 'subset': subset, 87 | 'alphabet': alphabet, 88 | 'class_name': class_name, 89 | 'filepath': os.path.join(root, f) 90 | }) 91 | 92 | progress_bar.close() 93 | return images 94 | 95 | 96 | class MiniImageNet(Dataset): 97 | def __init__(self, subset): 98 | """Dataset class representing miniImageNet dataset 99 | 100 | # Arguments: 101 | subset: Whether the dataset represents the background or evaluation set 102 | """ 103 | if subset not in ('background', 'evaluation'): 104 | raise(ValueError, 'subset must be one of (background, evaluation)') 105 | self.subset = subset 106 | 107 | self.df = pd.DataFrame(self.index_subset(self.subset)) 108 | 109 | # Index of dataframe has direct correspondence to item in dataset 110 | self.df = self.df.assign(id=self.df.index.values) 111 | 112 | # Convert arbitrary class names of dataset to ordered 0-(num_speakers - 1) integers 113 | self.unique_characters = sorted(self.df['class_name'].unique()) 114 | self.class_name_to_id = {self.unique_characters[i]: i for i in range(self.num_classes())} 115 | self.df = self.df.assign(class_id=self.df['class_name'].apply(lambda c: self.class_name_to_id[c])) 116 | 117 | # Create dicts 118 | self.datasetid_to_filepath = self.df.to_dict()['filepath'] 119 | self.datasetid_to_class_id = self.df.to_dict()['class_id'] 120 | 121 | # Setup transforms 122 | self.transform = transforms.Compose([ 123 | transforms.CenterCrop(224), 124 | transforms.Resize(84), 125 | transforms.ToTensor(), 126 | transforms.Normalize(mean=[0.485, 0.456, 0.406], 127 | std=[0.229, 0.224, 0.225]) 128 | ]) 129 | 130 | def __getitem__(self, item): 131 | instance = Image.open(self.datasetid_to_filepath[item]) 132 | instance = self.transform(instance) 133 | label = self.datasetid_to_class_id[item] 134 | return instance, label 135 | 136 | def __len__(self): 137 | return len(self.df) 138 | 139 | def num_classes(self): 140 | return len(self.df['class_name'].unique()) 141 | 142 | @staticmethod 143 | def index_subset(subset): 144 | """Index a subset by looping through all of its files and recording relevant information. 145 | 146 | # Arguments 147 | subset: Name of the subset 148 | 149 | # Returns 150 | A list of dicts containing information about all the image files in a particular subset of the 151 | miniImageNet dataset 152 | """ 153 | images = [] 154 | print('Indexing {}...'.format(subset)) 155 | # Quick first pass to find total for tqdm bar 156 | subset_len = 0 157 | for root, folders, files in os.walk(DATA_PATH + '/miniImageNet/images_{}/'.format(subset)): 158 | subset_len += len([f for f in files if f.endswith('.png')]) 159 | 160 | progress_bar = tqdm(total=subset_len) 161 | for root, folders, files in os.walk(DATA_PATH + '/miniImageNet/images_{}/'.format(subset)): 162 | if len(files) == 0: 163 | continue 164 | 165 | class_name = root.split('/')[-1] 166 | 167 | for f in files: 168 | progress_bar.update(1) 169 | images.append({ 170 | 'subset': subset, 171 | 'class_name': class_name, 172 | 'filepath': os.path.join(root, f) 173 | }) 174 | 175 | progress_bar.close() 176 | return images 177 | 178 | 179 | class DummyDataset(Dataset): 180 | def __init__(self, samples_per_class=10, n_classes=10, n_features=1): 181 | """Dummy dataset for debugging/testing purposes 182 | 183 | A sample from the DummyDataset has (n_features + 1) features. The first feature is the index of the sample 184 | in the data and the remaining features are the class index. 185 | 186 | # Arguments 187 | samples_per_class: Number of samples per class in the dataset 188 | n_classes: Number of distinct classes in the dataset 189 | n_features: Number of extra features each sample should have. 190 | """ 191 | self.samples_per_class = samples_per_class 192 | self.n_classes = n_classes 193 | self.n_features = n_features 194 | 195 | # Create a dataframe to be consistent with other Datasets 196 | self.df = pd.DataFrame({ 197 | 'class_id': [i % self.n_classes for i in range(len(self))] 198 | }) 199 | self.df = self.df.assign(id=self.df.index.values) 200 | 201 | def __len__(self): 202 | return self.samples_per_class * self.n_classes 203 | 204 | def __getitem__(self, item): 205 | class_id = item % self.n_classes 206 | return np.array([item] + [class_id]*self.n_features, dtype=np.float), float(class_id) 207 | -------------------------------------------------------------------------------- /few_shot/core.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import Sampler 2 | from typing import List, Iterable, Callable, Tuple 3 | import numpy as np 4 | import torch 5 | 6 | from few_shot.metrics import categorical_accuracy 7 | from few_shot.callbacks import Callback 8 | 9 | 10 | class NShotTaskSampler(Sampler): 11 | def __init__(self, 12 | dataset: torch.utils.data.Dataset, 13 | episodes_per_epoch: int = None, 14 | n: int = None, 15 | k: int = None, 16 | q: int = None, 17 | num_tasks: int = 1, 18 | fixed_tasks: List[Iterable[int]] = None): 19 | """PyTorch Sampler subclass that generates batches of n-shot, k-way, q-query tasks. 20 | 21 | Each n-shot task contains a "support set" of `k` sets of `n` samples and a "query set" of `k` sets 22 | of `q` samples. The support set and the query set are all grouped into one Tensor such that the first n * k 23 | samples are from the support set while the remaining q * k samples are from the query set. 24 | 25 | The support and query sets are sampled such that they are disjoint i.e. do not contain overlapping samples. 26 | 27 | # Arguments 28 | dataset: Instance of torch.utils.data.Dataset from which to draw samples 29 | episodes_per_epoch: Arbitrary number of batches of n-shot tasks to generate in one epoch 30 | n_shot: int. Number of samples for each class in the n-shot classification tasks. 31 | k_way: int. Number of classes in the n-shot classification tasks. 32 | q_queries: int. Number query samples for each class in the n-shot classification tasks. 33 | num_tasks: Number of n-shot tasks to group into a single batch 34 | fixed_tasks: If this argument is specified this Sampler will always generate tasks from 35 | the specified classes 36 | """ 37 | super(NShotTaskSampler, self).__init__(dataset) 38 | self.episodes_per_epoch = episodes_per_epoch 39 | self.dataset = dataset 40 | if num_tasks < 1: 41 | raise ValueError('num_tasks must be > 1.') 42 | 43 | self.num_tasks = num_tasks 44 | # TODO: Raise errors if initialise badly 45 | self.k = k 46 | self.n = n 47 | self.q = q 48 | self.fixed_tasks = fixed_tasks 49 | 50 | self.i_task = 0 51 | 52 | def __len__(self): 53 | return self.episodes_per_epoch 54 | 55 | def __iter__(self): 56 | for _ in range(self.episodes_per_epoch): 57 | batch = [] 58 | 59 | for task in range(self.num_tasks): 60 | if self.fixed_tasks is None: 61 | # Get random classes 62 | episode_classes = np.random.choice(self.dataset.df['class_id'].unique(), size=self.k, replace=False) 63 | else: 64 | # Loop through classes in fixed_tasks 65 | episode_classes = self.fixed_tasks[self.i_task % len(self.fixed_tasks)] 66 | self.i_task += 1 67 | 68 | df = self.dataset.df[self.dataset.df['class_id'].isin(episode_classes)] 69 | 70 | support_k = {k: None for k in episode_classes} 71 | for k in episode_classes: 72 | # Select support examples 73 | support = df[df['class_id'] == k].sample(self.n) 74 | support_k[k] = support 75 | 76 | for i, s in support.iterrows(): 77 | batch.append(s['id']) 78 | 79 | for k in episode_classes: 80 | query = df[(df['class_id'] == k) & (~df['id'].isin(support_k[k]['id']))].sample(self.q) 81 | for i, q in query.iterrows(): 82 | batch.append(q['id']) 83 | 84 | yield np.stack(batch) 85 | 86 | 87 | class EvaluateFewShot(Callback): 88 | """Evaluate a network on an n-shot, k-way classification tasks after every epoch. 89 | 90 | # Arguments 91 | eval_fn: Callable to perform few-shot classification. Examples include `proto_net_episode`, 92 | `matching_net_episode` and `meta_gradient_step` (MAML). 93 | num_tasks: int. Number of n-shot classification tasks to evaluate the model with. 94 | n_shot: int. Number of samples for each class in the n-shot classification tasks. 95 | k_way: int. Number of classes in the n-shot classification tasks. 96 | q_queries: int. Number query samples for each class in the n-shot classification tasks. 97 | task_loader: Instance of NShotWrapper class 98 | prepare_batch: function. The preprocessing function to apply to samples from the dataset. 99 | prefix: str. Prefix to identify dataset. 100 | """ 101 | 102 | def __init__(self, 103 | eval_fn: Callable, 104 | num_tasks: int, 105 | n_shot: int, 106 | k_way: int, 107 | q_queries: int, 108 | taskloader: torch.utils.data.DataLoader, 109 | prepare_batch: Callable, 110 | prefix: str = 'val_', 111 | **kwargs): 112 | super(EvaluateFewShot, self).__init__() 113 | self.eval_fn = eval_fn 114 | self.num_tasks = num_tasks 115 | self.n_shot = n_shot 116 | self.k_way = k_way 117 | self.q_queries = q_queries 118 | self.taskloader = taskloader 119 | self.prepare_batch = prepare_batch 120 | self.prefix = prefix 121 | self.kwargs = kwargs 122 | self.metric_name = f'{self.prefix}{self.n_shot}-shot_{self.k_way}-way_acc' 123 | 124 | def on_train_begin(self, logs=None): 125 | self.loss_fn = self.params['loss_fn'] 126 | self.optimiser = self.params['optimiser'] 127 | 128 | def on_epoch_end(self, epoch, logs=None): 129 | logs = logs or {} 130 | seen = 0 131 | totals = {'loss': 0, self.metric_name: 0} 132 | for batch_index, batch in enumerate(self.taskloader): 133 | x, y = self.prepare_batch(batch) 134 | 135 | loss, y_pred = self.eval_fn( 136 | self.model, 137 | self.optimiser, 138 | self.loss_fn, 139 | x, 140 | y, 141 | n_shot=self.n_shot, 142 | k_way=self.k_way, 143 | q_queries=self.q_queries, 144 | train=False, 145 | **self.kwargs 146 | ) 147 | 148 | seen += y_pred.shape[0] 149 | 150 | totals['loss'] += loss.item() * y_pred.shape[0] 151 | totals[self.metric_name] += categorical_accuracy(y, y_pred) * y_pred.shape[0] 152 | 153 | logs[self.prefix + 'loss'] = totals['loss'] / seen 154 | logs[self.metric_name] = totals[self.metric_name] / seen 155 | 156 | 157 | def prepare_nshot_task(n: int, k: int, q: int) -> Callable: 158 | """Typical n-shot task preprocessing. 159 | 160 | # Arguments 161 | n: Number of samples for each class in the n-shot classification task 162 | k: Number of classes in the n-shot classification task 163 | q: Number of query samples for each class in the n-shot classification task 164 | 165 | # Returns 166 | prepare_nshot_task_: A Callable that processes a few shot tasks with specified n, k and q 167 | """ 168 | def prepare_nshot_task_(batch: Tuple[torch.Tensor, torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]: 169 | """Create 0-k label and move to GPU. 170 | 171 | TODO: Move to arbitrary device 172 | """ 173 | x, y = batch 174 | x = x.float().cuda() 175 | # Create dummy 0-(num_classes - 1) label 176 | y = create_nshot_task_label(k, q).cuda() 177 | return x, y 178 | 179 | return prepare_nshot_task_ 180 | 181 | 182 | def create_nshot_task_label(k: int, q: int) -> torch.Tensor: 183 | """Creates an n-shot task label. 184 | 185 | Label has the structure: 186 | [0]*q + [1]*q + ... + [k-1]*q 187 | 188 | # TODO: Test this 189 | 190 | # Arguments 191 | k: Number of classes in the n-shot classification task 192 | q: Number of query samples for each class in the n-shot classification task 193 | 194 | # Returns 195 | y: Label vector for n-shot task of shape [q * k, ] 196 | """ 197 | y = torch.arange(0, k, 1 / q).long() 198 | return y 199 | -------------------------------------------------------------------------------- /app/whale/whale_utils.py: -------------------------------------------------------------------------------- 1 | from dlcliche.image import * 2 | from dlcliche.math import * 3 | from sklearn.decomposition import PCA 4 | from mpl_toolkits.mplot3d import axes3d 5 | from IPython.display import display 6 | 7 | from torch.utils.data import DataLoader, Dataset 8 | from torchvision import transforms 9 | import albumentations as A 10 | 11 | sys.path.append('..') # app 12 | sys.path.append('../..') # root 13 | from few_shot.extmodel_proto_net_clf import ExtModelProtoNetClf 14 | from config import DATA_PATH 15 | from app_utils_clf import * 16 | 17 | 18 | def get_test_images(data_test): 19 | return sorted([str(f).replace(data_test+'/', '') for f in Path(data_test).glob('*.jpg')]) 20 | 21 | 22 | def get_training_data_lists(sampling_type='more_than_two', train_csv=DATA_PATH+'/train.csv'): 23 | """Get lists of training data for train/valid images/labels according to sampling type.""" 24 | df = pd.read_csv(train_csv) 25 | 26 | if sampling_type == 'more_than_two': 27 | df = df[df.Id != 'new_whale'] 28 | ids = df.Id.values 29 | classes = sorted(list(set(ids))) 30 | images = df.Image.values 31 | all_cls2imgs = {cls:images[ids == cls] for cls in classes} 32 | 33 | trn_images = [image for image, _id in zip(images, ids) if len(all_cls2imgs[_id]) >= 2] 34 | trn_labels = [_id for image, _id in zip(images, ids) if len(all_cls2imgs[_id]) >= 2] 35 | val_images = [image for image, _id in zip(images, ids) if len(all_cls2imgs[_id]) == 2] 36 | val_labels = [_id for image, _id in zip(images, ids) if len(all_cls2imgs[_id]) == 2] 37 | elif sampling_type == 'exhaustive': 38 | # Assign fake Id to new_whale 39 | n_new_whale = len(df[df.Id == 'new_whale']) 40 | df.at[df.Id == 'new_whale', 'Id'] = [f'new{i:05d}' for i in range(n_new_whale)] 41 | 42 | ids = df.Id.values 43 | classes = sorted(list(set(ids))) 44 | images = df.Image.values 45 | all_cls2imgs = {cls:images[ids == cls] for cls in classes} 46 | 47 | # Duplicate all the single image classes 48 | single_images = [image for image, _id in zip(images, ids) if len(all_cls2imgs[_id]) == 1] 49 | single_labels = [_id for image, _id in zip(images, ids) if len(all_cls2imgs[_id]) == 1] 50 | 51 | trn_images = list(images) + single_images 52 | trn_labels = list(ids) + single_labels 53 | val_images = [image for image, _id in zip(images, ids) if len(all_cls2imgs[_id]) > 2] 54 | val_labels = [_id for image, _id in zip(images, ids) if len(all_cls2imgs[_id]) > 2] 55 | else: 56 | raise ValueError('unknown sampling_type option') 57 | return trn_images, trn_labels, val_images, val_labels 58 | 59 | 60 | def get_aug(re_size=224, to_size=224, train=True): 61 | augs = [A.Resize(height=re_size, width=re_size)] 62 | if train: 63 | augs.extend([ 64 | A.RandomCrop(height=to_size, width=to_size), 65 | A.ShiftScaleRotate(shift_limit=0.1, scale_limit=0.3, rotate_limit=30, p=0.75), 66 | A.RandomBrightnessContrast(brightness_limit=0.5, contrast_limit=0.5, p=0.75), 67 | A.Blur(p=0.5), 68 | A.Cutout(max_h_size=to_size//12, max_w_size=to_size//12, p=0.5), 69 | ]) 70 | else: 71 | augs.extend([A.CenterCrop(height=to_size, width=to_size)]) 72 | return A.Compose(augs + [A.Normalize()]) 73 | 74 | 75 | def get_img_loader(folder, to_gray=False): 76 | def _loader(filename): 77 | img = load_rgb_image(folder + '/' + str(filename)) 78 | if to_gray: 79 | img = np.mean(img, axis=-1).astype(np.uint8) 80 | img = np.stack((img,)*3, axis=-1) 81 | return img 82 | return _loader 83 | 84 | 85 | class WhaleImages(Dataset): 86 | def __init__(self, path, images, labels, re_size=256, to_size=224, train=True): 87 | self.datasetid_to_filepath = images 88 | self.datasetid_to_class_id = labels 89 | self.classes = sorted(list(set(labels))) 90 | 91 | self.df = pd.DataFrame({'class_id':labels, 'id':list(range(len(images)))}) 92 | 93 | self.loader = get_img_loader(path, to_gray=True) 94 | self.transform = get_aug(re_size=re_size, to_size=to_size, train=train) 95 | self.to_tensor = transforms.ToTensor() 96 | 97 | def __getitem__(self, item): 98 | instance = self.loader(self.datasetid_to_filepath[item]) 99 | instance = self.transform(image=instance)['image'] 100 | instance = self.to_tensor(instance) 101 | label = self.datasetid_to_class_id[item] 102 | return instance, label 103 | 104 | def __len__(self): 105 | return len(self.df) 106 | 107 | def num_classes(self): 108 | return len(self.cls2imgs) 109 | 110 | 111 | def plot_prototype_2d_space_distribution(prototypes): 112 | X = prototypes 113 | pca = PCA(n_components=2) 114 | X_pca = pca.fit(X).transform(X) 115 | print('PCA: Explained variance ratio: %s' 116 | % str(pca.explained_variance_ratio_)) 117 | plt.figure() 118 | plt.scatter(X_pca[:, 0], X_pca[:, 1], alpha=.6) 119 | plt.title('Prototype Distribution PCA') 120 | plt.xlim((-4, 4)) 121 | plt.ylim((-3, 3)) 122 | plt.show() 123 | return X_pca 124 | 125 | 126 | def plot_prototype_3d_space_distribution(prototypes): 127 | X = prototypes 128 | pca = PCA(n_components=3) 129 | X_pca = pca.fit(X).transform(X) 130 | print('PCA: Explained variance ratio: %s' 131 | % str(pca.explained_variance_ratio_)) 132 | fig = plt.figure(figsize=(10,10)) 133 | ax = fig.add_subplot(111, projection='3d') 134 | ax.scatter3D(X_pca[:, 0],X_pca[:, 1],X_pca[:, 2]) 135 | ax.set_title('Prototype Distribution PCA') 136 | ax.set_xlim((-4, 4)) 137 | ax.set_ylim((-3, 3)) 138 | ax.set_zlim((-3, 3)) 139 | plt.show() 140 | return X_pca 141 | 142 | 143 | def get_classes(data='data', except_new_whale=True, append_new_whale_last=True): 144 | df = pd.read_csv(data+'/train.csv') 145 | if except_new_whale: 146 | df = df[df.Id != 'new_whale'] 147 | classes = sorted(list(set(df.Id.values))) 148 | if append_new_whale_last: 149 | classes.append('new_whale') 150 | return classes 151 | 152 | 153 | def calculate_results(weight, SZ, get_model_fn, device, train_csv='data/data.csv', 154 | data_train='data/train', data_test='data/test'): 155 | # Training samples 156 | df = pd.read_csv(train_csv) 157 | df = df[df.Id != 'new_whale'] 158 | images = df.Image.values 159 | labels = df.Id.values 160 | 161 | # Test samples 162 | test_images = get_test_images(data_test) 163 | dummy_test_gts = list(range(len(test_images))) 164 | 165 | print(f'Training samples: {len(images)}, # of labels: {len(list(set(labels)))}.') 166 | print(f'Test samples: {len(test_images)}.') 167 | print(f'Work in progress for {weight}...') 168 | 169 | def get_dl(images, labels, folder, SZ=SZ, batch_size=64): 170 | ds = WhaleImages(folder, images, labels, re_size=SZ, to_size=SZ, train=False) 171 | dl = DataLoader(ds, batch_size=batch_size) 172 | return dl 173 | 174 | # Make prototypes 175 | trn_dl = get_dl(images, labels, data_train) 176 | model = get_model_fn(device=device, weight_file=weight+'.pth') 177 | proto_net = ExtModelProtoNetClf(model, trn_dl.dataset.classes, device) 178 | 179 | proto_net.make_prototypes(trn_dl) 180 | 181 | # Calculate distances 182 | test_dl = get_dl(test_images, dummy_test_gts, data_test) 183 | test_embs, gts = proto_net.get_embeddings(test_dl) 184 | test_dists = proto_net.predict_embeddings(test_embs, softmax=False) 185 | 186 | np.save(f'test_dists_{weight}.npy', test_dists) 187 | np.save(f'prototypes_{weight}.npy', np.array([x.mean() for x in proto_net.prototypes])) 188 | 189 | 190 | # Thanks to https://github.com/radekosmulski/whale/blob/master/utils.py 191 | def top_5_preds(preds): return np.argsort(preds.numpy())[:, ::-1][:, :5] 192 | 193 | def top_5_pred_labels(preds, classes): 194 | top_5 = top_5_preds(preds) 195 | labels = [] 196 | for i in range(top_5.shape[0]): 197 | labels.append(' '.join([classes[idx] for idx in top_5[i]])) 198 | return labels 199 | 200 | 201 | def prepare_submission(submission_filename, test_dists, new_whale_thresh, data_test, classes): 202 | def _create_proto_submission(preds, name, classes): 203 | sub = pd.DataFrame({'Image': get_test_images(data_test)}) 204 | sub['Id'] = [classes[i] if not isinstance(i, str) else i for i in 205 | top_5_pred_labels(torch.tensor(preds), classes)] 206 | ensure_folder('subs') 207 | sub.to_csv(f'subs/{name}.csv.gz', index=False, compression='gzip') 208 | 209 | dist_new_whale = np.ones_like(test_dists[:, :1]) 210 | dist_new_whale[:] = new_whale_thresh 211 | final_answer = np.c_[test_dists, dist_new_whale] 212 | 213 | _create_proto_submission(final_answer, submission_filename, classes) 214 | print(submission_filename, 215 | pd.read_csv(f'subs/{submission_filename}.csv.gz').Id.str.split().apply(lambda x: x[0] == 'new_whale').mean(), 216 | len(set(pd.read_csv(f'subs/{submission_filename}.csv.gz').Id.str.split().apply(lambda x: x[0]).values))) 217 | display(pd.read_csv(f'subs/{submission_filename}.csv.gz').head()) 218 | -------------------------------------------------------------------------------- /few_shot/models.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | import numpy as np 3 | import torch.nn.functional as F 4 | import torch 5 | from typing import Dict 6 | 7 | 8 | ########## 9 | # Layers # 10 | ########## 11 | class Flatten(nn.Module): 12 | """Converts N-dimensional Tensor of shape [batch_size, d1, d2, ..., dn] to 2-dimensional Tensor 13 | of shape [batch_size, d1*d2*...*dn]. 14 | 15 | # Arguments 16 | input: Input tensor 17 | """ 18 | def forward(self, input): 19 | return input.view(input.size(0), -1) 20 | 21 | 22 | class GlobalMaxPool1d(nn.Module): 23 | """Performs global max pooling over the entire length of a batched 1D tensor 24 | 25 | # Arguments 26 | input: Input tensor 27 | """ 28 | def forward(self, input): 29 | return nn.functional.max_pool1d(input, kernel_size=input.size()[2:]).view(-1, input.size(1)) 30 | 31 | 32 | class GlobalAvgPool2d(nn.Module): 33 | """Performs global average pooling over the entire height and width of a batched 2D tensor 34 | 35 | # Arguments 36 | input: Input tensor 37 | """ 38 | def forward(self, input): 39 | return nn.functional.avg_pool2d(input, kernel_size=input.size()[2:]).view(-1, input.size(1)) 40 | 41 | 42 | def conv_block(in_channels: int, out_channels: int) -> nn.Module: 43 | """Returns a Module that performs 3x3 convolution, ReLu activation, 2x2 max pooling. 44 | 45 | # Arguments 46 | in_channels: 47 | out_channels: 48 | """ 49 | return nn.Sequential( 50 | nn.Conv2d(in_channels, out_channels, 3, padding=1), 51 | nn.BatchNorm2d(out_channels), 52 | nn.ReLU(), 53 | nn.MaxPool2d(kernel_size=2, stride=2) 54 | ) 55 | 56 | 57 | def functional_conv_block(x: torch.Tensor, weights: torch.Tensor, biases: torch.Tensor, 58 | bn_weights, bn_biases) -> torch.Tensor: 59 | """Performs 3x3 convolution, ReLu activation, 2x2 max pooling in a functional fashion. 60 | 61 | # Arguments: 62 | x: Input Tensor for the conv block 63 | weights: Weights for the convolutional block 64 | biases: Biases for the convolutional block 65 | bn_weights: 66 | bn_biases: 67 | """ 68 | x = F.conv2d(x, weights, biases, padding=1) 69 | x = F.batch_norm(x, running_mean=None, running_var=None, weight=bn_weights, bias=bn_biases, training=True) 70 | x = F.relu(x) 71 | x = F.max_pool2d(x, kernel_size=2, stride=2) 72 | return x 73 | 74 | 75 | ########## 76 | # Models # 77 | ########## 78 | def get_few_shot_encoder(num_input_channels=1) -> nn.Module: 79 | """Creates a few shot encoder as used in Matching and Prototypical Networks 80 | 81 | # Arguments: 82 | num_input_channels: Number of color channels the model expects input data to contain. Omniglot = 1, 83 | miniImageNet = 3 84 | """ 85 | return nn.Sequential( 86 | conv_block(num_input_channels, 64), 87 | conv_block(64, 64), 88 | conv_block(64, 64), 89 | conv_block(64, 64), 90 | Flatten(), 91 | ) 92 | 93 | 94 | class FewShotClassifier(nn.Module): 95 | def __init__(self, num_input_channels: int, k_way: int, final_layer_size: int = 64): 96 | """Creates a few shot classifier as used in MAML. 97 | 98 | This network should be identical to the one created by `get_few_shot_encoder` but with a 99 | classification layer on top. 100 | 101 | # Arguments: 102 | num_input_channels: Number of color channels the model expects input data to contain. Omniglot = 1, 103 | miniImageNet = 3 104 | k_way: Number of classes the model will discriminate between 105 | final_layer_size: 64 for Omniglot, 1600 for miniImageNet 106 | """ 107 | super(FewShotClassifier, self).__init__() 108 | self.conv1 = conv_block(num_input_channels, 64) 109 | self.conv2 = conv_block(64, 64) 110 | self.conv3 = conv_block(64, 64) 111 | self.conv4 = conv_block(64, 64) 112 | 113 | self.logits = nn.Linear(final_layer_size, k_way) 114 | 115 | def forward(self, x): 116 | x = self.conv1(x) 117 | x = self.conv2(x) 118 | x = self.conv3(x) 119 | x = self.conv4(x) 120 | 121 | x = x.view(x.size(0), -1) 122 | 123 | return self.logits(x) 124 | 125 | def functional_forward(self, x, weights): 126 | """Applies the same forward pass using PyTorch functional operators using a specified set of weights.""" 127 | 128 | for block in [1, 2, 3, 4]: 129 | x = functional_conv_block(x, weights[f'conv{block}.0.weight'], weights[f'conv{block}.0.bias'], 130 | weights.get(f'conv{block}.1.weight'), weights.get(f'conv{block}.1.bias')) 131 | 132 | x = x.view(x.size(0), -1) 133 | 134 | x = F.linear(x, weights['logits.weight'], weights['logits.bias']) 135 | 136 | return x 137 | 138 | 139 | class MatchingNetwork(nn.Module): 140 | def __init__(self, n: int, k: int, q: int, fce: bool, num_input_channels: int, 141 | lstm_layers: int, lstm_input_size: int, unrolling_steps: int, device: torch.device): 142 | """Creates a Matching Network as described in Vinyals et al. 143 | 144 | # Arguments: 145 | n: Number of examples per class in the support set 146 | k: Number of classes in the few shot classification task 147 | q: Number of examples per class in the query set 148 | fce: Whether or not to us fully conditional embeddings 149 | num_input_channels: Number of color channels the model expects input data to contain. Omniglot = 1, 150 | miniImageNet = 3 151 | lstm_layers: Number of LSTM layers in the bidrectional LSTM g that embeds the support set (fce = True) 152 | lstm_input_size: Input size for the bidirectional and Attention LSTM. This is determined by the embedding 153 | dimension of the few shot encoder which is in turn determined by the size of the input data. Hence we 154 | have Omniglot -> 64, miniImageNet -> 1600. 155 | unrolling_steps: Number of unrolling steps to run the Attention LSTM 156 | device: Device on which to run computation 157 | """ 158 | super(MatchingNetwork, self).__init__() 159 | self.n = n 160 | self.k = k 161 | self.q = q 162 | self.fce = fce 163 | self.num_input_channels = num_input_channels 164 | self.encoder = get_few_shot_encoder(self.num_input_channels) 165 | if self.fce: 166 | self.g = BidrectionalLSTM(lstm_input_size, lstm_layers).to(device, dtype=torch.float) 167 | self.f = AttentionLSTM(lstm_input_size, unrolling_steps=unrolling_steps).to(device, dtype=torch.float) 168 | 169 | def forward(self, inputs): 170 | pass 171 | 172 | 173 | class BidrectionalLSTM(nn.Module): 174 | def __init__(self, size: int, layers: int): 175 | """Bidirectional LSTM used to generate fully conditional embeddings (FCE) of the support set as described 176 | in the Matching Networks paper. 177 | 178 | # Arguments 179 | size: Size of input and hidden layers. These are constrained to be the same in order to implement the skip 180 | connection described in Appendix A.2 181 | layers: Number of LSTM layers 182 | """ 183 | super(BidrectionalLSTM, self).__init__() 184 | self.num_layers = layers 185 | self.batch_size = 1 186 | # Force input size and hidden size to be the same in order to implement 187 | # the skip connection as described in Appendix A.1 and A.2 of Matching Networks 188 | self.lstm = nn.LSTM(input_size=size, 189 | num_layers=layers, 190 | hidden_size=size, 191 | bidirectional=True) 192 | 193 | def forward(self, inputs): 194 | # Give None as initial state and Pytorch LSTM creates initial hidden states 195 | output, (hn, cn) = self.lstm(inputs, None) 196 | 197 | forward_output = output[:, :, :self.lstm.hidden_size] 198 | backward_output = output[:, :, self.lstm.hidden_size:] 199 | 200 | # g(x_i, S) = h_forward_i + h_backward_i + g'(x_i) as written in Appendix A.2 201 | # AKA A skip connection between inputs and outputs is used 202 | output = forward_output + backward_output + inputs 203 | return output, hn, cn 204 | 205 | 206 | class AttentionLSTM(nn.Module): 207 | def __init__(self, size: int, unrolling_steps: int): 208 | """Attentional LSTM used to generate fully conditional embeddings (FCE) of the query set as described 209 | in the Matching Networks paper. 210 | 211 | # Arguments 212 | size: Size of input and hidden layers. These are constrained to be the same in order to implement the skip 213 | connection described in Appendix A.2 214 | unrolling_steps: Number of steps of attention over the support set to compute. Analogous to number of 215 | layers in a regular LSTM 216 | """ 217 | super(AttentionLSTM, self).__init__() 218 | self.unrolling_steps = unrolling_steps 219 | self.lstm_cell = nn.LSTMCell(input_size=size, 220 | hidden_size=size) 221 | 222 | def forward(self, support, queries): 223 | # Get embedding dimension, d 224 | if support.shape[-1] != queries.shape[-1]: 225 | raise(ValueError("Support and query set have different embedding dimension!")) 226 | 227 | batch_size = queries.shape[0] 228 | embedding_dim = queries.shape[1] 229 | 230 | h_hat = torch.zeros_like(queries).cuda().float() 231 | c = torch.zeros(batch_size, embedding_dim).cuda().float() 232 | 233 | for k in range(self.unrolling_steps): 234 | # Calculate hidden state cf. equation (4) of appendix A.2 235 | h = h_hat + queries 236 | 237 | # Calculate softmax attentions between hidden states and support set embeddings 238 | # cf. equation (6) of appendix A.2 239 | attentions = torch.mm(h, support.t()) 240 | attentions = attentions.softmax(dim=1) 241 | 242 | # Calculate readouts from support set embeddings cf. equation (5) 243 | readout = torch.mm(attentions, support) 244 | 245 | # Run LSTM cell cf. equation (3) 246 | # h_hat, c = self.lstm_cell(queries, (torch.cat([h, readout], dim=1), c)) 247 | h_hat, c = self.lstm_cell(queries, (h + readout, c)) 248 | 249 | h = h_hat + queries 250 | 251 | return h 252 | -------------------------------------------------------------------------------- /experiments/Test_ImageNet_ResNet18_as_ProtoNet.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# What if we use ImageNet pretrained model?" 8 | ] 9 | }, 10 | { 11 | "cell_type": "code", 12 | "execution_count": 1, 13 | "metadata": {}, 14 | "outputs": [], 15 | "source": [ 16 | "from dlcliche.notebook import *\n", 17 | "from dlcliche.utils import *\n", 18 | "sys.path.append('..')\n", 19 | "\n", 20 | "from torch.utils.data import DataLoader\n", 21 | "from torchvision import datasets, transforms\n", 22 | "from few_shot.models import get_few_shot_encoder\n", 23 | "from few_shot.callbacks import *\n", 24 | "\n", 25 | "assert torch.cuda.is_available()\n", 26 | "device = torch.device('cuda')\n", 27 | "#torch.backends.cudnn.benchmark = True\n", 28 | "\n", 29 | "from few_shot.extmodel_proto_net_clf import ExtModelProtoNetClf, BasePretrainedModel\n", 30 | "from torchvision import models\n", 31 | "from torch import nn\n", 32 | "from few_shot.models import Flatten" 33 | ] 34 | }, 35 | { 36 | "cell_type": "code", 37 | "execution_count": 2, 38 | "metadata": {}, 39 | "outputs": [], 40 | "source": [ 41 | "# Make dataset split\n", 42 | "from config import DATA_PATH\n", 43 | "DATA_PATH = Path(DATA_PATH)/'miniImageNet'\n", 44 | "SRC_EVAL_PATH = DATA_PATH/'images_evaluation'\n", 45 | "SRC_TRN_PATH = DATA_PATH/'images_background'\n", 46 | "\n", 47 | "EVAL_TRN_PATH = Path('data/clf_eval_train')\n", 48 | "EVAL_VAL_PATH = Path('data/clf_eval_valid')\n", 49 | "K_WAY = 5 # Class\n", 50 | "N_SKOT = 5 # Samples to build a prototype\n", 51 | "N_INPUT_CHANNELS = 3\n", 52 | "\n", 53 | "def rebuild_data_files(src=SRC_EVAL_PATH, trn_path=EVAL_TRN_PATH, val_path=EVAL_VAL_PATH,\n", 54 | " K_WAY=K_WAY, N_SKOT=N_SKOT):\n", 55 | " ensure_delete(trn_path)\n", 56 | " ensure_delete(val_path)\n", 57 | "\n", 58 | " classes = [str(d.name) for d in src.glob('*')][:K_WAY]\n", 59 | " for cls in classes:\n", 60 | " dest_trn_folder = trn_path/cls\n", 61 | " dest_val_folder = val_path/cls\n", 62 | " ensure_folder(dest_trn_folder)\n", 63 | " ensure_folder(dest_val_folder)\n", 64 | " files = sorted(list((src/cls).glob('*.jpg')))\n", 65 | " for i in range(N_SKOT):\n", 66 | " copy_file(files[i], dest_trn_folder/files[i].name)\n", 67 | " for i in range(N_SKOT, len(files)):\n", 68 | " copy_file(files[i], dest_val_folder/files[i].name)\n", 69 | "\n", 70 | " global plain_train_ds, train_dl, valid_ds, valid_dl\n", 71 | " plain_train_ds = datasets.ImageFolder(\n", 72 | " trn_path,\n", 73 | " transform = transforms.Compose([\n", 74 | " transforms.Resize((224, 224)),\n", 75 | " transforms.ToTensor(),\n", 76 | " transforms.Normalize(mean=[0.485, 0.456, 0.406],\n", 77 | " std=[0.229, 0.224, 0.225])\n", 78 | " ]))\n", 79 | " train_dl = DataLoader(\n", 80 | " plain_train_ds,\n", 81 | " batch_size=8,\n", 82 | " )\n", 83 | " valid_ds = datasets.ImageFolder(\n", 84 | " val_path,\n", 85 | " transform = transforms.Compose([\n", 86 | " transforms.Resize((224, 224)),\n", 87 | " transforms.ToTensor(),\n", 88 | " transforms.Normalize(mean=[0.485, 0.456, 0.406],\n", 89 | " std=[0.229, 0.224, 0.225])\n", 90 | " ]))\n", 91 | " valid_dl = DataLoader(\n", 92 | " valid_ds,\n", 93 | " batch_size=8,\n", 94 | " )" 95 | ] 96 | }, 97 | { 98 | "cell_type": "markdown", 99 | "metadata": {}, 100 | "source": [ 101 | "## Create model" 102 | ] 103 | }, 104 | { 105 | "cell_type": "code", 106 | "execution_count": 3, 107 | "metadata": {}, 108 | "outputs": [], 109 | "source": [ 110 | "model = BasePretrainedModel(base_model=models.resnet18)\n", 111 | "model.to(device, dtype=torch.float)\n", 112 | "model.eval()\n", 113 | "\n", 114 | "proto_net_clf = ExtModelProtoNetClf(model, device=device)" 115 | ] 116 | }, 117 | { 118 | "cell_type": "markdown", 119 | "metadata": {}, 120 | "source": [ 121 | "## Test 5-way" 122 | ] 123 | }, 124 | { 125 | "cell_type": "code", 126 | "execution_count": 4, 127 | "metadata": {}, 128 | "outputs": [ 129 | { 130 | "name": "stderr", 131 | "output_type": "stream", 132 | "text": [ 133 | "100%|██████████| 4/4 [00:00<00:00, 42.06it/s]\n", 134 | "2019-01-23 12:29:17,494 dlcliche.utils make_prototypes [INFO]: Making new prototypes.\n", 135 | "100%|██████████| 372/372 [00:08<00:00, 43.71it/s]\n", 136 | "100%|██████████| 2975/2975 [00:00<00:00, 20467.07it/s]\n" 137 | ] 138 | }, 139 | { 140 | "name": "stdout", 141 | "output_type": "stream", 142 | "text": [ 143 | "F1/Recall/Precision/Accuracy = (0.9882539207551945, 0.9882352941176471, 0.9883894134090533, 0.9882352941176471)\n" 144 | ] 145 | }, 146 | { 147 | "name": "stderr", 148 | "output_type": "stream", 149 | "text": [ 150 | "100%|██████████| 1/1 [00:00<00:00, 69.66it/s]\n", 151 | "2019-01-23 12:29:26,553 dlcliche.utils make_prototypes [INFO]: Making new prototypes.\n", 152 | "100%|██████████| 375/375 [00:08<00:00, 43.87it/s]\n", 153 | "100%|██████████| 2995/2995 [00:00<00:00, 21287.13it/s]\n" 154 | ] 155 | }, 156 | { 157 | "name": "stdout", 158 | "output_type": "stream", 159 | "text": [ 160 | "F1/Recall/Precision/Accuracy = (0.8297963255981461, 0.8410684474123539, 0.9035678645381544, 0.8410684474123539)\n" 161 | ] 162 | }, 163 | { 164 | "name": "stderr", 165 | "output_type": "stream", 166 | "text": [ 167 | "100%|██████████| 4/4 [00:00<00:00, 58.42it/s]\n", 168 | "2019-01-23 12:29:35,697 dlcliche.utils make_prototypes [INFO]: Making new prototypes.\n", 169 | "100%|██████████| 372/372 [00:08<00:00, 36.94it/s]\n", 170 | "100%|██████████| 2975/2975 [00:00<00:00, 20611.19it/s]\n" 171 | ] 172 | }, 173 | { 174 | "name": "stdout", 175 | "output_type": "stream", 176 | "text": [ 177 | "F1/Recall/Precision/Accuracy = (0.9804990776074115, 0.9805042016806723, 0.9811282105492662, 0.9805042016806723)\n" 178 | ] 179 | }, 180 | { 181 | "name": "stderr", 182 | "output_type": "stream", 183 | "text": [ 184 | "100%|██████████| 1/1 [00:00<00:00, 63.47it/s]\n", 185 | "2019-01-23 12:29:45,100 dlcliche.utils make_prototypes [INFO]: Making new prototypes.\n", 186 | "100%|██████████| 375/375 [00:08<00:00, 42.41it/s]\n", 187 | "100%|██████████| 2995/2995 [00:00<00:00, 20847.89it/s]" 188 | ] 189 | }, 190 | { 191 | "name": "stdout", 192 | "output_type": "stream", 193 | "text": [ 194 | "F1/Recall/Precision/Accuracy = (0.7793769948555661, 0.7796327212020033, 0.8696235374513783, 0.7796327212020033)\n" 195 | ] 196 | }, 197 | { 198 | "name": "stderr", 199 | "output_type": "stream", 200 | "text": [ 201 | "\n" 202 | ] 203 | } 204 | ], 205 | "source": [ 206 | "rebuild_data_files(src=SRC_TRN_PATH, K_WAY=5, N_SKOT=5)\n", 207 | "prototypes = proto_net_clf.make_prototypes(train_dl)\n", 208 | "print('F1/Recall/Precision/Accuracy =', proto_net_clf.evaluate(data_loader=valid_dl))\n", 209 | "rebuild_data_files(src=SRC_TRN_PATH, K_WAY=5, N_SKOT=1)\n", 210 | "prototypes = proto_net_clf.make_prototypes(train_dl)\n", 211 | "print('F1/Recall/Precision/Accuracy =', proto_net_clf.evaluate(data_loader=valid_dl))\n", 212 | "\n", 213 | "rebuild_data_files(src=SRC_EVAL_PATH, K_WAY=5, N_SKOT=5)\n", 214 | "prototypes = proto_net_clf.make_prototypes(train_dl)\n", 215 | "print('F1/Recall/Precision/Accuracy =', proto_net_clf.evaluate(data_loader=valid_dl))\n", 216 | "rebuild_data_files(src=SRC_EVAL_PATH, K_WAY=5, N_SKOT=1)\n", 217 | "prototypes = proto_net_clf.make_prototypes(train_dl)\n", 218 | "print('F1/Recall/Precision/Accuracy =', proto_net_clf.evaluate(data_loader=valid_dl))" 219 | ] 220 | }, 221 | { 222 | "cell_type": "markdown", 223 | "metadata": {}, 224 | "source": [ 225 | "## Test 80-way (only available with SRC_TRN_PATH)" 226 | ] 227 | }, 228 | { 229 | "cell_type": "code", 230 | "execution_count": 6, 231 | "metadata": {}, 232 | "outputs": [ 233 | { 234 | "name": "stderr", 235 | "output_type": "stream", 236 | "text": [ 237 | "100%|██████████| 10/10 [00:00<00:00, 46.86it/s]\n", 238 | "2019-01-23 12:31:05,807 dlcliche.utils make_prototypes [INFO]: Making new prototypes.\n", 239 | "100%|██████████| 5990/5990 [02:12<00:00, 43.69it/s]\n", 240 | "100%|██████████| 47920/47920 [00:27<00:00, 1748.40it/s]\n" 241 | ] 242 | }, 243 | { 244 | "name": "stdout", 245 | "output_type": "stream", 246 | "text": [ 247 | "F1/Recall/Precision/Accuracy = (0.5320421397494145, 0.532262103505843, 0.6443292247615288, 0.532262103505843)\n" 248 | ] 249 | }, 250 | { 251 | "name": "stderr", 252 | "output_type": "stream", 253 | "text": [ 254 | "100%|██████████| 50/50 [00:01<00:00, 43.68it/s]\n", 255 | "2019-01-23 12:33:53,551 dlcliche.utils make_prototypes [INFO]: Making new prototypes.\n", 256 | "100%|██████████| 5950/5950 [02:14<00:00, 44.39it/s]\n", 257 | "100%|██████████| 47600/47600 [00:27<00:00, 1720.76it/s]\n" 258 | ] 259 | }, 260 | { 261 | "name": "stdout", 262 | "output_type": "stream", 263 | "text": [ 264 | "F1/Recall/Precision/Accuracy = (0.8249790996404818, 0.8227100840336135, 0.8402276374817964, 0.8227100840336135)\n" 265 | ] 266 | } 267 | ], 268 | "source": [ 269 | "rebuild_data_files(src=SRC_TRN_PATH, K_WAY=80, N_SKOT=1)\n", 270 | "prototypes = proto_net_clf.make_prototypes(train_dl)\n", 271 | "print('F1/Recall/Precision/Accuracy =', proto_net_clf.evaluate(data_loader=valid_dl))\n", 272 | "rebuild_data_files(src=SRC_TRN_PATH, K_WAY=80, N_SKOT=5)\n", 273 | "prototypes = proto_net_clf.make_prototypes(train_dl)\n", 274 | "print('F1/Recall/Precision/Accuracy =', proto_net_clf.evaluate(data_loader=valid_dl))" 275 | ] 276 | }, 277 | { 278 | "cell_type": "code", 279 | "execution_count": null, 280 | "metadata": { 281 | "collapsed": true 282 | }, 283 | "outputs": [], 284 | "source": [] 285 | } 286 | ], 287 | "metadata": { 288 | "kernelspec": { 289 | "display_name": "Python 3", 290 | "language": "python", 291 | "name": "python3" 292 | }, 293 | "language_info": { 294 | "codemirror_mode": { 295 | "name": "ipython", 296 | "version": 3 297 | }, 298 | "file_extension": ".py", 299 | "mimetype": "text/x-python", 300 | "name": "python", 301 | "nbconvert_exporter": "python", 302 | "pygments_lexer": "ipython3", 303 | "version": "3.6.2" 304 | } 305 | }, 306 | "nbformat": 4, 307 | "nbformat_minor": 2 308 | } 309 | -------------------------------------------------------------------------------- /few_shot/callbacks.py: -------------------------------------------------------------------------------- 1 | """ 2 | Ports of Callback classes from the Keras library. 3 | """ 4 | from tqdm import tqdm 5 | import numpy as np 6 | import torch 7 | from collections import OrderedDict, Iterable 8 | import warnings 9 | import os 10 | import csv 11 | import io 12 | 13 | from few_shot.eval import evaluate 14 | 15 | 16 | class CallbackList(object): 17 | """Container abstracting a list of callbacks. 18 | 19 | # Arguments 20 | callbacks: List of `Callback` instances. 21 | """ 22 | def __init__(self, callbacks): 23 | self.callbacks = [c for c in callbacks] 24 | 25 | def set_params(self, params): 26 | for callback in self.callbacks: 27 | callback.set_params(params) 28 | 29 | def set_model(self, model): 30 | for callback in self.callbacks: 31 | callback.set_model(model) 32 | 33 | def on_epoch_begin(self, epoch, logs=None): 34 | """Called at the start of an epoch. 35 | # Arguments 36 | epoch: integer, index of epoch. 37 | logs: dictionary of logs. 38 | """ 39 | logs = logs or {} 40 | for callback in self.callbacks: 41 | callback.on_epoch_begin(epoch, logs) 42 | 43 | def on_epoch_end(self, epoch, logs=None): 44 | """Called at the end of an epoch. 45 | # Arguments 46 | epoch: integer, index of epoch. 47 | logs: dictionary of logs. 48 | """ 49 | logs = logs or {} 50 | for callback in self.callbacks: 51 | callback.on_epoch_end(epoch, logs) 52 | 53 | def on_batch_begin(self, batch, logs=None): 54 | """Called right before processing a batch. 55 | # Arguments 56 | batch: integer, index of batch within the current epoch. 57 | logs: dictionary of logs. 58 | """ 59 | logs = logs or {} 60 | for callback in self.callbacks: 61 | callback.on_batch_begin(batch, logs) 62 | 63 | def on_batch_end(self, batch, logs=None): 64 | """Called at the end of a batch. 65 | # Arguments 66 | batch: integer, index of batch within the current epoch. 67 | logs: dictionary of logs. 68 | """ 69 | logs = logs or {} 70 | for callback in self.callbacks: 71 | callback.on_batch_end(batch, logs) 72 | 73 | def on_train_begin(self, logs=None): 74 | """Called at the beginning of training. 75 | # Arguments 76 | logs: dictionary of logs. 77 | """ 78 | logs = logs or {} 79 | for callback in self.callbacks: 80 | callback.on_train_begin(logs) 81 | 82 | def on_train_end(self, logs=None): 83 | """Called at the end of training. 84 | # Arguments 85 | logs: dictionary of logs. 86 | """ 87 | logs = logs or {} 88 | for callback in self.callbacks: 89 | callback.on_train_end(logs) 90 | 91 | 92 | class Callback(object): 93 | def __init__(self): 94 | self.model = None 95 | 96 | def set_params(self, params): 97 | self.params = params 98 | 99 | def set_model(self, model): 100 | self.model = model 101 | 102 | def on_epoch_begin(self, epoch, logs=None): 103 | pass 104 | 105 | def on_epoch_end(self, epoch, logs=None): 106 | pass 107 | 108 | def on_batch_begin(self, batch, logs=None): 109 | pass 110 | 111 | def on_batch_end(self, batch, logs=None): 112 | pass 113 | 114 | def on_train_begin(self, logs=None): 115 | pass 116 | 117 | def on_train_end(self, logs=None): 118 | pass 119 | 120 | 121 | class DefaultCallback(Callback): 122 | """Records metrics over epochs by averaging over each batch. 123 | 124 | NB The metrics are calculated with a moving model 125 | """ 126 | def on_epoch_begin(self, batch, logs=None): 127 | self.seen = 0 128 | self.totals = {} 129 | self.metrics = ['loss'] + self.params['metrics'] 130 | 131 | def on_batch_end(self, batch, logs=None): 132 | logs = logs or {} 133 | batch_size = logs.get('size', 1) or 1 134 | self.seen += batch_size 135 | 136 | for k, v in logs.items(): 137 | if k in self.totals: 138 | self.totals[k] += v * batch_size 139 | else: 140 | self.totals[k] = v * batch_size 141 | 142 | def on_epoch_end(self, epoch, logs=None): 143 | if logs is not None: 144 | for k in self.metrics: 145 | if k in self.totals: 146 | # Make value available to next callbacks. 147 | logs[k] = self.totals[k] / self.seen 148 | 149 | 150 | class ProgressBarLogger(Callback): 151 | """TQDM progress bar that displays the running average of loss and other metrics.""" 152 | def __init__(self): 153 | super(ProgressBarLogger, self).__init__() 154 | 155 | def on_train_begin(self, logs=None): 156 | self.num_batches = self.params['num_batches'] 157 | self.verbose = self.params['verbose'] 158 | self.metrics = ['loss'] + self.params['metrics'] 159 | self.epoch_metrics = self.params['epoch_metrics'] 160 | 161 | def on_epoch_begin(self, epoch, logs=None): 162 | self.target = self.num_batches 163 | self.pbar = tqdm(total=self.target, desc='Epoch {}'.format(epoch)) 164 | self.seen = 0 165 | 166 | def on_batch_begin(self, batch, logs=None): 167 | self.log_values = {} 168 | 169 | def on_batch_end(self, batch, logs=None): 170 | logs = logs or {} 171 | self.seen += 1 172 | 173 | for k in self.metrics: 174 | if k in logs: 175 | self.log_values[k] = logs[k] 176 | 177 | # Skip progbar update for the last batch; 178 | # will be handled by on_epoch_end. 179 | if self.verbose and self.seen < self.target: 180 | self.pbar.update(1) 181 | self.pbar.set_postfix(self.log_values) 182 | 183 | def on_epoch_end(self, epoch, logs=None): 184 | # Update log values 185 | self.log_values = {} 186 | for k in self.metrics + self.epoch_metrics: 187 | if k in logs: 188 | self.log_values[k] = logs[k] 189 | 190 | if self.verbose: 191 | self.pbar.update(1) 192 | self.pbar.set_postfix(self.log_values) 193 | 194 | self.pbar.close() 195 | 196 | 197 | class CSVLogger(Callback): 198 | """Callback that streams epoch results to a csv file. 199 | Supports all values that can be represented as a string, 200 | including 1D iterables such as np.ndarray. 201 | 202 | # Arguments 203 | filename: filename of the csv file, e.g. 'run/log.csv'. 204 | separator: string used to separate elements in the csv file. 205 | append: True: append if file exists (useful for continuing 206 | training). False: overwrite existing file, 207 | """ 208 | 209 | def __init__(self, filename, separator=',', append=False): 210 | self.sep = separator 211 | self.filename = filename 212 | self.append = append 213 | self.writer = None 214 | self.keys = None 215 | self.append_header = True 216 | self.file_flags = '' 217 | self._open_args = {'newline': '\n'} 218 | super(CSVLogger, self).__init__() 219 | 220 | def on_train_begin(self, logs=None): 221 | if self.append: 222 | if os.path.exists(self.filename): 223 | with open(self.filename, 'r' + self.file_flags) as f: 224 | self.append_header = not bool(len(f.readline())) 225 | mode = 'a' 226 | else: 227 | mode = 'w' 228 | 229 | self.csv_file = io.open(self.filename, 230 | mode + self.file_flags, 231 | **self._open_args) 232 | 233 | def on_epoch_end(self, epoch, logs=None): 234 | logs = logs or {} 235 | 236 | def handle_value(k): 237 | is_zero_dim_ndarray = isinstance(k, np.ndarray) and k.ndim == 0 238 | if isinstance(k, str): 239 | return k 240 | elif isinstance(k, Iterable) and not is_zero_dim_ndarray: 241 | return '"[%s]"' % (', '.join(map(str, k))) 242 | else: 243 | return k 244 | 245 | if self.keys is None: 246 | self.keys = sorted(logs.keys()) 247 | 248 | if not self.writer: 249 | class CustomDialect(csv.excel): 250 | delimiter = self.sep 251 | fieldnames = ['epoch'] + self.keys 252 | self.writer = csv.DictWriter(self.csv_file, 253 | fieldnames=fieldnames, 254 | dialect=CustomDialect) 255 | if self.append_header: 256 | self.writer.writeheader() 257 | 258 | row_dict = OrderedDict({'epoch': epoch}) 259 | row_dict.update((key, handle_value(logs[key])) for key in self.keys) 260 | self.writer.writerow(row_dict) 261 | self.csv_file.flush() 262 | 263 | def on_train_end(self, logs=None): 264 | self.csv_file.close() 265 | self.writer = None 266 | 267 | 268 | class EvaluateMetrics(Callback): 269 | """Evaluates metrics on a dataset after every epoch. 270 | 271 | # Argments 272 | dataloader: torch.DataLoader of the dataset on which the model will be evaluated 273 | prefix: Prefix to prepend to the names of the metrics when they is logged. Defaults to 'val_' but can be changed 274 | if the model is to be evaluated on many datasets separately. 275 | suffix: Suffix to append to the names of the metrics when they is logged. 276 | """ 277 | def __init__(self, dataloader, prefix='val_', suffix=''): 278 | super(EvaluateMetrics, self).__init__() 279 | self.dataloader = dataloader 280 | self.prefix = prefix 281 | self.suffix = suffix 282 | 283 | def on_train_begin(self, logs=None): 284 | self.metrics = self.params['metrics'] 285 | self.prepare_batch = self.params['prepare_batch'] 286 | self.loss_fn = self.params['loss_fn'] 287 | 288 | def on_epoch_end(self, epoch, logs=None): 289 | logs = logs or {} 290 | logs.update( 291 | evaluate(self.model, self.dataloader, self.prepare_batch, self.metrics, self.loss_fn, self.prefix, self.suffix) 292 | ) 293 | 294 | 295 | class ReduceLROnPlateau(Callback): 296 | """Reduce learning rate when a metric has stopped improving. 297 | 298 | Models often benefit from reducing the learning rate by a factor 299 | of 2-10 once learning stagnates. This callback monitors a 300 | quantity and if no improvement is seen for a 'patience' number 301 | of epochs, the learning rate is reduced. 302 | 303 | # Arguments 304 | monitor: quantity to be monitored. 305 | factor: factor by which the learning rate will 306 | be reduced. new_lr = lr * factor 307 | patience: number of epochs with no improvement 308 | after which learning rate will be reduced. 309 | verbose: int. 0: quiet, 1: update messages. 310 | mode: one of {auto, min, max}. In `min` mode, 311 | lr will be reduced when the quantity 312 | monitored has stopped decreasing; in `max` 313 | mode it will be reduced when the quantity 314 | monitored has stopped increasing; in `auto` 315 | mode, the direction is automatically inferred 316 | from the name of the monitored quantity. 317 | min_delta: threshold for measuring the new optimum, 318 | to only focus on significant changes. 319 | cooldown: number of epochs to wait before resuming 320 | normal operation after lr has been reduced. 321 | min_lr: lower bound on the learning rate. 322 | """ 323 | 324 | def __init__(self, monitor='val_loss', factor=0.1, patience=10, 325 | verbose=0, mode='auto', min_delta=1e-4, cooldown=0, min_lr=0, 326 | **kwargs): 327 | super(ReduceLROnPlateau, self).__init__() 328 | 329 | self.monitor = monitor 330 | if factor >= 1.0: 331 | raise ValueError('ReduceLROnPlateau does not support a factor >= 1.0.') 332 | self.factor = factor 333 | self.min_lr = min_lr 334 | self.min_delta = min_delta 335 | self.patience = patience 336 | self.verbose = verbose 337 | self.cooldown = cooldown 338 | self.cooldown_counter = 0 # Cooldown counter. 339 | self.wait = 0 340 | self.best = 0 341 | if mode not in ['auto', 'min', 'max']: 342 | raise ValueError('Mode must be one of (auto, min, max).') 343 | self.mode = mode 344 | self.monitor_op = None 345 | 346 | self._reset() 347 | 348 | def _reset(self): 349 | """Resets wait counter and cooldown counter. 350 | """ 351 | if (self.mode == 'min' or 352 | (self.mode == 'auto' and 'acc' not in self.monitor)): 353 | self.monitor_op = lambda a, b: np.less(a, b - self.min_delta) 354 | self.best = np.Inf 355 | else: 356 | self.monitor_op = lambda a, b: np.greater(a, b + self.min_delta) 357 | self.best = -np.Inf 358 | self.cooldown_counter = 0 359 | self.wait = 0 360 | 361 | def on_train_begin(self, logs=None): 362 | self.optimiser = self.params['optimiser'] 363 | self.min_lrs = [self.min_lr] * len(self.optimiser.param_groups) 364 | self._reset() 365 | 366 | def on_epoch_end(self, epoch, logs=None): 367 | logs = logs or {} 368 | if len(self.optimiser.param_groups) == 1: 369 | logs['lr'] = self.optimiser.param_groups[0]['lr'] 370 | else: 371 | for i, param_group in enumerate(self.optimiser.param_groups): 372 | logs['lr_{}'.format(i)] = param_group['lr'] 373 | 374 | current = logs.get(self.monitor) 375 | 376 | if self.in_cooldown(): 377 | self.cooldown_counter -= 1 378 | self.wait = 0 379 | 380 | if self.monitor_op(current, self.best): 381 | self.best = current 382 | self.wait = 0 383 | elif not self.in_cooldown(): 384 | self.wait += 1 385 | if self.wait >= self.patience: 386 | self._reduce_lr(epoch) 387 | self.cooldown_counter = self.cooldown 388 | self.wait = 0 389 | 390 | def _reduce_lr(self, epoch): 391 | for i, param_group in enumerate(self.optimiser.param_groups): 392 | old_lr = float(param_group['lr']) 393 | new_lr = max(old_lr * self.factor, self.min_lrs[i]) 394 | if old_lr - new_lr > self.min_delta: 395 | param_group['lr'] = new_lr 396 | if self.verbose: 397 | print('Epoch {:5d}: reducing learning rate' 398 | ' of group {} to {:.4e}.'.format(epoch, i, new_lr)) 399 | 400 | def in_cooldown(self): 401 | return self.cooldown_counter > 0 402 | 403 | 404 | class ModelCheckpoint(Callback): 405 | """Save the model after every epoch. 406 | 407 | `filepath` can contain named formatting options, which will be filled the value of `epoch` and keys in `logs` 408 | (passed in `on_epoch_end`). 409 | 410 | For example: if `filepath` is `weights.{epoch:02d}-{val_loss:.2f}.hdf5`, then the model checkpoints will be saved 411 | with the epoch number and the validation loss in the filename. 412 | 413 | # Arguments 414 | filepath: string, path to save the model file. 415 | monitor: quantity to monitor. 416 | verbose: verbosity mode, 0 or 1. 417 | save_best_only: if `save_best_only=True`, 418 | the latest best model according to 419 | the quantity monitored will not be overwritten. 420 | mode: one of {auto, min, max}. 421 | If `save_best_only=True`, the decision 422 | to overwrite the current save file is made 423 | based on either the maximization or the 424 | minimization of the monitored quantity. For `val_acc`, 425 | this should be `max`, for `val_loss` this should 426 | be `min`, etc. In `auto` mode, the direction is 427 | automatically inferred from the name of the monitored quantity. 428 | save_weights_only: if True, then only the model's weights will be 429 | saved (`model.save_weights(filepath)`), else the full model 430 | is saved (`model.save(filepath)`). 431 | period: Interval (number of epochs) between checkpoints. 432 | """ 433 | 434 | def __init__(self, filepath, monitor='val_loss', verbose=0, save_best_only=False, mode='auto', period=1): 435 | super(ModelCheckpoint, self).__init__() 436 | self.monitor = monitor 437 | self.verbose = verbose 438 | self.filepath = filepath 439 | self.save_best_only = save_best_only 440 | self.period = period 441 | self.epochs_since_last_save = 0 442 | 443 | if mode not in ['auto', 'min', 'max']: 444 | raise ValueError('Mode must be one of (auto, min, max).') 445 | 446 | if mode == 'min': 447 | self.monitor_op = np.less 448 | self.best = np.Inf 449 | elif mode == 'max': 450 | self.monitor_op = np.greater 451 | self.best = -np.Inf 452 | else: 453 | if 'acc' in self.monitor or self.monitor.startswith('fmeasure'): 454 | self.monitor_op = np.greater 455 | self.best = -np.Inf 456 | else: 457 | self.monitor_op = np.less 458 | 459 | self.best = np.Inf 460 | 461 | def on_epoch_end(self, epoch, logs=None): 462 | logs = logs or {} 463 | self.epochs_since_last_save += 1 464 | if self.epochs_since_last_save >= self.period: 465 | self.epochs_since_last_save = 0 466 | filepath = self.filepath.format(epoch=epoch + 1, **logs) 467 | if self.save_best_only: 468 | current = logs.get(self.monitor) 469 | if current is None: 470 | warnings.warn('Can save best model only with %s available, ' 471 | 'skipping.' % (self.monitor), RuntimeWarning) 472 | else: 473 | if self.monitor_op(current, self.best): 474 | if self.verbose > 0: 475 | print('\nEpoch %05d: %s improved from %0.5f to %0.5f,' 476 | ' saving model to %s' 477 | % (epoch + 1, self.monitor, self.best, 478 | current, filepath)) 479 | self.best = current 480 | torch.save(self.model.state_dict(), filepath) 481 | else: 482 | if self.verbose > 0: 483 | print('\nEpoch %05d: %s did not improve from %0.5f' % 484 | (epoch + 1, self.monitor, self.best)) 485 | else: 486 | if self.verbose > 0: 487 | print('\nEpoch %05d: saving model to %s' % (epoch + 1, filepath)) 488 | torch.save(self.model.state_dict(), filepath) 489 | 490 | 491 | class LearningRateScheduler(Callback): 492 | """Learning rate scheduler. 493 | # Arguments 494 | schedule: a function that takes an epoch index as input 495 | (integer, indexed from 0) and current learning rate 496 | and returns a new learning rate as output (float). 497 | verbose: int. 0: quiet, 1: update messages. 498 | """ 499 | 500 | def __init__(self, schedule, verbose=0): 501 | super(LearningRateScheduler, self).__init__() 502 | self.schedule = schedule 503 | self.verbose = verbose 504 | 505 | def on_train_begin(self, logs=None): 506 | self.optimiser = self.params['optimiser'] 507 | 508 | def on_epoch_begin(self, epoch, logs=None): 509 | lrs = [self.schedule(epoch, param_group['lr']) for param_group in self.optimiser.param_groups] 510 | 511 | if not all(isinstance(lr, (float, np.float32, np.float64)) for lr in lrs): 512 | raise ValueError('The output of the "schedule" function ' 513 | 'should be float.') 514 | self.set_lr(epoch, lrs) 515 | 516 | def on_epoch_end(self, epoch, logs=None): 517 | logs = logs or {} 518 | if len(self.optimiser.param_groups) == 1: 519 | logs['lr'] = self.optimiser.param_groups[0]['lr'] 520 | else: 521 | for i, param_group in enumerate(self.optimiser.param_groups): 522 | logs['lr_{}'.format(i)] = param_group['lr'] 523 | 524 | def set_lr(self, epoch, lrs): 525 | for i, param_group in enumerate(self.optimiser.param_groups): 526 | new_lr = lrs[i] 527 | param_group['lr'] = new_lr 528 | if self.verbose: 529 | print('Epoch {:5d}: setting learning rate' 530 | ' of group {} to {:.4e}.'.format(epoch, i, new_lr)) 531 | --------------------------------------------------------------------------------