├── .gitignore ├── README.md ├── autodl ├── __init__.py ├── nas_201_api │ ├── __init__.py │ ├── api_201.py │ ├── api_301.py │ └── api_utils.py ├── procedures │ ├── __init__.py │ ├── basic_main.py │ ├── funcs_nasbench.py │ ├── optimizers.py │ ├── search_main.py │ ├── search_main_v2.py │ ├── simple_KD_main.py │ └── starts.py └── utils │ ├── __init__.py │ ├── affine_utils.py │ ├── evaluation_utils.py │ ├── flop_benchmark.py │ ├── gpu_manager.py │ ├── nas_utils.py │ └── weight_watcher.py ├── config_utils ├── __init__.py ├── attention_args.py ├── basic_args.py ├── cifar-split.txt ├── cls_init_args.py ├── cls_kd_args.py ├── configure_utils.py ├── pruning_args.py ├── random_baseline.py ├── search_args.py ├── search_single_args.py └── share_args.py ├── datasets ├── DownsampledImageNet.py ├── LandmarkDataset.py ├── SearchDatasetWrap.py ├── __init__.py ├── data.py ├── get_dataset_with_transform.py ├── landmark_utils │ ├── __init__.py │ └── point_meta.py └── test_utils.py ├── env.yml ├── models ├── CifarDenseNet.py ├── CifarResNet.py ├── CifarWideResNet.py ├── ImageNet_MobileNetV2.py ├── ImageNet_ResNet.py ├── SharedUtils.py ├── __init__.py ├── cell_infers │ ├── __init__.py │ ├── cells.py │ ├── nasnet_cifar.py │ └── tiny_network.py ├── cell_operations.py ├── cell_searchs │ ├── __init__.py │ ├── _test_module.py │ ├── genotypes.py │ ├── search_cells.py │ ├── search_model_darts.py │ ├── search_model_darts_nasnet.py │ ├── search_model_enas.py │ ├── search_model_enas_utils.py │ ├── search_model_gdas.py │ ├── search_model_gdas_nasnet.py │ ├── search_model_random.py │ ├── search_model_setn.py │ └── search_model_setn_nasnet.py ├── clone_weights.py ├── initialization.py ├── shape_infers │ ├── InferCifarResNet.py │ ├── InferCifarResNet_depth.py │ ├── InferCifarResNet_width.py │ ├── InferImagenetResNet.py │ ├── InferMobileNetV2.py │ ├── InferTinyCellNet.py │ ├── __init__.py │ └── shared_utils.py └── shape_searchs │ ├── SearchCifarResNet.py │ ├── SearchCifarResNet_depth.py │ ├── SearchCifarResNet_width.py │ ├── SearchImagenetResNet.py │ ├── SearchSimResNet_width.py │ ├── SoftSelect.py │ ├── __init__.py │ └── test.py ├── nas_101_api ├── __init__.py ├── base_ops.py ├── graph_util.py ├── model.py └── model_spec.py ├── nas_201_api ├── __init__.py ├── api.py ├── api_201.py └── api_utils.py ├── nasspace.py ├── plot_scores.py ├── pycls ├── core │ ├── __init__.py │ ├── benchmark.py │ ├── builders.py │ ├── checkpoint.py │ ├── config.py │ ├── distributed.py │ ├── io.py │ ├── logging.py │ ├── meters.py │ ├── net.py │ ├── optimizer.py │ ├── plotting.py │ ├── timer.py │ └── trainer.py └── models │ ├── __init__.py │ ├── anynet.py │ ├── common.py │ ├── effnet.py │ ├── nas │ ├── genotypes.py │ ├── nas.py │ └── operations.py │ ├── regnet.py │ └── resnet.py ├── score_networks.py ├── scorehook.sh ├── scores.py ├── search.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.pth 2 | __pycache__ 3 | *.t7 4 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Neural Architecture Search Without Training 2 | 3 | :warning: Note: this repository has been updated to reflect the [second version](https://arxiv.org/abs/2006.04647) of the paper 4 | For the [original version of the paper](https://arxiv.org/abs/2006.04647v1), refer to the tag [v1.0](https://github.com/BayesWatch/nas-without-training/releases/tag/v1.0).:warning: 5 | 6 | ## Usage 7 | 8 | Create a conda environment using the env.yml file 9 | 10 | ```bash 11 | conda env create -f env.yml 12 | ``` 13 | 14 | Activate the environment and follow the instructions to install 15 | 16 | Install nasbench (see https://github.com/google-research/nasbench) 17 | 18 | Download the NDS data from https://github.com/facebookresearch/nds and place the json files in naswot-codebase/nds_data/ 19 | Download the NASbench101 data (see https://github.com/google-research/nasbench) 20 | Download the NASbench201 data (see https://github.com/D-X-Y/NAS-Bench-201) 21 | 22 | Reproduce all of the results by running 23 | 24 | ```bash 25 | ./scorehook.sh 26 | ``` 27 | 28 | The code is licensed under the MIT licence. 29 | 30 | ## Citing us 31 | 32 | If you use or build on our work, please consider citing us: 33 | 34 | ```bibtex 35 | @inproceedings{mellor2021neural, 36 | title={Neural Architecture Search without Training}, 37 | author={Joseph Mellor and Jack Turner and Amos Storkey and Elliot J. Crowley}, 38 | year={2021}, 39 | booktitle={International Conference on Machine Learning} 40 | } 41 | ``` 42 | -------------------------------------------------------------------------------- /autodl/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /autodl/nas_201_api/__init__.py: -------------------------------------------------------------------------------- 1 | ##################################################### 2 | # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.08 # 3 | ##################################################### 4 | from .api_utils import ArchResults, ResultsCount 5 | from .api_201 import NASBench201API 6 | from .api_301 import NASBench301API 7 | 8 | # NAS_BENCH_201_API_VERSION="v1.1" # [2020.02.25] 9 | # NAS_BENCH_201_API_VERSION="v1.2" # [2020.03.09] 10 | # NAS_BENCH_201_API_VERSION="v1.3" # [2020.03.16] 11 | NAS_BENCH_201_API_VERSION="v2.0" # [2020.06.30] 12 | -------------------------------------------------------------------------------- /autodl/procedures/__init__.py: -------------------------------------------------------------------------------- 1 | ################################################## 2 | # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 # 3 | ################################################## 4 | from .starts import prepare_seed, prepare_logger, get_machine_info, save_checkpoint, copy_checkpoint 5 | from .optimizers import get_optim_scheduler 6 | from .funcs_nasbench import evaluate_for_seed as bench_evaluate_for_seed 7 | from .funcs_nasbench import pure_evaluate as bench_pure_evaluate 8 | from .funcs_nasbench import get_nas_bench_loaders 9 | 10 | def get_procedures(procedure): 11 | from .basic_main import basic_train, basic_valid 12 | from .search_main import search_train, search_valid 13 | from .search_main_v2 import search_train_v2 14 | from .simple_KD_main import simple_KD_train, simple_KD_valid 15 | 16 | train_funcs = {'basic' : basic_train, \ 17 | 'search': search_train,'Simple-KD': simple_KD_train, \ 18 | 'search-v2': search_train_v2} 19 | valid_funcs = {'basic' : basic_valid, \ 20 | 'search': search_valid,'Simple-KD': simple_KD_valid, \ 21 | 'search-v2': search_valid} 22 | 23 | train_func = train_funcs[procedure] 24 | valid_func = valid_funcs[procedure] 25 | return train_func, valid_func 26 | -------------------------------------------------------------------------------- /autodl/procedures/basic_main.py: -------------------------------------------------------------------------------- 1 | ################################################## 2 | # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 # 3 | ################################################## 4 | import os, sys, time, torch 5 | from log_utils import AverageMeter, time_string 6 | from utils import obtain_accuracy 7 | 8 | 9 | def basic_train(xloader, network, criterion, scheduler, optimizer, optim_config, extra_info, print_freq, logger): 10 | loss, acc1, acc5 = procedure(xloader, network, criterion, scheduler, optimizer, 'train', optim_config, extra_info, print_freq, logger) 11 | return loss, acc1, acc5 12 | 13 | 14 | def basic_valid(xloader, network, criterion, optim_config, extra_info, print_freq, logger): 15 | with torch.no_grad(): 16 | loss, acc1, acc5 = procedure(xloader, network, criterion, None, None, 'valid', None, extra_info, print_freq, logger) 17 | return loss, acc1, acc5 18 | 19 | 20 | def procedure(xloader, network, criterion, scheduler, optimizer, mode, config, extra_info, print_freq, logger): 21 | data_time, batch_time, losses, top1, top5 = AverageMeter(), AverageMeter(), AverageMeter(), AverageMeter(), AverageMeter() 22 | if mode == 'train': 23 | network.train() 24 | elif mode == 'valid': 25 | network.eval() 26 | else: raise ValueError("The mode is not right : {:}".format(mode)) 27 | 28 | #logger.log('[{:5s}] config :: auxiliary={:}, message={:}'.format(mode, config.auxiliary if hasattr(config, 'auxiliary') else -1, network.module.get_message())) 29 | logger.log('[{:5s}] config :: auxiliary={:}'.format(mode, config.auxiliary if hasattr(config, 'auxiliary') else -1)) 30 | end = time.time() 31 | for i, (inputs, targets) in enumerate(xloader): 32 | if mode == 'train': scheduler.update(None, 1.0 * i / len(xloader)) 33 | # measure data loading time 34 | data_time.update(time.time() - end) 35 | # calculate prediction and loss 36 | targets = targets.cuda(non_blocking=True) 37 | 38 | if mode == 'train': optimizer.zero_grad() 39 | 40 | features, logits = network(inputs) 41 | if isinstance(logits, list): 42 | assert len(logits) == 2, 'logits must has {:} items instead of {:}'.format(2, len(logits)) 43 | logits, logits_aux = logits 44 | else: 45 | logits, logits_aux = logits, None 46 | loss = criterion(logits, targets) 47 | if config is not None and hasattr(config, 'auxiliary') and config.auxiliary > 0: 48 | loss_aux = criterion(logits_aux, targets) 49 | loss += config.auxiliary * loss_aux 50 | 51 | if mode == 'train': 52 | loss.backward() 53 | optimizer.step() 54 | 55 | # record 56 | prec1, prec5 = obtain_accuracy(logits.data, targets.data, topk=(1, 5)) 57 | losses.update(loss.item(), inputs.size(0)) 58 | top1.update (prec1.item(), inputs.size(0)) 59 | top5.update (prec5.item(), inputs.size(0)) 60 | 61 | # measure elapsed time 62 | batch_time.update(time.time() - end) 63 | end = time.time() 64 | 65 | if i % print_freq == 0 or (i+1) == len(xloader): 66 | Sstr = ' {:5s} '.format(mode.upper()) + time_string() + ' [{:}][{:03d}/{:03d}]'.format(extra_info, i, len(xloader)) 67 | if scheduler is not None: 68 | Sstr += ' {:}'.format(scheduler.get_min_info()) 69 | Tstr = 'Time {batch_time.val:.2f} ({batch_time.avg:.2f}) Data {data_time.val:.2f} ({data_time.avg:.2f})'.format(batch_time=batch_time, data_time=data_time) 70 | Lstr = 'Loss {loss.val:.3f} ({loss.avg:.3f}) Prec@1 {top1.val:.2f} ({top1.avg:.2f}) Prec@5 {top5.val:.2f} ({top5.avg:.2f})'.format(loss=losses, top1=top1, top5=top5) 71 | Istr = 'Size={:}'.format(list(inputs.size())) 72 | logger.log(Sstr + ' ' + Tstr + ' ' + Lstr + ' ' + Istr) 73 | 74 | logger.log(' **{mode:5s}** Prec@1 {top1.avg:.2f} Prec@5 {top5.avg:.2f} Error@1 {error1:.2f} Error@5 {error5:.2f} Loss:{loss:.3f}'.format(mode=mode.upper(), top1=top1, top5=top5, error1=100-top1.avg, error5=100-top5.avg, loss=losses.avg)) 75 | return losses.avg, top1.avg, top5.avg 76 | -------------------------------------------------------------------------------- /autodl/procedures/search_main_v2.py: -------------------------------------------------------------------------------- 1 | ################################################## 2 | # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 # 3 | ################################################## 4 | import os, sys, time, torch 5 | from log_utils import AverageMeter, time_string 6 | from utils import obtain_accuracy 7 | from models import change_key 8 | 9 | 10 | def get_flop_loss(expected_flop, flop_cur, flop_need, flop_tolerant): 11 | expected_flop = torch.mean( expected_flop ) 12 | 13 | if flop_cur < flop_need - flop_tolerant: # Too Small FLOP 14 | loss = - torch.log( expected_flop ) 15 | #elif flop_cur > flop_need + flop_tolerant: # Too Large FLOP 16 | elif flop_cur > flop_need: # Too Large FLOP 17 | loss = torch.log( expected_flop ) 18 | else: # Required FLOP 19 | loss = None 20 | if loss is None: return 0, 0 21 | else : return loss, loss.item() 22 | 23 | 24 | def search_train_v2(search_loader, network, criterion, scheduler, base_optimizer, arch_optimizer, optim_config, extra_info, print_freq, logger): 25 | data_time, batch_time = AverageMeter(), AverageMeter() 26 | base_losses, arch_losses, top1, top5 = AverageMeter(), AverageMeter(), AverageMeter(), AverageMeter() 27 | arch_cls_losses, arch_flop_losses = AverageMeter(), AverageMeter() 28 | epoch_str, flop_need, flop_weight, flop_tolerant = extra_info['epoch-str'], extra_info['FLOP-exp'], extra_info['FLOP-weight'], extra_info['FLOP-tolerant'] 29 | 30 | network.train() 31 | logger.log('[Search] : {:}, FLOP-Require={:.2f} MB, FLOP-WEIGHT={:.2f}'.format(epoch_str, flop_need, flop_weight)) 32 | end = time.time() 33 | network.apply( change_key('search_mode', 'search') ) 34 | for step, (base_inputs, base_targets, arch_inputs, arch_targets) in enumerate(search_loader): 35 | scheduler.update(None, 1.0 * step / len(search_loader)) 36 | # calculate prediction and loss 37 | base_targets = base_targets.cuda(non_blocking=True) 38 | arch_targets = arch_targets.cuda(non_blocking=True) 39 | # measure data loading time 40 | data_time.update(time.time() - end) 41 | 42 | # update the weights 43 | base_optimizer.zero_grad() 44 | logits, expected_flop = network(base_inputs) 45 | base_loss = criterion(logits, base_targets) 46 | base_loss.backward() 47 | base_optimizer.step() 48 | # record 49 | prec1, prec5 = obtain_accuracy(logits.data, base_targets.data, topk=(1, 5)) 50 | base_losses.update(base_loss.item(), base_inputs.size(0)) 51 | top1.update (prec1.item(), base_inputs.size(0)) 52 | top5.update (prec5.item(), base_inputs.size(0)) 53 | 54 | # update the architecture 55 | arch_optimizer.zero_grad() 56 | logits, expected_flop = network(arch_inputs) 57 | flop_cur = network.module.get_flop('genotype', None, None) 58 | flop_loss, flop_loss_scale = get_flop_loss(expected_flop, flop_cur, flop_need, flop_tolerant) 59 | acls_loss = criterion(logits, arch_targets) 60 | arch_loss = acls_loss + flop_loss * flop_weight 61 | arch_loss.backward() 62 | arch_optimizer.step() 63 | 64 | # record 65 | arch_losses.update(arch_loss.item(), arch_inputs.size(0)) 66 | arch_flop_losses.update(flop_loss_scale, arch_inputs.size(0)) 67 | arch_cls_losses.update (acls_loss.item(), arch_inputs.size(0)) 68 | 69 | # measure elapsed time 70 | batch_time.update(time.time() - end) 71 | end = time.time() 72 | if step % print_freq == 0 or (step+1) == len(search_loader): 73 | Sstr = '**TRAIN** ' + time_string() + ' [{:}][{:03d}/{:03d}]'.format(epoch_str, step, len(search_loader)) 74 | Tstr = 'Time {batch_time.val:.2f} ({batch_time.avg:.2f}) Data {data_time.val:.2f} ({data_time.avg:.2f})'.format(batch_time=batch_time, data_time=data_time) 75 | Lstr = 'Base-Loss {loss.val:.3f} ({loss.avg:.3f}) Prec@1 {top1.val:.2f} ({top1.avg:.2f}) Prec@5 {top5.val:.2f} ({top5.avg:.2f})'.format(loss=base_losses, top1=top1, top5=top5) 76 | Vstr = 'Acls-loss {aloss.val:.3f} ({aloss.avg:.3f}) FLOP-Loss {floss.val:.3f} ({floss.avg:.3f}) Arch-Loss {loss.val:.3f} ({loss.avg:.3f})'.format(aloss=arch_cls_losses, floss=arch_flop_losses, loss=arch_losses) 77 | logger.log(Sstr + ' ' + Tstr + ' ' + Lstr + ' ' + Vstr) 78 | #num_bytes = torch.cuda.max_memory_allocated( next(network.parameters()).device ) * 1.0 79 | #logger.log(Sstr + ' ' + Tstr + ' ' + Lstr + ' ' + Vstr + ' GPU={:.2f}MB'.format(num_bytes/1e6)) 80 | #Istr = 'Bsz={:} Asz={:}'.format(list(base_inputs.size()), list(arch_inputs.size())) 81 | #logger.log(Sstr + ' ' + Tstr + ' ' + Lstr + ' ' + Vstr + ' ' + Istr) 82 | #print(network.module.get_arch_info()) 83 | #print(network.module.width_attentions[0]) 84 | #print(network.module.width_attentions[1]) 85 | 86 | logger.log(' **TRAIN** Prec@1 {top1.avg:.2f} Prec@5 {top5.avg:.2f} Error@1 {error1:.2f} Error@5 {error5:.2f} Base-Loss:{baseloss:.3f}, Arch-Loss={archloss:.3f}'.format(top1=top1, top5=top5, error1=100-top1.avg, error5=100-top5.avg, baseloss=base_losses.avg, archloss=arch_losses.avg)) 87 | return base_losses.avg, arch_losses.avg, top1.avg, top5.avg 88 | -------------------------------------------------------------------------------- /autodl/procedures/simple_KD_main.py: -------------------------------------------------------------------------------- 1 | ##################################################### 2 | # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.01 # 3 | ##################################################### 4 | import os, sys, time, torch 5 | import torch.nn.functional as F 6 | # our modules 7 | from log_utils import AverageMeter, time_string 8 | from utils import obtain_accuracy 9 | 10 | 11 | def simple_KD_train(xloader, teacher, network, criterion, scheduler, optimizer, optim_config, extra_info, print_freq, logger): 12 | loss, acc1, acc5 = procedure(xloader, teacher, network, criterion, scheduler, optimizer, 'train', optim_config, extra_info, print_freq, logger) 13 | return loss, acc1, acc5 14 | 15 | def simple_KD_valid(xloader, teacher, network, criterion, optim_config, extra_info, print_freq, logger): 16 | with torch.no_grad(): 17 | loss, acc1, acc5 = procedure(xloader, teacher, network, criterion, None, None, 'valid', optim_config, extra_info, print_freq, logger) 18 | return loss, acc1, acc5 19 | 20 | 21 | def loss_KD_fn(criterion, student_logits, teacher_logits, studentFeatures, teacherFeatures, targets, alpha, temperature): 22 | basic_loss = criterion(student_logits, targets) * (1. - alpha) 23 | log_student= F.log_softmax(student_logits / temperature, dim=1) 24 | sof_teacher= F.softmax (teacher_logits / temperature, dim=1) 25 | KD_loss = F.kl_div(log_student, sof_teacher, reduction='batchmean') * (alpha * temperature * temperature) 26 | return basic_loss + KD_loss 27 | 28 | 29 | def procedure(xloader, teacher, network, criterion, scheduler, optimizer, mode, config, extra_info, print_freq, logger): 30 | data_time, batch_time, losses, top1, top5 = AverageMeter(), AverageMeter(), AverageMeter(), AverageMeter(), AverageMeter() 31 | Ttop1, Ttop5 = AverageMeter(), AverageMeter() 32 | if mode == 'train': 33 | network.train() 34 | elif mode == 'valid': 35 | network.eval() 36 | else: raise ValueError("The mode is not right : {:}".format(mode)) 37 | teacher.eval() 38 | 39 | logger.log('[{:5s}] config :: auxiliary={:}, KD :: [alpha={:.2f}, temperature={:.2f}]'.format(mode, config.auxiliary if hasattr(config, 'auxiliary') else -1, config.KD_alpha, config.KD_temperature)) 40 | end = time.time() 41 | for i, (inputs, targets) in enumerate(xloader): 42 | if mode == 'train': scheduler.update(None, 1.0 * i / len(xloader)) 43 | # measure data loading time 44 | data_time.update(time.time() - end) 45 | # calculate prediction and loss 46 | targets = targets.cuda(non_blocking=True) 47 | 48 | if mode == 'train': optimizer.zero_grad() 49 | 50 | student_f, logits = network(inputs) 51 | if isinstance(logits, list): 52 | assert len(logits) == 2, 'logits must has {:} items instead of {:}'.format(2, len(logits)) 53 | logits, logits_aux = logits 54 | else: 55 | logits, logits_aux = logits, None 56 | with torch.no_grad(): 57 | teacher_f, teacher_logits = teacher(inputs) 58 | 59 | loss = loss_KD_fn(criterion, logits, teacher_logits, student_f, teacher_f, targets, config.KD_alpha, config.KD_temperature) 60 | if config is not None and hasattr(config, 'auxiliary') and config.auxiliary > 0: 61 | loss_aux = criterion(logits_aux, targets) 62 | loss += config.auxiliary * loss_aux 63 | 64 | if mode == 'train': 65 | loss.backward() 66 | optimizer.step() 67 | 68 | # record 69 | sprec1, sprec5 = obtain_accuracy(logits.data, targets.data, topk=(1, 5)) 70 | losses.update(loss.item(), inputs.size(0)) 71 | top1.update (sprec1.item(), inputs.size(0)) 72 | top5.update (sprec5.item(), inputs.size(0)) 73 | # teacher 74 | tprec1, tprec5 = obtain_accuracy(teacher_logits.data, targets.data, topk=(1, 5)) 75 | Ttop1.update (tprec1.item(), inputs.size(0)) 76 | Ttop5.update (tprec5.item(), inputs.size(0)) 77 | 78 | # measure elapsed time 79 | batch_time.update(time.time() - end) 80 | end = time.time() 81 | 82 | if i % print_freq == 0 or (i+1) == len(xloader): 83 | Sstr = ' {:5s} '.format(mode.upper()) + time_string() + ' [{:}][{:03d}/{:03d}]'.format(extra_info, i, len(xloader)) 84 | if scheduler is not None: 85 | Sstr += ' {:}'.format(scheduler.get_min_info()) 86 | Tstr = 'Time {batch_time.val:.2f} ({batch_time.avg:.2f}) Data {data_time.val:.2f} ({data_time.avg:.2f})'.format(batch_time=batch_time, data_time=data_time) 87 | Lstr = 'Loss {loss.val:.3f} ({loss.avg:.3f}) Prec@1 {top1.val:.2f} ({top1.avg:.2f}) Prec@5 {top5.val:.2f} ({top5.avg:.2f})'.format(loss=losses, top1=top1, top5=top5) 88 | Lstr+= ' Teacher : acc@1={:.2f}, acc@5={:.2f}'.format(Ttop1.avg, Ttop5.avg) 89 | Istr = 'Size={:}'.format(list(inputs.size())) 90 | logger.log(Sstr + ' ' + Tstr + ' ' + Lstr + ' ' + Istr) 91 | 92 | logger.log(' **{:5s}** accuracy drop :: @1={:.2f}, @5={:.2f}'.format(mode.upper(), Ttop1.avg - top1.avg, Ttop5.avg - top5.avg)) 93 | logger.log(' **{mode:5s}** Prec@1 {top1.avg:.2f} Prec@5 {top5.avg:.2f} Error@1 {error1:.2f} Error@5 {error5:.2f} Loss:{loss:.3f}'.format(mode=mode.upper(), top1=top1, top5=top5, error1=100-top1.avg, error5=100-top5.avg, loss=losses.avg)) 94 | return losses.avg, top1.avg, top5.avg 95 | -------------------------------------------------------------------------------- /autodl/procedures/starts.py: -------------------------------------------------------------------------------- 1 | ################################################## 2 | # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 # 3 | ################################################## 4 | import os, sys, torch, random, PIL, copy, numpy as np 5 | from os import path as osp 6 | from shutil import copyfile 7 | 8 | 9 | def prepare_seed(rand_seed): 10 | random.seed(rand_seed) 11 | np.random.seed(rand_seed) 12 | torch.manual_seed(rand_seed) 13 | torch.cuda.manual_seed(rand_seed) 14 | torch.cuda.manual_seed_all(rand_seed) 15 | 16 | 17 | def prepare_logger(xargs): 18 | args = copy.deepcopy( xargs ) 19 | from autodl.log_utils import Logger 20 | logger = Logger(args.save_dir, args.rand_seed) 21 | logger.log('Main Function with logger : {:}'.format(logger)) 22 | logger.log('Arguments : -------------------------------') 23 | for name, value in args._get_kwargs(): 24 | logger.log('{:16} : {:}'.format(name, value)) 25 | logger.log("Python Version : {:}".format(sys.version.replace('\n', ' '))) 26 | logger.log("Pillow Version : {:}".format(PIL.__version__)) 27 | logger.log("PyTorch Version : {:}".format(torch.__version__)) 28 | logger.log("cuDNN Version : {:}".format(torch.backends.cudnn.version())) 29 | logger.log("CUDA available : {:}".format(torch.cuda.is_available())) 30 | logger.log("CUDA GPU numbers : {:}".format(torch.cuda.device_count())) 31 | logger.log("CUDA_VISIBLE_DEVICES : {:}".format(os.environ['CUDA_VISIBLE_DEVICES'] if 'CUDA_VISIBLE_DEVICES' in os.environ else 'None')) 32 | return logger 33 | 34 | 35 | def get_machine_info(): 36 | info = "Python Version : {:}".format(sys.version.replace('\n', ' ')) 37 | info+= "\nPillow Version : {:}".format(PIL.__version__) 38 | info+= "\nPyTorch Version : {:}".format(torch.__version__) 39 | info+= "\ncuDNN Version : {:}".format(torch.backends.cudnn.version()) 40 | info+= "\nCUDA available : {:}".format(torch.cuda.is_available()) 41 | info+= "\nCUDA GPU numbers : {:}".format(torch.cuda.device_count()) 42 | if 'CUDA_VISIBLE_DEVICES' in os.environ: 43 | info+= "\nCUDA_VISIBLE_DEVICES={:}".format(os.environ['CUDA_VISIBLE_DEVICES']) 44 | else: 45 | info+= "\nDoes not set CUDA_VISIBLE_DEVICES" 46 | return info 47 | 48 | 49 | def save_checkpoint(state, filename, logger): 50 | if osp.isfile(filename): 51 | if hasattr(logger, 'log'): logger.log('Find {:} exist, delete is at first before saving'.format(filename)) 52 | os.remove(filename) 53 | torch.save(state, filename) 54 | assert osp.isfile(filename), 'save filename : {:} failed, which is not found.'.format(filename) 55 | if hasattr(logger, 'log'): logger.log('save checkpoint into {:}'.format(filename)) 56 | return filename 57 | 58 | 59 | def copy_checkpoint(src, dst, logger): 60 | if osp.isfile(dst): 61 | if hasattr(logger, 'log'): logger.log('Find {:} exist, delete is at first before saving'.format(dst)) 62 | os.remove(dst) 63 | copyfile(src, dst) 64 | if hasattr(logger, 'log'): logger.log('copy the file from {:} into {:}'.format(src, dst)) 65 | -------------------------------------------------------------------------------- /autodl/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .evaluation_utils import obtain_accuracy 2 | from .gpu_manager import GPUManager 3 | from .flop_benchmark import get_model_infos, count_parameters_in_MB 4 | from .affine_utils import normalize_points, denormalize_points 5 | from .affine_utils import identity2affine, solve2theta, affine2image 6 | -------------------------------------------------------------------------------- /autodl/utils/affine_utils.py: -------------------------------------------------------------------------------- 1 | # functions for affine transformation 2 | import math, torch 3 | import numpy as np 4 | import torch.nn.functional as F 5 | 6 | def identity2affine(full=False): 7 | if not full: 8 | parameters = torch.zeros((2,3)) 9 | parameters[0, 0] = parameters[1, 1] = 1 10 | else: 11 | parameters = torch.zeros((3,3)) 12 | parameters[0, 0] = parameters[1, 1] = parameters[2, 2] = 1 13 | return parameters 14 | 15 | def normalize_L(x, L): 16 | return -1. + 2. * x / (L-1) 17 | 18 | def denormalize_L(x, L): 19 | return (x + 1.0) / 2.0 * (L-1) 20 | 21 | def crop2affine(crop_box, W, H): 22 | assert len(crop_box) == 4, 'Invalid crop-box : {:}'.format(crop_box) 23 | parameters = torch.zeros(3,3) 24 | x1, y1 = normalize_L(crop_box[0], W), normalize_L(crop_box[1], H) 25 | x2, y2 = normalize_L(crop_box[2], W), normalize_L(crop_box[3], H) 26 | parameters[0,0] = (x2-x1)/2 27 | parameters[0,2] = (x2+x1)/2 28 | 29 | parameters[1,1] = (y2-y1)/2 30 | parameters[1,2] = (y2+y1)/2 31 | parameters[2,2] = 1 32 | return parameters 33 | 34 | def scale2affine(scalex, scaley): 35 | parameters = torch.zeros(3,3) 36 | parameters[0,0] = scalex 37 | parameters[1,1] = scaley 38 | parameters[2,2] = 1 39 | return parameters 40 | 41 | def offset2affine(offx, offy): 42 | parameters = torch.zeros(3,3) 43 | parameters[0,0] = parameters[1,1] = parameters[2,2] = 1 44 | parameters[0,2] = offx 45 | parameters[1,2] = offy 46 | return parameters 47 | 48 | def horizontalmirror2affine(): 49 | parameters = torch.zeros(3,3) 50 | parameters[0,0] = -1 51 | parameters[1,1] = parameters[2,2] = 1 52 | return parameters 53 | 54 | # clockwise rotate image = counterclockwise rotate the rectangle 55 | # degree is between [0, 360] 56 | def rotate2affine(degree): 57 | assert degree >= 0 and degree <= 360, 'Invalid degree : {:}'.format(degree) 58 | degree = degree / 180 * math.pi 59 | parameters = torch.zeros(3,3) 60 | parameters[0,0] = math.cos(-degree) 61 | parameters[0,1] = -math.sin(-degree) 62 | parameters[1,0] = math.sin(-degree) 63 | parameters[1,1] = math.cos(-degree) 64 | parameters[2,2] = 1 65 | return parameters 66 | 67 | # shape is a tuple [H, W] 68 | def normalize_points(shape, points): 69 | assert (isinstance(shape, tuple) or isinstance(shape, list)) and len(shape) == 2, 'invalid shape : {:}'.format(shape) 70 | assert isinstance(points, torch.Tensor) and (points.shape[0] == 2), 'points are wrong : {:}'.format(points.shape) 71 | (H, W), points = shape, points.clone() 72 | points[0, :] = normalize_L(points[0,:], W) 73 | points[1, :] = normalize_L(points[1,:], H) 74 | return points 75 | 76 | # shape is a tuple [H, W] 77 | def normalize_points_batch(shape, points): 78 | assert (isinstance(shape, tuple) or isinstance(shape, list)) and len(shape) == 2, 'invalid shape : {:}'.format(shape) 79 | assert isinstance(points, torch.Tensor) and (points.size(-1) == 2), 'points are wrong : {:}'.format(points.shape) 80 | (H, W), points = shape, points.clone() 81 | x = normalize_L(points[...,0], W) 82 | y = normalize_L(points[...,1], H) 83 | return torch.stack((x,y), dim=-1) 84 | 85 | # shape is a tuple [H, W] 86 | def denormalize_points(shape, points): 87 | assert (isinstance(shape, tuple) or isinstance(shape, list)) and len(shape) == 2, 'invalid shape : {:}'.format(shape) 88 | assert isinstance(points, torch.Tensor) and (points.shape[0] == 2), 'points are wrong : {:}'.format(points.shape) 89 | (H, W), points = shape, points.clone() 90 | points[0, :] = denormalize_L(points[0,:], W) 91 | points[1, :] = denormalize_L(points[1,:], H) 92 | return points 93 | 94 | # shape is a tuple [H, W] 95 | def denormalize_points_batch(shape, points): 96 | assert (isinstance(shape, tuple) or isinstance(shape, list)) and len(shape) == 2, 'invalid shape : {:}'.format(shape) 97 | assert isinstance(points, torch.Tensor) and (points.shape[-1] == 2), 'points are wrong : {:}'.format(points.shape) 98 | (H, W), points = shape, points.clone() 99 | x = denormalize_L(points[...,0], W) 100 | y = denormalize_L(points[...,1], H) 101 | return torch.stack((x,y), dim=-1) 102 | 103 | # make target * theta = source 104 | def solve2theta(source, target): 105 | source, target = source.clone(), target.clone() 106 | oks = source[2, :] == 1 107 | assert torch.sum(oks).item() >= 3, 'valid points : {:} is short'.format(oks) 108 | if target.size(0) == 2: target = torch.cat((target, oks.unsqueeze(0).float()), dim=0) 109 | source, target = source[:, oks], target[:, oks] 110 | source, target = source.transpose(1,0), target.transpose(1,0) 111 | assert source.size(1) == target.size(1) == 3 112 | #X, residual, rank, s = np.linalg.lstsq(target.numpy(), source.numpy()) 113 | #theta = torch.Tensor(X.T[:2, :]) 114 | X_, qr = torch.gels(source, target) 115 | theta = X_[:3, :2].transpose(1, 0) 116 | return theta 117 | 118 | # shape = [H,W] 119 | def affine2image(image, theta, shape): 120 | C, H, W = image.size() 121 | theta = theta[:2, :].unsqueeze(0) 122 | grid_size = torch.Size([1, C, shape[0], shape[1]]) 123 | grid = F.affine_grid(theta, grid_size) 124 | affI = F.grid_sample(image.unsqueeze(0), grid, mode='bilinear', padding_mode='border') 125 | return affI.squeeze(0) 126 | -------------------------------------------------------------------------------- /autodl/utils/evaluation_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | def obtain_accuracy(output, target, topk=(1,)): 4 | """Computes the precision@k for the specified values of k""" 5 | maxk = max(topk) 6 | batch_size = target.size(0) 7 | 8 | _, pred = output.topk(maxk, 1, True, True) 9 | pred = pred.t() 10 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 11 | 12 | res = [] 13 | for k in topk: 14 | correct_k = correct[:k].view(-1).float().sum(0, keepdim=True) 15 | res.append(correct_k.mul_(100.0 / batch_size)) 16 | return res 17 | -------------------------------------------------------------------------------- /autodl/utils/gpu_manager.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | class GPUManager(): 4 | queries = ('index', 'gpu_name', 'memory.free', 'memory.used', 'memory.total', 'power.draw', 'power.limit') 5 | 6 | def __init__(self): 7 | all_gpus = self.query_gpu(False) 8 | 9 | def get_info(self, ctype): 10 | cmd = 'nvidia-smi --query-gpu={} --format=csv,noheader'.format(ctype) 11 | lines = os.popen(cmd).readlines() 12 | lines = [line.strip('\n') for line in lines] 13 | return lines 14 | 15 | def query_gpu(self, show=True): 16 | num_gpus = len( self.get_info('index') ) 17 | all_gpus = [ {} for i in range(num_gpus) ] 18 | for query in self.queries: 19 | infos = self.get_info(query) 20 | for idx, info in enumerate(infos): 21 | all_gpus[idx][query] = info 22 | 23 | if 'CUDA_VISIBLE_DEVICES' in os.environ: 24 | CUDA_VISIBLE_DEVICES = os.environ['CUDA_VISIBLE_DEVICES'].split(',') 25 | selected_gpus = [] 26 | for idx, CUDA_VISIBLE_DEVICE in enumerate(CUDA_VISIBLE_DEVICES): 27 | find = False 28 | for gpu in all_gpus: 29 | if gpu['index'] == CUDA_VISIBLE_DEVICE: 30 | assert not find, 'Duplicate cuda device index : {}'.format(CUDA_VISIBLE_DEVICE) 31 | find = True 32 | selected_gpus.append( gpu.copy() ) 33 | selected_gpus[-1]['index'] = '{}'.format(idx) 34 | assert find, 'Does not find the device : {}'.format(CUDA_VISIBLE_DEVICE) 35 | all_gpus = selected_gpus 36 | 37 | if show: 38 | allstrings = '' 39 | for gpu in all_gpus: 40 | string = '| ' 41 | for query in self.queries: 42 | if query.find('memory') == 0: xinfo = '{:>9}'.format(gpu[query]) 43 | else: xinfo = gpu[query] 44 | string = string + query + ' : ' + xinfo + ' | ' 45 | allstrings = allstrings + string + '\n' 46 | return allstrings 47 | else: 48 | return all_gpus 49 | 50 | def select_by_memory(self, numbers=1): 51 | all_gpus = self.query_gpu(False) 52 | assert numbers <= len(all_gpus), 'Require {} gpus more than you have'.format(numbers) 53 | alls = [] 54 | for idx, gpu in enumerate(all_gpus): 55 | free_memory = gpu['memory.free'] 56 | free_memory = free_memory.split(' ')[0] 57 | free_memory = int(free_memory) 58 | index = gpu['index'] 59 | alls.append((free_memory, index)) 60 | alls.sort(reverse = True) 61 | alls = [ int(alls[i][1]) for i in range(numbers) ] 62 | return sorted(alls) 63 | 64 | """ 65 | if __name__ == '__main__': 66 | manager = GPUManager() 67 | manager.query_gpu(True) 68 | indexes = manager.select_by_memory(3) 69 | print (indexes) 70 | """ 71 | -------------------------------------------------------------------------------- /autodl/utils/nas_utils.py: -------------------------------------------------------------------------------- 1 | # This file is for experimental usage 2 | import torch, random 3 | import numpy as np 4 | from copy import deepcopy 5 | import torch.nn as nn 6 | 7 | # from utils import obtain_accuracy 8 | from models import CellStructure 9 | from log_utils import time_string 10 | 11 | 12 | def evaluate_one_shot(model, xloader, api, cal_mode, seed=111): 13 | print ('This is an old version of codes to use NAS-Bench-API, and should be modified to align with the new version. Please contact me for more details if you use this function.') 14 | weights = deepcopy(model.state_dict()) 15 | model.train(cal_mode) 16 | with torch.no_grad(): 17 | logits = nn.functional.log_softmax(model.arch_parameters, dim=-1) 18 | archs = CellStructure.gen_all(model.op_names, model.max_nodes, False) 19 | probs, accuracies, gt_accs_10_valid, gt_accs_10_test = [], [], [], [] 20 | loader_iter = iter(xloader) 21 | random.seed(seed) 22 | random.shuffle(archs) 23 | for idx, arch in enumerate(archs): 24 | arch_index = api.query_index_by_arch( arch ) 25 | metrics = api.get_more_info(arch_index, 'cifar10-valid', None, False, False) 26 | gt_accs_10_valid.append( metrics['valid-accuracy'] ) 27 | metrics = api.get_more_info(arch_index, 'cifar10', None, False, False) 28 | gt_accs_10_test.append( metrics['test-accuracy'] ) 29 | select_logits = [] 30 | for i, node_info in enumerate(arch.nodes): 31 | for op, xin in node_info: 32 | node_str = '{:}<-{:}'.format(i+1, xin) 33 | op_index = model.op_names.index(op) 34 | select_logits.append( logits[model.edge2index[node_str], op_index] ) 35 | cur_prob = sum(select_logits).item() 36 | probs.append( cur_prob ) 37 | cor_prob_valid = np.corrcoef(probs, gt_accs_10_valid)[0,1] 38 | cor_prob_test = np.corrcoef(probs, gt_accs_10_test )[0,1] 39 | print ('{:} correlation for probabilities : {:.6f} on CIFAR-10 validation and {:.6f} on CIFAR-10 test'.format(time_string(), cor_prob_valid, cor_prob_test)) 40 | 41 | for idx, arch in enumerate(archs): 42 | model.set_cal_mode('dynamic', arch) 43 | try: 44 | inputs, targets = next(loader_iter) 45 | except: 46 | loader_iter = iter(xloader) 47 | inputs, targets = next(loader_iter) 48 | _, logits = model(inputs.cuda()) 49 | _, preds = torch.max(logits, dim=-1) 50 | correct = (preds == targets.cuda() ).float() 51 | accuracies.append( correct.mean().item() ) 52 | if idx != 0 and (idx % 500 == 0 or idx + 1 == len(archs)): 53 | cor_accs_valid = np.corrcoef(accuracies, gt_accs_10_valid[:idx+1])[0,1] 54 | cor_accs_test = np.corrcoef(accuracies, gt_accs_10_test [:idx+1])[0,1] 55 | print ('{:} {:05d}/{:05d} mode={:5s}, correlation : accs={:.5f} for CIFAR-10 valid, {:.5f} for CIFAR-10 test.'.format(time_string(), idx, len(archs), 'Train' if cal_mode else 'Eval', cor_accs_valid, cor_accs_test)) 56 | model.load_state_dict(weights) 57 | return archs, probs, accuracies 58 | -------------------------------------------------------------------------------- /config_utils/__init__.py: -------------------------------------------------------------------------------- 1 | ################################################## 2 | # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 # 3 | ################################################## 4 | from .configure_utils import load_config, dict2config, configure2str 5 | from .basic_args import obtain_basic_args 6 | from .attention_args import obtain_attention_args 7 | from .random_baseline import obtain_RandomSearch_args 8 | from .cls_kd_args import obtain_cls_kd_args 9 | from .cls_init_args import obtain_cls_init_args 10 | from .search_single_args import obtain_search_single_args 11 | from .search_args import obtain_search_args 12 | # for network pruning 13 | from .pruning_args import obtain_pruning_args 14 | -------------------------------------------------------------------------------- /config_utils/attention_args.py: -------------------------------------------------------------------------------- 1 | import random, argparse 2 | from .share_args import add_shared_args 3 | 4 | def obtain_attention_args(): 5 | parser = argparse.ArgumentParser(description='Train a classification model on typical image classification datasets.', formatter_class=argparse.ArgumentDefaultsHelpFormatter) 6 | parser.add_argument('--resume' , type=str, help='Resume path.') 7 | parser.add_argument('--init_model' , type=str, help='The initialization model path.') 8 | parser.add_argument('--model_config', type=str, help='The path to the model configuration') 9 | parser.add_argument('--optim_config', type=str, help='The path to the optimizer configuration') 10 | parser.add_argument('--procedure' , type=str, help='The procedure basic prefix.') 11 | parser.add_argument('--att_channel' , type=int, help='.') 12 | parser.add_argument('--att_spatial' , type=str, help='.') 13 | parser.add_argument('--att_active' , type=str, help='.') 14 | add_shared_args( parser ) 15 | # Optimization options 16 | parser.add_argument('--batch_size', type=int, default=2, help='Batch size for training.') 17 | args = parser.parse_args() 18 | 19 | if args.rand_seed is None or args.rand_seed < 0: 20 | args.rand_seed = random.randint(1, 100000) 21 | assert args.save_dir is not None, 'save-path argument can not be None' 22 | return args 23 | -------------------------------------------------------------------------------- /config_utils/basic_args.py: -------------------------------------------------------------------------------- 1 | ################################################## 2 | # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020 # 3 | ################################################## 4 | import random, argparse 5 | from .share_args import add_shared_args 6 | 7 | def obtain_basic_args(): 8 | parser = argparse.ArgumentParser(description='Train a classification model on typical image classification datasets.', formatter_class=argparse.ArgumentDefaultsHelpFormatter) 9 | parser.add_argument('--resume' , type=str, help='Resume path.') 10 | parser.add_argument('--init_model' , type=str, help='The initialization model path.') 11 | parser.add_argument('--model_config', type=str, help='The path to the model configuration') 12 | parser.add_argument('--optim_config', type=str, help='The path to the optimizer configuration') 13 | parser.add_argument('--procedure' , type=str, help='The procedure basic prefix.') 14 | parser.add_argument('--model_source', type=str, default='normal',help='The source of model defination.') 15 | parser.add_argument('--extra_model_path', type=str, default=None, help='The extra model ckp file (help to indicate the searched architecture).') 16 | add_shared_args( parser ) 17 | # Optimization options 18 | parser.add_argument('--batch_size', type=int, default=2, help='Batch size for training.') 19 | args = parser.parse_args() 20 | 21 | if args.rand_seed is None or args.rand_seed < 0: 22 | args.rand_seed = random.randint(1, 100000) 23 | assert args.save_dir is not None, 'save-path argument can not be None' 24 | return args 25 | -------------------------------------------------------------------------------- /config_utils/cls_init_args.py: -------------------------------------------------------------------------------- 1 | import random, argparse 2 | from .share_args import add_shared_args 3 | 4 | def obtain_cls_init_args(): 5 | parser = argparse.ArgumentParser(description='Train a classification model on typical image classification datasets.', formatter_class=argparse.ArgumentDefaultsHelpFormatter) 6 | parser.add_argument('--resume' , type=str, help='Resume path.') 7 | parser.add_argument('--init_model' , type=str, help='The initialization model path.') 8 | parser.add_argument('--model_config', type=str, help='The path to the model configuration') 9 | parser.add_argument('--optim_config', type=str, help='The path to the optimizer configuration') 10 | parser.add_argument('--procedure' , type=str, help='The procedure basic prefix.') 11 | parser.add_argument('--init_checkpoint', type=str, help='The checkpoint path to the initial model.') 12 | add_shared_args( parser ) 13 | # Optimization options 14 | parser.add_argument('--batch_size', type=int, default=2, help='Batch size for training.') 15 | args = parser.parse_args() 16 | 17 | if args.rand_seed is None or args.rand_seed < 0: 18 | args.rand_seed = random.randint(1, 100000) 19 | assert args.save_dir is not None, 'save-path argument can not be None' 20 | return args 21 | -------------------------------------------------------------------------------- /config_utils/cls_kd_args.py: -------------------------------------------------------------------------------- 1 | import random, argparse 2 | from .share_args import add_shared_args 3 | 4 | def obtain_cls_kd_args(): 5 | parser = argparse.ArgumentParser(description='Train a classification model on typical image classification datasets.', formatter_class=argparse.ArgumentDefaultsHelpFormatter) 6 | parser.add_argument('--resume' , type=str, help='Resume path.') 7 | parser.add_argument('--init_model' , type=str, help='The initialization model path.') 8 | parser.add_argument('--model_config', type=str, help='The path to the model configuration') 9 | parser.add_argument('--optim_config', type=str, help='The path to the optimizer configuration') 10 | parser.add_argument('--procedure' , type=str, help='The procedure basic prefix.') 11 | parser.add_argument('--KD_checkpoint', type=str, help='The teacher checkpoint in knowledge distillation.') 12 | parser.add_argument('--KD_alpha' , type=float, help='The alpha parameter in knowledge distillation.') 13 | parser.add_argument('--KD_temperature', type=float, help='The temperature parameter in knowledge distillation.') 14 | #parser.add_argument('--KD_feature', type=float, help='Knowledge distillation at the feature level.') 15 | add_shared_args( parser ) 16 | # Optimization options 17 | parser.add_argument('--batch_size', type=int, default=2, help='Batch size for training.') 18 | args = parser.parse_args() 19 | 20 | if args.rand_seed is None or args.rand_seed < 0: 21 | args.rand_seed = random.randint(1, 100000) 22 | assert args.save_dir is not None, 'save-path argument can not be None' 23 | return args 24 | -------------------------------------------------------------------------------- /config_utils/configure_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | import os, json 8 | from os import path as osp 9 | from pathlib import Path 10 | from collections import namedtuple 11 | 12 | support_types = ('str', 'int', 'bool', 'float', 'none') 13 | 14 | 15 | def convert_param(original_lists): 16 | assert isinstance(original_lists, list), 'The type is not right : {:}'.format(original_lists) 17 | ctype, value = original_lists[0], original_lists[1] 18 | assert ctype in support_types, 'Ctype={:}, support={:}'.format(ctype, support_types) 19 | is_list = isinstance(value, list) 20 | if not is_list: value = [value] 21 | outs = [] 22 | for x in value: 23 | if ctype == 'int': 24 | x = int(x) 25 | elif ctype == 'str': 26 | x = str(x) 27 | elif ctype == 'bool': 28 | x = bool(int(x)) 29 | elif ctype == 'float': 30 | x = float(x) 31 | elif ctype == 'none': 32 | if x.lower() != 'none': 33 | raise ValueError('For the none type, the value must be none instead of {:}'.format(x)) 34 | x = None 35 | else: 36 | raise TypeError('Does not know this type : {:}'.format(ctype)) 37 | outs.append(x) 38 | if not is_list: outs = outs[0] 39 | return outs 40 | 41 | 42 | def load_config(path, extra, logger): 43 | path = str(path) 44 | if hasattr(logger, 'log'): logger.log(path) 45 | assert os.path.exists(path), 'Can not find {:}'.format(path) 46 | # Reading data back 47 | with open(path, 'r') as f: 48 | data = json.load(f) 49 | content = { k: convert_param(v) for k,v in data.items()} 50 | assert extra is None or isinstance(extra, dict), 'invalid type of extra : {:}'.format(extra) 51 | if isinstance(extra, dict): content = {**content, **extra} 52 | Arguments = namedtuple('Configure', ' '.join(content.keys())) 53 | content = Arguments(**content) 54 | if hasattr(logger, 'log'): logger.log('{:}'.format(content)) 55 | return content 56 | 57 | 58 | def configure2str(config, xpath=None): 59 | if not isinstance(config, dict): 60 | config = config._asdict() 61 | def cstring(x): 62 | return "\"{:}\"".format(x) 63 | def gtype(x): 64 | if isinstance(x, list): x = x[0] 65 | if isinstance(x, str) : return 'str' 66 | elif isinstance(x, bool) : return 'bool' 67 | elif isinstance(x, int): return 'int' 68 | elif isinstance(x, float): return 'float' 69 | elif x is None : return 'none' 70 | else: raise ValueError('invalid : {:}'.format(x)) 71 | def cvalue(x, xtype): 72 | if isinstance(x, list): is_list = True 73 | else: 74 | is_list, x = False, [x] 75 | temps = [] 76 | for temp in x: 77 | if xtype == 'bool' : temp = cstring(int(temp)) 78 | elif xtype == 'none': temp = cstring('None') 79 | else : temp = cstring(temp) 80 | temps.append( temp ) 81 | if is_list: 82 | return "[{:}]".format( ', '.join( temps ) ) 83 | else: 84 | return temps[0] 85 | 86 | xstrings = [] 87 | for key, value in config.items(): 88 | xtype = gtype(value) 89 | string = ' {:20s} : [{:8s}, {:}]'.format(cstring(key), cstring(xtype), cvalue(value, xtype)) 90 | xstrings.append(string) 91 | Fstring = '{\n' + ',\n'.join(xstrings) + '\n}' 92 | if xpath is not None: 93 | parent = Path(xpath).resolve().parent 94 | parent.mkdir(parents=True, exist_ok=True) 95 | if osp.isfile(xpath): os.remove(xpath) 96 | with open(xpath, "w") as text_file: 97 | text_file.write('{:}'.format(Fstring)) 98 | return Fstring 99 | 100 | 101 | def dict2config(xdict, logger): 102 | assert isinstance(xdict, dict), 'invalid type : {:}'.format( type(xdict) ) 103 | Arguments = namedtuple('Configure', ' '.join(xdict.keys())) 104 | content = Arguments(**xdict) 105 | if hasattr(logger, 'log'): logger.log('{:}'.format(content)) 106 | return content 107 | -------------------------------------------------------------------------------- /config_utils/pruning_args.py: -------------------------------------------------------------------------------- 1 | import os, sys, time, random, argparse 2 | from .share_args import add_shared_args 3 | 4 | def obtain_pruning_args(): 5 | parser = argparse.ArgumentParser(description='Train a classification model on typical image classification datasets.', formatter_class=argparse.ArgumentDefaultsHelpFormatter) 6 | parser.add_argument('--resume' , type=str, help='Resume path.') 7 | parser.add_argument('--init_model' , type=str, help='The initialization model path.') 8 | parser.add_argument('--model_config', type=str, help='The path to the model configuration') 9 | parser.add_argument('--optim_config', type=str, help='The path to the optimizer configuration') 10 | parser.add_argument('--procedure' , type=str, help='The procedure basic prefix.') 11 | parser.add_argument('--keep_ratio' , type=float, help='The left channel ratio compared to the original network.') 12 | parser.add_argument('--model_version', type=str, help='The network version.') 13 | parser.add_argument('--KD_alpha' , type=float, help='The alpha parameter in knowledge distillation.') 14 | parser.add_argument('--KD_temperature', type=float, help='The temperature parameter in knowledge distillation.') 15 | parser.add_argument('--Regular_W_feat', type=float, help='The .') 16 | parser.add_argument('--Regular_W_conv', type=float, help='The .') 17 | add_shared_args( parser ) 18 | # Optimization options 19 | parser.add_argument('--batch_size', type=int, default=2, help='Batch size for training.') 20 | args = parser.parse_args() 21 | 22 | if args.rand_seed is None or args.rand_seed < 0: 23 | args.rand_seed = random.randint(1, 100000) 24 | assert args.save_dir is not None, 'save-path argument can not be None' 25 | assert args.keep_ratio > 0 and args.keep_ratio <= 1, 'invalid keep ratio : {:}'.format(args.keep_ratio) 26 | return args 27 | -------------------------------------------------------------------------------- /config_utils/random_baseline.py: -------------------------------------------------------------------------------- 1 | import os, sys, time, random, argparse 2 | from .share_args import add_shared_args 3 | 4 | 5 | def obtain_RandomSearch_args(): 6 | parser = argparse.ArgumentParser(description='Train a classification model on typical image classification datasets.', formatter_class=argparse.ArgumentDefaultsHelpFormatter) 7 | parser.add_argument('--resume' , type=str, help='Resume path.') 8 | parser.add_argument('--init_model' , type=str, help='The initialization model path.') 9 | parser.add_argument('--expect_flop', type=float, help='The expected flop keep ratio.') 10 | parser.add_argument('--arch_nums' , type=int, help='The maximum number of running random arch generating..') 11 | parser.add_argument('--model_config', type=str, help='The path to the model configuration') 12 | parser.add_argument('--optim_config', type=str, help='The path to the optimizer configuration') 13 | parser.add_argument('--random_mode', type=str, choices=['random', 'fix'], help='The path to the optimizer configuration') 14 | parser.add_argument('--procedure' , type=str, help='The procedure basic prefix.') 15 | add_shared_args( parser ) 16 | # Optimization options 17 | parser.add_argument('--batch_size', type=int, default=2, help='Batch size for training.') 18 | args = parser.parse_args() 19 | 20 | if args.rand_seed is None or args.rand_seed < 0: 21 | args.rand_seed = random.randint(1, 100000) 22 | assert args.save_dir is not None, 'save-path argument can not be None' 23 | #assert args.flop_ratio_min < args.flop_ratio_max, 'flop-ratio {:} vs {:}'.format(args.flop_ratio_min, args.flop_ratio_max) 24 | return args 25 | -------------------------------------------------------------------------------- /config_utils/search_args.py: -------------------------------------------------------------------------------- 1 | import os, sys, time, random, argparse 2 | from .share_args import add_shared_args 3 | 4 | 5 | def obtain_search_args(): 6 | parser = argparse.ArgumentParser(description='Train a classification model on typical image classification datasets.', formatter_class=argparse.ArgumentDefaultsHelpFormatter) 7 | parser.add_argument('--resume' , type=str, help='Resume path.') 8 | parser.add_argument('--model_config' , type=str, help='The path to the model configuration') 9 | parser.add_argument('--optim_config' , type=str, help='The path to the optimizer configuration') 10 | parser.add_argument('--split_path' , type=str, help='The split file path.') 11 | #parser.add_argument('--arch_para_pure', type=int, help='The architecture-parameter pure or not.') 12 | parser.add_argument('--gumbel_tau_max', type=float, help='The maximum tau for Gumbel.') 13 | parser.add_argument('--gumbel_tau_min', type=float, help='The minimum tau for Gumbel.') 14 | parser.add_argument('--procedure' , type=str, help='The procedure basic prefix.') 15 | parser.add_argument('--FLOP_ratio' , type=float, help='The expected FLOP ratio.') 16 | parser.add_argument('--FLOP_weight' , type=float, help='The loss weight for FLOP.') 17 | parser.add_argument('--FLOP_tolerant' , type=float, help='The tolerant range for FLOP.') 18 | # ablation studies 19 | parser.add_argument('--ablation_num_select', type=int, help='The number of randomly selected channels.') 20 | add_shared_args( parser ) 21 | # Optimization options 22 | parser.add_argument('--batch_size' , type=int, default=2, help='Batch size for training.') 23 | args = parser.parse_args() 24 | 25 | if args.rand_seed is None or args.rand_seed < 0: 26 | args.rand_seed = random.randint(1, 100000) 27 | assert args.save_dir is not None, 'save-path argument can not be None' 28 | assert args.gumbel_tau_max is not None and args.gumbel_tau_min is not None 29 | assert args.FLOP_tolerant is not None and args.FLOP_tolerant > 0, 'invalid FLOP_tolerant : {:}'.format(FLOP_tolerant) 30 | #assert args.arch_para_pure is not None, 'arch_para_pure is not None: {:}'.format(args.arch_para_pure) 31 | #args.arch_para_pure = bool(args.arch_para_pure) 32 | return args 33 | -------------------------------------------------------------------------------- /config_utils/search_single_args.py: -------------------------------------------------------------------------------- 1 | import os, sys, time, random, argparse 2 | from .share_args import add_shared_args 3 | 4 | 5 | def obtain_search_single_args(): 6 | parser = argparse.ArgumentParser(description='Train a classification model on typical image classification datasets.', formatter_class=argparse.ArgumentDefaultsHelpFormatter) 7 | parser.add_argument('--resume' , type=str, help='Resume path.') 8 | parser.add_argument('--model_config' , type=str, help='The path to the model configuration') 9 | parser.add_argument('--optim_config' , type=str, help='The path to the optimizer configuration') 10 | parser.add_argument('--split_path' , type=str, help='The split file path.') 11 | parser.add_argument('--search_shape' , type=str, help='The shape to be searched.') 12 | #parser.add_argument('--arch_para_pure', type=int, help='The architecture-parameter pure or not.') 13 | parser.add_argument('--gumbel_tau_max', type=float, help='The maximum tau for Gumbel.') 14 | parser.add_argument('--gumbel_tau_min', type=float, help='The minimum tau for Gumbel.') 15 | parser.add_argument('--procedure' , type=str, help='The procedure basic prefix.') 16 | parser.add_argument('--FLOP_ratio' , type=float, help='The expected FLOP ratio.') 17 | parser.add_argument('--FLOP_weight' , type=float, help='The loss weight for FLOP.') 18 | parser.add_argument('--FLOP_tolerant' , type=float, help='The tolerant range for FLOP.') 19 | add_shared_args( parser ) 20 | # Optimization options 21 | parser.add_argument('--batch_size' , type=int, default=2, help='Batch size for training.') 22 | args = parser.parse_args() 23 | 24 | if args.rand_seed is None or args.rand_seed < 0: 25 | args.rand_seed = random.randint(1, 100000) 26 | assert args.save_dir is not None, 'save-path argument can not be None' 27 | assert args.gumbel_tau_max is not None and args.gumbel_tau_min is not None 28 | assert args.FLOP_tolerant is not None and args.FLOP_tolerant > 0, 'invalid FLOP_tolerant : {:}'.format(FLOP_tolerant) 29 | #assert args.arch_para_pure is not None, 'arch_para_pure is not None: {:}'.format(args.arch_para_pure) 30 | #args.arch_para_pure = bool(args.arch_para_pure) 31 | return args 32 | -------------------------------------------------------------------------------- /config_utils/share_args.py: -------------------------------------------------------------------------------- 1 | import os, sys, time, random, argparse 2 | 3 | def add_shared_args( parser ): 4 | # Data Generation 5 | parser.add_argument('--dataset', type=str, help='The dataset name.') 6 | parser.add_argument('--data_path', type=str, help='The dataset name.') 7 | parser.add_argument('--cutout_length', type=int, help='The cutout length, negative means not use.') 8 | # Printing 9 | parser.add_argument('--print_freq', type=int, default=100, help='print frequency (default: 200)') 10 | parser.add_argument('--print_freq_eval', type=int, default=100, help='print frequency (default: 200)') 11 | # Checkpoints 12 | parser.add_argument('--eval_frequency', type=int, default=1, help='evaluation frequency (default: 200)') 13 | parser.add_argument('--save_dir', type=str, help='Folder to save checkpoints and log.') 14 | # Acceleration 15 | parser.add_argument('--workers', type=int, default=8, help='number of data loading workers (default: 8)') 16 | # Random Seed 17 | parser.add_argument('--rand_seed', type=int, default=-1, help='manual seed') 18 | -------------------------------------------------------------------------------- /datasets/DownsampledImageNet.py: -------------------------------------------------------------------------------- 1 | ################################################## 2 | # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 # 3 | ################################################## 4 | import os, sys, hashlib, torch 5 | import numpy as np 6 | from PIL import Image 7 | import torch.utils.data as data 8 | if sys.version_info[0] == 2: 9 | import cPickle as pickle 10 | else: 11 | import pickle 12 | 13 | 14 | def calculate_md5(fpath, chunk_size=1024 * 1024): 15 | md5 = hashlib.md5() 16 | with open(fpath, 'rb') as f: 17 | for chunk in iter(lambda: f.read(chunk_size), b''): 18 | md5.update(chunk) 19 | return md5.hexdigest() 20 | 21 | 22 | def check_md5(fpath, md5, **kwargs): 23 | return md5 == calculate_md5(fpath, **kwargs) 24 | 25 | 26 | def check_integrity(fpath, md5=None): 27 | if not os.path.isfile(fpath): return False 28 | if md5 is None: return True 29 | else : return check_md5(fpath, md5) 30 | 31 | 32 | class ImageNet16(data.Dataset): 33 | # http://image-net.org/download-images 34 | # A Downsampled Variant of ImageNet as an Alternative to the CIFAR datasets 35 | # https://arxiv.org/pdf/1707.08819.pdf 36 | 37 | train_list = [ 38 | ['train_data_batch_1', '27846dcaa50de8e21a7d1a35f30f0e91'], 39 | ['train_data_batch_2', 'c7254a054e0e795c69120a5727050e3f'], 40 | ['train_data_batch_3', '4333d3df2e5ffb114b05d2ffc19b1e87'], 41 | ['train_data_batch_4', '1620cdf193304f4a92677b695d70d10f'], 42 | ['train_data_batch_5', '348b3c2fdbb3940c4e9e834affd3b18d'], 43 | ['train_data_batch_6', '6e765307c242a1b3d7d5ef9139b48945'], 44 | ['train_data_batch_7', '564926d8cbf8fc4818ba23d2faac7564'], 45 | ['train_data_batch_8', 'f4755871f718ccb653440b9dd0ebac66'], 46 | ['train_data_batch_9', 'bb6dd660c38c58552125b1a92f86b5d4'], 47 | ['train_data_batch_10','8f03f34ac4b42271a294f91bf480f29b'], 48 | ] 49 | valid_list = [ 50 | ['val_data', '3410e3017fdaefba8d5073aaa65e4bd6'], 51 | ] 52 | 53 | def __init__(self, root, train, transform, use_num_of_class_only=None): 54 | self.root = root 55 | self.transform = transform 56 | self.train = train # training set or valid set 57 | if not self._check_integrity(): raise RuntimeError('Dataset not found or corrupted.') 58 | 59 | if self.train: downloaded_list = self.train_list 60 | else : downloaded_list = self.valid_list 61 | self.data = [] 62 | self.targets = [] 63 | 64 | # now load the picked numpy arrays 65 | for i, (file_name, checksum) in enumerate(downloaded_list): 66 | file_path = os.path.join(self.root, file_name) 67 | #print ('Load {:}/{:02d}-th : {:}'.format(i, len(downloaded_list), file_path)) 68 | with open(file_path, 'rb') as f: 69 | if sys.version_info[0] == 2: 70 | entry = pickle.load(f) 71 | else: 72 | entry = pickle.load(f, encoding='latin1') 73 | self.data.append(entry['data']) 74 | self.targets.extend(entry['labels']) 75 | self.data = np.vstack(self.data).reshape(-1, 3, 16, 16) 76 | self.data = self.data.transpose((0, 2, 3, 1)) # convert to HWC 77 | if use_num_of_class_only is not None: 78 | assert isinstance(use_num_of_class_only, int) and use_num_of_class_only > 0 and use_num_of_class_only < 1000, 'invalid use_num_of_class_only : {:}'.format(use_num_of_class_only) 79 | new_data, new_targets = [], [] 80 | for I, L in zip(self.data, self.targets): 81 | if 1 <= L <= use_num_of_class_only: 82 | new_data.append( I ) 83 | new_targets.append( L ) 84 | self.data = new_data 85 | self.targets = new_targets 86 | # self.mean.append(entry['mean']) 87 | #self.mean = np.vstack(self.mean).reshape(-1, 3, 16, 16) 88 | #self.mean = np.mean(np.mean(np.mean(self.mean, axis=0), axis=1), axis=1) 89 | #print ('Mean : {:}'.format(self.mean)) 90 | #temp = self.data - np.reshape(self.mean, (1, 1, 1, 3)) 91 | #std_data = np.std(temp, axis=0) 92 | #std_data = np.mean(np.mean(std_data, axis=0), axis=0) 93 | #print ('Std : {:}'.format(std_data)) 94 | 95 | def __getitem__(self, index): 96 | img, target = self.data[index], self.targets[index] - 1 97 | 98 | img = Image.fromarray(img) 99 | 100 | if self.transform is not None: 101 | img = self.transform(img) 102 | 103 | return img, target 104 | 105 | def __len__(self): 106 | return len(self.data) 107 | 108 | def _check_integrity(self): 109 | root = self.root 110 | for fentry in (self.train_list + self.valid_list): 111 | filename, md5 = fentry[0], fentry[1] 112 | fpath = os.path.join(root, filename) 113 | if not check_integrity(fpath, md5): 114 | return False 115 | return True 116 | 117 | # 118 | if __name__ == '__main__': 119 | train = ImageNet16('/data02/dongxuanyi/.torch/cifar.python/ImageNet16', True , None) 120 | valid = ImageNet16('/data02/dongxuanyi/.torch/cifar.python/ImageNet16', False, None) 121 | 122 | print ( len(train) ) 123 | print ( len(valid) ) 124 | image, label = train[111] 125 | trainX = ImageNet16('/data02/dongxuanyi/.torch/cifar.python/ImageNet16', True , None, 200) 126 | validX = ImageNet16('/data02/dongxuanyi/.torch/cifar.python/ImageNet16', False , None, 200) 127 | print ( len(trainX) ) 128 | print ( len(validX) ) 129 | #import pdb; pdb.set_trace() 130 | -------------------------------------------------------------------------------- /datasets/SearchDatasetWrap.py: -------------------------------------------------------------------------------- 1 | ################################################## 2 | # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 # 3 | ################################################## 4 | import torch, copy, random 5 | import torch.utils.data as data 6 | 7 | 8 | class SearchDataset(data.Dataset): 9 | 10 | def __init__(self, name, data, train_split, valid_split, check=True): 11 | self.datasetname = name 12 | if isinstance(data, (list, tuple)): # new type of SearchDataset 13 | assert len(data) == 2, 'invalid length: {:}'.format( len(data) ) 14 | self.train_data = data[0] 15 | self.valid_data = data[1] 16 | self.train_split = train_split.copy() 17 | self.valid_split = valid_split.copy() 18 | self.mode_str = 'V2' # new mode 19 | else: 20 | self.mode_str = 'V1' # old mode 21 | self.data = data 22 | self.train_split = train_split.copy() 23 | self.valid_split = valid_split.copy() 24 | if check: 25 | intersection = set(train_split).intersection(set(valid_split)) 26 | assert len(intersection) == 0, 'the splitted train and validation sets should have no intersection' 27 | self.length = len(self.train_split) 28 | 29 | def __repr__(self): 30 | return ('{name}(name={datasetname}, train={tr_L}, valid={val_L}, version={ver})'.format(name=self.__class__.__name__, datasetname=self.datasetname, tr_L=len(self.train_split), val_L=len(self.valid_split), ver=self.mode_str)) 31 | 32 | def __len__(self): 33 | return self.length 34 | 35 | def __getitem__(self, index): 36 | assert index >= 0 and index < self.length, 'invalid index = {:}'.format(index) 37 | train_index = self.train_split[index] 38 | valid_index = random.choice( self.valid_split ) 39 | if self.mode_str == 'V1': 40 | train_image, train_label = self.data[train_index] 41 | valid_image, valid_label = self.data[valid_index] 42 | elif self.mode_str == 'V2': 43 | train_image, train_label = self.train_data[train_index] 44 | valid_image, valid_label = self.valid_data[valid_index] 45 | else: raise ValueError('invalid mode : {:}'.format(self.mode_str)) 46 | return train_image, train_label, valid_image, valid_label 47 | -------------------------------------------------------------------------------- /datasets/__init__.py: -------------------------------------------------------------------------------- 1 | ################################################## 2 | # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 # 3 | ################################################## 4 | from .get_dataset_with_transform import get_datasets, get_nas_search_loaders 5 | from .SearchDatasetWrap import SearchDataset 6 | from .data import get_data 7 | -------------------------------------------------------------------------------- /datasets/data.py: -------------------------------------------------------------------------------- 1 | from datasets import get_datasets 2 | from config_utils import load_config 3 | import torch 4 | import torchvision 5 | 6 | class AddGaussianNoise(object): 7 | def __init__(self, mean=0., std=0.001): 8 | self.std = std 9 | self.mean = mean 10 | 11 | def __call__(self, tensor): 12 | return tensor + torch.randn(tensor.size()) * self.std + self.mean 13 | 14 | def __repr__(self): 15 | return self.__class__.__name__ + '(mean={0}, std={1})'.format(self.mean, self.std) 16 | 17 | 18 | 19 | 20 | class RepeatSampler(torch.utils.data.sampler.Sampler): 21 | def __init__(self, samp, repeat): 22 | self.samp = samp 23 | self.repeat = repeat 24 | def __iter__(self): 25 | for i in self.samp: 26 | for j in range(self.repeat): 27 | yield i 28 | def __len__(self): 29 | return self.repeat*len(self.samp) 30 | 31 | 32 | def get_data(dataset, data_loc, trainval, batch_size, augtype, repeat, args, pin_memory=True): 33 | train_data, valid_data, xshape, class_num = get_datasets(dataset, data_loc, cutout=0) 34 | if augtype == 'gaussnoise': 35 | train_data.transform.transforms = train_data.transform.transforms[2:] 36 | train_data.transform.transforms.append(AddGaussianNoise(std=args.sigma)) 37 | elif augtype == 'cutout': 38 | train_data.transform.transforms = train_data.transform.transforms[2:] 39 | train_data.transform.transforms.append(torchvision.transforms.RandomErasing(p=0.9, scale=(0.02, 0.04))) 40 | elif augtype == 'none': 41 | train_data.transform.transforms = train_data.transform.transforms[2:] 42 | 43 | if dataset == 'cifar10': 44 | acc_type = 'ori-test' 45 | val_acc_type = 'x-valid' 46 | 47 | else: 48 | acc_type = 'x-test' 49 | val_acc_type = 'x-valid' 50 | 51 | if trainval and 'cifar10' in dataset: 52 | cifar_split = load_config('config_utils/cifar-split.txt', None, None) 53 | train_split, valid_split = cifar_split.train, cifar_split.valid 54 | if repeat > 0: 55 | train_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size, 56 | num_workers=0, pin_memory=pin_memory, sampler= RepeatSampler(torch.utils.data.sampler.SubsetRandomSampler(train_split), repeat)) 57 | else: 58 | train_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size, 59 | num_workers=0, pin_memory=pin_memory, sampler= torch.utils.data.sampler.SubsetRandomSampler(train_split)) 60 | 61 | 62 | else: 63 | if repeat > 0: 64 | train_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size, #shuffle=True, 65 | num_workers=0, pin_memory=pin_memory, sampler= RepeatSampler(torch.utils.data.sampler.SubsetRandomSampler(range(len(train_data))), repeat)) 66 | else: 67 | train_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size, shuffle=True, 68 | num_workers=0, pin_memory=pin_memory) 69 | return train_loader 70 | -------------------------------------------------------------------------------- /datasets/landmark_utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .point_meta import PointMeta2V, apply_affine2point, apply_boundary 2 | -------------------------------------------------------------------------------- /datasets/landmark_utils/point_meta.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | import copy, math, torch, numpy as np 8 | from xvision import normalize_points 9 | from xvision import denormalize_points 10 | 11 | 12 | class PointMeta(): 13 | # points : 3 x num_pts (x, y, oculusion) 14 | # image_size: original [width, height] 15 | def __init__(self, num_point, points, box, image_path, dataset_name): 16 | 17 | self.num_point = num_point 18 | if box is not None: 19 | assert (isinstance(box, tuple) or isinstance(box, list)) and len(box) == 4 20 | self.box = torch.Tensor(box) 21 | else: self.box = None 22 | if points is None: 23 | self.points = points 24 | else: 25 | assert len(points.shape) == 2 and points.shape[0] == 3 and points.shape[1] == self.num_point, 'The shape of point is not right : {}'.format( points ) 26 | self.points = torch.Tensor(points.copy()) 27 | self.image_path = image_path 28 | self.datasets = dataset_name 29 | 30 | def __repr__(self): 31 | if self.box is None: boxstr = 'None' 32 | else : boxstr = 'box=[{:.1f}, {:.1f}, {:.1f}, {:.1f}]'.format(*self.box.tolist()) 33 | return ('{name}(points={num_point}, '.format(name=self.__class__.__name__, **self.__dict__) + boxstr + ')') 34 | 35 | def get_box(self, return_diagonal=False): 36 | if self.box is None: return None 37 | if not return_diagonal: 38 | return self.box.clone() 39 | else: 40 | W = (self.box[2]-self.box[0]).item() 41 | H = (self.box[3]-self.box[1]).item() 42 | return math.sqrt(H*H+W*W) 43 | 44 | def get_points(self, ignore_indicator=False): 45 | if ignore_indicator: last = 2 46 | else : last = 3 47 | if self.points is not None: return self.points.clone()[:last, :] 48 | else : return torch.zeros((last, self.num_point)) 49 | 50 | def is_none(self): 51 | #assert self.box is not None, 'The box should not be None' 52 | return self.points is None 53 | #if self.box is None: return True 54 | #else : return self.points is None 55 | 56 | def copy(self): 57 | return copy.deepcopy(self) 58 | 59 | def visiable_pts_num(self): 60 | with torch.no_grad(): 61 | ans = self.points[2,:] > 0 62 | ans = torch.sum(ans) 63 | ans = ans.item() 64 | return ans 65 | 66 | def special_fun(self, indicator): 67 | if indicator == '68to49': # For 300W or 300VW, convert the default 68 points to 49 points. 68 | assert self.num_point == 68, 'num-point must be 68 vs. {:}'.format(self.num_point) 69 | self.num_point = 49 70 | out = torch.ones((68), dtype=torch.uint8) 71 | out[[0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,60,64]] = 0 72 | if self.points is not None: self.points = self.points.clone()[:, out] 73 | else: 74 | raise ValueError('Invalid indicator : {:}'.format( indicator )) 75 | 76 | def apply_horizontal_flip(self): 77 | #self.points[0, :] = width - self.points[0, :] - 1 78 | # Mugsy spefic or Synthetic 79 | if self.datasets.startswith('HandsyROT'): 80 | ori = np.array(list(range(0, 42))) 81 | pos = np.array(list(range(21,42)) + list(range(0,21))) 82 | self.points[:, pos] = self.points[:, ori] 83 | elif self.datasets.startswith('face68'): 84 | ori = np.array(list(range(0, 68))) 85 | pos = np.array([17,16,15,14,13,12,11,10, 9, 8,7,6,5,4,3,2,1, 27,26,25,24,23,22,21,20,19,18, 28,29,30,31, 36,35,34,33,32, 46,45,44,43,48,47, 40,39,38,37,42,41, 55,54,53,52,51,50,49,60,59,58,57,56,65,64,63,62,61,68,67,66])-1 86 | self.points[:, ori] = self.points[:, pos] 87 | else: 88 | raise ValueError('Does not support {:}'.format(self.datasets)) 89 | 90 | 91 | 92 | # shape = (H,W) 93 | def apply_affine2point(points, theta, shape): 94 | assert points.size(0) == 3, 'invalid points shape : {:}'.format(points.size()) 95 | with torch.no_grad(): 96 | ok_points = points[2,:] == 1 97 | assert torch.sum(ok_points).item() > 0, 'there is no visiable point' 98 | points[:2,:] = normalize_points(shape, points[:2,:]) 99 | 100 | norm_trans_points = ok_points.unsqueeze(0).repeat(3, 1).float() 101 | 102 | trans_points, ___ = torch.gesv(points[:, ok_points], theta) 103 | 104 | norm_trans_points[:, ok_points] = trans_points 105 | 106 | return norm_trans_points 107 | 108 | 109 | 110 | def apply_boundary(norm_trans_points): 111 | with torch.no_grad(): 112 | norm_trans_points = norm_trans_points.clone() 113 | oks = torch.stack((norm_trans_points[0]>-1, norm_trans_points[0]<1, norm_trans_points[1]>-1, norm_trans_points[1]<1, norm_trans_points[2]>0)) 114 | oks = torch.sum(oks, dim=0) == 5 115 | norm_trans_points[2, :] = oks 116 | return norm_trans_points 117 | -------------------------------------------------------------------------------- /datasets/test_utils.py: -------------------------------------------------------------------------------- 1 | ################################################## 2 | # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 # 3 | ################################################## 4 | import os 5 | 6 | 7 | def test_imagenet_data(imagenet): 8 | total_length = len(imagenet) 9 | assert total_length == 1281166 or total_length == 50000, 'The length of ImageNet is wrong : {}'.format(total_length) 10 | map_id = {} 11 | for index in range(total_length): 12 | path, target = imagenet.imgs[index] 13 | folder, image_name = os.path.split(path) 14 | _, folder = os.path.split(folder) 15 | if folder not in map_id: 16 | map_id[folder] = target 17 | else: 18 | assert map_id[folder] == target, 'Class : {} is not {}'.format(folder, target) 19 | assert image_name.find(folder) == 0, '{} is wrong.'.format(path) 20 | print ('Check ImageNet Dataset OK') 21 | -------------------------------------------------------------------------------- /env.yml: -------------------------------------------------------------------------------- 1 | name: naswot2 2 | channels: 3 | - conda-forge 4 | - pytorch 5 | dependencies: 6 | - python=3.7 7 | - numpy 8 | - matplotlib 9 | - seaborn 10 | - pandas 11 | - xlrd 12 | - scipy 13 | - pip 14 | - scikit-learn 15 | - scikit-image 16 | - pytorch::pytorch==1.6.0 17 | - pytorch::torchvision==0.7.0 18 | - cudatoolkit=9.2 19 | - tqdm 20 | - pip: 21 | - tensorflow-gpu==1.15 22 | - yacs 23 | - simplejson 24 | - "--editable=git+https://github.com/google-research/nasbench#egg=nasbench-master" 25 | -------------------------------------------------------------------------------- /models/CifarDenseNet.py: -------------------------------------------------------------------------------- 1 | ################################################## 2 | # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 # 3 | ################################################## 4 | import math, torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | from .initialization import initialize_resnet 8 | 9 | 10 | class Bottleneck(nn.Module): 11 | def __init__(self, nChannels, growthRate): 12 | super(Bottleneck, self).__init__() 13 | interChannels = 4*growthRate 14 | self.bn1 = nn.BatchNorm2d(nChannels) 15 | self.conv1 = nn.Conv2d(nChannels, interChannels, kernel_size=1, bias=False) 16 | self.bn2 = nn.BatchNorm2d(interChannels) 17 | self.conv2 = nn.Conv2d(interChannels, growthRate, kernel_size=3, padding=1, bias=False) 18 | 19 | def forward(self, x): 20 | out = self.conv1(F.relu(self.bn1(x))) 21 | out = self.conv2(F.relu(self.bn2(out))) 22 | out = torch.cat((x, out), 1) 23 | return out 24 | 25 | 26 | class SingleLayer(nn.Module): 27 | def __init__(self, nChannels, growthRate): 28 | super(SingleLayer, self).__init__() 29 | self.bn1 = nn.BatchNorm2d(nChannels) 30 | self.conv1 = nn.Conv2d(nChannels, growthRate, kernel_size=3, padding=1, bias=False) 31 | 32 | def forward(self, x): 33 | out = self.conv1(F.relu(self.bn1(x))) 34 | out = torch.cat((x, out), 1) 35 | return out 36 | 37 | 38 | class Transition(nn.Module): 39 | def __init__(self, nChannels, nOutChannels): 40 | super(Transition, self).__init__() 41 | self.bn1 = nn.BatchNorm2d(nChannels) 42 | self.conv1 = nn.Conv2d(nChannels, nOutChannels, kernel_size=1, bias=False) 43 | 44 | def forward(self, x): 45 | out = self.conv1(F.relu(self.bn1(x))) 46 | out = F.avg_pool2d(out, 2) 47 | return out 48 | 49 | 50 | class DenseNet(nn.Module): 51 | def __init__(self, growthRate, depth, reduction, nClasses, bottleneck): 52 | super(DenseNet, self).__init__() 53 | 54 | if bottleneck: nDenseBlocks = int( (depth-4) / 6 ) 55 | else : nDenseBlocks = int( (depth-4) / 3 ) 56 | 57 | self.message = 'CifarDenseNet : block : {:}, depth : {:}, reduction : {:}, growth-rate = {:}, class = {:}'.format('bottleneck' if bottleneck else 'basic', depth, reduction, growthRate, nClasses) 58 | 59 | nChannels = 2*growthRate 60 | self.conv1 = nn.Conv2d(3, nChannels, kernel_size=3, padding=1, bias=False) 61 | 62 | self.dense1 = self._make_dense(nChannels, growthRate, nDenseBlocks, bottleneck) 63 | nChannels += nDenseBlocks*growthRate 64 | nOutChannels = int(math.floor(nChannels*reduction)) 65 | self.trans1 = Transition(nChannels, nOutChannels) 66 | 67 | nChannels = nOutChannels 68 | self.dense2 = self._make_dense(nChannels, growthRate, nDenseBlocks, bottleneck) 69 | nChannels += nDenseBlocks*growthRate 70 | nOutChannels = int(math.floor(nChannels*reduction)) 71 | self.trans2 = Transition(nChannels, nOutChannels) 72 | 73 | nChannels = nOutChannels 74 | self.dense3 = self._make_dense(nChannels, growthRate, nDenseBlocks, bottleneck) 75 | nChannels += nDenseBlocks*growthRate 76 | 77 | self.act = nn.Sequential( 78 | nn.BatchNorm2d(nChannels), nn.ReLU(inplace=True), 79 | nn.AvgPool2d(8)) 80 | self.fc = nn.Linear(nChannels, nClasses) 81 | 82 | self.apply(initialize_resnet) 83 | 84 | def get_message(self): 85 | return self.message 86 | 87 | def _make_dense(self, nChannels, growthRate, nDenseBlocks, bottleneck): 88 | layers = [] 89 | for i in range(int(nDenseBlocks)): 90 | if bottleneck: 91 | layers.append(Bottleneck(nChannels, growthRate)) 92 | else: 93 | layers.append(SingleLayer(nChannels, growthRate)) 94 | nChannels += growthRate 95 | return nn.Sequential(*layers) 96 | 97 | def forward(self, inputs): 98 | out = self.conv1( inputs ) 99 | out = self.trans1(self.dense1(out)) 100 | out = self.trans2(self.dense2(out)) 101 | out = self.dense3(out) 102 | features = self.act(out) 103 | features = features.view(features.size(0), -1) 104 | out = self.fc(features) 105 | return features, out 106 | -------------------------------------------------------------------------------- /models/CifarResNet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from .initialization import initialize_resnet 5 | from .SharedUtils import additive_func 6 | 7 | 8 | class Downsample(nn.Module): 9 | 10 | def __init__(self, nIn, nOut, stride): 11 | super(Downsample, self).__init__() 12 | assert stride == 2 and nOut == 2*nIn, 'stride:{} IO:{},{}'.format(stride, nIn, nOut) 13 | self.in_dim = nIn 14 | self.out_dim = nOut 15 | self.avg = nn.AvgPool2d(kernel_size=2, stride=2, padding=0) 16 | self.conv = nn.Conv2d(nIn, nOut, kernel_size=1, stride=1, padding=0, bias=False) 17 | 18 | def forward(self, x): 19 | x = self.avg(x) 20 | out = self.conv(x) 21 | return out 22 | 23 | 24 | class ConvBNReLU(nn.Module): 25 | 26 | def __init__(self, nIn, nOut, kernel, stride, padding, bias, relu): 27 | super(ConvBNReLU, self).__init__() 28 | self.conv = nn.Conv2d(nIn, nOut, kernel_size=kernel, stride=stride, padding=padding, bias=bias) 29 | self.bn = nn.BatchNorm2d(nOut) 30 | if relu: self.relu = nn.ReLU(inplace=True) 31 | else : self.relu = None 32 | self.out_dim = nOut 33 | self.num_conv = 1 34 | 35 | def forward(self, x): 36 | conv = self.conv( x ) 37 | bn = self.bn( conv ) 38 | if self.relu: return self.relu( bn ) 39 | else : return bn 40 | 41 | 42 | class ResNetBasicblock(nn.Module): 43 | expansion = 1 44 | def __init__(self, inplanes, planes, stride): 45 | super(ResNetBasicblock, self).__init__() 46 | assert stride == 1 or stride == 2, 'invalid stride {:}'.format(stride) 47 | self.conv_a = ConvBNReLU(inplanes, planes, 3, stride, 1, False, True) 48 | self.conv_b = ConvBNReLU( planes, planes, 3, 1, 1, False, False) 49 | if stride == 2: 50 | self.downsample = Downsample(inplanes, planes, stride) 51 | elif inplanes != planes: 52 | self.downsample = ConvBNReLU(inplanes, planes, 1, 1, 0, False, False) 53 | else: 54 | self.downsample = None 55 | self.out_dim = planes 56 | self.num_conv = 2 57 | 58 | def forward(self, inputs): 59 | 60 | basicblock = self.conv_a(inputs) 61 | basicblock = self.conv_b(basicblock) 62 | 63 | if self.downsample is not None: 64 | residual = self.downsample(inputs) 65 | else: 66 | residual = inputs 67 | out = additive_func(residual, basicblock) 68 | return F.relu(out, inplace=True) 69 | 70 | 71 | 72 | class ResNetBottleneck(nn.Module): 73 | expansion = 4 74 | def __init__(self, inplanes, planes, stride): 75 | super(ResNetBottleneck, self).__init__() 76 | assert stride == 1 or stride == 2, 'invalid stride {:}'.format(stride) 77 | self.conv_1x1 = ConvBNReLU(inplanes, planes, 1, 1, 0, False, True) 78 | self.conv_3x3 = ConvBNReLU( planes, planes, 3, stride, 1, False, True) 79 | self.conv_1x4 = ConvBNReLU(planes, planes*self.expansion, 1, 1, 0, False, False) 80 | if stride == 2: 81 | self.downsample = Downsample(inplanes, planes*self.expansion, stride) 82 | elif inplanes != planes*self.expansion: 83 | self.downsample = ConvBNReLU(inplanes, planes*self.expansion, 1, 1, 0, False, False) 84 | else: 85 | self.downsample = None 86 | self.out_dim = planes * self.expansion 87 | self.num_conv = 3 88 | 89 | def forward(self, inputs): 90 | 91 | bottleneck = self.conv_1x1(inputs) 92 | bottleneck = self.conv_3x3(bottleneck) 93 | bottleneck = self.conv_1x4(bottleneck) 94 | 95 | if self.downsample is not None: 96 | residual = self.downsample(inputs) 97 | else: 98 | residual = inputs 99 | out = additive_func(residual, bottleneck) 100 | return F.relu(out, inplace=True) 101 | 102 | 103 | 104 | class CifarResNet(nn.Module): 105 | 106 | def __init__(self, block_name, depth, num_classes, zero_init_residual): 107 | super(CifarResNet, self).__init__() 108 | 109 | #Model type specifies number of layers for CIFAR-10 and CIFAR-100 model 110 | if block_name == 'ResNetBasicblock': 111 | block = ResNetBasicblock 112 | assert (depth - 2) % 6 == 0, 'depth should be one of 20, 32, 44, 56, 110' 113 | layer_blocks = (depth - 2) // 6 114 | elif block_name == 'ResNetBottleneck': 115 | block = ResNetBottleneck 116 | assert (depth - 2) % 9 == 0, 'depth should be one of 164' 117 | layer_blocks = (depth - 2) // 9 118 | else: 119 | raise ValueError('invalid block : {:}'.format(block_name)) 120 | 121 | self.message = 'CifarResNet : Block : {:}, Depth : {:}, Layers for each block : {:}'.format(block_name, depth, layer_blocks) 122 | self.num_classes = num_classes 123 | self.channels = [16] 124 | self.layers = nn.ModuleList( [ ConvBNReLU(3, 16, 3, 1, 1, False, True) ] ) 125 | for stage in range(3): 126 | for iL in range(layer_blocks): 127 | iC = self.channels[-1] 128 | planes = 16 * (2**stage) 129 | stride = 2 if stage > 0 and iL == 0 else 1 130 | module = block(iC, planes, stride) 131 | self.channels.append( module.out_dim ) 132 | self.layers.append ( module ) 133 | self.message += "\nstage={:}, ilayer={:02d}/{:02d}, block={:03d}, iC={:3d}, oC={:3d}, stride={:}".format(stage, iL, layer_blocks, len(self.layers)-1, iC, module.out_dim, stride) 134 | 135 | self.avgpool = nn.AvgPool2d(8) 136 | self.classifier = nn.Linear(module.out_dim, num_classes) 137 | assert sum(x.num_conv for x in self.layers) + 1 == depth, 'invalid depth check {:} vs {:}'.format(sum(x.num_conv for x in self.layers)+1, depth) 138 | 139 | self.apply(initialize_resnet) 140 | if zero_init_residual: 141 | for m in self.modules(): 142 | if isinstance(m, ResNetBasicblock): 143 | nn.init.constant_(m.conv_b.bn.weight, 0) 144 | elif isinstance(m, ResNetBottleneck): 145 | nn.init.constant_(m.conv_1x4.bn.weight, 0) 146 | 147 | def get_message(self): 148 | return self.message 149 | 150 | def forward(self, inputs): 151 | x = inputs 152 | for i, layer in enumerate(self.layers): 153 | x = layer( x ) 154 | features = self.avgpool(x) 155 | features = features.view(features.size(0), -1) 156 | logits = self.classifier(features) 157 | return features, logits 158 | -------------------------------------------------------------------------------- /models/CifarWideResNet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from .initialization import initialize_resnet 5 | 6 | 7 | class WideBasicblock(nn.Module): 8 | def __init__(self, inplanes, planes, stride, dropout=False): 9 | super(WideBasicblock, self).__init__() 10 | 11 | self.bn_a = nn.BatchNorm2d(inplanes) 12 | self.conv_a = nn.Conv2d(inplanes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 13 | 14 | self.bn_b = nn.BatchNorm2d(planes) 15 | if dropout: 16 | self.dropout = nn.Dropout2d(p=0.5, inplace=True) 17 | else: 18 | self.dropout = None 19 | self.conv_b = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False) 20 | 21 | if inplanes != planes: 22 | self.downsample = nn.Conv2d(inplanes, planes, kernel_size=1, stride=stride, padding=0, bias=False) 23 | else: 24 | self.downsample = None 25 | 26 | def forward(self, x): 27 | 28 | basicblock = self.bn_a(x) 29 | basicblock = F.relu(basicblock) 30 | basicblock = self.conv_a(basicblock) 31 | 32 | basicblock = self.bn_b(basicblock) 33 | basicblock = F.relu(basicblock) 34 | if self.dropout is not None: 35 | basicblock = self.dropout(basicblock) 36 | basicblock = self.conv_b(basicblock) 37 | 38 | if self.downsample is not None: 39 | x = self.downsample(x) 40 | 41 | return x + basicblock 42 | 43 | 44 | class CifarWideResNet(nn.Module): 45 | """ 46 | ResNet optimized for the Cifar dataset, as specified in 47 | https://arxiv.org/abs/1512.03385.pdf 48 | """ 49 | def __init__(self, depth, widen_factor, num_classes, dropout): 50 | super(CifarWideResNet, self).__init__() 51 | 52 | #Model type specifies number of layers for CIFAR-10 and CIFAR-100 model 53 | assert (depth - 4) % 6 == 0, 'depth should be one of 20, 32, 44, 56, 110' 54 | layer_blocks = (depth - 4) // 6 55 | print ('CifarPreResNet : Depth : {} , Layers for each block : {}'.format(depth, layer_blocks)) 56 | 57 | self.num_classes = num_classes 58 | self.dropout = dropout 59 | self.conv_3x3 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1, bias=False) 60 | 61 | self.message = 'Wide ResNet : depth={:}, widen_factor={:}, class={:}'.format(depth, widen_factor, num_classes) 62 | self.inplanes = 16 63 | self.stage_1 = self._make_layer(WideBasicblock, 16*widen_factor, layer_blocks, 1) 64 | self.stage_2 = self._make_layer(WideBasicblock, 32*widen_factor, layer_blocks, 2) 65 | self.stage_3 = self._make_layer(WideBasicblock, 64*widen_factor, layer_blocks, 2) 66 | self.lastact = nn.Sequential(nn.BatchNorm2d(64*widen_factor), nn.ReLU(inplace=True)) 67 | self.avgpool = nn.AvgPool2d(8) 68 | self.classifier = nn.Linear(64*widen_factor, num_classes) 69 | 70 | self.apply(initialize_resnet) 71 | 72 | def get_message(self): 73 | return self.message 74 | 75 | def _make_layer(self, block, planes, blocks, stride): 76 | 77 | layers = [] 78 | layers.append(block(self.inplanes, planes, stride, self.dropout)) 79 | self.inplanes = planes 80 | for i in range(1, blocks): 81 | layers.append(block(self.inplanes, planes, 1, self.dropout)) 82 | 83 | return nn.Sequential(*layers) 84 | 85 | def forward(self, x): 86 | x = self.conv_3x3(x) 87 | x = self.stage_1(x) 88 | x = self.stage_2(x) 89 | x = self.stage_3(x) 90 | x = self.lastact(x) 91 | x = self.avgpool(x) 92 | features = x.view(x.size(0), -1) 93 | outs = self.classifier(features) 94 | return features, outs 95 | -------------------------------------------------------------------------------- /models/ImageNet_MobileNetV2.py: -------------------------------------------------------------------------------- 1 | # MobileNetV2: Inverted Residuals and Linear Bottlenecks, CVPR 2018 2 | from torch import nn 3 | from .initialization import initialize_resnet 4 | 5 | 6 | class ConvBNReLU(nn.Module): 7 | def __init__(self, in_planes, out_planes, kernel_size=3, stride=1, groups=1): 8 | super(ConvBNReLU, self).__init__() 9 | padding = (kernel_size - 1) // 2 10 | self.conv = nn.Conv2d(in_planes, out_planes, kernel_size, stride, padding, groups=groups, bias=False) 11 | self.bn = nn.BatchNorm2d(out_planes) 12 | self.relu = nn.ReLU6(inplace=True) 13 | 14 | def forward(self, x): 15 | out = self.conv( x ) 16 | out = self.bn ( out ) 17 | out = self.relu( out ) 18 | return out 19 | 20 | 21 | class InvertedResidual(nn.Module): 22 | def __init__(self, inp, oup, stride, expand_ratio): 23 | super(InvertedResidual, self).__init__() 24 | self.stride = stride 25 | assert stride in [1, 2] 26 | 27 | hidden_dim = int(round(inp * expand_ratio)) 28 | self.use_res_connect = self.stride == 1 and inp == oup 29 | 30 | layers = [] 31 | if expand_ratio != 1: 32 | # pw 33 | layers.append(ConvBNReLU(inp, hidden_dim, kernel_size=1)) 34 | layers.extend([ 35 | # dw 36 | ConvBNReLU(hidden_dim, hidden_dim, stride=stride, groups=hidden_dim), 37 | # pw-linear 38 | nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False), 39 | nn.BatchNorm2d(oup), 40 | ]) 41 | self.conv = nn.Sequential(*layers) 42 | 43 | def forward(self, x): 44 | if self.use_res_connect: 45 | return x + self.conv(x) 46 | else: 47 | return self.conv(x) 48 | 49 | 50 | class MobileNetV2(nn.Module): 51 | def __init__(self, num_classes, width_mult, input_channel, last_channel, block_name, dropout): 52 | super(MobileNetV2, self).__init__() 53 | if block_name == 'InvertedResidual': 54 | block = InvertedResidual 55 | else: 56 | raise ValueError('invalid block name : {:}'.format(block_name)) 57 | inverted_residual_setting = [ 58 | # t, c, n, s 59 | [1, 16 , 1, 1], 60 | [6, 24 , 2, 2], 61 | [6, 32 , 3, 2], 62 | [6, 64 , 4, 2], 63 | [6, 96 , 3, 1], 64 | [6, 160, 3, 2], 65 | [6, 320, 1, 1], 66 | ] 67 | 68 | # building first layer 69 | input_channel = int(input_channel * width_mult) 70 | self.last_channel = int(last_channel * max(1.0, width_mult)) 71 | features = [ConvBNReLU(3, input_channel, stride=2)] 72 | # building inverted residual blocks 73 | for t, c, n, s in inverted_residual_setting: 74 | output_channel = int(c * width_mult) 75 | for i in range(n): 76 | stride = s if i == 0 else 1 77 | features.append(block(input_channel, output_channel, stride, expand_ratio=t)) 78 | input_channel = output_channel 79 | # building last several layers 80 | features.append(ConvBNReLU(input_channel, self.last_channel, kernel_size=1)) 81 | # make it nn.Sequential 82 | self.features = nn.Sequential(*features) 83 | 84 | # building classifier 85 | self.classifier = nn.Sequential( 86 | nn.Dropout(dropout), 87 | nn.Linear(self.last_channel, num_classes), 88 | ) 89 | self.message = 'MobileNetV2 : width_mult={:}, in-C={:}, last-C={:}, block={:}, dropout={:}'.format(width_mult, input_channel, last_channel, block_name, dropout) 90 | 91 | # weight initialization 92 | self.apply( initialize_resnet ) 93 | 94 | def get_message(self): 95 | return self.message 96 | 97 | def forward(self, inputs): 98 | features = self.features(inputs) 99 | vectors = features.mean([2, 3]) 100 | predicts = self.classifier(vectors) 101 | return features, predicts 102 | -------------------------------------------------------------------------------- /models/SharedUtils.py: -------------------------------------------------------------------------------- 1 | ##################################################### 2 | # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.01 # 3 | ##################################################### 4 | import torch 5 | import torch.nn as nn 6 | 7 | 8 | def additive_func(A, B): 9 | assert A.dim() == B.dim() and A.size(0) == B.size(0), '{:} vs {:}'.format(A.size(), B.size()) 10 | C = min(A.size(1), B.size(1)) 11 | if A.size(1) == B.size(1): 12 | return A + B 13 | elif A.size(1) < B.size(1): 14 | out = B.clone() 15 | out[:,:C] += A 16 | return out 17 | else: 18 | out = A.clone() 19 | out[:,:C] += B 20 | return out 21 | 22 | 23 | def change_key(key, value): 24 | def func(m): 25 | if hasattr(m, key): 26 | setattr(m, key, value) 27 | return func 28 | 29 | 30 | def parse_channel_info(xstring): 31 | blocks = xstring.split(' ') 32 | blocks = [x.split('-') for x in blocks] 33 | blocks = [[int(_) for _ in x] for x in blocks] 34 | return blocks 35 | -------------------------------------------------------------------------------- /models/cell_infers/__init__.py: -------------------------------------------------------------------------------- 1 | ##################################################### 2 | # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.01 # 3 | ##################################################### 4 | from .tiny_network import TinyNetwork 5 | from .nasnet_cifar import NASNetonCIFAR 6 | -------------------------------------------------------------------------------- /models/cell_infers/cells.py: -------------------------------------------------------------------------------- 1 | ##################################################### 2 | # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.01 # 3 | ##################################################### 4 | 5 | import torch 6 | import torch.nn as nn 7 | from copy import deepcopy 8 | from ..cell_operations import OPS 9 | 10 | 11 | # Cell for NAS-Bench-201 12 | class InferCell(nn.Module): 13 | 14 | def __init__(self, genotype, C_in, C_out, stride): 15 | super(InferCell, self).__init__() 16 | 17 | self.layers = nn.ModuleList() 18 | self.node_IN = [] 19 | self.node_IX = [] 20 | self.genotype = deepcopy(genotype) 21 | for i in range(1, len(genotype)): 22 | node_info = genotype[i-1] 23 | cur_index = [] 24 | cur_innod = [] 25 | for (op_name, op_in) in node_info: 26 | if op_in == 0: 27 | layer = OPS[op_name](C_in , C_out, stride, True, True) 28 | else: 29 | layer = OPS[op_name](C_out, C_out, 1, True, True) 30 | cur_index.append( len(self.layers) ) 31 | cur_innod.append( op_in ) 32 | self.layers.append( layer ) 33 | self.node_IX.append( cur_index ) 34 | self.node_IN.append( cur_innod ) 35 | self.nodes = len(genotype) 36 | self.in_dim = C_in 37 | self.out_dim = C_out 38 | 39 | def extra_repr(self): 40 | string = 'info :: nodes={nodes}, inC={in_dim}, outC={out_dim}'.format(**self.__dict__) 41 | laystr = [] 42 | for i, (node_layers, node_innods) in enumerate(zip(self.node_IX,self.node_IN)): 43 | y = ['I{:}-L{:}'.format(_ii, _il) for _il, _ii in zip(node_layers, node_innods)] 44 | x = '{:}<-({:})'.format(i+1, ','.join(y)) 45 | laystr.append( x ) 46 | return string + ', [{:}]'.format( ' | '.join(laystr) ) + ', {:}'.format(self.genotype.tostr()) 47 | 48 | def forward(self, inputs): 49 | nodes = [inputs] 50 | for i, (node_layers, node_innods) in enumerate(zip(self.node_IX,self.node_IN)): 51 | node_feature = sum( self.layers[_il](nodes[_ii]) for _il, _ii in zip(node_layers, node_innods) ) 52 | nodes.append( node_feature ) 53 | return nodes[-1] 54 | 55 | 56 | 57 | # Learning Transferable Architectures for Scalable Image Recognition, CVPR 2018 58 | class NASNetInferCell(nn.Module): 59 | 60 | def __init__(self, genotype, C_prev_prev, C_prev, C, reduction, reduction_prev, affine, track_running_stats): 61 | super(NASNetInferCell, self).__init__() 62 | self.reduction = reduction 63 | if reduction_prev: self.preprocess0 = OPS['skip_connect'](C_prev_prev, C, 2, affine, track_running_stats) 64 | else : self.preprocess0 = OPS['nor_conv_1x1'](C_prev_prev, C, 1, affine, track_running_stats) 65 | self.preprocess1 = OPS['nor_conv_1x1'](C_prev, C, 1, affine, track_running_stats) 66 | 67 | if not reduction: 68 | nodes, concats = genotype['normal'], genotype['normal_concat'] 69 | else: 70 | nodes, concats = genotype['reduce'], genotype['reduce_concat'] 71 | self._multiplier = len(concats) 72 | self._concats = concats 73 | self._steps = len(nodes) 74 | self._nodes = nodes 75 | self.edges = nn.ModuleDict() 76 | for i, node in enumerate(nodes): 77 | for in_node in node: 78 | name, j = in_node[0], in_node[1] 79 | stride = 2 if reduction and j < 2 else 1 80 | node_str = '{:}<-{:}'.format(i+2, j) 81 | self.edges[node_str] = OPS[name](C, C, stride, affine, track_running_stats) 82 | 83 | # [TODO] to support drop_prob in this function.. 84 | def forward(self, s0, s1, unused_drop_prob): 85 | s0 = self.preprocess0(s0) 86 | s1 = self.preprocess1(s1) 87 | 88 | states = [s0, s1] 89 | for i, node in enumerate(self._nodes): 90 | clist = [] 91 | for in_node in node: 92 | name, j = in_node[0], in_node[1] 93 | node_str = '{:}<-{:}'.format(i+2, j) 94 | op = self.edges[ node_str ] 95 | clist.append( op(states[j]) ) 96 | states.append( sum(clist) ) 97 | return torch.cat([states[x] for x in self._concats], dim=1) 98 | 99 | 100 | class AuxiliaryHeadCIFAR(nn.Module): 101 | 102 | def __init__(self, C, num_classes): 103 | """assuming input size 8x8""" 104 | super(AuxiliaryHeadCIFAR, self).__init__() 105 | self.features = nn.Sequential( 106 | nn.ReLU(inplace=True), 107 | nn.AvgPool2d(5, stride=3, padding=0, count_include_pad=False), # image size = 2 x 2 108 | nn.Conv2d(C, 128, 1, bias=False), 109 | nn.BatchNorm2d(128), 110 | nn.ReLU(inplace=True), 111 | nn.Conv2d(128, 768, 2, bias=False), 112 | nn.BatchNorm2d(768), 113 | nn.ReLU(inplace=True) 114 | ) 115 | self.classifier = nn.Linear(768, num_classes) 116 | 117 | def forward(self, x): 118 | x = self.features(x) 119 | x = self.classifier(x.view(x.size(0),-1)) 120 | return x 121 | -------------------------------------------------------------------------------- /models/cell_infers/nasnet_cifar.py: -------------------------------------------------------------------------------- 1 | ##################################################### 2 | # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.01 # 3 | ##################################################### 4 | import torch 5 | import torch.nn as nn 6 | from copy import deepcopy 7 | from .cells import NASNetInferCell as InferCell, AuxiliaryHeadCIFAR 8 | 9 | 10 | # The macro structure is based on NASNet 11 | class NASNetonCIFAR(nn.Module): 12 | 13 | def __init__(self, C, N, stem_multiplier, num_classes, genotype, auxiliary, affine=True, track_running_stats=True): 14 | super(NASNetonCIFAR, self).__init__() 15 | self._C = C 16 | self._layerN = N 17 | self.stem = nn.Sequential( 18 | nn.Conv2d(3, C*stem_multiplier, kernel_size=3, padding=1, bias=False), 19 | nn.BatchNorm2d(C*stem_multiplier)) 20 | 21 | # config for each layer 22 | layer_channels = [C ] * N + [C*2 ] + [C*2 ] * (N-1) + [C*4 ] + [C*4 ] * (N-1) 23 | layer_reductions = [False] * N + [True] + [False] * (N-1) + [True] + [False] * (N-1) 24 | 25 | C_prev_prev, C_prev, C_curr, reduction_prev = C*stem_multiplier, C*stem_multiplier, C, False 26 | self.auxiliary_index = None 27 | self.auxiliary_head = None 28 | self.cells = nn.ModuleList() 29 | for index, (C_curr, reduction) in enumerate(zip(layer_channels, layer_reductions)): 30 | cell = InferCell(genotype, C_prev_prev, C_prev, C_curr, reduction, reduction_prev, affine, track_running_stats) 31 | self.cells.append( cell ) 32 | C_prev_prev, C_prev, reduction_prev = C_prev, cell._multiplier*C_curr, reduction 33 | if reduction and C_curr == C*4 and auxiliary: 34 | self.auxiliary_head = AuxiliaryHeadCIFAR(C_prev, num_classes) 35 | self.auxiliary_index = index 36 | self._Layer = len(self.cells) 37 | self.lastact = nn.Sequential(nn.BatchNorm2d(C_prev), nn.ReLU(inplace=True)) 38 | self.global_pooling = nn.AdaptiveAvgPool2d(1) 39 | self.classifier = nn.Linear(C_prev, num_classes) 40 | self.drop_path_prob = -1 41 | 42 | def update_drop_path(self, drop_path_prob): 43 | self.drop_path_prob = drop_path_prob 44 | 45 | def auxiliary_param(self): 46 | if self.auxiliary_head is None: return [] 47 | else: return list( self.auxiliary_head.parameters() ) 48 | 49 | def get_message(self): 50 | string = self.extra_repr() 51 | for i, cell in enumerate(self.cells): 52 | string += '\n {:02d}/{:02d} :: {:}'.format(i, len(self.cells), cell.extra_repr()) 53 | return string 54 | 55 | def extra_repr(self): 56 | return ('{name}(C={_C}, N={_layerN}, L={_Layer})'.format(name=self.__class__.__name__, **self.__dict__)) 57 | 58 | def forward(self, inputs): 59 | stem_feature, logits_aux = self.stem(inputs), None 60 | cell_results = [stem_feature, stem_feature] 61 | for i, cell in enumerate(self.cells): 62 | cell_feature = cell(cell_results[-2], cell_results[-1], self.drop_path_prob) 63 | cell_results.append( cell_feature ) 64 | if self.auxiliary_index is not None and i == self.auxiliary_index and self.training: 65 | logits_aux = self.auxiliary_head( cell_results[-1] ) 66 | out = self.lastact(cell_results[-1]) 67 | out = self.global_pooling( out ) 68 | out = out.view(out.size(0), -1) 69 | logits = self.classifier(out) 70 | if logits_aux is None: return out, logits 71 | else: return out, [logits, logits_aux] 72 | -------------------------------------------------------------------------------- /models/cell_infers/tiny_network.py: -------------------------------------------------------------------------------- 1 | ##################################################### 2 | # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.01 # 3 | ##################################################### 4 | import torch.nn as nn 5 | from ..cell_operations import ResNetBasicblock 6 | from .cells import InferCell 7 | 8 | 9 | # The macro structure for architectures in NAS-Bench-201 10 | class TinyNetwork(nn.Module): 11 | 12 | def __init__(self, C, N, genotype, num_classes): 13 | super(TinyNetwork, self).__init__() 14 | self._C = C 15 | self._layerN = N 16 | 17 | self.stem = nn.Sequential( 18 | nn.Conv2d(3, C, kernel_size=3, padding=1, bias=False), 19 | nn.BatchNorm2d(C)) 20 | 21 | layer_channels = [C ] * N + [C*2 ] + [C*2 ] * N + [C*4 ] + [C*4 ] * N 22 | layer_reductions = [False] * N + [True] + [False] * N + [True] + [False] * N 23 | 24 | C_prev = C 25 | self.cells = nn.ModuleList() 26 | for index, (C_curr, reduction) in enumerate(zip(layer_channels, layer_reductions)): 27 | if reduction: 28 | cell = ResNetBasicblock(C_prev, C_curr, 2, True) 29 | else: 30 | cell = InferCell(genotype, C_prev, C_curr, 1) 31 | self.cells.append( cell ) 32 | C_prev = cell.out_dim 33 | self._Layer= len(self.cells) 34 | 35 | self.lastact = nn.Sequential(nn.BatchNorm2d(C_prev), nn.ReLU(inplace=True)) 36 | self.global_pooling = nn.AdaptiveAvgPool2d(1) 37 | self.classifier = nn.Linear(C_prev, num_classes) 38 | 39 | def get_message(self): 40 | string = self.extra_repr() 41 | for i, cell in enumerate(self.cells): 42 | string += '\n {:02d}/{:02d} :: {:}'.format(i, len(self.cells), cell.extra_repr()) 43 | return string 44 | 45 | def extra_repr(self): 46 | return ('{name}(C={_C}, N={_layerN}, L={_Layer})'.format(name=self.__class__.__name__, **self.__dict__)) 47 | 48 | def forward(self, inputs): 49 | feature = self.stem(inputs) 50 | for i, cell in enumerate(self.cells): 51 | feature = cell(feature) 52 | 53 | out = self.lastact(feature) 54 | out = self.global_pooling( out ) 55 | out = out.view(out.size(0), -1) 56 | logits = self.classifier(out) 57 | 58 | return logits, out 59 | -------------------------------------------------------------------------------- /models/cell_searchs/__init__.py: -------------------------------------------------------------------------------- 1 | ################################################## 2 | # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 # 3 | ################################################## 4 | # The macro structure is defined in NAS-Bench-201 5 | from .search_model_darts import TinyNetworkDarts 6 | from .search_model_gdas import TinyNetworkGDAS 7 | from .search_model_setn import TinyNetworkSETN 8 | from .search_model_enas import TinyNetworkENAS 9 | from .search_model_random import TinyNetworkRANDOM 10 | from .genotypes import Structure as CellStructure, architectures as CellArchitectures 11 | # NASNet-based macro structure 12 | from .search_model_gdas_nasnet import NASNetworkGDAS 13 | from .search_model_darts_nasnet import NASNetworkDARTS 14 | 15 | 16 | nas201_super_nets = {'DARTS-V1': TinyNetworkDarts, 17 | "DARTS-V2": TinyNetworkDarts, 18 | "GDAS": TinyNetworkGDAS, 19 | "SETN": TinyNetworkSETN, 20 | "ENAS": TinyNetworkENAS, 21 | "RANDOM": TinyNetworkRANDOM} 22 | 23 | nasnet_super_nets = {"GDAS": NASNetworkGDAS, 24 | "DARTS": NASNetworkDARTS} 25 | -------------------------------------------------------------------------------- /models/cell_searchs/_test_module.py: -------------------------------------------------------------------------------- 1 | ################################################## 2 | # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 # 3 | ################################################## 4 | import torch 5 | from search_model_enas_utils import Controller 6 | 7 | def main(): 8 | controller = Controller(6, 4) 9 | predictions = controller() 10 | 11 | if __name__ == '__main__': 12 | main() 13 | -------------------------------------------------------------------------------- /models/cell_searchs/search_model_darts.py: -------------------------------------------------------------------------------- 1 | ################################################## 2 | # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 # 3 | ######################################################## 4 | # DARTS: Differentiable Architecture Search, ICLR 2019 # 5 | ######################################################## 6 | import torch 7 | import torch.nn as nn 8 | from copy import deepcopy 9 | from ..cell_operations import ResNetBasicblock 10 | from .search_cells import NAS201SearchCell as SearchCell 11 | from .genotypes import Structure 12 | 13 | 14 | class TinyNetworkDarts(nn.Module): 15 | 16 | def __init__(self, C, N, max_nodes, num_classes, search_space, affine, track_running_stats): 17 | super(TinyNetworkDarts, self).__init__() 18 | self._C = C 19 | self._layerN = N 20 | self.max_nodes = max_nodes 21 | self.stem = nn.Sequential( 22 | nn.Conv2d(3, C, kernel_size=3, padding=1, bias=False), 23 | nn.BatchNorm2d(C)) 24 | 25 | layer_channels = [C ] * N + [C*2 ] + [C*2 ] * N + [C*4 ] + [C*4 ] * N 26 | layer_reductions = [False] * N + [True] + [False] * N + [True] + [False] * N 27 | 28 | C_prev, num_edge, edge2index = C, None, None 29 | self.cells = nn.ModuleList() 30 | for index, (C_curr, reduction) in enumerate(zip(layer_channels, layer_reductions)): 31 | if reduction: 32 | cell = ResNetBasicblock(C_prev, C_curr, 2) 33 | else: 34 | cell = SearchCell(C_prev, C_curr, 1, max_nodes, search_space, affine, track_running_stats) 35 | if num_edge is None: num_edge, edge2index = cell.num_edges, cell.edge2index 36 | else: assert num_edge == cell.num_edges and edge2index == cell.edge2index, 'invalid {:} vs. {:}.'.format(num_edge, cell.num_edges) 37 | self.cells.append( cell ) 38 | C_prev = cell.out_dim 39 | self.op_names = deepcopy( search_space ) 40 | self._Layer = len(self.cells) 41 | self.edge2index = edge2index 42 | self.lastact = nn.Sequential(nn.BatchNorm2d(C_prev), nn.ReLU(inplace=True)) 43 | self.global_pooling = nn.AdaptiveAvgPool2d(1) 44 | self.classifier = nn.Linear(C_prev, num_classes) 45 | self.arch_parameters = nn.Parameter( 1e-3*torch.randn(num_edge, len(search_space)) ) 46 | 47 | def get_weights(self): 48 | xlist = list( self.stem.parameters() ) + list( self.cells.parameters() ) 49 | xlist+= list( self.lastact.parameters() ) + list( self.global_pooling.parameters() ) 50 | xlist+= list( self.classifier.parameters() ) 51 | return xlist 52 | 53 | def get_alphas(self): 54 | return [self.arch_parameters] 55 | 56 | def show_alphas(self): 57 | with torch.no_grad(): 58 | return 'arch-parameters :\n{:}'.format( nn.functional.softmax(self.arch_parameters, dim=-1).cpu() ) 59 | 60 | def get_message(self): 61 | string = self.extra_repr() 62 | for i, cell in enumerate(self.cells): 63 | string += '\n {:02d}/{:02d} :: {:}'.format(i, len(self.cells), cell.extra_repr()) 64 | return string 65 | 66 | def extra_repr(self): 67 | return ('{name}(C={_C}, Max-Nodes={max_nodes}, N={_layerN}, L={_Layer})'.format(name=self.__class__.__name__, **self.__dict__)) 68 | 69 | def genotype(self): 70 | genotypes = [] 71 | for i in range(1, self.max_nodes): 72 | xlist = [] 73 | for j in range(i): 74 | node_str = '{:}<-{:}'.format(i, j) 75 | with torch.no_grad(): 76 | weights = self.arch_parameters[ self.edge2index[node_str] ] 77 | op_name = self.op_names[ weights.argmax().item() ] 78 | xlist.append((op_name, j)) 79 | genotypes.append( tuple(xlist) ) 80 | return Structure( genotypes ) 81 | 82 | def forward(self, inputs): 83 | alphas = nn.functional.softmax(self.arch_parameters, dim=-1) 84 | 85 | feature = self.stem(inputs) 86 | for i, cell in enumerate(self.cells): 87 | if isinstance(cell, SearchCell): 88 | feature = cell(feature, alphas) 89 | else: 90 | feature = cell(feature) 91 | 92 | out = self.lastact(feature) 93 | out = self.global_pooling( out ) 94 | out = out.view(out.size(0), -1) 95 | logits = self.classifier(out) 96 | 97 | return out, logits 98 | -------------------------------------------------------------------------------- /models/cell_searchs/search_model_darts_nasnet.py: -------------------------------------------------------------------------------- 1 | #################### 2 | # DARTS, ICLR 2019 # 3 | #################### 4 | import torch 5 | import torch.nn as nn 6 | from copy import deepcopy 7 | from typing import List, Text, Dict 8 | from .search_cells import NASNetSearchCell as SearchCell 9 | 10 | 11 | # The macro structure is based on NASNet 12 | class NASNetworkDARTS(nn.Module): 13 | 14 | def __init__(self, C: int, N: int, steps: int, multiplier: int, stem_multiplier: int, 15 | num_classes: int, search_space: List[Text], affine: bool, track_running_stats: bool): 16 | super(NASNetworkDARTS, self).__init__() 17 | self._C = C 18 | self._layerN = N 19 | self._steps = steps 20 | self._multiplier = multiplier 21 | self.stem = nn.Sequential( 22 | nn.Conv2d(3, C*stem_multiplier, kernel_size=3, padding=1, bias=False), 23 | nn.BatchNorm2d(C*stem_multiplier)) 24 | 25 | # config for each layer 26 | layer_channels = [C ] * N + [C*2 ] + [C*2 ] * (N-1) + [C*4 ] + [C*4 ] * (N-1) 27 | layer_reductions = [False] * N + [True] + [False] * (N-1) + [True] + [False] * (N-1) 28 | 29 | num_edge, edge2index = None, None 30 | C_prev_prev, C_prev, C_curr, reduction_prev = C*stem_multiplier, C*stem_multiplier, C, False 31 | 32 | self.cells = nn.ModuleList() 33 | for index, (C_curr, reduction) in enumerate(zip(layer_channels, layer_reductions)): 34 | cell = SearchCell(search_space, steps, multiplier, C_prev_prev, C_prev, C_curr, reduction, reduction_prev, affine, track_running_stats) 35 | if num_edge is None: num_edge, edge2index = cell.num_edges, cell.edge2index 36 | else: assert num_edge == cell.num_edges and edge2index == cell.edge2index, 'invalid {:} vs. {:}.'.format(num_edge, cell.num_edges) 37 | self.cells.append( cell ) 38 | C_prev_prev, C_prev, reduction_prev = C_prev, multiplier*C_curr, reduction 39 | self.op_names = deepcopy( search_space ) 40 | self._Layer = len(self.cells) 41 | self.edge2index = edge2index 42 | self.lastact = nn.Sequential(nn.BatchNorm2d(C_prev), nn.ReLU(inplace=True)) 43 | self.global_pooling = nn.AdaptiveAvgPool2d(1) 44 | self.classifier = nn.Linear(C_prev, num_classes) 45 | self.arch_normal_parameters = nn.Parameter( 1e-3*torch.randn(num_edge, len(search_space)) ) 46 | self.arch_reduce_parameters = nn.Parameter( 1e-3*torch.randn(num_edge, len(search_space)) ) 47 | 48 | def get_weights(self) -> List[torch.nn.Parameter]: 49 | xlist = list( self.stem.parameters() ) + list( self.cells.parameters() ) 50 | xlist+= list( self.lastact.parameters() ) + list( self.global_pooling.parameters() ) 51 | xlist+= list( self.classifier.parameters() ) 52 | return xlist 53 | 54 | def get_alphas(self) -> List[torch.nn.Parameter]: 55 | return [self.arch_normal_parameters, self.arch_reduce_parameters] 56 | 57 | def show_alphas(self) -> Text: 58 | with torch.no_grad(): 59 | A = 'arch-normal-parameters :\n{:}'.format( nn.functional.softmax(self.arch_normal_parameters, dim=-1).cpu() ) 60 | B = 'arch-reduce-parameters :\n{:}'.format( nn.functional.softmax(self.arch_reduce_parameters, dim=-1).cpu() ) 61 | return '{:}\n{:}'.format(A, B) 62 | 63 | def get_message(self) -> Text: 64 | string = self.extra_repr() 65 | for i, cell in enumerate(self.cells): 66 | string += '\n {:02d}/{:02d} :: {:}'.format(i, len(self.cells), cell.extra_repr()) 67 | return string 68 | 69 | def extra_repr(self) -> Text: 70 | return ('{name}(C={_C}, N={_layerN}, steps={_steps}, multiplier={_multiplier}, L={_Layer})'.format(name=self.__class__.__name__, **self.__dict__)) 71 | 72 | def genotype(self) -> Dict[Text, List]: 73 | def _parse(weights): 74 | gene = [] 75 | for i in range(self._steps): 76 | edges = [] 77 | for j in range(2+i): 78 | node_str = '{:}<-{:}'.format(i, j) 79 | ws = weights[ self.edge2index[node_str] ] 80 | for k, op_name in enumerate(self.op_names): 81 | if op_name == 'none': continue 82 | edges.append( (op_name, j, ws[k]) ) 83 | edges = sorted(edges, key=lambda x: -x[-1]) 84 | selected_edges = edges[:2] 85 | gene.append( tuple(selected_edges) ) 86 | return gene 87 | with torch.no_grad(): 88 | gene_normal = _parse(torch.softmax(self.arch_normal_parameters, dim=-1).cpu().numpy()) 89 | gene_reduce = _parse(torch.softmax(self.arch_reduce_parameters, dim=-1).cpu().numpy()) 90 | return {'normal': gene_normal, 'normal_concat': list(range(2+self._steps-self._multiplier, self._steps+2)), 91 | 'reduce': gene_reduce, 'reduce_concat': list(range(2+self._steps-self._multiplier, self._steps+2))} 92 | 93 | def forward(self, inputs): 94 | 95 | normal_w = nn.functional.softmax(self.arch_normal_parameters, dim=1) 96 | reduce_w = nn.functional.softmax(self.arch_reduce_parameters, dim=1) 97 | 98 | s0 = s1 = self.stem(inputs) 99 | for i, cell in enumerate(self.cells): 100 | if cell.reduction: ww = reduce_w 101 | else : ww = normal_w 102 | s0, s1 = s1, cell.forward_darts(s0, s1, ww) 103 | out = self.lastact(s1) 104 | out = self.global_pooling( out ) 105 | out = out.view(out.size(0), -1) 106 | logits = self.classifier(out) 107 | 108 | return out, logits 109 | -------------------------------------------------------------------------------- /models/cell_searchs/search_model_enas.py: -------------------------------------------------------------------------------- 1 | ################################################## 2 | # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 # 3 | ########################################################################## 4 | # Efficient Neural Architecture Search via Parameters Sharing, ICML 2018 # 5 | ########################################################################## 6 | import torch 7 | import torch.nn as nn 8 | from copy import deepcopy 9 | from ..cell_operations import ResNetBasicblock 10 | from .search_cells import NAS201SearchCell as SearchCell 11 | from .genotypes import Structure 12 | from .search_model_enas_utils import Controller 13 | 14 | 15 | class TinyNetworkENAS(nn.Module): 16 | 17 | def __init__(self, C, N, max_nodes, num_classes, search_space, affine, track_running_stats): 18 | super(TinyNetworkENAS, self).__init__() 19 | self._C = C 20 | self._layerN = N 21 | self.max_nodes = max_nodes 22 | self.stem = nn.Sequential( 23 | nn.Conv2d(3, C, kernel_size=3, padding=1, bias=False), 24 | nn.BatchNorm2d(C)) 25 | 26 | layer_channels = [C ] * N + [C*2 ] + [C*2 ] * N + [C*4 ] + [C*4 ] * N 27 | layer_reductions = [False] * N + [True] + [False] * N + [True] + [False] * N 28 | 29 | C_prev, num_edge, edge2index = C, None, None 30 | self.cells = nn.ModuleList() 31 | for index, (C_curr, reduction) in enumerate(zip(layer_channels, layer_reductions)): 32 | if reduction: 33 | cell = ResNetBasicblock(C_prev, C_curr, 2) 34 | else: 35 | cell = SearchCell(C_prev, C_curr, 1, max_nodes, search_space, affine, track_running_stats) 36 | if num_edge is None: num_edge, edge2index = cell.num_edges, cell.edge2index 37 | else: assert num_edge == cell.num_edges and edge2index == cell.edge2index, 'invalid {:} vs. {:}.'.format(num_edge, cell.num_edges) 38 | self.cells.append( cell ) 39 | C_prev = cell.out_dim 40 | self.op_names = deepcopy( search_space ) 41 | self._Layer = len(self.cells) 42 | self.edge2index = edge2index 43 | self.lastact = nn.Sequential(nn.BatchNorm2d(C_prev), nn.ReLU(inplace=True)) 44 | self.global_pooling = nn.AdaptiveAvgPool2d(1) 45 | self.classifier = nn.Linear(C_prev, num_classes) 46 | # to maintain the sampled architecture 47 | self.sampled_arch = None 48 | 49 | def update_arch(self, _arch): 50 | if _arch is None: 51 | self.sampled_arch = None 52 | elif isinstance(_arch, Structure): 53 | self.sampled_arch = _arch 54 | elif isinstance(_arch, (list, tuple)): 55 | genotypes = [] 56 | for i in range(1, self.max_nodes): 57 | xlist = [] 58 | for j in range(i): 59 | node_str = '{:}<-{:}'.format(i, j) 60 | op_index = _arch[ self.edge2index[node_str] ] 61 | op_name = self.op_names[ op_index ] 62 | xlist.append((op_name, j)) 63 | genotypes.append( tuple(xlist) ) 64 | self.sampled_arch = Structure(genotypes) 65 | else: 66 | raise ValueError('invalid type of input architecture : {:}'.format(_arch)) 67 | return self.sampled_arch 68 | 69 | def create_controller(self): 70 | return Controller(len(self.edge2index), len(self.op_names)) 71 | 72 | def get_message(self): 73 | string = self.extra_repr() 74 | for i, cell in enumerate(self.cells): 75 | string += '\n {:02d}/{:02d} :: {:}'.format(i, len(self.cells), cell.extra_repr()) 76 | return string 77 | 78 | def extra_repr(self): 79 | return ('{name}(C={_C}, Max-Nodes={max_nodes}, N={_layerN}, L={_Layer})'.format(name=self.__class__.__name__, **self.__dict__)) 80 | 81 | def forward(self, inputs): 82 | 83 | feature = self.stem(inputs) 84 | for i, cell in enumerate(self.cells): 85 | if isinstance(cell, SearchCell): 86 | feature = cell.forward_dynamic(feature, self.sampled_arch) 87 | else: feature = cell(feature) 88 | 89 | out = self.lastact(feature) 90 | out = self.global_pooling( out ) 91 | out = out.view(out.size(0), -1) 92 | logits = self.classifier(out) 93 | 94 | return out, logits 95 | -------------------------------------------------------------------------------- /models/cell_searchs/search_model_enas_utils.py: -------------------------------------------------------------------------------- 1 | ################################################## 2 | # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 # 3 | ########################################################################## 4 | # Efficient Neural Architecture Search via Parameters Sharing, ICML 2018 # 5 | ########################################################################## 6 | import torch 7 | import torch.nn as nn 8 | from torch.distributions.categorical import Categorical 9 | 10 | class Controller(nn.Module): 11 | # we refer to https://github.com/TDeVries/enas_pytorch/blob/master/models/controller.py 12 | def __init__(self, num_edge, num_ops, lstm_size=32, lstm_num_layers=2, tanh_constant=2.5, temperature=5.0): 13 | super(Controller, self).__init__() 14 | # assign the attributes 15 | self.num_edge = num_edge 16 | self.num_ops = num_ops 17 | self.lstm_size = lstm_size 18 | self.lstm_N = lstm_num_layers 19 | self.tanh_constant = tanh_constant 20 | self.temperature = temperature 21 | # create parameters 22 | self.register_parameter('input_vars', nn.Parameter(torch.Tensor(1, 1, lstm_size))) 23 | self.w_lstm = nn.LSTM(input_size=self.lstm_size, hidden_size=self.lstm_size, num_layers=self.lstm_N) 24 | self.w_embd = nn.Embedding(self.num_ops, self.lstm_size) 25 | self.w_pred = nn.Linear(self.lstm_size, self.num_ops) 26 | 27 | nn.init.uniform_(self.input_vars , -0.1, 0.1) 28 | nn.init.uniform_(self.w_lstm.weight_hh_l0, -0.1, 0.1) 29 | nn.init.uniform_(self.w_lstm.weight_ih_l0, -0.1, 0.1) 30 | nn.init.uniform_(self.w_embd.weight , -0.1, 0.1) 31 | nn.init.uniform_(self.w_pred.weight , -0.1, 0.1) 32 | 33 | def forward(self): 34 | 35 | inputs, h0 = self.input_vars, None 36 | log_probs, entropys, sampled_arch = [], [], [] 37 | for iedge in range(self.num_edge): 38 | outputs, h0 = self.w_lstm(inputs, h0) 39 | 40 | logits = self.w_pred(outputs) 41 | logits = logits / self.temperature 42 | logits = self.tanh_constant * torch.tanh(logits) 43 | # distribution 44 | op_distribution = Categorical(logits=logits) 45 | op_index = op_distribution.sample() 46 | sampled_arch.append( op_index.item() ) 47 | 48 | op_log_prob = op_distribution.log_prob(op_index) 49 | log_probs.append( op_log_prob.view(-1) ) 50 | op_entropy = op_distribution.entropy() 51 | entropys.append( op_entropy.view(-1) ) 52 | 53 | # obtain the input embedding for the next step 54 | inputs = self.w_embd(op_index) 55 | return torch.sum(torch.cat(log_probs)), torch.sum(torch.cat(entropys)), sampled_arch 56 | -------------------------------------------------------------------------------- /models/cell_searchs/search_model_gdas.py: -------------------------------------------------------------------------------- 1 | ########################################################################### 2 | # Searching for A Robust Neural Architecture in Four GPU Hours, CVPR 2019 # 3 | ########################################################################### 4 | import torch 5 | import torch.nn as nn 6 | from copy import deepcopy 7 | from ..cell_operations import ResNetBasicblock 8 | from .search_cells import NAS201SearchCell as SearchCell 9 | from .genotypes import Structure 10 | 11 | 12 | class TinyNetworkGDAS(nn.Module): 13 | 14 | #def __init__(self, C, N, max_nodes, num_classes, search_space, affine=False, track_running_stats=True): 15 | def __init__(self, C, N, max_nodes, num_classes, search_space, affine, track_running_stats): 16 | super(TinyNetworkGDAS, self).__init__() 17 | self._C = C 18 | self._layerN = N 19 | self.max_nodes = max_nodes 20 | self.stem = nn.Sequential( 21 | nn.Conv2d(3, C, kernel_size=3, padding=1, bias=False), 22 | nn.BatchNorm2d(C)) 23 | 24 | layer_channels = [C ] * N + [C*2 ] + [C*2 ] * N + [C*4 ] + [C*4 ] * N 25 | layer_reductions = [False] * N + [True] + [False] * N + [True] + [False] * N 26 | 27 | C_prev, num_edge, edge2index = C, None, None 28 | self.cells = nn.ModuleList() 29 | for index, (C_curr, reduction) in enumerate(zip(layer_channels, layer_reductions)): 30 | if reduction: 31 | cell = ResNetBasicblock(C_prev, C_curr, 2) 32 | else: 33 | cell = SearchCell(C_prev, C_curr, 1, max_nodes, search_space, affine, track_running_stats) 34 | if num_edge is None: num_edge, edge2index = cell.num_edges, cell.edge2index 35 | else: assert num_edge == cell.num_edges and edge2index == cell.edge2index, 'invalid {:} vs. {:}.'.format(num_edge, cell.num_edges) 36 | self.cells.append( cell ) 37 | C_prev = cell.out_dim 38 | self.op_names = deepcopy( search_space ) 39 | self._Layer = len(self.cells) 40 | self.edge2index = edge2index 41 | self.lastact = nn.Sequential(nn.BatchNorm2d(C_prev), nn.ReLU(inplace=True)) 42 | self.global_pooling = nn.AdaptiveAvgPool2d(1) 43 | self.classifier = nn.Linear(C_prev, num_classes) 44 | self.arch_parameters = nn.Parameter( 1e-3*torch.randn(num_edge, len(search_space)) ) 45 | self.tau = 10 46 | 47 | def get_weights(self): 48 | xlist = list( self.stem.parameters() ) + list( self.cells.parameters() ) 49 | xlist+= list( self.lastact.parameters() ) + list( self.global_pooling.parameters() ) 50 | xlist+= list( self.classifier.parameters() ) 51 | return xlist 52 | 53 | def set_tau(self, tau): 54 | self.tau = tau 55 | 56 | def get_tau(self): 57 | return self.tau 58 | 59 | def get_alphas(self): 60 | return [self.arch_parameters] 61 | 62 | def show_alphas(self): 63 | with torch.no_grad(): 64 | return 'arch-parameters :\n{:}'.format( nn.functional.softmax(self.arch_parameters, dim=-1).cpu() ) 65 | 66 | def get_message(self): 67 | string = self.extra_repr() 68 | for i, cell in enumerate(self.cells): 69 | string += '\n {:02d}/{:02d} :: {:}'.format(i, len(self.cells), cell.extra_repr()) 70 | return string 71 | 72 | def extra_repr(self): 73 | return ('{name}(C={_C}, Max-Nodes={max_nodes}, N={_layerN}, L={_Layer})'.format(name=self.__class__.__name__, **self.__dict__)) 74 | 75 | def genotype(self): 76 | genotypes = [] 77 | for i in range(1, self.max_nodes): 78 | xlist = [] 79 | for j in range(i): 80 | node_str = '{:}<-{:}'.format(i, j) 81 | with torch.no_grad(): 82 | weights = self.arch_parameters[ self.edge2index[node_str] ] 83 | op_name = self.op_names[ weights.argmax().item() ] 84 | xlist.append((op_name, j)) 85 | genotypes.append( tuple(xlist) ) 86 | return Structure( genotypes ) 87 | 88 | def forward(self, inputs): 89 | while True: 90 | gumbels = -torch.empty_like(self.arch_parameters).exponential_().log() 91 | logits = (self.arch_parameters.log_softmax(dim=1) + gumbels) / self.tau 92 | probs = nn.functional.softmax(logits, dim=1) 93 | index = probs.max(-1, keepdim=True)[1] 94 | one_h = torch.zeros_like(logits).scatter_(-1, index, 1.0) 95 | hardwts = one_h - probs.detach() + probs 96 | if (torch.isinf(gumbels).any()) or (torch.isinf(probs).any()) or (torch.isnan(probs).any()): 97 | continue 98 | else: break 99 | 100 | feature = self.stem(inputs) 101 | for i, cell in enumerate(self.cells): 102 | if isinstance(cell, SearchCell): 103 | feature = cell.forward_gdas(feature, hardwts, index) 104 | else: 105 | feature = cell(feature) 106 | out = self.lastact(feature) 107 | out = self.global_pooling( out ) 108 | out = out.view(out.size(0), -1) 109 | logits = self.classifier(out) 110 | 111 | return out, logits 112 | -------------------------------------------------------------------------------- /models/cell_searchs/search_model_gdas_nasnet.py: -------------------------------------------------------------------------------- 1 | ########################################################################### 2 | # Searching for A Robust Neural Architecture in Four GPU Hours, CVPR 2019 # 3 | ########################################################################### 4 | import torch 5 | import torch.nn as nn 6 | from copy import deepcopy 7 | from .search_cells import NASNetSearchCell as SearchCell 8 | 9 | 10 | # The macro structure is based on NASNet 11 | class NASNetworkGDAS(nn.Module): 12 | 13 | def __init__(self, C, N, steps, multiplier, stem_multiplier, num_classes, search_space, affine, track_running_stats): 14 | super(NASNetworkGDAS, self).__init__() 15 | self._C = C 16 | self._layerN = N 17 | self._steps = steps 18 | self._multiplier = multiplier 19 | self.stem = nn.Sequential( 20 | nn.Conv2d(3, C*stem_multiplier, kernel_size=3, padding=1, bias=False), 21 | nn.BatchNorm2d(C*stem_multiplier)) 22 | 23 | # config for each layer 24 | layer_channels = [C ] * N + [C*2 ] + [C*2 ] * (N-1) + [C*4 ] + [C*4 ] * (N-1) 25 | layer_reductions = [False] * N + [True] + [False] * (N-1) + [True] + [False] * (N-1) 26 | 27 | num_edge, edge2index = None, None 28 | C_prev_prev, C_prev, C_curr, reduction_prev = C*stem_multiplier, C*stem_multiplier, C, False 29 | 30 | self.cells = nn.ModuleList() 31 | for index, (C_curr, reduction) in enumerate(zip(layer_channels, layer_reductions)): 32 | cell = SearchCell(search_space, steps, multiplier, C_prev_prev, C_prev, C_curr, reduction, reduction_prev, affine, track_running_stats) 33 | if num_edge is None: num_edge, edge2index = cell.num_edges, cell.edge2index 34 | else: assert num_edge == cell.num_edges and edge2index == cell.edge2index, 'invalid {:} vs. {:}.'.format(num_edge, cell.num_edges) 35 | self.cells.append( cell ) 36 | C_prev_prev, C_prev, reduction_prev = C_prev, multiplier*C_curr, reduction 37 | self.op_names = deepcopy( search_space ) 38 | self._Layer = len(self.cells) 39 | self.edge2index = edge2index 40 | self.lastact = nn.Sequential(nn.BatchNorm2d(C_prev), nn.ReLU(inplace=True)) 41 | self.global_pooling = nn.AdaptiveAvgPool2d(1) 42 | self.classifier = nn.Linear(C_prev, num_classes) 43 | self.arch_normal_parameters = nn.Parameter( 1e-3*torch.randn(num_edge, len(search_space)) ) 44 | self.arch_reduce_parameters = nn.Parameter( 1e-3*torch.randn(num_edge, len(search_space)) ) 45 | self.tau = 10 46 | 47 | def get_weights(self): 48 | xlist = list( self.stem.parameters() ) + list( self.cells.parameters() ) 49 | xlist+= list( self.lastact.parameters() ) + list( self.global_pooling.parameters() ) 50 | xlist+= list( self.classifier.parameters() ) 51 | return xlist 52 | 53 | def set_tau(self, tau): 54 | self.tau = tau 55 | 56 | def get_tau(self): 57 | return self.tau 58 | 59 | def get_alphas(self): 60 | return [self.arch_normal_parameters, self.arch_reduce_parameters] 61 | 62 | def show_alphas(self): 63 | with torch.no_grad(): 64 | A = 'arch-normal-parameters :\n{:}'.format( nn.functional.softmax(self.arch_normal_parameters, dim=-1).cpu() ) 65 | B = 'arch-reduce-parameters :\n{:}'.format( nn.functional.softmax(self.arch_reduce_parameters, dim=-1).cpu() ) 66 | return '{:}\n{:}'.format(A, B) 67 | 68 | def get_message(self): 69 | string = self.extra_repr() 70 | for i, cell in enumerate(self.cells): 71 | string += '\n {:02d}/{:02d} :: {:}'.format(i, len(self.cells), cell.extra_repr()) 72 | return string 73 | 74 | def extra_repr(self): 75 | return ('{name}(C={_C}, N={_layerN}, steps={_steps}, multiplier={_multiplier}, L={_Layer})'.format(name=self.__class__.__name__, **self.__dict__)) 76 | 77 | def genotype(self): 78 | def _parse(weights): 79 | gene = [] 80 | for i in range(self._steps): 81 | edges = [] 82 | for j in range(2+i): 83 | node_str = '{:}<-{:}'.format(i, j) 84 | ws = weights[ self.edge2index[node_str] ] 85 | for k, op_name in enumerate(self.op_names): 86 | if op_name == 'none': continue 87 | edges.append( (op_name, j, ws[k]) ) 88 | edges = sorted(edges, key=lambda x: -x[-1]) 89 | selected_edges = edges[:2] 90 | gene.append( tuple(selected_edges) ) 91 | return gene 92 | with torch.no_grad(): 93 | gene_normal = _parse(torch.softmax(self.arch_normal_parameters, dim=-1).cpu().numpy()) 94 | gene_reduce = _parse(torch.softmax(self.arch_reduce_parameters, dim=-1).cpu().numpy()) 95 | return {'normal': gene_normal, 'normal_concat': list(range(2+self._steps-self._multiplier, self._steps+2)), 96 | 'reduce': gene_reduce, 'reduce_concat': list(range(2+self._steps-self._multiplier, self._steps+2))} 97 | 98 | def forward(self, inputs): 99 | def get_gumbel_prob(xins): 100 | while True: 101 | gumbels = -torch.empty_like(xins).exponential_().log() 102 | logits = (xins.log_softmax(dim=1) + gumbels) / self.tau 103 | probs = nn.functional.softmax(logits, dim=1) 104 | index = probs.max(-1, keepdim=True)[1] 105 | one_h = torch.zeros_like(logits).scatter_(-1, index, 1.0) 106 | hardwts = one_h - probs.detach() + probs 107 | if (torch.isinf(gumbels).any()) or (torch.isinf(probs).any()) or (torch.isnan(probs).any()): 108 | continue 109 | else: break 110 | return hardwts, index 111 | 112 | normal_hardwts, normal_index = get_gumbel_prob(self.arch_normal_parameters) 113 | reduce_hardwts, reduce_index = get_gumbel_prob(self.arch_reduce_parameters) 114 | 115 | s0 = s1 = self.stem(inputs) 116 | for i, cell in enumerate(self.cells): 117 | if cell.reduction: hardwts, index = reduce_hardwts, reduce_index 118 | else : hardwts, index = normal_hardwts, normal_index 119 | s0, s1 = s1, cell.forward_gdas(s0, s1, hardwts, index) 120 | out = self.lastact(s1) 121 | out = self.global_pooling( out ) 122 | out = out.view(out.size(0), -1) 123 | logits = self.classifier(out) 124 | 125 | return out, logits 126 | -------------------------------------------------------------------------------- /models/cell_searchs/search_model_random.py: -------------------------------------------------------------------------------- 1 | ################################################## 2 | # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 # 3 | ############################################################################## 4 | # Random Search and Reproducibility for Neural Architecture Search, UAI 2019 # 5 | ############################################################################## 6 | import torch, random 7 | import torch.nn as nn 8 | from copy import deepcopy 9 | from ..cell_operations import ResNetBasicblock 10 | from .search_cells import NAS201SearchCell as SearchCell 11 | from .genotypes import Structure 12 | 13 | 14 | class TinyNetworkRANDOM(nn.Module): 15 | 16 | def __init__(self, C, N, max_nodes, num_classes, search_space, affine, track_running_stats): 17 | super(TinyNetworkRANDOM, self).__init__() 18 | self._C = C 19 | self._layerN = N 20 | self.max_nodes = max_nodes 21 | self.stem = nn.Sequential( 22 | nn.Conv2d(3, C, kernel_size=3, padding=1, bias=False), 23 | nn.BatchNorm2d(C)) 24 | 25 | layer_channels = [C ] * N + [C*2 ] + [C*2 ] * N + [C*4 ] + [C*4 ] * N 26 | layer_reductions = [False] * N + [True] + [False] * N + [True] + [False] * N 27 | 28 | C_prev, num_edge, edge2index = C, None, None 29 | self.cells = nn.ModuleList() 30 | for index, (C_curr, reduction) in enumerate(zip(layer_channels, layer_reductions)): 31 | if reduction: 32 | cell = ResNetBasicblock(C_prev, C_curr, 2) 33 | else: 34 | cell = SearchCell(C_prev, C_curr, 1, max_nodes, search_space, affine, track_running_stats) 35 | if num_edge is None: num_edge, edge2index = cell.num_edges, cell.edge2index 36 | else: assert num_edge == cell.num_edges and edge2index == cell.edge2index, 'invalid {:} vs. {:}.'.format(num_edge, cell.num_edges) 37 | self.cells.append( cell ) 38 | C_prev = cell.out_dim 39 | self.op_names = deepcopy( search_space ) 40 | self._Layer = len(self.cells) 41 | self.edge2index = edge2index 42 | self.lastact = nn.Sequential(nn.BatchNorm2d(C_prev), nn.ReLU(inplace=True)) 43 | self.global_pooling = nn.AdaptiveAvgPool2d(1) 44 | self.classifier = nn.Linear(C_prev, num_classes) 45 | self.arch_cache = None 46 | 47 | def get_message(self): 48 | string = self.extra_repr() 49 | for i, cell in enumerate(self.cells): 50 | string += '\n {:02d}/{:02d} :: {:}'.format(i, len(self.cells), cell.extra_repr()) 51 | return string 52 | 53 | def extra_repr(self): 54 | return ('{name}(C={_C}, Max-Nodes={max_nodes}, N={_layerN}, L={_Layer})'.format(name=self.__class__.__name__, **self.__dict__)) 55 | 56 | def random_genotype(self, set_cache): 57 | genotypes = [] 58 | for i in range(1, self.max_nodes): 59 | xlist = [] 60 | for j in range(i): 61 | node_str = '{:}<-{:}'.format(i, j) 62 | op_name = random.choice( self.op_names ) 63 | xlist.append((op_name, j)) 64 | genotypes.append( tuple(xlist) ) 65 | arch = Structure( genotypes ) 66 | if set_cache: self.arch_cache = arch 67 | return arch 68 | 69 | def forward(self, inputs): 70 | 71 | feature = self.stem(inputs) 72 | for i, cell in enumerate(self.cells): 73 | if isinstance(cell, SearchCell): 74 | feature = cell.forward_dynamic(feature, self.arch_cache) 75 | else: feature = cell(feature) 76 | 77 | out = self.lastact(feature) 78 | out = self.global_pooling( out ) 79 | out = out.view(out.size(0), -1) 80 | logits = self.classifier(out) 81 | return out, logits 82 | -------------------------------------------------------------------------------- /models/clone_weights.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | def copy_conv(module, init): 6 | assert isinstance(module, nn.Conv2d), 'invalid module : {:}'.format(module) 7 | assert isinstance(init , nn.Conv2d), 'invalid module : {:}'.format(init) 8 | new_i, new_o = module.in_channels, module.out_channels 9 | module.weight.copy_( init.weight.detach()[:new_o, :new_i] ) 10 | if module.bias is not None: 11 | module.bias.copy_( init.bias.detach()[:new_o] ) 12 | 13 | def copy_bn (module, init): 14 | assert isinstance(module, nn.BatchNorm2d), 'invalid module : {:}'.format(module) 15 | assert isinstance(init , nn.BatchNorm2d), 'invalid module : {:}'.format(init) 16 | num_features = module.num_features 17 | if module.weight is not None: 18 | module.weight.copy_( init.weight.detach()[:num_features] ) 19 | if module.bias is not None: 20 | module.bias.copy_( init.bias.detach()[:num_features] ) 21 | if module.running_mean is not None: 22 | module.running_mean.copy_( init.running_mean.detach()[:num_features] ) 23 | if module.running_var is not None: 24 | module.running_var.copy_( init.running_var.detach()[:num_features] ) 25 | 26 | def copy_fc (module, init): 27 | assert isinstance(module, nn.Linear), 'invalid module : {:}'.format(module) 28 | assert isinstance(init , nn.Linear), 'invalid module : {:}'.format(init) 29 | new_i, new_o = module.in_features, module.out_features 30 | module.weight.copy_( init.weight.detach()[:new_o, :new_i] ) 31 | if module.bias is not None: 32 | module.bias.copy_( init.bias.detach()[:new_o] ) 33 | 34 | def copy_base(module, init): 35 | assert type(module).__name__ in ['ConvBNReLU', 'Downsample'], 'invalid module : {:}'.format(module) 36 | assert type( init).__name__ in ['ConvBNReLU', 'Downsample'], 'invalid module : {:}'.format( init) 37 | if module.conv is not None: 38 | copy_conv(module.conv, init.conv) 39 | if module.bn is not None: 40 | copy_bn (module.bn, init.bn) 41 | 42 | def copy_basic(module, init): 43 | copy_base(module.conv_a, init.conv_a) 44 | copy_base(module.conv_b, init.conv_b) 45 | if module.downsample is not None: 46 | if init.downsample is not None: 47 | copy_base(module.downsample, init.downsample) 48 | #else: 49 | # import pdb; pdb.set_trace() 50 | 51 | 52 | def init_from_model(network, init_model): 53 | with torch.no_grad(): 54 | copy_fc(network.classifier, init_model.classifier) 55 | for base, target in zip(init_model.layers, network.layers): 56 | assert type(base).__name__ == type(target).__name__, 'invalid type : {:} vs {:}'.format(base, target) 57 | if type(base).__name__ == 'ConvBNReLU': 58 | copy_base(target, base) 59 | elif type(base).__name__ == 'ResNetBasicblock': 60 | copy_basic(target, base) 61 | else: 62 | raise ValueError('unknown type name : {:}'.format( type(base).__name__ )) 63 | -------------------------------------------------------------------------------- /models/initialization.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | def initialize_resnet(m): 6 | if isinstance(m, nn.Conv2d): 7 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 8 | if m.bias is not None: 9 | nn.init.constant_(m.bias, 0) 10 | elif isinstance(m, nn.BatchNorm2d): 11 | nn.init.constant_(m.weight, 1) 12 | if m.bias is not None: 13 | nn.init.constant_(m.bias, 0) 14 | elif isinstance(m, nn.Linear): 15 | nn.init.normal_(m.weight, 0, 0.01) 16 | nn.init.constant_(m.bias, 0) 17 | 18 | 19 | -------------------------------------------------------------------------------- /models/shape_infers/InferMobileNetV2.py: -------------------------------------------------------------------------------- 1 | ##################################################### 2 | # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.01 # 3 | ##################################################### 4 | # MobileNetV2: Inverted Residuals and Linear Bottlenecks, CVPR 2018 5 | from torch import nn 6 | from ..initialization import initialize_resnet 7 | from ..SharedUtils import parse_channel_info 8 | 9 | 10 | class ConvBNReLU(nn.Module): 11 | def __init__(self, in_planes, out_planes, kernel_size, stride, groups, has_bn=True, has_relu=True): 12 | super(ConvBNReLU, self).__init__() 13 | padding = (kernel_size - 1) // 2 14 | self.conv = nn.Conv2d(in_planes, out_planes, kernel_size, stride, padding, groups=groups, bias=False) 15 | if has_bn: self.bn = nn.BatchNorm2d(out_planes) 16 | else : self.bn = None 17 | if has_relu: self.relu = nn.ReLU6(inplace=True) 18 | else : self.relu = None 19 | 20 | def forward(self, x): 21 | out = self.conv( x ) 22 | if self.bn: out = self.bn ( out ) 23 | if self.relu: out = self.relu( out ) 24 | return out 25 | 26 | 27 | class InvertedResidual(nn.Module): 28 | def __init__(self, channels, stride, expand_ratio, additive): 29 | super(InvertedResidual, self).__init__() 30 | self.stride = stride 31 | assert stride in [1, 2], 'invalid stride : {:}'.format(stride) 32 | assert len(channels) in [2, 3], 'invalid channels : {:}'.format(channels) 33 | 34 | if len(channels) == 2: 35 | layers = [] 36 | else: 37 | layers = [ConvBNReLU(channels[0], channels[1], 1, 1, 1)] 38 | layers.extend([ 39 | # dw 40 | ConvBNReLU(channels[-2], channels[-2], 3, stride, channels[-2]), 41 | # pw-linear 42 | ConvBNReLU(channels[-2], channels[-1], 1, 1, 1, True, False), 43 | ]) 44 | self.conv = nn.Sequential(*layers) 45 | self.additive = additive 46 | if self.additive and channels[0] != channels[-1]: 47 | self.shortcut = ConvBNReLU(channels[0], channels[-1], 1, 1, 1, True, False) 48 | else: 49 | self.shortcut = None 50 | self.out_dim = channels[-1] 51 | 52 | def forward(self, x): 53 | out = self.conv(x) 54 | # if self.additive: return additive_func(out, x) 55 | if self.shortcut: return out + self.shortcut(x) 56 | else : return out 57 | 58 | 59 | class InferMobileNetV2(nn.Module): 60 | def __init__(self, num_classes, xchannels, xblocks, dropout): 61 | super(InferMobileNetV2, self).__init__() 62 | block = InvertedResidual 63 | inverted_residual_setting = [ 64 | # t, c, n, s 65 | [1, 16 , 1, 1], 66 | [6, 24 , 2, 2], 67 | [6, 32 , 3, 2], 68 | [6, 64 , 4, 2], 69 | [6, 96 , 3, 1], 70 | [6, 160, 3, 2], 71 | [6, 320, 1, 1], 72 | ] 73 | assert len(inverted_residual_setting) == len(xblocks), 'invalid number of layers : {:} vs {:}'.format(len(inverted_residual_setting), len(xblocks)) 74 | for block_num, ir_setting in zip(xblocks, inverted_residual_setting): 75 | assert block_num <= ir_setting[2], '{:} vs {:}'.format(block_num, ir_setting) 76 | xchannels = parse_channel_info(xchannels) 77 | #for i, chs in enumerate(xchannels): 78 | # if i > 0: assert chs[0] == xchannels[i-1][-1], 'Layer[{:}] is invalid {:} vs {:}'.format(i, xchannels[i-1], chs) 79 | self.xchannels = xchannels 80 | self.message = 'InferMobileNetV2 : xblocks={:}'.format(xblocks) 81 | # building first layer 82 | features = [ConvBNReLU(xchannels[0][0], xchannels[0][1], 3, 2, 1)] 83 | last_channel_idx = 1 84 | 85 | # building inverted residual blocks 86 | for stage, (t, c, n, s) in enumerate(inverted_residual_setting): 87 | for i in range(n): 88 | stride = s if i == 0 else 1 89 | additv = True if i > 0 else False 90 | module = block(self.xchannels[last_channel_idx], stride, t, additv) 91 | features.append(module) 92 | self.message += "\nstage={:}, ilayer={:02d}/{:02d}, block={:03d}, Cs={:}, stride={:}, expand={:}, original-C={:}".format(stage, i, n, len(features), self.xchannels[last_channel_idx], stride, t, c) 93 | last_channel_idx += 1 94 | if i + 1 == xblocks[stage]: 95 | out_channel = module.out_dim 96 | for iiL in range(i+1, n): 97 | last_channel_idx += 1 98 | self.xchannels[last_channel_idx][0] = module.out_dim 99 | break 100 | # building last several layers 101 | features.append(ConvBNReLU(self.xchannels[last_channel_idx][0], self.xchannels[last_channel_idx][1], 1, 1, 1)) 102 | assert last_channel_idx + 2 == len(self.xchannels), '{:} vs {:}'.format(last_channel_idx, len(self.xchannels)) 103 | # make it nn.Sequential 104 | self.features = nn.Sequential(*features) 105 | 106 | # building classifier 107 | self.classifier = nn.Sequential( 108 | nn.Dropout(dropout), 109 | nn.Linear(self.xchannels[last_channel_idx][1], num_classes), 110 | ) 111 | 112 | # weight initialization 113 | self.apply( initialize_resnet ) 114 | 115 | def get_message(self): 116 | return self.message 117 | 118 | def forward(self, inputs): 119 | features = self.features(inputs) 120 | vectors = features.mean([2, 3]) 121 | predicts = self.classifier(vectors) 122 | return features, predicts 123 | -------------------------------------------------------------------------------- /models/shape_infers/InferTinyCellNet.py: -------------------------------------------------------------------------------- 1 | ##################################################### 2 | # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.01 # 3 | ##################################################### 4 | from typing import List, Text, Any 5 | import torch.nn as nn 6 | from models.cell_operations import ResNetBasicblock 7 | from models.cell_infers.cells import InferCell 8 | 9 | 10 | class DynamicShapeTinyNet(nn.Module): 11 | 12 | def __init__(self, channels: List[int], genotype: Any, num_classes: int): 13 | super(DynamicShapeTinyNet, self).__init__() 14 | self._channels = channels 15 | if len(channels) % 3 != 2: 16 | raise ValueError('invalid number of layers : {:}'.format(len(channels))) 17 | self._num_stage = N = len(channels) // 3 18 | 19 | self.stem = nn.Sequential( 20 | nn.Conv2d(3, channels[0], kernel_size=3, padding=1, bias=False), 21 | nn.BatchNorm2d(channels[0])) 22 | 23 | # layer_channels = [C ] * N + [C*2 ] + [C*2 ] * N + [C*4 ] + [C*4 ] * N 24 | layer_reductions = [False] * N + [True] + [False] * N + [True] + [False] * N 25 | 26 | c_prev = channels[0] 27 | self.cells = nn.ModuleList() 28 | for index, (c_curr, reduction) in enumerate(zip(channels, layer_reductions)): 29 | if reduction : cell = ResNetBasicblock(c_prev, c_curr, 2, True) 30 | else : cell = InferCell(genotype, c_prev, c_curr, 1) 31 | self.cells.append( cell ) 32 | c_prev = cell.out_dim 33 | self._num_layer = len(self.cells) 34 | 35 | self.lastact = nn.Sequential(nn.BatchNorm2d(c_prev), nn.ReLU(inplace=True)) 36 | self.global_pooling = nn.AdaptiveAvgPool2d(1) 37 | self.classifier = nn.Linear(c_prev, num_classes) 38 | 39 | def get_message(self) -> Text: 40 | string = self.extra_repr() 41 | for i, cell in enumerate(self.cells): 42 | string += '\n {:02d}/{:02d} :: {:}'.format(i, len(self.cells), cell.extra_repr()) 43 | return string 44 | 45 | def extra_repr(self): 46 | return ('{name}(C={_channels}, N={_num_stage}, L={_num_layer})'.format(name=self.__class__.__name__, **self.__dict__)) 47 | 48 | def forward(self, inputs): 49 | feature = self.stem(inputs) 50 | for i, cell in enumerate(self.cells): 51 | feature = cell(feature) 52 | 53 | out = self.lastact(feature) 54 | out = self.global_pooling( out ) 55 | out = out.view(out.size(0), -1) 56 | logits = self.classifier(out) 57 | 58 | return out, logits 59 | -------------------------------------------------------------------------------- /models/shape_infers/__init__.py: -------------------------------------------------------------------------------- 1 | ##################################################### 2 | # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.01 # 3 | ##################################################### 4 | from .InferCifarResNet_width import InferWidthCifarResNet 5 | from .InferImagenetResNet import InferImagenetResNet 6 | from .InferCifarResNet_depth import InferDepthCifarResNet 7 | from .InferCifarResNet import InferCifarResNet 8 | from .InferMobileNetV2 import InferMobileNetV2 9 | from .InferTinyCellNet import DynamicShapeTinyNet -------------------------------------------------------------------------------- /models/shape_infers/shared_utils.py: -------------------------------------------------------------------------------- 1 | def parse_channel_info(xstring): 2 | blocks = xstring.split(' ') 3 | blocks = [x.split('-') for x in blocks] 4 | blocks = [[int(_) for _ in x] for x in blocks] 5 | return blocks 6 | -------------------------------------------------------------------------------- /models/shape_searchs/SoftSelect.py: -------------------------------------------------------------------------------- 1 | ################################################## 2 | # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 # 3 | ################################################## 4 | import math, torch 5 | import torch.nn as nn 6 | 7 | 8 | def select2withP(logits, tau, just_prob=False, num=2, eps=1e-7): 9 | if tau <= 0: 10 | new_logits = logits 11 | probs = nn.functional.softmax(new_logits, dim=1) 12 | else : 13 | while True: # a trick to avoid the gumbels bug 14 | gumbels = -torch.empty_like(logits).exponential_().log() 15 | new_logits = (logits.log_softmax(dim=1) + gumbels) / tau 16 | probs = nn.functional.softmax(new_logits, dim=1) 17 | if (not torch.isinf(gumbels).any()) and (not torch.isinf(probs).any()) and (not torch.isnan(probs).any()): break 18 | 19 | if just_prob: return probs 20 | 21 | #with torch.no_grad(): # add eps for unexpected torch error 22 | # probs = nn.functional.softmax(new_logits, dim=1) 23 | # selected_index = torch.multinomial(probs + eps, 2, False) 24 | with torch.no_grad(): # add eps for unexpected torch error 25 | probs = probs.cpu() 26 | selected_index = torch.multinomial(probs + eps, num, False).to(logits.device) 27 | selected_logit = torch.gather(new_logits, 1, selected_index) 28 | selcted_probs = nn.functional.softmax(selected_logit, dim=1) 29 | return selected_index, selcted_probs 30 | 31 | 32 | def ChannelWiseInter(inputs, oC, mode='v2'): 33 | if mode == 'v1': 34 | return ChannelWiseInterV1(inputs, oC) 35 | elif mode == 'v2': 36 | return ChannelWiseInterV2(inputs, oC) 37 | else: 38 | raise ValueError('invalid mode : {:}'.format(mode)) 39 | 40 | 41 | def ChannelWiseInterV1(inputs, oC): 42 | assert inputs.dim() == 4, 'invalid dimension : {:}'.format(inputs.size()) 43 | def start_index(a, b, c): 44 | return int( math.floor(float(a * c) / b) ) 45 | def end_index(a, b, c): 46 | return int( math.ceil(float((a + 1) * c) / b) ) 47 | batch, iC, H, W = inputs.size() 48 | outputs = torch.zeros((batch, oC, H, W), dtype=inputs.dtype, device=inputs.device) 49 | if iC == oC: return inputs 50 | for ot in range(oC): 51 | istartT, iendT = start_index(ot, oC, iC), end_index(ot, oC, iC) 52 | values = inputs[:, istartT:iendT].mean(dim=1) 53 | outputs[:, ot, :, :] = values 54 | return outputs 55 | 56 | 57 | def ChannelWiseInterV2(inputs, oC): 58 | assert inputs.dim() == 4, 'invalid dimension : {:}'.format(inputs.size()) 59 | batch, C, H, W = inputs.size() 60 | if C == oC: return inputs 61 | else : return nn.functional.adaptive_avg_pool3d(inputs, (oC,H,W)) 62 | #inputs_5D = inputs.view(batch, 1, C, H, W) 63 | #otputs_5D = nn.functional.interpolate(inputs_5D, (oC,H,W), None, 'area', None) 64 | #otputs = otputs_5D.view(batch, oC, H, W) 65 | #otputs_5D = nn.functional.interpolate(inputs_5D, (oC,H,W), None, 'trilinear', False) 66 | #return otputs 67 | 68 | 69 | def linear_forward(inputs, linear): 70 | if linear is None: return inputs 71 | iC = inputs.size(1) 72 | weight = linear.weight[:, :iC] 73 | if linear.bias is None: bias = None 74 | else : bias = linear.bias 75 | return nn.functional.linear(inputs, weight, bias) 76 | 77 | 78 | def get_width_choices(nOut): 79 | xsrange = [0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0] 80 | if nOut is None: 81 | return len(xsrange) 82 | else: 83 | Xs = [int(nOut * i) for i in xsrange] 84 | #xs = [ int(nOut * i // 10) for i in range(2, 11)] 85 | #Xs = [x for i, x in enumerate(xs) if i+1 == len(xs) or xs[i+1] > x+1] 86 | Xs = sorted( list( set(Xs) ) ) 87 | return tuple(Xs) 88 | 89 | 90 | def get_depth_choices(nDepth): 91 | if nDepth is None: 92 | return 3 93 | else: 94 | assert nDepth >= 3, 'nDepth should be greater than 2 vs {:}'.format(nDepth) 95 | if nDepth == 1 : return (1, 1, 1) 96 | elif nDepth == 2: return (1, 1, 2) 97 | elif nDepth >= 3: 98 | return (nDepth//3, nDepth*2//3, nDepth) 99 | else: 100 | raise ValueError('invalid Depth : {:}'.format(nDepth)) 101 | 102 | 103 | def drop_path(x, drop_prob): 104 | if drop_prob > 0.: 105 | keep_prob = 1. - drop_prob 106 | mask = x.new_zeros(x.size(0), 1, 1, 1) 107 | mask = mask.bernoulli_(keep_prob) 108 | x = x * (mask / keep_prob) 109 | #x.div_(keep_prob) 110 | #x.mul_(mask) 111 | return x 112 | -------------------------------------------------------------------------------- /models/shape_searchs/__init__.py: -------------------------------------------------------------------------------- 1 | ################################################## 2 | # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 # 3 | ################################################## 4 | from .SearchCifarResNet_width import SearchWidthCifarResNet 5 | from .SearchCifarResNet_depth import SearchDepthCifarResNet 6 | from .SearchCifarResNet import SearchShapeCifarResNet 7 | from .SearchSimResNet_width import SearchWidthSimResNet 8 | from .SearchImagenetResNet import SearchShapeImagenetResNet 9 | -------------------------------------------------------------------------------- /models/shape_searchs/test.py: -------------------------------------------------------------------------------- 1 | ################################################## 2 | # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 # 3 | ################################################## 4 | import torch 5 | import torch.nn as nn 6 | from SoftSelect import ChannelWiseInter 7 | 8 | 9 | if __name__ == '__main__': 10 | 11 | tensors = torch.rand((16, 128, 7, 7)) 12 | 13 | for oc in range(200, 210): 14 | out_v1 = ChannelWiseInter(tensors, oc, 'v1') 15 | out_v2 = ChannelWiseInter(tensors, oc, 'v2') 16 | assert (out_v1 == out_v2).any().item() == 1 17 | for oc in range(48, 160): 18 | out_v1 = ChannelWiseInter(tensors, oc, 'v1') 19 | out_v2 = ChannelWiseInter(tensors, oc, 'v2') 20 | assert (out_v1 == out_v2).any().item() == 1 21 | -------------------------------------------------------------------------------- /nas_101_api/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /nas_101_api/base_ops.py: -------------------------------------------------------------------------------- 1 | """Base operations used by the modules in this search space.""" 2 | 3 | from __future__ import absolute_import 4 | from __future__ import division 5 | from __future__ import print_function 6 | 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | 11 | class ConvBnRelu(nn.Module): 12 | def __init__(self, in_channels, out_channels, kernel_size=1, stride=1, padding=0): 13 | super(ConvBnRelu, self).__init__() 14 | 15 | self.conv_bn_relu = nn.Sequential( 16 | #nn.ReLU(), 17 | nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, bias=False), 18 | nn.BatchNorm2d(out_channels), 19 | #nn.ReLU(inplace=True) 20 | nn.ReLU() 21 | ) 22 | 23 | def forward(self, x): 24 | return self.conv_bn_relu(x) 25 | 26 | class Conv3x3BnRelu(nn.Module): 27 | """3x3 convolution with batch norm and ReLU activation.""" 28 | def __init__(self, in_channels, out_channels): 29 | super(Conv3x3BnRelu, self).__init__() 30 | 31 | self.conv3x3 = ConvBnRelu(in_channels, out_channels, 3, 1, 1) 32 | 33 | def forward(self, x): 34 | x = self.conv3x3(x) 35 | return x 36 | 37 | class Conv1x1BnRelu(nn.Module): 38 | """1x1 convolution with batch norm and ReLU activation.""" 39 | def __init__(self, in_channels, out_channels): 40 | super(Conv1x1BnRelu, self).__init__() 41 | 42 | self.conv1x1 = ConvBnRelu(in_channels, out_channels, 1, 1, 0) 43 | 44 | def forward(self, x): 45 | x = self.conv1x1(x) 46 | return x 47 | 48 | class MaxPool3x3(nn.Module): 49 | """3x3 max pool with no subsampling.""" 50 | def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1): 51 | super(MaxPool3x3, self).__init__() 52 | 53 | self.maxpool = nn.MaxPool2d(kernel_size, stride, padding) 54 | #self.maxpool = nn.AvgPool2d(kernel_size, stride, padding) 55 | 56 | def forward(self, x): 57 | x = self.maxpool(x) 58 | return x 59 | 60 | # Commas should not be used in op names 61 | OP_MAP = { 62 | 'conv3x3-bn-relu': Conv3x3BnRelu, 63 | 'conv1x1-bn-relu': Conv1x1BnRelu, 64 | 'maxpool3x3': MaxPool3x3 65 | } 66 | -------------------------------------------------------------------------------- /nas_101_api/graph_util.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 The Google Research Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Utility functions used by generate_graph.py.""" 16 | from __future__ import absolute_import 17 | from __future__ import division 18 | from __future__ import print_function 19 | 20 | import hashlib 21 | import itertools 22 | 23 | import numpy as np 24 | 25 | 26 | def gen_is_edge_fn(bits): 27 | """Generate a boolean function for the edge connectivity. 28 | 29 | Given a bitstring FEDCBA and a 4x4 matrix, the generated matrix is 30 | [[0, A, B, D], 31 | [0, 0, C, E], 32 | [0, 0, 0, F], 33 | [0, 0, 0, 0]] 34 | 35 | Note that this function is agnostic to the actual matrix dimension due to 36 | order in which elements are filled out (column-major, starting from least 37 | significant bit). For example, the same FEDCBA bitstring (0-padded) on a 5x5 38 | matrix is 39 | [[0, A, B, D, 0], 40 | [0, 0, C, E, 0], 41 | [0, 0, 0, F, 0], 42 | [0, 0, 0, 0, 0], 43 | [0, 0, 0, 0, 0]] 44 | 45 | Args: 46 | bits: integer which will be interpreted as a bit mask. 47 | 48 | Returns: 49 | vectorized function that returns True when an edge is present. 50 | """ 51 | def is_edge(x, y): 52 | """Is there an edge from x to y (0-indexed)?""" 53 | if x >= y: 54 | return 0 55 | # Map x, y to index into bit string 56 | index = x + (y * (y - 1) // 2) 57 | return (bits >> index) % 2 == 1 58 | 59 | return np.vectorize(is_edge) 60 | 61 | 62 | def is_full_dag(matrix): 63 | """Full DAG == all vertices on a path from vert 0 to (V-1). 64 | 65 | i.e. no disconnected or "hanging" vertices. 66 | 67 | It is sufficient to check for: 68 | 1) no rows of 0 except for row V-1 (only output vertex has no out-edges) 69 | 2) no cols of 0 except for col 0 (only input vertex has no in-edges) 70 | 71 | Args: 72 | matrix: V x V upper-triangular adjacency matrix 73 | 74 | Returns: 75 | True if the there are no dangling vertices. 76 | """ 77 | shape = np.shape(matrix) 78 | 79 | rows = matrix[:shape[0]-1, :] == 0 80 | rows = np.all(rows, axis=1) # Any row with all 0 will be True 81 | rows_bad = np.any(rows) 82 | 83 | cols = matrix[:, 1:] == 0 84 | cols = np.all(cols, axis=0) # Any col with all 0 will be True 85 | cols_bad = np.any(cols) 86 | 87 | return (not rows_bad) and (not cols_bad) 88 | 89 | 90 | def num_edges(matrix): 91 | """Computes number of edges in adjacency matrix.""" 92 | return np.sum(matrix) 93 | 94 | 95 | def hash_module(matrix, labeling): 96 | """Computes a graph-invariance MD5 hash of the matrix and label pair. 97 | 98 | Args: 99 | matrix: np.ndarray square upper-triangular adjacency matrix. 100 | labeling: list of int labels of length equal to both dimensions of 101 | matrix. 102 | 103 | Returns: 104 | MD5 hash of the matrix and labeling. 105 | """ 106 | vertices = np.shape(matrix)[0] 107 | in_edges = np.sum(matrix, axis=0).tolist() 108 | out_edges = np.sum(matrix, axis=1).tolist() 109 | 110 | assert len(in_edges) == len(out_edges) == len(labeling) 111 | hashes = list(zip(out_edges, in_edges, labeling)) 112 | hashes = [hashlib.md5(str(h).encode('utf-8')).hexdigest() for h in hashes] 113 | # Computing this up to the diameter is probably sufficient but since the 114 | # operation is fast, it is okay to repeat more times. 115 | for _ in range(vertices): 116 | new_hashes = [] 117 | for v in range(vertices): 118 | in_neighbors = [hashes[w] for w in range(vertices) if matrix[w, v]] 119 | out_neighbors = [hashes[w] for w in range(vertices) if matrix[v, w]] 120 | new_hashes.append(hashlib.md5( 121 | (''.join(sorted(in_neighbors)) + '|' + 122 | ''.join(sorted(out_neighbors)) + '|' + 123 | hashes[v]).encode('utf-8')).hexdigest()) 124 | hashes = new_hashes 125 | fingerprint = hashlib.md5(str(sorted(hashes)).encode('utf-8')).hexdigest() 126 | 127 | return fingerprint 128 | 129 | 130 | def permute_graph(graph, label, permutation): 131 | """Permutes the graph and labels based on permutation. 132 | 133 | Args: 134 | graph: np.ndarray adjacency matrix. 135 | label: list of labels of same length as graph dimensions. 136 | permutation: a permutation list of ints of same length as graph dimensions. 137 | 138 | Returns: 139 | np.ndarray where vertex permutation[v] is vertex v from the original graph 140 | """ 141 | # vertex permutation[v] in new graph is vertex v in the old graph 142 | forward_perm = zip(permutation, list(range(len(permutation)))) 143 | inverse_perm = [x[1] for x in sorted(forward_perm)] 144 | edge_fn = lambda x, y: graph[inverse_perm[x], inverse_perm[y]] == 1 145 | new_matrix = np.fromfunction(np.vectorize(edge_fn), 146 | (len(label), len(label)), 147 | dtype=np.int8) 148 | new_label = [label[inverse_perm[i]] for i in range(len(label))] 149 | return new_matrix, new_label 150 | 151 | 152 | def is_isomorphic(graph1, graph2): 153 | """Exhaustively checks if 2 graphs are isomorphic.""" 154 | matrix1, label1 = np.array(graph1[0]), graph1[1] 155 | matrix2, label2 = np.array(graph2[0]), graph2[1] 156 | assert np.shape(matrix1) == np.shape(matrix2) 157 | assert len(label1) == len(label2) 158 | 159 | vertices = np.shape(matrix1)[0] 160 | # Note: input and output in our constrained graphs always map to themselves 161 | # but this script does not enforce that. 162 | for perm in itertools.permutations(range(0, vertices)): 163 | pmatrix1, plabel1 = permute_graph(matrix1, label1, perm) 164 | if np.array_equal(pmatrix1, matrix2) and plabel1 == label2: 165 | return True 166 | 167 | return False 168 | -------------------------------------------------------------------------------- /nas_101_api/model_spec.py: -------------------------------------------------------------------------------- 1 | """Model specification for module connectivity individuals. 2 | 3 | This module handles pruning the unused parts of the computation graph but should 4 | avoid creating any TensorFlow models (this is done inside model_builder.py). 5 | """ 6 | 7 | from __future__ import absolute_import 8 | from __future__ import division 9 | from __future__ import print_function 10 | 11 | import copy 12 | import numpy as np 13 | 14 | from . import graph_util 15 | 16 | # Graphviz is optional and only required for visualization. 17 | try: 18 | import graphviz # pylint: disable=g-import-not-at-top 19 | except ImportError: 20 | pass 21 | 22 | 23 | class ModelSpec(object): 24 | """Model specification given adjacency matrix and labeling.""" 25 | 26 | def __init__(self, matrix, ops, data_format='channels_last'): 27 | """Initialize the module spec. 28 | 29 | Args: 30 | matrix: ndarray or nested list with shape [V, V] for the adjacency matrix. 31 | ops: V-length list of labels for the base ops used. The first and last 32 | elements are ignored because they are the input and output vertices 33 | which have no operations. The elements are retained to keep consistent 34 | indexing. 35 | data_format: channels_last or channels_first. 36 | 37 | Raises: 38 | ValueError: invalid matrix or ops 39 | """ 40 | if not isinstance(matrix, np.ndarray): 41 | matrix = np.array(matrix) 42 | shape = np.shape(matrix) 43 | if len(shape) != 2 or shape[0] != shape[1]: 44 | raise ValueError('matrix must be square') 45 | if shape[0] != len(ops): 46 | raise ValueError('length of ops must match matrix dimensions') 47 | if not is_upper_triangular(matrix): 48 | raise ValueError('matrix must be upper triangular') 49 | 50 | # Both the original and pruned matrices are deep copies of the matrix and 51 | # ops so any changes to those after initialization are not recognized by the 52 | # spec. 53 | self.original_matrix = copy.deepcopy(matrix) 54 | self.original_ops = copy.deepcopy(ops) 55 | 56 | self.matrix = copy.deepcopy(matrix) 57 | self.ops = copy.deepcopy(ops) 58 | self.valid_spec = True 59 | self._prune() 60 | 61 | self.data_format = data_format 62 | 63 | def _prune(self): 64 | """Prune the extraneous parts of the graph. 65 | 66 | General procedure: 67 | 1) Remove parts of graph not connected to input. 68 | 2) Remove parts of graph not connected to output. 69 | 3) Reorder the vertices so that they are consecutive after steps 1 and 2. 70 | 71 | These 3 steps can be combined by deleting the rows and columns of the 72 | vertices that are not reachable from both the input and output (in reverse). 73 | """ 74 | num_vertices = np.shape(self.original_matrix)[0] 75 | 76 | # DFS forward from input 77 | visited_from_input = set([0]) 78 | frontier = [0] 79 | while frontier: 80 | top = frontier.pop() 81 | for v in range(top + 1, num_vertices): 82 | if self.original_matrix[top, v] and v not in visited_from_input: 83 | visited_from_input.add(v) 84 | frontier.append(v) 85 | 86 | # DFS backward from output 87 | visited_from_output = set([num_vertices - 1]) 88 | frontier = [num_vertices - 1] 89 | while frontier: 90 | top = frontier.pop() 91 | for v in range(0, top): 92 | if self.original_matrix[v, top] and v not in visited_from_output: 93 | visited_from_output.add(v) 94 | frontier.append(v) 95 | 96 | # Any vertex that isn't connected to both input and output is extraneous to 97 | # the computation graph. 98 | extraneous = set(range(num_vertices)).difference( 99 | visited_from_input.intersection(visited_from_output)) 100 | 101 | # If the non-extraneous graph is less than 2 vertices, the input is not 102 | # connected to the output and the spec is invalid. 103 | if len(extraneous) > num_vertices - 2: 104 | self.matrix = None 105 | self.ops = None 106 | self.valid_spec = False 107 | return 108 | 109 | self.matrix = np.delete(self.matrix, list(extraneous), axis=0) 110 | self.matrix = np.delete(self.matrix, list(extraneous), axis=1) 111 | for index in sorted(extraneous, reverse=True): 112 | del self.ops[index] 113 | 114 | def hash_spec(self, canonical_ops): 115 | """Computes the isomorphism-invariant graph hash of this spec. 116 | 117 | Args: 118 | canonical_ops: list of operations in the canonical ordering which they 119 | were assigned (i.e. the order provided in the config['available_ops']). 120 | 121 | Returns: 122 | MD5 hash of this spec which can be used to query the dataset. 123 | """ 124 | # Invert the operations back to integer label indices used in graph gen. 125 | labeling = [-1] + [canonical_ops.index(op) for op in self.ops[1:-1]] + [-2] 126 | return graph_util.hash_module(self.matrix, labeling) 127 | 128 | def visualize(self): 129 | """Creates a dot graph. Can be visualized in colab directly.""" 130 | num_vertices = np.shape(self.matrix)[0] 131 | g = graphviz.Digraph() 132 | g.node(str(0), 'input') 133 | for v in range(1, num_vertices - 1): 134 | g.node(str(v), self.ops[v]) 135 | g.node(str(num_vertices - 1), 'output') 136 | 137 | for src in range(num_vertices - 1): 138 | for dst in range(src + 1, num_vertices): 139 | if self.matrix[src, dst]: 140 | g.edge(str(src), str(dst)) 141 | 142 | return g 143 | 144 | 145 | def is_upper_triangular(matrix): 146 | """True if matrix is 0 on diagonal and below.""" 147 | for src in range(np.shape(matrix)[0]): 148 | for dst in range(0, src + 1): 149 | if matrix[src, dst] != 0: 150 | return False 151 | 152 | return True 153 | -------------------------------------------------------------------------------- /nas_201_api/__init__.py: -------------------------------------------------------------------------------- 1 | ##################################################### 2 | # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.08 # 3 | ##################################################################### 4 | # This API will not be updated after 2020.09.16. # 5 | # Please use our new API in NATS-Bench, which is # 6 | # more efficient and contains info of more architecture candidates. # 7 | ##################################################################### 8 | from .api_utils import ArchResults, ResultsCount 9 | from .api_201 import NASBench201API 10 | 11 | # NAS_BENCH_201_API_VERSION="v1.1" # [2020.02.25] 12 | # NAS_BENCH_201_API_VERSION="v1.2" # [2020.03.09] 13 | # NAS_BENCH_201_API_VERSION="v1.3" # [2020.03.16] 14 | NAS_BENCH_201_API_VERSION="v2.0" # [2020.06.30] 15 | 16 | -------------------------------------------------------------------------------- /pycls/core/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BayesWatch/nas-without-training/b3a82a6642564df115f989ff940ec6b8ef9ca9d3/pycls/core/__init__.py -------------------------------------------------------------------------------- /pycls/core/benchmark.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright (c) Facebook, Inc. and its affiliates. 4 | # 5 | # This source code is licensed under the MIT license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | """Benchmarking functions.""" 9 | 10 | import pycls.core.logging as logging 11 | import pycls.datasets.loader as loader 12 | import torch 13 | from pycls.core.config import cfg 14 | from pycls.core.timer import Timer 15 | 16 | 17 | logger = logging.get_logger(__name__) 18 | 19 | 20 | @torch.no_grad() 21 | def compute_time_eval(model): 22 | """Computes precise model forward test time using dummy data.""" 23 | # Use eval mode 24 | model.eval() 25 | # Generate a dummy mini-batch and copy data to GPU 26 | im_size, batch_size = cfg.TRAIN.IM_SIZE, int(cfg.TEST.BATCH_SIZE / cfg.NUM_GPUS) 27 | if cfg.TASK == "jig": 28 | inputs = torch.rand(batch_size, cfg.JIGSAW_GRID ** 2, cfg.MODEL.INPUT_CHANNELS, im_size, im_size).cuda(non_blocking=False) 29 | else: 30 | inputs = torch.zeros(batch_size, cfg.MODEL.INPUT_CHANNELS, im_size, im_size).cuda(non_blocking=False) 31 | # Compute precise forward pass time 32 | timer = Timer() 33 | total_iter = cfg.PREC_TIME.NUM_ITER + cfg.PREC_TIME.WARMUP_ITER 34 | for cur_iter in range(total_iter): 35 | # Reset the timers after the warmup phase 36 | if cur_iter == cfg.PREC_TIME.WARMUP_ITER: 37 | timer.reset() 38 | # Forward 39 | timer.tic() 40 | model(inputs) 41 | torch.cuda.synchronize() 42 | timer.toc() 43 | return timer.average_time 44 | 45 | 46 | def compute_time_train(model, loss_fun): 47 | """Computes precise model forward + backward time using dummy data.""" 48 | # Use train mode 49 | model.train() 50 | # Generate a dummy mini-batch and copy data to GPU 51 | im_size, batch_size = cfg.TRAIN.IM_SIZE, int(cfg.TRAIN.BATCH_SIZE / cfg.NUM_GPUS) 52 | if cfg.TASK == "jig": 53 | inputs = torch.rand(batch_size, cfg.JIGSAW_GRID ** 2, cfg.MODEL.INPUT_CHANNELS, im_size, im_size).cuda(non_blocking=False) 54 | else: 55 | inputs = torch.rand(batch_size, cfg.MODEL.INPUT_CHANNELS, im_size, im_size).cuda(non_blocking=False) 56 | if cfg.TASK in ['col', 'seg']: 57 | labels = torch.zeros(batch_size, im_size, im_size, dtype=torch.int64).cuda(non_blocking=False) 58 | else: 59 | labels = torch.zeros(batch_size, dtype=torch.int64).cuda(non_blocking=False) 60 | # Cache BatchNorm2D running stats 61 | bns = [m for m in model.modules() if isinstance(m, torch.nn.BatchNorm2d)] 62 | bn_stats = [[bn.running_mean.clone(), bn.running_var.clone()] for bn in bns] 63 | # Compute precise forward backward pass time 64 | fw_timer, bw_timer = Timer(), Timer() 65 | total_iter = cfg.PREC_TIME.NUM_ITER + cfg.PREC_TIME.WARMUP_ITER 66 | for cur_iter in range(total_iter): 67 | # Reset the timers after the warmup phase 68 | if cur_iter == cfg.PREC_TIME.WARMUP_ITER: 69 | fw_timer.reset() 70 | bw_timer.reset() 71 | # Forward 72 | fw_timer.tic() 73 | preds = model(inputs) 74 | if isinstance(preds, tuple): 75 | loss = loss_fun(preds[0], labels) + cfg.NAS.AUX_WEIGHT * loss_fun(preds[1], labels) 76 | preds = preds[0] 77 | else: 78 | loss = loss_fun(preds, labels) 79 | torch.cuda.synchronize() 80 | fw_timer.toc() 81 | # Backward 82 | bw_timer.tic() 83 | loss.backward() 84 | torch.cuda.synchronize() 85 | bw_timer.toc() 86 | # Restore BatchNorm2D running stats 87 | for bn, (mean, var) in zip(bns, bn_stats): 88 | bn.running_mean, bn.running_var = mean, var 89 | return fw_timer.average_time, bw_timer.average_time 90 | 91 | 92 | def compute_time_loader(data_loader): 93 | """Computes loader time.""" 94 | timer = Timer() 95 | loader.shuffle(data_loader, 0) 96 | data_loader_iterator = iter(data_loader) 97 | total_iter = cfg.PREC_TIME.NUM_ITER + cfg.PREC_TIME.WARMUP_ITER 98 | total_iter = min(total_iter, len(data_loader)) 99 | for cur_iter in range(total_iter): 100 | if cur_iter == cfg.PREC_TIME.WARMUP_ITER: 101 | timer.reset() 102 | timer.tic() 103 | next(data_loader_iterator) 104 | timer.toc() 105 | return timer.average_time 106 | 107 | 108 | def compute_time_full(model, loss_fun, train_loader, test_loader): 109 | """Times model and data loader.""" 110 | logger.info("Computing model and loader timings...") 111 | # Compute timings 112 | test_fw_time = compute_time_eval(model) 113 | train_fw_time, train_bw_time = compute_time_train(model, loss_fun) 114 | train_fw_bw_time = train_fw_time + train_bw_time 115 | train_loader_time = compute_time_loader(train_loader) 116 | # Output iter timing 117 | iter_times = { 118 | "test_fw_time": test_fw_time, 119 | "train_fw_time": train_fw_time, 120 | "train_bw_time": train_bw_time, 121 | "train_fw_bw_time": train_fw_bw_time, 122 | "train_loader_time": train_loader_time, 123 | } 124 | logger.info(logging.dump_log_data(iter_times, "iter_times")) 125 | # Output epoch timing 126 | epoch_times = { 127 | "test_fw_time": test_fw_time * len(test_loader), 128 | "train_fw_time": train_fw_time * len(train_loader), 129 | "train_bw_time": train_bw_time * len(train_loader), 130 | "train_fw_bw_time": train_fw_bw_time * len(train_loader), 131 | "train_loader_time": train_loader_time * len(train_loader), 132 | } 133 | logger.info(logging.dump_log_data(epoch_times, "epoch_times")) 134 | # Compute data loader overhead (assuming DATA_LOADER.NUM_WORKERS>1) 135 | overhead = max(0, train_loader_time - train_fw_bw_time) / train_fw_bw_time 136 | logger.info("Overhead of data loader is {:.2f}%".format(overhead * 100)) 137 | -------------------------------------------------------------------------------- /pycls/core/builders.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright (c) Facebook, Inc. and its affiliates. 4 | # 5 | # This source code is licensed under the MIT license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | """Model and loss construction functions.""" 9 | 10 | import torch 11 | from pycls.core.config import cfg 12 | from pycls.models.anynet import AnyNet 13 | from pycls.models.effnet import EffNet 14 | from pycls.models.regnet import RegNet 15 | from pycls.models.resnet import ResNet 16 | from pycls.models.nas.nas import NAS 17 | from pycls.models.nas.nas_search import NAS_Search 18 | from pycls.models.nas_bench.model_builder import NAS_Bench 19 | 20 | 21 | class LabelSmoothedCrossEntropyLoss(torch.nn.Module): 22 | """CrossEntropyLoss with label smoothing.""" 23 | def __init__(self): 24 | super(LabelSmoothedCrossEntropyLoss, self).__init__() 25 | self.eps = cfg.MODEL.LABEL_SMOOTHING_EPS 26 | self.num_classes = cfg.MODEL.NUM_CLASSES 27 | 28 | def forward(self, logits, target): 29 | pred = logits.log_softmax(dim=-1) 30 | with torch.no_grad(): 31 | target_dist = torch.ones_like(pred) * self.eps / (self.num_classes - 1) 32 | target_dist.scatter_(-1, target.unsqueeze(-1), 1 - self.eps) 33 | return (-target_dist * pred).sum(dim=-1).mean() 34 | 35 | 36 | # Supported models 37 | _models = { 38 | "anynet": AnyNet, 39 | "effnet": EffNet, 40 | "resnet": ResNet, 41 | "regnet": RegNet, 42 | "nas": NAS, 43 | "nas_search": NAS_Search, 44 | "nas_bench": NAS_Bench, 45 | } 46 | 47 | # Supported loss functions 48 | _loss_funs = { 49 | "cross_entropy": torch.nn.CrossEntropyLoss, 50 | "label_smoothed_cross_entropy": LabelSmoothedCrossEntropyLoss, 51 | } 52 | 53 | 54 | def get_model(): 55 | """Gets the model class specified in the config.""" 56 | err_str = "Model type '{}' not supported" 57 | assert cfg.MODEL.TYPE in _models.keys(), err_str.format(cfg.MODEL.TYPE) 58 | return _models[cfg.MODEL.TYPE] 59 | 60 | 61 | def get_loss_fun(): 62 | """Gets the loss function class specified in the config.""" 63 | err_str = "Loss function type '{}' not supported" 64 | assert cfg.MODEL.LOSS_FUN in _loss_funs.keys(), err_str.format(cfg.TRAIN.LOSS) 65 | return _loss_funs[cfg.MODEL.LOSS_FUN] 66 | 67 | 68 | def build_model(): 69 | """Builds the model.""" 70 | return get_model()() 71 | 72 | 73 | def build_loss_fun(): 74 | """Build the loss function.""" 75 | if cfg.TASK == "seg": 76 | return get_loss_fun()(ignore_index=255) 77 | else: 78 | return get_loss_fun()() 79 | 80 | 81 | def register_model(name, ctor): 82 | """Registers a model dynamically.""" 83 | _models[name] = ctor 84 | 85 | 86 | def register_loss_fun(name, ctor): 87 | """Registers a loss function dynamically.""" 88 | _loss_funs[name] = ctor 89 | -------------------------------------------------------------------------------- /pycls/core/checkpoint.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright (c) Facebook, Inc. and its affiliates. 4 | # 5 | # This source code is licensed under the MIT license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | """Functions that handle saving and loading of checkpoints.""" 9 | 10 | import os 11 | 12 | import pycls.core.distributed as dist 13 | import torch 14 | from pycls.core.config import cfg 15 | 16 | 17 | # Common prefix for checkpoint file names 18 | _NAME_PREFIX = "model_epoch_" 19 | # Checkpoints directory name 20 | _DIR_NAME = "checkpoints" 21 | 22 | 23 | def get_checkpoint_dir(): 24 | """Retrieves the location for storing checkpoints.""" 25 | return os.path.join(cfg.OUT_DIR, _DIR_NAME) 26 | 27 | 28 | def get_checkpoint(epoch): 29 | """Retrieves the path to a checkpoint file.""" 30 | name = "{}{:04d}.pyth".format(_NAME_PREFIX, epoch) 31 | return os.path.join(get_checkpoint_dir(), name) 32 | 33 | 34 | def get_last_checkpoint(): 35 | """Retrieves the most recent checkpoint (highest epoch number).""" 36 | checkpoint_dir = get_checkpoint_dir() 37 | # Checkpoint file names are in lexicographic order 38 | checkpoints = [f for f in os.listdir(checkpoint_dir) if _NAME_PREFIX in f] 39 | last_checkpoint_name = sorted(checkpoints)[-1] 40 | return os.path.join(checkpoint_dir, last_checkpoint_name) 41 | 42 | 43 | def has_checkpoint(): 44 | """Determines if there are checkpoints available.""" 45 | checkpoint_dir = get_checkpoint_dir() 46 | if not os.path.exists(checkpoint_dir): 47 | return False 48 | return any(_NAME_PREFIX in f for f in os.listdir(checkpoint_dir)) 49 | 50 | 51 | def save_checkpoint(model, optimizer, epoch): 52 | """Saves a checkpoint.""" 53 | # Save checkpoints only from the master process 54 | if not dist.is_master_proc(): 55 | return 56 | # Ensure that the checkpoint dir exists 57 | os.makedirs(get_checkpoint_dir(), exist_ok=True) 58 | # Omit the DDP wrapper in the multi-gpu setting 59 | sd = model.module.state_dict() if cfg.NUM_GPUS > 1 else model.state_dict() 60 | # Record the state 61 | if isinstance(optimizer, list): 62 | checkpoint = { 63 | "epoch": epoch, 64 | "model_state": sd, 65 | "optimizer_w_state": optimizer[0].state_dict(), 66 | "optimizer_a_state": optimizer[1].state_dict(), 67 | "cfg": cfg.dump(), 68 | } 69 | else: 70 | checkpoint = { 71 | "epoch": epoch, 72 | "model_state": sd, 73 | "optimizer_state": optimizer.state_dict(), 74 | "cfg": cfg.dump(), 75 | } 76 | # Write the checkpoint 77 | checkpoint_file = get_checkpoint(epoch + 1) 78 | torch.save(checkpoint, checkpoint_file) 79 | return checkpoint_file 80 | 81 | 82 | def load_checkpoint(checkpoint_file, model, optimizer=None): 83 | """Loads the checkpoint from the given file.""" 84 | err_str = "Checkpoint '{}' not found" 85 | assert os.path.exists(checkpoint_file), err_str.format(checkpoint_file) 86 | # Load the checkpoint on CPU to avoid GPU mem spike 87 | checkpoint = torch.load(checkpoint_file, map_location="cpu") 88 | # Account for the DDP wrapper in the multi-gpu setting 89 | ms = model.module if cfg.NUM_GPUS > 1 else model 90 | ms.load_state_dict(checkpoint["model_state"]) 91 | # Load the optimizer state (commonly not done when fine-tuning) 92 | if optimizer: 93 | if isinstance(optimizer, list): 94 | optimizer[0].load_state_dict(checkpoint["optimizer_w_state"]) 95 | optimizer[1].load_state_dict(checkpoint["optimizer_a_state"]) 96 | else: 97 | optimizer.load_state_dict(checkpoint["optimizer_state"]) 98 | return checkpoint["epoch"] 99 | -------------------------------------------------------------------------------- /pycls/core/distributed.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright (c) Facebook, Inc. and its affiliates. 4 | # 5 | # This source code is licensed under the MIT license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | """Distributed helpers.""" 9 | 10 | import multiprocessing 11 | import os 12 | import signal 13 | import threading 14 | import traceback 15 | 16 | import torch 17 | from pycls.core.config import cfg 18 | 19 | 20 | def is_master_proc(): 21 | """Determines if the current process is the master process. 22 | 23 | Master process is responsible for logging, writing and loading checkpoints. In 24 | the multi GPU setting, we assign the master role to the rank 0 process. When 25 | training using a single GPU, there is a single process which is considered master. 26 | """ 27 | return cfg.NUM_GPUS == 1 or torch.distributed.get_rank() == 0 28 | 29 | 30 | def init_process_group(proc_rank, world_size): 31 | """Initializes the default process group.""" 32 | # Set the GPU to use 33 | torch.cuda.set_device(proc_rank) 34 | # Initialize the process group 35 | torch.distributed.init_process_group( 36 | backend=cfg.DIST_BACKEND, 37 | init_method="tcp://{}:{}".format(cfg.HOST, cfg.PORT), 38 | world_size=world_size, 39 | rank=proc_rank, 40 | ) 41 | 42 | 43 | def destroy_process_group(): 44 | """Destroys the default process group.""" 45 | torch.distributed.destroy_process_group() 46 | 47 | 48 | def scaled_all_reduce(tensors): 49 | """Performs the scaled all_reduce operation on the provided tensors. 50 | 51 | The input tensors are modified in-place. Currently supports only the sum 52 | reduction operator. The reduced values are scaled by the inverse size of the 53 | process group (equivalent to cfg.NUM_GPUS). 54 | """ 55 | # There is no need for reduction in the single-proc case 56 | if cfg.NUM_GPUS == 1: 57 | return tensors 58 | # Queue the reductions 59 | reductions = [] 60 | for tensor in tensors: 61 | reduction = torch.distributed.all_reduce(tensor, async_op=True) 62 | reductions.append(reduction) 63 | # Wait for reductions to finish 64 | for reduction in reductions: 65 | reduction.wait() 66 | # Scale the results 67 | for tensor in tensors: 68 | tensor.mul_(1.0 / cfg.NUM_GPUS) 69 | return tensors 70 | 71 | 72 | class ChildException(Exception): 73 | """Wraps an exception from a child process.""" 74 | 75 | def __init__(self, child_trace): 76 | super(ChildException, self).__init__(child_trace) 77 | 78 | 79 | class ErrorHandler(object): 80 | """Multiprocessing error handler (based on fairseq's). 81 | 82 | Listens for errors in child processes and propagates the tracebacks to the parent. 83 | """ 84 | 85 | def __init__(self, error_queue): 86 | # Shared error queue 87 | self.error_queue = error_queue 88 | # Children processes sharing the error queue 89 | self.children_pids = [] 90 | # Start a thread listening to errors 91 | self.error_listener = threading.Thread(target=self.listen, daemon=True) 92 | self.error_listener.start() 93 | # Register the signal handler 94 | signal.signal(signal.SIGUSR1, self.signal_handler) 95 | 96 | def add_child(self, pid): 97 | """Registers a child process.""" 98 | self.children_pids.append(pid) 99 | 100 | def listen(self): 101 | """Listens for errors in the error queue.""" 102 | # Wait until there is an error in the queue 103 | child_trace = self.error_queue.get() 104 | # Put the error back for the signal handler 105 | self.error_queue.put(child_trace) 106 | # Invoke the signal handler 107 | os.kill(os.getpid(), signal.SIGUSR1) 108 | 109 | def signal_handler(self, _sig_num, _stack_frame): 110 | """Signal handler.""" 111 | # Kill children processes 112 | for pid in self.children_pids: 113 | os.kill(pid, signal.SIGINT) 114 | # Propagate the error from the child process 115 | raise ChildException(self.error_queue.get()) 116 | 117 | 118 | def run(proc_rank, world_size, error_queue, fun, fun_args, fun_kwargs): 119 | """Runs a function from a child process.""" 120 | try: 121 | # Initialize the process group 122 | init_process_group(proc_rank, world_size) 123 | # Run the function 124 | fun(*fun_args, **fun_kwargs) 125 | except KeyboardInterrupt: 126 | # Killed by the parent process 127 | pass 128 | except Exception: 129 | # Propagate exception to the parent process 130 | error_queue.put(traceback.format_exc()) 131 | finally: 132 | # Destroy the process group 133 | destroy_process_group() 134 | 135 | 136 | def multi_proc_run(num_proc, fun, fun_args=(), fun_kwargs=None): 137 | """Runs a function in a multi-proc setting (unless num_proc == 1).""" 138 | # There is no need for multi-proc in the single-proc case 139 | fun_kwargs = fun_kwargs if fun_kwargs else {} 140 | if num_proc == 1: 141 | fun(*fun_args, **fun_kwargs) 142 | return 143 | # Handle errors from training subprocesses 144 | error_queue = multiprocessing.SimpleQueue() 145 | error_handler = ErrorHandler(error_queue) 146 | # Run each training subprocess 147 | ps = [] 148 | for i in range(num_proc): 149 | p_i = multiprocessing.Process( 150 | target=run, args=(i, num_proc, error_queue, fun, fun_args, fun_kwargs) 151 | ) 152 | ps.append(p_i) 153 | p_i.start() 154 | error_handler.add_child(p_i.pid) 155 | # Wait for each subprocess to finish 156 | for p in ps: 157 | p.join() 158 | -------------------------------------------------------------------------------- /pycls/core/io.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright (c) Facebook, Inc. and its affiliates. 4 | # 5 | # This source code is licensed under the MIT license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | """IO utilities (adapted from Detectron)""" 9 | 10 | import logging 11 | import os 12 | import re 13 | import sys 14 | from urllib import request as urlrequest 15 | 16 | 17 | logger = logging.getLogger(__name__) 18 | 19 | _PYCLS_BASE_URL = "https://dl.fbaipublicfiles.com/pycls" 20 | 21 | 22 | def cache_url(url_or_file, cache_dir): 23 | """Download the file specified by the URL to the cache_dir and return the path to 24 | the cached file. If the argument is not a URL, simply return it as is. 25 | """ 26 | is_url = re.match(r"^(?:http)s?://", url_or_file, re.IGNORECASE) is not None 27 | if not is_url: 28 | return url_or_file 29 | url = url_or_file 30 | err_str = "pycls only automatically caches URLs in the pycls S3 bucket: {}" 31 | assert url.startswith(_PYCLS_BASE_URL), err_str.format(_PYCLS_BASE_URL) 32 | cache_file_path = url.replace(_PYCLS_BASE_URL, cache_dir) 33 | if os.path.exists(cache_file_path): 34 | return cache_file_path 35 | cache_file_dir = os.path.dirname(cache_file_path) 36 | if not os.path.exists(cache_file_dir): 37 | os.makedirs(cache_file_dir) 38 | logger.info("Downloading remote file {} to {}".format(url, cache_file_path)) 39 | download_url(url, cache_file_path) 40 | return cache_file_path 41 | 42 | 43 | def _progress_bar(count, total): 44 | """Report download progress. Credit: 45 | https://stackoverflow.com/questions/3173320/text-progress-bar-in-the-console/27871113 46 | """ 47 | bar_len = 60 48 | filled_len = int(round(bar_len * count / float(total))) 49 | percents = round(100.0 * count / float(total), 1) 50 | bar = "=" * filled_len + "-" * (bar_len - filled_len) 51 | sys.stdout.write( 52 | " [{}] {}% of {:.1f}MB file \r".format(bar, percents, total / 1024 / 1024) 53 | ) 54 | sys.stdout.flush() 55 | if count >= total: 56 | sys.stdout.write("\n") 57 | 58 | 59 | def download_url(url, dst_file_path, chunk_size=8192, progress_hook=_progress_bar): 60 | """Download url and write it to dst_file_path. Credit: 61 | https://stackoverflow.com/questions/2028517/python-urllib2-progress-hook 62 | """ 63 | req = urlrequest.Request(url) 64 | response = urlrequest.urlopen(req) 65 | total_size = response.info().get("Content-Length").strip() 66 | total_size = int(total_size) 67 | bytes_so_far = 0 68 | with open(dst_file_path, "wb") as f: 69 | while 1: 70 | chunk = response.read(chunk_size) 71 | bytes_so_far += len(chunk) 72 | if not chunk: 73 | break 74 | if progress_hook: 75 | progress_hook(bytes_so_far, total_size) 76 | f.write(chunk) 77 | return bytes_so_far 78 | -------------------------------------------------------------------------------- /pycls/core/logging.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright (c) Facebook, Inc. and its affiliates. 4 | # 5 | # This source code is licensed under the MIT license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | """Logging.""" 9 | 10 | import builtins 11 | import decimal 12 | import logging 13 | import os 14 | import sys 15 | 16 | import pycls.core.distributed as dist 17 | import simplejson 18 | from pycls.core.config import cfg 19 | 20 | 21 | # Show filename and line number in logs 22 | _FORMAT = "[%(filename)s: %(lineno)3d]: %(message)s" 23 | 24 | # Log file name (for cfg.LOG_DEST = 'file') 25 | _LOG_FILE = "stdout.log" 26 | 27 | # Data output with dump_log_data(data, data_type) will be tagged w/ this 28 | _TAG = "json_stats: " 29 | 30 | # Data output with dump_log_data(data, data_type) will have data[_TYPE]=data_type 31 | _TYPE = "_type" 32 | 33 | 34 | def _suppress_print(): 35 | """Suppresses printing from the current process.""" 36 | 37 | def ignore(*_objects, _sep=" ", _end="\n", _file=sys.stdout, _flush=False): 38 | pass 39 | 40 | builtins.print = ignore 41 | 42 | 43 | def setup_logging(): 44 | """Sets up the logging.""" 45 | # Enable logging only for the master process 46 | if dist.is_master_proc(): 47 | # Clear the root logger to prevent any existing logging config 48 | # (e.g. set by another module) from messing with our setup 49 | logging.root.handlers = [] 50 | # Construct logging configuration 51 | logging_config = {"level": logging.INFO, "format": _FORMAT} 52 | # Log either to stdout or to a file 53 | if cfg.LOG_DEST == "stdout": 54 | logging_config["stream"] = sys.stdout 55 | else: 56 | logging_config["filename"] = os.path.join(cfg.OUT_DIR, _LOG_FILE) 57 | # Configure logging 58 | logging.basicConfig(**logging_config) 59 | else: 60 | _suppress_print() 61 | 62 | 63 | def get_logger(name): 64 | """Retrieves the logger.""" 65 | return logging.getLogger(name) 66 | 67 | 68 | def dump_log_data(data, data_type, prec=4): 69 | """Covert data (a dictionary) into tagged json string for logging.""" 70 | data[_TYPE] = data_type 71 | data = float_to_decimal(data, prec) 72 | data_json = simplejson.dumps(data, sort_keys=True, use_decimal=True) 73 | return "{:s}{:s}".format(_TAG, data_json) 74 | 75 | 76 | def float_to_decimal(data, prec=4): 77 | """Convert floats to decimals which allows for fixed width json.""" 78 | if isinstance(data, dict): 79 | return {k: float_to_decimal(v, prec) for k, v in data.items()} 80 | if isinstance(data, float): 81 | return decimal.Decimal(("{:." + str(prec) + "f}").format(data)) 82 | else: 83 | return data 84 | 85 | 86 | def get_log_files(log_dir, name_filter="", log_file=_LOG_FILE): 87 | """Get all log files in directory containing subdirs of trained models.""" 88 | names = [n for n in sorted(os.listdir(log_dir)) if name_filter in n] 89 | files = [os.path.join(log_dir, n, log_file) for n in names] 90 | f_n_ps = [(f, n) for (f, n) in zip(files, names) if os.path.exists(f)] 91 | files, names = zip(*f_n_ps) if f_n_ps else ([], []) 92 | return files, names 93 | 94 | 95 | def load_log_data(log_file, data_types_to_skip=()): 96 | """Loads log data into a dictionary of the form data[data_type][metric][index].""" 97 | # Load log_file 98 | assert os.path.exists(log_file), "Log file not found: {}".format(log_file) 99 | with open(log_file, "r") as f: 100 | lines = f.readlines() 101 | # Extract and parse lines that start with _TAG and have a type specified 102 | lines = [l[l.find(_TAG) + len(_TAG) :] for l in lines if _TAG in l] 103 | lines = [simplejson.loads(l) for l in lines] 104 | lines = [l for l in lines if _TYPE in l and not l[_TYPE] in data_types_to_skip] 105 | # Generate data structure accessed by data[data_type][index][metric] 106 | data_types = [l[_TYPE] for l in lines] 107 | data = {t: [] for t in data_types} 108 | for t, line in zip(data_types, lines): 109 | del line[_TYPE] 110 | data[t].append(line) 111 | # Generate data structure accessed by data[data_type][metric][index] 112 | for t in data: 113 | metrics = sorted(data[t][0].keys()) 114 | err_str = "Inconsistent metrics in log for _type={}: {}".format(t, metrics) 115 | assert all(sorted(d.keys()) == metrics for d in data[t]), err_str 116 | data[t] = {m: [d[m] for d in data[t]] for m in metrics} 117 | return data 118 | 119 | 120 | def sort_log_data(data): 121 | """Sort each data[data_type][metric] by epoch or keep only first instance.""" 122 | for t in data: 123 | if "epoch" in data[t]: 124 | assert "epoch_ind" not in data[t] and "epoch_max" not in data[t] 125 | data[t]["epoch_ind"] = [int(e.split("/")[0]) for e in data[t]["epoch"]] 126 | data[t]["epoch_max"] = [int(e.split("/")[1]) for e in data[t]["epoch"]] 127 | epoch = data[t]["epoch_ind"] 128 | if "iter" in data[t]: 129 | assert "iter_ind" not in data[t] and "iter_max" not in data[t] 130 | data[t]["iter_ind"] = [int(i.split("/")[0]) for i in data[t]["iter"]] 131 | data[t]["iter_max"] = [int(i.split("/")[1]) for i in data[t]["iter"]] 132 | itr = zip(epoch, data[t]["iter_ind"], data[t]["iter_max"]) 133 | epoch = [e + (i_ind - 1) / i_max for e, i_ind, i_max in itr] 134 | for m in data[t]: 135 | data[t][m] = [v for _, v in sorted(zip(epoch, data[t][m]))] 136 | else: 137 | data[t] = {m: d[0] for m, d in data[t].items()} 138 | return data 139 | -------------------------------------------------------------------------------- /pycls/core/net.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright (c) Facebook, Inc. and its affiliates. 4 | # 5 | # This source code is licensed under the MIT license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | """Functions for manipulating networks.""" 9 | 10 | import itertools 11 | import math 12 | 13 | import torch 14 | import torch.nn as nn 15 | from pycls.core.config import cfg 16 | 17 | 18 | def init_weights(m): 19 | """Performs ResNet-style weight initialization.""" 20 | if isinstance(m, nn.Conv2d): 21 | # Note that there is no bias due to BN 22 | fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 23 | m.weight.data.normal_(mean=0.0, std=math.sqrt(2.0 / fan_out)) 24 | elif isinstance(m, nn.BatchNorm2d): 25 | zero_init_gamma = cfg.BN.ZERO_INIT_FINAL_GAMMA 26 | zero_init_gamma = hasattr(m, "final_bn") and m.final_bn and zero_init_gamma 27 | m.weight.data.fill_(0.0 if zero_init_gamma else 1.0) 28 | m.bias.data.zero_() 29 | elif isinstance(m, nn.Linear): 30 | m.weight.data.normal_(mean=0.0, std=0.01) 31 | m.bias.data.zero_() 32 | 33 | 34 | @torch.no_grad() 35 | def compute_precise_bn_stats(model, loader): 36 | """Computes precise BN stats on training data.""" 37 | # Compute the number of minibatches to use 38 | num_iter = min(cfg.BN.NUM_SAMPLES_PRECISE // loader.batch_size, len(loader)) 39 | # Retrieve the BN layers 40 | bns = [m for m in model.modules() if isinstance(m, torch.nn.BatchNorm2d)] 41 | # Initialize stats storage 42 | mus = [torch.zeros_like(bn.running_mean) for bn in bns] 43 | sqs = [torch.zeros_like(bn.running_var) for bn in bns] 44 | # Remember momentum values 45 | moms = [bn.momentum for bn in bns] 46 | # Disable momentum 47 | for bn in bns: 48 | bn.momentum = 1.0 49 | # Accumulate the stats across the data samples 50 | for inputs, _labels in itertools.islice(loader, num_iter): 51 | model(inputs.cuda()) 52 | # Accumulate the stats for each BN layer 53 | for i, bn in enumerate(bns): 54 | m, v = bn.running_mean, bn.running_var 55 | sqs[i] += (v + m * m) / num_iter 56 | mus[i] += m / num_iter 57 | # Set the stats and restore momentum values 58 | for i, bn in enumerate(bns): 59 | bn.running_var = sqs[i] - mus[i] * mus[i] 60 | bn.running_mean = mus[i] 61 | bn.momentum = moms[i] 62 | 63 | 64 | def reset_bn_stats(model): 65 | """Resets running BN stats.""" 66 | for m in model.modules(): 67 | if isinstance(m, torch.nn.BatchNorm2d): 68 | m.reset_running_stats() 69 | 70 | 71 | def complexity_conv2d(cx, w_in, w_out, k, stride, padding, groups=1, bias=False): 72 | """Accumulates complexity of Conv2D into cx = (h, w, flops, params, acts).""" 73 | h, w, flops, params, acts = cx["h"], cx["w"], cx["flops"], cx["params"], cx["acts"] 74 | h = (h + 2 * padding - k) // stride + 1 75 | w = (w + 2 * padding - k) // stride + 1 76 | flops += k * k * w_in * w_out * h * w // groups 77 | params += k * k * w_in * w_out // groups 78 | flops += w_out if bias else 0 79 | params += w_out if bias else 0 80 | acts += w_out * h * w 81 | return {"h": h, "w": w, "flops": flops, "params": params, "acts": acts} 82 | 83 | 84 | def complexity_batchnorm2d(cx, w_in): 85 | """Accumulates complexity of BatchNorm2D into cx = (h, w, flops, params, acts).""" 86 | h, w, flops, params, acts = cx["h"], cx["w"], cx["flops"], cx["params"], cx["acts"] 87 | params += 2 * w_in 88 | return {"h": h, "w": w, "flops": flops, "params": params, "acts": acts} 89 | 90 | 91 | def complexity_maxpool2d(cx, k, stride, padding): 92 | """Accumulates complexity of MaxPool2d into cx = (h, w, flops, params, acts).""" 93 | h, w, flops, params, acts = cx["h"], cx["w"], cx["flops"], cx["params"], cx["acts"] 94 | h = (h + 2 * padding - k) // stride + 1 95 | w = (w + 2 * padding - k) // stride + 1 96 | return {"h": h, "w": w, "flops": flops, "params": params, "acts": acts} 97 | 98 | 99 | def complexity(model): 100 | """Compute model complexity (model can be model instance or model class).""" 101 | size = cfg.TRAIN.IM_SIZE 102 | cx = {"h": size, "w": size, "flops": 0, "params": 0, "acts": 0} 103 | cx = model.complexity(cx) 104 | return {"flops": cx["flops"], "params": cx["params"], "acts": cx["acts"]} 105 | 106 | 107 | def drop_connect(x, drop_ratio): 108 | """Drop connect (adapted from DARTS).""" 109 | keep_ratio = 1.0 - drop_ratio 110 | mask = torch.empty([x.shape[0], 1, 1, 1], dtype=x.dtype, device=x.device) 111 | mask.bernoulli_(keep_ratio) 112 | x.div_(keep_ratio) 113 | x.mul_(mask) 114 | return x 115 | 116 | 117 | def get_flat_weights(model): 118 | """Gets all model weights as a single flat vector.""" 119 | return torch.cat([p.data.view(-1, 1) for p in model.parameters()], 0) 120 | 121 | 122 | def set_flat_weights(model, flat_weights): 123 | """Sets all model weights from a single flat vector.""" 124 | k = 0 125 | for p in model.parameters(): 126 | n = p.data.numel() 127 | p.data.copy_(flat_weights[k : (k + n)].view_as(p.data)) 128 | k += n 129 | assert k == flat_weights.numel() 130 | -------------------------------------------------------------------------------- /pycls/core/optimizer.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright (c) Facebook, Inc. and its affiliates. 4 | # 5 | # This source code is licensed under the MIT license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | """Optimizer.""" 9 | 10 | import numpy as np 11 | import torch 12 | from pycls.core.config import cfg 13 | 14 | 15 | def construct_optimizer(model): 16 | """Constructs the optimizer. 17 | 18 | Note that the momentum update in PyTorch differs from the one in Caffe2. 19 | In particular, 20 | 21 | Caffe2: 22 | V := mu * V + lr * g 23 | p := p - V 24 | 25 | PyTorch: 26 | V := mu * V + g 27 | p := p - lr * V 28 | 29 | where V is the velocity, mu is the momentum factor, lr is the learning rate, 30 | g is the gradient and p are the parameters. 31 | 32 | Since V is defined independently of the learning rate in PyTorch, 33 | when the learning rate is changed there is no need to perform the 34 | momentum correction by scaling V (unlike in the Caffe2 case). 35 | """ 36 | if cfg.BN.USE_CUSTOM_WEIGHT_DECAY: 37 | # Apply different weight decay to Batchnorm and non-batchnorm parameters. 38 | p_bn = [p for n, p in model.named_parameters() if "bn" in n] 39 | p_non_bn = [p for n, p in model.named_parameters() if "bn" not in n] 40 | optim_params = [ 41 | {"params": p_bn, "weight_decay": cfg.BN.CUSTOM_WEIGHT_DECAY}, 42 | {"params": p_non_bn, "weight_decay": cfg.OPTIM.WEIGHT_DECAY}, 43 | ] 44 | else: 45 | optim_params = model.parameters() 46 | return torch.optim.SGD( 47 | optim_params, 48 | lr=cfg.OPTIM.BASE_LR, 49 | momentum=cfg.OPTIM.MOMENTUM, 50 | weight_decay=cfg.OPTIM.WEIGHT_DECAY, 51 | dampening=cfg.OPTIM.DAMPENING, 52 | nesterov=cfg.OPTIM.NESTEROV, 53 | ) 54 | 55 | 56 | def lr_fun_steps(cur_epoch): 57 | """Steps schedule (cfg.OPTIM.LR_POLICY = 'steps').""" 58 | ind = [i for i, s in enumerate(cfg.OPTIM.STEPS) if cur_epoch >= s][-1] 59 | return cfg.OPTIM.BASE_LR * (cfg.OPTIM.LR_MULT ** ind) 60 | 61 | 62 | def lr_fun_exp(cur_epoch): 63 | """Exponential schedule (cfg.OPTIM.LR_POLICY = 'exp').""" 64 | return cfg.OPTIM.BASE_LR * (cfg.OPTIM.GAMMA ** cur_epoch) 65 | 66 | 67 | def lr_fun_cos(cur_epoch): 68 | """Cosine schedule (cfg.OPTIM.LR_POLICY = 'cos').""" 69 | base_lr, max_epoch = cfg.OPTIM.BASE_LR, cfg.OPTIM.MAX_EPOCH 70 | return 0.5 * base_lr * (1.0 + np.cos(np.pi * cur_epoch / max_epoch)) 71 | 72 | 73 | def get_lr_fun(): 74 | """Retrieves the specified lr policy function""" 75 | lr_fun = "lr_fun_" + cfg.OPTIM.LR_POLICY 76 | if lr_fun not in globals(): 77 | raise NotImplementedError("Unknown LR policy:" + cfg.OPTIM.LR_POLICY) 78 | return globals()[lr_fun] 79 | 80 | 81 | def get_epoch_lr(cur_epoch): 82 | """Retrieves the lr for the given epoch according to the policy.""" 83 | lr = get_lr_fun()(cur_epoch) 84 | # Linear warmup 85 | if cur_epoch < cfg.OPTIM.WARMUP_EPOCHS: 86 | alpha = cur_epoch / cfg.OPTIM.WARMUP_EPOCHS 87 | warmup_factor = cfg.OPTIM.WARMUP_FACTOR * (1.0 - alpha) + alpha 88 | lr *= warmup_factor 89 | return lr 90 | 91 | 92 | def set_lr(optimizer, new_lr): 93 | """Sets the optimizer lr to the specified value.""" 94 | for param_group in optimizer.param_groups: 95 | param_group["lr"] = new_lr 96 | -------------------------------------------------------------------------------- /pycls/core/plotting.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright (c) Facebook, Inc. and its affiliates. 4 | # 5 | # This source code is licensed under the MIT license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | """Plotting functions.""" 9 | 10 | import colorlover as cl 11 | import matplotlib.pyplot as plt 12 | import plotly.graph_objs as go 13 | import plotly.offline as offline 14 | import pycls.core.logging as logging 15 | 16 | 17 | def get_plot_colors(max_colors, color_format="pyplot"): 18 | """Generate colors for plotting.""" 19 | colors = cl.scales["11"]["qual"]["Paired"] 20 | if max_colors > len(colors): 21 | colors = cl.to_rgb(cl.interp(colors, max_colors)) 22 | if color_format == "pyplot": 23 | return [[j / 255.0 for j in c] for c in cl.to_numeric(colors)] 24 | return colors 25 | 26 | 27 | def prepare_plot_data(log_files, names, metric="top1_err"): 28 | """Load logs and extract data for plotting error curves.""" 29 | plot_data = [] 30 | for file, name in zip(log_files, names): 31 | d, data = {}, logging.sort_log_data(logging.load_log_data(file)) 32 | for phase in ["train", "test"]: 33 | x = data[phase + "_epoch"]["epoch_ind"] 34 | y = data[phase + "_epoch"][metric] 35 | d["x_" + phase], d["y_" + phase] = x, y 36 | d[phase + "_label"] = "[{:5.2f}] ".format(min(y) if y else 0) + name 37 | plot_data.append(d) 38 | assert len(plot_data) > 0, "No data to plot" 39 | return plot_data 40 | 41 | 42 | def plot_error_curves_plotly(log_files, names, filename, metric="top1_err"): 43 | """Plot error curves using plotly and save to file.""" 44 | plot_data = prepare_plot_data(log_files, names, metric) 45 | colors = get_plot_colors(len(plot_data), "plotly") 46 | # Prepare data for plots (3 sets, train duplicated w and w/o legend) 47 | data = [] 48 | for i, d in enumerate(plot_data): 49 | s = str(i) 50 | line_train = {"color": colors[i], "dash": "dashdot", "width": 1.5} 51 | line_test = {"color": colors[i], "dash": "solid", "width": 1.5} 52 | data.append( 53 | go.Scatter( 54 | x=d["x_train"], 55 | y=d["y_train"], 56 | mode="lines", 57 | name=d["train_label"], 58 | line=line_train, 59 | legendgroup=s, 60 | visible=True, 61 | showlegend=False, 62 | ) 63 | ) 64 | data.append( 65 | go.Scatter( 66 | x=d["x_test"], 67 | y=d["y_test"], 68 | mode="lines", 69 | name=d["test_label"], 70 | line=line_test, 71 | legendgroup=s, 72 | visible=True, 73 | showlegend=True, 74 | ) 75 | ) 76 | data.append( 77 | go.Scatter( 78 | x=d["x_train"], 79 | y=d["y_train"], 80 | mode="lines", 81 | name=d["train_label"], 82 | line=line_train, 83 | legendgroup=s, 84 | visible=False, 85 | showlegend=True, 86 | ) 87 | ) 88 | # Prepare layout w ability to toggle 'all', 'train', 'test' 89 | titlefont = {"size": 18, "color": "#7f7f7f"} 90 | vis = [[True, True, False], [False, False, True], [False, True, False]] 91 | buttons = zip(["all", "train", "test"], [[{"visible": v}] for v in vis]) 92 | buttons = [{"label": b, "args": v, "method": "update"} for b, v in buttons] 93 | layout = go.Layout( 94 | title=metric + " vs. epoch
[dash=train, solid=test]", 95 | xaxis={"title": "epoch", "titlefont": titlefont}, 96 | yaxis={"title": metric, "titlefont": titlefont}, 97 | showlegend=True, 98 | hoverlabel={"namelength": -1}, 99 | updatemenus=[ 100 | { 101 | "buttons": buttons, 102 | "direction": "down", 103 | "showactive": True, 104 | "x": 1.02, 105 | "xanchor": "left", 106 | "y": 1.08, 107 | "yanchor": "top", 108 | } 109 | ], 110 | ) 111 | # Create plotly plot 112 | offline.plot({"data": data, "layout": layout}, filename=filename) 113 | 114 | 115 | def plot_error_curves_pyplot(log_files, names, filename=None, metric="top1_err"): 116 | """Plot error curves using matplotlib.pyplot and save to file.""" 117 | plot_data = prepare_plot_data(log_files, names, metric) 118 | colors = get_plot_colors(len(names)) 119 | for ind, d in enumerate(plot_data): 120 | c, lbl = colors[ind], d["test_label"] 121 | plt.plot(d["x_train"], d["y_train"], "--", c=c, alpha=0.8) 122 | plt.plot(d["x_test"], d["y_test"], "-", c=c, alpha=0.8, label=lbl) 123 | plt.title(metric + " vs. epoch\n[dash=train, solid=test]", fontsize=14) 124 | plt.xlabel("epoch", fontsize=14) 125 | plt.ylabel(metric, fontsize=14) 126 | plt.grid(alpha=0.4) 127 | plt.legend() 128 | if filename: 129 | plt.savefig(filename) 130 | plt.clf() 131 | else: 132 | plt.show() 133 | -------------------------------------------------------------------------------- /pycls/core/timer.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright (c) Facebook, Inc. and its affiliates. 4 | # 5 | # This source code is licensed under the MIT license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | """Timer.""" 9 | 10 | import time 11 | 12 | 13 | class Timer(object): 14 | """A simple timer (adapted from Detectron).""" 15 | 16 | def __init__(self): 17 | self.total_time = None 18 | self.calls = None 19 | self.start_time = None 20 | self.diff = None 21 | self.average_time = None 22 | self.reset() 23 | 24 | def tic(self): 25 | # using time.time as time.clock does not normalize for multithreading 26 | self.start_time = time.time() 27 | 28 | def toc(self): 29 | self.diff = time.time() - self.start_time 30 | self.total_time += self.diff 31 | self.calls += 1 32 | self.average_time = self.total_time / self.calls 33 | 34 | def reset(self): 35 | self.total_time = 0.0 36 | self.calls = 0 37 | self.start_time = 0.0 38 | self.diff = 0.0 39 | self.average_time = 0.0 40 | -------------------------------------------------------------------------------- /pycls/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BayesWatch/nas-without-training/b3a82a6642564df115f989ff940ec6b8ef9ca9d3/pycls/models/__init__.py -------------------------------------------------------------------------------- /pycls/models/common.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright (c) Facebook, Inc. and its affiliates. 4 | # 5 | # This source code is licensed under the MIT license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | import torch 9 | import torch.nn as nn 10 | 11 | from pycls.core.config import cfg 12 | 13 | 14 | def Preprocess(x): 15 | if cfg.TASK == 'jig': 16 | assert len(x.shape) == 5, 'Wrong tensor dimension for jigsaw' 17 | assert x.shape[1] == cfg.JIGSAW_GRID ** 2, 'Wrong grid for jigsaw' 18 | x = x.view([x.shape[0] * x.shape[1], x.shape[2], x.shape[3], x.shape[4]]) 19 | return x 20 | 21 | 22 | class Classifier(nn.Module): 23 | def __init__(self, channels, num_classes): 24 | super(Classifier, self).__init__() 25 | if cfg.TASK == 'jig': 26 | self.jig_sq = cfg.JIGSAW_GRID ** 2 27 | self.pooling = nn.AdaptiveAvgPool2d(1) 28 | self.classifier = nn.Linear(channels * self.jig_sq, num_classes) 29 | elif cfg.TASK == 'col': 30 | self.classifier = nn.Conv2d(channels, num_classes, kernel_size=1, stride=1) 31 | elif cfg.TASK == 'seg': 32 | self.classifier = ASPP(channels, cfg.MODEL.ASPP_CHANNELS, num_classes, cfg.MODEL.ASPP_RATES) 33 | else: 34 | self.pooling = nn.AdaptiveAvgPool2d(1) 35 | self.classifier = nn.Linear(channels, num_classes) 36 | 37 | def forward(self, x, shape): 38 | if cfg.TASK == 'jig': 39 | x = self.pooling(x) 40 | x = x.view([x.shape[0] // self.jig_sq, x.shape[1] * self.jig_sq, x.shape[2], x.shape[3]]) 41 | x = self.classifier(x.view(x.size(0), -1)) 42 | elif cfg.TASK in ['col', 'seg']: 43 | x = self.classifier(x) 44 | x = nn.Upsample(shape, mode='bilinear', align_corners=True)(x) 45 | else: 46 | x = self.pooling(x) 47 | x = self.classifier(x.view(x.size(0), -1)) 48 | return x 49 | 50 | 51 | class ASPP(nn.Module): 52 | def __init__(self, in_channels, out_channels, num_classes, rates): 53 | super(ASPP, self).__init__() 54 | assert len(rates) in [1, 3] 55 | self.rates = rates 56 | self.global_pooling = nn.AdaptiveAvgPool2d(1) 57 | self.aspp1 = nn.Sequential( 58 | nn.Conv2d(in_channels, out_channels, 1, bias=False), 59 | nn.BatchNorm2d(out_channels), 60 | nn.ReLU(inplace=True) 61 | ) 62 | self.aspp2 = nn.Sequential( 63 | nn.Conv2d(in_channels, out_channels, 3, dilation=rates[0], 64 | padding=rates[0], bias=False), 65 | nn.BatchNorm2d(out_channels), 66 | nn.ReLU(inplace=True) 67 | ) 68 | if len(self.rates) == 3: 69 | self.aspp3 = nn.Sequential( 70 | nn.Conv2d(in_channels, out_channels, 3, dilation=rates[1], 71 | padding=rates[1], bias=False), 72 | nn.BatchNorm2d(out_channels), 73 | nn.ReLU(inplace=True) 74 | ) 75 | self.aspp4 = nn.Sequential( 76 | nn.Conv2d(in_channels, out_channels, 3, dilation=rates[2], 77 | padding=rates[2], bias=False), 78 | nn.BatchNorm2d(out_channels), 79 | nn.ReLU(inplace=True) 80 | ) 81 | self.aspp5 = nn.Sequential( 82 | nn.Conv2d(in_channels, out_channels, 1, bias=False), 83 | nn.BatchNorm2d(out_channels), 84 | nn.ReLU(inplace=True) 85 | ) 86 | self.classifier = nn.Sequential( 87 | nn.Conv2d(out_channels * (len(rates) + 2), out_channels, 1, 88 | bias=False), 89 | nn.BatchNorm2d(out_channels), 90 | nn.ReLU(inplace=True), 91 | nn.Conv2d(out_channels, num_classes, 1) 92 | ) 93 | 94 | def forward(self, x): 95 | x1 = self.aspp1(x) 96 | x2 = self.aspp2(x) 97 | x5 = self.global_pooling(x) 98 | x5 = self.aspp5(x5) 99 | x5 = nn.Upsample((x.shape[2], x.shape[3]), mode='bilinear', 100 | align_corners=True)(x5) 101 | if len(self.rates) == 3: 102 | x3 = self.aspp3(x) 103 | x4 = self.aspp4(x) 104 | x = torch.cat((x1, x2, x3, x4, x5), 1) 105 | else: 106 | x = torch.cat((x1, x2, x5), 1) 107 | x = self.classifier(x) 108 | return x 109 | -------------------------------------------------------------------------------- /pycls/models/regnet.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright (c) Facebook, Inc. and its affiliates. 4 | # 5 | # This source code is licensed under the MIT license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | """RegNet models.""" 9 | 10 | import numpy as np 11 | from pycls.core.config import cfg 12 | from pycls.models.anynet import AnyNet 13 | 14 | 15 | def quantize_float(f, q): 16 | """Converts a float to closest non-zero int divisible by q.""" 17 | return int(round(f / q) * q) 18 | 19 | 20 | def adjust_ws_gs_comp(ws, bms, gs): 21 | """Adjusts the compatibility of widths and groups.""" 22 | ws_bot = [int(w * b) for w, b in zip(ws, bms)] 23 | gs = [min(g, w_bot) for g, w_bot in zip(gs, ws_bot)] 24 | ws_bot = [quantize_float(w_bot, g) for w_bot, g in zip(ws_bot, gs)] 25 | ws = [int(w_bot / b) for w_bot, b in zip(ws_bot, bms)] 26 | return ws, gs 27 | 28 | 29 | def get_stages_from_blocks(ws, rs): 30 | """Gets ws/ds of network at each stage from per block values.""" 31 | ts_temp = zip(ws + [0], [0] + ws, rs + [0], [0] + rs) 32 | ts = [w != wp or r != rp for w, wp, r, rp in ts_temp] 33 | s_ws = [w for w, t in zip(ws, ts[:-1]) if t] 34 | s_ds = np.diff([d for d, t in zip(range(len(ts)), ts) if t]).tolist() 35 | return s_ws, s_ds 36 | 37 | 38 | def generate_regnet(w_a, w_0, w_m, d, q=8): 39 | """Generates per block ws from RegNet parameters.""" 40 | assert w_a >= 0 and w_0 > 0 and w_m > 1 and w_0 % q == 0 41 | ws_cont = np.arange(d) * w_a + w_0 42 | ks = np.round(np.log(ws_cont / w_0) / np.log(w_m)) 43 | ws = w_0 * np.power(w_m, ks) 44 | ws = np.round(np.divide(ws, q)) * q 45 | num_stages, max_stage = len(np.unique(ws)), ks.max() + 1 46 | ws, ws_cont = ws.astype(int).tolist(), ws_cont.tolist() 47 | return ws, num_stages, max_stage, ws_cont 48 | 49 | 50 | class RegNet(AnyNet): 51 | """RegNet model.""" 52 | 53 | @staticmethod 54 | def get_args(): 55 | """Convert RegNet to AnyNet parameter format.""" 56 | # Generate RegNet ws per block 57 | w_a, w_0, w_m, d = cfg.REGNET.WA, cfg.REGNET.W0, cfg.REGNET.WM, cfg.REGNET.DEPTH 58 | ws, num_stages, _, _ = generate_regnet(w_a, w_0, w_m, d) 59 | # Convert to per stage format 60 | s_ws, s_ds = get_stages_from_blocks(ws, ws) 61 | # Use the same gw, bm and ss for each stage 62 | s_gs = [cfg.REGNET.GROUP_W for _ in range(num_stages)] 63 | s_bs = [cfg.REGNET.BOT_MUL for _ in range(num_stages)] 64 | s_ss = [cfg.REGNET.STRIDE for _ in range(num_stages)] 65 | # Adjust the compatibility of ws and gws 66 | s_ws, s_gs = adjust_ws_gs_comp(s_ws, s_bs, s_gs) 67 | # Get AnyNet arguments defining the RegNet 68 | return { 69 | "stem_type": cfg.REGNET.STEM_TYPE, 70 | "stem_w": cfg.REGNET.STEM_W, 71 | "block_type": cfg.REGNET.BLOCK_TYPE, 72 | "ds": s_ds, 73 | "ws": s_ws, 74 | "ss": s_ss, 75 | "bms": s_bs, 76 | "gws": s_gs, 77 | "se_r": cfg.REGNET.SE_R if cfg.REGNET.SE_ON else None, 78 | "nc": cfg.MODEL.NUM_CLASSES, 79 | } 80 | 81 | def __init__(self): 82 | kwargs = RegNet.get_args() 83 | super(RegNet, self).__init__(**kwargs) 84 | 85 | @staticmethod 86 | def complexity(cx, **kwargs): 87 | """Computes model complexity. If you alter the model, make sure to update.""" 88 | kwargs = RegNet.get_args() if not kwargs else kwargs 89 | return AnyNet.complexity(cx, **kwargs) 90 | -------------------------------------------------------------------------------- /scorehook.sh: -------------------------------------------------------------------------------- 1 | python score_networks.py --trainval --augtype none --repeat 1 --score hook_logdet --sigma 0.05 --nasspace nasbench201 --batch_size 128 --GPU 3 --dataset cifar10 2 | python score_networks.py --trainval --augtype none --repeat 1 --score hook_logdet --sigma 0.05 --nasspace nasbench201 --batch_size 128 --GPU 3 --dataset cifar100 --data_loc ../cifar100/ 3 | python score_networks.py --trainval --augtype none --repeat 1 --score hook_logdet --sigma 0.05 --nasspace nasbench201 --batch_size 128 --GPU 3 --dataset ImageNet16-120 --data_loc ../imagenet16/Imagenet16/ 4 | 5 | 6 | python score_networks.py --trainval --augtype none --repeat 1 --score hook_logdet --sigma 0.05 --nasspace nds_pnas --batch_size 128 --GPU 3 7 | python score_networks.py --trainval --augtype none --repeat 1 --score hook_logdet --sigma 0.05 --nasspace nds_enas --batch_size 128 --GPU 3 8 | python score_networks.py --trainval --augtype none --repeat 1 --score hook_logdet --sigma 0.05 --nasspace nds_darts --batch_size 128 --GPU 3 9 | python score_networks.py --trainval --augtype none --repeat 1 --score hook_logdet --sigma 0.05 --nasspace nds_darts_fix-w-d --batch_size 128 --GPU 3 10 | python score_networks.py --trainval --augtype none --repeat 1 --score hook_logdet --sigma 0.05 --nasspace nds_nasnet --batch_size 128 --GPU 3 11 | python score_networks.py --trainval --augtype none --repeat 1 --score hook_logdet --sigma 0.05 --nasspace nds_amoeba --batch_size 128 --GPU 3 12 | python score_networks.py --trainval --augtype none --repeat 1 --score hook_logdet --sigma 0.05 --nasspace nds_resnet --batch_size 128 --GPU 3 13 | python score_networks.py --trainval --augtype none --repeat 1 --score hook_logdet --sigma 0.05 --nasspace nds_resnext-a --batch_size 128 --GPU 3 14 | python score_networks.py --trainval --augtype none --repeat 1 --score hook_logdet --sigma 0.05 --nasspace nds_resnext-b --batch_size 128 --GPU 3 15 | 16 | 17 | 18 | python score_networks.py --trainval --augtype none --repeat 1 --score hook_logdet --sigma 0.05 --nasspace amoeba_in --batch_size 128 --GPU 3 --dataset imagenette2 --data_loc ../imagenette2/ 19 | python score_networks.py --trainval --augtype none --repeat 1 --score hook_logdet --sigma 0.05 --nasspace nds_amoeba_in --batch_size 128 --GPU 3 --dataset imagenette2 --data_loc ../imagenette2/ 20 | python score_networks.py --trainval --augtype none --repeat 1 --score hook_logdet --sigma 0.05 --nasspace nds_darts_in --batch_size 128 --GPU 3 --dataset imagenette2 --data_loc ../imagenette2/ 21 | python score_networks.py --trainval --augtype none --repeat 1 --score hook_logdet --sigma 0.05 --nasspace nds_nasnet_in --batch_size 128 --GPU 3 --dataset imagenette2 --data_loc ../imagenette2/ 22 | python score_networks.py --trainval --augtype none --repeat 1 --score hook_logdet --sigma 0.05 --nasspace nds_pnas_in --batch_size 128 --GPU 3 --dataset imagenette2 --data_loc ../imagenette2/ 23 | python score_networks.py --trainval --augtype none --repeat 1 --score hook_logdet --sigma 0.05 --nasspace nds_enas_in --batch_size 128 --GPU 3 --dataset imagenette2 --data_loc ../imagenette2/ 24 | python score_networks.py --trainval --augtype none --repeat 1 --score hook_logdet --sigma 0.05 --nasspace nds_resnext-a_in --batch_size 128 --GPU 3 --dataset imagenette2 --data_loc ../imagenette2/ 25 | 26 | 27 | 28 | python score_networks.py --trainval --augtype none --repeat 1 --score hook_logdet --sigma 0.05 --nasspace nasbench201 --batch_size 128 --GPU 3 --dataset cifar100 --data_loc ../cifar100/ 29 | python score_networks.py --trainval --augtype none --repeat 1 --score hook_logdet --sigma 0.05 --nasspace nasbench201 --batch_size 128 --GPU 3 --dataset ImageNet16-120 --data_loc ../imagenet16/Imagenet16/ 30 | 31 | python score_networks.py --trainval --augtype none --repeat 1 --score hook_logdet --sigma 0.05 --nasspace nasbench101 --batch_size 128 --GPU 3 --api_loc ../nasbench_only108.tfrecord 32 | 33 | -------------------------------------------------------------------------------- /scores.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | 5 | 6 | 7 | def hooklogdet(K, labels=None): 8 | s, ld = np.linalg.slogdet(K) 9 | return ld 10 | 11 | def random_score(jacob, label=None): 12 | return np.random.normal() 13 | 14 | 15 | _scores = { 16 | 'hook_logdet': hooklogdet, 17 | 'random': random_score 18 | } 19 | 20 | def get_score_func(score_name): 21 | return _scores[score_name] 22 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from pycls.models.nas.nas import Cell 3 | 4 | class DropChannel(torch.nn.Module): 5 | def __init__(self, p, mod): 6 | super(DropChannel, self).__init__() 7 | self.mod = mod 8 | self.p = p 9 | def forward(self, s0, s1, droppath): 10 | ret = self.mod(s0, s1, droppath) 11 | return ret 12 | 13 | 14 | class DropConnect(torch.nn.Module): 15 | def __init__(self, p): 16 | super(DropConnect, self).__init__() 17 | self.p = p 18 | def forward(self, inputs): 19 | batch_size = inputs.shape[0] 20 | dim1 = inputs.shape[2] 21 | dim2 = inputs.shape[3] 22 | channel_size = inputs.shape[1] 23 | keep_prob = 1 - self.p 24 | # generate binary_tensor mask according to probability (p for 0, 1-p for 1) 25 | random_tensor = keep_prob 26 | random_tensor += torch.rand([batch_size, channel_size, 1, 1], dtype=inputs.dtype, device=inputs.device) 27 | binary_tensor = torch.floor(random_tensor) 28 | output = inputs / keep_prob * binary_tensor 29 | return output 30 | 31 | def add_dropout(network, p, prefix=''): 32 | #p = 0.5 33 | for attr_str in dir(network): 34 | target_attr = getattr(network, attr_str) 35 | if isinstance(target_attr, torch.nn.Conv2d): 36 | setattr(network, attr_str, torch.nn.Sequential(target_attr, DropConnect(p))) 37 | elif isinstance(target_attr, Cell): 38 | setattr(network, attr_str, DropChannel(p, target_attr)) 39 | for n, ch in list(network.named_children()): 40 | #print(f'{prefix}add_dropout {n}') 41 | if isinstance(ch, torch.nn.Conv2d): 42 | setattr(network, n, torch.nn.Sequential(ch, DropConnect(p))) 43 | elif isinstance(ch, Cell): 44 | setattr(network, n, DropChannel(p, ch)) 45 | else: 46 | add_dropout(ch, p, prefix + '\t') 47 | 48 | 49 | 50 | 51 | def orth_init(m): 52 | if isinstance(m, (torch.nn.Conv2d, torch.nn.Linear)): 53 | torch.nn.init.orthogonal_(m.weight) 54 | 55 | def uni_init(m): 56 | if isinstance(m, (torch.nn.Conv2d, torch.nn.Linear)): 57 | torch.nn.init.uniform_(m.weight) 58 | 59 | def uni2_init(m): 60 | if isinstance(m, (torch.nn.Conv2d, torch.nn.Linear)): 61 | torch.nn.init.uniform_(m.weight, -1., 1.) 62 | 63 | def uni3_init(m): 64 | if isinstance(m, (torch.nn.Conv2d, torch.nn.Linear)): 65 | torch.nn.init.uniform_(m.weight, -.5, .5) 66 | 67 | def norm_init(m): 68 | if isinstance(m, (torch.nn.Conv2d, torch.nn.Linear)): 69 | torch.nn.init.norm_(m.weight) 70 | 71 | def eye_init(m): 72 | if isinstance(m, torch.nn.Linear): 73 | torch.nn.init.eye_(m.weight) 74 | elif isinstance(m, torch.nn.Conv2d): 75 | torch.nn.init.dirac_(m.weight) 76 | 77 | 78 | 79 | def fixup_init(m): 80 | if isinstance(m, torch.nn.Conv2d): 81 | torch.nn.init.zero_(m.weight) 82 | elif isinstance(m, torch.nn.Linear): 83 | torch.nn.init.zero_(m.weight) 84 | torch.nn.init.zero_(m.bias) 85 | 86 | 87 | def init_network(network, init): 88 | if init == 'orthogonal': 89 | network.apply(orth_init) 90 | elif init == 'uniform': 91 | print('uniform') 92 | network.apply(uni_init) 93 | elif init == 'uniform2': 94 | network.apply(uni2_init) 95 | elif init == 'uniform3': 96 | network.apply(uni3_init) 97 | elif init == 'normal': 98 | network.apply(norm_init) 99 | elif init == 'identity': 100 | network.apply(eye_init) 101 | --------------------------------------------------------------------------------