├── plots ├── mini.png └── inference_summary.png ├── src ├── datasets │ ├── __init__.py │ ├── sampler.py │ ├── ingredient.py │ ├── loader.py │ └── transform.py ├── models │ ├── __init__.py │ ├── ingredient.py │ ├── Conv4.py │ ├── ProtoNet.py │ ├── MobileNet.py │ ├── WideResNet.py │ ├── DenseNet.py │ └── ResNet.py ├── optim.py ├── utils.py ├── main.py ├── trainer.py ├── eval.py └── tim.py ├── .gitignore ├── scripts ├── evaluate │ ├── tim_adm │ │ ├── all.sh │ │ ├── cub.sh │ │ ├── mini2cub.sh │ │ ├── mini_10_20_ways.sh │ │ ├── tiered.sh │ │ └── mini.sh │ ├── tim_gd │ │ ├── all.sh │ │ ├── cub.sh │ │ ├── mini2cub.sh │ │ ├── mini_10_20_ways.sh │ │ ├── mini.sh │ │ └── tiered.sh │ └── baseline │ │ ├── cub.sh │ │ ├── mini2cub.sh │ │ ├── mini_10_20_ways.sh │ │ ├── mini.sh │ │ └── tiered.sh ├── downloads │ ├── download_environment.py │ ├── download_models.py │ ├── download_data.py │ └── utils.py ├── train │ ├── wideres.sh │ ├── densenet.sh │ └── resnet18.sh ├── make_splits │ ├── tieredImagenet_split.py │ └── cub_split.py └── ablation │ ├── tim_gd_all │ ├── mini.sh │ ├── cub.sh │ └── tiered.sh │ ├── tim_adm │ └── weighting_effect.sh │ └── tim_gd │ └── weighting_effect.sh ├── LICENSE └── README.md /plots/mini.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mboudiaf/TIM/HEAD/plots/mini.png -------------------------------------------------------------------------------- /plots/inference_summary.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mboudiaf/TIM/HEAD/plots/inference_summary.png -------------------------------------------------------------------------------- /src/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from .sampler import CategoriesSampler 2 | from .loader import DatasetFolder 3 | from .transform import with_augment, without_augment -------------------------------------------------------------------------------- /src/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .ResNet import * 2 | from .DenseNet import * 3 | from .Conv4 import Conv4 as conv4 4 | from .MobileNet import MobileNet as mobilenet 5 | from .WideResNet import wideres 6 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.sublime-project 2 | *.sublime-workspace 3 | results/ 4 | data/ 5 | checkpoints/ 6 | deploy/ 7 | logs/ 8 | tmp/ 9 | env/ 10 | tim_github.sublime-project 11 | tim_github.sublime-workspace 12 | .gitignore 13 | __pycache__/ 14 | -------------------------------------------------------------------------------- /scripts/evaluate/tim_adm/all.sh: -------------------------------------------------------------------------------- 1 | bash scripts/evaluate/tim_adm/mini.sh 2 | bash scripts/evaluate/tim_adm/tiered.sh 3 | bash scripts/evaluate/tim_adm/cub.sh 4 | bash scripts/evaluate/tim_adm/mini_10_20_ways.sh 5 | bash scripts/evaluate/tim_adm/mini2cub.sh -------------------------------------------------------------------------------- /scripts/evaluate/tim_gd/all.sh: -------------------------------------------------------------------------------- 1 | bash scripts/evaluate/tim_gd/mini.sh 2 | bash scripts/evaluate/tim_gd/tiered.sh 3 | bash scripts/evaluate/tim_gd/cub.sh 4 | bash scripts/evaluate/tim_gd/mini_10_20_ways.sh 5 | bash scripts/evaluate/tim_gd/mini2cub.sh 6 | # bash scripts/evaluate/tim_gd/iNat.sh 7 | -------------------------------------------------------------------------------- /src/models/ingredient.py: -------------------------------------------------------------------------------- 1 | from sacred import Ingredient 2 | from src.models import __dict__ 3 | 4 | model_ingredient = Ingredient('model') 5 | @model_ingredient.config 6 | def config(): 7 | arch = 'resnet18' 8 | num_classes = 64 9 | 10 | 11 | @model_ingredient.capture 12 | def get_model(arch, num_classes): 13 | return __dict__[arch](num_classes=num_classes) -------------------------------------------------------------------------------- /scripts/evaluate/tim_gd/cub.sh: -------------------------------------------------------------------------------- 1 | 2 | # =============> Resnet18 <================ 3 | python3 -m src.main \ 4 | -F logs/tim_gd/cub/ \ 5 | with dataset.path="data/cub/CUB_200_2011/images" \ 6 | ckpt_path="checkpoints/cub/softmax/resnet18" \ 7 | dataset.split_dir="split/cub" \ 8 | model.arch='resnet18' \ 9 | model.num_classes=100 \ 10 | tim.iter=1000 \ 11 | evaluate=True \ 12 | eval.method='tim_gd' \ 13 | 14 | -------------------------------------------------------------------------------- /scripts/evaluate/baseline/cub.sh: -------------------------------------------------------------------------------- 1 | 2 | # ===========================> Resnet 18 <========================================= 3 | 4 | python3 -m src.main \ 5 | -F logs/non_augmented/baseline/cub/resnet18 \ 6 | with dataset.path="data/cub/CUB_200_2011/images" \ 7 | ckpt_path="checkpoints/cub/softmax/resnet18" \ 8 | dataset.split_dir="split/cub" \ 9 | model.arch='resnet18' \ 10 | model.num_classes=100 \ 11 | evaluate=True -------------------------------------------------------------------------------- /scripts/evaluate/tim_adm/cub.sh: -------------------------------------------------------------------------------- 1 | # ===========================> Resnet18 <========================================= 2 | 3 | python3 -m src.main \ 4 | -F logs/tim_adm/cub/resnet18 \ 5 | with dataset.path="data/cub/CUB_200_2011/images" \ 6 | ckpt_path="checkpoints/cub/softmax/resnet18" \ 7 | dataset.split_dir="split/cub" \ 8 | evaluate=True \ 9 | model.arch='resnet18' \ 10 | model.num_classes=100 \ 11 | eval.method="tim_adm" \ 12 | 13 | -------------------------------------------------------------------------------- /scripts/downloads/download_environment.py: -------------------------------------------------------------------------------- 1 | import os 2 | from scripts.downloads.utils import download_file_from_google_drive 3 | 4 | 5 | def download_env(): 6 | id = '1-4WhU-bM3wJFcYV0kIyokShddmQFNXTM' 7 | name = f"env.zip" 8 | print('Start Download (may take a few minutes)') 9 | download_file_from_google_drive(id, name) 10 | print('Finish Download') 11 | os.system(f'unzip {name}') 12 | os.system(f'rm -r {name}') 13 | 14 | 15 | if __name__ == '__main__': 16 | download_env() -------------------------------------------------------------------------------- /scripts/evaluate/baseline/mini2cub.sh: -------------------------------------------------------------------------------- 1 | 2 | # ===========================> Resnet 18 <========================================= 3 | 4 | python3 -m src.main \ 5 | -F logs/non_augmented/baseline/mini2cub/resnet18 \ 6 | with dataset.path="data/mini_imagenet" \ 7 | ckpt_path="checkpoints/mini/softmax/resnet18" \ 8 | dataset.split_dir="split/mini" \ 9 | model.arch='resnet18' \ 10 | evaluate=True \ 11 | eval.target_data_path="data/cub/CUB_200_2011/images" \ 12 | eval.target_split_dir="split/cub" \ 13 | -------------------------------------------------------------------------------- /scripts/evaluate/tim_adm/mini2cub.sh: -------------------------------------------------------------------------------- 1 | 2 | # ===========================> Resnet 18 <======================================= 3 | python3 -m src.main \ 4 | -F logs/tim_adm/mini2cub/resnet18 \ 5 | with dataset.path="data/mini_imagenet" \ 6 | ckpt_path="checkpoints/mini/softmax/resnet18" \ 7 | dataset.split_dir="split/mini" \ 8 | model.arch='resnet18' \ 9 | evaluate=True \ 10 | eval.method="tim_adm" \ 11 | eval.target_data_path="data/cub/CUB_200_2011/images" \ 12 | eval.target_split_dir="split/cub" 13 | -------------------------------------------------------------------------------- /scripts/evaluate/tim_gd/mini2cub.sh: -------------------------------------------------------------------------------- 1 | 2 | # ===========================> Resnet 18 <========================================= 3 | 4 | python3 -m src.main \ 5 | -F logs/tim_gd/mini2cub/resnet18 \ 6 | with dataset.path="data/mini_imagenet" \ 7 | ckpt_path="checkpoints/mini/softmax/resnet18" \ 8 | dataset.split_dir="split/mini" \ 9 | model.arch='resnet18' \ 10 | tim.iter=1000 \ 11 | evaluate=True \ 12 | eval.method='tim_gd' \ 13 | eval.target_data_path="data/cub/CUB_200_2011/images" \ 14 | eval.target_split_dir="split/cub" 15 | -------------------------------------------------------------------------------- /scripts/downloads/download_models.py: -------------------------------------------------------------------------------- 1 | import os 2 | from scripts.downloads.utils import download_file_from_google_drive 3 | 4 | 5 | ids = {'checkpoints': '15MFsig6pjXO7vZdo-1znJoXtHv4NY-AF'} 6 | 7 | 8 | def download_models(name): 9 | id = ids[name] 10 | name = f"{name}.zip" 11 | print('Start Download (may take a few minutes)') 12 | download_file_from_google_drive(id, name) 13 | print('Finish Download') 14 | os.system(f'unzip {name}') 15 | os.system(f'rm -r {name}') 16 | 17 | 18 | if __name__ == '__main__': 19 | download_models('checkpoints') 20 | -------------------------------------------------------------------------------- /scripts/evaluate/baseline/mini_10_20_ways.sh: -------------------------------------------------------------------------------- 1 | # # ===========================> TIM-GD <========================================= 2 | 3 | python3 -m src.main \ 4 | -F logs/10_ways/non_augmented/tim_gd/resnet18 \ 5 | with dataset.path="data/mini_imagenet" \ 6 | ckpt_path="checkpoints/mini/softmax/resnet18" \ 7 | dataset.split_dir="split/mini" \ 8 | model.arch='resnet18' \ 9 | evaluate=True \ 10 | eval.meta_val_way=10 11 | 12 | python3 -m src.main \ 13 | -F logs/20_ways/non_augmented/tim_gd/resnet18 \ 14 | with dataset.path="data/mini_imagenet" \ 15 | ckpt_path="checkpoints/mini/softmax/resnet18" \ 16 | dataset.split_dir="split/mini" \ 17 | model.arch='resnet18' \ 18 | evaluate=True \ 19 | eval.meta_val_way=20 20 | -------------------------------------------------------------------------------- /scripts/downloads/download_data.py: -------------------------------------------------------------------------------- 1 | import os 2 | from scripts.downloads.utils import download_file_from_google_drive 3 | 4 | 5 | ids = {'mini_imagenet': '1roXWm-COzPhyel_OZu7vseIpYV_-EH3G', 6 | 'tiered_imagenet': '1bj9KM1uwRxHk6ZljxIOBdyNFqG017YVx', 7 | 'cub': '1Ay-91MwLmNWYv7nNIOoI45XZHiz2JQRb' 8 | } 9 | 10 | 11 | def download_data(dataset): 12 | id = ids[dataset] 13 | name = f"{dataset}.zip" 14 | print('Start Download (may take a few minutes)') 15 | download_file_from_google_drive(id, name) 16 | print('Finish Download') 17 | os.system(f'unzip {name}') 18 | os.system(f'rm {name}') 19 | 20 | 21 | if __name__ == '__main__': 22 | data_dir = "data" 23 | os.makedirs(data_dir, exist_ok=True) 24 | os.chdir(data_dir) 25 | download_data('mini_imagenet') 26 | download_data('cub') 27 | -------------------------------------------------------------------------------- /scripts/evaluate/tim_adm/mini_10_20_ways.sh: -------------------------------------------------------------------------------- 1 | # # ===========================> TIM-ADM <========================================= 2 | 3 | python3 -m src.main \ 4 | -F logs/tim_adm/mini_10_ways/resnet18 \ 5 | with dataset.path="data/mini_imagenet" \ 6 | ckpt_path="checkpoints/mini/softmax/resnet18" \ 7 | dataset.split_dir="split/mini" \ 8 | model.arch='resnet18' \ 9 | evaluate=True \ 10 | eval.method="tim_adm" \ 11 | eval.meta_val_way=10 \ 12 | eval.meta_test_iter=1000 13 | 14 | python3 -m src.main \ 15 | -F logs/tim_adm/mini_20_ways/resnet18 \ 16 | with dataset.path="data/mini_imagenet" \ 17 | ckpt_path="checkpoints/mini/softmax/resnet18" \ 18 | dataset.split_dir="split/mini" \ 19 | model.arch='resnet18' \ 20 | evaluate=True \ 21 | eval.method="tim_adm" \ 22 | eval.meta_val_way=20 \ 23 | eval.meta_test_iter=1000 24 | -------------------------------------------------------------------------------- /scripts/train/wideres.sh: -------------------------------------------------------------------------------- 1 | # ===============> Mini <=================== 2 | 3 | python3 -m src.main \ 4 | with dataset.path="data/mini_imagenet" \ 5 | dataset.split_dir="split/mini" \ 6 | ckpt_path="checkpoints/mini/softmax/wideres" \ 7 | dataset.batch_size=128 \ 8 | dataset.jitter=True \ 9 | model.arch='wideres' \ 10 | model.num_classes=64 \ 11 | optim.scheduler="multi_step" \ 12 | trainer.label_smoothing=0.1 13 | 14 | # ===============> Tiered <=================== 15 | 16 | python3 -m src.main \ 17 | with dataset.path="data/tiered_imagenet/data" \ 18 | dataset.split_dir="split/tiered" \ 19 | dataset.jitter=True \ 20 | ckpt_path="checkpoints/tiered/softmax/wideres" \ 21 | dataset.batch_size=256 \ 22 | model.arch='wideres' \ 23 | model.num_classes=351 \ 24 | optim.scheduler="multi_step" \ 25 | epochs=90 \ 26 | trainer.label_smoothing=0.1 27 | 28 | -------------------------------------------------------------------------------- /scripts/evaluate/tim_gd/mini_10_20_ways.sh: -------------------------------------------------------------------------------- 1 | # # ===========================> TIM-GD <========================================= 2 | 3 | python3 -m src.main \ 4 | -F logs/tim_gd/mini_10_ways/resnet18 \ 5 | with dataset.path="data/mini_imagenet" \ 6 | tim.iter=1000 \ 7 | ckpt_path="checkpoints/mini/softmax/resnet18" \ 8 | dataset.split_dir="split/mini" \ 9 | model.arch='resnet18' \ 10 | evaluate=True \ 11 | eval.method="tim_gd" \ 12 | eval.meta_val_way=10 \ 13 | eval.meta_test_iter=1000 14 | 15 | python3 -m src.main \ 16 | -F logs/tim_gd/mini_20_ways/resnet18 \ 17 | with dataset.path="data/mini_imagenet" \ 18 | tim.iter=1000 \ 19 | ckpt_path="checkpoints/mini/softmax/resnet18" \ 20 | dataset.split_dir="split/mini" \ 21 | model.arch='resnet18' \ 22 | evaluate=True \ 23 | eval.method="tim_gd" \ 24 | eval.meta_val_way=20 \ 25 | eval.meta_test_iter=1000 26 | -------------------------------------------------------------------------------- /scripts/train/densenet.sh: -------------------------------------------------------------------------------- 1 | # ===============> Mini <=================== 2 | 3 | python3 -m src.main \ 4 | with dataset.path="data/mini_imagenet" \ 5 | visdom_port=8097 \ 6 | dataset.split_dir="split/mini" \ 7 | ckpt_path="checkpoints/mini/softmax/densenet121" \ 8 | dataset.batch_size=256 \ 9 | dataset.jitter=True \ 10 | model.arch='densenet121' \ 11 | model.num_classes=64 \ 12 | optim.scheduler="multi_step" \ 13 | epochs=90 \ 14 | trainer.label_smoothing=0.1 15 | 16 | # ===============> Tiered <=================== 17 | python3 -m src.main \ 18 | with dataset.path="data/tiered_imagenet/data" \ 19 | visdom_port=8097 \ 20 | dataset.split_dir="split/tiered" \ 21 | ckpt_path="checkpoints/tiered/softmax/densenet121" \ 22 | dataset.batch_size=256 \ 23 | dataset.jitter=True \ 24 | model.arch='densenet121' \ 25 | model.num_classes=351 \ 26 | optim.scheduler="multi_step" \ 27 | epochs=90 \ 28 | trainer.label_smoothing=0.1 29 | -------------------------------------------------------------------------------- /scripts/downloads/utils.py: -------------------------------------------------------------------------------- 1 | import requests 2 | import sys 3 | import os 4 | 5 | 6 | def download_file_from_google_drive(id, destination): 7 | def get_confirm_token(response): 8 | for key, value in response.cookies.items(): 9 | if key.startswith('download_warning'): 10 | return value 11 | return None 12 | 13 | def save_response_content(response, destination): 14 | CHUNK_SIZE = 32768 15 | with open(destination, "wb") as f: 16 | for chunk in response.iter_content(CHUNK_SIZE): 17 | if chunk: # filter out keep-alive new chunks 18 | f.write(chunk) 19 | 20 | URL = "https://docs.google.com/uc?export=download" 21 | session = requests.Session() 22 | response = session.get(URL, params={'id': id}, stream=True) 23 | token = get_confirm_token(response) 24 | if token: 25 | params = {'id': id, 'confirm': token} 26 | response = session.get(URL, params=params, stream=True) 27 | save_response_content(response, destination) -------------------------------------------------------------------------------- /scripts/evaluate/baseline/mini.sh: -------------------------------------------------------------------------------- 1 | 2 | # ===========================> Resnet 18 <========================================= 3 | 4 | python3 -m src.main \ 5 | -F logs/non_augmented/baseline/mini/resnet18 \ 6 | with dataset.path="data/mini_imagenet" \ 7 | ckpt_path="checkpoints/mini/softmax/resnet18" \ 8 | dataset.split_dir="split/mini" \ 9 | model.arch='resnet18' \ 10 | evaluate=True 11 | 12 | # ===========================> WRN 28-10 <========================================= 13 | 14 | python3 -m src.main \ 15 | -F logs/non_augmented/baseline/mini/wideres \ 16 | with dataset.path="data/mini_imagenet" \ 17 | ckpt_path="checkpoints/mini/softmax/wideres" \ 18 | dataset.split_dir="split/mini" \ 19 | model.arch='wideres' \ 20 | evaluate=True 21 | 22 | # ===========================> DenseNet <========================================= 23 | 24 | python3 -m src.main \ 25 | -F logs/non_augmented/baseline/mini/densenet121 \ 26 | with dataset.path="data/mini_imagenet" \ 27 | ckpt_path="checkpoints/mini/softmax/densenet121" \ 28 | dataset.split_dir="split/mini" \ 29 | model.arch='densenet121' \ 30 | evaluate=True -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Malik Boudiaf 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /scripts/evaluate/tim_gd/mini.sh: -------------------------------------------------------------------------------- 1 | 2 | # ===========================> Resnet 18 <===================================== 3 | 4 | python3 -m src.main \ 5 | -F logs/tim_gd/mini/resnet18 \ 6 | with dataset.path="data/mini_imagenet" \ 7 | ckpt_path="checkpoints/mini/softmax/resnet18" \ 8 | dataset.split_dir="split/mini" \ 9 | model.arch='resnet18' \ 10 | evaluate=True \ 11 | tim.iter=1000 \ 12 | eval.method='tim_gd' \ 13 | 14 | # ===========================> WRN 28-10 <===================================== 15 | 16 | python3 -m src.main \ 17 | -F logs/tim_gd/mini/wideres \ 18 | with dataset.path="data/mini_imagenet" \ 19 | ckpt_path="checkpoints/mini/softmax/wideres" \ 20 | dataset.split_dir="split/mini" \ 21 | model.arch='wideres' \ 22 | evaluate=True \ 23 | tim.iter=1000 \ 24 | eval.method='tim_gd' \ 25 | 26 | 27 | # # ===========================> DenseNet <===================================== 28 | 29 | python3 -m src.main \ 30 | -F logs/tim_gd/mini/densenet121 \ 31 | with dataset.path="data/mini_imagenet" \ 32 | ckpt_path="checkpoints/mini/softmax/densenet121" \ 33 | dataset.split_dir="split/mini" \ 34 | model.arch='densenet121' \ 35 | evaluate=True \ 36 | tim.iter=1000 \ 37 | eval.method='tim_gd' \ -------------------------------------------------------------------------------- /scripts/evaluate/baseline/tiered.sh: -------------------------------------------------------------------------------- 1 | 2 | # ===========================> Resnet 18 <========================================= 3 | 4 | python3 -m src.main \ 5 | -F logs/non_augmented/baseline/tiered/resnet18 \ 6 | with dataset.path="data/tiered_imagenet/data" \ 7 | ckpt_path="checkpoints/tiered/softmax/resnet18" \ 8 | dataset.split_dir="split/tiered" \ 9 | model.arch='resnet18' \ 10 | model.num_classes=351 \ 11 | evaluate=True 12 | 13 | # ===========================> WRN 28-10 <========================================= 14 | 15 | python3 -m src.main \ 16 | -F logs/non_augmented/baseline/tiered/wideres \ 17 | with dataset.path="data/tiered_imagenet/data" \ 18 | ckpt_path="checkpoints/tiered/softmax/wideres" \ 19 | dataset.split_dir="split/tiered" \ 20 | model.arch='wideres' \ 21 | model.num_classes=351 \ 22 | evaluate=True 23 | 24 | # ===========================> DenseNet <========================================= 25 | 26 | python3 -m src.main \ 27 | -F logs/non_augmented/baseline/tiered/densenet121 \ 28 | with dataset.path="data/tiered_imagenet/data" \ 29 | ckpt_path="checkpoints/tiered/softmax/densenet121" \ 30 | dataset.split_dir="split/tiered" \ 31 | dataset.batch_size=16 \ 32 | model.arch='densenet121' \ 33 | model.num_classes=351 \ 34 | evaluate=True -------------------------------------------------------------------------------- /scripts/evaluate/tim_adm/tiered.sh: -------------------------------------------------------------------------------- 1 | 2 | # ===========================> Resnet 18 <========================================= 3 | 4 | python3 -m src.main \ 5 | -F logs/tim_adm/tiered/resnet18 \ 6 | with dataset.path="data/tiered_imagenet/data" \ 7 | ckpt_path="checkpoints/tiered/softmax/resnet18" \ 8 | dataset.split_dir="split/tiered" \ 9 | model.arch='resnet18' \ 10 | model.num_classes=351 \ 11 | evaluate=True \ 12 | eval.method="tim_adm" 13 | 14 | # ===========================> WRN 28-10 <========================================= 15 | 16 | python3 -m src.main \ 17 | -F logs/tim_adm/tiered/wideres \ 18 | with dataset.path="data/tiered_imagenet/data" \ 19 | ckpt_path="checkpoints/tiered/softmax/wideres" \ 20 | dataset.split_dir="split/tiered" \ 21 | model.num_classes=351 \ 22 | model.arch='wideres' \ 23 | evaluate=True \ 24 | eval.method="tim_adm" 25 | 26 | # # ===========================> DenseNet <========================================= 27 | 28 | python3 -m src.main \ 29 | -F logs/tim_adm/tiered/densenet121 \ 30 | with dataset.path="data/tiered_imagenet/data" \ 31 | ckpt_path="checkpoints/tiered/softmax/densenet121" \ 32 | dataset.split_dir="split/tiered" \ 33 | model.num_classes=351 \ 34 | model.arch='densenet121' \ 35 | evaluate=True \ 36 | eval.method="tim_adm" -------------------------------------------------------------------------------- /src/models/Conv4.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | 3 | __all__ = ['Conv4'] 4 | 5 | 6 | def conv_block(in_channels: int, out_channels: int) -> nn.Module: 7 | return nn.Sequential( 8 | nn.Conv2d(in_channels, out_channels, 3, padding=1), 9 | nn.BatchNorm2d(out_channels), 10 | nn.ReLU(), 11 | nn.MaxPool2d(kernel_size=2, stride=2) 12 | ) 13 | 14 | 15 | class Conv4(nn.Module): 16 | def __init__(self, num_classes, remove_linear=False): 17 | super(Conv4, self).__init__() 18 | self.conv1 = conv_block(3, 64) 19 | self.conv2 = conv_block(64, 64) 20 | self.conv3 = conv_block(64, 64) 21 | self.conv4 = conv_block(64, 64) 22 | if remove_linear: 23 | self.logits = None 24 | else: 25 | self.logits = nn.Linear(1600, num_classes) 26 | 27 | def forward(self, x, feature=False): 28 | x = self.conv1(x) 29 | x = self.conv2(x) 30 | x = self.conv3(x) 31 | x = self.conv4(x) 32 | 33 | x = x.view(x.size(0), -1) 34 | if self.logits is None: 35 | if feature: 36 | return x, None 37 | else: 38 | return x 39 | if feature: 40 | x1 = self.logits(x) 41 | return x, x1 42 | 43 | return self.logits(x) 44 | -------------------------------------------------------------------------------- /src/datasets/sampler.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torch.utils.data import Sampler 4 | 5 | __all__ = ['CategoriesSampler'] 6 | 7 | 8 | class CategoriesSampler(Sampler): 9 | 10 | def __init__(self, label, n_iter, n_way, n_shot, n_query): 11 | 12 | self.n_iter = n_iter 13 | self.n_way = n_way 14 | self.n_shot = n_shot 15 | self.n_query = n_query 16 | 17 | label = np.array(label) 18 | self.m_ind = [] 19 | unique = np.unique(label) 20 | unique = np.sort(unique) 21 | for i in unique: 22 | ind = np.argwhere(label == i).reshape(-1) 23 | ind = torch.from_numpy(ind) 24 | self.m_ind.append(ind) 25 | 26 | def __len__(self): 27 | return self.n_iter 28 | 29 | def __iter__(self): 30 | for i in range(self.n_iter): 31 | batch_gallery = [] 32 | batch_query = [] 33 | classes = torch.randperm(len(self.m_ind))[:self.n_way] 34 | for c in classes: 35 | l = self.m_ind[c.item()] 36 | pos = torch.randperm(l.size()[0]) 37 | batch_gallery.append(l[pos[:self.n_shot]]) 38 | batch_query.append(l[pos[self.n_shot:self.n_shot + self.n_query]]) 39 | batch = torch.cat(batch_gallery + batch_query) 40 | yield batch 41 | -------------------------------------------------------------------------------- /scripts/evaluate/tim_adm/mini.sh: -------------------------------------------------------------------------------- 1 | 2 | # ===========================> Resnet 18 <========================================= 3 | 4 | python3 -m src.main \ 5 | -F logs/tim_adm/mini/resnet18 \ 6 | with dataset.path="data/mini_imagenet" \ 7 | ckpt_path="checkpoints/mini/softmax/resnet18" \ 8 | dataset.split_dir="split/mini" \ 9 | model.arch='resnet18' \ 10 | evaluate=True \ 11 | eval.method="tim_adm" \ 12 | 13 | # ===========================> WRN 28-10 <========================================= 14 | 15 | python3 -m src.main \ 16 | -F logs/tim_adm/mini/wideres \ 17 | with dataset.path="data/mini_imagenet" \ 18 | ckpt_path="checkpoints/mini/softmax/wideres" \ 19 | dataset.split_dir="split/mini" \ 20 | model.arch='wideres' \ 21 | evaluate=True \ 22 | eval.method="tim_adm" \ 23 | 24 | # ===========================> DenseNet <========================================= 25 | 26 | python3 -m src.main \ 27 | -F logs/tim_adm/mini/densenet121 \ 28 | with dataset.path="data/mini_imagenet" \ 29 | ckpt_path="checkpoints/mini/softmax/densenet121" \ 30 | dataset.split_dir="split/mini" \ 31 | model.arch='densenet121' \ 32 | evaluate=True \ 33 | eval.method="tim_adm" \ 34 | 35 | 36 | # python3 -m src.main with dataset.path="data/mini_imagenet" ckpt_path="checkpoints/mini/softmax/wideres" dataset.split_dir="split/mini" model.arch='wideres' evaluate=True eval.method="tim_adm" -------------------------------------------------------------------------------- /scripts/train/resnet18.sh: -------------------------------------------------------------------------------- 1 | # ===============> Mini <=================== 2 | 3 | python3 -m src.main \ 4 | with dataset.path="data/mini_imagenet" \ 5 | visdom_port=8097 \ 6 | dataset.split_dir="split/mini" \ 7 | ckpt_path="checkpoints/mini/softmax/resnet18" \ 8 | dataset.batch_size=256 \ 9 | dataset.jitter=True \ 10 | model.arch='resnet18' \ 11 | model.num_classes=64 \ 12 | optim.scheduler="multi_step" \ 13 | epochs=90 \ 14 | trainer.label_smoothing=0.1 15 | 16 | # ===============> Tiered <=================== 17 | python3 -m src.main \ 18 | with dataset.path="data/tiered_imagenet/data" \ 19 | visdom_port=8097 \ 20 | dataset.split_dir="split/tiered" \ 21 | ckpt_path="checkpoints/tiered/softmax/resnet18" \ 22 | dataset.batch_size=256 \ 23 | dataset.jitter=True \ 24 | model.arch='resnet18' \ 25 | model.num_classes=351 \ 26 | optim.scheduler="multi_step" \ 27 | epochs=90 \ 28 | trainer.label_smoothing=0.1 29 | 30 | # ===============> Tiered <=================== 31 | python3 -m src.main \ 32 | with dataset.path="data/cub/CUB_200_2011/images" \ 33 | visdom_port=8097 \ 34 | dataset.split_dir="split/cub" \ 35 | ckpt_path="checkpoints/cub/softmax/resnet18" \ 36 | dataset.batch_size=256 \ 37 | dataset.jitter=True \ 38 | model.arch='resnet18' \ 39 | model.num_classes=100 \ 40 | optim.scheduler="multi_step" \ 41 | epochs=90 \ 42 | trainer.label_smoothing=0.1 43 | -------------------------------------------------------------------------------- /scripts/evaluate/tim_gd/tiered.sh: -------------------------------------------------------------------------------- 1 | 2 | # ===========================> Resnet 18 <========================================= 3 | 4 | python3 -m src.main \ 5 | -F logs/tim_gd/tiered/resnet18 \ 6 | with dataset.path="data/tiered_imagenet/data" \ 7 | ckpt_path="checkpoints/tiered/softmax/resnet18" \ 8 | dataset.split_dir="split/tiered" \ 9 | model.arch='resnet18' \ 10 | model.num_classes=351 \ 11 | evaluate=True \ 12 | tim.iter=1000 \ 13 | eval.method='tim_gd' \ 14 | 15 | # ===========================> WRN 28-10 <========================================= 16 | 17 | python3 -m src.main \ 18 | -F logs/tim_gd/tiered/wideres \ 19 | with dataset.path="data/tiered_imagenet/data" \ 20 | ckpt_path="checkpoints/tiered/softmax/wideres" \ 21 | dataset.split_dir="split/tiered" \ 22 | model.arch='wideres' \ 23 | model.num_classes=351 \ 24 | tim.iter=1000 \ 25 | evaluate=True \ 26 | eval.method='tim_gd' \ 27 | 28 | # ===========================> DenseNet <========================================= 29 | 30 | python3 -m src.main \ 31 | -F logs/tim_gd/tiered/densenet121 \ 32 | with dataset.path="data/tiered_imagenet/data" \ 33 | ckpt_path="checkpoints/tiered/softmax/densenet121" \ 34 | dataset.split_dir="split/tiered" \ 35 | dataset.batch_size=16 \ 36 | model.arch='densenet121' \ 37 | model.num_classes=351 \ 38 | tim.iter=1000 \ 39 | evaluate=True \ 40 | eval.method='tim_gd' \ 41 | 42 | -------------------------------------------------------------------------------- /src/optim.py: -------------------------------------------------------------------------------- 1 | import torch.optim 2 | from sacred import Ingredient 3 | from torch.optim.lr_scheduler import MultiStepLR, StepLR, CosineAnnealingLR 4 | 5 | optim_ingredient = Ingredient('optim') 6 | 7 | 8 | @optim_ingredient.config 9 | def config(): 10 | gamma = 0.1 11 | lr = 0.1 12 | lr_stepsize = 30 13 | nesterov = False 14 | weight_decay = 1e-4 15 | optimizer_name = 'SGD' 16 | scheduler = 'multi_step' 17 | 18 | 19 | @optim_ingredient.capture 20 | def get_scheduler(epochs, num_batches, optimizer, gamma, lr_stepsize, scheduler): 21 | 22 | SCHEDULER = {'step': StepLR(optimizer, lr_stepsize, gamma), 23 | 'multi_step': MultiStepLR(optimizer, milestones=[int(.5 * epochs), int(.75 * epochs)], 24 | gamma=gamma), 25 | 'cosine': CosineAnnealingLR(optimizer, num_batches * epochs, eta_min=1e-9), 26 | None: None} 27 | print(epochs) 28 | return SCHEDULER[scheduler] 29 | 30 | 31 | @optim_ingredient.capture 32 | def get_optimizer(module, optimizer_name, nesterov, lr, weight_decay): 33 | OPTIMIZER = {'SGD': torch.optim.SGD(module.parameters(), lr=lr, momentum=0.9, 34 | weight_decay=weight_decay, nesterov=nesterov), 35 | 'Adam': torch.optim.Adam(module.parameters(), lr=lr, weight_decay=weight_decay)} 36 | return OPTIMIZER[optimizer_name] -------------------------------------------------------------------------------- /src/datasets/ingredient.py: -------------------------------------------------------------------------------- 1 | from sacred import Ingredient 2 | from src.datasets.loader import DatasetFolder 3 | from src.datasets.sampler import CategoriesSampler 4 | from src.datasets.transform import with_augment, without_augment 5 | from torch.utils.data import DataLoader 6 | 7 | 8 | dataset_ingredient = Ingredient('dataset') 9 | @dataset_ingredient.config 10 | def config(): 11 | batch_size = 256 12 | enlarge = True 13 | num_workers = 4 14 | disable_random_resize = False 15 | jitter = False 16 | path = 'data' 17 | split_dir = None 18 | 19 | 20 | @dataset_ingredient.capture 21 | def get_dataloader(split, enlarge, num_workers, batch_size, disable_random_resize, 22 | path, split_dir, jitter, aug=False, shuffle=True, out_name=False, 23 | sample=None): 24 | # sample: iter, way, shot, query 25 | if aug: 26 | transform = with_augment(84, disable_random_resize=disable_random_resize, 27 | jitter=jitter) 28 | else: 29 | transform = without_augment(84, enlarge=enlarge) 30 | sets = DatasetFolder(path, split_dir, split, transform, out_name=out_name) 31 | if sample is not None: 32 | sampler = CategoriesSampler(sets.labels, *sample) 33 | loader = DataLoader(sets, batch_sampler=sampler, 34 | num_workers=num_workers, pin_memory=False) 35 | else: 36 | loader = DataLoader(sets, batch_size=batch_size, shuffle=shuffle, 37 | num_workers=num_workers, pin_memory=False) 38 | return loader -------------------------------------------------------------------------------- /src/models/ProtoNet.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | import torch.nn.functional as F 4 | 5 | 6 | def get_metric(metric_type): 7 | METRICS = { 8 | 'cosine': lambda gallery, query: 1. - F.cosine_similarity(query[:, None, :], gallery[None, :, :], dim=2), 9 | 'euclidean': lambda gallery, query: ((query[:, None, :] - gallery[None, :, :]) ** 2).sum(2), 10 | 'l1': lambda gallery, query: torch.norm((query[:, None, :] - gallery[None, :, :]), p=1, dim=2), 11 | 'l2': lambda gallery, query: torch.norm((query[:, None, :] - gallery[None, :, :]), p=2, dim=2), 12 | } 13 | return METRICS[metric_type] 14 | 15 | 16 | class ProtoNet(nn.Module): 17 | 18 | def __init__(self, feature_net, args=None): 19 | super().__init__() 20 | self.encoder = feature_net 21 | if args is None: 22 | self.train_info = [30, 1, 15, 'euclidean'] 23 | # self.val_info = [5, 1, 15] 24 | else: 25 | self.train_info = [args.meta_train_way, args.meta_train_shot, args.meta_train_query, args.meta_train_metric] 26 | # self.val_info = [args.meta_val_way, args.meta_val_shot, args.meta_val_query] 27 | 28 | def forward(self, data, _=False): 29 | if not self.training: 30 | return self.encoder(data, True) 31 | way, shot, query, metric_name = self.train_info 32 | proto, _ = self.encoder(data, True) 33 | shot_proto, query_proto = proto[:shot * way], proto[shot * way:] 34 | shot_proto = shot_proto.reshape(way, shot, -1).mean(1) 35 | logits = -get_metric(metric_name)(shot_proto, query_proto) 36 | 37 | return logits 38 | -------------------------------------------------------------------------------- /src/datasets/loader.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import PIL.Image as Image 4 | import numpy as np 5 | 6 | __all__ = ['DatasetFolder'] 7 | 8 | 9 | class DatasetFolder(object): 10 | 11 | def __init__(self, root, split_dir, split_type, transform, out_name=False): 12 | assert split_type in ['train', 'test', 'val', 'query', 'support'] 13 | split_file = os.path.join(split_dir, split_type + '.csv') 14 | assert os.path.isfile(split_file), split_file 15 | with open(split_file, 'r') as f: 16 | split = [x.strip().split(',') for x in f.readlines()[1:] if x.strip() != ''] 17 | 18 | data, ori_labels = [x[0] for x in split], [x[1] for x in split] 19 | label_key = sorted(np.unique(np.array(ori_labels))) 20 | label_map = dict(zip(label_key, range(len(label_key)))) 21 | mapped_labels = [label_map[x] for x in ori_labels] 22 | 23 | self.root = root 24 | self.transform = transform 25 | self.data = data 26 | self.labels = mapped_labels 27 | self.out_name = out_name 28 | self.length = len(self.data) 29 | 30 | def __len__(self): 31 | return self.length 32 | 33 | def __getitem__(self, index): 34 | assert os.path.isfile(os.path.join(self.root, self.data[index])), os.path.join(self.root, self.data[index]) 35 | img = Image.open(os.path.join(self.root, self.data[index])).convert('RGB') 36 | label = self.labels[index] 37 | label = int(label) 38 | if self.transform: 39 | img = self.transform(img) 40 | if self.out_name: 41 | return img, label, self.data[index] 42 | else: 43 | return img, label, index 44 | -------------------------------------------------------------------------------- /scripts/make_splits/tieredImagenet_split.py: -------------------------------------------------------------------------------- 1 | import pickle as pkl 2 | import cv2 3 | import os 4 | import argparse 5 | import tqdm 6 | 7 | parser = argparse.ArgumentParser() 8 | parser.add_argument('--data', type=str, help='path to the data') 9 | parser.add_argument('--split', type=str, help='path to the split folder') 10 | args = parser.parse_args() 11 | 12 | 13 | def save_imgs(prex1, prex2, tag): 14 | data_file = prex1 + tag + '_images_png.pkl' 15 | label_file = prex1 + tag + '_labels.pkl' 16 | with open(data_file, 'rb') as f: 17 | array = pkl.load(f) 18 | with open(label_file, 'rb') as f: 19 | labels = pkl.load(f) 20 | for idx, (img, glabel, slabel) in enumerate( 21 | tqdm.tqdm(zip(array, labels['label_general'], labels['label_specific']), total=len(array))): 22 | file_name = prex2 + tag + '_' + str(glabel) + '_' + str(slabel) + '_' + str(idx) + '.png' 23 | im = cv2.imdecode(img, 1) 24 | cv2.imwrite(file_name, im) 25 | # print('Finish ' + file_name) 26 | 27 | 28 | if __name__ == '__main__': 29 | prex1 = args.data 30 | prex2 = prex1 + 'data/' 31 | if not os.path.isdir(prex2): 32 | os.makedirs(prex2) 33 | save_imgs(prex1, prex2, 'test') 34 | save_imgs(prex1, prex2, 'train') 35 | save_imgs(prex1, prex2, 'val') 36 | 37 | if not os.path.isdir(args.split): 38 | os.makedirs(args.split) 39 | data = os.listdir(prex2) 40 | with open(args.split + '/train.csv', 'w') as f: 41 | for name in data: 42 | if 'train' in name: 43 | label = name.split('.')[0].split('_')[-2] 44 | f.write('{},{}\n'.format(name, label)) 45 | 46 | with open(args.split + '/val.csv', 'w') as f: 47 | for name in data: 48 | if 'val' in name: 49 | label = name.split('.')[0].split('_')[-2] 50 | f.write('{},{}\n'.format(name, label)) 51 | 52 | with open(args.split + '/test.csv', 'w') as f: 53 | for name in data: 54 | if 'test' in name: 55 | label = name.split('.')[0].split('_')[-2] 56 | f.write('{},{}\n'.format(name, label)) 57 | -------------------------------------------------------------------------------- /scripts/ablation/tim_gd_all/mini.sh: -------------------------------------------------------------------------------- 1 | META_TEST_ITER="1000" 2 | 3 | # Only Xent 4 | python3 -m src.main \ 5 | -F logs/ablation/tim_gd_all/mini \ 6 | with dataset.path="data/mini_imagenet" \ 7 | ckpt_path="checkpoints/mini/softmax/resnet18" \ 8 | dataset.split_dir="split/mini" \ 9 | model.arch='resnet18' \ 10 | evaluate=True \ 11 | eval.method='tim_gd' \ 12 | tim.finetune_encoder=True \ 13 | tim.iter=25 \ 14 | tim.lr=5e-5 \ 15 | eval.meta_test_iter=$META_TEST_ITER \ 16 | tim.loss_weights="[0.1, 0., 0.]" \ 17 | 18 | 19 | # No H(Y) 20 | python3 -m src.main \ 21 | -F logs/ablation/tim_gd_all/mini \ 22 | with dataset.path="data/mini_imagenet" \ 23 | ckpt_path="checkpoints/mini/softmax/resnet18" \ 24 | dataset.split_dir="split/mini" \ 25 | model.arch='resnet18' \ 26 | evaluate=True \ 27 | eval.method='tim_gd' \ 28 | tim.finetune_encoder=True \ 29 | tim.iter=25 \ 30 | tim.lr=5e-5 \ 31 | eval.meta_test_iter=$META_TEST_ITER \ 32 | tim.loss_weights="[0.1, 0., 0.1]" \ 33 | 34 | # No H(Y|X) 35 | python3 -m src.main \ 36 | -F logs/ablation/tim_gd_all/mini \ 37 | with dataset.path="data/mini_imagenet" \ 38 | ckpt_path="checkpoints/mini/softmax/resnet18" \ 39 | dataset.split_dir="split/mini" \ 40 | model.arch='resnet18' \ 41 | evaluate=True \ 42 | eval.method='tim_gd' \ 43 | tim.finetune_encoder=True \ 44 | tim.iter=25 \ 45 | tim.lr=5e-5 \ 46 | eval.meta_test_iter=$META_TEST_ITER \ 47 | tim.loss_weights="[0.1, 1.0, 0.]" \ 48 | 49 | 50 | python3 -m src.main \ 51 | -F logs/ablation/tim_gd_all/mini \ 52 | with dataset.path="data/mini_imagenet" \ 53 | ckpt_path="checkpoints/mini/softmax/resnet18" \ 54 | dataset.split_dir="split/mini" \ 55 | model.arch='resnet18' \ 56 | evaluate=True \ 57 | eval.method='tim_gd' \ 58 | tim.finetune_encoder=True \ 59 | tim.iter=25 \ 60 | tim.lr=5e-5 \ 61 | eval.meta_test_iter=$META_TEST_ITER \ 62 | tim.loss_weights="[0.1, 1.0, 0.1]" \ 63 | -------------------------------------------------------------------------------- /scripts/ablation/tim_gd_all/cub.sh: -------------------------------------------------------------------------------- 1 | META_TEST_ITER="1000" 2 | 3 | # Only XEnt 4 | python3 -m src.main \ 5 | -F logs/ablation/tim_gd_all/cub \ 6 | with dataset.path="data/cub/CUB_200_2011/images" \ 7 | ckpt_path="checkpoints/cub/softmax/resnet18" \ 8 | dataset.split_dir="split/cub" \ 9 | model.arch='resnet18' \ 10 | model.num_classes=100 \ 11 | evaluate=True \ 12 | eval.method='tim_gd' \ 13 | tim.finetune_encoder=True \ 14 | tim.iter=25 \ 15 | tim.lr=5e-5 \ 16 | eval.meta_test_iter=$META_TEST_ITER \ 17 | tim.loss_weights="[0.1, 0., 0.]" \ 18 | 19 | 20 | # No H(Y) 21 | python3 -m src.main \ 22 | -F logs/ablation/tim_gd_all/cub \ 23 | with dataset.path="data/cub/CUB_200_2011/images" \ 24 | ckpt_path="checkpoints/cub/softmax/resnet18" \ 25 | dataset.split_dir="split/cub" \ 26 | model.arch='resnet18' \ 27 | model.num_classes=100 \ 28 | evaluate=True \ 29 | eval.method='tim_gd' \ 30 | tim.finetune_encoder=True \ 31 | tim.iter=25 \ 32 | tim.lr=5e-5 \ 33 | eval.meta_test_iter=$META_TEST_ITER \ 34 | tim.loss_weights="[0.1, 0., 0.1]" \ 35 | 36 | # No H(Y|X) 37 | python3 -m src.main \ 38 | -F logs/ablation/tim_gd_all/cub \ 39 | with dataset.path="data/cub/CUB_200_2011/images" \ 40 | ckpt_path="checkpoints/cub/softmax/resnet18" \ 41 | dataset.split_dir="split/cub" \ 42 | model.arch='resnet18' \ 43 | model.num_classes=100 \ 44 | evaluate=True \ 45 | eval.method='tim_gd' \ 46 | tim.finetune_encoder=True \ 47 | tim.iter=25 \ 48 | tim.lr=5e-5 \ 49 | eval.meta_test_iter=$META_TEST_ITER \ 50 | tim.loss_weights="[0.1, 1.0, 0.]" \ 51 | 52 | # No H(Y|X) 53 | python3 -m src.main \ 54 | -F logs/ablation/tim_gd_all/cub \ 55 | with dataset.path="data/cub/CUB_200_2011/images" \ 56 | ckpt_path="checkpoints/cub/softmax/resnet18" \ 57 | dataset.split_dir="split/cub" \ 58 | model.arch='resnet18' \ 59 | model.num_classes=100 \ 60 | evaluate=True \ 61 | eval.method='tim_gd' \ 62 | tim.finetune_encoder=True \ 63 | tim.iter=25 \ 64 | tim.lr=5e-5 \ 65 | eval.meta_test_iter=$META_TEST_ITER \ 66 | tim.loss_weights="[0.1, 1.0, 0.1]" \ -------------------------------------------------------------------------------- /scripts/ablation/tim_gd_all/tiered.sh: -------------------------------------------------------------------------------- 1 | META_TEST_ITER="1000" 2 | 3 | # Only XEnt 4 | python3 -m src.main \ 5 | -F logs/ablation/tim_gd_all/tiered \ 6 | with dataset.path="data/tiered_imagenet/data" \ 7 | ckpt_path="checkpoints/tiered/softmax/resnet18" \ 8 | dataset.split_dir="split/tiered" \ 9 | model.arch='resnet18' \ 10 | model.num_classes=351 \ 11 | evaluate=True \ 12 | eval.method='tim_gd' \ 13 | tim.finetune_encoder=True \ 14 | tim.iter=25 \ 15 | tim.lr=5e-5 \ 16 | eval.meta_test_iter=$META_TEST_ITER \ 17 | tim.loss_weights="[0.1, 0., 0.]" \ 18 | 19 | # No H(Y) 20 | python3 -m src.main \ 21 | -F logs/ablation/tim_gd_all/tiered \ 22 | with dataset.path="data/tiered_imagenet/data" \ 23 | ckpt_path="checkpoints/tiered/softmax/resnet18" \ 24 | dataset.split_dir="split/tiered" \ 25 | model.arch='resnet18' \ 26 | model.num_classes=351 \ 27 | evaluate=True \ 28 | eval.method='tim_gd' \ 29 | tim.finetune_encoder=True \ 30 | tim.iter=25 \ 31 | tim.lr=5e-5 \ 32 | eval.meta_test_iter=$META_TEST_ITER \ 33 | tim.loss_weights="[0.1, 0., 0.1]" \ 34 | 35 | # No H(Y|X) 36 | python3 -m src.main \ 37 | -F logs/ablation/tim_gd_all/tiered \ 38 | with dataset.path="data/tiered_imagenet/data" \ 39 | ckpt_path="checkpoints/tiered/softmax/resnet18" \ 40 | dataset.split_dir="split/tiered" \ 41 | model.num_classes=351 \ 42 | model.arch='resnet18' \ 43 | evaluate=True \ 44 | eval.method='tim_gd' \ 45 | tim.finetune_encoder=True \ 46 | tim.iter=25 \ 47 | tim.lr=5e-5 \ 48 | eval.meta_test_iter=$META_TEST_ITER \ 49 | tim.loss_weights="[0.1, 1.0, 0.]" \ 50 | 51 | 52 | python3 -m src.main \ 53 | -F logs/ablation/tim_gd_all/tiered \ 54 | with dataset.path="data/tiered_imagenet/data" \ 55 | ckpt_path="checkpoints/tiered/softmax/resnet18" \ 56 | dataset.split_dir="split/tiered" \ 57 | model.num_classes=351 \ 58 | model.arch='resnet18' \ 59 | evaluate=True \ 60 | eval.method='tim_gd' \ 61 | tim.finetune_encoder=True \ 62 | tim.iter=25 \ 63 | tim.lr=5e-5 \ 64 | eval.meta_test_iter=$META_TEST_ITER \ 65 | tim.loss_weights="[0.1, 1.0, 0.1]" \ 66 | -------------------------------------------------------------------------------- /src/datasets/transform.py: -------------------------------------------------------------------------------- 1 | import torchvision.transforms as transforms 2 | 3 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 4 | std=[0.229, 0.224, 0.225]) 5 | 6 | 7 | import torch 8 | from PIL import ImageEnhance 9 | 10 | transformtypedict = dict(Brightness=ImageEnhance.Brightness, 11 | Contrast=ImageEnhance.Contrast, 12 | Sharpness=ImageEnhance.Sharpness, 13 | Color=ImageEnhance.Color) 14 | 15 | 16 | class ImageJitter(object): 17 | def __init__(self, transformdict): 18 | self.transforms = [(transformtypedict[k], transformdict[k]) for k in transformdict] 19 | 20 | def __call__(self, img): 21 | out = img 22 | randtensor = torch.rand(len(self.transforms)) 23 | 24 | for i, (transformer, alpha) in enumerate(self.transforms): 25 | r = alpha*(randtensor[i]*2.0 - 1.0) + 1 26 | out = transformer(out).enhance(r).convert('RGB') 27 | 28 | return out 29 | 30 | 31 | def without_augment(size=84, enlarge=False): 32 | if enlarge: 33 | resize = int(size*256./224.) 34 | else: 35 | resize = size 36 | return transforms.Compose([ 37 | transforms.Resize(resize), 38 | transforms.CenterCrop(size), 39 | transforms.ToTensor(), 40 | normalize, 41 | ]) 42 | 43 | 44 | def with_augment(size=84, disable_random_resize=False, jitter=False): 45 | if disable_random_resize: 46 | return transforms.Compose([ 47 | transforms.Resize(size), 48 | transforms.CenterCrop(size), 49 | transforms.RandomHorizontalFlip(), 50 | transforms.ToTensor(), 51 | normalize, 52 | ]) 53 | else: 54 | if jitter: 55 | return transforms.Compose([ 56 | transforms.RandomResizedCrop(size), 57 | ImageJitter(dict(Brightness=0.4, Contrast=0.4, Color=0.4)), 58 | transforms.RandomHorizontalFlip(), 59 | transforms.ToTensor(), 60 | normalize, 61 | ]) 62 | else: 63 | return transforms.Compose([ 64 | transforms.RandomResizedCrop(size), 65 | transforms.RandomHorizontalFlip(), 66 | transforms.ToTensor(), 67 | normalize, 68 | ]) -------------------------------------------------------------------------------- /src/models/MobileNet.py: -------------------------------------------------------------------------------- 1 | '''MobileNet in PyTorch. 2 | 3 | See the paper "MobileNets: Efficient Convolutional Neural Networks for Mobile Vision Applications" 4 | for more details. 5 | ''' 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | 9 | __all__ = ['MobileNet'] 10 | 11 | 12 | class Block(nn.Module): 13 | '''Depthwise conv + Pointwise conv''' 14 | 15 | def __init__(self, in_planes, out_planes, stride=1): 16 | super(Block, self).__init__() 17 | self.conv1 = nn.Conv2d(in_planes, in_planes, kernel_size=3, stride=stride, padding=1, groups=in_planes, 18 | bias=False) 19 | self.bn1 = nn.BatchNorm2d(in_planes) 20 | self.conv2 = nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=1, padding=0, bias=False) 21 | self.bn2 = nn.BatchNorm2d(out_planes) 22 | 23 | def forward(self, x): 24 | out = F.relu(self.bn1(self.conv1(x))) 25 | out = F.relu(self.bn2(self.conv2(out))) 26 | return out 27 | 28 | 29 | class MobileNet(nn.Module): 30 | # (128,2) means conv planes=128, conv stride=2, by default conv stride=1 31 | cfg = [64, (128, 2), 128, (256, 2), 256, (512, 2), 512, 512, 512, 512, 512, (1024, 2), 1024] 32 | 33 | def __init__(self, num_classes=10, remove_linear=False): 34 | super(MobileNet, self).__init__() 35 | self.conv1 = nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1, bias=False) 36 | self.bn1 = nn.BatchNorm2d(32) 37 | self.layers = self._make_layers(in_planes=32) 38 | if remove_linear: 39 | self.linear = None 40 | else: 41 | self.linear = nn.Linear(1024, num_classes) 42 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 43 | 44 | def _make_layers(self, in_planes): 45 | layers = [] 46 | for x in self.cfg: 47 | out_planes = x if isinstance(x, int) else x[0] 48 | stride = 1 if isinstance(x, int) else x[1] 49 | layers.append(Block(in_planes, out_planes, stride)) 50 | in_planes = out_planes 51 | return nn.Sequential(*layers) 52 | 53 | def forward(self, x, feature=False): 54 | out = F.relu(self.bn1(self.conv1(x))) 55 | out = self.layers(out) 56 | out = self.avgpool(out) 57 | out = out.view(out.size(0), -1) 58 | if self.linear is None: 59 | if feature: 60 | return out, None 61 | else: 62 | return out 63 | if feature: 64 | out1 = self.linear(out) 65 | return out, out1 66 | out = self.linear(out) 67 | return out 68 | -------------------------------------------------------------------------------- /scripts/make_splits/cub_split.py: -------------------------------------------------------------------------------- 1 | 2 | import numpy as np 3 | from os import listdir 4 | from os.path import isfile, isdir, join 5 | import os 6 | import json 7 | import random 8 | 9 | cwd = os.getcwd() 10 | data_path = join(cwd, 'CUB_200_2011/images') 11 | savedir = os.path.join(cwd, '..', '..', 'split', 'cub/') 12 | dataset_list = ['train', 'val', 'test'] 13 | 14 | #if not os.path.exists(savedir): 15 | # os.makedirs(savedir) 16 | 17 | folder_list = [f for f in listdir(data_path) if isdir(join(data_path, f))] 18 | folder_list.sort() 19 | label_dict = dict(zip(folder_list, range(0, len(folder_list)))) 20 | 21 | classfile_list_all = [] 22 | 23 | for i, folder in enumerate(folder_list): 24 | folder_path = join(data_path, folder) 25 | classfile_list_all.append([join(folder_path, cf) for cf in listdir(folder_path) if (isfile(join(folder_path,cf)) and cf[0] != '.')]) 26 | random.shuffle(classfile_list_all[i]) 27 | 28 | 29 | for dataset in dataset_list: 30 | file_list = [] 31 | label_list = [] 32 | for i, classfile_list in enumerate(classfile_list_all): 33 | if 'train' in dataset: 34 | if (i % 2 == 0): 35 | file_list = file_list + classfile_list 36 | label_list = label_list + np.repeat(i, len(classfile_list)).tolist() 37 | if 'val' in dataset: 38 | if (i % 4 == 1): 39 | file_list = file_list + classfile_list 40 | label_list = label_list + np.repeat(i, len(classfile_list)).tolist() 41 | if 'test' in dataset: 42 | if (i % 4 == 3): 43 | file_list = file_list + classfile_list 44 | label_list = label_list + np.repeat(i, len(classfile_list)).tolist() 45 | os.makedirs(savedir, exist_ok=True) 46 | with open(os.path.join(savedir, f'{dataset}.csv'), 'w') as f: 47 | for file, label in zip(file_list, label_list): 48 | file = file.split('/images/')[1] 49 | f.write('{},{}\n'.format(file, label)) 50 | 51 | # fo = open(savedir + dataset + ".csv", "w") 52 | # fo.write('{"label_names": [') 53 | # fo.writelines(['"%s",' % item for item in folder_list]) 54 | # fo.seek(0, os.SEEK_END) 55 | # fo.seek(fo.tell()-1, os.SEEK_SET) 56 | # fo.write('],') 57 | 58 | # fo.write('"image_names": [') 59 | # fo.writelines(['"%s",' % item for item in file_list]) 60 | # fo.seek(0, os.SEEK_END) 61 | # fo.seek(fo.tell()-1, os.SEEK_SET) 62 | # fo.write('],') 63 | 64 | # fo.write('"image_labels": [') 65 | # fo.writelines(['%d,' % item for item in label_list]) 66 | # fo.seek(0, os.SEEK_END) 67 | # fo.seek(fo.tell()-1, os.SEEK_SET) 68 | # fo.write(']}') 69 | 70 | # fo.close() 71 | print("%s -OK" %dataset) -------------------------------------------------------------------------------- /src/models/WideResNet.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import torch.nn.init as init 6 | from torch.autograd import Variable 7 | 8 | 9 | def conv3x3(in_planes, out_planes, stride=1): 10 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=True) 11 | 12 | 13 | def conv_init(m): 14 | classname = m.__class__.__name__ 15 | if classname.find('Conv') != -1: 16 | init.xavier_uniform(m.weight, gain=np.sqrt(2)) 17 | init.constant(m.bias, 0) 18 | elif classname.find('BatchNorm') != -1: 19 | init.constant(m.weight, 1) 20 | init.constant(m.bias, 0) 21 | 22 | 23 | class wide_basic(nn.Module): 24 | def __init__(self, in_planes, planes, dropout_rate, stride=1): 25 | super(wide_basic, self).__init__() 26 | self.bn1 = nn.BatchNorm2d(in_planes) 27 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, padding=1, bias=True) 28 | self.dropout = nn.Dropout(p=dropout_rate) 29 | self.bn2 = nn.BatchNorm2d(planes) 30 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=True) 31 | 32 | self.shortcut = nn.Sequential() 33 | if stride != 1 or in_planes != planes: 34 | self.shortcut = nn.Sequential( 35 | nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride, bias=True), 36 | ) 37 | 38 | def forward(self, x): 39 | out = self.dropout(self.conv1(F.relu(self.bn1(x)))) 40 | out = self.conv2(F.relu(self.bn2(out))) 41 | out += self.shortcut(x) 42 | 43 | return out 44 | 45 | 46 | class Wide_ResNet(nn.Module): 47 | def __init__(self, depth, widen_factor, dropout_rate, num_classes): 48 | super(Wide_ResNet, self).__init__() 49 | self.in_planes = 16 50 | 51 | assert ((depth - 4) % 6 == 0), 'Wide-resnet depth should be 6n+4' 52 | n = (depth - 4) // 6 53 | k = widen_factor 54 | 55 | print('| Wide-Resnet %dx%d' % (depth, k)) 56 | nStages = [16, 16 * k, 32 * k, 64 * k] 57 | 58 | self.conv1 = conv3x3(3, nStages[0]) 59 | self.layer1 = self._wide_layer(wide_basic, nStages[1], n, dropout_rate, stride=1) 60 | self.layer2 = self._wide_layer(wide_basic, nStages[2], n, dropout_rate, stride=2) 61 | self.layer3 = self._wide_layer(wide_basic, nStages[3], n, dropout_rate, stride=2) 62 | self.bn1 = nn.BatchNorm2d(nStages[3], momentum=0.9) 63 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 64 | self.linear = nn.Linear(nStages[3], num_classes) 65 | for m in self.modules(): 66 | if isinstance(m, nn.Conv2d): 67 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 68 | elif isinstance(m, nn.BatchNorm2d): 69 | nn.init.constant_(m.weight, 1) 70 | nn.init.constant_(m.bias, 0) 71 | 72 | def _wide_layer(self, block, planes, num_blocks, dropout_rate, stride): 73 | strides = [stride] + [1] * (num_blocks - 1) 74 | layers = [] 75 | 76 | for stride in strides: 77 | layers.append(block(self.in_planes, planes, dropout_rate, stride)) 78 | self.in_planes = planes 79 | 80 | return nn.Sequential(*layers) 81 | 82 | def forward(self, x, feature=False): 83 | out = self.conv1(x) 84 | out = self.layer1(out) 85 | out = self.layer2(out) 86 | out = self.layer3(out) 87 | out = F.relu(self.bn1(out)) 88 | out = self.avgpool(out) 89 | out = out.view(out.size(0), -1) 90 | 91 | if self.linear is None: 92 | if feature: 93 | return out, None 94 | else: 95 | return out 96 | out1 = self.linear(out) 97 | if feature: 98 | return out, out1 99 | return out1 100 | 101 | 102 | def wideres(num_classes): 103 | """Constructs a wideres-28-10 model without dropout. 104 | """ 105 | return Wide_ResNet(28, 10, 0, num_classes) 106 | 107 | 108 | if __name__ == '__main__': 109 | net = Wide_ResNet(28, 10, 0.3, 10) 110 | y = net(Variable(torch.randn(1, 3, 32, 32))) 111 | -------------------------------------------------------------------------------- /scripts/ablation/tim_adm/weighting_effect.sh: -------------------------------------------------------------------------------- 1 | # ================> mini-Imagenet <================= 2 | # ======================================== 3 | 4 | python3 -m src.main \ 5 | -F logs/ablation/tim_adm/mini \ 6 | with dataset.path="data/mini_imagenet" \ 7 | ckpt_path="checkpoints/mini/softmax/resnet18" \ 8 | dataset.split_dir="split/mini" \ 9 | model.arch='resnet18' \ 10 | evaluate=True \ 11 | eval.method='tim_adm' \ 12 | tim.loss_weights="[0.1, 0., 0.]" \ 13 | 14 | 15 | # No H(Y) 16 | python3 -m src.main \ 17 | -F logs/ablation/tim_adm/mini \ 18 | with dataset.path="data/mini_imagenet" \ 19 | ckpt_path="checkpoints/mini/softmax/resnet18" \ 20 | dataset.split_dir="split/mini" \ 21 | model.arch='resnet18' \ 22 | evaluate=True \ 23 | eval.method='tim_adm' \ 24 | tim.loss_weights="[0.1, 0., 0.1]" \ 25 | 26 | 27 | # No H(Y|X) 28 | python3 -m src.main \ 29 | -F logs/ablation/tim_adm/mini \ 30 | with dataset.path="data/mini_imagenet" \ 31 | ckpt_path="checkpoints/mini/softmax/resnet18" \ 32 | dataset.split_dir="split/mini" \ 33 | model.arch='resnet18' \ 34 | evaluate=True \ 35 | eval.method='tim_adm' \ 36 | tim.loss_weights="[0.1, 1.0, 0.]" \ 37 | 38 | # Full loss 39 | python3 -m src.main \ 40 | -F logs/ablation/tim_adm/mini \ 41 | with dataset.path="data/mini_imagenet" \ 42 | ckpt_path="checkpoints/mini/softmax/resnet18" \ 43 | dataset.split_dir="split/mini" \ 44 | model.arch='resnet18' \ 45 | evaluate=True \ 46 | eval.method='tim_adm' \ 47 | tim.loss_weights="[0.1, 1.0, 0.1]" \ 48 | 49 | # ================> tiered-Imagenet <================= 50 | # ======================================== 51 | 52 | 53 | # No H(Y) 54 | python3 -m src.main \ 55 | -F logs/ablation/tim_adm/tiered \ 56 | with dataset.path="data/tiered_imagenet/data" \ 57 | ckpt_path="checkpoints/tiered/softmax/resnet18" \ 58 | dataset.split_dir="split/tiered" \ 59 | model.arch='resnet18' \ 60 | model.num_classes=351 \ 61 | evaluate=True \ 62 | eval.method='tim_adm' \ 63 | tim.loss_weights="[0.1, 0., 0.]" \ 64 | 65 | # No H(Y) 66 | python3 -m src.main \ 67 | -F logs/ablation/tim_adm/tiered \ 68 | with dataset.path="data/tiered_imagenet/data" \ 69 | ckpt_path="checkpoints/tiered/softmax/resnet18" \ 70 | dataset.split_dir="split/tiered" \ 71 | model.arch='resnet18' \ 72 | model.num_classes=351 \ 73 | evaluate=True \ 74 | eval.method='tim_adm' \ 75 | tim.loss_weights="[0.1, 0., 0.1]" \ 76 | 77 | # No H(Y|X) 78 | python3 -m src.main \ 79 | -F logs/ablation/tim_adm/tiered \ 80 | with dataset.path="data/tiered_imagenet/data" \ 81 | ckpt_path="checkpoints/tiered/softmax/resnet18" \ 82 | dataset.split_dir="split/tiered" \ 83 | model.num_classes=351 \ 84 | model.arch='resnet18' \ 85 | evaluate=True \ 86 | eval.method='tim_adm' \ 87 | tim.loss_weights="[0.1, 1.0, 0.]" \ 88 | 89 | # Full loss 90 | python3 -m src.main \ 91 | -F logs/ablation/tim_adm/tiered \ 92 | with dataset.path="data/tiered_imagenet/data" \ 93 | ckpt_path="checkpoints/tiered/softmax/resnet18" \ 94 | dataset.split_dir="split/tiered" \ 95 | model.num_classes=351 \ 96 | model.arch='resnet18' \ 97 | evaluate=True \ 98 | eval.method='tim_adm' \ 99 | tim.loss_weights="[0.1, 1.0, 0.1]" \ 100 | 101 | 102 | # ================> CUB <================= 103 | # ======================================== 104 | 105 | # No H(Y) 106 | python3 -m src.main \ 107 | -F logs/ablation/tim_adm/cub \ 108 | with dataset.path="data/cub/CUB_200_2011/images" \ 109 | ckpt_path="checkpoints/cub/softmax/resnet18" \ 110 | dataset.split_dir="split/cub" \ 111 | model.arch='resnet18' \ 112 | model.num_classes=100 \ 113 | evaluate=True \ 114 | eval.method='tim_adm' \ 115 | tim.loss_weights="[0.1, 0., 0.]" \ 116 | 117 | # No H(Y) 118 | python3 -m src.main \ 119 | -F logs/ablation/tim_adm/cub \ 120 | with dataset.path="data/cub/CUB_200_2011/images" \ 121 | ckpt_path="checkpoints/cub/softmax/resnet18" \ 122 | dataset.split_dir="split/cub" \ 123 | model.arch='resnet18' \ 124 | model.num_classes=100 \ 125 | evaluate=True \ 126 | eval.method='tim_adm' \ 127 | tim.loss_weights="[0.1, 0., 0.1]" \ 128 | 129 | # No H(Y|X) 130 | python3 -m src.main \ 131 | -F logs/ablation/tim_adm/cub \ 132 | with dataset.path="data/cub/CUB_200_2011/images" \ 133 | ckpt_path="checkpoints/cub/softmax/resnet18" \ 134 | dataset.split_dir="split/cub" \ 135 | model.arch='resnet18' \ 136 | model.num_classes=100 \ 137 | evaluate=True \ 138 | eval.method='tim_adm' \ 139 | tim.loss_weights="[0.1, 1.0, 0.]" \ 140 | 141 | # Full loss 142 | python3 -m src.main \ 143 | -F logs/ablation/tim_adm/cub \ 144 | with dataset.path="data/cub/CUB_200_2011/images" \ 145 | ckpt_path="checkpoints/cub/softmax/resnet18" \ 146 | dataset.split_dir="split/cub" \ 147 | model.arch='resnet18' \ 148 | model.num_classes=100 \ 149 | evaluate=True \ 150 | eval.method='tim_adm' \ 151 | tim.loss_weights="[0.1, 1.0, 0.1]" \ -------------------------------------------------------------------------------- /src/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import shutil 4 | from tqdm import tqdm 5 | import logging 6 | import os 7 | import pickle 8 | import torch.nn.functional as F 9 | 10 | 11 | def get_one_hot(y_s): 12 | num_classes = torch.unique(y_s).size(0) 13 | eye = torch.eye(num_classes).to(y_s.device) 14 | one_hot = [] 15 | for y_task in y_s: 16 | one_hot.append(eye[y_task].unsqueeze(0)) 17 | one_hot = torch.cat(one_hot, 0) 18 | return one_hot 19 | 20 | 21 | def get_logs_path(model_path, method, shot): 22 | exp_path = '_'.join(model_path.split('/')[1:]) 23 | file_path = os.path.join('tmp', exp_path, method) 24 | os.makedirs(file_path, exist_ok=True) 25 | return os.path.join(file_path, f'{shot}.txt') 26 | 27 | 28 | def get_features(model, samples): 29 | features, _ = model(samples, True) 30 | features = F.normalize(features.view(features.size(0), -1), dim=1) 31 | return features 32 | 33 | 34 | def get_loss(logits_s, logits_q, labels_s, lambdaa): 35 | Q = logits_q.softmax(2) 36 | y_s_one_hot = get_one_hot(labels_s) 37 | ce_sup = - (y_s_one_hot * torch.log(logits_s.softmax(2) + 1e-12)).sum(2).mean(1) # Taking the mean over samples within a task, and summing over all samples 38 | ent_q = get_entropy(Q) 39 | cond_ent_q = get_cond_entropy(Q) 40 | loss = - (ent_q - cond_ent_q) + lambdaa * ce_sup 41 | return loss 42 | 43 | 44 | def get_mi(probs): 45 | q_cond_ent = get_cond_entropy(probs) 46 | q_ent = get_entropy(probs) 47 | return q_ent - q_cond_ent 48 | 49 | 50 | def get_entropy(probs): 51 | q_ent = - (probs.mean(1) * torch.log(probs.mean(1) + 1e-12)).sum(1, keepdim=True) 52 | return q_ent 53 | 54 | 55 | def get_cond_entropy(probs): 56 | q_cond_ent = - (probs * torch.log(probs + 1e-12)).sum(2).mean(1, keepdim=True) 57 | return q_cond_ent 58 | 59 | 60 | def get_metric(metric_type): 61 | METRICS = { 62 | 'cosine': lambda gallery, query: 1. - F.cosine_similarity(query[:, None, :], gallery[None, :, :], dim=2), 63 | 'euclidean': lambda gallery, query: ((query[:, None, :] - gallery[None, :, :]) ** 2).sum(2), 64 | 'l1': lambda gallery, query: torch.norm((query[:, None, :] - gallery[None, :, :]), p=1, dim=2), 65 | 'l2': lambda gallery, query: torch.norm((query[:, None, :] - gallery[None, :, :]), p=2, dim=2), 66 | } 67 | return METRICS[metric_type] 68 | 69 | 70 | class AverageMeter(object): 71 | def __init__(self): 72 | self.reset() 73 | 74 | def reset(self): 75 | self.val = 0 76 | self.avg = 0 77 | self.sum = 0 78 | self.count = 0 79 | 80 | def update(self, val, n=1): 81 | self.val = val 82 | self.sum += val * n 83 | self.count += n 84 | self.avg = self.sum / self.count 85 | 86 | 87 | def setup_logger(filepath): 88 | file_formatter = logging.Formatter( 89 | "[%(asctime)s %(filename)s:%(lineno)s] %(levelname)-8s %(message)s", 90 | datefmt='%Y-%m-%d %H:%M:%S', 91 | ) 92 | logger = logging.getLogger('example') 93 | # handler = logging.StreamHandler() 94 | # handler.setFormatter(file_formatter) 95 | # logger.addHandler(handler) 96 | 97 | file_handle_name = "file" 98 | if file_handle_name in [h.name for h in logger.handlers]: 99 | return 100 | if os.path.dirname(filepath) != '': 101 | if not os.path.isdir(os.path.dirname(filepath)): 102 | os.makedirs(os.path.dirname(filepath)) 103 | file_handle = logging.FileHandler(filename=filepath, mode="a") 104 | file_handle.set_name(file_handle_name) 105 | file_handle.setFormatter(file_formatter) 106 | logger.addHandler(file_handle) 107 | logger.setLevel(logging.DEBUG) 108 | return logger 109 | 110 | 111 | def warp_tqdm(data_loader, disable_tqdm): 112 | if disable_tqdm: 113 | tqdm_loader = data_loader 114 | else: 115 | tqdm_loader = tqdm(data_loader, total=len(data_loader)) 116 | return tqdm_loader 117 | 118 | 119 | def save_pickle(file, data): 120 | with open(file, 'wb') as f: 121 | pickle.dump(data, f) 122 | 123 | 124 | def load_pickle(file): 125 | with open(file, 'rb') as f: 126 | return pickle.load(f) 127 | 128 | 129 | def save_checkpoint(state, is_best, filename='checkpoint.pth.tar', folder='result/default'): 130 | os.makedirs(folder, exist_ok=True) 131 | torch.save(state, os.path.join(folder, filename)) 132 | if is_best: 133 | shutil.copyfile(folder + '/' + filename, folder + '/model_best.pth.tar') 134 | 135 | 136 | def load_checkpoint(model, model_path, type='best'): 137 | if type == 'best': 138 | checkpoint = torch.load('{}/model_best.pth.tar'.format(model_path)) 139 | elif type == 'last': 140 | checkpoint = torch.load('{}/checkpoint.pth.tar'.format(model_path)) 141 | else: 142 | assert False, 'type should be in [best, or last], but got {}'.format(type) 143 | state_dict = checkpoint['state_dict'] 144 | names = [] 145 | for k, v in state_dict.items(): 146 | names.append(k) 147 | model.load_state_dict(state_dict) 148 | 149 | 150 | def compute_confidence_interval(data, axis=0): 151 | """ 152 | Compute 95% confidence interval 153 | :param data: An array of mean accuracy (or mAP) across a number of sampled episodes. 154 | :return: the 95% confidence interval for this data. 155 | """ 156 | a = 1.0 * np.array(data) 157 | m = np.mean(a, axis=axis) 158 | std = np.std(a, axis=axis) 159 | pm = 1.96 * (std / np.sqrt(a.shape[axis])) 160 | return m, pm 161 | -------------------------------------------------------------------------------- /scripts/ablation/tim_gd/weighting_effect.sh: -------------------------------------------------------------------------------- 1 | # # ================> mini-Imagenet <================= 2 | # # ======================================== 3 | # Only Xent 4 | python3 -m src.main \ 5 | -F logs/ablation/tim_gd/mini \ 6 | with dataset.path="data/mini_imagenet" \ 7 | ckpt_path="checkpoints/mini/softmax/resnet18" \ 8 | dataset.split_dir="split/mini" \ 9 | model.arch='resnet18' \ 10 | evaluate=True \ 11 | eval.method='tim_gd' \ 12 | tim.iter=1000 \ 13 | tim.loss_weights="[0.1, 0., 0.]" \ 14 | 15 | 16 | # No H(Y) 17 | python3 -m src.main \ 18 | -F logs/ablation/tim_gd/mini \ 19 | with dataset.path="data/mini_imagenet" \ 20 | ckpt_path="checkpoints/mini/softmax/resnet18" \ 21 | dataset.split_dir="split/mini" \ 22 | model.arch='resnet18' \ 23 | evaluate=True \ 24 | eval.method='tim_gd' \ 25 | tim.iter=1000 \ 26 | tim.loss_weights="[0.1, 0., 0.1]" \ 27 | 28 | 29 | # No H(Y|X) 30 | python3 -m src.main \ 31 | -F logs/ablation/tim_gd/mini \ 32 | with dataset.path="data/mini_imagenet" \ 33 | ckpt_path="checkpoints/mini/softmax/resnet18" \ 34 | dataset.split_dir="split/mini" \ 35 | model.arch='resnet18' \ 36 | evaluate=True \ 37 | eval.method='tim_gd' \ 38 | tim.iter=1000 \ 39 | tim.loss_weights="[0.1, 1.0, 0.]" \ 40 | 41 | # Full loss 42 | python3 -m src.main \ 43 | -F logs/ablation/tim_gd/mini \ 44 | with dataset.path="data/mini_imagenet" \ 45 | ckpt_path="checkpoints/mini/softmax/resnet18" \ 46 | dataset.split_dir="split/mini" \ 47 | model.arch='resnet18' \ 48 | evaluate=True \ 49 | eval.method='tim_gd' \ 50 | tim.iter=1000 \ 51 | tim.loss_weights="[0.1, 1.0, 0.1]" \ 52 | 53 | 54 | # # ================> tiered-Imagenet <================= 55 | # # ======================================== 56 | 57 | # Only XEnt 58 | python3 -m src.main \ 59 | -F logs/ablation/tim_gd/tiered \ 60 | with dataset.path="data/tiered_imagenet/data" \ 61 | ckpt_path="checkpoints/tiered/softmax/resnet18" \ 62 | dataset.split_dir="split/tiered" \ 63 | model.arch='resnet18' \ 64 | model.num_classes=351 \ 65 | evaluate=True \ 66 | eval.method='tim_gd' \ 67 | tim.iter=1000 \ 68 | tim.loss_weights="[0.1, 0., 0.]" \ 69 | 70 | 71 | # No H(Y) 72 | python3 -m src.main \ 73 | -F logs/ablation/tim_gd/tiered \ 74 | with dataset.path="data/tiered_imagenet/data" \ 75 | ckpt_path="checkpoints/tiered/softmax/resnet18" \ 76 | dataset.split_dir="split/tiered" \ 77 | model.arch='resnet18' \ 78 | model.num_classes=351 \ 79 | evaluate=True \ 80 | eval.method='tim_gd' \ 81 | tim.iter=1000 \ 82 | tim.loss_weights="[0.1, 0., 0.1]" \ 83 | 84 | # No H(Y|X) 85 | python3 -m src.main \ 86 | -F logs/ablation/tim_gd/tiered \ 87 | with dataset.path="data/tiered_imagenet/data" \ 88 | ckpt_path="checkpoints/tiered/softmax/resnet18" \ 89 | dataset.split_dir="split/tiered" \ 90 | model.num_classes=351 \ 91 | model.arch='resnet18' \ 92 | evaluate=True \ 93 | eval.method='tim_gd' \ 94 | tim.iter=1000 \ 95 | tim.loss_weights="[0.1, 1.0, 0.]" \ 96 | 97 | # No H(Y|X) 98 | python3 -m src.main \ 99 | -F logs/ablation/tim_gd/tiered \ 100 | with dataset.path="data/tiered_imagenet/data" \ 101 | ckpt_path="checkpoints/tiered/softmax/resnet18" \ 102 | dataset.split_dir="split/tiered" \ 103 | model.num_classes=351 \ 104 | model.arch='resnet18' \ 105 | evaluate=True \ 106 | eval.method='tim_gd' \ 107 | tim.iter=1000 \ 108 | tim.loss_weights="[0.1, 1.0, 0.1]" \ 109 | 110 | 111 | # ================> CUB <================= 112 | # ======================================== 113 | 114 | # Only XEnt 115 | python3 -m src.main \ 116 | -F logs/ablation/tim_gd/cub \ 117 | with dataset.path="data/cub/CUB_200_2011/images" \ 118 | ckpt_path="checkpoints/cub/softmax/resnet18" \ 119 | dataset.split_dir="split/cub" \ 120 | model.arch='resnet18' \ 121 | model.num_classes=100 \ 122 | evaluate=True \ 123 | eval.method='tim_gd' \ 124 | tim.iter=1000 \ 125 | tim.loss_weights="[0.1, 0., 0.]" \ 126 | 127 | 128 | # No H(Y) 129 | python3 -m src.main \ 130 | -F logs/ablation/tim_gd/cub \ 131 | with dataset.path="data/cub/CUB_200_2011/images" \ 132 | ckpt_path="checkpoints/cub/softmax/resnet18" \ 133 | dataset.split_dir="split/cub" \ 134 | model.arch='resnet18' \ 135 | model.num_classes=100 \ 136 | evaluate=True \ 137 | eval.method='tim_gd' \ 138 | tim.iter=1000 \ 139 | tim.loss_weights="[0.1, 0., 0.1]" \ 140 | 141 | # No H(Y|X) 142 | python3 -m src.main \ 143 | -F logs/ablation/tim_gd/cub \ 144 | with dataset.path="data/cub/CUB_200_2011/images" \ 145 | ckpt_path="checkpoints/cub/softmax/resnet18" \ 146 | dataset.split_dir="split/cub" \ 147 | model.arch='resnet18' \ 148 | model.num_classes=100 \ 149 | evaluate=True \ 150 | eval.method='tim_gd' \ 151 | tim.iter=1000 \ 152 | tim.loss_weights="[0.1, 1.0, 0.]" \ 153 | 154 | # Full loss 155 | python3 -m src.main \ 156 | -F logs/ablation/tim_gd/cub \ 157 | with dataset.path="data/cub/CUB_200_2011/images" \ 158 | ckpt_path="checkpoints/cub/softmax/resnet18" \ 159 | dataset.split_dir="split/cub" \ 160 | model.arch='resnet18' \ 161 | model.num_classes=100 \ 162 | evaluate=True \ 163 | eval.method='tim_gd' \ 164 | tim.iter=1000 \ 165 | tim.loss_weights="[0.1, 1.0, 0.1]" \ -------------------------------------------------------------------------------- /src/main.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import sacred 4 | import torch 5 | import torch.backends.cudnn as cudnn 6 | import torch.utils.data 7 | import torch.utils.data.distributed 8 | from visdom_logger import VisdomLogger 9 | from sacred import SETTINGS 10 | from sacred.utils import apply_backspaces_and_linefeeds 11 | from src.utils import warp_tqdm, save_checkpoint 12 | from src.trainer import Trainer, trainer_ingredient 13 | from src.eval import Evaluator 14 | from src.eval import eval_ingredient 15 | from src.tim import tim_ingredient 16 | from src.optim import optim_ingredient, get_optimizer, get_scheduler 17 | from src.datasets.ingredient import dataset_ingredient 18 | from src.models.ingredient import get_model, model_ingredient 19 | 20 | ex = sacred.Experiment('FSL training', 21 | ingredients=[trainer_ingredient, eval_ingredient, 22 | optim_ingredient, dataset_ingredient, 23 | model_ingredient, tim_ingredient]) 24 | # Filter backspaces and linefeeds 25 | SETTINGS.CAPTURE_MODE = 'sys' 26 | ex.captured_out_filter = apply_backspaces_and_linefeeds 27 | 28 | 29 | @ex.config 30 | def config(): 31 | ckpt_path = os.path.join('checkpoints') 32 | seed = 2020 33 | pretrain = False 34 | resume = False 35 | evaluate = False 36 | make_plot = False 37 | epochs = 90 38 | disable_tqdm = False 39 | visdom_port = None 40 | print_runtime = False 41 | cuda = True 42 | 43 | 44 | @ex.automain 45 | def main(seed, pretrain, resume, evaluate, print_runtime, 46 | epochs, disable_tqdm, visdom_port, ckpt_path, 47 | make_plot, cuda): 48 | device = torch.device("cuda" if cuda else "cpu") 49 | callback = None if visdom_port is None else VisdomLogger(port=visdom_port) 50 | if seed is not None: 51 | random.seed(seed) 52 | torch.manual_seed(seed) 53 | cudnn.deterministic = True 54 | torch.cuda.set_device(0) 55 | # create model 56 | print("=> Creating model '{}'".format(ex.current_run.config['model']['arch'])) 57 | model = torch.nn.DataParallel(get_model()).cuda() 58 | 59 | print('Number of model parameters: {}'.format(sum([p.data.nelement() for p in model.parameters()]))) 60 | optimizer = get_optimizer(model) 61 | 62 | if pretrain: 63 | pretrain = os.path.join(pretrain, 'checkpoint.pth.tar') 64 | if os.path.isfile(pretrain): 65 | print("=> loading pretrained weight '{}'".format(pretrain)) 66 | checkpoint = torch.load(pretrain) 67 | model_dict = model.state_dict() 68 | params = checkpoint['state_dict'] 69 | params = {k: v for k, v in params.items() if k in model_dict} 70 | model_dict.update(params) 71 | model.load_state_dict(model_dict) 72 | else: 73 | print('[Warning]: Did not find pretrained model {}'.format(pretrain)) 74 | 75 | if resume: 76 | resume_path = ckpt_path + '/checkpoint.pth.tar' 77 | if os.path.isfile(resume_path): 78 | print("=> loading checkpoint '{}'".format(resume_path)) 79 | checkpoint = torch.load(resume_path) 80 | start_epoch = checkpoint['epoch'] 81 | best_prec1 = checkpoint['best_prec1'] 82 | # scheduler.load_state_dict(checkpoint['scheduler']) 83 | model.load_state_dict(checkpoint['state_dict']) 84 | optimizer.load_state_dict(checkpoint['optimizer']) 85 | print("=> loaded checkpoint '{}' (epoch {})" 86 | .format(resume_path, checkpoint['epoch'])) 87 | else: 88 | print('[Warning]: Did not find checkpoint {}'.format(resume_path)) 89 | else: 90 | start_epoch = 0 91 | best_prec1 = -1 92 | 93 | cudnn.benchmark = True 94 | 95 | # Data loading code 96 | evaluator = Evaluator(device=device, ex=ex) 97 | if evaluate: 98 | results = evaluator.run_full_evaluation(model=model, 99 | model_path=ckpt_path, 100 | callback=callback) 101 | return results 102 | 103 | # If this line is reached, then training the model 104 | trainer = Trainer(device=device, ex=ex) 105 | scheduler = get_scheduler(optimizer=optimizer, 106 | num_batches=len(trainer.train_loader), 107 | epochs=epochs) 108 | tqdm_loop = warp_tqdm(list(range(start_epoch, epochs)), 109 | disable_tqdm=disable_tqdm) 110 | for epoch in tqdm_loop: 111 | # Do one epoch 112 | trainer.do_epoch(model=model, optimizer=optimizer, epoch=epoch, 113 | scheduler=scheduler, disable_tqdm=disable_tqdm, 114 | callback=callback) 115 | 116 | # Evaluation on validation set 117 | prec1 = trainer.meta_val(model=model, disable_tqdm=disable_tqdm, 118 | epoch=epoch, callback=callback) 119 | print('Meta Val {}: {}'.format(epoch, prec1)) 120 | is_best = prec1 > best_prec1 121 | best_prec1 = max(prec1, best_prec1) 122 | if not disable_tqdm: 123 | tqdm_loop.set_description('Best Acc {:.2f}'.format(best_prec1 * 100.)) 124 | 125 | # Save checkpoint 126 | save_checkpoint(state={'epoch': epoch + 1, 127 | 'arch': ex.current_run.config['model']['arch'], 128 | 'state_dict': model.state_dict(), 129 | 'best_prec1': best_prec1, 130 | 'optimizer': optimizer.state_dict()}, 131 | is_best=is_best, 132 | folder=ckpt_path) 133 | if scheduler is not None: 134 | scheduler.step() 135 | 136 | # Final evaluation on test set 137 | results = evaluator.run_full_evaluation(model=model, model_path=ckpt_path) 138 | return results 139 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # TIM: Transductive Information Maximization 2 | 3 | 4 | ## Introduction 5 | This repo contains the code for our NeurIPS 2020 paper "Transductive Infomation Maximization (TIM) for few-shot learning" available at https://arxiv.org/abs/2008.11297. Our method maximizes the mutual information between the query features and predictions of a few-shot task, subject to supervision constraint from the support set. Results provided in the paper can be reproduced with this repo. Code was developped under python 3.8.3 and pytorch 1.4.0. The code is parallelized over tasks (which makes the execution of the 10'000 tasks very efficient). 6 | 7 | 8 | ## 1. Getting started 9 | 10 | 11 | Please find the data and pretrained models at [icloud drive](https://www.icloud.com/iclouddrive/0f3PFO3rJK0fk0nkCe8sDKAiQ#TIM). Please use `cat` command to reform the original file, and extract 12 | 13 | **Checkpoints:** The checkpoints/ directory should be placed in the root dir, and be structured as follows: 14 | 15 | ``` 16 | ├── mini 17 | │ └── softmax 18 | │ ├── densenet121 19 | │ │ ├── best 20 | │ │ ├── checkpoint.pth.tar 21 | │ │ └── model_best.pth.tar 22 | ``` 23 | 24 | **Data:** The checkpoints should be placed in the root directory, and have a structure like. Because of the size, tiered_imagenet has been sharded into 24 shard, 1GB each. Use `cat tiered_imagenet_*` to reform the original file. Extract everything to data/. The data folder should should be structured as follows: 25 | 26 | ``` 27 | ├── cub 28 | ├── attributes.txt 29 | └── CUB_200_2011 30 | ├── attributes 31 | ├── bounding_boxes.txt 32 | ├── classes.txt 33 | ├── image_class_labels.txt 34 | ├── images 35 | ├── images.txt 36 | ├── parts 37 | ├── README 38 | └── train_test_split.txt 39 | ├── mini_imagenet 40 | └── tiered_imagenet 41 | ├── class_names.txt 42 | ├── data 43 | ├── synsets.txt 44 | ├── test_images_png.pkl 45 | ├── test_labels.pkl 46 | ├── train_images_png.pkl 47 | ├── train_labels.pkl 48 | ├── val_images_png.pkl 49 | └── val_labels.pkl 50 | ``` 51 | 52 | All required libraries should be easily found online, except for visdom_logger that you can download using: 53 | ``` 54 | pip install git+https://github.com/luizgh/visdom_logger 55 | ``` 56 | 57 | ## 2. Train models (optional) 58 | 59 | Instead of using the pre-trained models, you may want to train the models from scratch. Before anything, don't forget to activate the downloaded environment: 60 | ```python 61 | source env/bin/activate 62 | ``` 63 | Then to visualize the results, turn on your local visdom server: 64 | ```python 65 | python -m visdom.server -port 8097 66 | ``` 67 | and open it in your browser : http://localhost:8097/ . Then, for instance, if you want to train a Resnet-18 on mini-Imagenet, go to the root of the directory, and execute: 68 | ```python 69 | bash scripts/train/resnet18.sh 70 | ``` 71 | 72 | **Important :** Whenever you have trained yourself a new model and want to test it, please specify the option `eval.fresh_start=True` to your test command. Otherwise, the code may use cached information (used to speed-up experiments) from previously used models that are longer valid. 73 | 74 | ## 3. Reproducing the main results 75 | 76 | Before anything, don't forget to activate the downloaded environement: 77 | ```python 78 | source env/bin/activate 79 | ``` 80 | 81 | ### 3.1 Benchmarks (Table 1. in paper) 82 | 83 | 84 | |(1 shot/5 shot)| Arch | mini-Imagenet | Tiered-Imagenet | 85 | | --- | --- | --- | --- | 86 | | TIM-ADM | Resnet-18 | 73.6 / **85.0** | **80.0** / **88.5** | 87 | | TIM-GD | Resnet-18 | **73.9** / **85.0** | 79.9 / **88.5** | 88 | | TIM-ADM | WRN28-10 | 77.5 / 87.2 | 82.0 / 89.7 | 89 | | TIM-GD | WRN28-10 | **77.8** / **87.4** | **82.1** / **89.8** | 90 | 91 | To reproduce the results from Table 1. in the paper, use the bash files at scripts/evaluate/. For instance, if you want to reproduce the methods on mini-Imagenet, go to the root of the directory and execute: 92 | ```python 93 | bash scripts/evaluate//mini.sh 94 | ``` 95 | This will reproduce the results for the three network architectures in the paper (Resnet-18/WideResNet28-10/DenseNet-121). Upon completion, exhaustive logs can be found in logs/ folder 96 | 97 | 98 | ### 3.2 Domain shift (Table 2. in paper) 99 | 100 | |(5 shot) | Arch | CUB -> CUB | mini-Imagenet -> CUB | 101 | | --- | --- | --- | --- | 102 | | TIM-ADM | Resnet18 | 90.7 | 70.3 | 103 | | TIM-GD | Resnet18 | **90.8** | **71.0** | 104 | 105 | If you want to reproduce the methods on CUB -> CUB, go to the root of the directory and execute: 106 | ```python 107 | bash scripts/evaluate//cub.sh 108 | ``` 109 | If you want to reproduce the methods on mini -> CUB, go to the root of the directory and execute: 110 | ```python 111 | bash scripts/evaluate//mini2cub.sh 112 | ``` 113 | 114 | ### 3.3 Tasks with more ways (Table 3. in paper) 115 | 116 | If you want to reproduce the methods with more ways (10 and 20 ways) on mini-Imagenet, go to the root of the directory and execute: 117 | 118 | ```python 119 | bash scripts/evaluate//mini_10_20_ways.sh 120 | ``` 121 | 122 | |(1 shot/5 shot)| Arch | 10 ways | 20 ways | 123 | | --- | --- | --- | --- | 124 | | TIM-ADM | Resnet18 | 56.0 / **72.9** | **39.5** / 58.8 | 125 | | TIM-GD | Resnet18 |**56.1** / 72.8 | 39.3 / **59.5** | 126 | 127 | 128 | ### 3.4 Ablation study (Table 4. in paper) 129 | 130 | If you want to reproduce the 4 loss configurations of on mini-Imagenet, Tiered-Imagenet and CUB, go to the root of the directory and execute: 131 | ```python 132 | bash scripts/ablation//weighting_effect.sh 133 | ``` 134 | for respectively TIM-ADM, TIM-GD {W} and TIM-GD {phi, W}. 135 | 136 | 137 | 138 | 139 | 140 | ## Contact 141 | 142 | For further questions or details, reach out to Malik Boudiaf (malik.boudiaf.1@etsmtl.net) 143 | 144 | ## Acknowledgements 145 | 146 | We would like to thank the authors from SimpleShot code https://github.com/mileyan/simple_shot and LaplacianShot https://github.com/imtiazziko/LaplacianShot for giving access to their pre-trained models and to their codes from which this repo was inspired. 147 | -------------------------------------------------------------------------------- /src/models/DenseNet.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | __all__ = ['densenet121', 'densenet169', 'densenet201', 'densenet161'] 8 | 9 | 10 | class _DenseLayer(nn.Sequential): 11 | def __init__(self, num_input_features, growth_rate, bn_size, drop_rate): 12 | super(_DenseLayer, self).__init__() 13 | self.add_module('norm1', nn.BatchNorm2d(num_input_features)), 14 | self.add_module('relu1', nn.ReLU(inplace=True)), 15 | self.add_module('conv1', nn.Conv2d(num_input_features, bn_size * 16 | growth_rate, kernel_size=1, stride=1, bias=False)), 17 | self.add_module('norm2', nn.BatchNorm2d(bn_size * growth_rate)), 18 | self.add_module('relu2', nn.ReLU(inplace=True)), 19 | self.add_module('conv2', nn.Conv2d(bn_size * growth_rate, growth_rate, 20 | kernel_size=3, stride=1, padding=1, bias=False)), 21 | self.drop_rate = drop_rate 22 | 23 | def forward(self, x): 24 | new_features = super(_DenseLayer, self).forward(x) 25 | if self.drop_rate > 0: 26 | new_features = F.dropout(new_features, p=self.drop_rate, training=self.training) 27 | return torch.cat([x, new_features], 1) 28 | 29 | 30 | class _DenseBlock(nn.Sequential): 31 | def __init__(self, num_layers, num_input_features, bn_size, growth_rate, drop_rate): 32 | super(_DenseBlock, self).__init__() 33 | for i in range(num_layers): 34 | layer = _DenseLayer(num_input_features + i * growth_rate, growth_rate, bn_size, drop_rate) 35 | self.add_module('denselayer%d' % (i + 1), layer) 36 | 37 | 38 | class _Transition(nn.Sequential): 39 | def __init__(self, num_input_features, num_output_features): 40 | super(_Transition, self).__init__() 41 | self.add_module('norm', nn.BatchNorm2d(num_input_features)) 42 | self.add_module('relu', nn.ReLU(inplace=True)) 43 | self.add_module('conv', nn.Conv2d(num_input_features, num_output_features, 44 | kernel_size=1, stride=1, bias=False)) 45 | self.add_module('pool', nn.AvgPool2d(kernel_size=2, stride=2)) 46 | 47 | 48 | class DenseNet(nn.Module): 49 | r"""Densenet-BC model class, based on 50 | `"Densely Connected Convolutional Networks" `_ 51 | 52 | Args: 53 | growth_rate (int) - how many filters to add each layer (`k` in paper) 54 | block_config (list of 4 ints) - how many layers in each pooling block 55 | num_init_features (int) - the number of filters to learn in the first convolution layer 56 | bn_size (int) - multiplicative factor for number of bottle neck layers 57 | (i.e. bn_size * k features in the bottleneck layer) 58 | drop_rate (float) - dropout rate after each dense layer 59 | num_classes (int) - number of classification classes 60 | """ 61 | 62 | def __init__(self, growth_rate=32, block_config=(6, 12, 24, 16), 63 | num_init_features=64, bn_size=4, drop_rate=0, num_classes=1000): 64 | 65 | super(DenseNet, self).__init__() 66 | 67 | # First convolution 68 | self.features = nn.Sequential(OrderedDict([ 69 | ('conv0', nn.Conv2d(3, num_init_features, kernel_size=3, stride=1, padding=1, bias=False)), 70 | ])) 71 | 72 | # Each denseblock 73 | num_features = num_init_features 74 | for i, num_layers in enumerate(block_config): 75 | block = _DenseBlock(num_layers=num_layers, num_input_features=num_features, 76 | bn_size=bn_size, growth_rate=growth_rate, drop_rate=drop_rate) 77 | self.features.add_module('denseblock%d' % (i + 1), block) 78 | num_features = num_features + num_layers * growth_rate 79 | if i != len(block_config) - 1: 80 | trans = _Transition(num_input_features=num_features, num_output_features=num_features // 2) 81 | self.features.add_module('transition%d' % (i + 1), trans) 82 | num_features = num_features // 2 83 | 84 | # Final batch norm 85 | self.features.add_module('norm5', nn.BatchNorm2d(num_features)) 86 | 87 | # Linear layer 88 | self.classifier = nn.Linear(num_features, num_classes) 89 | 90 | # Official init from torch repo. 91 | for m in self.modules(): 92 | if isinstance(m, nn.Conv2d): 93 | nn.init.kaiming_normal_(m.weight) 94 | elif isinstance(m, nn.BatchNorm2d): 95 | nn.init.constant_(m.weight, 1) 96 | nn.init.constant_(m.bias, 0) 97 | elif isinstance(m, nn.Linear): 98 | nn.init.constant_(m.bias, 0) 99 | 100 | def forward(self, x, feature=False): 101 | features = self.features(x) 102 | out = F.relu(features, inplace=True) 103 | out = F.adaptive_avg_pool2d(out, (1, 1)).view(features.size(0), -1) 104 | if self.classifier is None: 105 | if feature: 106 | return out, None 107 | else: 108 | return out 109 | if feature: 110 | out1 = self.classifier(out) 111 | return out, out1 112 | 113 | out = self.classifier(out) 114 | return out 115 | 116 | 117 | def densenet121(**kwargs): 118 | r"""Densenet-121 model from 119 | `"Densely Connected Convolutional Networks" `_ 120 | """ 121 | model = DenseNet(num_init_features=64, growth_rate=32, block_config=(6, 12, 24, 16), 122 | **kwargs) 123 | return model 124 | 125 | 126 | def densenet169(**kwargs): 127 | r"""Densenet-169 model from 128 | `"Densely Connected Convolutional Networks" `_ 129 | """ 130 | model = DenseNet(num_init_features=64, growth_rate=32, block_config=(6, 12, 32, 32), 131 | **kwargs) 132 | return model 133 | 134 | 135 | def densenet201(**kwargs): 136 | r"""Densenet-201 model from 137 | `"Densely Connected Convolutional Networks" `_ 138 | """ 139 | model = DenseNet(num_init_features=64, growth_rate=32, block_config=(6, 12, 48, 32), 140 | **kwargs) 141 | return model 142 | 143 | 144 | def densenet161(**kwargs): 145 | r"""Densenet-161 model from 146 | `"Densely Connected Convolutional Networks" `_ 147 | """ 148 | model = DenseNet(num_init_features=96, growth_rate=48, block_config=(6, 12, 36, 24), 149 | **kwargs) 150 | return model 151 | -------------------------------------------------------------------------------- /src/models/ResNet.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | __all__ = ['resnet10', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 4 | 'resnet152'] 5 | 6 | 7 | def conv3x3(in_planes, out_planes, stride=1): 8 | """3x3 convolution with padding""" 9 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 10 | padding=1, bias=False) 11 | 12 | 13 | def conv1x1(in_planes, out_planes, stride=1): 14 | """1x1 convolution""" 15 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) 16 | 17 | 18 | class BasicBlock(nn.Module): 19 | expansion = 1 20 | 21 | def __init__(self, inplanes, planes, stride=1, downsample=None): 22 | super(BasicBlock, self).__init__() 23 | self.conv1 = conv3x3(inplanes, planes, stride) 24 | self.bn1 = nn.BatchNorm2d(planes) 25 | self.relu = nn.ReLU(inplace=True) 26 | self.conv2 = conv3x3(planes, planes) 27 | self.bn2 = nn.BatchNorm2d(planes) 28 | self.downsample = downsample 29 | self.stride = stride 30 | 31 | def forward(self, x): 32 | identity = x 33 | 34 | out = self.conv1(x) 35 | out = self.bn1(out) 36 | out = self.relu(out) 37 | 38 | out = self.conv2(out) 39 | out = self.bn2(out) 40 | 41 | if self.downsample is not None: 42 | identity = self.downsample(x) 43 | 44 | out += identity 45 | out = self.relu(out) 46 | 47 | return out 48 | 49 | 50 | class Bottleneck(nn.Module): 51 | expansion = 4 52 | 53 | def __init__(self, inplanes, planes, stride=1, downsample=None): 54 | super(Bottleneck, self).__init__() 55 | self.conv1 = conv1x1(inplanes, planes) 56 | self.bn1 = nn.BatchNorm2d(planes) 57 | self.conv2 = conv3x3(planes, planes, stride) 58 | self.bn2 = nn.BatchNorm2d(planes) 59 | self.conv3 = conv1x1(planes, planes * self.expansion) 60 | self.bn3 = nn.BatchNorm2d(planes * self.expansion) 61 | self.relu = nn.ReLU(inplace=True) 62 | self.downsample = downsample 63 | self.stride = stride 64 | 65 | def forward(self, x): 66 | identity = x 67 | 68 | out = self.conv1(x) 69 | out = self.bn1(out) 70 | out = self.relu(out) 71 | 72 | out = self.conv2(out) 73 | out = self.bn2(out) 74 | out = self.relu(out) 75 | 76 | out = self.conv3(out) 77 | out = self.bn3(out) 78 | 79 | if self.downsample is not None: 80 | identity = self.downsample(x) 81 | 82 | out += identity 83 | out = self.relu(out) 84 | 85 | return out 86 | 87 | 88 | class ResNet(nn.Module): 89 | 90 | def __init__(self, block, layers, num_classes=1000, zero_init_residual=False): 91 | super(ResNet, self).__init__() 92 | self.inplanes = 64 93 | self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, 94 | bias=False) 95 | self.bn1 = nn.BatchNorm2d(64) 96 | self.relu = nn.ReLU(inplace=True) 97 | self.layer1 = self._make_layer(block, 64, layers[0]) 98 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 99 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 100 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2) 101 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 102 | self.fc = nn.Linear(512 * block.expansion, num_classes) 103 | 104 | for m in self.modules(): 105 | if isinstance(m, nn.Conv2d): 106 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 107 | elif isinstance(m, nn.BatchNorm2d): 108 | nn.init.constant_(m.weight, 1) 109 | nn.init.constant_(m.bias, 0) 110 | 111 | # Zero-initialize the last BN in each residual branch, 112 | # so that the residual branch starts with zeros, and each residual block behaves like an identity. 113 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 114 | if zero_init_residual: 115 | for m in self.modules(): 116 | if isinstance(m, Bottleneck): 117 | nn.init.constant_(m.bn3.weight, 0) 118 | elif isinstance(m, BasicBlock): 119 | nn.init.constant_(m.bn2.weight, 0) 120 | 121 | def _make_layer(self, block, planes, blocks, stride=1): 122 | downsample = None 123 | if stride != 1 or self.inplanes != planes * block.expansion: 124 | downsample = nn.Sequential( 125 | conv1x1(self.inplanes, planes * block.expansion, stride), 126 | nn.BatchNorm2d(planes * block.expansion), 127 | ) 128 | 129 | layers = [] 130 | layers.append(block(self.inplanes, planes, stride, downsample)) 131 | self.inplanes = planes * block.expansion 132 | for _ in range(1, blocks): 133 | layers.append(block(self.inplanes, planes)) 134 | 135 | return nn.Sequential(*layers) 136 | 137 | def forward(self, x, feature=False): 138 | x = self.conv1(x) 139 | x = self.bn1(x) 140 | x = self.relu(x) 141 | 142 | x = self.layer1(x) 143 | x = self.layer2(x) 144 | x = self.layer3(x) 145 | x = self.layer4(x) 146 | 147 | x = self.avgpool(x) 148 | x = x.view(x.size(0), -1) 149 | if self.fc is None: 150 | if feature: 151 | return x, None 152 | else: 153 | return x 154 | if feature: 155 | x1 = self.fc(x) 156 | return x, x1 157 | x = self.fc(x) 158 | return x 159 | 160 | 161 | def resnet10(**kwargs): 162 | """Constructs a ResNet-10 model. 163 | """ 164 | model = ResNet(BasicBlock, [1, 1, 1, 1], **kwargs) 165 | return model 166 | 167 | 168 | def resnet18(**kwargs): 169 | """Constructs a ResNet-18 model. 170 | """ 171 | model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs) 172 | return model 173 | 174 | 175 | def resnet34(**kwargs): 176 | """Constructs a ResNet-34 model. 177 | """ 178 | model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs) 179 | return model 180 | 181 | 182 | def resnet50(**kwargs): 183 | """Constructs a ResNet-50 model. 184 | """ 185 | model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs) 186 | return model 187 | 188 | 189 | def resnet101(**kwargs): 190 | """Constructs a ResNet-101 model. 191 | """ 192 | model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs) 193 | return model 194 | 195 | 196 | def resnet152(**kwargs): 197 | """Constructs a ResNet-152 model. 198 | """ 199 | model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs) 200 | return model 201 | -------------------------------------------------------------------------------- /src/trainer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import time 3 | import torch.nn as nn 4 | import numpy as np 5 | from sacred import Ingredient 6 | from src.utils import warp_tqdm, get_metric, AverageMeter 7 | from src.datasets.ingredient import get_dataloader 8 | 9 | trainer_ingredient = Ingredient('trainer') 10 | @trainer_ingredient.config 11 | def config(): 12 | print_freq = 10 13 | meta_val_way = 5 14 | meta_val_shot = 1 15 | meta_val_metric = 'cosine' # ('euclidean', 'cosine', 'l1', l2') 16 | meta_val_iter = 500 17 | meta_val_query = 15 18 | alpha = - 1.0 19 | label_smoothing = 0. 20 | 21 | 22 | class Trainer: 23 | @trainer_ingredient.capture 24 | def __init__(self, device, meta_val_iter, meta_val_way, meta_val_shot, 25 | meta_val_query, ex): 26 | 27 | self.train_loader = get_dataloader(split='train', aug=True, shuffle=True) 28 | sample_info = [meta_val_iter, meta_val_way, meta_val_shot, meta_val_query] 29 | self.val_loader = get_dataloader(split='val', aug=False, sample=sample_info) 30 | self.device = device 31 | self.num_classes = ex.current_run.config['model']['num_classes'] 32 | 33 | def cross_entropy(self, logits, one_hot_targets, reduction='batchmean'): 34 | logsoftmax_fn = nn.LogSoftmax(dim=1) 35 | logsoftmax = logsoftmax_fn(logits) 36 | return - (one_hot_targets * logsoftmax).sum(1).mean() 37 | 38 | @trainer_ingredient.capture 39 | def do_epoch(self, epoch, scheduler, print_freq, disable_tqdm, callback, model, 40 | alpha, optimizer): 41 | batch_time = AverageMeter() 42 | losses = AverageMeter() 43 | top1 = AverageMeter() 44 | 45 | # switch to train mode 46 | model.train() 47 | steps_per_epoch = len(self.train_loader) 48 | end = time.time() 49 | tqdm_train_loader = warp_tqdm(self.train_loader, disable_tqdm) 50 | for i, (input, target, _) in enumerate(tqdm_train_loader): 51 | 52 | input, target = input.to(self.device), target.to(self.device, non_blocking=True) 53 | 54 | smoothed_targets = self.smooth_one_hot(target) 55 | assert (smoothed_targets.argmax(1) == target).float().mean() == 1.0 56 | # Forward pass 57 | if alpha > 0: # Mixup augmentation 58 | # generate mixed sample and targets 59 | lam = np.random.beta(alpha, alpha) 60 | rand_index = torch.randperm(input.size()[0]).cuda() 61 | target_a = smoothed_targets 62 | target_b = smoothed_targets[rand_index] 63 | mixed_input = lam * input + (1 - lam) * input[rand_index] 64 | 65 | output = model(mixed_input) 66 | loss = self.cross_entropy(output, target_a) * lam + self.cross_entropy(output, target_b) * (1. - lam) 67 | else: 68 | output = model(input) 69 | loss = self.cross_entropy(output, smoothed_targets) 70 | 71 | # Backward pass 72 | optimizer.zero_grad() 73 | loss.backward() 74 | optimizer.step() 75 | prec1 = (output.argmax(1) == target).float().mean() 76 | top1.update(prec1.item(), input.size(0)) 77 | if not disable_tqdm: 78 | tqdm_train_loader.set_description('Acc {:.2f}'.format(top1.avg)) 79 | 80 | # Measure accuracy and record loss 81 | losses.update(loss.item(), input.size(0)) 82 | batch_time.update(time.time() - end) 83 | end = time.time() 84 | 85 | if i % print_freq == 0: 86 | print('Epoch: [{0}][{1}/{2}]\t' 87 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 88 | 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' 89 | 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t'.format( 90 | epoch, i, len(self.train_loader), batch_time=batch_time, 91 | loss=losses, top1=top1)) 92 | if callback is not None: 93 | callback.scalar('train_loss', i / steps_per_epoch + epoch, losses.avg, title='Train loss') 94 | callback.scalar('@1', i / steps_per_epoch + epoch, top1.avg, title='Train Accuracy') 95 | for param_group in optimizer.param_groups: 96 | current_lr = param_group['lr'] 97 | if callback is not None: 98 | callback.scalar('lr', epoch, current_lr, title='Learning rate') 99 | 100 | @trainer_ingredient.capture 101 | def smooth_one_hot(self, targets, label_smoothing): 102 | assert 0 <= label_smoothing < 1 103 | with torch.no_grad(): 104 | new_targets = torch.empty(size=(targets.size(0), self.num_classes), device=self.device) 105 | new_targets.fill_(label_smoothing / (self.num_classes-1)) 106 | new_targets.scatter_(1, targets.unsqueeze(1), 1. - label_smoothing) 107 | return new_targets 108 | 109 | @trainer_ingredient.capture 110 | def meta_val(self, model, meta_val_way, meta_val_shot, disable_tqdm, callback, epoch): 111 | top1 = AverageMeter() 112 | model.eval() 113 | 114 | with torch.no_grad(): 115 | tqdm_test_loader = warp_tqdm(self.val_loader, disable_tqdm) 116 | for i, (inputs, target, _) in enumerate(tqdm_test_loader): 117 | inputs, target = inputs.to(self.device), target.to(self.device, non_blocking=True) 118 | output = model(inputs, feature=True)[0].cuda(0) 119 | train_out = output[:meta_val_way * meta_val_shot] 120 | train_label = target[:meta_val_way * meta_val_shot] 121 | test_out = output[meta_val_way * meta_val_shot:] 122 | test_label = target[meta_val_way * meta_val_shot:] 123 | train_out = train_out.reshape(meta_val_way, meta_val_shot, -1).mean(1) 124 | train_label = train_label[::meta_val_shot] 125 | prediction = self.metric_prediction(train_out, test_out, train_label) 126 | acc = (prediction == test_label).float().mean() 127 | top1.update(acc.item()) 128 | if not disable_tqdm: 129 | tqdm_test_loader.set_description('Acc {:.2f}'.format(top1.avg * 100)) 130 | 131 | if callback is not None: 132 | callback.scalar('val_acc', epoch + 1, top1.avg, title='Val acc') 133 | return top1.avg 134 | 135 | @trainer_ingredient.capture 136 | def metric_prediction(self, support, query, train_label, meta_val_metric): 137 | support = support.view(support.shape[0], -1) 138 | query = query.view(query.shape[0], -1) 139 | distance = get_metric(meta_val_metric)(support, query) 140 | predict = torch.argmin(distance, dim=1) 141 | predict = torch.take(train_label, predict) 142 | return predict -------------------------------------------------------------------------------- /src/eval.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from sacred import Ingredient 3 | from src.utils import warp_tqdm, compute_confidence_interval, load_checkpoint 4 | from src.utils import load_pickle, save_pickle 5 | from src.datasets.ingredient import get_dataloader 6 | import os 7 | import torch 8 | import collections 9 | import torch.nn.functional as F 10 | from src.tim import TIM, TIM_ADM, TIM_GD 11 | 12 | eval_ingredient = Ingredient('eval') 13 | @eval_ingredient.config 14 | def config(): 15 | number_tasks = 10000 16 | n_ways = 5 17 | query_shots = 15 18 | method = 'baseline' 19 | model_tag = 'best' 20 | target_data_path = None # Only for cross-domain scenario 21 | target_split_dir = None # Only for cross-domain scenario 22 | plt_metrics = ['accs'] 23 | shots = [1, 5] 24 | used_set = 'test' # can also be val for hyperparameter tuning 25 | fresh_start = True 26 | 27 | 28 | class Evaluator: 29 | @eval_ingredient.capture 30 | def __init__(self, device, ex): 31 | self.device = device 32 | self.ex = ex 33 | 34 | @eval_ingredient.capture 35 | def run_full_evaluation(self, model, model_path, model_tag, shots, method, callback): 36 | """ 37 | Run the evaluation over all the tasks in parallel 38 | inputs: 39 | model : The loaded model containing the feature extractor 40 | loaders_dic : Dictionnary containing training and testing loaders 41 | model_path : Where was the model loaded from 42 | model_tag : Which model ('final' or 'best') to load 43 | method : Which method to use for inference ("baseline", "tim-gd" or "tim-adm") 44 | shots : Number of support shots to try 45 | 46 | returns : 47 | results : List of the mean accuracy for each number of support shots 48 | """ 49 | print("=> Runnning full evaluation with method: {}".format(method)) 50 | 51 | # Load pre-trained model 52 | load_checkpoint(model=model, model_path=model_path, type=model_tag) 53 | 54 | # Get loaders 55 | loaders_dic = self.get_loaders() 56 | 57 | # Extract features (just load them if already in memory) 58 | extracted_features_dic = self.extract_features(model=model, 59 | model_path=model_path, 60 | loaders_dic=loaders_dic) 61 | results = [] 62 | for shot in shots: 63 | 64 | tasks = self.generate_tasks(extracted_features_dic=extracted_features_dic, 65 | shot=shot) 66 | logs = self.run_task(task_dic=tasks, 67 | model=model, 68 | callback=callback) 69 | 70 | l2n_mean, l2n_conf = compute_confidence_interval(logs['acc'][:, -1]) 71 | 72 | print('==> Meta Test: {} \nfeature\tL2N\n{}-shot \t{:.4f}({:.4f})'.format( 73 | model_tag.upper(), shot, l2n_mean, l2n_conf)) 74 | results.append(l2n_mean) 75 | return results 76 | 77 | def run_task(self, task_dic, model, callback): 78 | 79 | # Build the TIM classifier builder 80 | tim_builder = self.get_tim_builder(model=model) 81 | 82 | # Extract support and query 83 | y_s, y_q = task_dic['y_s'], task_dic['y_q'] 84 | z_s, z_q = task_dic['z_s'], task_dic['z_q'] 85 | 86 | # Transfer tensors to GPU if needed 87 | support = z_s.to(self.device) # [ N * (K_s + K_q), d] 88 | query = z_q.to(self.device) # [ N * (K_s + K_q), d] 89 | y_s = y_s.long().squeeze(2).to(self.device) 90 | y_q = y_q.long().squeeze(2).to(self.device) 91 | 92 | # Perform normalizations required 93 | support = F.normalize(support, dim=2) 94 | query = F.normalize(query, dim=2) 95 | 96 | # Initialize weights 97 | tim_builder.compute_lambda(support=support, query=query, y_s=y_s) 98 | tim_builder.init_weights(support=support, y_s=y_s, query=query, y_q=y_q) 99 | 100 | # Run adaptation 101 | tim_builder.run_adaptation(support=support, query=query, y_s=y_s, y_q=y_q, callback=callback) 102 | 103 | # Extract adaptation logs 104 | logs = tim_builder.get_logs() 105 | return logs 106 | 107 | @eval_ingredient.capture 108 | def get_tim_builder(self, model, method): 109 | # Initialize TIM classifier builder 110 | tim_info = {'model': model} 111 | if method == 'tim_adm': 112 | tim_builder = TIM_ADM(**tim_info) 113 | elif method == 'tim_gd': 114 | tim_builder = TIM_GD(**tim_info) 115 | elif method == 'baseline': 116 | tim_builder = TIM(**tim_info) 117 | else: 118 | raise ValueError("Method must be in ['tim_gd', 'tim_adm', 'baseline']") 119 | return tim_builder 120 | 121 | @eval_ingredient.capture 122 | def get_loaders(self, used_set, target_data_path, target_split_dir): 123 | # First, get loaders 124 | loaders_dic = {} 125 | loader_info = {'aug': False, 'shuffle': False, 'out_name': False} 126 | 127 | if target_data_path is not None: # This mean we are in the cross-domain scenario 128 | loader_info.update({'path': target_data_path, 129 | 'split_dir': target_split_dir}) 130 | 131 | train_loader = get_dataloader('train', **loader_info) 132 | loaders_dic['train_loader'] = train_loader 133 | 134 | test_loader = get_dataloader(used_set, **loader_info) 135 | loaders_dic.update({'test': test_loader}) 136 | return loaders_dic 137 | 138 | @eval_ingredient.capture 139 | def extract_features(self, model, model_path, model_tag, used_set, fresh_start, loaders_dic): 140 | """ 141 | inputs: 142 | model : The loaded model containing the feature extractor 143 | loaders_dic : Dictionnary containing training and testing loaders 144 | model_path : Where was the model loaded from 145 | model_tag : Which model ('final' or 'best') to load 146 | used_set : Set used between 'test' and 'val' 147 | n_ways : Number of ways for the task 148 | 149 | returns : 150 | extracted_features_dic : Dictionnary containing all extracted features and labels 151 | """ 152 | 153 | # Load features from memory if previously saved ... 154 | save_dir = os.path.join(model_path, model_tag, used_set) 155 | filepath = os.path.join(save_dir, 'output.plk') 156 | if os.path.isfile(filepath) and (not fresh_start): 157 | extracted_features_dic = load_pickle(filepath) 158 | print(" ==> Features loaded from {}".format(filepath)) 159 | return extracted_features_dic 160 | 161 | # ... otherwise just extract them 162 | else: 163 | print(" ==> Beginning feature extraction") 164 | os.makedirs(save_dir, exist_ok=True) 165 | 166 | model.eval() 167 | with torch.no_grad(): 168 | 169 | all_features = [] 170 | all_labels = [] 171 | for i, (inputs, labels, _) in enumerate(warp_tqdm(loaders_dic['test'], False)): 172 | inputs = inputs.to(self.device) 173 | outputs, _ = model(inputs, True) 174 | all_features.append(outputs.cpu()) 175 | all_labels.append(labels) 176 | all_features = torch.cat(all_features, 0) 177 | all_labels = torch.cat(all_labels, 0) 178 | extracted_features_dic = {'concat_features': all_features, 179 | 'concat_labels': all_labels 180 | } 181 | print(" ==> Saving features to {}".format(filepath)) 182 | save_pickle(filepath, extracted_features_dic) 183 | return extracted_features_dic 184 | 185 | @eval_ingredient.capture 186 | def get_task(self, shot, n_ways, query_shots, extracted_features_dic): 187 | """ 188 | inputs: 189 | extracted_features_dic : Dictionnary containing all extracted features and labels 190 | shot : Number of support shot per class 191 | n_ways : Number of ways for the task 192 | 193 | returns : 194 | task : Dictionnary : z_support : torch.tensor of shape [n_ways * shot, feature_dim] 195 | z_query : torch.tensor of shape [n_ways * query_shot, feature_dim] 196 | y_support : torch.tensor of shape [n_ways * shot] 197 | y_query : torch.tensor of shape [n_ways * query_shot] 198 | """ 199 | all_features = extracted_features_dic['concat_features'] 200 | all_labels = extracted_features_dic['concat_labels'] 201 | all_classes = torch.unique(all_labels) 202 | samples_classes = np.random.choice(a=all_classes, size=n_ways, replace=False) 203 | support_samples = [] 204 | query_samples = [] 205 | for each_class in samples_classes: 206 | class_indexes = torch.where(all_labels == each_class)[0] 207 | indexes = np.random.choice(a=class_indexes, size=shot + query_shots, replace=False) 208 | support_samples.append(all_features[indexes[:shot]]) 209 | query_samples.append(all_features[indexes[shot:]]) 210 | 211 | y_support = torch.arange(n_ways)[:, None].repeat(1, shot).reshape(-1) 212 | y_query = torch.arange(n_ways)[:, None].repeat(1, query_shots).reshape(-1) 213 | 214 | z_support = torch.cat(support_samples, 0) 215 | z_query = torch.cat(query_samples, 0) 216 | 217 | task = {'z_s': z_support, 'y_s': y_support, 218 | 'z_q': z_query, 'y_q': y_query} 219 | return task 220 | 221 | @eval_ingredient.capture 222 | def generate_tasks(self, extracted_features_dic, shot, number_tasks): 223 | """ 224 | inputs: 225 | extracted_features_dic : 226 | shot : Number of support shot per class 227 | number_tasks : Number of tasks to generate 228 | 229 | returns : 230 | merged_task : { z_support : torch.tensor of shape [number_tasks, n_ways * shot, feature_dim] 231 | z_query : torch.tensor of shape [number_tasks, n_ways * query_shot, feature_dim] 232 | y_support : torch.tensor of shape [number_tasks, n_ways * shot] 233 | y_query : torch.tensor of shape [number_tasks, n_ways * query_shot] } 234 | """ 235 | print(f" ==> Generating {number_tasks} tasks ...") 236 | tasks_dics = [] 237 | for _ in warp_tqdm(range(number_tasks), False): 238 | task_dic = self.get_task(shot=shot, extracted_features_dic=extracted_features_dic) 239 | tasks_dics.append(task_dic) 240 | 241 | # Now merging all tasks into 1 single dictionnary 242 | merged_tasks = {} 243 | n_tasks = len(tasks_dics) 244 | for key in tasks_dics[0].keys(): 245 | n_samples = tasks_dics[0][key].size(0) 246 | merged_tasks[key] = torch.cat([tasks_dics[i][key] for i in range(n_tasks)], dim=0).view(n_tasks, n_samples, -1) 247 | return merged_tasks 248 | -------------------------------------------------------------------------------- /src/tim.py: -------------------------------------------------------------------------------- 1 | from .utils import get_mi, get_cond_entropy, get_entropy, get_one_hot 2 | from tqdm import tqdm 3 | from sacred import Ingredient 4 | import torch 5 | import time 6 | 7 | tim_ingredient = Ingredient('tim') 8 | @tim_ingredient.config 9 | def config(): 10 | temp = 15 11 | loss_weights = [0.1, 1.0, 0.1] # [Xent, H(Y), H(Y|X)] 12 | lr = 1e-4 13 | iter = 150 14 | alpha = 1.0 15 | 16 | 17 | class TIM(object): 18 | @tim_ingredient.capture 19 | def __init__(self, temp, loss_weights, iter, model): 20 | self.temp = temp 21 | self.loss_weights = loss_weights.copy() 22 | self.iter = iter 23 | self.model = model 24 | self.init_info_lists() 25 | 26 | def init_info_lists(self): 27 | self.timestamps = [] 28 | self.mutual_infos = [] 29 | self.entropy = [] 30 | self.cond_entropy = [] 31 | self.test_acc = [] 32 | self.losses = [] 33 | 34 | def get_logits(self, samples): 35 | """ 36 | inputs: 37 | samples : torch.Tensor of shape [n_task, shot, feature_dim] 38 | 39 | returns : 40 | logits : torch.Tensor of shape [n_task, shot, num_class] 41 | """ 42 | n_tasks = samples.size(0) 43 | logits = self.temp * (samples.matmul(self.weights.transpose(1, 2)) \ 44 | - 1 / 2 * (self.weights**2).sum(2).view(n_tasks, 1, -1) \ 45 | - 1 / 2 * (samples**2).sum(2).view(n_tasks, -1, 1)) # 46 | return logits 47 | 48 | def get_preds(self, samples): 49 | """ 50 | inputs: 51 | samples : torch.Tensor of shape [n_task, s_shot, feature_dim] 52 | 53 | returns : 54 | preds : torch.Tensor of shape [n_task, shot] 55 | """ 56 | logits = self.get_logits(samples) 57 | preds = logits.argmax(2) 58 | return preds 59 | 60 | def init_weights(self, support, query, y_s, y_q): 61 | """ 62 | inputs: 63 | support : torch.Tensor of shape [n_task, s_shot, feature_dim] 64 | query : torch.Tensor of shape [n_task, q_shot, feature_dim] 65 | y_s : torch.Tensor of shape [n_task, s_shot] 66 | y_q : torch.Tensor of shape [n_task, q_shot] 67 | 68 | updates : 69 | self.weights : torch.Tensor of shape [n_task, num_class, feature_dim] 70 | """ 71 | self.model.eval() 72 | t0 = time.time() 73 | n_tasks = support.size(0) 74 | one_hot = get_one_hot(y_s) 75 | counts = one_hot.sum(1).view(n_tasks, -1, 1) 76 | weights = one_hot.transpose(1, 2).matmul(support) 77 | self.weights = weights / counts 78 | self.record_info(new_time=time.time()-t0, 79 | support=support, 80 | query=query, 81 | y_s=y_s, 82 | y_q=y_q) 83 | self.model.train() 84 | 85 | def compute_lambda(self, support, query, y_s): 86 | """ 87 | inputs: 88 | support : torch.Tensor of shape [n_task, s_shot, feature_dim] 89 | query : torch.Tensor of shape [n_task, q_shot, feature_dim] 90 | y_s : torch.Tensor of shape [n_task, s_shot] 91 | 92 | updates : 93 | self.loss_weights[0] : Scalar 94 | """ 95 | self.N_s, self.N_q = support.size(1), query.size(1) 96 | self.num_classes = torch.unique(y_s).size(0) 97 | if self.loss_weights[0] == 'auto': 98 | self.loss_weights[0] = (1 + self.loss_weights[2]) * self.N_s / self.N_q 99 | 100 | def record_info(self, new_time, support, query, y_s, y_q): 101 | """ 102 | inputs: 103 | support : torch.Tensor of shape [n_task, s_shot, feature_dim] 104 | query : torch.Tensor of shape [n_task, q_shot, feature_dim] 105 | y_s : torch.Tensor of shape [n_task, s_shot] 106 | y_q : torch.Tensor of shape [n_task, q_shot] : 107 | """ 108 | logits_q = self.get_logits(query).detach() 109 | preds_q = logits_q.argmax(2) 110 | q_probs = logits_q.softmax(2) 111 | self.timestamps.append(new_time) 112 | self.mutual_infos.append(get_mi(probs=q_probs)) 113 | self.entropy.append(get_entropy(probs=q_probs.detach())) 114 | self.cond_entropy.append(get_cond_entropy(probs=q_probs.detach())) 115 | self.test_acc.append((preds_q == y_q).float().mean(1, keepdim=True)) 116 | 117 | def get_logs(self): 118 | self.test_acc = torch.cat(self.test_acc, dim=1).cpu().numpy() 119 | self.cond_entropy = torch.cat(self.cond_entropy, dim=1).cpu().numpy() 120 | self.entropy = torch.cat(self.entropy, dim=1).cpu().numpy() 121 | self.mutual_infos = torch.cat(self.mutual_infos, dim=1).cpu().numpy() 122 | return {'timestamps': self.timestamps, 'mutual_info': self.mutual_infos, 123 | 'entropy': self.entropy, 'cond_entropy': self.cond_entropy, 124 | 'acc': self.test_acc, 'losses': self.losses} 125 | 126 | def run_adaptation(self, support, query, y_s, y_q, callback): 127 | """ 128 | Corresponds to the baseline (no transductive inference = SimpleShot) 129 | inputs: 130 | support : torch.Tensor of shape [n_task, s_shot, feature_dim] 131 | query : torch.Tensor of shape [n_task, q_shot, feature_dim] 132 | y_s : torch.Tensor of shape [n_task, s_shot] 133 | y_q : torch.Tensor of shape [n_task, q_shot] 134 | callback : VisdomLogger or None to plot in live our metrics 135 | 136 | 137 | updates : 138 | self.weights : torch.Tensor of shape [n_task, num_class, feature_dim] 139 | """ 140 | pass 141 | 142 | 143 | class TIM_GD(TIM): 144 | @tim_ingredient.capture 145 | def __init__(self, lr, model): 146 | super().__init__(model=model) 147 | self.lr = lr 148 | 149 | def run_adaptation(self, support, query, y_s, y_q, callback): 150 | """ 151 | Corresponds to the TIM-GD inference 152 | inputs: 153 | support : torch.Tensor of shape [n_task, s_shot, feature_dim] 154 | query : torch.Tensor of shape [n_task, q_shot, feature_dim] 155 | y_s : torch.Tensor of shape [n_task, s_shot] 156 | y_q : torch.Tensor of shape [n_task, q_shot] 157 | callback : VisdomLogger or None to plot in live our metrics 158 | 159 | 160 | updates : 161 | self.weights : torch.Tensor of shape [n_task, num_class, feature_dim] 162 | """ 163 | t0 = time.time() 164 | self.weights.requires_grad_() 165 | optimizer = torch.optim.Adam([self.weights], lr=self.lr) 166 | y_s_one_hot = get_one_hot(y_s) 167 | self.model.train() 168 | for i in tqdm(range(self.iter)): 169 | logits_s = self.get_logits(support) 170 | logits_q = self.get_logits(query) 171 | 172 | ce = - (y_s_one_hot * torch.log(logits_s.softmax(2) + 1e-12)).sum(2).mean(1).sum(0) 173 | q_probs = logits_q.softmax(2) 174 | q_cond_ent = - (q_probs * torch.log(q_probs + 1e-12)).sum(2).mean(1).sum(0) 175 | q_ent = - (q_probs.mean(1) * torch.log(q_probs.mean(1))).sum(1).sum(0) 176 | loss = self.loss_weights[0] * ce - (self.loss_weights[1] * q_ent - self.loss_weights[2] * q_cond_ent) 177 | 178 | optimizer.zero_grad() 179 | loss.backward() 180 | optimizer.step() 181 | 182 | t1 = time.time() 183 | self.model.eval() 184 | if callback is not None: 185 | P_q = self.get_logits(query).softmax(2).detach() 186 | prec = (P_q.argmax(2) == y_q).float().mean() 187 | callback.scalar('prec', i, prec, title='Precision') 188 | 189 | self.record_info(new_time=t1-t0, 190 | support=support, 191 | query=query, 192 | y_s=y_s, 193 | y_q=y_q) 194 | self.model.train() 195 | t0 = time.time() 196 | 197 | 198 | class TIM_ADM(TIM): 199 | @tim_ingredient.capture 200 | def __init__(self, model, alpha): 201 | super().__init__(model=model) 202 | self.alpha = alpha 203 | 204 | def q_update(self, P): 205 | """ 206 | inputs: 207 | P : torch.tensor of shape [n_tasks, q_shot, num_class] 208 | where P[i,j,k] = probability of point j in task i belonging to class k 209 | (according to our L2 classifier) 210 | """ 211 | l1, l2 = self.loss_weights[1], self.loss_weights[2] 212 | l3 = 1.0 # Corresponds to the weight of the KL penalty 213 | alpha = l2 / l3 214 | beta = l1 / (l1 + l3) 215 | 216 | Q = (P ** (1+alpha)) / ((P ** (1+alpha)).sum(dim=1, keepdim=True)) ** beta 217 | self.Q = (Q / Q.sum(dim=2, keepdim=True)).float() 218 | 219 | def weights_update(self, support, query, y_s_one_hot): 220 | """ 221 | Corresponds to w_k updates 222 | inputs: 223 | support : torch.Tensor of shape [n_task, s_shot, feature_dim] 224 | query : torch.Tensor of shape [n_task, q_shot, feature_dim] 225 | y_s_one_hot : torch.Tensor of shape [n_task, s_shot, num_classes] 226 | 227 | 228 | updates : 229 | self.weights : torch.Tensor of shape [n_task, num_class, feature_dim] 230 | """ 231 | n_tasks = support.size(0) 232 | P_s = self.get_logits(support).softmax(2) 233 | P_q = self.get_logits(query).softmax(2) 234 | src_part = self.loss_weights[0] / (1 + self.loss_weights[2]) * y_s_one_hot.transpose(1, 2).matmul(support) 235 | src_part += self.loss_weights[0] / (1 + self.loss_weights[2]) * (self.weights * P_s.sum(1, keepdim=True).transpose(1, 2)\ 236 | - P_s.transpose(1, 2).matmul(support)) 237 | src_norm = self.loss_weights[0] / (1 + self.loss_weights[2]) * y_s_one_hot.sum(1).view(n_tasks, -1, 1) 238 | 239 | qry_part = self.N_s / self.N_q * self.Q.transpose(1, 2).matmul(query) 240 | qry_part += self.N_s / self.N_q * (self.weights * P_q.sum(1, keepdim=True).transpose(1, 2)\ 241 | - P_q.transpose(1, 2).matmul(query)) 242 | qry_norm = self.N_s / self.N_q * self.Q.sum(1).view(n_tasks, -1, 1) 243 | 244 | new_weights = (src_part + qry_part) / (src_norm + qry_norm) 245 | self.weights = self.weights + self.alpha * (new_weights - self.weights) 246 | 247 | def run_adaptation(self, support, query, y_s, y_q, callback): 248 | """ 249 | Corresponds to the TIM-ADM inference 250 | inputs: 251 | support : torch.Tensor of shape [n_task, s_shot, feature_dim] 252 | query : torch.Tensor of shape [n_task, q_shot, feature_dim] 253 | y_s : torch.Tensor of shape [n_task, s_shot] 254 | y_q : torch.Tensor of shape [n_task, q_shot] 255 | callback : VisdomLogger or None to plot in live our metrics 256 | 257 | 258 | updates : 259 | self.weights : torch.Tensor of shape [n_task, num_class, feature_dim] 260 | """ 261 | t0 = time.time() 262 | y_s_one_hot = get_one_hot(y_s) 263 | for i in tqdm(range(self.iter)): 264 | P_q = self.get_logits(query).softmax(2) 265 | self.q_update(P=P_q) 266 | self.weights_update(support, query, y_s_one_hot) 267 | t1 = time.time() 268 | if callback is not None: 269 | callback.scalar('acc', i, self.test_acc[-1].mean(), title='Accuracy') 270 | callback.scalars(['cond_ent', 'marg_ent'], i, [self.cond_entropy[-1].mean(), 271 | self.entropy[-1].mean()], title='Entropies') 272 | self.record_info(new_time=t1-t0, 273 | support=support, 274 | query=query, 275 | y_s=y_s, 276 | y_q=y_q) 277 | t0 = time.time() --------------------------------------------------------------------------------