├── eval_lib ├── eval │ ├── __init__.py │ ├── util.py │ ├── cls_eval.py │ └── meta_eval.py ├── rfs_models │ ├── __init__.py │ ├── convnet.py │ ├── util.py │ ├── wresnet.py │ ├── operations.py │ ├── resnet_new.py │ └── augment_cnn.py ├── rfs_util.py ├── rfs_dataset │ ├── transform_cfg.py │ ├── cifar.py │ ├── mini_imagenet.py │ └── tiered_imagenet.py └── flop_benchmark.py ├── lib ├── utils │ ├── __init__.py │ └── flop_benchmark.py ├── log_utils │ ├── __init__.py │ ├── time_utils.py │ └── logger.py ├── models │ ├── cell_infers │ │ ├── __init__.py │ │ ├── tiny_network.py │ │ ├── nasnet_cifar.py │ │ └── cells.py │ ├── shape_infers │ │ ├── shared_utils.py │ │ ├── __init__.py │ │ └── InferTinyCellNet.py │ ├── cell_searchs │ │ ├── __init__.py │ │ ├── search_cells.py │ │ ├── search_model_darts.py │ │ ├── search_model_darts_nasnet.py │ │ └── genotypes.py │ ├── SharedUtils.py │ └── __init__.py ├── config_utils │ ├── __init__.py │ └── configure_utils.py ├── procedures │ ├── __init__.py │ ├── starts.py │ └── .ipynb_checkpoints │ │ └── metantk_test-checkpoint.ipynb ├── datasets │ ├── __init__.py │ ├── test_utils.py │ ├── SearchDatasetWrap.py │ ├── DownsampledImageNet.py │ ├── transform_cfg.py │ ├── mini_imagenet.py │ └── tiered_imagenet.py └── nas_201_api │ └── __init__.py ├── requirements.txt ├── LICENSE ├── scripts ├── script_mini_5cell.sh ├── script_mini_8cell.sh ├── script_tiered_5cell.sh └── script_tiered_8cell.sh ├── .gitignore ├── README.md ├── prune_launch.py └── random_baseline.py /eval_lib/eval/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /lib/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .flop_benchmark import get_model_infos 2 | -------------------------------------------------------------------------------- /lib/log_utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .logger import Logger, PrintLogger 2 | from .time_utils import time_string 3 | -------------------------------------------------------------------------------- /lib/models/cell_infers/__init__.py: -------------------------------------------------------------------------------- 1 | from .tiny_network import TinyNetwork 2 | from .nasnet_cifar import NASNetonCIFAR 3 | -------------------------------------------------------------------------------- /lib/config_utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .configure_utils import load_config, load_config_dict, merge_config_dict, dict2config, configure2str 2 | -------------------------------------------------------------------------------- /lib/procedures/__init__.py: -------------------------------------------------------------------------------- 1 | from .starts import prepare_seed, prepare_logger, get_machine_info 2 | from .ntk_opacus import get_ntk_n, get_analytical_metantk_n 3 | from .linear_region_counter import Linear_Region_Collector 4 | -------------------------------------------------------------------------------- /lib/models/shape_infers/shared_utils.py: -------------------------------------------------------------------------------- 1 | def parse_channel_info(xstring): 2 | blocks = xstring.split(' ') 3 | blocks = [x.split('-') for x in blocks] 4 | blocks = [[int(_) for _ in x] for x in blocks] 5 | return blocks 6 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | easydict==1.9 2 | numpy==1.17.4 3 | opacus==0.14.0 4 | Pillow==9.0.1 5 | ptflops==0.6.6 6 | scikit_learn==1.0.2 7 | scipy==1.3.1 8 | torch==1.3.1 9 | torch_optimizer==0.1.0 10 | torchvision==0.4.2 11 | tqdm==4.36.1 12 | -------------------------------------------------------------------------------- /lib/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from .get_dataset_with_transform import get_datasets, get_nas_search_loaders, Dataset2Class, CUTOUT, Lighting 2 | from .DownsampledImageNet import ImageNet16 3 | from .mini_imagenet import ImageNet, MetaImageNet 4 | from .tiered_imagenet import TieredImageNet, MetaTieredImageNet 5 | from .SearchDatasetWrap import SearchDataset 6 | -------------------------------------------------------------------------------- /lib/models/shape_infers/__init__.py: -------------------------------------------------------------------------------- 1 | from .InferCifarResNet_width import InferWidthCifarResNet 2 | from .InferImagenetResNet import InferImagenetResNet 3 | from .InferCifarResNet_depth import InferDepthCifarResNet 4 | from .InferCifarResNet import InferCifarResNet 5 | from .InferMobileNetV2 import InferMobileNetV2 6 | from .InferTinyCellNet import DynamicShapeTinyNet 7 | -------------------------------------------------------------------------------- /lib/nas_201_api/__init__.py: -------------------------------------------------------------------------------- 1 | ##################################################### 2 | # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.08 # 3 | ##################################################### 4 | from .api import NASBench201API 5 | from .api import ArchResults, ResultsCount 6 | 7 | # NAS_BENCH_201_API_VERSION="v1.1" # [2020.02.25] 8 | # NAS_BENCH_201_API_VERSION="v1.2" # [2020.03.09] 9 | NAS_BENCH_201_API_VERSION="v1.3" # [2020.03.16] 10 | -------------------------------------------------------------------------------- /lib/models/cell_searchs/__init__.py: -------------------------------------------------------------------------------- 1 | # The macro structure is defined in NAS-Bench-201 2 | from .search_model_darts import TinyNetworkDarts 3 | from .genotypes import Structure as CellStructure, architectures as CellArchitectures 4 | # NASNet-based macro structure 5 | from .search_model_darts_nasnet import NASNetworkDARTS 6 | 7 | 8 | nas201_super_nets = {'DARTS-V1': TinyNetworkDarts, 9 | "DARTS-V2": TinyNetworkDarts} 10 | 11 | nasnet_super_nets = {"DARTS-V1": NASNetworkDARTS, 12 | "DARTS-V2": NASNetworkDARTS} 13 | -------------------------------------------------------------------------------- /lib/datasets/test_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | 4 | def test_imagenet_data(imagenet): 5 | total_length = len(imagenet) 6 | assert total_length == 1281166 or total_length == 50000, 'The length of ImageNet is wrong : {}'.format(total_length) 7 | map_id = {} 8 | for index in range(total_length): 9 | path, target = imagenet.imgs[index] 10 | folder, image_name = os.path.split(path) 11 | _, folder = os.path.split(folder) 12 | if folder not in map_id: 13 | map_id[folder] = target 14 | else: 15 | assert map_id[folder] == target, 'Class : {} is not {}'.format(folder, target) 16 | assert image_name.find(folder) == 0, '{} is wrong.'.format(path) 17 | print ('Check ImageNet Dataset OK') 18 | -------------------------------------------------------------------------------- /eval_lib/rfs_models/__init__.py: -------------------------------------------------------------------------------- 1 | from .convnet import convnet4 2 | from .resnet import resnet12 3 | from .resnet import seresnet12 4 | from .wresnet import wrn_28_10 5 | from .dartsmodel import NetworkMiniImageNet 6 | from .augment_cnn import AugmentCNN 7 | 8 | from .resnet_new import resnet50 9 | 10 | model_pool = [ 11 | 'convnet4', 12 | 'resnet12', 13 | 'seresnet12', 14 | 'wrn_28_10', 15 | 'dartsmodel', 16 | 'augmentcnn', 17 | 'convnet4small' 18 | ] 19 | 20 | model_dict = { 21 | 'wrn_28_10': wrn_28_10, 22 | 'convnet4': convnet4, 23 | 'convnet4small': convnet4, 24 | 'resnet12': resnet12, 25 | 'seresnet12': seresnet12, 26 | 'resnet50': resnet50, 27 | 'dartsmodel': NetworkMiniImageNet, 28 | 'augmentcnn': AugmentCNN 29 | } 30 | -------------------------------------------------------------------------------- /lib/models/SharedUtils.py: -------------------------------------------------------------------------------- 1 | def additive_func(A, B): 2 | assert A.dim() == B.dim() and A.size(0) == B.size(0), '{:} vs {:}'.format(A.size(), B.size()) 3 | C = min(A.size(1), B.size(1)) 4 | if A.size(1) == B.size(1): 5 | return A + B 6 | elif A.size(1) < B.size(1): 7 | out = B.clone() 8 | out[:,:C] += A 9 | return out 10 | else: 11 | out = A.clone() 12 | out[:,:C] += B 13 | return out 14 | 15 | 16 | def change_key(key, value): 17 | def func(m): 18 | if hasattr(m, key): 19 | setattr(m, key, value) 20 | return func 21 | 22 | 23 | def parse_channel_info(xstring): 24 | blocks = xstring.split(' ') 25 | blocks = [x.split('-') for x in blocks] 26 | blocks = [[int(_) for _ in x] for x in blocks] 27 | return blocks 28 | -------------------------------------------------------------------------------- /eval_lib/eval/util.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | 5 | class AverageMeter(object): 6 | """Computes and stores the average and current value""" 7 | def __init__(self): 8 | self.reset() 9 | 10 | def reset(self): 11 | self.val = 0 12 | self.avg = 0 13 | self.sum = 0 14 | self.count = 0 15 | 16 | def update(self, val, n=1): 17 | self.val = val 18 | self.sum += val * n 19 | self.count += n 20 | self.avg = self.sum / self.count 21 | 22 | 23 | def accuracy(output, target, topk=(1,)): 24 | """Computes the accuracy over the k top predictions for the specified values of k""" 25 | with torch.no_grad(): 26 | maxk = max(topk) 27 | batch_size = target.size(0) 28 | 29 | _, pred = output.topk(maxk, 1, True, True) 30 | pred = pred.t() 31 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 32 | 33 | res = [] 34 | for k in topk: 35 | correct_k = correct[:k].contiguous().view(-1).float().sum(0, keepdim=True) 36 | res.append(correct_k.mul_(100.0 / batch_size)) 37 | return res 38 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Yite Wang 4 | Copyright (c) 2022 Haoxiang Wang 5 | 6 | Permission is hereby granted, free of charge, to any person obtaining a copy 7 | of this software and associated documentation files (the "Software"), to deal 8 | in the Software without restriction, including without limitation the rights 9 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 10 | copies of the Software, and to permit persons to whom the Software is 11 | furnished to do so, subject to the following conditions: 12 | 13 | The above copyright notice and this permission notice shall be included in all 14 | copies or substantial portions of the Software. 15 | 16 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 22 | SOFTWARE. 23 | -------------------------------------------------------------------------------- /lib/log_utils/time_utils.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | 4 | def time_for_file(): 5 | ISOTIMEFORMAT='%d-%h-at-%H-%M-%S' 6 | return '{:}'.format(time.strftime( ISOTIMEFORMAT, time.gmtime(time.time()) )) 7 | 8 | 9 | def time_string(): 10 | ISOTIMEFORMAT='%Y-%m-%d %X' 11 | string = '[{:}]'.format(time.strftime( ISOTIMEFORMAT, time.gmtime(time.time()) )) 12 | return string 13 | 14 | 15 | def time_string_short(): 16 | ISOTIMEFORMAT='%Y%m%d' 17 | string = '{:}'.format(time.strftime( ISOTIMEFORMAT, time.gmtime(time.time()) )) 18 | return string 19 | 20 | 21 | def time_print(string, is_print=True): 22 | if (is_print): 23 | print('{} : {}'.format(time_string(), string)) 24 | 25 | 26 | def convert_secs2time(epoch_time, return_str=False): 27 | need_hour = int(epoch_time / 3600) 28 | need_mins = int((epoch_time - 3600*need_hour) / 60) 29 | need_secs = int(epoch_time - 3600*need_hour - 60*need_mins) 30 | if return_str: 31 | str = '[{:02d}:{:02d}:{:02d}]'.format(need_hour, need_mins, need_secs) 32 | return str 33 | else: 34 | return need_hour, need_mins, need_secs 35 | 36 | 37 | def print_log(print_string, log): 38 | if hasattr(log, 'log'): log.log('{:}'.format(print_string)) 39 | else: 40 | print("{:}".format(print_string)) 41 | if log is not None: 42 | log.write('{:}\n'.format(print_string)) 43 | log.flush() 44 | -------------------------------------------------------------------------------- /scripts/script_mini_5cell.sh: -------------------------------------------------------------------------------- 1 | # This is a sample script for 5-cells experiments on MiniImageNet. 2 | 3 | # Searched architectures examples: 4 | # (1) 5 | # Genotype(normal=[[('sep_conv_3x3', 0), ('sep_conv_3x3', 1)], [('sep_conv_3x3', 1), ('sep_conv_3x3', 2)], [('conv_3x3', 0), ('sep_conv_3x3', 2)]], 6 | # normal_concat=[2, 3, 4], reduce=[[('sep_conv_3x3', 0), ('sep_conv_3x3', 1)], [('conv_1x5_5x1', 1), ('dil_conv_3x3', 2)], [('dil_conv_3x3', 2), ('dil_conv_3x3', 3)]], 7 | # reduce_concat=[2, 3, 4]) 8 | 9 | # (2) 10 | # Genotype(normal=[[('sep_conv_3x3', 0), ('sep_conv_3x3', 1)], [('sep_conv_3x3', 1), ('sep_conv_3x3', 2)], [('conv_3x3', 2), ('sep_conv_3x3', 3)]], 11 | # normal_concat=[2, 3, 4], reduce=[[('sep_conv_3x3', 0), ('dil_conv_3x3', 1)], [('conv_1x5_5x1', 1), ('sep_conv_3x3', 2)], [('sep_conv_3x3', 1), ('conv_1x5_5x1', 3)]], 12 | # reduce_concat=[2, 3, 4]) 13 | 14 | # (3) 15 | # Genotype(normal=[[('sep_conv_3x3', 0), ('sep_conv_3x3', 1)], [('sep_conv_3x3', 0), ('max_pool_3x3BN', 2)], [('dil_conv_3x3', 2), ('sep_conv_3x3', 3)]], 16 | # normal_concat=[2, 3, 4], reduce=[[('dil_conv_3x3', 0), ('sep_conv_3x3', 1)], [('conv_3x3', 1), ('sep_conv_3x3', 2)], [('sep_conv_3x3', 2), ('sep_conv_3x3', 3)]], 17 | # reduce_concat=[2, 3, 4]) 18 | 19 | args=( 20 | --gpu 0 \ 21 | 22 | # search space settings 23 | --space darts_fewshot \ 24 | --max_nodes 3 \ 25 | 26 | # Dataset setting 27 | --dataset MetaMiniImageNet \ 28 | --dartsbs 3 \ 29 | 30 | # Random seed 31 | --seed -1 \ 32 | 33 | # If use only linear regions 34 | --only_lrs false \ 35 | 36 | # NTK/MetaNTK setting 37 | --ntk_type MetaNTK_anl \ 38 | --algorithm MAML \ 39 | --inner_lr_time 1000 \ 40 | --reg_coef 1e-3 \ 41 | 42 | # Search/evaluate architecture setting 43 | --ntk_channels 48 \ 44 | --ntk_layers 5 \ 45 | --train_after_search true \ 46 | ) 47 | 48 | python prune_launch.py "${args[@]}" -------------------------------------------------------------------------------- /scripts/script_mini_8cell.sh: -------------------------------------------------------------------------------- 1 | # This is a sample script for 8-cells experiments on MiniImageNet. 2 | 3 | # Searched architecture examples: 4 | # (1) 5 | # Genotype(normal=[[('sep_conv_3x3', 0), ('max_pool_3x3BN', 1)], [('conv_1x5_5x1', 0), ('sep_conv_3x3', 1)], [('sep_conv_3x3', 2), ('conv_3x3', 3)]], 6 | # normal_concat=[2, 3, 4], reduce=[[('skip_connect', 0), ('sep_conv_3x3', 1)], [('conv_3x3', 1), ('sep_conv_3x3', 2)], [('dil_conv_3x3', 0), ('conv_1x5_5x1', 1)]], 7 | # reduce_concat=[2, 3, 4]) 8 | 9 | # (2) 10 | # Genotype(normal=[[('conv_3x3', 0), ('sep_conv_3x3', 1)], [('conv_1x5_5x1', 0), ('conv_3x3', 1)], [('sep_conv_3x3', 0), ('dil_conv_3x3', 2)]], 11 | # normal_concat=[2, 3, 4], reduce=[[('sep_conv_3x3', 0), ('sep_conv_3x3', 1)], [('sep_conv_3x3', 0), ('sep_conv_3x3', 2)], [('sep_conv_3x3', 0), ('sep_conv_3x3', 2)]], 12 | # reduce_concat=[2, 3, 4]) 13 | 14 | # (3) 15 | # Genotype(normal=[[('sep_conv_3x3', 0), ('conv_1x5_5x1', 1)], [('sep_conv_3x3', 0), ('sep_conv_3x3', 1)], [('conv_3x3', 0), ('sep_conv_3x3', 1)]], 16 | # normal_concat=[2, 3, 4], reduce=[[('sep_conv_3x3', 0), ('sep_conv_3x3', 1)], [('conv_1x5_5x1', 1), ('sep_conv_3x3', 2)], [('conv_3x3', 0), ('sep_conv_3x3', 3)]], 17 | # reduce_concat=[2, 3, 4]) 18 | 19 | 20 | args=( 21 | --gpu 0 \ 22 | 23 | # search space settings 24 | --space darts_fewshot \ 25 | --max_nodes 3 \ 26 | 27 | # Dataset setting 28 | --dataset MetaMiniImageNet \ 29 | --dartsbs 3 \ 30 | 31 | # Random seed 32 | --seed -1 \ 33 | 34 | # If use only linear regions 35 | --only_lrs false \ 36 | 37 | # NTK/MetaNTK setting 38 | --ntk_type MetaNTK_anl \ 39 | --algorithm ANIL \ 40 | --inner_lr_time 1 \ 41 | --reg_coef 1e-5 \ 42 | 43 | # Search/evaluate architecture setting 44 | --ntk_channels 48 \ 45 | --ntk_layers 8 \ 46 | --train_after_search true \ 47 | ) 48 | 49 | python prune_launch.py "${args[@]}" -------------------------------------------------------------------------------- /scripts/script_tiered_5cell.sh: -------------------------------------------------------------------------------- 1 | # This is a sample script for 5-cells experiments on TieredImageNet. 2 | 3 | # Searched architecture examples: 4 | # (1) 5 | # Genotype(normal=[[('conv_3x3', 0), ('dil_conv_3x3', 1)], [('sep_conv_3x3', 1), ('sep_conv_3x3', 2)], [('sep_conv_3x3', 0), ('sep_conv_3x3', 3)]], 6 | # normal_concat=[2, 3, 4], reduce=[[('skip_connect', 0), ('dil_conv_3x3', 1)], [('sep_conv_3x3', 0), ('conv_1x5_5x1', 2)], [('max_pool_3x3BN', 2), ('conv_1x5_5x1', 3)]], 7 | # reduce_concat=[2, 3, 4]) 8 | 9 | # (2) 10 | # Genotype(normal=[[('sep_conv_3x3', 0), ('sep_conv_3x3', 1)], [('sep_conv_3x3', 1), ('max_pool_3x3BN', 2)], [('conv_3x3', 2), ('conv_3x3', 3)]], 11 | # normal_concat=[2, 3, 4], reduce=[[('dil_conv_3x3', 0), ('conv_1x5_5x1', 1)], [('conv_3x3', 0), ('sep_conv_3x3', 1)], [('sep_conv_3x3', 2), ('sep_conv_3x3', 3)]], 12 | # reduce_concat=[2, 3, 4]) 13 | 14 | # (3) 15 | # Genotype(normal=[[('sep_conv_3x3', 0), ('conv_1x5_5x1', 1)], [('dil_conv_3x3', 0), ('sep_conv_3x3', 2)], [('sep_conv_3x3', 0), ('max_pool_3x3BN', 3)]], 16 | # normal_concat=[2, 3, 4], reduce=[[('sep_conv_3x3', 0), ('sep_conv_3x3', 1)], [('conv_1x5_5x1', 0), ('conv_1x5_5x1', 2)], [('dil_conv_3x3', 1), ('sep_conv_3x3', 3)]], 17 | # reduce_concat=[2, 3, 4]) 18 | 19 | args=( 20 | --gpu 0 \ 21 | 22 | # search space settings 23 | --space darts_fewshot \ 24 | --max_nodes 3 \ 25 | 26 | # Dataset setting 27 | --dataset MetaTieredImageNet \ 28 | --dartsbs 3 \ 29 | 30 | # Random seed 31 | --seed -1 \ 32 | 33 | # If use only linear regions 34 | --only_lrs false \ 35 | 36 | # NTK/MetaNTK setting 37 | --ntk_type MetaNTK_anl \ 38 | --algorithm MAML \ 39 | --inner_lr_time 1000 \ 40 | --reg_coef 1e-3 \ 41 | 42 | # Search/evaluate architecture setting 43 | --ntk_channels 48 \ 44 | --ntk_layers 5 \ 45 | --train_after_search true \ 46 | ) 47 | 48 | python prune_launch.py "${args[@]}" -------------------------------------------------------------------------------- /scripts/script_tiered_8cell.sh: -------------------------------------------------------------------------------- 1 | # This is a sample script for 8-cells experiments on TieredImageNet. 2 | 3 | # Searched architecture examples: 4 | # (1) 5 | # Genotype(normal=[[('conv_1x5_5x1', 0), ('sep_conv_3x3', 1)], [('conv_1x5_5x1', 0), ('conv_1x5_5x1', 1)], [('sep_conv_3x3', 0), ('conv_1x5_5x1', 2)]], 6 | # normal_concat=[2, 3, 4], reduce=[[('skip_connect', 0), ('conv_3x3', 1)], [('conv_3x3', 0), ('conv_1x5_5x1', 1)], [('sep_conv_3x3', 0), ('sep_conv_3x3', 3)]], 7 | # reduce_concat=[2, 3, 4]) 8 | 9 | # (2) 10 | # Genotype(normal=[[('conv_1x5_5x1', 0), ('dil_conv_3x3', 1)], [('sep_conv_3x3', 0), ('sep_conv_3x3', 1)], [('sep_conv_3x3', 0), ('conv_3x3', 2)]], 11 | # normal_concat=[2, 3, 4], reduce=[[('conv_3x3', 0), ('sep_conv_3x3', 1)], [('sep_conv_3x3', 1), ('sep_conv_3x3', 2)], [('conv_3x3', 0), ('sep_conv_3x3', 1)]], 12 | # reduce_concat=[2, 3, 4]) 13 | 14 | # (3) 15 | # Genotype(normal=[[('conv_1x5_5x1', 0), ('conv_1x5_5x1', 1)], [('sep_conv_3x3', 0), ('conv_1x5_5x1', 2)], [('sep_conv_3x3', 2), ('sep_conv_3x3', 3)]], 16 | # normal_concat=[2, 3, 4], reduce=[[('conv_3x3', 0), ('conv_3x3', 1)], [('sep_conv_3x3', 0), ('sep_conv_3x3', 1)], [('sep_conv_3x3', 1), ('conv_1x5_5x1', 3)]], 17 | # reduce_concat=[2, 3, 4]) 18 | 19 | args=( 20 | --gpu 0 \ 21 | 22 | # search space settings 23 | --space darts_fewshot \ 24 | --max_nodes 3 \ 25 | 26 | # Dataset setting 27 | --dataset MetaTieredImageNet \ 28 | --dartsbs 3 \ 29 | 30 | # Random seed 31 | --seed -1 \ 32 | 33 | # If use only linear regions 34 | --only_lrs false \ 35 | 36 | # NTK/MetaNTK setting 37 | --ntk_type MetaNTK_anl \ 38 | --algorithm ANIL \ 39 | --inner_lr_time 1000 \ 40 | --reg_coef 1e-3 \ 41 | 42 | # Search/evaluate architecture setting 43 | --ntk_channels 48 \ 44 | --ntk_layers 8 \ 45 | --train_after_search true \ 46 | ) 47 | 48 | python prune_launch.py "${args[@]}" -------------------------------------------------------------------------------- /eval_lib/eval/cls_eval.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import torch 4 | import time 5 | import logging 6 | 7 | from .util import AverageMeter, accuracy 8 | 9 | 10 | def validate(val_loader, model, criterion, opt): 11 | logger = logging.getLogger(__name__) 12 | 13 | """One epoch validation""" 14 | batch_time = AverageMeter() 15 | losses = AverageMeter() 16 | top1 = AverageMeter() 17 | top5 = AverageMeter() 18 | 19 | # switch to evaluate mode 20 | model.eval() 21 | 22 | with torch.no_grad(): 23 | end = time.time() 24 | for idx, (input, target, _) in enumerate(val_loader): 25 | 26 | input = input.float() 27 | if torch.cuda.is_available(): 28 | input = input.cuda() 29 | target = target.cuda() 30 | 31 | # compute output 32 | output = model(input) 33 | loss = criterion(output, target) 34 | 35 | # measure accuracy and record loss 36 | acc1, acc5 = accuracy(output, target, topk=(1, 5)) 37 | losses.update(loss.item(), input.size(0)) 38 | top1.update(acc1[0], input.size(0)) 39 | top5.update(acc5[0], input.size(0)) 40 | 41 | # measure elapsed time 42 | batch_time.update(time.time() - end) 43 | end = time.time() 44 | 45 | if idx % opt.print_freq == 0: 46 | logger.info('Test: [{0}/{1}]\t' 47 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 48 | 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' 49 | 'Acc@1 {top1.val:.3f} ({top1.avg:.3f})\t' 50 | 'Acc@5 {top5.val:.3f} ({top5.avg:.3f})'.format( 51 | idx, len(val_loader), batch_time=batch_time, loss=losses, 52 | top1=top1, top5=top5)) 53 | 54 | logger.info(' * Acc@1 {top1.avg:.3f} Acc@5 {top5.avg:.3f}' 55 | .format(top1=top1, top5=top5)) 56 | 57 | return top1.avg, top5.avg, losses.avg 58 | -------------------------------------------------------------------------------- /lib/procedures/starts.py: -------------------------------------------------------------------------------- 1 | import os, sys, torch, random, PIL, copy, numpy as np 2 | 3 | 4 | def prepare_seed(rand_seed): 5 | random.seed(rand_seed) 6 | np.random.seed(rand_seed) 7 | torch.manual_seed(rand_seed) 8 | torch.cuda.manual_seed(rand_seed) 9 | torch.cuda.manual_seed_all(rand_seed) 10 | 11 | 12 | def prepare_logger(xargs): 13 | args = copy.deepcopy(xargs) 14 | from log_utils import Logger 15 | logger = Logger(args.save_dir, args.rand_seed) 16 | logger.log('Main Function with logger : {:}'.format(logger)) 17 | logger.log('Arguments : -------------------------------') 18 | for name, value in args._get_kwargs(): 19 | logger.log('{:16} : {:}'.format(name, value)) 20 | logger.log("Python Version : {:}".format(sys.version.replace('\n', ' '))) 21 | logger.log("Pillow Version : {:}".format(PIL.__version__)) 22 | logger.log("PyTorch Version : {:}".format(torch.__version__)) 23 | logger.log("cuDNN Version : {:}".format(torch.backends.cudnn.version())) 24 | logger.log("CUDA available : {:}".format(torch.cuda.is_available())) 25 | logger.log("CUDA GPU numbers : {:}".format(torch.cuda.device_count())) 26 | logger.log("CUDA_VISIBLE_DEVICES : {:}".format(os.environ['CUDA_VISIBLE_DEVICES'] if 'CUDA_VISIBLE_DEVICES' in os.environ else 'None')) 27 | return logger 28 | 29 | 30 | def get_machine_info(): 31 | info = "Python Version : {:}".format(sys.version.replace('\n', ' ')) 32 | info += "\nPillow Version : {:}".format(PIL.__version__) 33 | info += "\nPyTorch Version : {:}".format(torch.__version__) 34 | info += "\ncuDNN Version : {:}".format(torch.backends.cudnn.version()) 35 | info += "\nCUDA available : {:}".format(torch.cuda.is_available()) 36 | info += "\nCUDA GPU numbers : {:}".format(torch.cuda.device_count()) 37 | if 'CUDA_VISIBLE_DEVICES' in os.environ: 38 | info += "\nCUDA_VISIBLE_DEVICES={:}".format(os.environ['CUDA_VISIBLE_DEVICES']) 39 | else: 40 | info += "\nDoes not set CUDA_VISIBLE_DEVICES" 41 | return info 42 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .DS_Store/ 2 | .idea/ 3 | # Byte-compiled / optimized / DLL files 4 | __pycache__/ 5 | *.py[cod] 6 | *$py.class 7 | 8 | # C extensions 9 | *.so 10 | 11 | # Distribution / packaging 12 | .Python 13 | env/ 14 | build/ 15 | develop-eggs/ 16 | dist/ 17 | downloads/ 18 | eggs/ 19 | .eggs/ 20 | lib64/ 21 | parts/ 22 | sdist/ 23 | var/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *,cover 47 | .hypothesis/ 48 | 49 | # Translations 50 | *.mo 51 | *.pot 52 | 53 | # Django stuff: 54 | *.log 55 | local_settings.py 56 | 57 | # Flask stuff: 58 | instance/ 59 | .webassets-cache 60 | 61 | # Scrapy stuff: 62 | .scrapy 63 | 64 | # Sphinx documentation 65 | docs/_build/ 66 | 67 | # PyBuilder 68 | target/ 69 | 70 | # IPython Notebook 71 | .ipynb_checkpoints 72 | 73 | # pyenv 74 | .python-version 75 | 76 | # celery beat schedule file 77 | celerybeat-schedule 78 | 79 | # dotenv 80 | .env 81 | 82 | # virtualenv 83 | venv/ 84 | ENV/ 85 | 86 | # Spyder project settings 87 | .spyderproject 88 | 89 | # Rope project settings 90 | .ropeproject 91 | 92 | # Pycharm project 93 | .idea 94 | snapshots 95 | *.pytorch 96 | *.tar.bz 97 | data 98 | .*.swp 99 | main_main.py 100 | *.pdf 101 | */*.pdf 102 | 103 | # Device 104 | scripts-nas/.nfs00* 105 | */.nfs00* 106 | *.DS_Store 107 | 108 | # logs and snapshots 109 | output 110 | logs 111 | 112 | # snapshot 113 | a.pth 114 | cal-merge*.sh 115 | GPU-*.sh 116 | cal.sh 117 | aaa 118 | cx.sh 119 | 120 | NAS-Bench-*-v1_0.pth 121 | lib/NAS-Bench-*-v1_0.pth 122 | others/TF 123 | scripts-search/l2s-algos 124 | TEMP-L.sh 125 | 126 | .nfs00* 127 | -------------------------------------------------------------------------------- /lib/log_utils/logger.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | import sys 3 | if sys.version_info.major == 2: # Python 2.x 4 | from StringIO import StringIO as BIO 5 | else: # Python 3.x 6 | from io import BytesIO as BIO 7 | from torch.utils.tensorboard import SummaryWriter 8 | 9 | 10 | class PrintLogger(object): 11 | 12 | def __init__(self): 13 | """Create a summary writer logging to log_dir.""" 14 | self.name = 'PrintLogger' 15 | 16 | def log(self, string): 17 | print (string) 18 | 19 | def close(self): 20 | print ('-'*30 + ' close printer ' + '-'*30) 21 | 22 | 23 | class Logger(object): 24 | 25 | def __init__(self, log_dir, seed, create_model_dir=True): 26 | """Create a summary writer logging to log_dir.""" 27 | self.seed = int(seed) 28 | self.log_dir = Path(log_dir) 29 | self.model_dir = Path(log_dir) / 'model' 30 | self.log_dir.mkdir (parents=True, exist_ok=True) 31 | if create_model_dir: 32 | self.model_dir.mkdir(parents=True, exist_ok=True) 33 | 34 | self.tensorboard_dir = self.log_dir 35 | self.logger_path = self.log_dir / 'seed-{:}.log'.format(self.seed) 36 | self.logger_file = open(self.logger_path, 'w') 37 | 38 | self.tensorboard_dir.mkdir(mode=0o775, parents=True, exist_ok=True) 39 | self.writer = SummaryWriter(str(self.tensorboard_dir)) 40 | 41 | def __repr__(self): 42 | return ('{name}(dir={log_dir}, writer={writer})'.format(name=self.__class__.__name__, **self.__dict__)) 43 | 44 | def extract_log(self): 45 | return self.logger_file 46 | 47 | def close(self): 48 | self.logger_file.close() 49 | if self.writer is not None: 50 | self.writer.close() 51 | 52 | def log(self, string, save=True, stdout=False): 53 | if stdout: 54 | sys.stdout.write(string); sys.stdout.flush() 55 | else: 56 | print(string) 57 | if save: 58 | self.logger_file.write('{:}\n'.format(string)) 59 | self.logger_file.flush() 60 | -------------------------------------------------------------------------------- /lib/datasets/SearchDatasetWrap.py: -------------------------------------------------------------------------------- 1 | import random 2 | import torch.utils.data as data 3 | 4 | 5 | class SearchDataset(data.Dataset): 6 | 7 | def __init__(self, name, data, train_split, valid_split, check=True): 8 | self.datasetname = name 9 | if isinstance(data, (list, tuple)): # new type of SearchDataset 10 | assert len(data) == 2, 'invalid length: {:}'.format( len(data) ) 11 | self.train_data = data[0] 12 | self.valid_data = data[1] 13 | self.train_split = train_split.copy() 14 | self.valid_split = valid_split.copy() 15 | self.mode_str = 'V2' # new mode 16 | else: 17 | self.mode_str = 'V1' # old mode 18 | self.data = data 19 | self.train_split = train_split.copy() 20 | self.valid_split = valid_split.copy() 21 | if check: 22 | intersection = set(train_split).intersection(set(valid_split)) 23 | assert len(intersection) == 0, 'the splitted train and validation sets should have no intersection' 24 | self.length = len(self.train_split) 25 | 26 | def __repr__(self): 27 | return ('{name}(name={datasetname}, train={tr_L}, valid={val_L}, version={ver})'.format(name=self.__class__.__name__, datasetname=self.datasetname, tr_L=len(self.train_split), val_L=len(self.valid_split), ver=self.mode_str)) 28 | 29 | def __len__(self): 30 | return self.length 31 | 32 | def __getitem__(self, index): 33 | assert index >= 0 and index < self.length, 'invalid index = {:}'.format(index) 34 | train_index = self.train_split[index] 35 | valid_index = random.choice( self.valid_split ) 36 | if self.mode_str == 'V1': 37 | train_image, train_label = self.data[train_index] 38 | valid_image, valid_label = self.data[valid_index] 39 | elif self.mode_str == 'V2': 40 | train_image, train_label = self.train_data[train_index] 41 | valid_image, valid_label = self.valid_data[valid_index] 42 | else: raise ValueError('invalid mode : {:}'.format(self.mode_str)) 43 | return train_image, train_label, valid_image, valid_label 44 | -------------------------------------------------------------------------------- /lib/models/shape_infers/InferTinyCellNet.py: -------------------------------------------------------------------------------- 1 | from typing import List, Text, Any 2 | import torch.nn as nn 3 | from models.cell_operations import ResNetBasicblock 4 | from models.cell_infers.cells import InferCell 5 | 6 | 7 | class DynamicShapeTinyNet(nn.Module): 8 | 9 | def __init__(self, channels: List[int], genotype: Any, num_classes: int): 10 | super(DynamicShapeTinyNet, self).__init__() 11 | self._channels = channels 12 | if len(channels) % 3 != 2: 13 | raise ValueError('invalid number of layers : {:}'.format(len(channels))) 14 | self._num_stage = N = len(channels) // 3 15 | 16 | self.stem = nn.Sequential(nn.Conv2d(3, channels[0], kernel_size=3, padding=1, bias=False), 17 | nn.BatchNorm2d(channels[0])) 18 | 19 | # layer_channels = [C ] * N + [C*2 ] + [C*2 ] * N + [C*4 ] + [C*4 ] * N 20 | layer_reductions = [False] * N + [True] + [False] * N + [True] + [False] * N 21 | 22 | c_prev = channels[0] 23 | self.cells = nn.ModuleList() 24 | for index, (c_curr, reduction) in enumerate(zip(channels, layer_reductions)): 25 | if reduction : cell = ResNetBasicblock(c_prev, c_curr, 2, True) 26 | else : cell = InferCell(genotype, c_prev, c_curr, 1) 27 | self.cells.append( cell ) 28 | c_prev = cell.out_dim 29 | self._num_layer = len(self.cells) 30 | 31 | self.lastact = nn.Sequential(nn.BatchNorm2d(c_prev), nn.ReLU(inplace=True)) 32 | self.global_pooling = nn.AdaptiveAvgPool2d(1) 33 | self.classifier = nn.Linear(c_prev, num_classes) 34 | 35 | def get_message(self) -> Text: 36 | string = self.extra_repr() 37 | for i, cell in enumerate(self.cells): 38 | string += '\n {:02d}/{:02d} :: {:}'.format(i, len(self.cells), cell.extra_repr()) 39 | return string 40 | 41 | def extra_repr(self): 42 | return ('{name}(C={_channels}, N={_num_stage}, L={_num_layer})'.format(name=self.__class__.__name__, **self.__dict__)) 43 | 44 | def forward(self, inputs): 45 | feature = self.stem(inputs) 46 | for i, cell in enumerate(self.cells): 47 | feature = cell(feature) 48 | 49 | out = self.lastact(feature) 50 | out = self.global_pooling( out ) 51 | out = out.view(out.size(0), -1) 52 | logits = self.classifier(out) 53 | 54 | return out, logits 55 | -------------------------------------------------------------------------------- /eval_lib/rfs_models/convnet.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | 7 | class ConvNet(nn.Module): 8 | 9 | def __init__(self, num_classes=-1, hidden_size=64): 10 | super(ConvNet, self).__init__() 11 | self.layer1 = nn.Sequential( 12 | nn.Conv2d(3, hidden_size, kernel_size=3, padding=1), 13 | nn.BatchNorm2d(hidden_size), 14 | nn.ReLU(), 15 | nn.MaxPool2d(2)) 16 | self.layer2 = nn.Sequential( 17 | nn.Conv2d(hidden_size, hidden_size, kernel_size=3, padding=1), 18 | nn.BatchNorm2d(hidden_size), 19 | nn.ReLU(), 20 | nn.MaxPool2d(2)) 21 | self.layer3 = nn.Sequential( 22 | nn.Conv2d(hidden_size, hidden_size, kernel_size=3, padding=1), 23 | nn.BatchNorm2d(hidden_size), 24 | nn.ReLU(), 25 | nn.MaxPool2d(2)) 26 | self.layer4 = nn.Sequential( 27 | nn.Conv2d(hidden_size, hidden_size, kernel_size=3, padding=1), 28 | # nn.BatchNorm2d(64, momentum=1, affine=True, track_running_stats=False), 29 | nn.BatchNorm2d(hidden_size), 30 | nn.ReLU()) 31 | 32 | self.avgpool = nn.AdaptiveAvgPool2d(1) 33 | 34 | self.num_classes = num_classes 35 | if self.num_classes > 0: 36 | self.classifier = nn.Linear(hidden_size, self.num_classes) 37 | 38 | for m in self.modules(): 39 | if isinstance(m, nn.Conv2d): 40 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 41 | elif isinstance(m, nn.BatchNorm2d): 42 | nn.init.constant_(m.weight, 1) 43 | nn.init.constant_(m.bias, 0) 44 | 45 | def forward(self, x, is_feat=False): 46 | out = self.layer1(x) 47 | f0 = out 48 | out = self.layer2(out) 49 | f1 = out 50 | out = self.layer3(out) 51 | f2 = out 52 | out = self.layer4(out) 53 | f3 = out 54 | out = self.avgpool(out) 55 | out = out.view(out.size(0), -1) 56 | feat = out 57 | 58 | if self.num_classes > 0: 59 | out = self.classifier(out) 60 | 61 | if is_feat: 62 | return [f0, f1, f2, f3, feat], out 63 | else: 64 | return out 65 | 66 | 67 | def convnet4(**kwargs): 68 | """Four layer ConvNet 69 | """ 70 | model = ConvNet(**kwargs) 71 | return model 72 | 73 | 74 | if __name__ == '__main__': 75 | model = convnet4(num_classes=64) 76 | data = torch.randn(2, 3, 84, 84) 77 | feat, logit = model(data, is_feat=True) 78 | print(feat[-1].shape) 79 | print(logit.shape) 80 | -------------------------------------------------------------------------------- /lib/models/cell_infers/tiny_network.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from ..cell_operations import ResNetBasicblock 3 | from .cells import InferCell 4 | from pdb import set_trace as bp 5 | 6 | 7 | # The macro structure for architectures in NAS-Bench-201 8 | class TinyNetwork(nn.Module): 9 | 10 | def __init__(self, C, N, genotype, num_classes, C_in=3, depth=-1): 11 | super(TinyNetwork, self).__init__() 12 | self._C = C 13 | self._layerN = N 14 | # C_in: number of input channel 15 | # depth: number of cells to forward 16 | 17 | self.stem = nn.Sequential(nn.Conv2d(C_in, C, kernel_size=3, padding=1, bias=False), 18 | nn.BatchNorm2d(C)) 19 | 20 | layer_channels = [C ] * N + [C*2 ] + [C*2 ] * N + [C*4 ] + [C*4 ] * N 21 | layer_reductions = [False] * N + [True] + [False] * N + [True] + [False] * N 22 | 23 | if depth == -1: 24 | self.depth = len(layer_channels) 25 | else: 26 | self.depth = min(depth, len(layer_channels)) 27 | 28 | C_prev = C 29 | self.cells = nn.ModuleList() 30 | for index, (C_curr, reduction) in enumerate(zip(layer_channels, layer_reductions)): 31 | if index >= self.depth: 32 | break 33 | if reduction: 34 | cell = ResNetBasicblock(C_prev, C_curr, 2, True) 35 | else: 36 | cell = InferCell(genotype, C_prev, C_curr, 1) 37 | self.cells.append( cell ) 38 | C_prev = cell.out_dim 39 | self._Layer= len(self.cells) 40 | 41 | self.lastact = nn.Sequential(nn.BatchNorm2d(C_prev), nn.ReLU(inplace=True)) 42 | self.global_pooling = nn.AdaptiveAvgPool2d(1) 43 | self.classifier = nn.Linear(C_prev, num_classes) 44 | 45 | def get_message(self): 46 | string = self.extra_repr() 47 | for i, cell in enumerate(self.cells): 48 | string += '\n {:02d}/{:02d} :: {:}'.format(i, len(self.cells), cell.extra_repr()) 49 | return string 50 | 51 | def extra_repr(self): 52 | return ('{name}(C={_C}, N={_layerN}, L={_Layer})'.format(name=self.__class__.__name__, **self.__dict__)) 53 | 54 | def forward(self, inputs, return_features=False): 55 | features = [] 56 | feature = self.stem(inputs) 57 | features.append(feature) 58 | for i, cell in enumerate(self.cells): 59 | feature = cell(feature) 60 | features.append(feature) 61 | 62 | out = self.lastact(feature) 63 | features.append(out) 64 | out = self.global_pooling(out) 65 | features.append(out) 66 | out = out.view(out.size(0), -1) 67 | logits = self.classifier(out) 68 | 69 | if return_features: 70 | return out, logits, features 71 | else: 72 | return out, logits 73 | -------------------------------------------------------------------------------- /eval_lib/rfs_models/util.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | 4 | from . import model_dict 5 | 6 | from collections import namedtuple 7 | 8 | Genotype = namedtuple('Genotype', 'normal normal_concat reduce reduce_concat') 9 | 10 | def create_model(name, n_cls, dataset='miniImageNet', args=None): 11 | """create model by name""" 12 | if dataset == 'miniImageNet' or dataset == 'tieredImageNet': 13 | if name.endswith('v2') or name.endswith('v3'): 14 | model = model_dict[name](num_classes=n_cls) 15 | elif name.startswith('resnet50'): 16 | print('use imagenet-style resnet50') 17 | model = model_dict[name](num_classes=n_cls) 18 | elif name.startswith('resnet') or name.startswith('seresnet'): 19 | model = model_dict[name](avg_pool=True, drop_rate=0.1, dropblock_size=5, num_classes=n_cls) 20 | elif name.startswith('wrn'): 21 | model = model_dict[name](num_classes=n_cls) 22 | elif name == 'convnet4small': 23 | model = model_dict[name](num_classes=n_cls, hidden_size=32) 24 | elif name.startswith('convnet'): 25 | model = model_dict[name](num_classes=n_cls) 26 | elif name.startswith('dartsmodel'): 27 | assert args is not None 28 | assert args.genotype != '' 29 | genotype = eval(args.genotype) 30 | model = model_dict[name](args, args.init_channels, n_cls, args.layers, criterion=None, auxiliary=None, genotype=genotype) 31 | elif name == 'augmentcnn': 32 | assert args is not None 33 | assert args.genotype != '' 34 | genotype = eval(args.genotype) 35 | model = model_dict[name](input_size=args.input_size, C_in=args.n_input_channels, C=args.init_channels, n_classes=n_cls, n_layers=args.layers, auxiliary=False, genotype=genotype, stem_multiplier=args.aug_stemm, feature_scale_rate=args.aug_fsr) 36 | else: 37 | raise NotImplementedError('model {} not supported in dataset {}:'.format(name, dataset)) 38 | elif dataset == 'CIFAR-FS' or dataset == 'FC100': 39 | if name.startswith('resnet') or name.startswith('seresnet'): 40 | model = model_dict[name](avg_pool=True, drop_rate=0.1, dropblock_size=2, num_classes=n_cls) 41 | elif name.startswith('convnet'): 42 | model = model_dict[name](num_classes=n_cls) 43 | else: 44 | raise NotImplementedError('model {} not supported in dataset {}:'.format(name, dataset)) 45 | else: 46 | raise NotImplementedError('dataset not supported: {}'.format(dataset)) 47 | 48 | return model 49 | 50 | def count_params(net): 51 | return sum(p.numel() for p in net.parameters()) 52 | 53 | def get_teacher_name(model_path): 54 | """parse to get teacher model name""" 55 | segments = model_path.split('/')[-2].split('_') 56 | if ':' in segments[0]: 57 | return segments[0].split(':')[-1] 58 | else: 59 | if segments[0] != 'wrn': 60 | return segments[0] 61 | else: 62 | return segments[0] + '_' + segments[1] + '_' + segments[2] 63 | -------------------------------------------------------------------------------- /eval_lib/rfs_util.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | 5 | 6 | class LabelSmoothing(nn.Module): 7 | """ 8 | NLL loss with label smoothing. 9 | """ 10 | def __init__(self, smoothing=0.0): 11 | """ 12 | Constructor for the LabelSmoothing module. 13 | :param smoothing: label smoothing factor 14 | """ 15 | super(LabelSmoothing, self).__init__() 16 | self.confidence = 1.0 - smoothing 17 | self.smoothing = smoothing 18 | 19 | def forward(self, x, target): 20 | logprobs = torch.nn.functional.log_softmax(x, dim=-1) 21 | 22 | nll_loss = -logprobs.gather(dim=-1, index=target.unsqueeze(1)) 23 | nll_loss = nll_loss.squeeze(1) 24 | smooth_loss = -logprobs.mean(dim=-1) 25 | loss = self.confidence * nll_loss + self.smoothing * smooth_loss 26 | return loss.mean() 27 | 28 | 29 | class BCEWithLogitsLoss(nn.Module): 30 | def __init__(self, weight=None, size_average=None, reduce=None, reduction='mean', pos_weight=None, num_classes=64): 31 | super(BCEWithLogitsLoss, self).__init__() 32 | self.num_classes = num_classes 33 | self.criterion = nn.BCEWithLogitsLoss(weight=weight, 34 | size_average=size_average, 35 | reduce=reduce, 36 | reduction=reduction, 37 | pos_weight=pos_weight) 38 | def forward(self, input, target): 39 | target_onehot = F.one_hot(target, num_classes=self.num_classes) 40 | return self.criterion(input, target_onehot) 41 | 42 | 43 | class AverageMeter(object): 44 | """Computes and stores the average and current value""" 45 | def __init__(self): 46 | self.reset() 47 | 48 | def reset(self): 49 | self.val = 0 50 | self.avg = 0 51 | self.sum = 0 52 | self.count = 0 53 | 54 | def update(self, val, n=1): 55 | self.val = val 56 | self.sum += val * n 57 | self.count += n 58 | self.avg = self.sum / self.count 59 | 60 | 61 | def adjust_learning_rate(epoch, opt, optimizer): 62 | """Sets the learning rate to the initial LR decayed by decay rate every steep step""" 63 | steps = np.sum(epoch > np.asarray(opt.lr_decay_epochs)) 64 | if steps > 0: 65 | new_lr = opt.learning_rate * (opt.lr_decay_rate ** steps) 66 | for param_group in optimizer.param_groups: 67 | param_group['lr'] = new_lr 68 | 69 | 70 | def accuracy(output, target, topk=(1,)): 71 | """Computes the accuracy over the k top predictions for the specified values of k""" 72 | with torch.no_grad(): 73 | maxk = max(topk) 74 | batch_size = target.size(0) 75 | 76 | _, pred = output.topk(maxk, 1, True, True) 77 | pred = pred.t() 78 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 79 | 80 | res = [] 81 | for k in topk: 82 | correct_k = correct[:k].contiguous().view(-1).float().sum(0, keepdim=True) 83 | res.append(correct_k.mul_(100.0 / batch_size)) 84 | return res 85 | -------------------------------------------------------------------------------- /lib/models/cell_infers/nasnet_cifar.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from .cells import NASNetInferCell as InferCell, AuxiliaryHeadCIFAR 3 | 4 | 5 | # The macro structure is based on NASNet 6 | class NASNetonCIFAR(nn.Module): 7 | 8 | def __init__(self, C, N, stem_multiplier, num_classes, genotype, auxiliary, affine=True, track_running_stats=True): 9 | super(NASNetonCIFAR, self).__init__() 10 | self._C = C 11 | self._layerN = N 12 | self.stem = nn.Sequential(nn.Conv2d(3, C*stem_multiplier, kernel_size=3, padding=1, bias=False), 13 | nn.BatchNorm2d(C*stem_multiplier)) 14 | 15 | # config for each layer 16 | layer_channels = [C ] * N + [C*2 ] + [C*2 ] * (N-1) + [C*4 ] + [C*4 ] * (N-1) 17 | layer_reductions = [False] * N + [True] + [False] * (N-1) + [True] + [False] * (N-1) 18 | 19 | C_prev_prev, C_prev, C_curr, reduction_prev = C*stem_multiplier, C*stem_multiplier, C, False 20 | self.auxiliary_index = None 21 | self.auxiliary_head = None 22 | self.cells = nn.ModuleList() 23 | for index, (C_curr, reduction) in enumerate(zip(layer_channels, layer_reductions)): 24 | cell = InferCell(genotype, C_prev_prev, C_prev, C_curr, reduction, reduction_prev, affine, track_running_stats) 25 | self.cells.append( cell ) 26 | C_prev_prev, C_prev, reduction_prev = C_prev, cell._multiplier*C_curr, reduction 27 | if reduction and C_curr == C*4 and auxiliary: 28 | self.auxiliary_head = AuxiliaryHeadCIFAR(C_prev, num_classes) 29 | self.auxiliary_index = index 30 | self._Layer = len(self.cells) 31 | self.lastact = nn.Sequential(nn.BatchNorm2d(C_prev), nn.ReLU(inplace=True)) 32 | self.global_pooling = nn.AdaptiveAvgPool2d(1) 33 | self.classifier = nn.Linear(C_prev, num_classes) 34 | self.drop_path_prob = -1 35 | 36 | def update_drop_path(self, drop_path_prob): 37 | self.drop_path_prob = drop_path_prob 38 | 39 | def auxiliary_param(self): 40 | if self.auxiliary_head is None: return [] 41 | else: return list( self.auxiliary_head.parameters() ) 42 | 43 | def get_message(self): 44 | string = self.extra_repr() 45 | for i, cell in enumerate(self.cells): 46 | string += '\n {:02d}/{:02d} :: {:}'.format(i, len(self.cells), cell.extra_repr()) 47 | return string 48 | 49 | def extra_repr(self): 50 | return ('{name}(C={_C}, N={_layerN}, L={_Layer})'.format(name=self.__class__.__name__, **self.__dict__)) 51 | 52 | def forward(self, inputs): 53 | stem_feature, logits_aux = self.stem(inputs), None 54 | cell_results = [stem_feature, stem_feature] 55 | for i, cell in enumerate(self.cells): 56 | cell_feature = cell(cell_results[-2], cell_results[-1], self.drop_path_prob) 57 | cell_results.append( cell_feature ) 58 | if self.auxiliary_index is not None and i == self.auxiliary_index and self.training: 59 | logits_aux = self.auxiliary_head( cell_results[-1] ) 60 | out = self.lastact(cell_results[-1]) 61 | out = self.global_pooling( out ) 62 | out = out.view(out.size(0), -1) 63 | logits = self.classifier(out) 64 | if logits_aux is None: return out, logits 65 | else: return out, [logits, logits_aux] 66 | -------------------------------------------------------------------------------- /eval_lib/rfs_dataset/transform_cfg.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import numpy as np 4 | from PIL import Image 5 | import torchvision.transforms as transforms 6 | 7 | 8 | mean = [120.39586422 / 255.0, 115.59361427 / 255.0, 104.54012653 / 255.0] 9 | std = [70.68188272 / 255.0, 68.27635443 / 255.0, 72.54505529 / 255.0] 10 | normalize = transforms.Normalize(mean=mean, std=std) 11 | 12 | 13 | transform_A = [ 14 | transforms.Compose([ 15 | transforms.ToPILImage(), 16 | transforms.RandomCrop(84, padding=8), 17 | transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4), 18 | transforms.RandomHorizontalFlip(), 19 | # lambda x: np.asarray(x), 20 | transforms.ToTensor(), 21 | normalize 22 | ]), 23 | 24 | transforms.Compose([ 25 | transforms.ToPILImage(), 26 | transforms.ToTensor(), 27 | normalize 28 | ]) 29 | ] 30 | 31 | 32 | transform_B = [ 33 | transforms.Compose([ 34 | transforms.ToPILImage(), 35 | transforms.RandomResizedCrop(84, scale=(0.2, 1.0)), 36 | transforms.RandomHorizontalFlip(), 37 | # lambda x: np.asarray(x), 38 | transforms.ToTensor(), 39 | transforms.Normalize(np.array([x / 255.0 for x in [125.3, 123.0, 113.9]]), 40 | np.array([x / 255.0 for x in [63.0, 62.1, 66.7]])) 41 | ]), 42 | 43 | transforms.Compose([ 44 | transforms.ToPILImage(), 45 | transforms.Resize(92), 46 | transforms.CenterCrop(84), 47 | transforms.ToTensor(), 48 | transforms.Normalize(np.array([x / 255.0 for x in [125.3, 123.0, 113.9]]), 49 | np.array([x / 255.0 for x in [63.0, 62.1, 66.7]])) 50 | ]) 51 | ] 52 | 53 | transform_C = [ 54 | transforms.Compose([ 55 | transforms.ToPILImage(), 56 | # transforms.Resize(92, interpolation = PIL.Image.BICUBIC), 57 | transforms.RandomResizedCrop(80), 58 | # transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4), 59 | transforms.RandomHorizontalFlip(), 60 | transforms.ToTensor(), 61 | # Lighting(0.1, imagenet_pca['eigval'], imagenet_pca['eigvec']), 62 | # normalize 63 | transforms.Normalize(np.array([x / 255.0 for x in [125.3, 123.0, 113.9]]), 64 | np.array([x / 255.0 for x in [63.0, 62.1, 66.7]])) 65 | ]), 66 | 67 | transforms.Compose([ 68 | transforms.ToPILImage(), 69 | transforms.Resize(92), 70 | transforms.CenterCrop(80), 71 | transforms.ToTensor(), 72 | # normalize 73 | transforms.Normalize(np.array([x / 255.0 for x in [125.3, 123.0, 113.9]]), 74 | np.array([x / 255.0 for x in [63.0, 62.1, 66.7]])) 75 | ]) 76 | ] 77 | 78 | # CIFAR style transformation 79 | mean = [0.5071, 0.4867, 0.4408] 80 | std = [0.2675, 0.2565, 0.2761] 81 | normalize_cifar100 = transforms.Normalize(mean=mean, std=std) 82 | transform_D = [ 83 | transforms.Compose([ 84 | transforms.ToPILImage(), 85 | transforms.RandomCrop(32, padding=4), 86 | transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4), 87 | transforms.RandomHorizontalFlip(), 88 | # lambda x: np.asarray(x), 89 | transforms.ToTensor(), 90 | normalize_cifar100 91 | ]), 92 | 93 | transforms.Compose([ 94 | transforms.ToPILImage(), 95 | transforms.ToTensor(), 96 | normalize_cifar100 97 | ]) 98 | ] 99 | 100 | 101 | transforms_list = ['A', 'B', 'C', 'D'] 102 | 103 | 104 | transforms_options = { 105 | 'A': transform_A, 106 | 'B': transform_B, 107 | 'C': transform_C, 108 | 'D': transform_D, 109 | } 110 | -------------------------------------------------------------------------------- /eval_lib/rfs_models/wresnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.init as init 4 | import torch.nn.functional as F 5 | from torch.autograd import Variable 6 | 7 | import sys 8 | import numpy as np 9 | 10 | 11 | def conv3x3(in_planes, out_planes, stride=1): 12 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=True) 13 | 14 | 15 | def conv_init(m): 16 | classname = m.__class__.__name__ 17 | if classname.find('Conv') != -1: 18 | init.xavier_uniform(m.weight, gain=np.sqrt(2)) 19 | init.constant(m.bias, 0) 20 | elif classname.find('BatchNorm') != -1: 21 | init.constant(m.weight, 1) 22 | init.constant(m.bias, 0) 23 | 24 | 25 | class wide_basic(nn.Module): 26 | def __init__(self, in_planes, planes, dropout_rate, stride=1): 27 | super(wide_basic, self).__init__() 28 | self.bn1 = nn.BatchNorm2d(in_planes) 29 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, padding=1, bias=True) 30 | self.dropout = nn.Dropout(p=dropout_rate) 31 | self.bn2 = nn.BatchNorm2d(planes) 32 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=True) 33 | 34 | self.shortcut = nn.Sequential() 35 | if stride != 1 or in_planes != planes: 36 | self.shortcut = nn.Sequential( 37 | nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride, bias=True), 38 | ) 39 | 40 | def forward(self, x): 41 | out = self.dropout(self.conv1(F.relu(self.bn1(x)))) 42 | out = self.conv2(F.relu(self.bn2(out))) 43 | out += self.shortcut(x) 44 | 45 | return out 46 | 47 | 48 | class Wide_ResNet(nn.Module): 49 | def __init__(self, depth, widen_factor, dropout_rate, num_classes=-1): 50 | super(Wide_ResNet, self).__init__() 51 | self.in_planes = 16 52 | 53 | assert ((depth-4)%6 ==0), 'Wide-resnet depth should be 6n+4' 54 | n = (depth-4) // 6 55 | k = widen_factor 56 | 57 | print('| Wide-Resnet %dx%d' %(depth, k)) 58 | nStages = [16, 16*k, 32*k, 64*k] 59 | 60 | self.conv1 = conv3x3(3,nStages[0]) 61 | self.layer1 = self._wide_layer(wide_basic, nStages[1], n, dropout_rate, stride=1) 62 | self.layer2 = self._wide_layer(wide_basic, nStages[2], n, dropout_rate, stride=2) 63 | self.layer3 = self._wide_layer(wide_basic, nStages[3], n, dropout_rate, stride=2) 64 | self.bn1 = nn.BatchNorm2d(nStages[3], momentum=0.9) 65 | 66 | self.num_classes = num_classes 67 | if self.num_classes > 0: 68 | self.classifier = nn.Linear(64*k, self.num_classes) 69 | 70 | def _wide_layer(self, block, planes, num_blocks, dropout_rate, stride): 71 | strides = [stride] + [1]*(num_blocks-1) 72 | layers = [] 73 | 74 | for stride in strides: 75 | layers.append(block(self.in_planes, planes, dropout_rate, stride)) 76 | self.in_planes = planes 77 | 78 | return nn.Sequential(*layers) 79 | 80 | def forward(self, x, is_feat=False): 81 | out = self.conv1(x) 82 | out = self.layer1(out) 83 | out = self.layer2(out) 84 | out = self.layer3(out) 85 | out = F.relu(self.bn1(out)) 86 | out = F.adaptive_avg_pool2d(out, 1) 87 | out = out.view(out.size(0), -1) 88 | feat = out 89 | if self.num_classes > 0: 90 | out = self.classifier(out) 91 | 92 | if is_feat: 93 | return [feat], out 94 | else: 95 | return out 96 | 97 | 98 | def wrn_28_10(dropout_rate=0.3, num_classes=-1): 99 | return Wide_ResNet(28, 10, dropout_rate, num_classes) 100 | 101 | 102 | if __name__ == '__main__': 103 | net=Wide_ResNet(28, 10, 0.3) 104 | y = net(Variable(torch.randn(1,3,32,32))) 105 | 106 | print(y.size()) -------------------------------------------------------------------------------- /lib/config_utils/configure_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | from pathlib import Path 4 | from collections import namedtuple 5 | 6 | support_types = ('str', 'int', 'bool', 'float', 'none') 7 | 8 | 9 | def convert_param(original_lists): 10 | assert isinstance(original_lists, list), 'The type is not right : {:}'.format(original_lists) 11 | ctype, value = original_lists[0], original_lists[1] 12 | assert ctype in support_types, 'Ctype={:}, support={:}'.format(ctype, support_types) 13 | is_list = isinstance(value, list) 14 | if not is_list: value = [value] 15 | outs = [] 16 | for x in value: 17 | if ctype == 'int': 18 | x = int(x) 19 | elif ctype == 'str': 20 | x = str(x) 21 | elif ctype == 'bool': 22 | x = bool(int(x)) 23 | elif ctype == 'float': 24 | x = float(x) 25 | elif ctype == 'none': 26 | if x.lower() != 'none': 27 | raise ValueError('For the none type, the value must be none instead of {:}'.format(x)) 28 | x = None 29 | else: 30 | raise TypeError('Does not know this type : {:}'.format(ctype)) 31 | outs.append(x) 32 | if not is_list: outs = outs[0] 33 | return outs 34 | 35 | 36 | def load_config_dict(path): 37 | path = str(path) 38 | assert os.path.exists(path), 'Can not find {:}'.format(path) 39 | # Reading data back 40 | with open(path, 'r') as f: 41 | data = json.load(f) 42 | content = {k: convert_param(v) for k, v in data.items()} 43 | return content 44 | 45 | 46 | def merge_config_dict(config_list): 47 | for config in config_list: 48 | assert isinstance(config, dict), 'invalid type of config: {:}'.format(type(config)) 49 | content = config_list[0] 50 | for config in config_list[1:]: 51 | content = {**content, **config} 52 | return content 53 | 54 | 55 | def load_config(path, extra, logger=None): 56 | if hasattr(logger, 'log'): 57 | logger.log(path) 58 | content = load_config_dict(path) 59 | assert extra is None or isinstance( 60 | extra, dict), 'invalid type of extra : {:}'.format(extra) 61 | if isinstance(extra, dict): 62 | content = {**content, **extra} 63 | Arguments = namedtuple('Configure', ' '.join(content.keys())) 64 | content = Arguments(**content) 65 | if hasattr(logger, 'log'): 66 | logger.log('{:}'.format(content)) 67 | return content 68 | 69 | 70 | def configure2str(config, xpath=None): 71 | if not isinstance(config, dict): 72 | config = config._asdict() 73 | def cstring(x): 74 | return "\"{:}\"".format(x) 75 | def gtype(x): 76 | if isinstance(x, list): x = x[0] 77 | if isinstance(x, str) : return 'str' 78 | elif isinstance(x, bool) : return 'bool' 79 | elif isinstance(x, int): return 'int' 80 | elif isinstance(x, float): return 'float' 81 | elif x is None : return 'none' 82 | else: raise ValueError('invalid : {:}'.format(x)) 83 | def cvalue(x, xtype): 84 | if isinstance(x, list): is_list = True 85 | else: 86 | is_list, x = False, [x] 87 | temps = [] 88 | for temp in x: 89 | if xtype == 'bool' : temp = cstring(int(temp)) 90 | elif xtype == 'none': temp = cstring('None') 91 | else : temp = cstring(temp) 92 | temps.append( temp ) 93 | if is_list: 94 | return "[{:}]".format( ', '.join( temps ) ) 95 | else: 96 | return temps[0] 97 | 98 | xstrings = [] 99 | for key, value in config.items(): 100 | xtype = gtype(value) 101 | string = ' {:20s} : [{:8s}, {:}]'.format(cstring(key), cstring(xtype), cvalue(value, xtype)) 102 | xstrings.append(string) 103 | Fstring = '{\n' + ',\n'.join(xstrings) + '\n}' 104 | if xpath is not None: 105 | parent = Path(xpath).resolve().parent 106 | parent.mkdir(parents=True, exist_ok=True) 107 | if os.path.isfile(xpath): os.remove(xpath) 108 | with open(xpath, "w") as text_file: 109 | text_file.write('{:}'.format(Fstring)) 110 | return Fstring 111 | 112 | 113 | def dict2config(xdict, logger=None): 114 | assert isinstance(xdict, dict), 'invalid type : {:}'.format(type(xdict)) 115 | Arguments = namedtuple('Configure', ' '.join(xdict.keys())) 116 | content = Arguments(**xdict) 117 | if hasattr(logger, 'log'): 118 | logger.log('{:}'.format(content)) 119 | return content 120 | -------------------------------------------------------------------------------- /lib/models/cell_searchs/search_cells.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from copy import deepcopy 4 | from pdb import set_trace as bp 5 | from ..cell_operations import OPS 6 | 7 | 8 | INF = 1000 9 | 10 | 11 | # This module is used for NAS-Bench-201, represents a small search space with a complete DAG 12 | class NAS201SearchCell(nn.Module): 13 | 14 | def __init__(self, C_in, C_out, stride, max_nodes, op_names, affine=False, track_running_stats=True): 15 | super(NAS201SearchCell, self).__init__() 16 | 17 | self.op_names = deepcopy(op_names) 18 | self.edges = nn.ModuleDict() 19 | self.max_nodes = max_nodes 20 | self.in_dim = C_in 21 | self.out_dim = C_out 22 | for i in range(1, max_nodes): 23 | for j in range(i): 24 | node_str = '{:}<-{:}'.format(i, j) 25 | if j == 0: 26 | xlists = [OPS[op_name](C_in, C_out, stride, affine, track_running_stats) for op_name in op_names] 27 | else: 28 | xlists = [OPS[op_name](C_in, C_out, 1, affine, track_running_stats) for op_name in op_names] 29 | self.edges[node_str] = nn.ModuleList(xlists) 30 | self.edge_keys = sorted(list(self.edges.keys())) 31 | self.edge2index = {key: i for i, key in enumerate(self.edge_keys)} 32 | self.num_edges = len(self.edges) 33 | 34 | def extra_repr(self): 35 | string = 'info :: {max_nodes} nodes, inC={in_dim}, outC={out_dim}'.format(**self.__dict__) 36 | return string 37 | 38 | def forward(self, inputs, weightss): 39 | nodes = [inputs] 40 | for i in range(1, self.max_nodes): 41 | inter_nodes = [] 42 | for j in range(i): 43 | node_str = '{:}<-{:}'.format(i, j) 44 | weights = weightss[self.edge2index[node_str]] 45 | inter_nodes.append(sum(layer(nodes[j]) * w if w > 0.01 else 0 for layer, w in zip(self.edges[node_str], weights))) # for pruning purpose 46 | nodes.append(sum(inter_nodes)) 47 | return nodes[-1] 48 | 49 | 50 | class MixedOp(nn.Module): 51 | 52 | def __init__(self, space, C, stride, affine, track_running_stats): 53 | super(MixedOp, self).__init__() 54 | self._ops = nn.ModuleList() 55 | for primitive in space: 56 | op = OPS[primitive](C, C, stride, affine, track_running_stats) 57 | self._ops.append(op) 58 | 59 | def forward_darts(self, x, weights): 60 | return sum(w * op(x) if w > 0.01 else 0 for w, op in zip(weights, self._ops)) # for pruning purpose 61 | 62 | 63 | # Learning Transferable Architectures for Scalable Image Recognition, CVPR 2018 64 | class NASNetSearchCell(nn.Module): 65 | 66 | def __init__(self, space, steps, multiplier, C_prev_prev, C_prev, C, reduction, reduction_prev, affine, track_running_stats): 67 | super(NASNetSearchCell, self).__init__() 68 | self.reduction = reduction 69 | self.op_names = deepcopy(space) 70 | if reduction_prev: self.preprocess0 = OPS['skip_connect'](C_prev_prev, C, 2, affine, track_running_stats) 71 | else: self.preprocess0 = OPS['nor_conv_1x1'](C_prev_prev, C, 1, affine, track_running_stats) 72 | self.preprocess1 = OPS['nor_conv_1x1'](C_prev, C, 1, affine, track_running_stats) 73 | self._steps = steps 74 | self._multiplier = multiplier 75 | 76 | self._ops = nn.ModuleList() 77 | self.edges = nn.ModuleDict() 78 | for i in range(self._steps): 79 | for j in range(2+i): 80 | node_str = '{:}<-{:}'.format(i, j) # indicate the edge from node-(j) to node-(i+2) 81 | stride = 2 if reduction and j < 2 else 1 82 | op = MixedOp(space, C, stride, affine, track_running_stats) 83 | self.edges[ node_str ] = op 84 | self.edge_keys = sorted(list(self.edges.keys())) 85 | self.edge2index = {key:i for i, key in enumerate(self.edge_keys)} 86 | self.num_edges = len(self.edges) 87 | 88 | def forward_darts(self, s0, s1, weightss, alphass): 89 | s0 = self.preprocess0(s0) 90 | s1 = self.preprocess1(s1) 91 | 92 | states = [s0, s1] 93 | for i in range(self._steps): 94 | clist = [] 95 | for j, h in enumerate(states): 96 | node_str = '{:}<-{:}'.format(i, j) 97 | op = self.edges[ node_str ] 98 | weights = weightss[ self.edge2index[node_str] ] 99 | alphas = alphass[ self.edge2index[node_str] ] 100 | if sum(alphas) <= (-INF) * len(alphas): 101 | # all ops on this edge are masked out 102 | clist.append( 0 ) 103 | else: 104 | clist.append( op.forward_darts(h, weights) ) 105 | states.append( sum(clist) ) 106 | 107 | return torch.cat(states[-self._multiplier:], dim=1) -------------------------------------------------------------------------------- /lib/models/cell_infers/cells.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from copy import deepcopy 4 | from ..cell_operations import OPS 5 | 6 | 7 | # Cell for NAS-Bench-201 8 | class InferCell(nn.Module): 9 | 10 | def __init__(self, genotype, C_in, C_out, stride): 11 | super(InferCell, self).__init__() 12 | 13 | self.layers = nn.ModuleList() 14 | self.node_IN = [] 15 | self.node_IX = [] 16 | self.genotype = deepcopy(genotype) 17 | for i in range(1, len(genotype)): 18 | node_info = genotype[i-1] 19 | cur_index = [] 20 | cur_innod = [] 21 | for (op_name, op_in) in node_info: 22 | if op_in == 0: 23 | layer = OPS[op_name](C_in , C_out, stride, True, True) 24 | else: 25 | layer = OPS[op_name](C_out, C_out, 1, True, True) 26 | cur_index.append( len(self.layers) ) 27 | cur_innod.append( op_in ) 28 | self.layers.append( layer ) 29 | self.node_IX.append( cur_index ) 30 | self.node_IN.append( cur_innod ) 31 | self.nodes = len(genotype) 32 | self.in_dim = C_in 33 | self.out_dim = C_out 34 | 35 | def extra_repr(self): 36 | string = 'info :: nodes={nodes}, inC={in_dim}, outC={out_dim}'.format(**self.__dict__) 37 | laystr = [] 38 | for i, (node_layers, node_innods) in enumerate(zip(self.node_IX,self.node_IN)): 39 | y = ['I{:}-L{:}'.format(_ii, _il) for _il, _ii in zip(node_layers, node_innods)] 40 | x = '{:}<-({:})'.format(i+1, ','.join(y)) 41 | laystr.append( x ) 42 | return string + ', [{:}]'.format( ' | '.join(laystr) ) + ', {:}'.format(self.genotype.tostr()) 43 | 44 | def forward(self, inputs): 45 | nodes = [inputs] 46 | for i, (node_layers, node_innods) in enumerate(zip(self.node_IX,self.node_IN)): 47 | node_feature = sum( self.layers[_il](nodes[_ii]) for _il, _ii in zip(node_layers, node_innods) ) 48 | nodes.append( node_feature ) 49 | return nodes[-1] 50 | 51 | 52 | 53 | # Learning Transferable Architectures for Scalable Image Recognition, CVPR 2018 54 | class NASNetInferCell(nn.Module): 55 | 56 | def __init__(self, genotype, C_prev_prev, C_prev, C, reduction, reduction_prev, affine, track_running_stats): 57 | super(NASNetInferCell, self).__init__() 58 | self.reduction = reduction 59 | if reduction_prev: self.preprocess0 = OPS['skip_connect'](C_prev_prev, C, 2, affine, track_running_stats) 60 | else : self.preprocess0 = OPS['nor_conv_1x1'](C_prev_prev, C, 1, affine, track_running_stats) 61 | self.preprocess1 = OPS['nor_conv_1x1'](C_prev, C, 1, affine, track_running_stats) 62 | 63 | if not reduction: 64 | nodes, concats = genotype['normal'], genotype['normal_concat'] 65 | else: 66 | nodes, concats = genotype['reduce'], genotype['reduce_concat'] 67 | self._multiplier = len(concats) 68 | self._concats = concats 69 | self._steps = len(nodes) 70 | self._nodes = nodes 71 | self.edges = nn.ModuleDict() 72 | for i, node in enumerate(nodes): 73 | for in_node in node: 74 | name, j = in_node[0], in_node[1] 75 | stride = 2 if reduction and j < 2 else 1 76 | node_str = '{:}<-{:}'.format(i+2, j) 77 | self.edges[node_str] = OPS[name](C, C, stride, affine, track_running_stats) 78 | 79 | # [TODO] to support drop_prob in this function.. 80 | def forward(self, s0, s1, unused_drop_prob): 81 | s0 = self.preprocess0(s0) 82 | s1 = self.preprocess1(s1) 83 | 84 | states = [s0, s1] 85 | for i, node in enumerate(self._nodes): 86 | clist = [] 87 | for in_node in node: 88 | name, j = in_node[0], in_node[1] 89 | node_str = '{:}<-{:}'.format(i+2, j) 90 | op = self.edges[ node_str ] 91 | clist.append( op(states[j]) ) 92 | states.append( sum(clist) ) 93 | return torch.cat([states[x] for x in self._concats], dim=1) 94 | 95 | 96 | class AuxiliaryHeadCIFAR(nn.Module): 97 | 98 | def __init__(self, C, num_classes): 99 | """assuming input size 8x8""" 100 | super(AuxiliaryHeadCIFAR, self).__init__() 101 | self.features = nn.Sequential( 102 | nn.ReLU(inplace=True), 103 | nn.AvgPool2d(5, stride=3, padding=0, count_include_pad=False), # image size = 2 x 2 104 | nn.Conv2d(C, 128, 1, bias=False), 105 | nn.BatchNorm2d(128), 106 | nn.ReLU(inplace=True), 107 | nn.Conv2d(128, 768, 2, bias=False), 108 | nn.BatchNorm2d(768), 109 | nn.ReLU(inplace=True) 110 | ) 111 | self.classifier = nn.Linear(768, num_classes) 112 | 113 | def forward(self, x): 114 | x = self.features(x) 115 | x = self.classifier(x.view(x.size(0),-1)) 116 | return x 117 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # MetaNTK-NAS: Global Convergence of MAML and Theory-Inspired Neural Architecture Search for Few-Shot Learning [Paper](https://arxiv.org/abs/2203.09137) 2 | 3 | Haoxiang Wang*, Yite Wang*, Ruoyu Sun, Bo Li 4 | 5 | In CVPR 2022. 6 | 7 | If you find this repo useful for your research, please consider citing our paper 8 | ``` 9 | @inproceedings{MetaNTK-NAS, 10 | title={Global Convergence of MAML and Theory-Inspired Neural Architecture Search for Few-Shot Learning}, 11 | author={Wang, Haoxiang and Wang, Yite and Sun, Ruoyu and Li, Bo}, 12 | booktitle={CVPR}, 13 | year={2022} 14 | } 15 | ``` 16 | 17 | ## Overview 18 | 19 | This is the PyTorch implementation of MetaNTK-NAS, a training-free NAS method for few-shot learning based on Meta Neural Tangent Kernels (MetaNTK). 20 | 21 | ## Installation 22 | 23 | This repository has been tested with RedHat with Pytorch 1.3.1 on NVIDIA V100 GPUs and Ubuntu with Pytorch 1.10 on GTX 3090 and NVIDIA V100 GPUs. For other platforms, configurations may need to be changed. 24 | 25 | #### Required packages 26 | 27 | - Common packages: numpy, scipy, scikit-learn, easydict, pillow etc. 28 | - Pytorch packages: [Pytorch](https://pytorch.org/), Torchvision, [torch-optimizer](https://github.com/jettify/pytorch-optimizer). 29 | - Packages for efficient gradient computation: [Opacus](https://opacus.ai/). 30 | - Packages for counting operations and parameters for architectures: [ptflops](https://pypi.org/project/ptflops/). 31 | 32 | Or you can simply install all dependencies using: 33 | 34 | `pip install -r requirements.txt` 35 | 36 | ## Usage 37 | 38 | ### 0. Prepare the dataset 39 | 40 | * Please download MiniImageNet and TieredImageNet dataset from [RFS](https://github.com/WangYueFt/rfs). 41 | * Please properly set the `data_paths` in the `prune_launch.py`. 42 | 43 | ### 1. Search 44 | 45 | #### [DARTS_fewshot Space](https://arxiv.org/pdf/1911.11090.pdf) 46 | 47 | You may want to check the sample scripts in `scripts` folder. It will call `prune_lanch.py` with predefined configurations. Here are multiple arguments you might want to modify to replicate our experiment results. 48 | 49 | - `--gpu`: Which GPU to use during search/train/evaluation. 50 | - `--space`: Which search space to use. In our implementation, we use `'darts_fewshot'`. Choose from `['darts', 'darts_fewshot']`. 51 | - `--dataset`: Dataset to use for search/train/evaluation. If you want to search with NTK, choose one of `['MiniImageNet', 'TieredImageNet']`. On the other hand, if you want to search with MetaNTK, you will need to add 'Meta' to the front. Use `'MetaMiniImageNet'`, for example. 52 | - `--seed`: Manual seed. For random seed, set it to `-1`. 53 | - `--max_nodes`: Number of intermediate nodes in each cell. In our experiments, we use `3` as default value. 54 | - `--dartsbs`: (Meta) Batch size used for searching, make sure the total number of samples used for NTK and MetaNTK are the same. We use meta batch size of `3` as default value for MetaNTK-NAS. Since we use 5 way 1 query/support sample as one meta batch, 30 should be used for NTK as batch size as a fair comparison. 55 | - `--ntk_type`: Search based on condition numbers of NTK or MetaNTK. Choose from `['NTK', 'MetaNTK_anl']`. 56 | - `--ntk_channels`: Initial number of channels of architecture search/train/evaluation. We set this to `48` for all experiments. 57 | - `--ntk_layers`: Number of cells to stack for the final architecture during search/train/evaluation. In our experiments, we use `5` and `8`. 58 | - `--algorithm`: The algorithm of constructing MetaNTK kernel. Choose from `['MAML','ANIL']`. 59 | - `--inner_lr_time`: The product of inner loop learning rate and training time. The default value is `1000` (we treat inner_lr_time larger than 1000 as infinity). 60 | - `--reg_coef`: the regularization coefficient for the inner loop optimization, we suggest use value larger than `1e-5`. The default value is `1e-3`. 61 | - `--train_after_search`: If automatically train/evaluate the searched architecture. Choose from `['true', 'false']`. 62 | 63 | You may also directly call `prune_metantknas.py`, there you will have much more flexibility. Check the file for more details. 64 | 65 | ### 2. Evaluation 66 | 67 | * You can set `train_after_search` to `true` in script provided so architecture evaluation will automatically run after the search is done. 68 | * You can also use `eval_searched_arch.py` to train/evaluate obtained architectures (by its genotype). 69 | 70 | ## To-Do 71 | - [x] Provide searched architectures (in a Dropbox or Google Drive folder) 72 | - [ ] Provide optimal hyperparameters (in `scripts/`) for the fine-tuning part of evaluation, i.e., finetuning a linear classifier on the top of the trained searched architecture for test few-shot tasks. 73 | 74 | ## Acknowledgement 75 | 76 | * Code base from: 77 | * [MetaNAS](https://github.com/boschresearch/metanas): We mainly use the model provided by MetaNAS. Check `eval_lib/rfs_models/augment_cnn.py`. 78 | * [TE-NAS](https://github.com/VITA-Group/TENAS): Our architecture search code is developed based on TENAS. 79 | * [RFS](https://github.com/WangYueFt/rfs): Our architecture evaluation code is developed based on RFS. 80 | * **This work utilizes resources supported by the National Science Foundation’s Major Research Instrumentation program, grant #1725729, as well as the University of Illinois at Urbana-Champaign.** 81 | -------------------------------------------------------------------------------- /eval_lib/eval/meta_eval.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import numpy as np 4 | import scipy 5 | from scipy.stats import t 6 | from tqdm import tqdm 7 | 8 | import torch 9 | from sklearn import metrics 10 | from sklearn.svm import SVC, LinearSVC 11 | from sklearn.linear_model import LogisticRegression 12 | from sklearn.neighbors import KNeighborsClassifier 13 | from sklearn.ensemble import RandomForestClassifier 14 | 15 | from sklearn.pipeline import make_pipeline 16 | from sklearn.preprocessing import StandardScaler 17 | 18 | 19 | def mean_confidence_interval(data, confidence=0.95): 20 | a = 1.0 * np.array(data) 21 | n = len(a) 22 | m, se = np.mean(a), scipy.stats.sem(a) 23 | h = se * t._ppf((1+confidence)/2., n-1) 24 | return m, h 25 | 26 | 27 | def normalize(x): 28 | norm = x.pow(2).sum(1, keepdim=True).pow(1. / 2) 29 | out = x.div(norm) 30 | return out 31 | 32 | 33 | def meta_test(net, testloader, use_logit=True, is_norm=True, classifier='LR', opt=None, C=1.0): 34 | net = net.eval() 35 | acc = [] 36 | 37 | with torch.no_grad(): 38 | for idx, data in tqdm(enumerate(testloader)): 39 | support_xs, support_ys, query_xs, query_ys = data 40 | support_xs = support_xs.cuda() 41 | query_xs = query_xs.cuda() 42 | batch_size, _, channel, height, width = support_xs.size() 43 | support_xs = support_xs.view(-1, channel, height, width) 44 | query_xs = query_xs.view(-1, channel, height, width) 45 | 46 | if use_logit: 47 | support_features = net(support_xs).view(support_xs.size(0), -1) 48 | query_features = net(query_xs).view(query_xs.size(0), -1) 49 | else: 50 | feat_support, _ = net(support_xs, is_feat=True) 51 | support_features = feat_support[-1].view(support_xs.size(0), -1) 52 | feat_query, _ = net(query_xs, is_feat=True) 53 | query_features = feat_query[-1].view(query_xs.size(0), -1) 54 | 55 | if is_norm: 56 | support_features = normalize(support_features) 57 | query_features = normalize(query_features) 58 | 59 | support_features = support_features.detach().cpu().numpy() 60 | query_features = query_features.detach().cpu().numpy() 61 | 62 | support_ys = support_ys.view(-1).numpy() 63 | query_ys = query_ys.view(-1).numpy() 64 | 65 | # clf = SVC(gamma='auto', C=0.1) 66 | if classifier == 'LR': 67 | clf = LogisticRegression(penalty='l2', 68 | random_state=0, 69 | C=C, 70 | solver='lbfgs', 71 | max_iter=1000, 72 | multi_class='multinomial') 73 | clf.fit(support_features, support_ys) 74 | query_ys_pred = clf.predict(query_features) 75 | elif classifier == 'SVM': 76 | clf = make_pipeline(StandardScaler(), SVC(gamma='auto', 77 | C=1, 78 | kernel='linear', 79 | decision_function_shape='ovr')) 80 | clf.fit(support_features, support_ys) 81 | query_ys_pred = clf.predict(query_features) 82 | elif classifier == 'NN': 83 | query_ys_pred = NN(support_features, support_ys, query_features) 84 | elif classifier == 'Cosine': 85 | query_ys_pred = Cosine(support_features, support_ys, query_features) 86 | elif classifier == 'Proto': 87 | query_ys_pred = Proto(support_features, support_ys, query_features, opt) 88 | else: 89 | raise NotImplementedError('classifier not supported: {}'.format(classifier)) 90 | 91 | acc.append(metrics.accuracy_score(query_ys, query_ys_pred)) 92 | 93 | return mean_confidence_interval(acc) 94 | 95 | 96 | def Proto(support, support_ys, query, opt): 97 | """Protonet classifier""" 98 | nc = support.shape[-1] 99 | support = np.reshape(support, (-1, 1, opt.n_ways, opt.n_shots, nc)) 100 | support = support.mean(axis=3) 101 | batch_size = support.shape[0] 102 | query = np.reshape(query, (batch_size, -1, 1, nc)) 103 | logits = - ((query - support)**2).sum(-1) 104 | pred = np.argmax(logits, axis=-1) 105 | pred = np.reshape(pred, (-1,)) 106 | return pred 107 | 108 | 109 | def NN(support, support_ys, query): 110 | """nearest classifier""" 111 | support = np.expand_dims(support.transpose(), 0) 112 | query = np.expand_dims(query, 2) 113 | 114 | diff = np.multiply(query - support, query - support) 115 | distance = diff.sum(1) 116 | min_idx = np.argmin(distance, axis=1) 117 | pred = [support_ys[idx] for idx in min_idx] 118 | return pred 119 | 120 | 121 | def Cosine(support, support_ys, query): 122 | """Cosine classifier""" 123 | support_norm = np.linalg.norm(support, axis=1, keepdims=True) 124 | support = support / support_norm 125 | query_norm = np.linalg.norm(query, axis=1, keepdims=True) 126 | query = query / query_norm 127 | 128 | cosine_distance = query @ support.transpose() 129 | max_idx = np.argmax(cosine_distance, axis=1) 130 | pred = [support_ys[idx] for idx in max_idx] 131 | return pred 132 | -------------------------------------------------------------------------------- /lib/datasets/DownsampledImageNet.py: -------------------------------------------------------------------------------- 1 | import os, sys, hashlib 2 | import numpy as np 3 | from PIL import Image 4 | import torch.utils.data as data 5 | if sys.version_info[0] == 2: 6 | import cPickle as pickle 7 | else: 8 | import pickle 9 | 10 | 11 | def calculate_md5(fpath, chunk_size=1024 * 1024): 12 | md5 = hashlib.md5() 13 | with open(fpath, 'rb') as f: 14 | for chunk in iter(lambda: f.read(chunk_size), b''): 15 | md5.update(chunk) 16 | return md5.hexdigest() 17 | 18 | 19 | def check_md5(fpath, md5, **kwargs): 20 | return md5 == calculate_md5(fpath, **kwargs) 21 | 22 | 23 | def check_integrity(fpath, md5=None): 24 | if not os.path.isfile(fpath): return False 25 | if md5 is None: return True 26 | else : return check_md5(fpath, md5) 27 | 28 | 29 | class ImageNet16(data.Dataset): 30 | # http://image-net.org/download-images 31 | # A Downsampled Variant of ImageNet as an Alternative to the CIFAR datasets 32 | # https://arxiv.org/pdf/1707.08819.pdf 33 | 34 | train_list = [ 35 | ['train_data_batch_1', '27846dcaa50de8e21a7d1a35f30f0e91'], 36 | ['train_data_batch_2', 'c7254a054e0e795c69120a5727050e3f'], 37 | ['train_data_batch_3', '4333d3df2e5ffb114b05d2ffc19b1e87'], 38 | ['train_data_batch_4', '1620cdf193304f4a92677b695d70d10f'], 39 | ['train_data_batch_5', '348b3c2fdbb3940c4e9e834affd3b18d'], 40 | ['train_data_batch_6', '6e765307c242a1b3d7d5ef9139b48945'], 41 | ['train_data_batch_7', '564926d8cbf8fc4818ba23d2faac7564'], 42 | ['train_data_batch_8', 'f4755871f718ccb653440b9dd0ebac66'], 43 | ['train_data_batch_9', 'bb6dd660c38c58552125b1a92f86b5d4'], 44 | ['train_data_batch_10','8f03f34ac4b42271a294f91bf480f29b'], 45 | ] 46 | valid_list = [ 47 | ['val_data', '3410e3017fdaefba8d5073aaa65e4bd6'], 48 | ] 49 | 50 | def __init__(self, root, train, transform, use_num_of_class_only=None): 51 | self.root = root 52 | self.transform = transform 53 | self.train = train # training set or valid set 54 | if not self._check_integrity(): raise RuntimeError('Dataset not found or corrupted.') 55 | 56 | if self.train: downloaded_list = self.train_list 57 | else : downloaded_list = self.valid_list 58 | self.data = [] 59 | self.targets = [] 60 | 61 | # now load the picked numpy arrays 62 | for i, (file_name, checksum) in enumerate(downloaded_list): 63 | file_path = os.path.join(self.root, file_name) 64 | #print ('Load {:}/{:02d}-th : {:}'.format(i, len(downloaded_list), file_path)) 65 | with open(file_path, 'rb') as f: 66 | if sys.version_info[0] == 2: 67 | entry = pickle.load(f) 68 | else: 69 | entry = pickle.load(f, encoding='latin1') 70 | self.data.append(entry['data']) 71 | self.targets.extend(entry['labels']) 72 | self.data = np.vstack(self.data).reshape(-1, 3, 16, 16) 73 | self.data = self.data.transpose((0, 2, 3, 1)) # convert to HWC 74 | if use_num_of_class_only is not None: 75 | assert isinstance(use_num_of_class_only, int) and use_num_of_class_only > 0 and use_num_of_class_only < 1000, 'invalid use_num_of_class_only : {:}'.format(use_num_of_class_only) 76 | new_data, new_targets = [], [] 77 | for I, L in zip(self.data, self.targets): 78 | if 1 <= L <= use_num_of_class_only: 79 | new_data.append( I ) 80 | new_targets.append( L ) 81 | self.data = new_data 82 | self.targets = new_targets 83 | # self.mean.append(entry['mean']) 84 | #self.mean = np.vstack(self.mean).reshape(-1, 3, 16, 16) 85 | #self.mean = np.mean(np.mean(np.mean(self.mean, axis=0), axis=1), axis=1) 86 | #print ('Mean : {:}'.format(self.mean)) 87 | #temp = self.data - np.reshape(self.mean, (1, 1, 1, 3)) 88 | #std_data = np.std(temp, axis=0) 89 | #std_data = np.mean(np.mean(std_data, axis=0), axis=0) 90 | #print ('Std : {:}'.format(std_data)) 91 | 92 | def __getitem__(self, index): 93 | img, target = self.data[index], self.targets[index] - 1 94 | 95 | img = Image.fromarray(img) 96 | 97 | if self.transform is not None: 98 | img = self.transform(img) 99 | 100 | return img, target 101 | 102 | def __len__(self): 103 | return len(self.data) 104 | 105 | def _check_integrity(self): 106 | root = self.root 107 | for fentry in (self.train_list + self.valid_list): 108 | filename, md5 = fentry[0], fentry[1] 109 | fpath = os.path.join(root, filename) 110 | if not check_integrity(fpath, md5): 111 | return False 112 | return True 113 | 114 | 115 | if __name__ == '__main__': 116 | train = ImageNet16('/data02/dongxuanyi/.torch/cifar.python/ImageNet16', True , None) 117 | valid = ImageNet16('/data02/dongxuanyi/.torch/cifar.python/ImageNet16', False, None) 118 | 119 | print ( len(train) ) 120 | print ( len(valid) ) 121 | image, label = train[111] 122 | trainX = ImageNet16('/data02/dongxuanyi/.torch/cifar.python/ImageNet16', True , None, 200) 123 | validX = ImageNet16('/data02/dongxuanyi/.torch/cifar.python/ImageNet16', False , None, 200) 124 | print ( len(trainX) ) 125 | print ( len(validX) ) 126 | #import pdb; pdb.set_trace() 127 | -------------------------------------------------------------------------------- /eval_lib/rfs_models/operations.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | OPS = { 6 | 'none': lambda C, stride, affine: Zero(stride), 7 | 'avg_pool_3x3': lambda C, stride, affine: nn.AvgPool2d(3, stride=stride, padding=1, count_include_pad=False), 8 | 'max_pool_3x3': lambda C, stride, affine: nn.MaxPool2d(3, stride=stride, padding=1), 9 | 'skip_connect': lambda C, stride, affine: Identity() if stride == 1 else FactorizedReduce(C, C, affine=affine), 10 | 'sep_conv_3x3': lambda C, stride, affine: SepConv(C, C, 3, stride, 1, affine=affine), 11 | 'sep_conv_5x5': lambda C, stride, affine: SepConv(C, C, 5, stride, 2, affine=affine), 12 | 'sep_conv_7x7': lambda C, stride, affine: SepConv(C, C, 7, stride, 3, affine=affine), 13 | 'dil_conv_3x3': lambda C, stride, affine: DilConv(C, C, 3, stride, 2, 2, affine=affine), 14 | 'dil_conv_5x5': lambda C, stride, affine: DilConv(C, C, 5, stride, 4, 2, affine=affine), 15 | 'conv_7x1_1x7': lambda C, stride, affine: nn.Sequential( 16 | nn.ReLU(inplace=False), 17 | nn.Conv2d(C, C, (1, 7), stride=(1, stride), padding=(0, 3), bias=False), 18 | nn.Conv2d(C, C, (7, 1), stride=(stride, 1), padding=(3, 0), bias=False), 19 | nn.BatchNorm2d(C, affine=affine) 20 | ), 21 | } 22 | 23 | def drop_path_(x, drop_prob, training): 24 | if training and drop_prob > 0.0: 25 | keep_prob = 1.0 - drop_prob 26 | # per data point mask; assuming x in cuda. 27 | mask = torch.cuda.FloatTensor(x.size(0), 1, 1, 1).bernoulli_(keep_prob) 28 | x.div_(keep_prob).mul_(mask) 29 | 30 | return x 31 | 32 | class StdConv(nn.Module): 33 | """Standard conv 34 | ReLU - Conv - BN 35 | """ 36 | 37 | def __init__(self, C_in, C_out, kernel_size, stride, padding, affine=True): 38 | super().__init__() 39 | self.net = nn.Sequential( 40 | nn.ReLU(), 41 | nn.Conv2d(C_in, C_out, kernel_size, stride, padding, bias=False), 42 | nn.BatchNorm2d(C_out, affine=affine), 43 | ) 44 | 45 | def forward(self, x): 46 | return self.net(x) 47 | 48 | class DropPath_(nn.Module): 49 | def __init__(self, p=0.0): 50 | """[!] DropPath is inplace module 51 | Args: 52 | p: probability of an path to be zeroed. 53 | """ 54 | super().__init__() 55 | self.p = p 56 | 57 | def extra_repr(self): 58 | return f"p={self.p}, inplace" 59 | 60 | def forward(self, x): 61 | drop_path_(x, self.p, self.training) 62 | 63 | return x 64 | 65 | class ReLUConvBN(nn.Module): 66 | def __init__(self, C_in, C_out, kernel_size, stride, padding, affine=True): 67 | super(ReLUConvBN, self).__init__() 68 | self.op = nn.Sequential( 69 | nn.ReLU(inplace=False), 70 | nn.Conv2d(C_in, C_out, kernel_size, stride=stride, padding=padding, bias=False), 71 | nn.BatchNorm2d(C_out, affine=affine) 72 | ) 73 | 74 | def forward(self, x): 75 | return self.op(x) 76 | 77 | 78 | class DilConv(nn.Module): 79 | def __init__(self, C_in, C_out, kernel_size, stride, padding, dilation, affine=True): 80 | super(DilConv, self).__init__() 81 | self.op = nn.Sequential( 82 | nn.ReLU(inplace=False), 83 | nn.Conv2d(C_in, C_in, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, 84 | groups=C_in, bias=False), 85 | nn.Conv2d(C_in, C_out, kernel_size=1, padding=0, bias=False), 86 | nn.BatchNorm2d(C_out, affine=affine), 87 | ) 88 | 89 | def forward(self, x): 90 | return self.op(x) 91 | 92 | 93 | class SepConv(nn.Module): 94 | def __init__(self, C_in, C_out, kernel_size, stride, padding, affine=True): 95 | super(SepConv, self).__init__() 96 | self.op = nn.Sequential( 97 | nn.ReLU(inplace=False), 98 | nn.Conv2d(C_in, C_in, kernel_size=kernel_size, stride=stride, padding=padding, groups=C_in, bias=False), 99 | nn.Conv2d(C_in, C_in, kernel_size=1, padding=0, bias=False), 100 | nn.BatchNorm2d(C_in, affine=affine), 101 | nn.ReLU(inplace=False), 102 | nn.Conv2d(C_in, C_in, kernel_size=kernel_size, stride=1, padding=padding, groups=C_in, bias=False), 103 | nn.Conv2d(C_in, C_out, kernel_size=1, padding=0, bias=False), 104 | nn.BatchNorm2d(C_out, affine=affine), 105 | ) 106 | 107 | def forward(self, x): 108 | return self.op(x) 109 | 110 | 111 | class Identity(nn.Module): 112 | def __init__(self): 113 | super(Identity, self).__init__() 114 | 115 | def forward(self, x): 116 | return x 117 | 118 | 119 | class Zero(nn.Module): 120 | def __init__(self, stride): 121 | super(Zero, self).__init__() 122 | self.stride = stride 123 | 124 | def forward(self, x): 125 | if self.stride == 1: 126 | return x.mul(0.) 127 | return x[:, :, ::self.stride, ::self.stride].mul(0.) 128 | 129 | 130 | class FactorizedReduce(nn.Module): 131 | def __init__(self, C_in, C_out, affine=True): 132 | super(FactorizedReduce, self).__init__() 133 | assert C_out % 2 == 0 134 | self.relu = nn.ReLU(inplace=False) 135 | self.conv_1 = nn.Conv2d(C_in, C_out // 2, 1, stride=2, padding=0, bias=False) 136 | self.conv_2 = nn.Conv2d(C_in, C_out // 2, 1, stride=2, padding=0, bias=False) 137 | self.bn = nn.BatchNorm2d(C_out, affine=affine) 138 | 139 | def forward(self, x): 140 | x = self.relu(x) 141 | if x.size(2)%2!=0: 142 | x = F.pad(x, (1,0,1,0), "constant", 0) 143 | out = torch.cat([self.conv_1(x), self.conv_2(x[:, :, 1:, 1:])], dim=1) 144 | out = self.bn(out) 145 | return out -------------------------------------------------------------------------------- /lib/models/cell_searchs/search_model_darts.py: -------------------------------------------------------------------------------- 1 | import random 2 | import torch 3 | import torch.nn as nn 4 | from copy import deepcopy 5 | from ..cell_operations import ResNetBasicblock 6 | from .search_cells import NAS201SearchCell as SearchCell 7 | from .genotypes import Structure 8 | import torch.nn.functional as F 9 | from pdb import set_trace as bp 10 | 11 | 12 | def cal_entropy(logit: torch.Tensor, dim=-1) -> torch.Tensor: 13 | """ 14 | :param logit: An unnormalized vector. 15 | :param dim: ~ 16 | :return: entropy 17 | """ 18 | prob = F.softmax(logit, dim=dim) 19 | log_prob = F.log_softmax(logit, dim=dim) 20 | 21 | entropy = -(log_prob * prob).sum(-1, keepdim=False) 22 | 23 | return entropy 24 | 25 | 26 | class TinyNetworkDarts(nn.Module): 27 | 28 | def __init__(self, C, N, max_nodes, num_classes, search_space, affine, track_running_stats, depth=-1, use_stem=True): 29 | super(TinyNetworkDarts, self).__init__() 30 | self._C = C 31 | self._layerN = N # number of stacked cell at each stage 32 | self.max_nodes = max_nodes 33 | self.use_stem = use_stem 34 | self.stem = nn.Sequential(nn.Conv2d(min(3, C), C, kernel_size=3, padding=1, bias=False), 35 | nn.BatchNorm2d(C)) 36 | 37 | layer_channels = [C ] * N + [C*2 ] + [C*2 ] * N + [C*4 ] + [C*4 ] * N 38 | layer_reductions = [False] * N + [True] + [False] * N + [True] + [False] * N 39 | 40 | C_prev, num_edge, edge2index = C, None, None 41 | self.cells = nn.ModuleList() 42 | for index, (C_curr, reduction) in enumerate(zip(layer_channels, layer_reductions)): 43 | if depth > 0 and index >= depth: break 44 | if reduction: 45 | cell = ResNetBasicblock(C_prev, C_curr, 2) 46 | else: 47 | cell = SearchCell(C_prev, C_curr, 1, max_nodes, search_space, affine, track_running_stats) 48 | if num_edge is None: num_edge, edge2index = cell.num_edges, cell.edge2index 49 | else: assert num_edge == cell.num_edges and edge2index == cell.edge2index, 'invalid {:} vs. {:}.'.format(num_edge, cell.num_edges) 50 | self.cells.append( cell ) 51 | C_prev = cell.out_dim 52 | self.op_names = deepcopy( search_space ) 53 | self._Layer = len(self.cells) 54 | self.edge2index = edge2index 55 | self.lastact = nn.Sequential(nn.BatchNorm2d(C_prev), nn.ReLU(inplace=True)) 56 | self.global_pooling = nn.AdaptiveAvgPool2d(1) 57 | self.classifier = nn.Linear(C_prev, num_classes) 58 | self.arch_parameters = nn.Parameter( 1e-3*torch.randn(num_edge, len(search_space)) ) 59 | 60 | def entropy(self, mean=True): 61 | if mean: 62 | return cal_entropy(self.arch_parameters, -1).mean().view(-1) 63 | else: 64 | return cal_entropy(self.arch_parameters, -1) 65 | 66 | def get_weights(self): 67 | xlist = list( self.stem.parameters() ) + list( self.cells.parameters() ) 68 | xlist += list( self.lastact.parameters() ) + list( self.global_pooling.parameters() ) 69 | xlist += list( self.classifier.parameters() ) 70 | return xlist 71 | 72 | def get_alphas(self): 73 | return [self.arch_parameters] 74 | 75 | def set_alphas(self, arch_parameters): 76 | self.arch_parameters.data.copy_(arch_parameters[0].data) 77 | 78 | def show_alphas(self): 79 | with torch.no_grad(): 80 | return 'arch-parameters :\n{:}'.format( nn.functional.softmax(self.arch_parameters, dim=-1).cpu() ) 81 | 82 | def get_message(self): 83 | string = self.extra_repr() 84 | for i, cell in enumerate(self.cells): 85 | string += '\n {:02d}/{:02d} :: {:}'.format(i, len(self.cells), cell.extra_repr()) 86 | return string 87 | 88 | def extra_repr(self): 89 | return ('{name}(C={_C}, Max-Nodes={max_nodes}, N={_layerN}, L={_Layer})'.format(name=self.__class__.__name__, **self.__dict__)) 90 | 91 | def genotype(self, get_random=False, hardwts=None): 92 | genotypes = [] 93 | for i in range(1, self.max_nodes): 94 | xlist = [] 95 | for j in range(i): 96 | node_str = '{:}<-{:}'.format(i, j) 97 | with torch.no_grad(): 98 | if hardwts is not None: 99 | weights = hardwts[ self.edge2index[node_str] ] 100 | op_name = self.op_names[ weights.argmax().item() ] 101 | elif get_random: 102 | op_name = random.choice(self.op_names) 103 | else: 104 | weights = self.arch_parameters[ self.edge2index[node_str] ] 105 | op_name = self.op_names[ weights.argmax().item() ] 106 | xlist.append((op_name, j)) 107 | genotypes.append( tuple(xlist) ) 108 | return Structure( genotypes ) 109 | 110 | def forward(self, inputs, return_features=False): 111 | alphas = nn.functional.softmax(self.arch_parameters, dim=-1) 112 | features_all = [] 113 | if self.use_stem: 114 | feature = self.stem(inputs) 115 | else: 116 | feature = inputs 117 | features_all.append(feature.detach()) 118 | for i, cell in enumerate(self.cells): 119 | if isinstance(cell, SearchCell): 120 | feature = cell(feature, alphas) 121 | else: 122 | feature = cell(feature) 123 | features_all.append(feature.detach()) 124 | 125 | out = self.lastact(feature) 126 | out = self.global_pooling( out ) 127 | out = out.view(out.size(0), -1) 128 | logits = self.classifier(out) 129 | 130 | if return_features: 131 | return out, logits, features_all 132 | else: 133 | return out, logits 134 | -------------------------------------------------------------------------------- /lib/datasets/transform_cfg.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import numpy as np 4 | from PIL import Image 5 | import torchvision.transforms as transforms 6 | 7 | 8 | mean = [120.39586422 / 255.0, 115.59361427 / 255.0, 104.54012653 / 255.0] 9 | std = [70.68188272 / 255.0, 68.27635443 / 255.0, 72.54505529 / 255.0] 10 | normalize = transforms.Normalize(mean=mean, std=std) 11 | 12 | 13 | transform_A = [ 14 | transforms.Compose([ 15 | lambda x: Image.fromarray(x), 16 | transforms.RandomCrop(84, padding=8), 17 | transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4), 18 | transforms.RandomHorizontalFlip(), 19 | lambda x: np.asarray(x), 20 | transforms.ToTensor(), 21 | normalize 22 | ]), 23 | 24 | transforms.Compose([ 25 | lambda x: Image.fromarray(x), 26 | transforms.ToTensor(), 27 | normalize 28 | ]) 29 | ] 30 | 31 | 32 | transform_B = [ 33 | transforms.Compose([ 34 | lambda x: Image.fromarray(x), 35 | transforms.RandomResizedCrop(84, scale=(0.2, 1.0)), 36 | transforms.RandomHorizontalFlip(), 37 | lambda x: np.asarray(x), 38 | transforms.ToTensor(), 39 | transforms.Normalize(np.array([x / 255.0 for x in [125.3, 123.0, 113.9]]), 40 | np.array([x / 255.0 for x in [63.0, 62.1, 66.7]])) 41 | ]), 42 | 43 | transforms.Compose([ 44 | lambda x: Image.fromarray(x), 45 | transforms.Resize(92), 46 | transforms.CenterCrop(84), 47 | transforms.ToTensor(), 48 | transforms.Normalize(np.array([x / 255.0 for x in [125.3, 123.0, 113.9]]), 49 | np.array([x / 255.0 for x in [63.0, 62.1, 66.7]])) 50 | ]) 51 | ] 52 | 53 | transform_C = [ 54 | transforms.Compose([ 55 | lambda x: Image.fromarray(x), 56 | # transforms.Resize(92, interpolation = PIL.Image.BICUBIC), 57 | transforms.RandomResizedCrop(80), 58 | # transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4), 59 | transforms.RandomHorizontalFlip(), 60 | transforms.ToTensor(), 61 | # Lighting(0.1, imagenet_pca['eigval'], imagenet_pca['eigvec']), 62 | # normalize 63 | transforms.Normalize(np.array([x / 255.0 for x in [125.3, 123.0, 113.9]]), 64 | np.array([x / 255.0 for x in [63.0, 62.1, 66.7]])) 65 | ]), 66 | 67 | transforms.Compose([ 68 | lambda x: Image.fromarray(x), 69 | transforms.Resize(92), 70 | transforms.CenterCrop(80), 71 | transforms.ToTensor(), 72 | # normalize 73 | transforms.Normalize(np.array([x / 255.0 for x in [125.3, 123.0, 113.9]]), 74 | np.array([x / 255.0 for x in [63.0, 62.1, 66.7]])) 75 | ]) 76 | ] 77 | 78 | transform_D = [ 79 | transforms.Compose([ 80 | lambda x: Image.fromarray(x), 81 | # transforms.Resize(92, interpolation = PIL.Image.BICUBIC), 82 | transforms.RandomResizedCrop(80), 83 | # transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4), 84 | transforms.RandomHorizontalFlip(), 85 | transforms.ToTensor(), 86 | # Lighting(0.1, imagenet_pca['eigval'], imagenet_pca['eigvec']), 87 | # normalize 88 | transforms.Normalize(np.array([x / 255.0 for x in [125.3, 123.0, 113.9]]), 89 | np.array([x / 255.0 for x in [63.0, 62.1, 66.7]])) 90 | ]), 91 | 92 | transforms.Compose([ 93 | lambda x: Image.fromarray(x), 94 | transforms.Resize(92), 95 | transforms.CenterCrop(80), 96 | transforms.ToTensor(), 97 | # normalize 98 | transforms.Normalize(np.array([x / 255.0 for x in [125.3, 123.0, 113.9]]), 99 | np.array([x / 255.0 for x in [63.0, 62.1, 66.7]])) 100 | ]) 101 | ] 102 | 103 | # Transform adopted from TE-NAS imagenet-1K 104 | transform_TENAS = [ 105 | transforms.Compose([ 106 | # lambda x: Image.fromarray(x), 107 | transforms.ToPILImage(), 108 | # transforms.Resize(92, interpolation = PIL.Image.BICUBIC), 109 | transforms.Resize((32, 32), interpolation=2), 110 | # transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4), 111 | transforms.RandomHorizontalFlip(), 112 | transforms.RandomCrop(32, padding=4), 113 | # lambda x: np.asarray(x), 114 | transforms.ToTensor(), 115 | # Lighting(0.1, imagenet_pca['eigval'], imagenet_pca['eigvec']), 116 | # normalize 117 | transforms.Normalize(np.array([x / 255.0 for x in [125.3, 123.0, 113.9]]), 118 | np.array([x / 255.0 for x in [63.0, 62.1, 66.7]])) 119 | ]), 120 | 121 | transforms.Compose([ 122 | # lambda x: Image.fromarray(x), 123 | transforms.ToPILImage(), 124 | transforms.Resize((32, 32), interpolation=2), 125 | # lambda x: np.asarray(x), 126 | transforms.ToTensor(), 127 | # normalize 128 | transforms.Normalize(np.array([x / 255.0 for x in [125.3, 123.0, 113.9]]), 129 | np.array([x / 255.0 for x in [63.0, 62.1, 66.7]])) 130 | ]) 131 | ] 132 | 133 | # CIFAR style transformation 134 | mean = [0.5071, 0.4867, 0.4408] 135 | std = [0.2675, 0.2565, 0.2761] 136 | normalize_cifar100 = transforms.Normalize(mean=mean, std=std) 137 | transform_D = [ 138 | transforms.Compose([ 139 | lambda x: Image.fromarray(x), 140 | transforms.RandomCrop(32, padding=4), 141 | transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4), 142 | transforms.RandomHorizontalFlip(), 143 | lambda x: np.asarray(x), 144 | transforms.ToTensor(), 145 | normalize_cifar100 146 | ]), 147 | 148 | transforms.Compose([ 149 | lambda x: Image.fromarray(x), 150 | transforms.ToTensor(), 151 | normalize_cifar100 152 | ]) 153 | ] 154 | 155 | 156 | transforms_list = ['A', 'B', 'C', 'D', 'TENAS'] 157 | 158 | 159 | transforms_options = { 160 | 'A': transform_A, 161 | 'B': transform_B, 162 | 'C': transform_C, 163 | 'D': transform_D, 164 | 'TENAS': transform_TENAS, 165 | } 166 | -------------------------------------------------------------------------------- /lib/procedures/.ipynb_checkpoints/metantk_test-checkpoint.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import numpy as np\n", 10 | "import scipy\n", 11 | "import torch\n", 12 | "from metantk import get_analytical_metantk_n,check_symmetric" 13 | ] 14 | }, 15 | { 16 | "cell_type": "code", 17 | "execution_count": 2, 18 | "metadata": {}, 19 | "outputs": [], 20 | "source": [ 21 | "n_tasks = 10\n", 22 | "n_queries = 5\n", 23 | "n_support = 10\n", 24 | "N = n_tasks*(n_queries+n_support)\n", 25 | "D_ntk = 1000\n", 26 | "D_K = 200\n", 27 | "grad = np.random.randn(N,D_ntk)\n", 28 | "ntk = np.matmul(grad, np.transpose(grad))\n", 29 | "grad = np.random.randn(N,D_K)\n", 30 | "K = np.matmul(grad, np.transpose(grad))\n", 31 | "ntk = torch.from_numpy(ntk)\n", 32 | "K = torch.from_numpy(K)" 33 | ] 34 | }, 35 | { 36 | "cell_type": "code", 37 | "execution_count": 7, 38 | "metadata": {}, 39 | "outputs": [ 40 | { 41 | "data": { 42 | "text/plain": [ 43 | "150" 44 | ] 45 | }, 46 | "execution_count": 7, 47 | "metadata": {}, 48 | "output_type": "execute_result" 49 | } 50 | ], 51 | "source": [ 52 | "np.linalg.matrix_rank(K)" 53 | ] 54 | }, 55 | { 56 | "cell_type": "code", 57 | "execution_count": 3, 58 | "metadata": {}, 59 | "outputs": [], 60 | "source": [ 61 | "result = get_analytical_metantk_n(ntk, K, n_tasks, n_queries, n_support, inner_lr_time=1, reg_coef=1e-3, algorithm='ANIL')" 62 | ] 63 | }, 64 | { 65 | "cell_type": "code", 66 | "execution_count": 4, 67 | "metadata": {}, 68 | "outputs": [ 69 | { 70 | "data": { 71 | "text/plain": [ 72 | "{'metantk': array([[1007.88965584, -32.38164702, -31.97877812, ..., 5.30027156,\n", 73 | " 2.07870873, 37.68027217],\n", 74 | " [ -32.38164702, 1065.36747308, -17.41538723, ..., -56.67162961,\n", 75 | " -17.8426898 , -31.2356895 ],\n", 76 | " [ -31.97877812, -17.41538723, 1047.2717075 , ..., -47.0635837 ,\n", 77 | " 7.61168848, 24.41002514],\n", 78 | " ...,\n", 79 | " [ 5.30027156, -56.67162961, -47.0635837 , ..., 1008.23734564,\n", 80 | " -50.74651123, 12.27198125],\n", 81 | " [ 2.07870873, -17.8426898 , 7.61168848, ..., -50.74651123,\n", 82 | " 1182.44392767, -16.24570712],\n", 83 | " [ 37.68027217, -31.2356895 , 24.41002514, ..., 12.27198125,\n", 84 | " -16.24570712, 921.09015251]]),\n", 85 | " 'eigenvalues': array([ 641.96964711, 669.22847567, 687.57091895, 713.29101724,\n", 86 | " 739.1053324 , 751.44989869, 771.39610582, 781.90397226,\n", 87 | " 814.47845034, 821.26775373, 836.61040734, 850.28986802,\n", 88 | " 857.46909142, 875.49676558, 889.59759649, 901.79305259,\n", 89 | " 919.61371814, 939.47265872, 946.91758548, 955.73719118,\n", 90 | " 980.94519674, 1007.23893921, 1018.03949701, 1031.14224978,\n", 91 | " 1037.32058679, 1053.67300246, 1059.12098904, 1081.4282596 ,\n", 92 | " 1098.5623824 , 1114.73258368, 1121.9328007 , 1134.57518305,\n", 93 | " 1154.2962264 , 1175.16227106, 1184.90156945, 1202.97946011,\n", 94 | " 1213.41754243, 1241.95945499, 1262.99493235, 1271.09277241,\n", 95 | " 1283.97616971, 1307.39140311, 1327.17502647, 1374.34725423,\n", 96 | " 1402.59562677, 1438.15132772, 1469.79654234, 1485.07734017,\n", 97 | " 1505.21843646, 1574.42802538])}" 98 | ] 99 | }, 100 | "execution_count": 4, 101 | "metadata": {}, 102 | "output_type": "execute_result" 103 | } 104 | ], 105 | "source": [ 106 | "result" 107 | ] 108 | }, 109 | { 110 | "cell_type": "code", 111 | "execution_count": 18, 112 | "metadata": {}, 113 | "outputs": [ 114 | { 115 | "data": { 116 | "text/plain": [ 117 | "[[array([[ 0, 1, 2, 3, 4],\n", 118 | " [10, 11, 12, 13, 14],\n", 119 | " [20, 21, 22, 23, 24]]),\n", 120 | " array([[ 5, 6, 7, 8, 9],\n", 121 | " [15, 16, 17, 18, 19],\n", 122 | " [25, 26, 27, 28, 29]])],\n", 123 | " [array([[30, 31, 32, 33, 34],\n", 124 | " [40, 41, 42, 43, 44],\n", 125 | " [50, 51, 52, 53, 54]]),\n", 126 | " array([[35, 36, 37, 38, 39],\n", 127 | " [45, 46, 47, 48, 49],\n", 128 | " [55, 56, 57, 58, 59]])]]" 129 | ] 130 | }, 131 | "execution_count": 18, 132 | "metadata": {}, 133 | "output_type": "execute_result" 134 | } 135 | ], 136 | "source": [ 137 | "[np.hsplit(sub_K,2) for sub_K in K_vsplit]" 138 | ] 139 | }, 140 | { 141 | "cell_type": "code", 142 | "execution_count": 24, 143 | "metadata": {}, 144 | "outputs": [ 145 | { 146 | "data": { 147 | "text/plain": [ 148 | "array([[ 3, 4, 5, 6, 7, 8, 9],\n", 149 | " [13, 14, 15, 16, 17, 18, 19],\n", 150 | " [23, 24, 25, 26, 27, 28, 29]])" 151 | ] 152 | }, 153 | "execution_count": 24, 154 | "metadata": {}, 155 | "output_type": "execute_result" 156 | } 157 | ], 158 | "source": [ 159 | "K[:3,3:]" 160 | ] 161 | }, 162 | { 163 | "cell_type": "code", 164 | "execution_count": null, 165 | "metadata": {}, 166 | "outputs": [], 167 | "source": [] 168 | } 169 | ], 170 | "metadata": { 171 | "kernelspec": { 172 | "display_name": "Python 3", 173 | "language": "python", 174 | "name": "python3" 175 | }, 176 | "language_info": { 177 | "codemirror_mode": { 178 | "name": "ipython", 179 | "version": 3 180 | }, 181 | "file_extension": ".py", 182 | "mimetype": "text/x-python", 183 | "name": "python", 184 | "nbconvert_exporter": "python", 185 | "pygments_lexer": "ipython3", 186 | "version": "3.7.4" 187 | } 188 | }, 189 | "nbformat": 4, 190 | "nbformat_minor": 4 191 | } 192 | -------------------------------------------------------------------------------- /eval_lib/rfs_models/resnet_new.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import math 4 | import torch.utils.model_zoo as model_zoo 5 | 6 | __all__ = ['ResNet', 'resnet50'] 7 | 8 | model_urls = { 9 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', 10 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', 11 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', 12 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', 13 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', 14 | } 15 | 16 | 17 | def conv3x3(in_planes, out_planes, stride=1): 18 | "3x3 convolution with padding" 19 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 20 | padding=1, bias=False) 21 | 22 | 23 | class Normalize(nn.Module): 24 | 25 | def __init__(self, power=2): 26 | super(Normalize, self).__init__() 27 | self.power = power 28 | 29 | def forward(self, x): 30 | norm = x.pow(self.power).sum(1, keepdim=True).pow(1. / self.power) 31 | out = x.div(norm) 32 | return out 33 | 34 | 35 | class BasicBlock(nn.Module): 36 | expansion = 1 37 | 38 | def __init__(self, inplanes, planes, stride=1, downsample=None): 39 | super(BasicBlock, self).__init__() 40 | self.conv1 = conv3x3(inplanes, planes, stride) 41 | self.bn1 = nn.BatchNorm2d(planes) 42 | self.relu = nn.ReLU(inplace=True) 43 | self.conv2 = conv3x3(planes, planes) 44 | self.bn2 = nn.BatchNorm2d(planes) 45 | self.downsample = downsample 46 | self.stride = stride 47 | 48 | def forward(self, x): 49 | residual = x 50 | 51 | out = self.conv1(x) 52 | out = self.bn1(out) 53 | out = self.relu(out) 54 | 55 | out = self.conv2(out) 56 | out = self.bn2(out) 57 | 58 | if self.downsample is not None: 59 | residual = self.downsample(x) 60 | 61 | out += residual 62 | out = self.relu(out) 63 | 64 | return out 65 | 66 | 67 | class Bottleneck(nn.Module): 68 | expansion = 4 69 | 70 | def __init__(self, inplanes, planes, stride=1, downsample=None): 71 | super(Bottleneck, self).__init__() 72 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 73 | self.bn1 = nn.BatchNorm2d(planes) 74 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 75 | padding=1, bias=False) 76 | self.bn2 = nn.BatchNorm2d(planes) 77 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) 78 | self.bn3 = nn.BatchNorm2d(planes * 4) 79 | self.relu = nn.ReLU(inplace=True) 80 | self.downsample = downsample 81 | self.stride = stride 82 | 83 | def forward(self, x): 84 | residual = x 85 | 86 | out = self.conv1(x) 87 | out = self.bn1(out) 88 | out = self.relu(out) 89 | 90 | out = self.conv2(out) 91 | out = self.bn2(out) 92 | out = self.relu(out) 93 | 94 | out = self.conv3(out) 95 | out = self.bn3(out) 96 | 97 | if self.downsample is not None: 98 | residual = self.downsample(x) 99 | 100 | out += residual 101 | out = self.relu(out) 102 | 103 | return out 104 | 105 | 106 | class ResNet(nn.Module): 107 | 108 | def __init__(self, block, layers, in_channel=3, width=1, num_classes=64): 109 | self.inplanes = 64 110 | super(ResNet, self).__init__() 111 | self.conv1 = nn.Conv2d(in_channel, 64, kernel_size=7, stride=2, padding=3, 112 | bias=False) 113 | self.bn1 = nn.BatchNorm2d(64) 114 | self.relu = nn.ReLU(inplace=True) 115 | 116 | self.base = int(64 * width) 117 | 118 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 119 | self.layer1 = self._make_layer(block, self.base, layers[0]) 120 | self.layer2 = self._make_layer(block, self.base * 2, layers[1], stride=2) 121 | self.layer3 = self._make_layer(block, self.base * 4, layers[2], stride=2) 122 | self.layer4 = self._make_layer(block, self.base * 8, layers[3], stride=2) 123 | self.avgpool = nn.AvgPool2d(3, stride=1) 124 | self.classifier = nn.Linear(self.base * 8 * block.expansion, num_classes) 125 | 126 | for m in self.modules(): 127 | if isinstance(m, nn.Conv2d): 128 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 129 | m.weight.data.normal_(0, math.sqrt(2. / n)) 130 | elif isinstance(m, nn.BatchNorm2d): 131 | m.weight.data.fill_(1) 132 | m.bias.data.zero_() 133 | 134 | def _make_layer(self, block, planes, blocks, stride=1): 135 | downsample = None 136 | if stride != 1 or self.inplanes != planes * block.expansion: 137 | downsample = nn.Sequential( 138 | nn.Conv2d(self.inplanes, planes * block.expansion, 139 | kernel_size=1, stride=stride, bias=False), 140 | nn.BatchNorm2d(planes * block.expansion), 141 | ) 142 | 143 | layers = [] 144 | layers.append(block(self.inplanes, planes, stride, downsample)) 145 | self.inplanes = planes * block.expansion 146 | for i in range(1, blocks): 147 | layers.append(block(self.inplanes, planes)) 148 | 149 | return nn.Sequential(*layers) 150 | 151 | def forward(self, x, is_feat=False): 152 | x = self.conv1(x) 153 | x = self.bn1(x) 154 | x = self.relu(x) 155 | x = self.maxpool(x) 156 | x = self.layer1(x) 157 | x = self.layer2(x) 158 | x = self.layer3(x) 159 | x = self.layer4(x) 160 | x = self.avgpool(x) 161 | x = x.view(x.size(0), -1) 162 | 163 | if is_feat: 164 | return [x], x 165 | 166 | x = self.classifier(x) 167 | return x 168 | 169 | 170 | def resnet50(pretrained=False, **kwargs): 171 | """Constructs a ResNet-50 model. 172 | Args: 173 | pretrained (bool): If True, returns a model pre-trained on ImageNet 174 | """ 175 | model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs) 176 | if pretrained: 177 | model.load_state_dict(model_zoo.load_url(model_urls['resnet50'])) 178 | return model 179 | 180 | 181 | if __name__ == '__main__': 182 | model = resnet50(num_classes=200) 183 | 184 | data = torch.randn(2, 3, 84, 84) 185 | -------------------------------------------------------------------------------- /prune_launch.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import argparse 4 | 5 | # TODO please configure data_paths before running, please leave TORCH_HOME empty 6 | TORCH_HOME = "" 7 | data_paths = { 8 | "MiniImageNet": "~/data/miniImageNet", 9 | "TieredImageNet": "~/data/tieredImageNet", 10 | "MetaMiniImageNet": "~/data/miniImageNet", 11 | "MetaTieredImageNet": "~/data/tieredImageNet", 12 | } 13 | 14 | 15 | parser = argparse.ArgumentParser("MetaNTK_NAS_launch") 16 | parser.add_argument('--gpu', default=0, type=int, help='use gpu with cuda number') 17 | parser.add_argument('--space', default='darts_fewshot', type=str, choices=['darts', 'darts_fewshot'], help='which nas search space to use') 18 | parser.add_argument('--dataset', default='MiniImageNet', type=str, choices=['MiniImageNet', 'TieredImageNet', 'MetaMiniImageNet', 'MetaTieredImageNet'], help='Choose from MiniImageNet/TieredImageNet') 19 | parser.add_argument('--seed', default=-1, type=int, help='manual seed, set to -1 for random seed') 20 | parser.add_argument('--max_nodes', default=3, type=int, help='number of max nodes, 4 for darts and 3 for darts_fewshot') 21 | parser.add_argument('--dartsbs', default=3, type=int, help = 'Batch size of NTK/MetaNTK when using darts or darts_fewshot search space on imagenet subset (mini, imagenet-1k), default 24.') 22 | parser.add_argument('--ntk_type', default='MetaNTK_anl', type=str, choices = ['NTK', 'MetaNTK_anl'], help = 'To compute NTK or MetaNtk') 23 | parser.add_argument('--ntk_channels', type=int, default=48, help='initial channels of small network for computing NTKs. To use Opacus, use 16n channels.') 24 | parser.add_argument('--ntk_layers', type=int, default=5, help='number of layers of small network for computing NTKs') 25 | parser.add_argument('--only_lrs', choices = ['true', 'false'], default='false', help='Use only linear regions') 26 | 27 | # Arguments for computing analytical MetaNTK 28 | parser.add_argument('--algorithm', type=str, default='MAML', choices = ['ANIL', 'MAML'], help='Algorithm for computing analytical MetaNTK') 29 | parser.add_argument('--inner_lr_time', type=float, default=1000.0, help='the product of inner loop learning rate & training time') 30 | parser.add_argument('--reg_coef', type=float, default=1e-3, help='the regularization coefficient for the inner loop optimization. suggest >=1e-5') 31 | 32 | # Train after search 33 | parser.add_argument('--train_after_search', choices = ['true', 'false'], default='false', help='If directly train after search with RFS.') 34 | parser.add_argument('--aug_dp', type=float, default=0.2, help='Drop probability of augmentCNN') 35 | parser.add_argument('--aug_channels', type=int, default=48, help='Init channels for network during augmentation') 36 | parser.add_argument('--aug_layers', type=int, default=5, help='Number of layers for network during augmentation') 37 | parser.add_argument('--aug_lr', type=float, default=0.02, help='Learning rate for network during augmentation') 38 | parser.add_argument('--aug_batchsize', type=int, default=64, help='Batch size for network during augmentation') 39 | parser.add_argument('--aug_epochs', type=int, default=100, help='Total number of epochs for network during augmentation') 40 | parser.add_argument('--aug_lr_decay_epochs', type=str, default='60,80', help='Learning rate decay epochs during augmentation') 41 | 42 | 43 | args = parser.parse_args() 44 | 45 | ##### Basic Settings 46 | precision = 3 47 | # init = 'normal' 48 | # init = 'kaiming_uniform' 49 | init = 'kaiming_normal' 50 | 51 | space = args.space 52 | super_type = "nasnet-super" 53 | batch_size = args.dartsbs 54 | 55 | if args.ntk_type == 'MetaNTK_anl': 56 | assert args.dataset in ['MetaMiniImageNet', 'MetaTieredImageNet'], 'To use MetaNTK-NAS, please use meta version of the dataset.' 57 | 58 | # ONLY TRAIN AFTER SEARCH FOR OUR SETTINGS 59 | if args.train_after_search == 'true': 60 | assert (args.ntk_layers in [5,8]) and (args.ntk_channels == 48) 61 | args.aug_channels = args.ntk_channels 62 | args.aug_layers = args.ntk_layers 63 | if args.dataset in ['MiniImageNet', 'MetaMiniImageNet']: 64 | args.aug_dp = 0.2 65 | if args.ntk_layers == 8: 66 | args.aug_batchsize = 40 # Or change to 64 if your memory is big enough 67 | else: 68 | args.aug_batchsize = 64 69 | args.aug_lr = 0.02 70 | args.aug_epochs = 100 71 | args.aug_lr_decay_epochs = '60,80' 72 | elif args.dataset in ['TieredImageNet', 'MetaTieredImageNet']: 73 | args.aug_dp = 0.1 74 | if args.ntk_layers == 8: 75 | args.aug_batchsize = 56 # Or change to 64 if your memory is big enough 76 | else: 77 | args.aug_batchsize = 64 78 | args.aug_lr = 0.01 79 | args.aug_epochs = 60 80 | args.aug_lr_decay_epochs = '30,40,50' 81 | else: 82 | raise NotImplementedError 83 | 84 | timestamp = "{:}".format(time.strftime('%h-%d-%C_%H-%M-%s', time.gmtime(time.time()))) 85 | 86 | 87 | core_cmd = "CUDA_VISIBLE_DEVICES={gpuid} OMP_NUM_THREADS=4 python ./prune_metantknas.py \ 88 | --save_dir {save_dir} --max_nodes {max_nodes} \ 89 | --dataset {dataset} \ 90 | --data_path {data_path} \ 91 | --search_space_name {space} \ 92 | --super_type {super_type} \ 93 | --arch_nas_dataset {TORCH_HOME}/NAS-Bench-201-v1_0-e61699.pth \ 94 | --track_running_stats 1 \ 95 | --workers 0 --rand_seed {seed} \ 96 | --timestamp {timestamp} \ 97 | --precision {precision} \ 98 | --init {init} \ 99 | --repeat 3 \ 100 | --batch_size {batch_size} \ 101 | --ntk_type {ntk_type} \ 102 | --algorithm {algorithm} \ 103 | --inner_lr_time {inner_lr_time} \ 104 | --reg_coef {reg_coef} \ 105 | --train_after_search {train_after_search} \ 106 | --ntk_channels {ntk_channels} \ 107 | --ntk_layers {ntk_layers} \ 108 | --only_lrs {only_lrs} \ 109 | --aug_dp {aug_dp} \ 110 | --aug_channels {aug_channels} \ 111 | --aug_layers {aug_layers} \ 112 | --aug_lr {aug_lr} \ 113 | --aug_batchsize {aug_batchsize} \ 114 | --aug_epochs {aug_epochs} \ 115 | --aug_lr_decay_epochs {aug_lr_decay_epochs} \ 116 | ".format( 117 | gpuid=args.gpu, 118 | save_dir="./output/prune-{space}/{dataset}".format(space=space, dataset=args.dataset), 119 | max_nodes=args.max_nodes, 120 | data_path=data_paths[args.dataset], 121 | dataset=args.dataset, 122 | TORCH_HOME=TORCH_HOME, 123 | space=space, 124 | super_type=super_type, 125 | seed=args.seed, 126 | timestamp=timestamp, 127 | precision=precision, 128 | init=init, 129 | batch_size=batch_size, 130 | ntk_type=args.ntk_type, 131 | algorithm=args.algorithm, 132 | inner_lr_time=args.inner_lr_time, 133 | reg_coef=args.reg_coef, 134 | train_after_search=args.train_after_search, 135 | ntk_channels=args.ntk_channels, 136 | ntk_layers=args.ntk_layers, 137 | only_lrs=args.only_lrs, 138 | aug_dp=args.aug_dp, 139 | aug_channels=args.aug_channels, 140 | aug_layers=args.aug_layers, 141 | aug_lr=args.aug_lr, 142 | aug_batchsize=args.aug_batchsize, 143 | aug_epochs=args.aug_epochs, 144 | aug_lr_decay_epochs=args.aug_lr_decay_epochs, 145 | ) 146 | 147 | os.system(core_cmd) 148 | -------------------------------------------------------------------------------- /random_baseline.py: -------------------------------------------------------------------------------- 1 | import os, sys, time, argparse 2 | import math 3 | import random 4 | from easydict import EasyDict as edict 5 | import numpy as np 6 | import torch 7 | from torch import nn 8 | from tqdm import tqdm 9 | from pathlib import Path 10 | lib_dir = (Path(__file__).parent / 'lib').resolve() 11 | if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir)) 12 | from procedures import prepare_seed, prepare_logger 13 | from models import get_cell_based_tiny_net, get_search_spaces # , nas_super_nets 14 | import eval_searched_arch 15 | 16 | def prepare_seed(rand_seed): 17 | random.seed(rand_seed) 18 | np.random.seed(rand_seed) 19 | torch.manual_seed(rand_seed) 20 | torch.cuda.manual_seed(rand_seed) 21 | torch.cuda.manual_seed_all(rand_seed) 22 | 23 | def main(xargs): 24 | PID = os.getpid() 25 | prepare_seed(xargs.rand_seed) 26 | assert torch.cuda.is_available(), 'CUDA is not available.' 27 | torch.backends.cudnn.enabled = True 28 | torch.backends.cudnn.benchmark = False 29 | torch.backends.cudnn.deterministic = True 30 | 31 | if xargs.timestamp == 'none': 32 | xargs.timestamp = "{:}".format(time.strftime('%h-%d-%C_%H-%M-%s', time.gmtime(time.time()))) 33 | 34 | ##### logging ##### 35 | xargs.save_dir = xargs.save_dir + \ 36 | "Randombaseline" + \ 37 | "/{:}/seed{:}".format(xargs.timestamp, xargs.rand_seed) 38 | logger = prepare_logger(xargs) 39 | ############### 40 | 41 | search_space = get_search_spaces('cell', xargs.search_space_name) 42 | if xargs.search_space_name == 'nas-bench-201': 43 | model_config = edict({'name': 'DARTS-V1', 44 | 'C': 3, 'N': 1, 'depth': -1, 'use_stem': True, 45 | 'max_nodes': xargs.max_nodes, 'num_classes': 1, 46 | 'space': search_space, 47 | 'affine': True, 'track_running_stats': bool(xargs.track_running_stats), 48 | 'ntk_type': 'NTK', 49 | }) 50 | 51 | elif xargs.search_space_name in ['darts', 'darts_fewshot']: 52 | model_config = edict({'name': 'DARTS-V1', 53 | 'C': 1, 'N': 1, 'depth': 2, 'use_stem': True, 'stem_multiplier': 1, 54 | 'feature_scale_rate': 2, 55 | 'num_classes': 1, 56 | 'space': search_space, 57 | 'affine': True, 'track_running_stats': bool(xargs.track_running_stats), 58 | 'super_type': xargs.super_type, 59 | 'steps': xargs.max_nodes, 60 | 'multiplier': xargs.max_nodes, 61 | 'ntk_type': 'NTK', 62 | }) 63 | 64 | network = get_cell_based_tiny_net(model_config) 65 | 66 | logger.log('<<<--->>> End: {:}'.format(network.genotype())) 67 | 68 | if xargs.dataset in ['MiniImageNet', 'MetaMiniImageNet']: 69 | dataset_for_eval = 'miniImageNet' 70 | elif xargs.dataset in ['TieredImageNet', 'MetaTieredImageNet']: 71 | dataset_for_eval = 'tieredImageNet' 72 | else: 73 | raise NotImplementedError('Only support miniImageNet and tieredImageNet') 74 | 75 | if xargs.train_method == 'rfs': 76 | evaluation_cmd = ['--model', 'augmentcnn', 77 | '--dataset', dataset_for_eval, '--data_root', os.path.dirname(xargs.data_path), 78 | '--init_channels', str(xargs.aug_channels), '--layers', str(xargs.aug_layers), '--aug_dp', str(xargs.aug_dp), 79 | '--aug_stemm', str(xargs.aug_stemm), '--aug_fsr', str(xargs.aug_fsr), 80 | '--lr_decay_epochs', str(xargs.aug_lr_decay_epochs), '--epochs', str(xargs.aug_epochs), 81 | '--learning_rate', str(xargs.aug_lr), 82 | '--seed', '-1', 83 | '--batch_size', str(xargs.aug_batchsize), '--genotype', str(network.genotype()), 84 | '--tb_path', os.path.join(str(xargs.save_dir),'logs'), '--model_path', os.path.join(str(xargs.save_dir),'model'),] 85 | # 5 way 5 shot 86 | eval_searched_arch.main(evaluation_cmd) 87 | else: 88 | raise NotImplementedError('Only support rfs training now.') 89 | 90 | logger.close() 91 | 92 | 93 | if __name__ == '__main__': 94 | parser = argparse.ArgumentParser("Random_baseline") 95 | parser.add_argument('--data_path', type=str, help='Path to dataset') 96 | parser.add_argument('--dataset', type=str, choices=['MiniImageNet', 'MetaMiniImageNet', 'TieredImageNet', 'MetaTieredImageNet'], help='Choose dataset') 97 | parser.add_argument('--search_space_name', type=str, default='darts_fewshot', help='space of operator candidates: nas-bench-201 or darts or darts_fewshot.') 98 | parser.add_argument('--max_nodes', type=int, choices=[3,4], default=3, help='The maximum number of nodes, choose from 3 and 4') 99 | parser.add_argument('--track_running_stats', type=int, choices=[0, 1], help='Whether use track_running_stats or not in the BN layer.') 100 | parser.add_argument('--workers', type=int, default=0, help='number of data loading workers (default: 0)') 101 | parser.add_argument('--save_dir', type=str, help='Folder to save checkpoints and log.') 102 | parser.add_argument('--timestamp', default='none', type=str, help='timestamp for logging naming') 103 | parser.add_argument('--super_type', type=str, default='nasnet-super', help='type of supernet: basic or nasnet-super') 104 | parser.add_argument('--rand_seed', type=int, help='manual seed') 105 | 106 | # Train after search for RFS 107 | parser.add_argument('--train_after_search', choices = ['true', 'false'], default='true', help='If directly train after search with RFS') 108 | parser.add_argument('--train_method', default='rfs', choices = ['rfs'], help='What evaluation method used to train the architecture') 109 | parser.add_argument('--aug_dp', type=float, default=0.2, help='Drop probability of augmentCNN') 110 | parser.add_argument('--aug_channels', type=int, default=48, help='Init channels for network during augmentation') 111 | parser.add_argument('--aug_layers', type=int, default=5, help='Number of layers for network during augmentation') 112 | parser.add_argument('--aug_lr', type=float, default=0.02, help='Learning rate for network during augmentation') 113 | parser.add_argument('--aug_batchsize', type=int, default=64, help='Batch size for network during augmentation') 114 | parser.add_argument('--aug_epochs', type=int, default=100, help='Batch size for network during augmentation') 115 | parser.add_argument('--aug_lr_decay_epochs', type=str, default='60,80', help='Learning rate decay epochs during augmentation') 116 | parser.add_argument('--aug_stemm', type=int, default=1, help='Stem multiplier during augmentation') 117 | parser.add_argument('--aug_fsr', type=int, default=2, help='Feature scaling ratio during augmenation') 118 | 119 | args = parser.parse_args() 120 | 121 | if args.rand_seed is None or args.rand_seed < 0: 122 | args.rand_seed = random.randint(1, 100000) 123 | 124 | main(args) 125 | -------------------------------------------------------------------------------- /lib/models/cell_searchs/search_model_darts_nasnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from copy import deepcopy 4 | from typing import List, Text, Dict 5 | from .search_cells import NASNetSearchCell as SearchCell 6 | from pdb import set_trace as bp 7 | # from lib.models import cell_operations 8 | 9 | # The macro structure is based on NASNet 10 | class NASNetworkDARTS(nn.Module): 11 | 12 | def __init__(self, C: int, N: int, steps: int, multiplier: int, stem_multiplier: int, feature_scale_rate: int, 13 | num_classes: int, search_space: List[Text], affine: bool, track_running_stats: bool, 14 | depth=-1, use_stem=True): 15 | super(NASNetworkDARTS, self).__init__() 16 | self._C = C 17 | self._layerN = N # number of stacked cell at each stage 18 | self._steps = steps 19 | self._multiplier = multiplier 20 | self.depth = depth 21 | self.use_stem = use_stem 22 | self.stem = nn.Sequential( 23 | nn.Conv2d(3 if use_stem else min(3, C), C*stem_multiplier, kernel_size=3, padding=1, bias=False), 24 | nn.BatchNorm2d(C*stem_multiplier)) 25 | # cell_operations.BatchNorm_scratch(C*stem_multiplier,4)) 26 | 27 | # config for each layer 28 | # layer_channels = [C ] * N + [C*2 ] + [C*2 ] * (N-1) + [C*4 ] + [C*4 ] * (N-1) 29 | # layer_reductions = [False] * N + [True] + [False] * (N-1) + [True] + [False] * (N-1) 30 | # layer_channels = [C ] * N + [C*2 ] + [C*2 ] * N + [C*4 ] + [C*4 ] * N 31 | # layer_reductions = [False] * N + [True] + [False] * N + [True] + [False] * N 32 | fsr = feature_scale_rate 33 | layer_channels = [C ] * N + [C*fsr ] + [C*fsr ] * N + [C*fsr**2 ] + [C*fsr**2 ] * N 34 | layer_reductions = [False] * N + [True ] + [False ] * N + [True ] + [False ] * N 35 | 36 | num_edge, edge2index = None, None 37 | C_prev_prev, C_prev, C_curr, reduction_prev = C*stem_multiplier, C*stem_multiplier, C, False 38 | 39 | self.cells = nn.ModuleList() 40 | for index, (C_curr, reduction) in enumerate(zip(layer_channels, layer_reductions)): 41 | if depth > 0 and index >= depth: break 42 | cell = SearchCell(search_space, steps, multiplier, C_prev_prev, C_prev, C_curr, reduction, reduction_prev, affine, track_running_stats) 43 | if num_edge is None: num_edge, edge2index = cell.num_edges, cell.edge2index 44 | else: assert num_edge == cell.num_edges and edge2index == cell.edge2index, 'invalid {:} vs. {:}.'.format(num_edge, cell.num_edges) 45 | self.cells.append( cell ) 46 | C_prev_prev, C_prev, reduction_prev = C_prev, multiplier*C_curr, reduction 47 | self.op_names = deepcopy( search_space ) 48 | self._Layer = len(self.cells) 49 | self.edge2index = edge2index 50 | self.lastact = nn.Sequential(nn.BatchNorm2d(C_prev), nn.ReLU(inplace=True)) 51 | # self.lastact = nn.Sequential(cell_operations.BatchNorm_scratch(C_prev,4), nn.ReLU(inplace=True)) 52 | self.global_pooling = nn.AdaptiveAvgPool2d(1) 53 | self.classifier = nn.Linear(C_prev, num_classes) 54 | self.arch_normal_parameters = nn.Parameter( 1e-3*torch.randn(num_edge, len(search_space)) ) 55 | self.arch_reduce_parameters = nn.Parameter( 1e-3*torch.randn(num_edge, len(search_space)) ) 56 | 57 | def get_weights(self) -> List[torch.nn.Parameter]: 58 | xlist = list( self.stem.parameters() ) + list( self.cells.parameters() ) 59 | xlist+= list( self.lastact.parameters() ) + list( self.global_pooling.parameters() ) 60 | xlist+= list( self.classifier.parameters() ) 61 | return xlist 62 | 63 | def get_alphas(self) -> List[torch.nn.Parameter]: 64 | return [self.arch_normal_parameters, self.arch_reduce_parameters] 65 | 66 | def set_alphas(self, arch_parameters): 67 | self.arch_normal_parameters.data.copy_(arch_parameters[0].data) 68 | self.arch_reduce_parameters.data.copy_(arch_parameters[1].data) 69 | 70 | def show_alphas(self) -> Text: 71 | with torch.no_grad(): 72 | A = 'arch-normal-parameters :\n{:}'.format( nn.functional.softmax(self.arch_normal_parameters, dim=-1).cpu() ) 73 | B = 'arch-reduce-parameters :\n{:}'.format( nn.functional.softmax(self.arch_reduce_parameters, dim=-1).cpu() ) 74 | return '{:}\n{:}'.format(A, B) 75 | 76 | def get_message(self) -> Text: 77 | string = self.extra_repr() 78 | for i, cell in enumerate(self.cells): 79 | string += '\n {:02d}/{:02d} :: {:}'.format(i, len(self.cells), cell.extra_repr()) 80 | return string 81 | 82 | def extra_repr(self) -> Text: 83 | return ('{name}(C={_C}, N={_layerN}, steps={_steps}, multiplier={_multiplier}, L={_Layer})'.format(name=self.__class__.__name__, **self.__dict__)) 84 | 85 | def genotype2cmd(self, genotype): 86 | cmd = "Genotype(normal=%s, normal_concat=%s, reduce=%s, reduce_concat=%s)"%(genotype['normal'], genotype['normal_concat'], genotype['reduce'], genotype['reduce_concat']) 87 | return cmd 88 | 89 | def genotype(self) -> Dict[Text, List]: 90 | def _parse(weights): 91 | gene = [] 92 | n = 2; start = 0 93 | for i in range(self._steps): 94 | end = start + n 95 | W = weights[start:end].copy() 96 | selected_edges = [] 97 | _edge_indice = sorted(range(i + 2), key=lambda x: -max(W[x][k] for k in range(len(W[x])) if k != self.op_names.index('none')))[:2] 98 | for _edge_index in _edge_indice: 99 | _op_indice = list(range(W.shape[1])) 100 | _op_indice.remove(self.op_names.index('none')) 101 | _op_index = sorted(_op_indice, key=lambda x: -W[_edge_index][x])[0] 102 | selected_edges.append( (self.op_names[_op_index], _edge_index) ) 103 | gene.append(selected_edges) 104 | start = end; n += 1 105 | return gene 106 | with torch.no_grad(): 107 | gene_normal = _parse(torch.softmax(self.arch_normal_parameters, dim=-1).cpu().numpy()) 108 | gene_reduce = _parse(torch.softmax(self.arch_reduce_parameters, dim=-1).cpu().numpy()) 109 | return self.genotype2cmd({'normal': gene_normal, 'normal_concat': list(range(2+self._steps-self._multiplier, self._steps+2)), 'reduce': gene_reduce, 'reduce_concat': list(range(2+self._steps-self._multiplier, self._steps+2))}) 110 | 111 | def forward(self, inputs): 112 | 113 | normal_w = nn.functional.softmax(self.arch_normal_parameters, dim=1).detach().clone() 114 | # normal_w = nn.functional.softmax(self.arch_normal_parameters, dim=1) 115 | reduce_w = nn.functional.softmax(self.arch_reduce_parameters, dim=1).detach().clone() 116 | # reduce_w = nn.functional.softmax(self.arch_reduce_parameters, dim=1) 117 | 118 | normal_a = self.arch_normal_parameters.detach().clone() 119 | reduce_a = self.arch_reduce_parameters.detach().clone() 120 | 121 | if self.use_stem: 122 | s0 = s1 = self.stem(inputs) 123 | else: 124 | s0 = s1 = inputs 125 | for i, cell in enumerate(self.cells): 126 | if cell.reduction: ww, aa = reduce_w, reduce_a 127 | else : ww, aa = normal_w, normal_a 128 | s0, s1 = s1, cell.forward_darts(s0, s1, ww, aa) 129 | out = self.lastact(s1) 130 | out = self.global_pooling(out) 131 | out = out.view(out.size(0), -1) 132 | logits = self.classifier(out) 133 | 134 | return out, logits 135 | 136 | -------------------------------------------------------------------------------- /eval_lib/flop_benchmark.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | 5 | 6 | def count_parameters_in_MB(model): 7 | if isinstance(model, nn.Module): 8 | return np.sum(np.prod(v.size()) for v in model.parameters())/1e6 9 | else: 10 | return np.sum(np.prod(v.size()) for v in model)/1e6 11 | 12 | 13 | def get_model_infos(model, shape): 14 | #model = copy.deepcopy( model ) 15 | 16 | model = add_flops_counting_methods(model) 17 | #model = model.cuda() 18 | model.eval() 19 | 20 | #cache_inputs = torch.zeros(*shape).cuda() 21 | #cache_inputs = torch.zeros(*shape) 22 | cache_inputs = torch.rand(*shape) 23 | if next(model.parameters()).is_cuda: cache_inputs = cache_inputs.cuda() 24 | #print_log('In the calculating function : cache input size : {:}'.format(cache_inputs.size()), log) 25 | with torch.no_grad(): 26 | _ = model(cache_inputs) 27 | FLOPs = compute_average_flops_cost( model ) / 1e6 28 | Param = count_parameters_in_MB(model) 29 | 30 | if hasattr(model, 'auxiliary_param'): 31 | aux_params = count_parameters_in_MB(model.auxiliary_param()) 32 | print ('The auxiliary params of this model is : {:}'.format(aux_params)) 33 | print ('We remove the auxiliary params from the total params ({:}) when counting'.format(Param)) 34 | Param = Param - aux_params 35 | 36 | #print_log('FLOPs : {:} MB'.format(FLOPs), log) 37 | torch.cuda.empty_cache() 38 | model.apply( remove_hook_function ) 39 | return FLOPs, Param 40 | 41 | 42 | # ---- Public functions 43 | def add_flops_counting_methods( model ): 44 | model.__batch_counter__ = 0 45 | add_batch_counter_hook_function( model ) 46 | model.apply( add_flops_counter_variable_or_reset ) 47 | model.apply( add_flops_counter_hook_function ) 48 | return model 49 | 50 | 51 | 52 | def compute_average_flops_cost(model): 53 | """ 54 | A method that will be available after add_flops_counting_methods() is called on a desired net object. 55 | Returns current mean flops consumption per image. 56 | """ 57 | batches_count = model.__batch_counter__ 58 | flops_sum = 0 59 | #or isinstance(module, torch.nn.AvgPool2d) or isinstance(module, torch.nn.MaxPool2d) \ 60 | for module in model.modules(): 61 | if isinstance(module, torch.nn.Conv2d) or isinstance(module, torch.nn.Linear) \ 62 | or isinstance(module, torch.nn.Conv1d) \ 63 | or hasattr(module, 'calculate_flop_self'): 64 | flops_sum += module.__flops__ 65 | return flops_sum / batches_count 66 | 67 | 68 | # ---- Internal functions 69 | def pool_flops_counter_hook(pool_module, inputs, output): 70 | batch_size = inputs[0].size(0) 71 | kernel_size = pool_module.kernel_size 72 | out_C, output_height, output_width = output.shape[1:] 73 | assert out_C == inputs[0].size(1), '{:} vs. {:}'.format(out_C, inputs[0].size()) 74 | 75 | overall_flops = batch_size * out_C * output_height * output_width * kernel_size * kernel_size 76 | pool_module.__flops__ += overall_flops 77 | 78 | 79 | def self_calculate_flops_counter_hook(self_module, inputs, output): 80 | overall_flops = self_module.calculate_flop_self(inputs[0].shape, output.shape) 81 | self_module.__flops__ += overall_flops 82 | 83 | 84 | def fc_flops_counter_hook(fc_module, inputs, output): 85 | batch_size = inputs[0].size(0) 86 | xin, xout = fc_module.in_features, fc_module.out_features 87 | assert xin == inputs[0].size(1) and xout == output.size(1), 'IO=({:}, {:})'.format(xin, xout) 88 | overall_flops = batch_size * xin * xout 89 | if fc_module.bias is not None: 90 | overall_flops += batch_size * xout 91 | fc_module.__flops__ += overall_flops 92 | 93 | 94 | def conv1d_flops_counter_hook(conv_module, inputs, outputs): 95 | batch_size = inputs[0].size(0) 96 | outL = outputs.shape[-1] 97 | [kernel] = conv_module.kernel_size 98 | in_channels = conv_module.in_channels 99 | out_channels = conv_module.out_channels 100 | groups = conv_module.groups 101 | conv_per_position_flops = kernel * in_channels * out_channels / groups 102 | 103 | active_elements_count = batch_size * outL 104 | overall_flops = conv_per_position_flops * active_elements_count 105 | 106 | if conv_module.bias is not None: 107 | overall_flops += out_channels * active_elements_count 108 | conv_module.__flops__ += overall_flops 109 | 110 | 111 | def conv2d_flops_counter_hook(conv_module, inputs, output): 112 | batch_size = inputs[0].size(0) 113 | output_height, output_width = output.shape[2:] 114 | 115 | kernel_height, kernel_width = conv_module.kernel_size 116 | in_channels = conv_module.in_channels 117 | out_channels = conv_module.out_channels 118 | groups = conv_module.groups 119 | conv_per_position_flops = kernel_height * kernel_width * in_channels * out_channels / groups 120 | 121 | active_elements_count = batch_size * output_height * output_width 122 | overall_flops = conv_per_position_flops * active_elements_count 123 | 124 | if conv_module.bias is not None: 125 | overall_flops += out_channels * active_elements_count 126 | conv_module.__flops__ += overall_flops 127 | 128 | 129 | def batch_counter_hook(module, inputs, output): 130 | # Can have multiple inputs, getting the first one 131 | inputs = inputs[0] 132 | batch_size = inputs.shape[0] 133 | module.__batch_counter__ += batch_size 134 | 135 | 136 | def add_batch_counter_hook_function(module): 137 | if not hasattr(module, '__batch_counter_handle__'): 138 | handle = module.register_forward_hook(batch_counter_hook) 139 | module.__batch_counter_handle__ = handle 140 | 141 | 142 | def add_flops_counter_variable_or_reset(module): 143 | if isinstance(module, torch.nn.Conv2d) or isinstance(module, torch.nn.Linear) \ 144 | or isinstance(module, torch.nn.Conv1d) \ 145 | or isinstance(module, torch.nn.AvgPool2d) or isinstance(module, torch.nn.MaxPool2d) \ 146 | or hasattr(module, 'calculate_flop_self'): 147 | module.__flops__ = 0 148 | 149 | 150 | def add_flops_counter_hook_function(module): 151 | if isinstance(module, torch.nn.Conv2d): 152 | if not hasattr(module, '__flops_handle__'): 153 | handle = module.register_forward_hook(conv2d_flops_counter_hook) 154 | module.__flops_handle__ = handle 155 | elif isinstance(module, torch.nn.Conv1d): 156 | if not hasattr(module, '__flops_handle__'): 157 | handle = module.register_forward_hook(conv1d_flops_counter_hook) 158 | module.__flops_handle__ = handle 159 | elif isinstance(module, torch.nn.Linear): 160 | if not hasattr(module, '__flops_handle__'): 161 | handle = module.register_forward_hook(fc_flops_counter_hook) 162 | module.__flops_handle__ = handle 163 | elif isinstance(module, torch.nn.AvgPool2d) or isinstance(module, torch.nn.MaxPool2d): 164 | if not hasattr(module, '__flops_handle__'): 165 | handle = module.register_forward_hook(pool_flops_counter_hook) 166 | module.__flops_handle__ = handle 167 | elif hasattr(module, 'calculate_flop_self'): # self-defined module 168 | if not hasattr(module, '__flops_handle__'): 169 | handle = module.register_forward_hook(self_calculate_flops_counter_hook) 170 | module.__flops_handle__ = handle 171 | 172 | 173 | def remove_hook_function(module): 174 | hookers = ['__batch_counter_handle__', '__flops_handle__'] 175 | for hooker in hookers: 176 | if hasattr(module, hooker): 177 | handle = getattr(module, hooker) 178 | handle.remove() 179 | keys = ['__flops__', '__batch_counter__', '__flops__'] + hookers 180 | for ckey in keys: 181 | if hasattr(module, ckey): delattr(module, ckey) 182 | -------------------------------------------------------------------------------- /lib/utils/flop_benchmark.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | 5 | 6 | def count_parameters_in_MB(model): 7 | if isinstance(model, nn.Module): 8 | return np.sum(np.prod(v.size()) for v in model.parameters())/1e6 9 | else: 10 | return np.sum(np.prod(v.size()) for v in model)/1e6 11 | 12 | 13 | def get_model_infos(model, shape): 14 | #model = copy.deepcopy( model ) 15 | 16 | model = add_flops_counting_methods(model) 17 | #model = model.cuda() 18 | model.eval() 19 | 20 | #cache_inputs = torch.zeros(*shape).cuda() 21 | #cache_inputs = torch.zeros(*shape) 22 | cache_inputs = torch.rand(*shape) 23 | if next(model.parameters()).is_cuda: cache_inputs = cache_inputs.cuda() 24 | #print_log('In the calculating function : cache input size : {:}'.format(cache_inputs.size()), log) 25 | with torch.no_grad(): 26 | _ = model(cache_inputs) 27 | FLOPs = compute_average_flops_cost( model ) / 1e6 28 | Param = count_parameters_in_MB(model) 29 | 30 | if hasattr(model, 'auxiliary_param'): 31 | aux_params = count_parameters_in_MB(model.auxiliary_param()) 32 | print ('The auxiliary params of this model is : {:}'.format(aux_params)) 33 | print ('We remove the auxiliary params from the total params ({:}) when counting'.format(Param)) 34 | Param = Param - aux_params 35 | 36 | #print_log('FLOPs : {:} MB'.format(FLOPs), log) 37 | torch.cuda.empty_cache() 38 | model.apply( remove_hook_function ) 39 | return FLOPs, Param 40 | 41 | 42 | # ---- Public functions 43 | def add_flops_counting_methods( model ): 44 | model.__batch_counter__ = 0 45 | add_batch_counter_hook_function( model ) 46 | model.apply( add_flops_counter_variable_or_reset ) 47 | model.apply( add_flops_counter_hook_function ) 48 | return model 49 | 50 | 51 | 52 | def compute_average_flops_cost(model): 53 | """ 54 | A method that will be available after add_flops_counting_methods() is called on a desired net object. 55 | Returns current mean flops consumption per image. 56 | """ 57 | batches_count = model.__batch_counter__ 58 | flops_sum = 0 59 | #or isinstance(module, torch.nn.AvgPool2d) or isinstance(module, torch.nn.MaxPool2d) \ 60 | for module in model.modules(): 61 | if isinstance(module, torch.nn.Conv2d) or isinstance(module, torch.nn.Linear) \ 62 | or isinstance(module, torch.nn.Conv1d) \ 63 | or hasattr(module, 'calculate_flop_self'): 64 | flops_sum += module.__flops__ 65 | return flops_sum / batches_count 66 | 67 | 68 | # ---- Internal functions 69 | def pool_flops_counter_hook(pool_module, inputs, output): 70 | batch_size = inputs[0].size(0) 71 | kernel_size = pool_module.kernel_size 72 | out_C, output_height, output_width = output.shape[1:] 73 | assert out_C == inputs[0].size(1), '{:} vs. {:}'.format(out_C, inputs[0].size()) 74 | 75 | overall_flops = batch_size * out_C * output_height * output_width * kernel_size * kernel_size 76 | pool_module.__flops__ += overall_flops 77 | 78 | 79 | def self_calculate_flops_counter_hook(self_module, inputs, output): 80 | overall_flops = self_module.calculate_flop_self(inputs[0].shape, output.shape) 81 | self_module.__flops__ += overall_flops 82 | 83 | 84 | def fc_flops_counter_hook(fc_module, inputs, output): 85 | batch_size = inputs[0].size(0) 86 | xin, xout = fc_module.in_features, fc_module.out_features 87 | assert xin == inputs[0].size(1) and xout == output.size(1), 'IO=({:}, {:})'.format(xin, xout) 88 | overall_flops = batch_size * xin * xout 89 | if fc_module.bias is not None: 90 | overall_flops += batch_size * xout 91 | fc_module.__flops__ += overall_flops 92 | 93 | 94 | def conv1d_flops_counter_hook(conv_module, inputs, outputs): 95 | batch_size = inputs[0].size(0) 96 | outL = outputs.shape[-1] 97 | [kernel] = conv_module.kernel_size 98 | in_channels = conv_module.in_channels 99 | out_channels = conv_module.out_channels 100 | groups = conv_module.groups 101 | conv_per_position_flops = kernel * in_channels * out_channels / groups 102 | 103 | active_elements_count = batch_size * outL 104 | overall_flops = conv_per_position_flops * active_elements_count 105 | 106 | if conv_module.bias is not None: 107 | overall_flops += out_channels * active_elements_count 108 | conv_module.__flops__ += overall_flops 109 | 110 | 111 | def conv2d_flops_counter_hook(conv_module, inputs, output): 112 | batch_size = inputs[0].size(0) 113 | output_height, output_width = output.shape[2:] 114 | 115 | kernel_height, kernel_width = conv_module.kernel_size 116 | in_channels = conv_module.in_channels 117 | out_channels = conv_module.out_channels 118 | groups = conv_module.groups 119 | conv_per_position_flops = kernel_height * kernel_width * in_channels * out_channels / groups 120 | 121 | active_elements_count = batch_size * output_height * output_width 122 | overall_flops = conv_per_position_flops * active_elements_count 123 | 124 | if conv_module.bias is not None: 125 | overall_flops += out_channels * active_elements_count 126 | conv_module.__flops__ += overall_flops 127 | 128 | 129 | def batch_counter_hook(module, inputs, output): 130 | # Can have multiple inputs, getting the first one 131 | inputs = inputs[0] 132 | batch_size = inputs.shape[0] 133 | module.__batch_counter__ += batch_size 134 | 135 | 136 | def add_batch_counter_hook_function(module): 137 | if not hasattr(module, '__batch_counter_handle__'): 138 | handle = module.register_forward_hook(batch_counter_hook) 139 | module.__batch_counter_handle__ = handle 140 | 141 | 142 | def add_flops_counter_variable_or_reset(module): 143 | if isinstance(module, torch.nn.Conv2d) or isinstance(module, torch.nn.Linear) \ 144 | or isinstance(module, torch.nn.Conv1d) \ 145 | or isinstance(module, torch.nn.AvgPool2d) or isinstance(module, torch.nn.MaxPool2d) \ 146 | or hasattr(module, 'calculate_flop_self'): 147 | module.__flops__ = 0 148 | 149 | 150 | def add_flops_counter_hook_function(module): 151 | if isinstance(module, torch.nn.Conv2d): 152 | if not hasattr(module, '__flops_handle__'): 153 | handle = module.register_forward_hook(conv2d_flops_counter_hook) 154 | module.__flops_handle__ = handle 155 | elif isinstance(module, torch.nn.Conv1d): 156 | if not hasattr(module, '__flops_handle__'): 157 | handle = module.register_forward_hook(conv1d_flops_counter_hook) 158 | module.__flops_handle__ = handle 159 | elif isinstance(module, torch.nn.Linear): 160 | if not hasattr(module, '__flops_handle__'): 161 | handle = module.register_forward_hook(fc_flops_counter_hook) 162 | module.__flops_handle__ = handle 163 | elif isinstance(module, torch.nn.AvgPool2d) or isinstance(module, torch.nn.MaxPool2d): 164 | if not hasattr(module, '__flops_handle__'): 165 | handle = module.register_forward_hook(pool_flops_counter_hook) 166 | module.__flops_handle__ = handle 167 | elif hasattr(module, 'calculate_flop_self'): # self-defined module 168 | if not hasattr(module, '__flops_handle__'): 169 | handle = module.register_forward_hook(self_calculate_flops_counter_hook) 170 | module.__flops_handle__ = handle 171 | 172 | 173 | def remove_hook_function(module): 174 | hookers = ['__batch_counter_handle__', '__flops_handle__'] 175 | for hooker in hookers: 176 | if hasattr(module, hooker): 177 | handle = getattr(module, hooker) 178 | handle.remove() 179 | keys = ['__flops__', '__batch_counter__', '__flops__'] + hookers 180 | for ckey in keys: 181 | if hasattr(module, ckey): delattr(module, ckey) 182 | -------------------------------------------------------------------------------- /lib/models/__init__.py: -------------------------------------------------------------------------------- 1 | from os import path as osp 2 | from typing import List, Text 3 | import torch 4 | 5 | __all__ = ['change_key', 'get_cell_based_tiny_net', 'get_search_spaces', 'get_cifar_models', 'get_imagenet_models', \ 6 | 'obtain_model', 7 | 'CellStructure', 'CellArchitectures' 8 | ] 9 | 10 | # useful modules 11 | from .SharedUtils import change_key 12 | from .cell_searchs import CellStructure, CellArchitectures 13 | from .cell_infers import TinyNetwork 14 | from .cell_searchs import nas201_super_nets as nas_super_nets 15 | 16 | 17 | # Cell-based NAS Models 18 | def get_cell_based_tiny_net(config): 19 | super_type = getattr(config, 'super_type', 'basic') 20 | group_names = ['DARTS-V1', 'DARTS-V2'] 21 | 22 | # Now MetaNTK only supports the following two types: 23 | if config.ntk_type == 'MetaNTK': 24 | assert ((super_type == 'basic' and config.name in group_names) or (super_type == 'nasnet-super')) 25 | 26 | if super_type == 'basic' and config.name in group_names: 27 | from .cell_searchs import nas201_super_nets as nas_super_nets 28 | try: 29 | return nas_super_nets[config.name](config.C, config.N, config.max_nodes, config.num_classes, config.space, config.affine, config.track_running_stats, config.depth, config.use_stem) 30 | except Exception: 31 | return nas_super_nets[config.name](config.C, config.N, config.max_nodes, config.num_classes, config.space, config.depth, config.use_stem) 32 | elif super_type == 'nasnet-super': 33 | from .cell_searchs import nasnet_super_nets as nas_super_nets 34 | return nas_super_nets[config.name](config.C, config.N, config.steps, config.multiplier, \ 35 | config.stem_multiplier, config.feature_scale_rate, config.num_classes, config.space, config.affine, config.track_running_stats, config.depth, config.use_stem) 36 | elif config.name == 'infer.tiny': 37 | from .cell_infers import TinyNetwork 38 | if hasattr(config, 'genotype'): 39 | genotype = config.genotype 40 | elif hasattr(config, 'arch_str'): 41 | genotype = CellStructure.str2structure(config.arch_str) 42 | else: raise ValueError('Can not find genotype from this config : {:}'.format(config)) 43 | return TinyNetwork(config.C, config.N, genotype, config.num_classes) 44 | elif config.name == 'infer.shape.tiny': 45 | from .shape_infers import DynamicShapeTinyNet 46 | if isinstance(config.channels, str): 47 | channels = tuple([int(x) for x in config.channels.split(':')]) 48 | else: channels = config.channels 49 | genotype = CellStructure.str2structure(config.genotype) 50 | return DynamicShapeTinyNet(channels, genotype, config.num_classes) 51 | elif config.name == 'infer.nasnet-cifar': 52 | from .cell_infers import NASNetonCIFAR 53 | raise NotImplementedError 54 | else: 55 | raise ValueError('invalid network name : {:}'.format(config.name)) 56 | 57 | 58 | # obtain the search space, i.e., a dict mapping the operation name into a python-function for this op 59 | def get_search_spaces(xtype, name) -> List[Text]: 60 | if xtype == 'cell': 61 | from .cell_operations import SearchSpaceNames 62 | assert name in SearchSpaceNames, 'invalid name [{:}] in {:}'.format(name, SearchSpaceNames.keys()) 63 | return SearchSpaceNames[name] 64 | else: 65 | raise ValueError('invalid search-space type is {:}'.format(xtype)) 66 | 67 | 68 | def get_cifar_models(config, extra_path=None): 69 | super_type = getattr(config, 'super_type', 'basic') 70 | if super_type == 'basic': 71 | from .CifarResNet import CifarResNet 72 | from .CifarDenseNet import DenseNet 73 | from .CifarWideResNet import CifarWideResNet 74 | if config.arch == 'resnet': 75 | return CifarResNet(config.module, config.depth, config.class_num, config.zero_init_residual) 76 | elif config.arch == 'densenet': 77 | return DenseNet(config.growthRate, config.depth, config.reduction, config.class_num, config.bottleneck) 78 | elif config.arch == 'wideresnet': 79 | return CifarWideResNet(config.depth, config.wide_factor, config.class_num, config.dropout) 80 | else: 81 | raise ValueError('invalid module type : {:}'.format(config.arch)) 82 | elif super_type.startswith('infer'): 83 | from .shape_infers import InferWidthCifarResNet 84 | from .shape_infers import InferDepthCifarResNet 85 | from .shape_infers import InferCifarResNet 86 | from .cell_infers import NASNetonCIFAR 87 | assert len(super_type.split('-')) == 2, 'invalid super_type : {:}'.format(super_type) 88 | infer_mode = super_type.split('-')[1] 89 | if infer_mode == 'width': 90 | return InferWidthCifarResNet(config.module, config.depth, config.xchannels, config.class_num, config.zero_init_residual) 91 | elif infer_mode == 'depth': 92 | return InferDepthCifarResNet(config.module, config.depth, config.xblocks, config.class_num, config.zero_init_residual) 93 | elif infer_mode == 'shape': 94 | return InferCifarResNet(config.module, config.depth, config.xblocks, config.xchannels, config.class_num, config.zero_init_residual) 95 | elif infer_mode == 'nasnet.cifar': 96 | genotype = config.genotype 97 | if extra_path is not None: # reload genotype by extra_path 98 | if not osp.isfile(extra_path): raise ValueError('invalid extra_path : {:}'.format(extra_path)) 99 | xdata = torch.load(extra_path) 100 | current_epoch = xdata['epoch'] 101 | genotype = xdata['genotypes'][current_epoch-1] 102 | C = config.C if hasattr(config, 'C') else config.ichannel 103 | N = config.N if hasattr(config, 'N') else config.layers 104 | return NASNetonCIFAR(C, N, config.stem_multi, config.class_num, genotype, config.auxiliary) 105 | else: 106 | raise ValueError('invalid infer-mode : {:}'.format(infer_mode)) 107 | else: 108 | raise ValueError('invalid super-type : {:}'.format(super_type)) 109 | 110 | 111 | def get_imagenet_models(config): 112 | super_type = getattr(config, 'super_type', 'basic') 113 | if super_type == 'basic': 114 | from .ImageNet_ResNet import ResNet 115 | from .ImageNet_MobileNetV2 import MobileNetV2 116 | if config.arch == 'resnet': 117 | return ResNet(config.block_name, config.layers, config.deep_stem, config.class_num, config.zero_init_residual, config.groups, config.width_per_group) 118 | elif config.arch == 'mobilenet_v2': 119 | return MobileNetV2(config.class_num, config.width_multi, config.input_channel, config.last_channel, 'InvertedResidual', config.dropout) 120 | else: 121 | raise ValueError('invalid arch : {:}'.format( config.arch )) 122 | elif super_type.startswith('infer'): # NAS searched architecture 123 | assert len(super_type.split('-')) == 2, 'invalid super_type : {:}'.format(super_type) 124 | infer_mode = super_type.split('-')[1] 125 | if infer_mode == 'shape': 126 | from .shape_infers import InferImagenetResNet 127 | from .shape_infers import InferMobileNetV2 128 | if config.arch == 'resnet': 129 | return InferImagenetResNet(config.block_name, config.layers, config.xblocks, config.xchannels, config.deep_stem, config.class_num, config.zero_init_residual) 130 | elif config.arch == "MobileNetV2": 131 | return InferMobileNetV2(config.class_num, config.xchannels, config.xblocks, config.dropout) 132 | else: 133 | raise ValueError('invalid arch-mode : {:}'.format(config.arch)) 134 | else: 135 | raise ValueError('invalid infer-mode : {:}'.format(infer_mode)) 136 | else: 137 | raise ValueError('invalid super-type : {:}'.format(super_type)) 138 | 139 | 140 | # Try to obtain the network by config. 141 | def obtain_model(config, extra_path=None): 142 | if config.dataset == 'cifar': 143 | return get_cifar_models(config, extra_path) 144 | elif config.dataset == 'imagenet': 145 | return get_imagenet_models(config) 146 | else: 147 | raise ValueError('invalid dataset in the model config : {:}'.format(config)) 148 | -------------------------------------------------------------------------------- /eval_lib/rfs_models/augment_cnn.py: -------------------------------------------------------------------------------- 1 | """ CNN for network augmentation 2 | Copyright (c) 2021 Robert Bosch GmbH 3 | 4 | This program is free software: you can redistribute it and/or modify 5 | it under the terms of the GNU Affero General Public License as published 6 | by the Free Software Foundation, either version 3 of the License, or 7 | (at your option) any later version. 8 | 9 | This program is distributed in the hope that it will be useful, 10 | but WITHOUT ANY WARRANTY; without even the implied warranty of 11 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 12 | GNU Affero General Public License for more details. 13 | 14 | """ 15 | 16 | 17 | """ 18 | Based on https://github.com/khanrc/pt.darts 19 | which is licensed under MIT License, 20 | cf. 3rd-party-licenses.txt in root directory. 21 | """ 22 | 23 | import torch 24 | import torch.nn as nn 25 | 26 | from .fewshot_operations import * 27 | 28 | def to_dag(C_in, gene, reduction): 29 | """ generate discrete ops from gene """ 30 | dag = nn.ModuleList() 31 | for edges in gene: 32 | row = nn.ModuleList() 33 | for op_name, s_idx in edges: 34 | # reduction cell & from input nodes => stride = 2 35 | stride = 2 if reduction and s_idx < 2 else 1 36 | op = OPS[op_name](C_in, C_in, stride, True, True) 37 | if not isinstance(op, Identity): # Identity does not use drop path 38 | op = nn.Sequential(op, DropPath_()) 39 | op.s_idx = s_idx 40 | row.append(op) 41 | dag.append(row) 42 | 43 | return dag 44 | 45 | class AugmentCNN(nn.Module): 46 | """ Augmented CNN model """ 47 | 48 | def __init__( 49 | self, 50 | input_size, 51 | C_in, 52 | C, 53 | n_classes, 54 | n_layers, 55 | auxiliary, 56 | genotype, 57 | stem_multiplier=3, 58 | feature_scale_rate=2, 59 | reduction_layers=[], 60 | ): 61 | """ 62 | Args: 63 | input_size: size of height and width (assuming height = width) 64 | C_in: # of input channels 65 | C: # of starting model channels 66 | """ 67 | super().__init__() 68 | self.C_in = C_in 69 | self.C = C 70 | self.n_classes = n_classes 71 | self.n_layers = n_layers 72 | self.aux_pos = 2 * n_layers // 3 if auxiliary else -1 73 | 74 | C_cur = stem_multiplier * C 75 | self.stem = nn.Sequential( 76 | nn.Conv2d(C_in, C_cur, 3, 1, 1, bias=False), nn.BatchNorm2d(C_cur) 77 | ) 78 | 79 | C_pp, C_p, C_cur = C_cur, C_cur, C 80 | 81 | self.cells = nn.ModuleList() 82 | reduction_p = False 83 | 84 | if not reduction_layers: 85 | reduction_layers = [n_layers // 3, (2 * n_layers) // 3] 86 | 87 | for i in range(n_layers): 88 | if i in reduction_layers: 89 | C_cur *= feature_scale_rate 90 | reduction = True 91 | else: 92 | reduction = False 93 | 94 | cell = AugmentCell(genotype, C_pp, C_p, C_cur, reduction_p, reduction) 95 | reduction_p = reduction 96 | self.cells.append(cell) 97 | C_cur_out = C_cur * len(cell.concat) 98 | C_pp, C_p = C_p, C_cur_out 99 | 100 | if i == self.aux_pos: 101 | # [!] this auxiliary head is ignored in computing parameter size 102 | # by the name 'aux_head' 103 | self.aux_head = AuxiliaryHead(input_size // 4, C_p, n_classes) 104 | 105 | # self.lastact = nn.Sequential(nn.BatchNorm2d(C_p), nn.ReLU(inplace=True)) 106 | self.gap = nn.AdaptiveAvgPool2d(1) 107 | self.classifier = nn.Linear(C_p, n_classes) 108 | 109 | self.criterion = nn.CrossEntropyLoss() 110 | 111 | ####### dummy alphas 112 | self.alpha_normal = nn.ParameterList() 113 | self.alpha_reduce = nn.ParameterList() 114 | 115 | for i in range(2): 116 | self.alpha_normal.append(nn.Parameter(1e-3 * torch.randn(1, 5))) 117 | self.alpha_reduce.append(nn.Parameter(1e-3 * torch.randn(1, 5))) 118 | 119 | # setup alphas list 120 | self._alphas = [] 121 | for n, p in self.named_parameters(): 122 | if "alpha" in n: 123 | self._alphas.append((n, p)) 124 | 125 | self.alpha_prune_threshold = 0.0 126 | 127 | def forward(self, x, is_feat=False): 128 | s0 = s1 = self.stem(x) 129 | 130 | aux_logits = None 131 | for i, cell in enumerate(self.cells): 132 | s0, s1 = s1, cell(s0, s1) 133 | if i == self.aux_pos and self.training: 134 | aux_logits = self.aux_head(s1) 135 | 136 | # out = self.lastact(s1) 137 | # out = self.gap(out) 138 | 139 | out = self.gap(s1) 140 | out = out.view(out.size(0), -1) # flatten 141 | logits = self.classifier(out) 142 | 143 | if is_feat: 144 | return [out], logits 145 | else: 146 | if self.aux_pos == -1: # no auxiliary head 147 | return logits 148 | else: 149 | return logits, aux_logits 150 | 151 | def drop_path_prob(self, p): 152 | """ Set drop path probability """ 153 | for module in self.modules(): 154 | if isinstance(module, DropPath_): 155 | module.p = p 156 | 157 | def weights(self): 158 | return self.parameters() 159 | 160 | def named_weights(self): 161 | return self.named_parameters() 162 | 163 | def alphas(self): 164 | for n, p in self._alphas: 165 | yield p 166 | # return None 167 | 168 | def named_alphas(self): 169 | for n, p in self._alphas: 170 | yield n, p 171 | # return None 172 | 173 | def genotype(self): 174 | return None 175 | 176 | def loss(self, X, y): 177 | logits = self.forward(X) 178 | return self.criterion(logits, y) 179 | 180 | def get_sparse_num_params( 181 | self, alpha_prune_threshold=0.0 182 | ): # dummy function to not break code 183 | """Get number of parameters for sparse one-shot-model (in this case just number of parameters of model) 184 | 185 | Returns: 186 | A torch tensor 187 | """ 188 | return None 189 | 190 | 191 | class AuxiliaryHead(nn.Module): 192 | """ Auxiliary head in 2/3 place of network to let the gradient flow well """ 193 | 194 | def __init__(self, input_size, C, n_classes): 195 | """ assuming input size 7x7 or 8x8 """ 196 | assert input_size in [7, 8] 197 | super().__init__() 198 | self.net = nn.Sequential( 199 | nn.ReLU(inplace=True), 200 | # 2x2 out 201 | nn.AvgPool2d(5, stride=input_size - 5, padding=0, count_include_pad=False), 202 | nn.Conv2d(C, 128, kernel_size=1, bias=False), 203 | nn.BatchNorm2d(128), 204 | nn.ReLU(inplace=True), 205 | # 1x1 out 206 | nn.Conv2d(128, 768, kernel_size=2, bias=False), 207 | nn.BatchNorm2d(768), 208 | nn.ReLU(inplace=True), 209 | ) 210 | self.classifier = nn.classifier(768, n_classes) 211 | 212 | def forward(self, x): 213 | out = self.net(x) 214 | out = out.view(out.size(0), -1) # flatten 215 | logits = self.classifier(out) 216 | return logits 217 | 218 | 219 | class AugmentCell(nn.Module): 220 | """Cell for augmentation 221 | Each edge is discrete. 222 | """ 223 | 224 | def __init__(self, genotype, C_pp, C_p, C, reduction_p, reduction): 225 | super().__init__() 226 | self.reduction = reduction 227 | self.n_nodes = len(genotype.normal) 228 | 229 | if reduction_p: 230 | self.preproc0 = FactorizedReduce(C_in=C_pp, C_out=C, stride=2, affine=True, track_running_stats=True) 231 | else: 232 | self.preproc0 = StdConv(C_pp, C, 1, 1, 0) 233 | self.preproc1 = StdConv(C_p, C, 1, 1, 0) 234 | 235 | # generate dag 236 | if reduction: 237 | gene = genotype.reduce 238 | self.concat = genotype.reduce_concat 239 | else: 240 | gene = genotype.normal 241 | self.concat = genotype.normal_concat 242 | 243 | self.dag = to_dag(C, gene, reduction) 244 | 245 | def forward(self, s0, s1): 246 | s0 = self.preproc0(s0) 247 | s1 = self.preproc1(s1) 248 | 249 | states = [s0, s1] 250 | for edges in self.dag: 251 | s_cur = sum(op(states[op.s_idx]) for op in edges) 252 | states.append(s_cur) 253 | 254 | s_out = torch.cat([states[i] for i in self.concat], dim=1) 255 | 256 | return s_out 257 | -------------------------------------------------------------------------------- /lib/datasets/mini_imagenet.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | from PIL import Image 4 | import numpy as np 5 | import torch 6 | from torch.utils.data import Dataset 7 | import torchvision.transforms as transforms 8 | 9 | 10 | class ImageNet(Dataset): 11 | def __init__(self, args, partition='train', pretrain=True, is_sample=False, k=4096, 12 | transform=None): 13 | super(Dataset, self).__init__() 14 | self.data_path = args.data_path 15 | self.partition = partition 16 | self.data_aug = args.data_aug 17 | self.mean = [120.39586422 / 255.0, 115.59361427 / 255.0, 104.54012653 / 255.0] 18 | self.std = [70.68188272 / 255.0, 68.27635443 / 255.0, 72.54505529 / 255.0] 19 | self.normalize = transforms.Normalize(mean=self.mean, std=self.std) 20 | self.pretrain = pretrain 21 | 22 | if transform is None: 23 | if self.partition == 'train' and self.data_aug: 24 | self.transform = transforms.Compose([ 25 | lambda x: Image.fromarray(x), 26 | transforms.RandomCrop(84, padding=8), 27 | transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4), 28 | transforms.RandomHorizontalFlip(), 29 | lambda x: np.asarray(x), 30 | transforms.ToTensor(), 31 | self.normalize 32 | ]) 33 | else: 34 | self.transform = transforms.Compose([ 35 | lambda x: Image.fromarray(x), 36 | transforms.ToTensor(), 37 | self.normalize 38 | ]) 39 | else: 40 | self.transform = transform 41 | 42 | if self.pretrain: 43 | self.file_pattern = 'miniImageNet_category_split_train_phase_%s.pickle' 44 | else: 45 | self.file_pattern = 'miniImageNet_category_split_%s.pickle' 46 | self.data = {} 47 | with open(os.path.join(self.data_path, self.file_pattern % partition), 'rb') as f: 48 | data = pickle.load(f, encoding='latin1') 49 | self.imgs = data['data'] 50 | self.labels = data['labels'] 51 | 52 | # pre-process for contrastive sampling 53 | self.k = k 54 | self.is_sample = is_sample 55 | if self.is_sample: 56 | self.labels = np.asarray(self.labels) 57 | self.labels = self.labels - np.min(self.labels) 58 | num_classes = np.max(self.labels) + 1 59 | 60 | self.cls_positive = [[] for _ in range(num_classes)] 61 | for i in range(len(self.imgs)): 62 | self.cls_positive[self.labels[i]].append(i) 63 | 64 | self.cls_negative = [[] for _ in range(num_classes)] 65 | for i in range(num_classes): 66 | for j in range(num_classes): 67 | if j == i: 68 | continue 69 | self.cls_negative[i].extend(self.cls_positive[j]) 70 | 71 | self.cls_positive = [np.asarray(self.cls_positive[i]) for i in range(num_classes)] 72 | self.cls_negative = [np.asarray(self.cls_negative[i]) for i in range(num_classes)] 73 | self.cls_positive = np.asarray(self.cls_positive) 74 | self.cls_negative = np.asarray(self.cls_negative) 75 | 76 | def __getitem__(self, item): 77 | img = np.asarray(self.imgs[item]).astype('uint8') 78 | img = self.transform(img) 79 | target = self.labels[item] - min(self.labels) 80 | 81 | if not self.is_sample: 82 | return img, target, item 83 | else: 84 | pos_idx = item 85 | replace = True if self.k > len(self.cls_negative[target]) else False 86 | neg_idx = np.random.choice(self.cls_negative[target], self.k, replace=replace) 87 | sample_idx = np.hstack((np.asarray([pos_idx]), neg_idx)) 88 | return img, target, item, sample_idx 89 | 90 | def __len__(self): 91 | return len(self.labels) 92 | 93 | 94 | class MetaImageNet(ImageNet): 95 | 96 | def __init__(self, args, partition='train', train_transform=None, test_transform=None, fix_seed=True): 97 | super(MetaImageNet, self).__init__(args, partition, True) 98 | self.fix_seed = fix_seed 99 | self.n_ways = args.n_ways 100 | self.n_shots = args.n_shots 101 | self.n_queries = args.n_queries 102 | self.classes = list(self.data.keys()) 103 | self.n_test_runs = args.n_test_runs 104 | self.n_aug_support_samples = args.n_aug_support_samples 105 | if train_transform is None: 106 | self.train_transform = transforms.Compose([ 107 | lambda x: Image.fromarray(x), 108 | transforms.RandomCrop(84, padding=8), 109 | transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4), 110 | transforms.RandomHorizontalFlip(), 111 | lambda x: np.asarray(x), 112 | transforms.ToTensor(), 113 | self.normalize 114 | ]) 115 | else: 116 | self.train_transform = train_transform 117 | 118 | if test_transform is None: 119 | self.test_transform = transforms.Compose([ 120 | lambda x: Image.fromarray(x), 121 | transforms.ToTensor(), 122 | self.normalize 123 | ]) 124 | else: 125 | self.test_transform = test_transform 126 | 127 | self.data = {} 128 | for idx in range(self.imgs.shape[0]): 129 | if self.labels[idx] not in self.data: 130 | self.data[self.labels[idx]] = [] 131 | self.data[self.labels[idx]].append(self.imgs[idx]) 132 | self.classes = list(self.data.keys()) 133 | 134 | def __getitem__(self, item): 135 | if self.fix_seed: 136 | np.random.seed(item) 137 | cls_sampled = np.random.choice(self.classes, self.n_ways, False) 138 | support_xs = [] 139 | support_ys = [] 140 | query_xs = [] 141 | query_ys = [] 142 | for idx, cls in enumerate(cls_sampled): 143 | imgs = np.asarray(self.data[cls]).astype('uint8') 144 | support_xs_ids_sampled = np.random.choice(range(imgs.shape[0]), self.n_shots, False) 145 | support_xs.append(imgs[support_xs_ids_sampled]) 146 | support_ys.append([idx] * self.n_shots) 147 | query_xs_ids = np.setxor1d(np.arange(imgs.shape[0]), support_xs_ids_sampled) 148 | query_xs_ids = np.random.choice(query_xs_ids, self.n_queries, False) 149 | query_xs.append(imgs[query_xs_ids]) 150 | query_ys.append([idx] * query_xs_ids.shape[0]) 151 | support_xs, support_ys, query_xs, query_ys = np.array(support_xs), np.array(support_ys), np.array( 152 | query_xs), np.array(query_ys) 153 | num_ways, n_queries_per_way, height, width, channel = query_xs.shape 154 | query_xs = query_xs.reshape((num_ways * n_queries_per_way, height, width, channel)) 155 | query_ys = query_ys.reshape((num_ways * n_queries_per_way, )) 156 | 157 | support_xs = support_xs.reshape((-1, height, width, channel)) 158 | if self.n_aug_support_samples > 1: 159 | support_xs = np.tile(support_xs, (self.n_aug_support_samples, 1, 1, 1)) 160 | support_ys = np.tile(support_ys.reshape((-1, )), (self.n_aug_support_samples)) 161 | support_xs = np.split(support_xs, support_xs.shape[0], axis=0) 162 | query_xs = query_xs.reshape((-1, height, width, channel)) 163 | query_xs = np.split(query_xs, query_xs.shape[0], axis=0) 164 | 165 | support_xs = torch.stack(list(map(lambda x: self.train_transform(x.squeeze()), support_xs))) 166 | query_xs = torch.stack(list(map(lambda x: self.test_transform(x.squeeze()), query_xs))) 167 | 168 | return support_xs, support_ys, query_xs, query_ys 169 | 170 | def __len__(self): 171 | return self.n_test_runs 172 | 173 | 174 | if __name__ == '__main__': 175 | args = lambda x: None 176 | args.n_ways = 5 177 | args.n_shots = 1 178 | args.n_queries = 12 179 | args.data_path = 'data' 180 | args.data_aug = True 181 | args.n_test_runs = 5 182 | args.n_aug_support_samples = 1 183 | imagenet = ImageNet(args, 'val') 184 | print(len(imagenet)) 185 | print(imagenet.__getitem__(500)[0].shape) 186 | 187 | metaimagenet = MetaImageNet(args) 188 | print(len(metaimagenet)) 189 | print(metaimagenet.__getitem__(500)[0].size()) 190 | print(metaimagenet.__getitem__(500)[1].shape) 191 | print(metaimagenet.__getitem__(500)[2].size()) 192 | print(metaimagenet.__getitem__(500)[3].shape) 193 | -------------------------------------------------------------------------------- /eval_lib/rfs_dataset/cifar.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import os 4 | import pickle 5 | from PIL import Image 6 | import numpy as np 7 | 8 | import torch 9 | import torchvision.transforms as transforms 10 | from torch.utils.data import Dataset 11 | 12 | 13 | class CIFAR100(Dataset): 14 | """support FC100 and CIFAR-FS""" 15 | def __init__(self, args, partition='train', pretrain=True, is_sample=False, k=4096, 16 | transform=None): 17 | super(Dataset, self).__init__() 18 | self.data_root = args.data_root 19 | self.partition = partition 20 | self.data_aug = args.data_aug 21 | self.mean = [0.5071, 0.4867, 0.4408] 22 | self.std = [0.2675, 0.2565, 0.2761] 23 | self.normalize = transforms.Normalize(mean=self.mean, std=self.std) 24 | self.pretrain = pretrain 25 | 26 | if transform is None: 27 | if self.partition == 'train' and self.data_aug: 28 | self.transform = transforms.Compose([ 29 | lambda x: Image.fromarray(x), 30 | transforms.RandomCrop(32, padding=4), 31 | transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4), 32 | transforms.RandomHorizontalFlip(), 33 | lambda x: np.asarray(x), 34 | transforms.ToTensor(), 35 | self.normalize 36 | ]) 37 | else: 38 | self.transform = transforms.Compose([ 39 | lambda x: Image.fromarray(x), 40 | transforms.ToTensor(), 41 | self.normalize 42 | ]) 43 | else: 44 | self.transform = transform 45 | 46 | if self.pretrain: 47 | self.file_pattern = '%s.pickle' 48 | else: 49 | self.file_pattern = '%s.pickle' 50 | self.data = {} 51 | 52 | with open(os.path.join(self.data_root, self.file_pattern % partition), 'rb') as f: 53 | data = pickle.load(f, encoding='latin1') 54 | self.imgs = data['data'] 55 | labels = data['labels'] 56 | # adjust sparse labels to labels from 0 to n. 57 | cur_class = 0 58 | label2label = {} 59 | for idx, label in enumerate(labels): 60 | if label not in label2label: 61 | label2label[label] = cur_class 62 | cur_class += 1 63 | new_labels = [] 64 | for idx, label in enumerate(labels): 65 | new_labels.append(label2label[label]) 66 | self.labels = new_labels 67 | 68 | # pre-process for contrastive sampling 69 | self.k = k 70 | self.is_sample = is_sample 71 | if self.is_sample: 72 | self.labels = np.asarray(self.labels) 73 | self.labels = self.labels - np.min(self.labels) 74 | num_classes = np.max(self.labels) + 1 75 | 76 | self.cls_positive = [[] for _ in range(num_classes)] 77 | for i in range(len(self.imgs)): 78 | self.cls_positive[self.labels[i]].append(i) 79 | 80 | self.cls_negative = [[] for _ in range(num_classes)] 81 | for i in range(num_classes): 82 | for j in range(num_classes): 83 | if j == i: 84 | continue 85 | self.cls_negative[i].extend(self.cls_positive[j]) 86 | 87 | self.cls_positive = [np.asarray(self.cls_positive[i]) for i in range(num_classes)] 88 | self.cls_negative = [np.asarray(self.cls_negative[i]) for i in range(num_classes)] 89 | self.cls_positive = np.asarray(self.cls_positive) 90 | self.cls_negative = np.asarray(self.cls_negative) 91 | 92 | def __getitem__(self, item): 93 | img = np.asarray(self.imgs[item]).astype('uint8') 94 | img = self.transform(img) 95 | target = self.labels[item] - min(self.labels) 96 | 97 | if not self.is_sample: 98 | return img, target, item 99 | else: 100 | pos_idx = item 101 | replace = True if self.k > len(self.cls_negative[target]) else False 102 | neg_idx = np.random.choice(self.cls_negative[target], self.k, replace=replace) 103 | sample_idx = np.hstack((np.asarray([pos_idx]), neg_idx)) 104 | return img, target, item, sample_idx 105 | 106 | def __len__(self): 107 | return len(self.labels) 108 | 109 | 110 | class MetaCIFAR100(CIFAR100): 111 | 112 | def __init__(self, args, partition='train', train_transform=None, test_transform=None, fix_seed=True): 113 | super(MetaCIFAR100, self).__init__(args, partition, False) 114 | self.fix_seed = fix_seed 115 | self.n_ways = args.n_ways 116 | self.n_shots = args.n_shots 117 | self.n_queries = args.n_queries 118 | self.classes = list(self.data.keys()) 119 | self.n_test_runs = args.n_test_runs 120 | self.n_aug_support_samples = args.n_aug_support_samples 121 | if train_transform is None: 122 | self.train_transform = transforms.Compose([ 123 | lambda x: Image.fromarray(x), 124 | transforms.RandomCrop(32, padding=4), 125 | transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4), 126 | transforms.RandomHorizontalFlip(), 127 | lambda x: np.asarray(x), 128 | transforms.ToTensor(), 129 | self.normalize 130 | ]) 131 | else: 132 | self.train_transform = train_transform 133 | 134 | if test_transform is None: 135 | self.test_transform = transforms.Compose([ 136 | lambda x: Image.fromarray(x), 137 | transforms.ToTensor(), 138 | self.normalize 139 | ]) 140 | else: 141 | self.test_transform = test_transform 142 | 143 | self.data = {} 144 | for idx in range(self.imgs.shape[0]): 145 | if self.labels[idx] not in self.data: 146 | self.data[self.labels[idx]] = [] 147 | self.data[self.labels[idx]].append(self.imgs[idx]) 148 | self.classes = list(self.data.keys()) 149 | 150 | def __getitem__(self, item): 151 | if self.fix_seed: 152 | np.random.seed(item) 153 | cls_sampled = np.random.choice(self.classes, self.n_ways, False) 154 | support_xs = [] 155 | support_ys = [] 156 | query_xs = [] 157 | query_ys = [] 158 | for idx, cls in enumerate(cls_sampled): 159 | imgs = np.asarray(self.data[cls]).astype('uint8') 160 | support_xs_ids_sampled = np.random.choice(range(imgs.shape[0]), self.n_shots, False) 161 | support_xs.append(imgs[support_xs_ids_sampled]) 162 | support_ys.append([idx] * self.n_shots) 163 | query_xs_ids = np.setxor1d(np.arange(imgs.shape[0]), support_xs_ids_sampled) 164 | query_xs_ids = np.random.choice(query_xs_ids, self.n_queries, False) 165 | query_xs.append(imgs[query_xs_ids]) 166 | query_ys.append([idx] * query_xs_ids.shape[0]) 167 | support_xs, support_ys, query_xs, query_ys = np.array(support_xs), np.array(support_ys), np.array( 168 | query_xs), np.array(query_ys) 169 | num_ways, n_queries_per_way, height, width, channel = query_xs.shape 170 | query_xs = query_xs.reshape((num_ways * n_queries_per_way, height, width, channel)) 171 | query_ys = query_ys.reshape((num_ways * n_queries_per_way,)) 172 | 173 | support_xs = support_xs.reshape((-1, height, width, channel)) 174 | if self.n_aug_support_samples > 1: 175 | support_xs = np.tile(support_xs, (self.n_aug_support_samples, 1, 1, 1)) 176 | support_ys = np.tile(support_ys.reshape((-1,)), (self.n_aug_support_samples)) 177 | support_xs = np.split(support_xs, support_xs.shape[0], axis=0) 178 | query_xs = query_xs.reshape((-1, height, width, channel)) 179 | query_xs = np.split(query_xs, query_xs.shape[0], axis=0) 180 | 181 | support_xs = torch.stack(list(map(lambda x: self.train_transform(x.squeeze()), support_xs))) 182 | query_xs = torch.stack(list(map(lambda x: self.test_transform(x.squeeze()), query_xs))) 183 | 184 | return support_xs, support_ys, query_xs, query_ys 185 | 186 | def __len__(self): 187 | return self.n_test_runs 188 | 189 | 190 | if __name__ == '__main__': 191 | args = lambda x: None 192 | args.n_ways = 5 193 | args.n_shots = 1 194 | args.n_queries = 12 195 | # args.data_root = 'data' 196 | args.data_root = '/home/yonglong/Downloads/FC100' 197 | args.data_aug = True 198 | args.n_test_runs = 5 199 | args.n_aug_support_samples = 1 200 | imagenet = CIFAR100(args, 'train') 201 | print(len(imagenet)) 202 | print(imagenet.__getitem__(500)[0].shape) 203 | 204 | metaimagenet = MetaCIFAR100(args, 'train') 205 | print(len(metaimagenet)) 206 | print(metaimagenet.__getitem__(500)[0].size()) 207 | print(metaimagenet.__getitem__(500)[1].shape) 208 | print(metaimagenet.__getitem__(500)[2].size()) 209 | print(metaimagenet.__getitem__(500)[3].shape) -------------------------------------------------------------------------------- /lib/models/cell_searchs/genotypes.py: -------------------------------------------------------------------------------- 1 | from copy import deepcopy 2 | 3 | 4 | def get_combination(space, num): 5 | combs = [] 6 | for i in range(num): 7 | if i == 0: 8 | for func in space: 9 | combs.append( [(func, i)] ) 10 | else: 11 | new_combs = [] 12 | for string in combs: 13 | for func in space: 14 | xstring = string + [(func, i)] 15 | new_combs.append( xstring ) 16 | combs = new_combs 17 | return combs 18 | 19 | 20 | class Structure: 21 | 22 | def __init__(self, genotype): 23 | assert isinstance(genotype, list) or isinstance(genotype, tuple), 'invalid class of genotype : {:}'.format(type(genotype)) 24 | self.node_num = len(genotype) + 1 25 | self.nodes = [] 26 | self.node_N = [] 27 | for idx, node_info in enumerate(genotype): 28 | assert isinstance(node_info, list) or isinstance(node_info, tuple), 'invalid class of node_info : {:}'.format(type(node_info)) 29 | assert len(node_info) >= 1, 'invalid length : {:}'.format(len(node_info)) 30 | for node_in in node_info: 31 | assert isinstance(node_in, list) or isinstance(node_in, tuple), 'invalid class of in-node : {:}'.format(type(node_in)) 32 | assert len(node_in) == 2 and node_in[1] <= idx, 'invalid in-node : {:}'.format(node_in) 33 | self.node_N.append( len(node_info) ) 34 | self.nodes.append( tuple(deepcopy(node_info)) ) 35 | 36 | def tolist(self, remove_str): 37 | # convert this class to the list, if remove_str is 'none', then remove the 'none' operation. 38 | # note that we re-order the input node in this function 39 | # return the-genotype-list and success [if unsuccess, it is not a connectivity] 40 | genotypes = [] 41 | for node_info in self.nodes: 42 | node_info = list( node_info ) 43 | node_info = sorted(node_info, key=lambda x: (x[1], x[0])) 44 | node_info = tuple(filter(lambda x: x[0] != remove_str, node_info)) 45 | if len(node_info) == 0: return None, False 46 | genotypes.append( node_info ) 47 | return genotypes, True 48 | 49 | def node(self, index): 50 | assert index > 0 and index <= len(self), 'invalid index={:} < {:}'.format(index, len(self)) 51 | return self.nodes[index] 52 | 53 | def tostr(self): 54 | strings = [] 55 | for node_info in self.nodes: 56 | string = '|'.join([x[0]+'~{:}'.format(x[1]) for x in node_info]) 57 | string = '|{:}|'.format(string) 58 | strings.append( string ) 59 | return '+'.join(strings) 60 | 61 | def check_valid(self): 62 | nodes = {0: True} 63 | for i, node_info in enumerate(self.nodes): 64 | sums = [] 65 | for op, xin in node_info: 66 | if op == 'none' or nodes[xin] is False: x = False 67 | else: x = True 68 | sums.append( x ) 69 | nodes[i+1] = sum(sums) > 0 70 | return nodes[len(self.nodes)] 71 | 72 | def to_unique_str(self, consider_zero=False): 73 | # this is used to identify the isomorphic cell, which rerquires the prior knowledge of operation 74 | # two operations are special, i.e., none and skip_connect 75 | nodes = {0: '0'} 76 | for i_node, node_info in enumerate(self.nodes): 77 | cur_node = [] 78 | for op, xin in node_info: 79 | if consider_zero is None: 80 | x = '('+nodes[xin]+')' + '@{:}'.format(op) 81 | elif consider_zero: 82 | if op == 'none' or nodes[xin] == '#': x = '#' # zero 83 | elif op == 'skip_connect': x = nodes[xin] 84 | else: x = '('+nodes[xin]+')' + '@{:}'.format(op) 85 | else: 86 | if op == 'skip_connect': x = nodes[xin] 87 | else: x = '('+nodes[xin]+')' + '@{:}'.format(op) 88 | cur_node.append(x) 89 | nodes[i_node+1] = '+'.join( sorted(cur_node) ) 90 | return nodes[ len(self.nodes) ] 91 | 92 | def check_valid_op(self, op_names): 93 | for node_info in self.nodes: 94 | for inode_edge in node_info: 95 | #assert inode_edge[0] in op_names, 'invalid op-name : {:}'.format(inode_edge[0]) 96 | if inode_edge[0] not in op_names: return False 97 | return True 98 | 99 | def __repr__(self): 100 | return ('{name}({node_num} nodes with {node_info})'.format(name=self.__class__.__name__, node_info=self.tostr(), **self.__dict__)) 101 | 102 | def __len__(self): 103 | return len(self.nodes) + 1 104 | 105 | def __getitem__(self, index): 106 | return self.nodes[index] 107 | 108 | @staticmethod 109 | def str2structure(xstr): 110 | assert isinstance(xstr, str), 'must take string (not {:}) as input'.format(type(xstr)) 111 | nodestrs = xstr.split('+') 112 | genotypes = [] 113 | for i, node_str in enumerate(nodestrs): 114 | inputs = list(filter(lambda x: x != '', node_str.split('|'))) 115 | for xinput in inputs: assert len(xinput.split('~')) == 2, 'invalid input length : {:}'.format(xinput) 116 | inputs = ( xi.split('~') for xi in inputs ) 117 | input_infos = tuple( (op, int(IDX)) for (op, IDX) in inputs) 118 | genotypes.append( input_infos ) 119 | return Structure( genotypes ) 120 | 121 | @staticmethod 122 | def str2fullstructure(xstr, default_name='none'): 123 | assert isinstance(xstr, str), 'must take string (not {:}) as input'.format(type(xstr)) 124 | nodestrs = xstr.split('+') 125 | genotypes = [] 126 | for i, node_str in enumerate(nodestrs): 127 | inputs = list(filter(lambda x: x != '', node_str.split('|'))) 128 | for xinput in inputs: assert len(xinput.split('~')) == 2, 'invalid input length : {:}'.format(xinput) 129 | inputs = ( xi.split('~') for xi in inputs ) 130 | input_infos = list( (op, int(IDX)) for (op, IDX) in inputs) 131 | all_in_nodes= list(x[1] for x in input_infos) 132 | for j in range(i): 133 | if j not in all_in_nodes: input_infos.append((default_name, j)) 134 | node_info = sorted(input_infos, key=lambda x: (x[1], x[0])) 135 | genotypes.append( tuple(node_info) ) 136 | return Structure( genotypes ) 137 | 138 | @staticmethod 139 | def gen_all(search_space, num, return_ori): 140 | assert isinstance(search_space, list) or isinstance(search_space, tuple), 'invalid class of search-space : {:}'.format(type(search_space)) 141 | assert num >= 2, 'There should be at least two nodes in a neural cell instead of {:}'.format(num) 142 | all_archs = get_combination(search_space, 1) 143 | for i, arch in enumerate(all_archs): 144 | all_archs[i] = [ tuple(arch) ] 145 | 146 | for inode in range(2, num): 147 | cur_nodes = get_combination(search_space, inode) 148 | new_all_archs = [] 149 | for previous_arch in all_archs: 150 | for cur_node in cur_nodes: 151 | new_all_archs.append( previous_arch + [tuple(cur_node)] ) 152 | all_archs = new_all_archs 153 | if return_ori: 154 | return all_archs 155 | else: 156 | return [Structure(x) for x in all_archs] 157 | 158 | 159 | 160 | ResNet_CODE = Structure( 161 | [(('nor_conv_3x3', 0), ), # node-1 162 | (('nor_conv_3x3', 1), ), # node-2 163 | (('skip_connect', 0), ('skip_connect', 2))] # node-3 164 | ) 165 | 166 | AllConv3x3_CODE = Structure( 167 | [(('nor_conv_3x3', 0), ), # node-1 168 | (('nor_conv_3x3', 0), ('nor_conv_3x3', 1)), # node-2 169 | (('nor_conv_3x3', 0), ('nor_conv_3x3', 1), ('nor_conv_3x3', 2))] # node-3 170 | ) 171 | 172 | AllFull_CODE = Structure( 173 | [(('skip_connect', 0), ('nor_conv_1x1', 0), ('nor_conv_3x3', 0), ('avg_pool_3x3', 0)), # node-1 174 | (('skip_connect', 0), ('nor_conv_1x1', 0), ('nor_conv_3x3', 0), ('avg_pool_3x3', 0), ('skip_connect', 1), ('nor_conv_1x1', 1), ('nor_conv_3x3', 1), ('avg_pool_3x3', 1)), # node-2 175 | (('skip_connect', 0), ('nor_conv_1x1', 0), ('nor_conv_3x3', 0), ('avg_pool_3x3', 0), ('skip_connect', 1), ('nor_conv_1x1', 1), ('nor_conv_3x3', 1), ('avg_pool_3x3', 1), ('skip_connect', 2), ('nor_conv_1x1', 2), ('nor_conv_3x3', 2), ('avg_pool_3x3', 2))] # node-3 176 | ) 177 | 178 | AllConv1x1_CODE = Structure( 179 | [(('nor_conv_1x1', 0), ), # node-1 180 | (('nor_conv_1x1', 0), ('nor_conv_1x1', 1)), # node-2 181 | (('nor_conv_1x1', 0), ('nor_conv_1x1', 1), ('nor_conv_1x1', 2))] # node-3 182 | ) 183 | 184 | AllIdentity_CODE = Structure( 185 | [(('skip_connect', 0), ), # node-1 186 | (('skip_connect', 0), ('skip_connect', 1)), # node-2 187 | (('skip_connect', 0), ('skip_connect', 1), ('skip_connect', 2))] # node-3 188 | ) 189 | 190 | architectures = {'resnet' : ResNet_CODE, 191 | 'all_c3x3': AllConv3x3_CODE, 192 | 'all_c1x1': AllConv1x1_CODE, 193 | 'all_idnt': AllIdentity_CODE, 194 | 'all_full': AllFull_CODE} 195 | -------------------------------------------------------------------------------- /eval_lib/rfs_dataset/mini_imagenet.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | from PIL import Image 4 | import numpy as np 5 | import torch 6 | from torch.utils.data import Dataset 7 | import torchvision.transforms as transforms 8 | 9 | 10 | class ImageNet(Dataset): 11 | def __init__(self, args, partition='train', pretrain=True, is_sample=False, k=4096, 12 | transform=None): 13 | super(Dataset, self).__init__() 14 | self.data_root = args.data_root 15 | self.partition = partition 16 | self.data_aug = args.data_aug 17 | self.mean = [120.39586422 / 255.0, 115.59361427 / 255.0, 104.54012653 / 255.0] 18 | self.std = [70.68188272 / 255.0, 68.27635443 / 255.0, 72.54505529 / 255.0] 19 | self.normalize = transforms.Normalize(mean=self.mean, std=self.std) 20 | self.pretrain = pretrain 21 | 22 | if transform is None: 23 | if self.partition == 'train' and self.data_aug: 24 | self.transform = transforms.Compose([ 25 | lambda x: Image.fromarray(x), 26 | transforms.RandomCrop(84, padding=8), 27 | transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4), 28 | transforms.RandomHorizontalFlip(), 29 | lambda x: np.asarray(x), 30 | transforms.ToTensor(), 31 | self.normalize 32 | ]) 33 | else: 34 | self.transform = transforms.Compose([ 35 | lambda x: Image.fromarray(x), 36 | transforms.ToTensor(), 37 | self.normalize 38 | ]) 39 | else: 40 | self.transform = transform 41 | 42 | if self.pretrain: 43 | self.file_pattern = 'miniImageNet_category_split_train_phase_%s.pickle' 44 | else: 45 | self.file_pattern = 'miniImageNet_category_split_%s.pickle' 46 | self.data = {} 47 | with open(os.path.join(self.data_root, self.file_pattern % partition), 'rb') as f: 48 | data = pickle.load(f, encoding='latin1') 49 | self.imgs = data['data'] 50 | self.labels = data['labels'] 51 | 52 | # pre-process for contrastive sampling 53 | self.k = k 54 | self.is_sample = is_sample 55 | if self.is_sample: 56 | self.labels = np.asarray(self.labels) 57 | self.labels = self.labels - np.min(self.labels) 58 | num_classes = np.max(self.labels) + 1 59 | 60 | self.cls_positive = [[] for _ in range(num_classes)] 61 | for i in range(len(self.imgs)): 62 | self.cls_positive[self.labels[i]].append(i) 63 | 64 | self.cls_negative = [[] for _ in range(num_classes)] 65 | for i in range(num_classes): 66 | for j in range(num_classes): 67 | if j == i: 68 | continue 69 | self.cls_negative[i].extend(self.cls_positive[j]) 70 | 71 | self.cls_positive = [np.asarray(self.cls_positive[i]) for i in range(num_classes)] 72 | self.cls_negative = [np.asarray(self.cls_negative[i]) for i in range(num_classes)] 73 | self.cls_positive = np.asarray(self.cls_positive) 74 | self.cls_negative = np.asarray(self.cls_negative) 75 | 76 | def __getitem__(self, item): 77 | img = np.asarray(self.imgs[item]).astype('uint8') 78 | img = self.transform(img) 79 | target = self.labels[item] - min(self.labels) 80 | 81 | if not self.is_sample: 82 | return img, target, item 83 | else: 84 | pos_idx = item 85 | replace = True if self.k > len(self.cls_negative[target]) else False 86 | neg_idx = np.random.choice(self.cls_negative[target], self.k, replace=replace) 87 | sample_idx = np.hstack((np.asarray([pos_idx]), neg_idx)) 88 | return img, target, item, sample_idx 89 | 90 | def __len__(self): 91 | return len(self.labels) 92 | 93 | 94 | class MetaImageNet(ImageNet): 95 | 96 | def __init__(self, args, partition='train', train_transform=None, test_transform=None, fix_seed=True, 97 | n_queries=None, n_test_runs=None, n_shots=None, n_aug_support_samples=None): 98 | super(MetaImageNet, self).__init__(args, partition, False) 99 | self.fix_seed = fix_seed 100 | self.n_ways = args.n_ways 101 | if n_shots is not None: 102 | self.n_shots = n_shots 103 | else: 104 | self.n_shots = args.n_shots 105 | if n_queries is not None: 106 | self.n_queries = n_queries 107 | else: 108 | self.n_queries = args.n_queries 109 | self.classes = list(self.data.keys()) 110 | if n_test_runs is not None: 111 | self.n_test_runs = n_test_runs 112 | else: 113 | self.n_test_runs = args.n_test_runs 114 | if n_aug_support_samples is not None: 115 | self.n_aug_support_samples = n_aug_support_samples 116 | else: 117 | self.n_aug_support_samples = args.n_aug_support_samples 118 | if train_transform is None: 119 | self.train_transform = transforms.Compose([ 120 | lambda x: Image.fromarray(x), 121 | transforms.RandomCrop(84, padding=8), 122 | transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4), 123 | transforms.RandomHorizontalFlip(), 124 | lambda x: np.asarray(x), 125 | transforms.ToTensor(), 126 | self.normalize 127 | ]) 128 | else: 129 | self.train_transform = train_transform 130 | 131 | if test_transform is None: 132 | self.test_transform = transforms.Compose([ 133 | lambda x: Image.fromarray(x), 134 | transforms.ToTensor(), 135 | self.normalize 136 | ]) 137 | else: 138 | self.test_transform = test_transform 139 | 140 | self.data = {} 141 | for idx in range(self.imgs.shape[0]): 142 | if self.labels[idx] not in self.data: 143 | self.data[self.labels[idx]] = [] 144 | self.data[self.labels[idx]].append(self.imgs[idx]) 145 | self.classes = list(self.data.keys()) 146 | 147 | def __getitem__(self, item): 148 | if self.fix_seed: 149 | np.random.seed(item) 150 | cls_sampled = np.random.choice(self.classes, self.n_ways, False) 151 | support_xs = [] 152 | support_ys = [] 153 | query_xs = [] 154 | query_ys = [] 155 | for idx, cls in enumerate(cls_sampled): 156 | imgs = np.asarray(self.data[cls]).astype('uint8') 157 | support_xs_ids_sampled = np.random.choice(range(imgs.shape[0]), self.n_shots, False) 158 | support_xs.append(imgs[support_xs_ids_sampled]) 159 | support_ys.append([idx] * self.n_shots) 160 | query_xs_ids = np.setxor1d(np.arange(imgs.shape[0]), support_xs_ids_sampled) 161 | query_xs_ids = np.random.choice(query_xs_ids, self.n_queries, False) 162 | query_xs.append(imgs[query_xs_ids]) 163 | query_ys.append([idx] * query_xs_ids.shape[0]) 164 | support_xs, support_ys, query_xs, query_ys = np.array(support_xs), np.array(support_ys), np.array( 165 | query_xs), np.array(query_ys) 166 | num_ways, n_queries_per_way, height, width, channel = query_xs.shape 167 | query_xs = query_xs.reshape((num_ways * n_queries_per_way, height, width, channel)) 168 | query_ys = query_ys.reshape((num_ways * n_queries_per_way, )) 169 | 170 | support_xs = support_xs.reshape((-1, height, width, channel)) 171 | if self.n_aug_support_samples > 1: 172 | support_xs = np.tile(support_xs, (self.n_aug_support_samples, 1, 1, 1)) 173 | support_ys = np.tile(support_ys.reshape((-1, )), (self.n_aug_support_samples)) 174 | support_xs = np.split(support_xs, support_xs.shape[0], axis=0) 175 | query_xs = query_xs.reshape((-1, height, width, channel)) 176 | query_xs = np.split(query_xs, query_xs.shape[0], axis=0) 177 | 178 | support_xs = torch.stack(list(map(lambda x: self.train_transform(x.squeeze()), support_xs))) 179 | query_xs = torch.stack(list(map(lambda x: self.test_transform(x.squeeze()), query_xs))) 180 | 181 | return support_xs, support_ys, query_xs, query_ys 182 | 183 | def __len__(self): 184 | return self.n_test_runs 185 | 186 | 187 | if __name__ == '__main__': 188 | args = lambda x: None 189 | args.n_ways = 5 190 | args.n_shots = 1 191 | args.n_queries = 12 192 | args.data_root = 'data' 193 | args.data_aug = True 194 | args.n_test_runs = 5 195 | args.n_aug_support_samples = 1 196 | imagenet = ImageNet(args, 'val') 197 | print(len(imagenet)) 198 | print(imagenet.__getitem__(500)[0].shape) 199 | 200 | metaimagenet = MetaImageNet(args) 201 | print(len(metaimagenet)) 202 | print(metaimagenet.__getitem__(500)[0].size()) 203 | print(metaimagenet.__getitem__(500)[1].shape) 204 | print(metaimagenet.__getitem__(500)[2].size()) 205 | print(metaimagenet.__getitem__(500)[3].shape) 206 | -------------------------------------------------------------------------------- /lib/datasets/tiered_imagenet.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | from PIL import Image 4 | import numpy as np 5 | import torch 6 | from torch.utils.data import Dataset 7 | import torchvision.transforms as transforms 8 | 9 | 10 | class TieredImageNet(Dataset): 11 | def __init__(self, args, partition='train', pretrain=True, is_sample=False, k=4096, 12 | transform=None): 13 | super(Dataset, self).__init__() 14 | self.data_path = args.data_path 15 | self.partition = partition 16 | self.data_aug = args.data_aug 17 | self.mean = [120.39586422 / 255.0, 115.59361427 / 255.0, 104.54012653 / 255.0] 18 | self.std = [70.68188272 / 255.0, 68.27635443 / 255.0, 72.54505529 / 255.0] 19 | 20 | self.normalize = transforms.Normalize(mean=self.mean, std=self.std) 21 | self.pretrain = pretrain 22 | 23 | if transform is None: 24 | if self.partition == 'train' and self.data_aug: 25 | self.transform = transforms.Compose([ 26 | lambda x: Image.fromarray(x), 27 | transforms.RandomCrop(84, padding=8), 28 | transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4), 29 | transforms.RandomHorizontalFlip(), 30 | lambda x: np.asarray(x), 31 | transforms.ToTensor(), 32 | self.normalize 33 | ]) 34 | else: 35 | self.transform = transforms.Compose([ 36 | lambda x: Image.fromarray(x), 37 | transforms.ToTensor(), 38 | self.normalize 39 | ]) 40 | else: 41 | self.transform = transform 42 | 43 | if self.pretrain: 44 | self.image_file_pattern = '%s_images.npz' 45 | self.label_file_pattern = '%s_labels.pkl' 46 | else: 47 | self.image_file_pattern = '%s_images.npz' 48 | self.label_file_pattern = '%s_labels.pkl' 49 | 50 | self.data = {} 51 | 52 | # modified code to load tieredImageNet 53 | image_file = os.path.join(self.data_path, self.image_file_pattern % partition) 54 | self.imgs = np.load(image_file)['images'] 55 | label_file = os.path.join(self.data_path, self.label_file_pattern % partition) 56 | self.labels = self._load_labels(label_file)['labels'] 57 | 58 | # pre-process for contrastive sampling 59 | self.k = k 60 | self.is_sample = is_sample 61 | if self.is_sample: 62 | self.labels = np.asarray(self.labels) 63 | self.labels = self.labels - np.min(self.labels) 64 | num_classes = np.max(self.labels) + 1 65 | 66 | self.cls_positive = [[] for _ in range(num_classes)] 67 | for i in range(len(self.imgs)): 68 | self.cls_positive[self.labels[i]].append(i) 69 | 70 | self.cls_negative = [[] for _ in range(num_classes)] 71 | for i in range(num_classes): 72 | for j in range(num_classes): 73 | if j == i: 74 | continue 75 | self.cls_negative[i].extend(self.cls_positive[j]) 76 | 77 | self.cls_positive = [np.asarray(self.cls_positive[i]) for i in range(num_classes)] 78 | self.cls_negative = [np.asarray(self.cls_negative[i]) for i in range(num_classes)] 79 | self.cls_positive = np.asarray(self.cls_positive) 80 | self.cls_negative = np.asarray(self.cls_negative) 81 | 82 | def __getitem__(self, item): 83 | img = np.asarray(self.imgs[item]).astype('uint8') 84 | img = self.transform(img) 85 | target = self.labels[item] - min(self.labels) 86 | 87 | if not self.is_sample: 88 | return img, target, item 89 | else: 90 | pos_idx = item 91 | replace = True if self.k > len(self.cls_negative[target]) else False 92 | neg_idx = np.random.choice(self.cls_negative[target], self.k, replace=replace) 93 | sample_idx = np.hstack((np.asarray([pos_idx]), neg_idx)) 94 | return img, target, item, sample_idx 95 | 96 | def __len__(self): 97 | return len(self.labels) 98 | 99 | @staticmethod 100 | def _load_labels(file): 101 | try: 102 | with open(file, 'rb') as fo: 103 | data = pickle.load(fo) 104 | return data 105 | except: 106 | with open(file, 'rb') as f: 107 | u = pickle._Unpickler(f) 108 | u.encoding = 'latin1' 109 | data = u.load() 110 | return data 111 | 112 | 113 | class MetaTieredImageNet(TieredImageNet): 114 | 115 | def __init__(self, args, partition='train', train_transform=None, test_transform=None, fix_seed=True): 116 | super(MetaTieredImageNet, self).__init__(args, partition, True) 117 | self.fix_seed = fix_seed 118 | self.n_ways = args.n_ways 119 | self.n_shots = args.n_shots 120 | self.n_queries = args.n_queries 121 | self.classes = list(self.data.keys()) 122 | self.n_test_runs = args.n_test_runs 123 | self.n_aug_support_samples = args.n_aug_support_samples 124 | if train_transform is None: 125 | self.train_transform = transforms.Compose([ 126 | lambda x: Image.fromarray(x), 127 | transforms.RandomCrop(84, padding=8), 128 | transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4), 129 | transforms.RandomHorizontalFlip(), 130 | lambda x: np.asarray(x), 131 | transforms.ToTensor(), 132 | self.normalize 133 | ]) 134 | else: 135 | self.train_transform = train_transform 136 | 137 | if test_transform is None: 138 | self.test_transform = transforms.Compose([ 139 | lambda x: Image.fromarray(x), 140 | transforms.ToTensor(), 141 | self.normalize 142 | ]) 143 | else: 144 | self.test_transform = test_transform 145 | 146 | self.data = {} 147 | for idx in range(self.imgs.shape[0]): 148 | if self.labels[idx] not in self.data: 149 | self.data[self.labels[idx]] = [] 150 | self.data[self.labels[idx]].append(self.imgs[idx]) 151 | self.classes = list(self.data.keys()) 152 | 153 | def __getitem__(self, item): 154 | if self.fix_seed: 155 | np.random.seed(item) 156 | cls_sampled = np.random.choice(self.classes, self.n_ways, False) 157 | support_xs = [] 158 | support_ys = [] 159 | query_xs = [] 160 | query_ys = [] 161 | for idx, cls in enumerate(cls_sampled): 162 | imgs = np.asarray(self.data[cls]).astype('uint8') 163 | support_xs_ids_sampled = np.random.choice(range(imgs.shape[0]), self.n_shots, False) 164 | support_xs.append(imgs[support_xs_ids_sampled]) 165 | support_ys.append([idx] * self.n_shots) 166 | query_xs_ids = np.setxor1d(np.arange(imgs.shape[0]), support_xs_ids_sampled) 167 | query_xs_ids = np.random.choice(query_xs_ids, self.n_queries, False) 168 | query_xs.append(imgs[query_xs_ids]) 169 | query_ys.append([idx] * query_xs_ids.shape[0]) 170 | support_xs, support_ys, query_xs, query_ys = np.array(support_xs), np.array(support_ys), np.array( 171 | query_xs), np.array(query_ys) 172 | num_ways, n_queries_per_way, height, width, channel = query_xs.shape 173 | query_xs = query_xs.reshape((num_ways * n_queries_per_way, height, width, channel)) 174 | query_ys = query_ys.reshape((num_ways * n_queries_per_way,)) 175 | 176 | support_xs = support_xs.reshape((-1, height, width, channel)) 177 | if self.n_aug_support_samples > 1: 178 | support_xs = np.tile(support_xs, (self.n_aug_support_samples, 1, 1, 1)) 179 | support_ys = np.tile(support_ys.reshape((-1,)), (self.n_aug_support_samples)) 180 | support_xs = np.split(support_xs, support_xs.shape[0], axis=0) 181 | query_xs = query_xs.reshape((-1, height, width, channel)) 182 | query_xs = np.split(query_xs, query_xs.shape[0], axis=0) 183 | 184 | support_xs = torch.stack(list(map(lambda x: self.train_transform(x.squeeze()), support_xs))) 185 | query_xs = torch.stack(list(map(lambda x: self.test_transform(x.squeeze()), query_xs))) 186 | 187 | return support_xs, support_ys, query_xs, query_ys 188 | 189 | def __len__(self): 190 | return self.n_test_runs 191 | 192 | 193 | if __name__ == '__main__': 194 | args = lambda x: None 195 | args.n_ways = 5 196 | args.n_shots = 1 197 | args.n_queries = 12 198 | # args.data_path = 'data' 199 | args.data_path = '/home/yonglong/Data/tiered-imagenet-kwon' 200 | args.data_aug = True 201 | args.n_test_runs = 5 202 | args.n_aug_support_samples = 1 203 | imagenet = TieredImageNet(args, 'train') 204 | print(len(imagenet)) 205 | print(imagenet.__getitem__(500)[0].shape) 206 | 207 | metaimagenet = MetaTieredImageNet(args) 208 | print(len(metaimagenet)) 209 | print(metaimagenet.__getitem__(500)[0].size()) 210 | print(metaimagenet.__getitem__(500)[1].shape) 211 | print(metaimagenet.__getitem__(500)[2].size()) 212 | print(metaimagenet.__getitem__(500)[3].shape) 213 | -------------------------------------------------------------------------------- /eval_lib/rfs_dataset/tiered_imagenet.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | from PIL import Image 4 | import numpy as np 5 | import torch 6 | from torch.utils.data import Dataset 7 | import torchvision.transforms as transforms 8 | 9 | 10 | class TieredImageNet(Dataset): 11 | def __init__(self, args, partition='train', pretrain=True, is_sample=False, k=4096, 12 | transform=None): 13 | super(Dataset, self).__init__() 14 | self.data_root = args.data_root 15 | self.partition = partition 16 | self.data_aug = args.data_aug 17 | self.mean = [120.39586422 / 255.0, 115.59361427 / 255.0, 104.54012653 / 255.0] 18 | self.std = [70.68188272 / 255.0, 68.27635443 / 255.0, 72.54505529 / 255.0] 19 | 20 | self.normalize = transforms.Normalize(mean=self.mean, std=self.std) 21 | self.pretrain = pretrain 22 | 23 | if transform is None: 24 | if self.partition == 'train' and self.data_aug: 25 | self.transform = transforms.Compose([ 26 | lambda x: Image.fromarray(x), 27 | transforms.RandomCrop(84, padding=8), 28 | transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4), 29 | transforms.RandomHorizontalFlip(), 30 | lambda x: np.asarray(x), 31 | transforms.ToTensor(), 32 | self.normalize 33 | ]) 34 | else: 35 | self.transform = transforms.Compose([ 36 | lambda x: Image.fromarray(x), 37 | transforms.ToTensor(), 38 | self.normalize 39 | ]) 40 | else: 41 | self.transform = transform 42 | 43 | if self.pretrain: 44 | self.image_file_pattern = '%s_images.npz' 45 | self.label_file_pattern = '%s_labels.pkl' 46 | else: 47 | self.image_file_pattern = '%s_images.npz' 48 | self.label_file_pattern = '%s_labels.pkl' 49 | 50 | self.data = {} 51 | 52 | # modified code to load tieredImageNet 53 | image_file = os.path.join(self.data_root, self.image_file_pattern % partition) 54 | self.imgs = np.load(image_file)['images'] 55 | label_file = os.path.join(self.data_root, self.label_file_pattern % partition) 56 | self.labels = self._load_labels(label_file)['labels'] 57 | 58 | # pre-process for contrastive sampling 59 | self.k = k 60 | self.is_sample = is_sample 61 | if self.is_sample: 62 | self.labels = np.asarray(self.labels) 63 | self.labels = self.labels - np.min(self.labels) 64 | num_classes = np.max(self.labels) + 1 65 | 66 | self.cls_positive = [[] for _ in range(num_classes)] 67 | for i in range(len(self.imgs)): 68 | self.cls_positive[self.labels[i]].append(i) 69 | 70 | self.cls_negative = [[] for _ in range(num_classes)] 71 | for i in range(num_classes): 72 | for j in range(num_classes): 73 | if j == i: 74 | continue 75 | self.cls_negative[i].extend(self.cls_positive[j]) 76 | 77 | self.cls_positive = [np.asarray(self.cls_positive[i]) for i in range(num_classes)] 78 | self.cls_negative = [np.asarray(self.cls_negative[i]) for i in range(num_classes)] 79 | self.cls_positive = np.asarray(self.cls_positive) 80 | self.cls_negative = np.asarray(self.cls_negative) 81 | 82 | def __getitem__(self, item): 83 | img = np.asarray(self.imgs[item]).astype('uint8') 84 | img = self.transform(img) 85 | target = self.labels[item] - min(self.labels) 86 | 87 | if not self.is_sample: 88 | return img, target, item 89 | else: 90 | pos_idx = item 91 | replace = True if self.k > len(self.cls_negative[target]) else False 92 | neg_idx = np.random.choice(self.cls_negative[target], self.k, replace=replace) 93 | sample_idx = np.hstack((np.asarray([pos_idx]), neg_idx)) 94 | return img, target, item, sample_idx 95 | 96 | def __len__(self): 97 | return len(self.labels) 98 | 99 | @staticmethod 100 | def _load_labels(file): 101 | try: 102 | with open(file, 'rb') as fo: 103 | data = pickle.load(fo) 104 | return data 105 | except: 106 | with open(file, 'rb') as f: 107 | u = pickle._Unpickler(f) 108 | u.encoding = 'latin1' 109 | data = u.load() 110 | return data 111 | 112 | 113 | class MetaTieredImageNet(TieredImageNet): 114 | 115 | def __init__(self, args, partition='train', train_transform=None, test_transform=None, fix_seed=True): 116 | super(MetaTieredImageNet, self).__init__(args, partition, False) 117 | self.fix_seed = fix_seed 118 | self.n_ways = args.n_ways 119 | self.n_shots = args.n_shots 120 | self.n_queries = args.n_queries 121 | self.classes = list(self.data.keys()) 122 | self.n_test_runs = args.n_test_runs 123 | self.n_aug_support_samples = args.n_aug_support_samples 124 | if train_transform is None: 125 | self.train_transform = transforms.Compose([ 126 | lambda x: Image.fromarray(x), 127 | transforms.RandomCrop(84, padding=8), 128 | transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4), 129 | transforms.RandomHorizontalFlip(), 130 | lambda x: np.asarray(x), 131 | transforms.ToTensor(), 132 | self.normalize 133 | ]) 134 | else: 135 | self.train_transform = train_transform 136 | 137 | if test_transform is None: 138 | self.test_transform = transforms.Compose([ 139 | lambda x: Image.fromarray(x), 140 | transforms.ToTensor(), 141 | self.normalize 142 | ]) 143 | else: 144 | self.test_transform = test_transform 145 | 146 | self.data = {} 147 | for idx in range(self.imgs.shape[0]): 148 | if self.labels[idx] not in self.data: 149 | self.data[self.labels[idx]] = [] 150 | self.data[self.labels[idx]].append(self.imgs[idx]) 151 | self.classes = list(self.data.keys()) 152 | 153 | def __getitem__(self, item): 154 | if self.fix_seed: 155 | np.random.seed(item) 156 | cls_sampled = np.random.choice(self.classes, self.n_ways, False) 157 | support_xs = [] 158 | support_ys = [] 159 | query_xs = [] 160 | query_ys = [] 161 | for idx, cls in enumerate(cls_sampled): 162 | imgs = np.asarray(self.data[cls]).astype('uint8') 163 | support_xs_ids_sampled = np.random.choice(range(imgs.shape[0]), self.n_shots, False) 164 | support_xs.append(imgs[support_xs_ids_sampled]) 165 | support_ys.append([idx] * self.n_shots) 166 | query_xs_ids = np.setxor1d(np.arange(imgs.shape[0]), support_xs_ids_sampled) 167 | query_xs_ids = np.random.choice(query_xs_ids, self.n_queries, False) 168 | query_xs.append(imgs[query_xs_ids]) 169 | query_ys.append([idx] * query_xs_ids.shape[0]) 170 | support_xs, support_ys, query_xs, query_ys = np.array(support_xs), np.array(support_ys), np.array( 171 | query_xs), np.array(query_ys) 172 | num_ways, n_queries_per_way, height, width, channel = query_xs.shape 173 | query_xs = query_xs.reshape((num_ways * n_queries_per_way, height, width, channel)) 174 | query_ys = query_ys.reshape((num_ways * n_queries_per_way,)) 175 | 176 | support_xs = support_xs.reshape((-1, height, width, channel)) 177 | if self.n_aug_support_samples > 1: 178 | support_xs = np.tile(support_xs, (self.n_aug_support_samples, 1, 1, 1)) 179 | support_ys = np.tile(support_ys.reshape((-1,)), (self.n_aug_support_samples)) 180 | support_xs = np.split(support_xs, support_xs.shape[0], axis=0) 181 | query_xs = query_xs.reshape((-1, height, width, channel)) 182 | query_xs = np.split(query_xs, query_xs.shape[0], axis=0) 183 | 184 | support_xs = torch.stack(list(map(lambda x: self.train_transform(x.squeeze()), support_xs))) 185 | query_xs = torch.stack(list(map(lambda x: self.test_transform(x.squeeze()), query_xs))) 186 | 187 | return support_xs, support_ys, query_xs, query_ys 188 | 189 | def __len__(self): 190 | return self.n_test_runs 191 | 192 | 193 | if __name__ == '__main__': 194 | args = lambda x: None 195 | args.n_ways = 5 196 | args.n_shots = 1 197 | args.n_queries = 12 198 | # args.data_root = 'data' 199 | args.data_root = '/home/yonglong/Data/tiered-imagenet-kwon' 200 | args.data_aug = True 201 | args.n_test_runs = 5 202 | args.n_aug_support_samples = 1 203 | imagenet = TieredImageNet(args, 'train') 204 | print(len(imagenet)) 205 | print(imagenet.__getitem__(500)[0].shape) 206 | 207 | metaimagenet = MetaTieredImageNet(args) 208 | print(len(metaimagenet)) 209 | print(metaimagenet.__getitem__(500)[0].size()) 210 | print(metaimagenet.__getitem__(500)[1].shape) 211 | print(metaimagenet.__getitem__(500)[2].size()) 212 | print(metaimagenet.__getitem__(500)[3].shape) 213 | --------------------------------------------------------------------------------