├── README.md ├── configs ├── dn │ └── BSD500_3c4n │ │ ├── 03_search_CR_R0.yaml │ │ ├── 03_search_CR_R10.yaml │ │ ├── 03_search_CR_R20.yaml │ │ └── 03_train_CR_R0 │ │ ├── train_s30.yaml │ │ ├── train_s50.yaml │ │ └── train_s70.yaml └── sr │ └── DIV2K_2c3n │ ├── 03_search_CR.yaml │ ├── 03_x2_infe_CR.yaml │ ├── 03_x2_train_CR.yaml │ ├── 03_x3_infe_CR.yaml │ ├── 03_x3_train_CR.yaml │ ├── 03_x4_infe_CR.yaml │ ├── 03_x4_train_CR.yaml │ ├── 03_x8_infe_CR.yaml │ └── 03_x8_train_CR.yaml ├── data_generation ├── rename.py └── rgb2gray_add_noise.py ├── log file for DNAS_For_IR.txt ├── one_stage_nas ├── config │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-36.pyc │ │ ├── defaults.cpython-36.pyc │ │ └── paths_catalog.cpython-36.pyc │ ├── defaults.py │ └── paths_catalog.py ├── darts │ ├── __pycache__ │ │ ├── cell.cpython-36.pyc │ │ ├── genotypes.cpython-36.pyc │ │ └── operations.cpython-36.pyc │ ├── cell.py │ ├── genotypes.py │ └── operations.py ├── data │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-36.pyc │ │ └── build_dataset.cpython-36.pyc │ ├── build_dataset.py │ └── datasets │ │ ├── __init__.py │ │ ├── __pycache__ │ │ ├── __init__.cpython-36.pyc │ │ ├── denoise.cpython-36.pyc │ │ ├── see_in_dark.cpython-36.pyc │ │ ├── super_resolution.cpython-36.pyc │ │ ├── tasks_dict.cpython-36.pyc │ │ └── transforms.cpython-36.pyc │ │ ├── denoise.py │ │ ├── see_in_dark.py │ │ ├── super_resolution.py │ │ ├── tasks_dict.py │ │ └── transforms.py ├── engine │ ├── __pycache__ │ │ ├── inference.cpython-36.pyc │ │ ├── searcher.cpython-36.pyc │ │ ├── trainer.cpython-36.pyc │ │ └── trainer_joint.cpython-36.pyc │ ├── inference.py │ ├── searcher.py │ ├── trainer.py │ └── trainer_joint.py ├── modeling │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-36.pyc │ │ ├── amt_discrete.cpython-36.pyc │ │ ├── architectures.cpython-36.pyc │ │ ├── auto_multitask.cpython-36.pyc │ │ ├── common.cpython-36.pyc │ │ ├── decoders.cpython-36.pyc │ │ ├── dn_compnet.cpython-36.pyc │ │ ├── dn_supernet.cpython-36.pyc │ │ ├── loss.cpython-36.pyc │ │ ├── sid_compnet.cpython-36.pyc │ │ ├── sid_supernet.cpython-36.pyc │ │ ├── sr_compnet.cpython-36.pyc │ │ └── sr_supernet.cpython-36.pyc │ ├── amt_discrete.py │ ├── architectures.py │ ├── auto_multitask.py │ ├── common.py │ ├── decoders.py │ ├── dn_compnet.py │ ├── dn_supernet.py │ ├── loss.py │ ├── sid_compnet.py │ ├── sid_supernet.py │ ├── sr_compnet.py │ └── sr_supernet.py ├── solver │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-36.pyc │ │ ├── build.cpython-36.pyc │ │ └── lr_scheduler.cpython-36.pyc │ ├── build.py │ └── lr_scheduler.py └── utils │ ├── SSIM.py │ ├── __pycache__ │ ├── checkpoint.cpython-36.pyc │ ├── comm.cpython-36.pyc │ ├── evaluation_metrics.cpython-36.pyc │ ├── logger.cpython-36.pyc │ ├── metric_logger.cpython-36.pyc │ ├── misc.cpython-36.pyc │ └── visualize.cpython-36.pyc │ ├── checkpoint.py │ ├── comm.py │ ├── evaluation_metrics.py │ ├── logger.py │ ├── metric_logger.py │ ├── misc.py │ └── visualize.py ├── preprocess ├── dataset_json │ ├── dn │ │ ├── BSD500_200.json │ │ └── BSD500_300.json │ ├── sid │ │ └── Sony │ │ │ ├── test.json │ │ │ └── train.json │ └── sr │ │ ├── BSD100.json │ │ ├── DIV2K_800.json │ │ ├── Manga109.json │ │ ├── Set14.json │ │ ├── Set5.json │ │ └── Urban100.json ├── dn_preprocess.py ├── image_check.py ├── sid_preprocess.py ├── sr_preprocess.py └── utils │ ├── __init__.py │ ├── __pycache__ │ ├── __init__.cpython-36.pyc │ ├── dn_utils.cpython-36.pyc │ ├── sid_utils.cpython-36.pyc │ └── sr_utils.cpython-36.pyc │ ├── dn_utils.py │ ├── sid_utils.py │ └── sr_utils.py ├── requirements.txt └── tools ├── dn_eval.py ├── search.py ├── sr_eval.py └── train.py /README.md: -------------------------------------------------------------------------------- 1 | # HiNAS 2 | Code for paper: 3 | > [Memory-Efficient Hierarchical Neural Architecture Search for Image Denoising (CVPR 2020)](https://arxiv.org/abs/1909.08228) 4 | 5 | > [Memory-Efficient Hierarchical Neural Architecture Search for Image Restoration](https://arxiv.org/abs/2012.13212) 6 | 7 | Compared with the mehtod propsoed in Memory-Efficient Hierarchical Neural Architecture Search for Image Denoising, we made some improvements, which are explained in "log file for DNAS_For_IR" 8 | 9 | 10 | ## Requirements 11 | ``` 12 | This code is tested on Pytorch = 1.0.0. 13 | 14 | I build my experimental environment by create a virtual env via anaconda3. 15 | After activating you env, you can install other all dependences via run: pip install -r requirements.txt. 16 | Note that, to install graphviz, you also need to run: conda install graphviz. 17 | ``` 18 | 19 | ## Searching 20 | ``` 21 | #searching for denoising network 22 | cd ./tools/ 23 | python search.py --config-file "../configs/dn/BSD500_3c4n/03_search_CR_R0.yaml" --device '0' 24 | 25 | #searching for super-resolution network 26 | cd ./tools/ 27 | python search.py --config-file "../configs/sr/DIV2K_3c3n/03_search_CR.yaml" --device '0' 28 | ``` 29 | 30 | ## Training 31 | ``` 32 | #training the founded denosing network with noise factor=30 33 | cd ./tools/ 34 | python train.py --config-file "../configs/dn/BSD500_3c4n/03_train_CR_RO/train_s30.yaml" --device '0' 35 | 36 | #training the founded super-resolution network with SR factor=3 37 | cd ./tools/ 38 | python train.py --config-file "../configs/sr/DIV2K_2c3n/03_x3_train_CR.yaml" --device '0' 39 | ``` 40 | ## Inference 41 | ``` 42 | # testing the trained denoising network with noise factor=[30 50 70] 43 | cd ./tools/ 44 | python dn_eval.py --config-file "../configs/dn/BSD500_3c4n/03_train_CR_RO/03_infe.yaml" --device '0' 45 | 46 | # testing the trained super-resolution network with sr factor=3 47 | cd ./tools/ 48 | python sr_eval.py --config-file "../configs/sr/DIV2K_2c3n/03_x3_infe_CR.yaml" --device '0' 49 | ``` 50 | 51 | ## Datasets 52 | Fetch code:1111 53 | 54 | >Denoising datasets: [BSD200](https://pan.baidu.com/s/1HS4g7DYUxZv-tqQNJI6WsA) [BSD300](https://pan.baidu.com/s/1WBGWtZj91p2x2bPRkmlWVw) 55 | 56 | >Super-resolution datasets: [DIV2K_800](https://pan.baidu.com/s/1J6S0dTs1lJG3c0iqdQCBEA) [DIV2K_100](https://pan.baidu.com/s/1e6SrUTx94cWGSu0vjUXASA) [BSD100](https://pan.baidu.com/s/1ARqQbmEA2XhT3NbLlDkhWA) [BSDS100](https://pan.baidu.com/s/1oueecq2DogYzQ3naDZoqtg) [Urban100](https://pan.baidu.com/s/13S6EIS3ezfwIb_9mw7p7mw) [Manga109](https://pan.baidu.com/s/1ns1uMk3KL0dja-_Xq_jQUA) [General100](https://pan.baidu.com/s/17Fsm0PDjz8jo0-LB9TiW4Q) [Set14](https://pan.baidu.com/s/1SMbippo3jX1IRKslWdaVeQ) [Set5](https://pan.baidu.com/s/1l0ygYeIQ-1PRqZIjQnl5oQ) 57 | 58 | 59 | 60 | ## Citation 61 | If you use this code in your paper, please cite our papers 62 | ``` 63 | @inproceedings{zhang2020memory, 64 | title={Memory-efficient hierarchical neural architecture search for image denoising}, 65 | author={Zhang, Haokui and Li, Ying and Chen, Hao and Shen, Chunhua}, 66 | booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition}, 67 | pages={3657--3666}, 68 | year={2020} 69 | } 70 | ``` 71 | -------------------------------------------------------------------------------- /configs/dn/BSD500_3c4n/03_search_CR_R0.yaml: -------------------------------------------------------------------------------- 1 | DATASET: 2 | DATA_ROOT: /home/ren/hkzhang/data_p/nas_data 3 | DATA_NAME: BSD500_300 4 | CROP_SIZE: 64 5 | TASK: "dn" 6 | LOAD_ALL: True 7 | TO_GRAY: True 8 | SEARCH: 9 | SEARCH_ON: True 10 | ARCH_START_EPOCH: 20 11 | TIE_CELL: False 12 | R_SEED: 0 13 | MODEL: 14 | META_ARCHITECTURE: Dn_supernet 15 | META_MODE: Width 16 | NUM_STRIDES: 3 17 | NUM_LAYERS: 3 18 | NUM_BLOCKS: 4 19 | FILTER_MULTIPLIER: 10 20 | IN_CHANNEL: 1 21 | PRIMITIVES: "NO_DEF_L" 22 | ACTIVATION_F: "Leaky" 23 | AFFINE: False 24 | USE_ASPP: False 25 | USE_RES: True 26 | DATALOADER: 27 | BATCH_SIZE_TRAIN: 6 28 | BATCH_SIZE_TEST: 12 29 | NUM_WORKERS: 4 30 | SIGMA: [30, 50, 70] 31 | DATA_LIST_DIR: ../preprocess/dataset_json 32 | 33 | SOLVER: 34 | LOSS: ['mse', 'log_ssim'] 35 | LOSS_WEIGHT: [1.0, 0.5] 36 | MAX_EPOCH: 100 37 | 38 | OUTPUT_DIR: output_R0 39 | -------------------------------------------------------------------------------- /configs/dn/BSD500_3c4n/03_search_CR_R10.yaml: -------------------------------------------------------------------------------- 1 | DATASET: 2 | DATA_ROOT: /home/ren/hkzhang/data_p/nas_data 3 | DATA_NAME: BSD500_300 4 | CROP_SIZE: 64 5 | TASK: "dn" 6 | LOAD_ALL: True 7 | TO_GRAY: True 8 | SEARCH: 9 | SEARCH_ON: True 10 | ARCH_START_EPOCH: 20 11 | TIE_CELL: False 12 | R_SEED: 10 13 | MODEL: 14 | META_ARCHITECTURE: Dn_supernet 15 | META_MODE: Width 16 | NUM_STRIDES: 3 17 | NUM_LAYERS: 3 18 | NUM_BLOCKS: 4 19 | FILTER_MULTIPLIER: 10 20 | IN_CHANNEL: 1 21 | PRIMITIVES: "NO_DEF_L" 22 | ACTIVATION_F: "Leaky" 23 | AFFINE: False 24 | USE_ASPP: False 25 | USE_RES: True 26 | DATALOADER: 27 | BATCH_SIZE_TRAIN: 6 28 | BATCH_SIZE_TEST: 12 29 | NUM_WORKERS: 4 30 | SIGMA: [30, 50, 70] 31 | DATA_LIST_DIR: ../preprocess/dataset_json 32 | 33 | SOLVER: 34 | LOSS: ['mse', 'log_ssim'] 35 | LOSS_WEIGHT: [1.0, 0.5] 36 | MAX_EPOCH: 100 37 | 38 | OUTPUT_DIR: output_R10 39 | -------------------------------------------------------------------------------- /configs/dn/BSD500_3c4n/03_search_CR_R20.yaml: -------------------------------------------------------------------------------- 1 | DATASET: 2 | DATA_ROOT: /home/ren/hkzhang/data_p/nas_data 3 | DATA_NAME: BSD500_300 4 | CROP_SIZE: 64 5 | TASK: "dn" 6 | LOAD_ALL: True 7 | TO_GRAY: True 8 | SEARCH: 9 | SEARCH_ON: True 10 | ARCH_START_EPOCH: 20 11 | TIE_CELL: False 12 | R_SEED: 20 13 | MODEL: 14 | META_ARCHITECTURE: Dn_supernet 15 | META_MODE: Width 16 | NUM_STRIDES: 3 17 | NUM_LAYERS: 3 18 | NUM_BLOCKS: 4 19 | FILTER_MULTIPLIER: 10 20 | IN_CHANNEL: 1 21 | PRIMITIVES: "NO_DEF_L" 22 | ACTIVATION_F: "Leaky" 23 | AFFINE: False 24 | USE_ASPP: False 25 | USE_RES: True 26 | DATALOADER: 27 | BATCH_SIZE_TRAIN: 6 28 | BATCH_SIZE_TEST: 12 29 | NUM_WORKERS: 4 30 | SIGMA: [30, 50, 70] 31 | DATA_LIST_DIR: ../preprocess/dataset_json 32 | 33 | SOLVER: 34 | LOSS: ['mse', 'log_ssim'] 35 | LOSS_WEIGHT: [1.0, 0.5] 36 | MAX_EPOCH: 100 37 | 38 | OUTPUT_DIR: output_R20 39 | -------------------------------------------------------------------------------- /configs/dn/BSD500_3c4n/03_train_CR_R0/train_s30.yaml: -------------------------------------------------------------------------------- 1 | DATASET: 2 | DATA_ROOT: /home/ren/hkzhang/data_p/nas_data 3 | DATA_NAME: BSD500_300 4 | CROP_SIZE: 64 5 | TASK: "dn" 6 | LOAD_ALL: True 7 | TO_GRAY: True 8 | SEARCH: 9 | TIE_CELL: False 10 | INPUT: 11 | CROP_SIZE_TRAIN: 64 12 | SOLVER: 13 | TRAIN: 14 | MAX_ITER: 600000 15 | CHECKPOINT_PERIOD: 1000 16 | VALIDATE_PERIOD: 1000 17 | LOSS: ['mse', 'log_ssim'] 18 | LOSS_WEIGHT: [1.0, 0.6] 19 | DATALOADER: 20 | NUM_WORKERS: 2 21 | BATCH_SIZE_TRAIN: 24 22 | BATCH_SIZE_TEST: 24 23 | SIGMA: [30] 24 | DATA_AUG: 5 25 | MODEL: 26 | FILTER_MULTIPLIER: 20 27 | META_ARCHITECTURE: Dn_compnet 28 | META_MODE: Width 29 | NUM_STRIDES: 3 30 | NUM_LAYERS: 3 31 | NUM_BLOCKS: 4 32 | IN_CHANNEL: 1 33 | PRIMITIVES: "NO_DEF_L" 34 | ACTIVATION_F: "Leaky" 35 | USE_ASPP: False 36 | USE_RES: True 37 | 38 | OUTPUT_DIR: output_R0 39 | 40 | -------------------------------------------------------------------------------- /configs/dn/BSD500_3c4n/03_train_CR_R0/train_s50.yaml: -------------------------------------------------------------------------------- 1 | DATASET: 2 | DATA_ROOT: /home/ren/hkzhang/data_p/nas_data 3 | DATA_NAME: BSD500_300 4 | CROP_SIZE: 64 5 | TASK: "dn" 6 | LOAD_ALL: True 7 | TO_GRAY: True 8 | SEARCH: 9 | TIE_CELL: False 10 | INPUT: 11 | CROP_SIZE_TRAIN: 64 12 | SOLVER: 13 | TRAIN: 14 | MAX_ITER: 600000 15 | CHECKPOINT_PERIOD: 1000 16 | VALIDATE_PERIOD: 1000 17 | LOSS: ['mse', 'log_ssim'] 18 | LOSS_WEIGHT: [1.0, 0.6] 19 | DATALOADER: 20 | NUM_WORKERS: 2 21 | BATCH_SIZE_TRAIN: 24 22 | BATCH_SIZE_TEST: 24 23 | SIGMA: [50] 24 | DATA_AUG: 5 25 | MODEL: 26 | FILTER_MULTIPLIER: 20 27 | META_ARCHITECTURE: Dn_compnet 28 | META_MODE: Width 29 | NUM_STRIDES: 3 30 | NUM_LAYERS: 3 31 | NUM_BLOCKS: 4 32 | IN_CHANNEL: 1 33 | PRIMITIVES: "NO_DEF_L" 34 | ACTIVATION_F: "Leaky" 35 | USE_ASPP: False 36 | USE_RES: True 37 | 38 | OUTPUT_DIR: output_R0 39 | 40 | -------------------------------------------------------------------------------- /configs/dn/BSD500_3c4n/03_train_CR_R0/train_s70.yaml: -------------------------------------------------------------------------------- 1 | DATASET: 2 | DATA_ROOT: /home/ren/hkzhang/data_p/nas_data 3 | DATA_NAME: BSD500_300 4 | CROP_SIZE: 64 5 | TASK: "dn" 6 | LOAD_ALL: True 7 | TO_GRAY: True 8 | SEARCH: 9 | TIE_CELL: False 10 | INPUT: 11 | CROP_SIZE_TRAIN: 64 12 | SOLVER: 13 | TRAIN: 14 | MAX_ITER: 600000 15 | CHECKPOINT_PERIOD: 1000 16 | VALIDATE_PERIOD: 1000 17 | LOSS: ['mse', 'log_ssim'] 18 | LOSS_WEIGHT: [1.0, 0.6] 19 | DATALOADER: 20 | NUM_WORKERS: 2 21 | BATCH_SIZE_TRAIN: 24 22 | BATCH_SIZE_TEST: 24 23 | SIGMA: [70] 24 | DATA_AUG: 5 25 | MODEL: 26 | FILTER_MULTIPLIER: 20 27 | META_ARCHITECTURE: Dn_compnet 28 | META_MODE: Width 29 | NUM_STRIDES: 3 30 | NUM_LAYERS: 3 31 | NUM_BLOCKS: 4 32 | IN_CHANNEL: 1 33 | PRIMITIVES: "NO_DEF_L" 34 | ACTIVATION_F: "Leaky" 35 | USE_ASPP: False 36 | USE_RES: True 37 | 38 | OUTPUT_DIR: output_R0 39 | 40 | -------------------------------------------------------------------------------- /configs/sr/DIV2K_2c3n/03_search_CR.yaml: -------------------------------------------------------------------------------- 1 | DATASET: 2 | DATA_ROOT: /home/zhanghaokui/data/nas_data 3 | # DATA_ROOT: /data/data2/zhk218/data/nas_data 4 | DATA_NAME: DIV2K_800 5 | CROP_SIZE: 64 6 | TASK: "sr" 7 | LOAD_ALL: False 8 | SEARCH: 9 | SEARCH_ON: True 10 | ARCH_START_EPOCH: 20 11 | TIE_CELL: False 12 | VAL_PORTION: 0.05 13 | MODEL: 14 | META_ARCHITECTURE: Sr_supernet 15 | META_MODE: Width 16 | NUM_STRIDES: 3 17 | NUM_LAYERS: 2 18 | NUM_BLOCKS: 3 19 | FILTER_MULTIPLIER: 8 20 | IN_CHANNEL: 3 21 | PRIMITIVES: "NO_DEF_L" 22 | ACTIVATION_F: "Leaky" 23 | AFFINE: False 24 | USE_ASPP: True 25 | USE_RES: True 26 | DATALOADER: 27 | BATCH_SIZE_TRAIN: 24 28 | BATCH_SIZE_TEST: 24 29 | NUM_WORKERS: 2 30 | S_FACTOR: 4 31 | DATA_LIST_DIR: ../preprocess/dataset_json 32 | 33 | SOLVER: 34 | LOSS: ['l1', 'log_ssim'] 35 | LOSS_WEIGHT: [1.0, 0.6] 36 | MAX_EPOCH: 100 37 | 38 | OUTPUT_DIR: output 39 | -------------------------------------------------------------------------------- /configs/sr/DIV2K_2c3n/03_x2_infe_CR.yaml: -------------------------------------------------------------------------------- 1 | DATASET: 2 | DATA_ROOT: /home/zhanghaokui/data/nas_data 3 | DATA_NAME: DIV2K_800 4 | #TEST_DATASETS: [Set5, Set14, BSD100, Urban100, Manga109] 5 | TEST_DATASETS: [Set5, Set14, BSD100, Urban100, Manga109] 6 | CROP_SIZE: 64 7 | TASK: "sr" 8 | TO_GRAY: False 9 | SEARCH: 10 | TIE_CELL: False 11 | DATALOADER: 12 | BATCH_SIZE_TEST: 32 13 | S_FACTOR: 2 14 | MODEL: 15 | FILTER_MULTIPLIER: 16 16 | META_ARCHITECTURE: Sr_compnet 17 | META_MODE: Width 18 | NUM_STRIDES: 3 19 | NUM_LAYERS: 2 20 | NUM_BLOCKS: 3 21 | IN_CHANNEL: 3 22 | PRIMITIVES: "NO_DEF_L" 23 | ACTIVATION_F: "Leaky" 24 | USE_ASPP: True 25 | USE_RES: True 26 | 27 | OUTPUT_DIR: output 28 | 29 | -------------------------------------------------------------------------------- /configs/sr/DIV2K_2c3n/03_x2_train_CR.yaml: -------------------------------------------------------------------------------- 1 | DATASET: 2 | DATA_ROOT: /home/zhanghaokui/data/nas_data 3 | # DATA_ROOT: /data/data2/zhk218/data/nas_data 4 | DATA_NAME: DIV2K_800 5 | CROP_SIZE: 64 6 | TASK: "sr" 7 | LOAD_ALL: False 8 | SEARCH: 9 | TIE_CELL: False 10 | INPUT: 11 | CROP_SIZE_TRAIN: 64 12 | SOLVER: 13 | TRAIN: 14 | MAX_ITER: 600000 15 | CHECKPOINT_PERIOD: 1000 16 | VALIDATE_PERIOD: 1000 17 | LOSS: ['l1', 'log_ssim'] 18 | LOSS_WEIGHT: [1.0, 0.6] 19 | DATALOADER: 20 | NUM_WORKERS: 4 21 | BATCH_SIZE_TRAIN: 16 22 | BATCH_SIZE_TEST: 16 23 | S_FACTOR: 2 24 | R_CROP: 4 25 | MODEL: 26 | FILTER_MULTIPLIER: 16 27 | META_ARCHITECTURE: Sr_compnet 28 | META_MODE: Width 29 | NUM_STRIDES: 3 30 | NUM_LAYERS: 2 31 | NUM_BLOCKS: 3 32 | IN_CHANNEL: 3 33 | PRIMITIVES: "NO_DEF_L" 34 | ACTIVATION_F: "Leaky" 35 | USE_ASPP: True 36 | USE_RES: True 37 | 38 | OUTPUT_DIR: output 39 | 40 | -------------------------------------------------------------------------------- /configs/sr/DIV2K_2c3n/03_x3_infe_CR.yaml: -------------------------------------------------------------------------------- 1 | DATASET: 2 | DATA_ROOT: /home/zhanghaokui/data/nas_data 3 | DATA_NAME: DIV2K_800 4 | #TEST_DATASETS: [Set5, Set14, BSD100, Urban100, Manga109] 5 | TEST_DATASETS: [Set5, Set14, BSD100, Urban100, Manga109] 6 | CROP_SIZE: 64 7 | TASK: "sr" 8 | TO_GRAY: False 9 | SEARCH: 10 | TIE_CELL: False 11 | DATALOADER: 12 | BATCH_SIZE_TEST: 32 13 | S_FACTOR: 3 14 | MODEL: 15 | FILTER_MULTIPLIER: 16 16 | META_ARCHITECTURE: Sr_compnet 17 | META_MODE: Width 18 | NUM_STRIDES: 3 19 | NUM_LAYERS: 2 20 | NUM_BLOCKS: 3 21 | IN_CHANNEL: 3 22 | PRIMITIVES: "NO_DEF_L" 23 | ACTIVATION_F: "Leaky" 24 | USE_ASPP: True 25 | USE_RES: True 26 | 27 | OUTPUT_DIR: output 28 | 29 | -------------------------------------------------------------------------------- /configs/sr/DIV2K_2c3n/03_x3_train_CR.yaml: -------------------------------------------------------------------------------- 1 | DATASET: 2 | DATA_ROOT: /home/zhanghaokui/data/nas_data 3 | # DATA_ROOT: /data/data2/zhk218/data/nas_data 4 | DATA_NAME: DIV2K_800 5 | CROP_SIZE: 64 6 | TASK: "sr" 7 | LOAD_ALL: False 8 | SEARCH: 9 | TIE_CELL: False 10 | INPUT: 11 | CROP_SIZE_TRAIN: 64 12 | SOLVER: 13 | TRAIN: 14 | MAX_ITER: 600000 15 | CHECKPOINT_PERIOD: 1000 16 | VALIDATE_PERIOD: 1000 17 | LOSS: ['l1', 'log_ssim'] 18 | LOSS_WEIGHT: [1.0, 0.6] 19 | DATALOADER: 20 | NUM_WORKERS: 4 21 | BATCH_SIZE_TRAIN: 16 22 | BATCH_SIZE_TEST: 16 23 | S_FACTOR: 3 24 | R_CROP: 4 25 | MODEL: 26 | FILTER_MULTIPLIER: 16 27 | META_ARCHITECTURE: Sr_compnet 28 | META_MODE: Width 29 | NUM_STRIDES: 3 30 | NUM_LAYERS: 2 31 | NUM_BLOCKS: 3 32 | IN_CHANNEL: 3 33 | PRIMITIVES: "NO_DEF_L" 34 | ACTIVATION_F: "Leaky" 35 | USE_ASPP: True 36 | USE_RES: True 37 | 38 | OUTPUT_DIR: output 39 | 40 | -------------------------------------------------------------------------------- /configs/sr/DIV2K_2c3n/03_x4_infe_CR.yaml: -------------------------------------------------------------------------------- 1 | DATASET: 2 | DATA_ROOT: /home/zhanghaokui/data/nas_data 3 | DATA_NAME: DIV2K_800 4 | #TEST_DATASETS: [Set5, Set14, BSD100, Urban100, Manga109] 5 | TEST_DATASETS: [Set5, Set14, BSD100, Urban100, Manga109] 6 | CROP_SIZE: 64 7 | TASK: "sr" 8 | TO_GRAY: False 9 | SEARCH: 10 | TIE_CELL: False 11 | DATALOADER: 12 | BATCH_SIZE_TEST: 32 13 | S_FACTOR: 4 14 | MODEL: 15 | FILTER_MULTIPLIER: 16 16 | META_ARCHITECTURE: Sr_compnet 17 | META_MODE: Width 18 | NUM_STRIDES: 3 19 | NUM_LAYERS: 2 20 | NUM_BLOCKS: 3 21 | IN_CHANNEL: 3 22 | PRIMITIVES: "NO_DEF_L" 23 | ACTIVATION_F: "Leaky" 24 | USE_ASPP: True 25 | USE_RES: True 26 | 27 | OUTPUT_DIR: output 28 | 29 | -------------------------------------------------------------------------------- /configs/sr/DIV2K_2c3n/03_x4_train_CR.yaml: -------------------------------------------------------------------------------- 1 | DATASET: 2 | DATA_ROOT: /home/zhanghaokui/data/nas_data 3 | # DATA_ROOT: /data/data2/zhk218/data/nas_data 4 | DATA_NAME: DIV2K_800 5 | CROP_SIZE: 64 6 | TASK: "sr" 7 | LOAD_ALL: False 8 | SEARCH: 9 | TIE_CELL: False 10 | INPUT: 11 | CROP_SIZE_TRAIN: 64 12 | SOLVER: 13 | TRAIN: 14 | MAX_ITER: 600000 15 | CHECKPOINT_PERIOD: 1000 16 | VALIDATE_PERIOD: 1000 17 | LOSS: ['l1', 'log_ssim'] 18 | LOSS_WEIGHT: [1.0, 0.6] 19 | DATALOADER: 20 | NUM_WORKERS: 4 21 | BATCH_SIZE_TRAIN: 16 22 | BATCH_SIZE_TEST: 16 23 | S_FACTOR: 4 24 | R_CROP: 4 25 | MODEL: 26 | FILTER_MULTIPLIER: 16 27 | META_ARCHITECTURE: Sr_compnet 28 | META_MODE: Width 29 | NUM_STRIDES: 3 30 | NUM_LAYERS: 2 31 | NUM_BLOCKS: 3 32 | IN_CHANNEL: 3 33 | PRIMITIVES: "NO_DEF_L" 34 | ACTIVATION_F: "Leaky" 35 | USE_ASPP: True 36 | USE_RES: True 37 | 38 | OUTPUT_DIR: output 39 | 40 | -------------------------------------------------------------------------------- /configs/sr/DIV2K_2c3n/03_x8_infe_CR.yaml: -------------------------------------------------------------------------------- 1 | DATASET: 2 | DATA_ROOT: /home/zhanghaokui/data/nas_data 3 | DATA_NAME: DIV2K_800 4 | #TEST_DATASETS: [Set5, Set14, BSD100, Urban100, Manga109] 5 | TEST_DATASETS: [Set5, Set14, BSD100, Urban100, Manga109] 6 | CROP_SIZE: 64 7 | TASK: "sr" 8 | TO_GRAY: False 9 | SEARCH: 10 | TIE_CELL: False 11 | DATALOADER: 12 | BATCH_SIZE_TEST: 32 13 | S_FACTOR: 8 14 | MODEL: 15 | FILTER_MULTIPLIER: 16 16 | META_ARCHITECTURE: Sr_compnet 17 | META_MODE: Width 18 | NUM_STRIDES: 3 19 | NUM_LAYERS: 2 20 | NUM_BLOCKS: 3 21 | IN_CHANNEL: 3 22 | PRIMITIVES: "NO_DEF_L" 23 | ACTIVATION_F: "Leaky" 24 | USE_ASPP: True 25 | USE_RES: True 26 | 27 | OUTPUT_DIR: output 28 | 29 | -------------------------------------------------------------------------------- /configs/sr/DIV2K_2c3n/03_x8_train_CR.yaml: -------------------------------------------------------------------------------- 1 | DATASET: 2 | DATA_ROOT: /home/zhanghaokui/data/nas_data 3 | # DATA_ROOT: /data/data2/zhk218/data/nas_data 4 | DATA_NAME: DIV2K_800 5 | CROP_SIZE: 64 6 | TASK: "sr" 7 | LOAD_ALL: False 8 | SEARCH: 9 | TIE_CELL: False 10 | INPUT: 11 | CROP_SIZE_TRAIN: 64 12 | SOLVER: 13 | TRAIN: 14 | MAX_ITER: 600000 15 | CHECKPOINT_PERIOD: 1000 16 | VALIDATE_PERIOD: 1000 17 | LOSS: ['l1', 'log_ssim'] 18 | LOSS_WEIGHT: [1.0, 0.6] 19 | DATALOADER: 20 | NUM_WORKERS: 4 21 | BATCH_SIZE_TRAIN: 16 22 | BATCH_SIZE_TEST: 16 23 | S_FACTOR: 8 24 | R_CROP: 4 25 | MODEL: 26 | FILTER_MULTIPLIER: 16 27 | META_ARCHITECTURE: Sr_compnet 28 | META_MODE: Width 29 | NUM_STRIDES: 3 30 | NUM_LAYERS: 2 31 | NUM_BLOCKS: 3 32 | IN_CHANNEL: 3 33 | PRIMITIVES: "NO_DEF_L" 34 | ACTIVATION_F: "Leaky" 35 | USE_ASPP: True 36 | USE_RES: True 37 | 38 | OUTPUT_DIR: output 39 | 40 | -------------------------------------------------------------------------------- /data_generation/rename.py: -------------------------------------------------------------------------------- 1 | from glob import glob 2 | from PIL import Image 3 | import os 4 | 5 | 6 | set = 'sigma70' 7 | name_source = '/home/hkzhang/Documents/sdb_a/nas_data/denoise/BSD500_200' 8 | image_source = '/home/hkzhang/Documents/codes/Architecture_search/projects/comparison_methods/denoise/n3net-master/src_denoising/BSD500_result' 9 | save_dir = image_source + '/{}_resort'.format(set) 10 | 11 | name_source_list = glob(name_source + '/*.jpg') 12 | name_list = [item.split('/')[-1].split('.')[0] for item in name_source_list] 13 | name_list.sort() 14 | 15 | image_list = glob(os.path.join(image_source, set, '*.jpg')) 16 | image_list.sort() 17 | for im_dir, new_id in zip(image_list, name_list): 18 | im=Image.open(im_dir) 19 | im.save(save_dir + '/{}.jpg'.format(new_id)) 20 | -------------------------------------------------------------------------------- /data_generation/rgb2gray_add_noise.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | from PIL import Image 4 | from glob import glob 5 | 6 | def make_if_not_exist(path): 7 | if not os.path.exists(path): 8 | os.makedirs(path) 9 | 10 | image_dir = '/home/hkzhang/Documents/sdb_a/nas_data/denoise/BSD500_200' 11 | gray_save_dir = image_dir + '_gray/' 12 | 13 | # sigmas=[30, 50, 70] 14 | # for sigma in sigmas: 15 | # make_if_not_exist(image_dir + '_sigma_{}'.format(sigma)) 16 | # 17 | # make_if_not_exist(gray_save_dir) 18 | 19 | image_list = glob(os.path.join(image_dir, '*.jpg')) 20 | for im_dir in image_list: 21 | im_id = im_dir.split('/')[-1][:-4] 22 | im = Image.open(im_dir).convert('L') 23 | im.save(gray_save_dir+ im_id + '.png') 24 | im_clean = np.array(im) 25 | 26 | for sigma in sigmas: 27 | im_noise = np.array(im_clean + np.random.normal(0, 1, size=im_clean.shape) * sigma, np.uint8) 28 | im_noise = Image.fromarray(im_noise) 29 | im_noise.save(image_dir + '_sigma_{}'.format(sigma) + '/{}'.format(im_id) + '.jpg') 30 | 31 | 32 | 33 | 34 | 35 | 36 | -------------------------------------------------------------------------------- /log file for DNAS_For_IR.txt: -------------------------------------------------------------------------------- 1 | v1->v2: 2 | 1) untie the cell structure, cells in different layers use different architectures and cells in the same layer share the same structure. 3 | 2) use residual learning. Predicting the residuals between high quality images and low quality images, accelerating the conversicing speed and improving performance. 4 | 3) bulid new search space. Replacing deformable conv with relatively normal conv operations, as deformable conv is time consuming. 5 | 4) new activation function. using Leaky RelU instead of ReLU in new version. 6 | 7 | v2->v3 8 | 1) optimizer, replacing sgd optimizer with adam optimizer 9 | 2) sigmoid acitivation. Abandoning the sigmoid activation function in the last conv layer. 10 | 11 | v3->v4 12 | 1) upsample sample method is set to bicubic 13 | 2) more flexible width change 14 | 15 | -------------------------------------------------------------------------------- /one_stage_nas/config/__init__.py: -------------------------------------------------------------------------------- 1 | from .defaults import _C as cfg 2 | -------------------------------------------------------------------------------- /one_stage_nas/config/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hkzhang-git/HiNAS/fc56724fde81870ac90bd6d87075f0d1284b1f8c/one_stage_nas/config/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /one_stage_nas/config/__pycache__/defaults.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hkzhang-git/HiNAS/fc56724fde81870ac90bd6d87075f0d1284b1f8c/one_stage_nas/config/__pycache__/defaults.cpython-36.pyc -------------------------------------------------------------------------------- /one_stage_nas/config/__pycache__/paths_catalog.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hkzhang-git/HiNAS/fc56724fde81870ac90bd6d87075f0d1284b1f8c/one_stage_nas/config/__pycache__/paths_catalog.cpython-36.pyc -------------------------------------------------------------------------------- /one_stage_nas/config/defaults.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 2 | from yacs.config import CfgNode as CN 3 | _C = CN() 4 | 5 | 6 | # ----------------------------------------------------------------------------- 7 | # SEARCH 8 | # ----------------------------------------------------------------------------- 9 | _C.SEARCH = CN() 10 | _C.SEARCH.ARCH_START_EPOCH = 20 11 | _C.SEARCH.VAL_PORTION = 0.02 12 | _C.SEARCH.PORTION = 0.5 13 | _C.SEARCH.SEARCH_ON = False 14 | _C.SEARCH.TIE_CELL = True 15 | _C.SEARCH.R_SEED = 0 16 | 17 | 18 | 19 | # ----------------------------------------------------------------------------- 20 | # MODEL 21 | # ----------------------------------------------------------------------------- 22 | _C.MODEL = CN() 23 | _C.MODEL.META_ARCHITECTURE = "AutoMultiTask" 24 | _C.MODEL.FILTER_MULTIPLIER = 20 25 | _C.MODEL.NUM_LAYERS = 12 26 | _C.MODEL.NUM_BLOCKS = 5 27 | _C.MODEL.NUM_STRIDES = 3 28 | _C.MODEL.WS_FACTORS = [1, 1.5, 2] 29 | _C.MODEL.IN_CHANNEL = 3 30 | _C.MODEL.AFFINE = True 31 | _C.MODEL.WEIGHT = "" # Init weights 32 | _C.MODEL.PRIMITIVES = "NO_DEF_R" 33 | _C.MODEL.ACTIVATION_F = "ReLU" 34 | _C.MODEL.ASPP_RATES = (2, 4, 6) 35 | _C.MODEL.META_MODE = "Scale" 36 | _C.MODEL.USE_ASPP = True 37 | _C.MODEL.USE_RES = False 38 | _C.MODEL.RES = "add" # add | mul 39 | 40 | 41 | # ----------------------------------------------------------------------------- 42 | # INPUT 43 | # ----------------------------------------------------------------------------- 44 | _C.INPUT = CN() 45 | # Size of the smallest side of the image during training 46 | _C.INPUT.MIN_SIZE_TRAIN = -1 47 | # Crop size of the side of the image during training 48 | _C.INPUT.CROP_SIZE_TRAIN = 128 49 | # Maximum size of the side of the image during training 50 | _C.INPUT.MAX_SIZE_TRAIN = 1024 51 | _C.INPUT.MIN_SIZE_TEST = -1 52 | _C.INPUT.MAX_SIZE_TEST = 1024 53 | 54 | 55 | # ----------------------------------------------------------------------------- 56 | # Dataset 57 | # ----------------------------------------------------------------------------- 58 | _C.DATASET = CN() 59 | # _C.DATASET.DATA_ROOT = "/home/hkzhang/Documents/sdb_a/nas_data" 60 | _C.DATASET.DATA_ROOT = "/home/a1224062/g_acvt/Architecture_search/nas_data" 61 | _C.DATASET.DATA_NAME = "rain800" 62 | _C.DATASET.TRAIN_DATASETS = [] 63 | _C.DATASET.TRAIN_DATASETS_WEIGHT = [] 64 | _C.DATASET.TEST_DATASETS = [] 65 | _C.DATASET.CROP_SIZE = 128 66 | _C.DATASET.TASK = "derain" 67 | _C.DATASET.LOAD_ALL = False 68 | _C.DATASET.TO_GRAY = False 69 | 70 | # ----------------------------------------------------------------------------- 71 | # DataLoader 72 | # ----------------------------------------------------------------------------- 73 | _C.DATALOADER = CN() 74 | # Number of data loading threads 75 | _C.DATALOADER.NUM_WORKERS = 4 76 | _C.DATALOADER.BATCH_SIZE_TRAIN = 2 77 | _C.DATALOADER.BATCH_SIZE_TEST = 2 78 | _C.DATALOADER.SIGMA = [] 79 | _C.DATALOADER.S_FACTOR = 1 80 | _C.DATALOADER.DATA_LIST_DIR = "../preprocess/dataset_json" 81 | _C.DATALOADER.DATA_AUG = 1 82 | _C.DATALOADER.R_CROP = 1 83 | 84 | # ----------------------------------------------------------------------------- 85 | # Solver 86 | # ----------------------------------------------------------------------------- 87 | _C.SOLVER = CN() 88 | _C.SOLVER.MAX_EPOCH = 60 89 | _C.SOLVER.BIAS_LR_FACTOR = 2 90 | _C.SOLVER.WEIGHT_DECAY_BIAS = 0 91 | _C.SOLVER.WEIGHT_DECAY = 0.00004 92 | _C.SOLVER.MOMENTUM = 0.9 93 | # cosine learning rate 94 | _C.SOLVER.SEARCH = CN() 95 | _C.SOLVER.SEARCH.LR_START = 0.025 96 | _C.SOLVER.SEARCH.LR_END = 0.001 97 | _C.SOLVER.SEARCH.MOMENTUM = 0.9 98 | _C.SOLVER.SEARCH.WEIGHT_DECAY = 0.0003 99 | # architecture encoding Adam params 100 | _C.SOLVER.SEARCH.LR_A = 0.001 # learning rate 101 | _C.SOLVER.SEARCH.WD_A = 0.001 # weight decay 102 | _C.SOLVER.SEARCH.T_MAX = 10 # cosine lr time 103 | 104 | _C.SOLVER.TRAIN = CN() 105 | _C.SOLVER.TRAIN.INIT_LR = 0.05 106 | _C.SOLVER.TRAIN.POWER = 0.9 107 | _C.SOLVER.TRAIN.MAX_ITER = 500000 108 | _C.SOLVER.TRAIN.VAL_PORTION = 0.01 109 | _C.SOLVER.SCHEDULER = 'poly' # poly lr 110 | _C.SOLVER.CHECKPOINT_PERIOD = 10 111 | _C.SOLVER.VALIDATE_PERIOD = 1 112 | 113 | _C.SOLVER.LOSS = ['l1', 'neg_ssim', 'grad'] 114 | _C.SOLVER.LOSS_WEIGHT = [1, 1, 1] 115 | 116 | 117 | # ---------------------------------------------------------------------------- # 118 | # Misc options 119 | # ---------------------------------------------------------------------------- # 120 | _C.OUTPUT_DIR = "." 121 | _C.RESULT_DIR = "." 122 | 123 | -------------------------------------------------------------------------------- /one_stage_nas/config/paths_catalog.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | 4 | class DatasetCatalog(object): 5 | #DATA_DIR = "datasets" 6 | DATA_DIR = '/home/hkzhang/Documents/sdb_a/segmentation' 7 | DATASETS = { 8 | "cityscapes_fine_train": { 9 | "data_file": "cityscapes/lists/train.lst", 10 | "data_dir": "cityscapes" 11 | }, 12 | "cityscapes_fine_val": { 13 | "data_file": "cityscapes/lists/val.lst", 14 | "data_dir": "cityscapes" 15 | } 16 | } 17 | 18 | @staticmethod 19 | def get(name): 20 | if 'cityscapes' in name: 21 | data_dir = DatasetCatalog.DATA_DIR 22 | attrs = DatasetCatalog.DATASETS[name] 23 | args = dict( 24 | data_file=os.path.join(data_dir, attrs['data_file']), 25 | data_dir=os.path.join(data_dir, attrs['data_dir'])) 26 | return dict( 27 | factory="CityscapesDataset", 28 | args=args) 29 | raise RuntimeError("Dataset not available: {}".format(name)) 30 | 31 | 32 | 33 | -------------------------------------------------------------------------------- /one_stage_nas/darts/__pycache__/cell.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hkzhang-git/HiNAS/fc56724fde81870ac90bd6d87075f0d1284b1f8c/one_stage_nas/darts/__pycache__/cell.cpython-36.pyc -------------------------------------------------------------------------------- /one_stage_nas/darts/__pycache__/genotypes.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hkzhang-git/HiNAS/fc56724fde81870ac90bd6d87075f0d1284b1f8c/one_stage_nas/darts/__pycache__/genotypes.cpython-36.pyc -------------------------------------------------------------------------------- /one_stage_nas/darts/__pycache__/operations.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hkzhang-git/HiNAS/fc56724fde81870ac90bd6d87075f0d1284b1f8c/one_stage_nas/darts/__pycache__/operations.cpython-36.pyc -------------------------------------------------------------------------------- /one_stage_nas/darts/cell.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | from one_stage_nas.utils.comm import drop_path 5 | from .operations import OPS, Identity 6 | 7 | 8 | class MixedOp(nn.Module): 9 | 10 | def __init__(self, C, primitives, affine=True): 11 | super(MixedOp, self).__init__() 12 | self._ops = nn.ModuleList() 13 | for primitive in primitives: 14 | op = OPS[primitive](C, 1, affine) 15 | self._ops.append(op) 16 | 17 | def forward(self, x, weights): 18 | return sum(w * op(x) for w, op in zip(weights, self._ops)) 19 | 20 | 21 | class Cell(nn.Module): 22 | """ 23 | Generate outputs for next layer inside 24 | """ 25 | 26 | def __init__(self, blocks, C, primitives, empty_H1=False, affine=True): 27 | """ 28 | Arguments: 29 | blocks (int): number of blocks in cell 30 | inp (int): channels, reshaped and weighted by previous cells 31 | empty_H1: if True, hidden state l-2 is empty 32 | 33 | Returns: 34 | h_next (Tensor BxinpxWxH): next hidden state 35 | """ 36 | super(Cell, self).__init__() 37 | 38 | self._steps = blocks 39 | self._multiplier = blocks 40 | self._ops = nn.ModuleList() 41 | self._empty_h1 = empty_H1 42 | self._primitives = primitives 43 | 44 | for i in range(self._steps): 45 | for j in range(2 + i): 46 | op = MixedOp(C, primitives, affine) 47 | self._ops.append(op) 48 | 49 | def forward(self, s0, s1, weights): 50 | states = [s1, s0] 51 | 52 | offset = 0 53 | for i in range(self._steps): 54 | # summing up all branches, if H1 is empty, skip the first weight 55 | s = sum(self._ops[offset+j](h, weights[offset+j]) 56 | for j, h in enumerate(states) 57 | if not self._empty_h1 or j > 0) 58 | offset += len(states) 59 | # summing counting weights 60 | states.append(s) 61 | 62 | out = torch.cat(states[-self._multiplier:], dim=1) 63 | return out 64 | 65 | def genotype(self, weights): 66 | """ 67 | get cell genotype 68 | """ 69 | gene = [] 70 | n = 2 71 | start = 0 72 | for i in range(self._steps): 73 | end = start + n 74 | W = weights[start:end].clone().detach() 75 | edges = sorted(range(i+2), 76 | key=lambda x: -max(W[x][k] 77 | for k in range( 78 | len(W[x])) 79 | if k != self._primitives.index('none')))[:2] 80 | for j in edges: 81 | k_best = None 82 | for k in range(len(W[j])): 83 | if k != self._primitives.index('none'): 84 | if k_best is None or W[j][k] > W[j][k_best]: 85 | k_best = k 86 | gene.append((self._primitives[k_best], j)) 87 | start = end 88 | n += 1 89 | return gene 90 | 91 | 92 | class FixCell(nn.Module): 93 | def __init__(self, genotype, C, repeats=2): 94 | """ 95 | Arguments: 96 | genotype: cell structure 97 | C (int): channels, reshaped and weighted by previous cells 98 | empty_H1: if True, hidden state l-2 is empty 99 | 100 | Returns: 101 | h_next (Tensor BxinpxWxH): next hidden state 102 | """ 103 | super(FixCell, self).__init__() 104 | 105 | op_names, indices = zip(*genotype) 106 | 107 | self._steps = len(op_names) // 2 108 | self._multiplier = self._steps 109 | self._ops = nn.ModuleList() 110 | # self.empty_h1 = empty_H1 111 | 112 | for name, index in zip(op_names, indices): 113 | op = OPS[name](C, 1, True) 114 | self._ops.append(op) 115 | 116 | self._indices = indices 117 | 118 | def forward(self, s0, s1, drop_prob): 119 | states = [s1, s0] 120 | 121 | for i in range(self._steps): 122 | # summing up all branches, if H1 is empty, skip the first weight 123 | s = 0 124 | for ind in [2*i, 2*i+1]: 125 | # if self.empty_h1 and ind == 0: continue 126 | op = self._ops[ind] 127 | h = op(states[self._indices[ind]]) 128 | if self.training and drop_prob > 0: 129 | if not isinstance(op, Identity): 130 | h = drop_path(h, drop_prob) 131 | s = s + h 132 | states.append(s) 133 | return torch.cat(states[-self._multiplier:], dim=1) 134 | -------------------------------------------------------------------------------- /one_stage_nas/darts/genotypes.py: -------------------------------------------------------------------------------- 1 | NO_DEF_R = [ 2 | 'con_c_3x3_relu', 3 | 'sep_c_3x3_relu', 4 | 'sep_c_5x5_relu', 5 | 'dil_c_3x3_relu', 6 | 'dil_c_5x5_relu', 7 | 'skip_connect', 8 | 'none', 9 | ] 10 | 11 | NO_DEF_L = [ 12 | 'con_c_3x3_leaky', 13 | 'sep_c_3x3_leaky', 14 | 'sep_c_5x5_leaky', 15 | 'dil_c_3x3_leaky', 16 | 'dil_c_5x5_leaky', 17 | 'skip_connect', 18 | 'none', 19 | ] 20 | 21 | 22 | NO_DEF_P = [ 23 | 'con_c_3x3_prelu', 24 | 'sep_c_3x3_prelu', 25 | 'sep_c_5x5_prelu', 26 | 'dil_c_3x3_prelu', 27 | 'dil_c_5x5_prelu', 28 | 'skip_connect', 29 | 'none', 30 | ] 31 | 32 | 33 | NO_DEF_S = [ 34 | 'con_c_3x3_sine', 35 | 'sep_c_3x3_sine', 36 | 'sep_c_5x5_sine', 37 | 'dil_c_3x3_sine', 38 | 'dil_c_5x5_sine', 39 | 'skip_connect', 40 | 'none', 41 | ] 42 | 43 | 44 | PRIMITIVES = { 45 | "NO_DEF_R": NO_DEF_R, 46 | "NO_DEF_L": NO_DEF_L, 47 | "NO_DEF_P": NO_DEF_P, 48 | "NO_DEF_S": NO_DEF_S 49 | } 50 | -------------------------------------------------------------------------------- /one_stage_nas/data/__init__.py: -------------------------------------------------------------------------------- 1 | from .build_dataset import build_dataset 2 | from .build_dataset import build_transforms 3 | 4 | 5 | -------------------------------------------------------------------------------- /one_stage_nas/data/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hkzhang-git/HiNAS/fc56724fde81870ac90bd6d87075f0d1284b1f8c/one_stage_nas/data/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /one_stage_nas/data/__pycache__/build_dataset.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hkzhang-git/HiNAS/fc56724fde81870ac90bd6d87075f0d1284b1f8c/one_stage_nas/data/__pycache__/build_dataset.cpython-36.pyc -------------------------------------------------------------------------------- /one_stage_nas/data/build_dataset.py: -------------------------------------------------------------------------------- 1 | from .datasets.transforms import (RandomCrop, RandomMirror, RandomOverturn, 2 | RandomRotate, FourLRotate, Normalize, 3 | RandomRescaleCrop, ToTensor, NoiseToTensor, 4 | SID_RandomCrop, SID_RandomFlip, Ndarray2tensor, 5 | Sr_RandomCrop, Rescale, Compose) 6 | from .datasets.tasks_dict import tasks_dict 7 | import numpy as np 8 | import torch 9 | import json 10 | import os 11 | 12 | 13 | def json_loader(dict_file_dir): 14 | with open(dict_file_dir, 'r') as data_file: 15 | return json.load(data_file) 16 | 17 | 18 | def build_transforms(crop_size=None, task='dn', tag='train', sigma=[], s_factor=1): 19 | 20 | if task == 'dn': 21 | if tag == 'train': 22 | return Compose([ 23 | RandomRescaleCrop(crop_size), 24 | FourLRotate(), 25 | RandomMirror(), 26 | NoiseToTensor(sigma), 27 | Rescale() 28 | ]) 29 | elif tag == 'test': 30 | return Compose([ 31 | NoiseToTensor(sigma), 32 | Rescale(), 33 | ]) 34 | elif task == 'sid': 35 | if tag == 'train': 36 | return Compose([ 37 | SID_RandomCrop(crop_size), 38 | SID_RandomFlip(), 39 | Ndarray2tensor(), 40 | ]) 41 | elif tag == 'test': 42 | return Compose([ 43 | Ndarray2tensor(), 44 | ]) 45 | elif task == 'sr': 46 | if tag == 'train': 47 | return Compose([ 48 | Sr_RandomCrop(crop_size, s_factor), 49 | FourLRotate(), 50 | RandomMirror(), 51 | ToTensor(), 52 | Rescale() 53 | ]) 54 | elif tag == 'test': 55 | return Compose([ 56 | ToTensor(), 57 | Rescale(), 58 | ]) 59 | 60 | def build_dataset(cfg): 61 | data_root = cfg.DATASET.DATA_ROOT 62 | data_name = cfg.DATASET.DATA_NAME 63 | task = cfg.DATASET.TASK 64 | 65 | if cfg.SEARCH.SEARCH_ON: 66 | crop_size = cfg.DATASET.CROP_SIZE 67 | else: 68 | crop_size = cfg.INPUT.CROP_SIZE_TRAIN 69 | 70 | data_list_dir = cfg.DATALOADER.DATA_LIST_DIR 71 | num_workers = cfg.DATALOADER.NUM_WORKERS 72 | batch_size = cfg.DATALOADER.BATCH_SIZE_TRAIN 73 | 74 | search_on = cfg.SEARCH.SEARCH_ON 75 | transform = build_transforms(crop_size, task, tag='train', sigma=cfg.DATALOADER.SIGMA, s_factor=cfg.DATALOADER.S_FACTOR) 76 | 77 | if task in ['dn', 'sr']: 78 | data_dict = json_loader('/'.join((data_list_dir, task, data_name + '.json'))) 79 | elif task in ['sid']: 80 | data_dict = json_loader('/'.join((data_list_dir, task, data_name, 'train.json'))) 81 | 82 | if search_on: 83 | num_samples = len(data_dict) 84 | val_split = int(np.floor(cfg.SEARCH.VAL_PORTION * num_samples)) 85 | num_train = num_samples - val_split 86 | train_split = int(np.floor(cfg.SEARCH.PORTION * num_train)) 87 | w_data_list = [data_dict[i] for i in range(train_split)] 88 | a_data_list = [data_dict[i] for i in range(train_split, num_train)] 89 | v_data_list = [data_dict[i] for i in range(num_train, num_samples)] 90 | 91 | dataset_w = tasks_dict[task]('/'.join((data_root, task)), w_data_list, transform, 92 | cfg.DATASET.LOAD_ALL, cfg.DATASET.TO_GRAY, cfg.DATALOADER.S_FACTOR, cfg.DATALOADER.R_CROP) 93 | dataset_a = tasks_dict[task]('/'.join((data_root, task)), a_data_list, transform, 94 | cfg.DATASET.LOAD_ALL, cfg.DATASET.TO_GRAY, cfg.DATALOADER.S_FACTOR, cfg.DATALOADER.R_CROP) 95 | 96 | data_loader_w = torch.utils.data.DataLoader( 97 | dataset_w, 98 | shuffle=True, 99 | batch_size=batch_size // cfg.DATALOADER.R_CROP, 100 | num_workers=min(num_workers, batch_size // cfg.DATALOADER.R_CROP), 101 | pin_memory=True) 102 | 103 | data_loader_a = torch.utils.data.DataLoader( 104 | dataset_a, 105 | shuffle=True, 106 | batch_size=batch_size // cfg.DATALOADER.R_CROP, 107 | num_workers=min(num_workers, batch_size // cfg.DATALOADER.R_CROP), 108 | pin_memory=True) 109 | 110 | return [data_loader_w, data_loader_a], v_data_list 111 | else: 112 | num_samples = len(data_dict) 113 | val_split = int(np.floor(cfg.SEARCH.VAL_PORTION * num_samples)) 114 | num_train = num_samples - val_split 115 | 116 | t_data_list = [data_dict[i] for i in range(num_train)] 117 | v_data_list = [data_dict[i] for i in range(num_train, num_samples)] 118 | 119 | dataset_t = tasks_dict[task]('/'.join((data_root, task)), t_data_list, transform, 120 | cfg.DATASET.LOAD_ALL, cfg.DATASET.TO_GRAY, cfg.DATALOADER.S_FACTOR, cfg.DATALOADER.R_CROP) 121 | 122 | data_loader_t = torch.utils.data.DataLoader( 123 | dataset_t, 124 | shuffle=True, 125 | batch_size=batch_size // cfg.DATALOADER.R_CROP, 126 | num_workers=min(num_workers, batch_size // cfg.DATALOADER.R_CROP), 127 | pin_memory=True) 128 | 129 | return data_loader_t, v_data_list 130 | 131 | -------------------------------------------------------------------------------- /one_stage_nas/data/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /one_stage_nas/data/datasets/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hkzhang-git/HiNAS/fc56724fde81870ac90bd6d87075f0d1284b1f8c/one_stage_nas/data/datasets/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /one_stage_nas/data/datasets/__pycache__/denoise.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hkzhang-git/HiNAS/fc56724fde81870ac90bd6d87075f0d1284b1f8c/one_stage_nas/data/datasets/__pycache__/denoise.cpython-36.pyc -------------------------------------------------------------------------------- /one_stage_nas/data/datasets/__pycache__/see_in_dark.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hkzhang-git/HiNAS/fc56724fde81870ac90bd6d87075f0d1284b1f8c/one_stage_nas/data/datasets/__pycache__/see_in_dark.cpython-36.pyc -------------------------------------------------------------------------------- /one_stage_nas/data/datasets/__pycache__/super_resolution.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hkzhang-git/HiNAS/fc56724fde81870ac90bd6d87075f0d1284b1f8c/one_stage_nas/data/datasets/__pycache__/super_resolution.cpython-36.pyc -------------------------------------------------------------------------------- /one_stage_nas/data/datasets/__pycache__/tasks_dict.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hkzhang-git/HiNAS/fc56724fde81870ac90bd6d87075f0d1284b1f8c/one_stage_nas/data/datasets/__pycache__/tasks_dict.cpython-36.pyc -------------------------------------------------------------------------------- /one_stage_nas/data/datasets/__pycache__/transforms.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hkzhang-git/HiNAS/fc56724fde81870ac90bd6d87075f0d1284b1f8c/one_stage_nas/data/datasets/__pycache__/transforms.cpython-36.pyc -------------------------------------------------------------------------------- /one_stage_nas/data/datasets/denoise.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import torch 4 | import random 5 | from PIL import Image 6 | from torch.utils.data import Dataset 7 | 8 | 9 | def json_loader(dict_file_dir): 10 | with open(dict_file_dir, 'r') as data_file: 11 | return json.load(data_file) 12 | 13 | 14 | # This class is build for loading different datasets in denoise tasks 15 | class Dn_datasets(Dataset): 16 | def __init__(self, data_root, data_dict, transform, load_all=False, to_gray=False, s_factor=1, repeat_crop=1): 17 | self.data_root = data_root 18 | self.transform = transform 19 | self.load_all = load_all 20 | self.to_gray = to_gray 21 | self.repeat_crop = repeat_crop 22 | if self.load_all is False: 23 | self.data_dict = data_dict 24 | else: 25 | self.data_dict = [] 26 | for sample_info in data_dict: 27 | sample_data = Image.open('/'.join((self.data_root, sample_info['path']))).copy() 28 | if sample_data.mode in ['RGBA']: 29 | sample_data = sample_data.convert('RGB') 30 | width = sample_info['width'] 31 | height = sample_info['height'] 32 | sample = { 33 | 'data': sample_data, 34 | 'width': width, 35 | 'height': height 36 | } 37 | self.data_dict.append(sample) 38 | 39 | def __len__(self): 40 | return len(self.data_dict) 41 | 42 | def __getitem__(self, idx): 43 | sample_info = self.data_dict[idx] 44 | if self.load_all is False: 45 | sample_data = Image.open('/'.join((self.data_root, sample_info['path']))) 46 | if sample_data.mode in ['RGBA']: 47 | sample_data = sample_data.convert('RGB') 48 | else: 49 | sample_data = sample_info['data'] 50 | 51 | if self.to_gray: 52 | sample_data = sample_data.convert('L') 53 | 54 | # crop (w_start, h_start, w_end, h_end) 55 | image = sample_data 56 | target = sample_data 57 | 58 | sample = {'image': image, 'target': target} 59 | 60 | if self.repeat_crop != 1: 61 | image_stacks = [] 62 | target_stacks = [] 63 | 64 | for i in range(self.repeat_crop): 65 | sample_patch = self.transform(sample) 66 | image_stacks.append(sample_patch['image']) 67 | target_stacks.append(sample_patch['target']) 68 | return torch.stack(image_stacks), torch.stack(target_stacks) 69 | 70 | else: 71 | sample = self.transform(sample) 72 | return sample['image'], sample['target'] 73 | 74 | 75 | 76 | -------------------------------------------------------------------------------- /one_stage_nas/data/datasets/see_in_dark.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import torch 4 | import rawpy 5 | import random 6 | import numpy as np 7 | from PIL import Image 8 | from torch.utils.data import Dataset 9 | import matplotlib.pyplot as plt 10 | 11 | 12 | def json_loader(dict_file_dir): 13 | with open(dict_file_dir, 'r') as data_file: 14 | return json.load(data_file) 15 | 16 | 17 | # the dataloader of seeing in the dark on the Sony dataset 18 | class Sid_dataset(Dataset): 19 | def __init__(self, data_root, data_dict, transform, load_all=False, to_gray=False, s_factor=1, repeat_crop=1): 20 | self.data_root = data_root 21 | self.transform = transform 22 | self.load_all = load_all 23 | self.to_gray = to_gray 24 | self.repeat_crop = repeat_crop 25 | self.load_all = load_all 26 | 27 | if not self.load_all: 28 | self.data_dict = data_dict 29 | else: 30 | self.data_dict = [] 31 | for sample_info in data_dict: 32 | raw_data = [] 33 | for raw_path in sample_info['raw_path']: 34 | raw_data.append(rawpy.imread('/'.join((self.data_root, raw_path)))) 35 | gt_data =rawpy.imread('/'.join((self.data_root, sample_info['gt_path']))) 36 | 37 | sample = { 38 | 'sample_id': sample_info['sample_id'], 39 | 'raw_data': raw_data, 40 | 'gt_data': gt_data, 41 | 'raw_exposure':sample_info['raw_exposure'], 42 | 'gt_exposure': sample_info['gt_exposure'], 43 | } 44 | 45 | self.data_dict.append(sample) 46 | 47 | def __len__(self): 48 | return len(self.data_dict) 49 | 50 | def __getitem__(self, idx): 51 | sample_info = self.data_dict[idx] 52 | raw_exposure = sample_info['raw_exposure'] 53 | gt_exposure = sample_info['gt_exposure'] 54 | 55 | raw_index = random.randint(0, len(raw_exposure)-1) 56 | raw_exposure_cur = raw_exposure[raw_index] 57 | 58 | if not self.load_all: 59 | raw_path = sample_info['raw_path'][raw_index] 60 | gt_path = sample_info['gt_path'] 61 | raw_input = rawpy.imread('/'.join((self.data_root, raw_path))) 62 | gt_input = rawpy.imread('/'.join((self.data_root, gt_path))) 63 | else: 64 | raw_input = sample_info['raw_data'][raw_index] 65 | gt_input = sample_info['gt_data'] 66 | 67 | arw_input, width, height = self.pack_raw(raw_input) 68 | rgb_input = raw_input.postprocess(use_camera_wb=True, half_size=False, no_auto_bright=True, output_bps=16) 69 | rgb_input = (rgb_input / 65535.0).astype(np.float32) 70 | 71 | gt_input = gt_input.postprocess(use_camera_wb=True, half_size=False, no_auto_bright=True, output_bps=16) 72 | gt_input = (gt_input / 65535.0).astype(np.float32) 73 | 74 | ratio = min(gt_exposure/raw_exposure_cur, 300) 75 | arw_input = np.minimum(np.maximum(arw_input * ratio, 0), 1) 76 | rgb_input = np.minimum(np.maximum(rgb_input * ratio, 0), 1) 77 | gt_input = np.minimum(np.maximum(gt_input, 0), 1) 78 | 79 | sample = {'arw': arw_input, 'rgb': rgb_input, 'gt': gt_input} 80 | 81 | if self.repeat_crop !=1: 82 | arw_stacks=[] 83 | rgb_stacks=[] 84 | gt_stacks=[] 85 | 86 | for i in range(self.repeat_crop): 87 | sample_patch = self.transform(sample) 88 | arw_stacks.append(sample_patch['arw']) 89 | rgb_stacks.append(sample_patch['rgb']) 90 | gt_stacks.append(sample_patch['gt']) 91 | return [torch.stack(arw_stacks), torch.stack(rgb_stacks)], torch.stack(gt_stacks) 92 | 93 | else: 94 | sample_patch = self.transform(sample) 95 | return [sample_patch['arw'], sample_patch['rgb']], sample_patch['gt'] 96 | 97 | def pack_raw(self, raw): 98 | # pack Bayer image to 4 channels 99 | im = raw.raw_image_visible.astype(np.float32) 100 | im = np.maximum(im - 512, 0) / (16383 - 512) # subtract the black level 101 | 102 | im = np.expand_dims(im, axis=2) 103 | img_shape = im.shape 104 | H = img_shape[0] 105 | W = img_shape[1] 106 | 107 | out = np.concatenate((im[0:H:2, 0:W:2, :], 108 | im[0:H:2, 1:W:2, :], 109 | im[1:H:2, 1:W:2, :], 110 | im[1:H:2, 0:W:2, :]), axis=2) 111 | return out, W, H 112 | 113 | -------------------------------------------------------------------------------- /one_stage_nas/data/datasets/super_resolution.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import torch 4 | import random 5 | import numpy as np 6 | from PIL import Image 7 | from torch.utils.data import Dataset 8 | import matplotlib.pyplot as plt 9 | 10 | 11 | def json_loader(dict_file_dir): 12 | with open(dict_file_dir, 'r') as data_file: 13 | return json.load(data_file) 14 | 15 | # This class is build for loading different datasets in super resolution tasks 16 | class Sr_datasets(Dataset): 17 | def __init__(self, data_root, data_dict, transform, load_all=True, to_gray=False, s_factor=1, repeat_crop=1): 18 | self.data_root = data_root 19 | self.transform = transform 20 | self.load_all = load_all 21 | self.s_factor = s_factor 22 | self.repeat_crop = repeat_crop 23 | if self.load_all is False: 24 | self.data_dict = data_dict 25 | else: 26 | self.data_dict = [] 27 | for sample_info in data_dict: 28 | hr_img = Image.open('/'.join((self.data_root, sample_info['gt_path']))).copy() 29 | lw_img = Image.open('/'.join((self.data_root, sample_info['x{}_path'.format(self.s_factor)]))).copy() 30 | if hr_img.mode != 'RGB': 31 | hr_img, lw_img = hr_img.convert('RGB'), lw_img.convert('RGB') 32 | [width, height] = sample_info['x{}_size'.format(self.s_factor)] 33 | sample = { 34 | 'hr_img': hr_img, 35 | 'lw_img': lw_img, 36 | 'width': width, 37 | 'height': height 38 | } 39 | self.data_dict.append(sample) 40 | 41 | def __len__(self): 42 | return len(self.data_dict) 43 | 44 | def __getitem__(self, idx): 45 | sample_info = self.data_dict[idx] 46 | if self.load_all is False: 47 | image = Image.open('/'.join((self.data_root, sample_info['x{}_path'.format(self.s_factor)]))) 48 | target = Image.open('/'.join((self.data_root, sample_info['gt_path']))) 49 | if image.mode != 'RGB': 50 | image, target = image.convert('RGB'), target.convert('RGB') 51 | else: 52 | image = sample_info['lw_img'] 53 | target = sample_info['hr_img'] 54 | 55 | sample = {'image': image, 'target': target} 56 | 57 | if self.repeat_crop != 1: 58 | image_stacks = [] 59 | target_stacks = [] 60 | 61 | for i in range(self.repeat_crop): 62 | sample_patch = self.transform(sample) 63 | image_stacks.append(sample_patch['image']) 64 | target_stacks.append(sample_patch['target']) 65 | return torch.stack(image_stacks), torch.stack(target_stacks) 66 | 67 | else: 68 | sample_patch = self.transform(sample) 69 | return sample_patch['image'], sample_patch['target'] 70 | -------------------------------------------------------------------------------- /one_stage_nas/data/datasets/tasks_dict.py: -------------------------------------------------------------------------------- 1 | from .denoise import Dn_datasets 2 | # from .see_in_dark import Sid_dataset 3 | from .super_resolution import Sr_datasets 4 | 5 | tasks_dict = { 6 | 'dn': Dn_datasets, 7 | # 'sid': Sid_dataset, 8 | 'sr': Sr_datasets 9 | } 10 | 11 | 12 | -------------------------------------------------------------------------------- /one_stage_nas/engine/__pycache__/inference.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hkzhang-git/HiNAS/fc56724fde81870ac90bd6d87075f0d1284b1f8c/one_stage_nas/engine/__pycache__/inference.cpython-36.pyc -------------------------------------------------------------------------------- /one_stage_nas/engine/__pycache__/searcher.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hkzhang-git/HiNAS/fc56724fde81870ac90bd6d87075f0d1284b1f8c/one_stage_nas/engine/__pycache__/searcher.cpython-36.pyc -------------------------------------------------------------------------------- /one_stage_nas/engine/__pycache__/trainer.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hkzhang-git/HiNAS/fc56724fde81870ac90bd6d87075f0d1284b1f8c/one_stage_nas/engine/__pycache__/trainer.cpython-36.pyc -------------------------------------------------------------------------------- /one_stage_nas/engine/__pycache__/trainer_joint.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hkzhang-git/HiNAS/fc56724fde81870ac90bd6d87075f0d1284b1f8c/one_stage_nas/engine/__pycache__/trainer_joint.cpython-36.pyc -------------------------------------------------------------------------------- /one_stage_nas/engine/searcher.py: -------------------------------------------------------------------------------- 1 | import os 2 | import logging 3 | import time 4 | import datetime 5 | 6 | import torch 7 | import matplotlib.pyplot as plt 8 | 9 | from one_stage_nas.utils.metric_logger import MetricLogger 10 | from one_stage_nas.utils.comm import reduce_loss_dict 11 | from one_stage_nas.utils.visualize import model_visualize 12 | # from .inference import dn_inference, sid_inference, sr_inference 13 | from .inference import dn_inference, sr_inference 14 | 15 | 16 | def do_search( 17 | model, 18 | train_loaders, 19 | val_list, 20 | max_epoch, 21 | arch_start_epoch, 22 | val_period, 23 | optimizer, 24 | scheduler, 25 | checkpointer, 26 | checkpointer_period, 27 | arguments, 28 | writer, 29 | cfg, 30 | visual_dir): 31 | """ 32 | num_classes (int): number of classes. Required by computing mIoU. 33 | """ 34 | logger = logging.getLogger("one_stage_nas.searcher") 35 | logger.info("Start searching") 36 | 37 | start_epoch = arguments["epoch"] 38 | start_training_time = time.time() 39 | 40 | if cfg.DATASET.TASK == 'dn': 41 | inference = dn_inference 42 | # elif cfg.DATASET.TASK == 'sid': 43 | # inference = sid_inference 44 | elif cfg.DATASET.TASK == 'sr': 45 | inference = sr_inference 46 | 47 | best_val = 0 48 | for epoch in range(start_epoch, max_epoch): 49 | epoch = epoch + 1 50 | arguments["epoch"] = epoch 51 | 52 | scheduler.step() 53 | 54 | train(model, train_loaders, optimizer, epoch, 55 | train_arch=epoch > arch_start_epoch, repeat_crop=cfg.DATALOADER.R_CROP) 56 | if epoch > cfg.SEARCH.ARCH_START_EPOCH: 57 | save_dir = '/'.join((visual_dir, 'visualize', 'arch_epoch{}'.format(epoch))) 58 | model_visualize(model, save_dir, cfg.SEARCH.TIE_CELL) 59 | if epoch % val_period == 0 and epoch > 60: 60 | # if epoch % val_period == 0: 61 | ssim, psnr = inference(model, val_list, cfg) 62 | if best_val < (ssim + psnr/100): 63 | best_val = (ssim + psnr/100) 64 | checkpointer.save("model_best", **arguments) 65 | writer.add_scalars('Search_SSIM', { 'val_ssim': ssim}, epoch) 66 | writer.add_scalars('Search_PSNR', {'val_psnr': psnr}, epoch) 67 | if epoch % checkpointer_period == 0: 68 | checkpointer.save("model_{:03d}".format(epoch), **arguments) 69 | if epoch == max_epoch: 70 | checkpointer.save("model_final", **arguments) 71 | 72 | total_training_time = time.time() - start_training_time 73 | total_time_str = str(datetime.timedelta(seconds=total_training_time)) 74 | logger.info("Total training time: {}".format(total_time_str)) 75 | 76 | 77 | def train(model, data_loaders, optimizer, epoch, 78 | train_arch=False, repeat_crop=1): 79 | """ 80 | Should add some stats and log to visualise the archs 81 | """ 82 | data_loader_w = data_loaders[0] 83 | data_loader_a = data_loaders[1] 84 | optim_w = optimizer['optim_w'] 85 | optim_a = optimizer['optim_a'] 86 | 87 | logger = logging.getLogger("one_stage_nas.searcher") 88 | 89 | max_iter = len(data_loader_w) 90 | model.train() 91 | meters = MetricLogger(delimiter=" ") 92 | end = time.time() 93 | for iteration, (images, targets) in enumerate(data_loader_w): 94 | data_time = time.time() - end 95 | 96 | if train_arch: 97 | # print('start_train_arch') 98 | images_a, targets_a = next(iter(data_loader_a)) 99 | 100 | if repeat_crop != 1: 101 | if isinstance(images_a, list): 102 | ima0_sizes = images_a[0].shape 103 | ima1_sizes = images_a[1].shape 104 | images_a = [images_a[0].view(ima0_sizes[0] * ima0_sizes[1], ima0_sizes[2], ima0_sizes[3], ima0_sizes[4]), 105 | images_a[1].view(ima1_sizes[0] * ima1_sizes[1], ima1_sizes[2], ima1_sizes[3], ima1_sizes[4]) 106 | ] 107 | else: 108 | im_sizes = images_a.shape 109 | images_a = images_a.view(im_sizes[0] * im_sizes[1], im_sizes[2], im_sizes[3], im_sizes[4]) 110 | 111 | ta_sizes = targets_a.shape 112 | targets_a = targets.view(ta_sizes[0] * ta_sizes[1], ta_sizes[2], ta_sizes[3], ta_sizes[4]) 113 | 114 | loss_dict = model(images_a, targets_a) 115 | losses = sum(loss for loss in loss_dict.values()).mean() 116 | 117 | optim_a.zero_grad() 118 | losses.backward() 119 | optim_a.step() 120 | 121 | if repeat_crop!=1: 122 | 123 | if isinstance(images, list): 124 | im0_sizes = images[0].shape 125 | im1_sizes = images[1].shape 126 | images = [images[0].view(im0_sizes[0]*im0_sizes[1], im0_sizes[2], im0_sizes[3], im0_sizes[4]), 127 | images[1].view(im1_sizes[0]*im1_sizes[1], im1_sizes[2], im1_sizes[3], im1_sizes[4]) 128 | ] 129 | else: 130 | im_sizes = images.shape 131 | images = images.view(im_sizes[0] * im_sizes[1], im_sizes[2], im_sizes[3], im_sizes[4]) 132 | 133 | ta_sizes = targets.shape 134 | targets = targets.view(ta_sizes[0]*ta_sizes[1], ta_sizes[2], ta_sizes[3], ta_sizes[4]) 135 | 136 | loss_dict = model(images, targets) 137 | losses = sum(loss for loss in loss_dict.values()).mean() 138 | 139 | # reduce losses over all GPUs for logging purposes 140 | loss_dict_reduced = reduce_loss_dict(loss_dict) 141 | losses_reduced = sum(loss for loss in loss_dict_reduced.values()) 142 | meters.update(loss=losses_reduced, **loss_dict_reduced) 143 | 144 | optim_w.zero_grad() 145 | losses.backward() 146 | optim_w.step() 147 | 148 | batch_time = time.time() - end 149 | end = time.time() 150 | meters.update(time=batch_time, data=data_time) 151 | 152 | eta_seconds = meters.time.global_avg * (max_iter - iteration) 153 | eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) 154 | 155 | if iteration % 50 == 0: 156 | logger.info( 157 | meters.delimiter.join( 158 | ["eta: {eta}", 159 | "iter: {epoch}/{iter}", 160 | "{meters}", 161 | "lr: {lr:.6f}", 162 | "max_mem: {memory:.1f} G"]).format( 163 | eta=eta_string, 164 | epoch=epoch, 165 | iter=iteration, 166 | meters=str(meters), 167 | lr=optim_w.param_groups[0]['lr'], 168 | memory=torch.cuda.max_memory_allocated() / 1024.0 / 1024.0 / 1024.0)) 169 | -------------------------------------------------------------------------------- /one_stage_nas/engine/trainer.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import time 3 | import datetime 4 | 5 | import torch 6 | import matplotlib.pyplot as plt 7 | 8 | from one_stage_nas.utils.metric_logger import MetricLogger 9 | from one_stage_nas.utils.comm import reduce_loss_dict, compute_params 10 | from .inference import dn_inference, sid_inference, sr_inference 11 | from one_stage_nas.utils.evaluation_metrics import SSIM, PSNR 12 | 13 | 14 | def do_train( 15 | model, 16 | train_loader, 17 | val_list, 18 | max_iter, 19 | val_period, 20 | optimizer, 21 | scheduler, 22 | checkpointer, 23 | checkpointer_period, 24 | arguments, 25 | writer, 26 | cfg): 27 | """ 28 | num_classes (int): number of classes. Required by computing mIoU. 29 | """ 30 | logger = logging.getLogger("one_stage_nas.trainer") 31 | logger.info("Model Params: {:.2f}M".format(compute_params(model) / 1024 / 1024)) 32 | 33 | logger.info("Start training") 34 | 35 | start_iter = arguments["iteration"] 36 | start_training_time = time.time() 37 | 38 | if cfg.DATASET.TASK == 'dn': 39 | inference = dn_inference 40 | elif cfg.DATASET.TASK == 'sid': 41 | inference = sid_inference 42 | elif cfg.DATASET.TASK == 'sr': 43 | inference = sr_inference 44 | 45 | 46 | best_val = 0 47 | model.train() 48 | data_iter = iter(train_loader) 49 | 50 | meters = MetricLogger(delimiter=" ") 51 | if cfg.DATASET.TASK in ['sid']: 52 | metric_SSIM = SSIM(window_size=11, channel=3, is_cuda=True) 53 | else: 54 | metric_SSIM = SSIM(window_size=11, channel=cfg.MODEL.IN_CHANNEL, is_cuda=True) 55 | metric_PSNR = PSNR() 56 | repeat_crop = cfg.DATALOADER.R_CROP 57 | 58 | end = time.time() 59 | for iteration in range(start_iter, max_iter): 60 | iteration = iteration + 1 61 | arguments["iteration"] = iteration 62 | 63 | scheduler.step() 64 | 65 | try: 66 | images, targets = next(data_iter) 67 | except StopIteration: 68 | data_iter = iter(train_loader) 69 | images, targets = next(data_iter) 70 | data_time = time.time() - end 71 | 72 | if repeat_crop!=1: 73 | 74 | if isinstance(images, list): 75 | im0_sizes = images[0].shape 76 | im1_sizes = images[1].shape 77 | images = [images[0].view(im0_sizes[0]*im0_sizes[1], im0_sizes[2], im0_sizes[3], im0_sizes[4]), 78 | images[1].view(im1_sizes[0]*im1_sizes[1], im1_sizes[2], im1_sizes[3], im1_sizes[4]) 79 | ] 80 | else: 81 | im_sizes = images.shape 82 | images = images.view(im_sizes[0] * im_sizes[1], im_sizes[2], im_sizes[3], im_sizes[4]) 83 | 84 | ta_sizes = targets.shape 85 | targets = targets.view(ta_sizes[0]*ta_sizes[1], ta_sizes[2], ta_sizes[3], ta_sizes[4]) 86 | 87 | pred, loss_dict = model(images, targets) 88 | losses = sum(loss for loss in loss_dict.values()).mean() 89 | 90 | # # reduce losses over all GPUs for logging purposes 91 | # loss_dict_reduced = reduce_loss_dict(loss_dict) 92 | # losses_reduced = sum(loss for loss in loss_dict_reduced.values()) 93 | # meters.update(loss=losses_reduced, **loss_dict_reduced) 94 | 95 | optimizer.zero_grad() 96 | losses.backward() 97 | torch.nn.utils.clip_grad_value_(model.parameters(), 5.0) 98 | optimizer.step() 99 | 100 | batch_time = time.time() - end 101 | end = time.time() 102 | meters.update(time=batch_time, data=data_time) 103 | 104 | eta_seconds = meters.time.global_avg * (max_iter - iteration) 105 | eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) 106 | 107 | pred[pred>1.0] = 1.0 108 | pred[pred<0.0] = 0.0 109 | 110 | targets = targets.cuda() 111 | 112 | metric_SSIM(pred.detach(), targets, transpose=False) 113 | metric_PSNR(pred.detach(), targets) 114 | 115 | if iteration % (val_period // 4) == 0: 116 | logger.info( 117 | meters.delimiter.join( 118 | ["eta: {eta}", 119 | "iter: {iter}", 120 | "{meters}", 121 | "lr: {lr:.6f}", 122 | "max_mem: {memory:.0f}"]).format( 123 | eta=eta_string, 124 | iter=iteration, 125 | meters=str(meters), 126 | lr=optimizer.param_groups[0]['lr'], 127 | memory=torch.cuda.max_memory_allocated() / 1024.0 / 1024.0)) 128 | print(float(losses)) 129 | 130 | if iteration % val_period == 0: 131 | train_ssim, train_psnr = metric_SSIM.metric_get(), metric_PSNR.metric_get() 132 | metric_SSIM.reset() 133 | metric_PSNR.reset() 134 | 135 | if iteration > int(max_iter*3/4): 136 | ssim, psnr, input_img, output_img, target_img = inference(model, val_list, cfg, show_img=True, tag='train') 137 | if best_val < (ssim + psnr/100): 138 | best_val = (ssim + psnr/100) 139 | checkpointer.save("model_best", **arguments) 140 | # set mode back to train 141 | model.train() 142 | writer.add_image('img/train/input', input_img, iteration) 143 | writer.add_image('img/train/output', output_img, iteration) 144 | writer.add_image('img/train/target', target_img, iteration) 145 | writer.add_scalars('SSIM', {'train_ssim': train_ssim, 'val_ssim': ssim}, iteration) 146 | writer.add_scalars('PSNR', {'train_psnr': train_psnr, 'val_psnr': psnr}, iteration) 147 | else: 148 | writer.add_scalars('SSIM', {'train_ssim': train_ssim}, iteration) 149 | writer.add_scalars('PSNR', {'train_psnr': train_psnr}, iteration) 150 | 151 | if iteration % val_period == 0: 152 | checkpointer.save("model_{:06d}".format(iteration), **arguments) 153 | if iteration == max_iter: 154 | checkpointer.save("model_final", **arguments) 155 | 156 | total_training_time = time.time() - start_training_time 157 | total_time_str = str(datetime.timedelta(seconds=total_training_time)) 158 | logger.info("Total training time: {}".format(total_time_str)) 159 | 160 | writer.close() 161 | 162 | -------------------------------------------------------------------------------- /one_stage_nas/engine/trainer_joint.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import time 3 | import datetime 4 | import numpy as np 5 | import torch 6 | import random 7 | import matplotlib.pyplot as plt 8 | 9 | from one_stage_nas.utils.metric_logger import MetricLogger 10 | from one_stage_nas.utils.comm import reduce_loss_dict, compute_params 11 | from .inference import denoise_inference_joint 12 | from one_stage_nas.utils.evaluation_metrics import SSIM, PSNR 13 | 14 | 15 | def do_train( 16 | model, 17 | train_loader_list, 18 | val_set_list, 19 | max_iter, 20 | val_period, 21 | optimizer, 22 | scheduler, 23 | checkpointer, 24 | checkpointer_period, 25 | arguments, 26 | writer, 27 | cfg): 28 | """ 29 | num_classes (int): number of classes. Required by computing mIoU. 30 | """ 31 | logger = logging.getLogger("one_stage_nas.trainer") 32 | logger.info("Model Params: {:.2f}M".format(compute_params(model) / 1024 / 1024)) 33 | 34 | logger.info("Start training") 35 | 36 | start_iter = arguments["iteration"] 37 | start_training_time = time.time() 38 | 39 | inference = denoise_inference_joint 40 | 41 | best_val = 0 42 | model.train() 43 | 44 | data_iter_list = [] 45 | for train_loader in train_loader_list: 46 | data_iter_list.append(iter(train_loader)) 47 | 48 | meters = MetricLogger(delimiter=" ") 49 | metric_SSIM = SSIM(window_size=11, channel=cfg.MODEL.IN_CHANNEL, is_cuda=True) 50 | metric_PSNR = PSNR() 51 | 52 | datasets_weight = np.array(cfg.DATASET.TRAIN_DATASETS_WEIGHT) 53 | weights = [datasets_weight[:i+1].sum() for i in range(len(datasets_weight))] 54 | end = time.time() 55 | for iteration in range(start_iter, max_iter): 56 | iteration = iteration + 1 57 | arguments["iteration"] = iteration 58 | 59 | scheduler.step() 60 | 61 | random_id = random_choose(weights) 62 | try: 63 | images, targets = next(data_iter_list[random_id]) 64 | except StopIteration: 65 | data_iter_list = [] 66 | for train_loader in train_loader_list: 67 | data_iter_list.append(iter(train_loader)) 68 | images, targets = next(data_iter_list[random_id]) 69 | data_time = time.time() - end 70 | 71 | pred, loss_dict = model(images, targets) 72 | losses = sum(loss for loss in loss_dict.values()).mean() 73 | 74 | # reduce losses over all GPUs for logging purposes 75 | loss_dict_reduced = reduce_loss_dict(loss_dict) 76 | losses_reduced = sum(loss for loss in loss_dict_reduced.values()) 77 | meters.update(loss=losses_reduced, **loss_dict_reduced) 78 | 79 | optimizer.zero_grad() 80 | losses.backward() 81 | torch.nn.utils.clip_grad_value_(model.parameters(), 3) 82 | optimizer.step() 83 | 84 | batch_time = time.time() - end 85 | end = time.time() 86 | meters.update(time=batch_time, data=data_time) 87 | 88 | eta_seconds = meters.time.global_avg * (max_iter - iteration) 89 | eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) 90 | 91 | pred[pred>1.0] = 1.0 92 | pred[pred<0.0] = 0.0 93 | 94 | targets = targets.cuda() 95 | 96 | metric_SSIM(pred.detach(), targets, transpose=False) 97 | metric_PSNR(pred.detach(), targets) 98 | 99 | if iteration % (val_period // 4) == 0: 100 | logger.info( 101 | meters.delimiter.join( 102 | ["eta: {eta}", 103 | "iter: {iter}", 104 | "{meters}", 105 | "lr: {lr:.6f}", 106 | "max_mem: {memory:.0f}"]).format( 107 | eta=eta_string, 108 | iter=iteration, 109 | meters=str(meters), 110 | lr=optimizer.param_groups[0]['lr'], 111 | memory=torch.cuda.max_memory_allocated() / 1024.0 / 1024.0)) 112 | 113 | if iteration % val_period == 0: 114 | train_ssim, train_psnr = metric_SSIM.metric_get(), metric_PSNR.metric_get() 115 | metric_SSIM.reset() 116 | metric_PSNR.reset() 117 | 118 | ssim, psnr, input_img, output_img, target_img = inference(model, val_set_list, cfg, show_img=True, tag='train') 119 | if best_val < (ssim + psnr/100): 120 | best_val = (ssim + psnr/100) 121 | checkpointer.save("model_best", **arguments) 122 | # set mode back to train 123 | model.train() 124 | writer.add_image('img/train/input', input_img, iteration) 125 | writer.add_image('img/train/output', output_img, iteration) 126 | writer.add_image('img/train/target', target_img, iteration) 127 | writer.add_scalars('SSIM', {'train_ssim': train_ssim, 'val_ssim': ssim}, iteration) 128 | writer.add_scalars('PSNR', {'train_psnr': train_psnr, 'val_psnr': psnr}, iteration) 129 | # writer.add_scalars('SSIM', {'train_ssim': train_ssim}, iteration) 130 | # writer.add_scalars('PSNR', {'train_psnr': train_psnr}, iteration) 131 | 132 | if iteration % val_period == 0: 133 | checkpointer.save("model_{:06d}".format(iteration), **arguments) 134 | if iteration == max_iter: 135 | checkpointer.save("model_final", **arguments) 136 | 137 | total_training_time = time.time() - start_training_time 138 | total_time_str = str(datetime.timedelta(seconds=total_training_time)) 139 | logger.info("Total training time: {}".format(total_time_str)) 140 | 141 | writer.close() 142 | 143 | 144 | def random_choose(weights): 145 | random_id = random.randint(1, weights[-1]) 146 | for id, region in enumerate(weights): 147 | if random_id<=region: 148 | return id 149 | 150 | 151 | -------------------------------------------------------------------------------- /one_stage_nas/modeling/__init__.py: -------------------------------------------------------------------------------- 1 | from .architectures import build_model 2 | -------------------------------------------------------------------------------- /one_stage_nas/modeling/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hkzhang-git/HiNAS/fc56724fde81870ac90bd6d87075f0d1284b1f8c/one_stage_nas/modeling/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /one_stage_nas/modeling/__pycache__/amt_discrete.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hkzhang-git/HiNAS/fc56724fde81870ac90bd6d87075f0d1284b1f8c/one_stage_nas/modeling/__pycache__/amt_discrete.cpython-36.pyc -------------------------------------------------------------------------------- /one_stage_nas/modeling/__pycache__/architectures.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hkzhang-git/HiNAS/fc56724fde81870ac90bd6d87075f0d1284b1f8c/one_stage_nas/modeling/__pycache__/architectures.cpython-36.pyc -------------------------------------------------------------------------------- /one_stage_nas/modeling/__pycache__/auto_multitask.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hkzhang-git/HiNAS/fc56724fde81870ac90bd6d87075f0d1284b1f8c/one_stage_nas/modeling/__pycache__/auto_multitask.cpython-36.pyc -------------------------------------------------------------------------------- /one_stage_nas/modeling/__pycache__/common.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hkzhang-git/HiNAS/fc56724fde81870ac90bd6d87075f0d1284b1f8c/one_stage_nas/modeling/__pycache__/common.cpython-36.pyc -------------------------------------------------------------------------------- /one_stage_nas/modeling/__pycache__/decoders.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hkzhang-git/HiNAS/fc56724fde81870ac90bd6d87075f0d1284b1f8c/one_stage_nas/modeling/__pycache__/decoders.cpython-36.pyc -------------------------------------------------------------------------------- /one_stage_nas/modeling/__pycache__/dn_compnet.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hkzhang-git/HiNAS/fc56724fde81870ac90bd6d87075f0d1284b1f8c/one_stage_nas/modeling/__pycache__/dn_compnet.cpython-36.pyc -------------------------------------------------------------------------------- /one_stage_nas/modeling/__pycache__/dn_supernet.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hkzhang-git/HiNAS/fc56724fde81870ac90bd6d87075f0d1284b1f8c/one_stage_nas/modeling/__pycache__/dn_supernet.cpython-36.pyc -------------------------------------------------------------------------------- /one_stage_nas/modeling/__pycache__/loss.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hkzhang-git/HiNAS/fc56724fde81870ac90bd6d87075f0d1284b1f8c/one_stage_nas/modeling/__pycache__/loss.cpython-36.pyc -------------------------------------------------------------------------------- /one_stage_nas/modeling/__pycache__/sid_compnet.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hkzhang-git/HiNAS/fc56724fde81870ac90bd6d87075f0d1284b1f8c/one_stage_nas/modeling/__pycache__/sid_compnet.cpython-36.pyc -------------------------------------------------------------------------------- /one_stage_nas/modeling/__pycache__/sid_supernet.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hkzhang-git/HiNAS/fc56724fde81870ac90bd6d87075f0d1284b1f8c/one_stage_nas/modeling/__pycache__/sid_supernet.cpython-36.pyc -------------------------------------------------------------------------------- /one_stage_nas/modeling/__pycache__/sr_compnet.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hkzhang-git/HiNAS/fc56724fde81870ac90bd6d87075f0d1284b1f8c/one_stage_nas/modeling/__pycache__/sr_compnet.cpython-36.pyc -------------------------------------------------------------------------------- /one_stage_nas/modeling/__pycache__/sr_supernet.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hkzhang-git/HiNAS/fc56724fde81870ac90bd6d87075f0d1284b1f8c/one_stage_nas/modeling/__pycache__/sr_supernet.cpython-36.pyc -------------------------------------------------------------------------------- /one_stage_nas/modeling/amt_discrete.py: -------------------------------------------------------------------------------- 1 | """ 2 | Discrete structure of Auto-DeepLab 3 | 4 | Includes utils to convert continous Auto-DeepLab to discrete ones 5 | """ 6 | 7 | import os 8 | import torch 9 | from torch import nn 10 | from torch.nn import functional as F 11 | 12 | from one_stage_nas.darts.cell import FixCell 13 | from .auto_multitask import AutoMultiTask 14 | from .common import conv3x3_bn, conv1x1_bn 15 | from .decoders import build_decoder 16 | from .loss import loss_dict 17 | 18 | 19 | def get_genotype_from_adl(cfg): 20 | # create ADL model 21 | adl_cfg = cfg.clone() 22 | adl_cfg.defrost() 23 | 24 | adl_cfg.merge_from_list(['MODEL.META_ARCHITECTURE', 'AutoDeepLab', 25 | 'MODEL.FILTER_MULTIPLIER', 8, 26 | 'MODEL.AFFINE', True, 27 | 'SEARCH.SEARCH_ON', True]) 28 | 29 | model = AutoMultiTask(adl_cfg) 30 | # load weights 31 | SEARCH_RESULT_DIR = '/'.join((cfg.OUTPUT_DIR, cfg.DATASET.TASK, 32 | '{}/Outline-{}c{}n_TC-{}_ASPP-{}_Res-{}_Prim-{}'. 33 | format(cfg.DATASET.DATA_NAME, cfg.MODEL.NUM_LAYERS, cfg.MODEL.NUM_BLOCKS, 34 | cfg.SEARCH.TIE_CELL, cfg.MODEL.USE_ASPP, cfg.MODEL.USE_RES, cfg.MODEL.PRIMITIVES), 35 | 'search/models/model_best.pth')) 36 | ckpt = torch.load(SEARCH_RESULT_DIR) 37 | restore = {k: v for k, v in ckpt['model'].items() if 'arch' in k} 38 | model.load_state_dict(restore, strict=False) 39 | return model.genotype() 40 | 41 | 42 | class Scaler(nn.Module): 43 | """Reshape features""" 44 | def __init__(self, scale, inp, C): 45 | """ 46 | Arguments: 47 | scale (int) [-2, 2]: scale < 0 for downsample 48 | """ 49 | super(Scaler, self).__init__() 50 | if scale == 0: 51 | self.scale = conv1x1_bn(inp, C, 1, relu=False) 52 | if scale == 1: 53 | self.scale = nn.Sequential( 54 | nn.Upsample(scale_factor=2, mode='bilinear', 55 | align_corners=False), 56 | conv1x1_bn(inp, C, 1, relu=False)) 57 | if scale == 2: 58 | self.scale = nn.Sequential( 59 | nn.Upsample(scale_factor=4, mode='bilinear', 60 | align_corners=False), 61 | conv1x1_bn(inp, C, 1, relu=False)) 62 | # official implementation used bilinear for all scalers 63 | if scale == -1: 64 | self.scale = conv3x3_bn(inp, C, 2, relu=False) 65 | if scale == -2: 66 | self.scale = nn.Sequential(conv3x3_bn(inp, inp * 2, 2), 67 | conv3x3_bn(inp * 2, C, 2, relu=False)) 68 | 69 | def forward(self, hidden_state): 70 | return self.scale(hidden_state) 71 | 72 | 73 | class DeepLabScaler_Width(nn.Module): 74 | """Official implementation 75 | https://github.com/tensorflow/models/blob/master/research/deeplab/core/nas_cell.py#L90 76 | """ 77 | def __init__(self, scale, inp, C, activate_f='ReLU'): 78 | super(DeepLabScaler_Width, self).__init__() 79 | self.activate_f = activate_f 80 | self.scale = 2 ** scale 81 | self.conv = conv1x1_bn(inp, C, 1, activate_f=None) 82 | 83 | def forward(self, hidden_state): 84 | if self.activate_f.lower() == 'relu': 85 | return self.conv(F.relu(hidden_state)) 86 | elif self.activate_f.lower() in ['leaky', 'prelu']: 87 | return self.conv(F.leaky_relu(hidden_state, negative_slope=0.2)) 88 | elif self.activate_f.lower() == 'sine': 89 | return self.conv(torch.sin(hidden_state)) 90 | 91 | 92 | class AMTDiscrete(nn.Module): 93 | def __init__(self, cfg): 94 | super(AMTDiscrete, self).__init__() 95 | 96 | # load genotype 97 | if len(cfg.DATASET.TRAIN_DATASETS) == 0: 98 | geno_file = '/'.join((cfg.OUTPUT_DIR, cfg.DATASET.TASK, 99 | '{}/Outline-{}c{}n_TC-{}_ASPP-{}_Res-{}_Prim-{}'. 100 | format(cfg.DATASET.DATA_NAME, cfg.MODEL.NUM_LAYERS, cfg.MODEL.NUM_BLOCKS, 101 | cfg.SEARCH.TIE_CELL, cfg.MODEL.USE_ASPP, cfg.MODEL.USE_RES, cfg.MODEL.PRIMITIVES), 102 | 'search/models/model_best.geno')) 103 | 104 | else: 105 | geno_file = '/'.join((cfg.OUTPUT_DIR, cfg.DATASET.TASK, 106 | '{}/Outline-{}c{}n_TC-{}_ASPP-{}_Res-{}_Prim-{}'. 107 | format(cfg.DATASET.DATA_NAME, cfg.MODEL.NUM_LAYERS, cfg.MODEL.NUM_BLOCKS, 108 | cfg.SEARCH.TIE_CELL, cfg.MODEL.USE_ASPP, cfg.MODEL.USE_RES, cfg.MODEL.PRIMITIVES), 109 | 'search/models/model_best.geno')) 110 | 111 | if os.path.exists(geno_file): 112 | print("Loading genotype from {}".format(geno_file)) 113 | genotype = torch.load(geno_file, map_location=torch.device("cpu")) 114 | else: 115 | genotype = get_genotype_from_adl(cfg) 116 | print("Saving genotype to {}".format(geno_file)) 117 | torch.save(genotype, geno_file) 118 | 119 | geno_cell, geno_path = genotype 120 | 121 | self.genotpe = genotype 122 | 123 | if 0 in geno_path: 124 | self.endpoint = (len(geno_path) - 1) - list(reversed(geno_path)).index(0) 125 | if self.endpoint == (len(geno_path) -1): 126 | self.endpoint = None 127 | else: 128 | self.endpoint = None 129 | 130 | # basic configs 131 | self.activate_f = cfg.MODEL.ACTIVATION_F 132 | self.use_res = cfg.MODEL.USE_RES 133 | self.f = cfg.MODEL.FILTER_MULTIPLIER 134 | self.num_layers = cfg.MODEL.NUM_LAYERS 135 | self.num_blocks = cfg.MODEL.NUM_BLOCKS 136 | self.num_strides = cfg.MODEL.NUM_STRIDES 137 | self.in_channel = cfg.MODEL.IN_CHANNEL 138 | self.stem1 = conv3x3_bn(self.in_channel, 64, 1, activate_f=self.activate_f) 139 | self.stem2 = conv3x3_bn(64, 64, 1, activate_f=None) 140 | self.reduce = conv3x3_bn(64, self.f*self.num_blocks, 1, affine=False, activate_f=None) 141 | 142 | # create cells 143 | self.cells = nn.ModuleList() 144 | self.scalers = nn.ModuleList() 145 | if cfg.SEARCH.TIE_CELL: 146 | geno_cell = [geno_cell] * self.num_layers 147 | 148 | DeepLabScaler = DeepLabScaler_Width 149 | 150 | h_0 = 0 # prev hidden index 151 | h_1 = -1 # prev prev hidden index 152 | for layer, (geno, h) in enumerate(zip(geno_cell, geno_path), 1): 153 | stride = 2 ** h 154 | self.cells.append(FixCell(geno, self.f * stride)) 155 | # scalers 156 | if layer == 1: 157 | inp0 = 64 158 | inp1 = 64 159 | elif layer == 2: 160 | inp0 = 2 ** h_0 * self.f * self.num_blocks 161 | inp1 = 64 162 | else: 163 | inp0 = 2 ** h_0 * self.f * self.num_blocks 164 | inp1 = 2 ** h_1 * self.f * self.num_blocks 165 | 166 | if layer == 1: 167 | scaler0 = DeepLabScaler(h_0 - h, inp0, 168 | stride * self.f, activate_f=self.activate_f) 169 | scaler1 = DeepLabScaler(h_0 - h, inp1, 170 | stride * self.f, activate_f=self.activate_f) 171 | else: 172 | scaler0 = DeepLabScaler(h_0 - h, inp0, 173 | stride * self.f, activate_f=self.activate_f) 174 | scaler1 = DeepLabScaler(h_1 - h, inp1, 175 | stride * self.f, activate_f=self.activate_f) 176 | 177 | h_1 = h_0 178 | h_0 = h 179 | self.scalers.append(scaler0) 180 | self.scalers.append(scaler1) 181 | self.decoder = build_decoder(cfg, out_strides=stride) 182 | if cfg.SOLVER.LOSS is not None: 183 | self.loss_dict = [] 184 | self.loss_weight = [] 185 | for loss_item, loss_weight in zip(cfg.SOLVER.LOSS, cfg.SOLVER.LOSS_WEIGHT): 186 | if 'ssim' in loss_item or 'grad' in loss_item: 187 | self.loss_dict.append(loss_dict[loss_item](channel=cfg.MODEL.IN_CHANNEL)) 188 | else: 189 | self.loss_dict.append(loss_dict[loss_item]()) 190 | self.loss_weight.append(loss_weight) 191 | 192 | else: 193 | self.loss_dict = None 194 | self.loss_weight = None 195 | 196 | def genotype(self): 197 | return self.genotpe 198 | 199 | def forward(self, images, targets=None, drop_prob=-1): 200 | if self.training and targets is None: 201 | raise ValueError("In training mode, targets should be passed.") 202 | 203 | h1 = self.stem1(images) 204 | if self.activate_f.lower() == 'relu': 205 | h0 = self.stem2(F.relu(h1)) 206 | elif self.activate_f.lower() in ['leaky', 'prelu']: 207 | h0 = self.stem2(F.leaky_relu(h1, negative_slope=0.2)) 208 | elif self.activate_f.lower() == 'sine': 209 | h0 = self.stem2(torch.sin(h1)) 210 | 211 | if self.endpoint==None: 212 | endpoint = self.reduce(h0) 213 | 214 | for i, cell in enumerate(self.cells): 215 | s0 = self.scalers[i*2](h0) 216 | s1 = self.scalers[i*2+1](h1) 217 | h1 = h0 218 | h0 = cell(s0, s1, drop_prob) 219 | if self.endpoint is not None and i == self.endpoint: 220 | endpoint = h0 221 | 222 | if self.activate_f.lower() == 'relu': 223 | pred = self.decoder([endpoint, F.relu(h0)]) 224 | elif self.activate_f.lower() in ['leaky', 'prelu']: 225 | pred = self.decoder([endpoint, F.leaky_relu(h0, negative_slope=0.2)]) 226 | elif self.activate_f.lower() == 'sine': 227 | pred= self.decoder([endpoint, torch.sin(h0)]) 228 | 229 | if self.use_res: 230 | pred = images-pred 231 | pred = torch.sigmoid(pred) 232 | 233 | if self.training: 234 | if loss_dict is not None: 235 | loss = [] 236 | for loss_item, weight in zip(self.loss_dict, self.loss_weight): 237 | loss.append(loss_item(pred, targets) * weight) 238 | else: 239 | loss = F.mse_loss(pred, targets) 240 | return pred, {'decoder_loss': sum(loss) / len(loss)} 241 | 242 | else: 243 | return pred 244 | 245 | -------------------------------------------------------------------------------- /one_stage_nas/modeling/architectures.py: -------------------------------------------------------------------------------- 1 | from .dn_supernet import Dn_supernet 2 | from .dn_compnet import Dn_compnet 3 | from .sid_supernet import Sid_supernet 4 | from .sid_compnet import Sid_compnet 5 | from .sr_supernet import Sr_supernet 6 | from .sr_compnet import Sr_compnet 7 | 8 | 9 | ARCHITECTURES = { 10 | "Dn_supernet": Dn_supernet, 11 | "Dn_compnet": Dn_compnet, 12 | "Sid_supernet": Sid_supernet, 13 | "Sid_compnet": Sid_compnet, 14 | "Sr_supernet": Sr_supernet, 15 | "Sr_compnet": Sr_compnet 16 | } 17 | 18 | 19 | def build_model(cfg): 20 | meta_arch = ARCHITECTURES[cfg.MODEL.META_ARCHITECTURE] 21 | return meta_arch(cfg) 22 | -------------------------------------------------------------------------------- /one_stage_nas/modeling/auto_multitask.py: -------------------------------------------------------------------------------- 1 | """ 2 | Implements Auto-DeepLab framework 3 | """ 4 | 5 | import torch 6 | from torch import nn 7 | import torch.nn.functional as F 8 | 9 | from one_stage_nas.darts.cell import Cell 10 | from one_stage_nas.darts.genotypes import PRIMITIVES 11 | from .decoders import build_decoder 12 | from .common import conv3x3_bn, conv1x1_bn, viterbi 13 | from .loss import loss_dict 14 | 15 | class Router_Width(nn.Module): 16 | """ Propagate hidden states to next layer 17 | """ 18 | 19 | def __init__(self, ind, inp, C, num_strides=4, affine=True): 20 | """ 21 | Arguments: 22 | ind (int) [2-5]: index of the cell, which decides output scales 23 | inp (int): inp size 24 | C (int): output size of the same scale 25 | """ 26 | super(Router_Width, self).__init__() 27 | self.ind = ind 28 | self.num_strides = num_strides 29 | 30 | if ind > 0: 31 | # upsample 32 | self.postprocess0 = conv1x1_bn(inp, C // 2, 1, affine=affine, activate_f=None) 33 | 34 | self.postprocess1 = conv1x1_bn(inp, C, 1, affine=affine, activate_f=None) 35 | if ind < num_strides - 1: 36 | # downsample 37 | self.postprocess2 = conv1x1_bn(inp, C * 2, 1, affine=affine, activate_f=None) 38 | 39 | def forward(self, out): 40 | """ 41 | Returns: 42 | h_next ([Tensor]): None for empty 43 | """ 44 | if self.ind > 0: 45 | h_next_0 = self.postprocess0(out) 46 | else: 47 | h_next_0 = None 48 | h_next_1 = self.postprocess1(out) 49 | if self.ind < self.num_strides - 1: 50 | h_next_2 = self.postprocess2(out) 51 | else: 52 | h_next_2 = None 53 | return h_next_0, h_next_1, h_next_2 54 | 55 | 56 | class AutoMultiTask(nn.Module): 57 | """ 58 | Main class for Auto-DeepLab. 59 | 60 | Use one cell per hidden states 61 | """ 62 | 63 | def __init__(self, cfg): 64 | super(AutoMultiTask, self).__init__() 65 | self.f = cfg.MODEL.FILTER_MULTIPLIER 66 | self.num_layers = cfg.MODEL.NUM_LAYERS 67 | self.num_blocks = cfg.MODEL.NUM_BLOCKS 68 | self.num_strides = cfg.MODEL.NUM_STRIDES 69 | self.primitives = PRIMITIVES[cfg.MODEL.PRIMITIVES] 70 | self.activatioin_f = cfg.MODEL.ACTIVATION_F 71 | self.use_res = cfg.MODEL.USE_RES 72 | affine = cfg.MODEL.AFFINE 73 | self.stem1 = nn.Sequential( 74 | conv3x3_bn(cfg.MODEL.IN_CHANNEL, 64, 1, affine=affine, activate_f=self.activatioin_f) 75 | ) 76 | self.stem2 = conv3x3_bn(64, self.f * self.num_blocks, 1, affine=affine, activate_f=None) 77 | # generates first h_1 78 | self.reduce1 = conv3x3_bn(64, self.f, 1, affine=affine, activate_f=None) 79 | 80 | # upsample module for other strides 81 | self.upsamplers = nn.ModuleList() 82 | 83 | 84 | Router = Router_Width 85 | for i in range(1, self.num_strides): 86 | self.upsamplers.append(conv1x1_bn(self.f * 2 ** (i - 1), 87 | self.f * 2 ** i, 88 | 1, affine=affine, activate_f=None)) 89 | 90 | self.cells = nn.ModuleList() 91 | self.routers = nn.ModuleList() 92 | self.cell_configs = [] 93 | self.tie_cell = cfg.SEARCH.TIE_CELL 94 | 95 | for l in range(1, self.num_layers + 1): 96 | for h in range(min(self.num_strides, l + 1)): 97 | stride = 2 ** h 98 | C = self.f * stride 99 | 100 | if h < l: 101 | self.routers.append(Router(h, C * self.num_blocks, 102 | C, affine=affine)) 103 | 104 | self.cell_configs.append( 105 | "L{}H{}: {}".format(l, h, C)) 106 | self.cells.append(Cell(self.num_blocks, C, 107 | self.primitives, 108 | affine=affine)) 109 | 110 | # ASPP 111 | self.decoder = build_decoder(cfg) 112 | self.init_alphas() 113 | if cfg.SOLVER.LOSS is not None: 114 | self.loss_dict = [] 115 | self.loss_weight = [] 116 | for loss_item, loss_weight in zip(cfg.SOLVER.LOSS, cfg.SOLVER.LOSS_WEIGHT): 117 | if 'ssim' in loss_item or 'grad' in loss_item: 118 | self.loss_dict.append(loss_dict[loss_item](channel=cfg.MODEL.IN_CHANNEL)) 119 | else: 120 | self.loss_dict.append(loss_dict[loss_item]()) 121 | self.loss_weight.append(loss_weight) 122 | 123 | else: 124 | self.loss_dict = None 125 | self.loss_weight = None 126 | 127 | 128 | def w_parameters(self): 129 | return [value for key, value in self.named_parameters() 130 | if 'arch' not in key and value.requires_grad] 131 | 132 | def a_parameters(self): 133 | a_params = [value for key, value in self.named_parameters() if 'arch' in key] 134 | return a_params 135 | 136 | def init_alphas(self): 137 | k = sum(2 + i for i in range(self.num_blocks)) 138 | num_ops = len(self.primitives) 139 | if self.tie_cell: 140 | self.arch_alphas = nn.Parameter(torch.ones(k, num_ops)) 141 | else: 142 | self.arch_alphas = nn.Parameter(torch.ones(self.num_layers, k, num_ops)) 143 | 144 | m = sum(min(l+1, self.num_strides) for l in range(self.num_layers)) 145 | beta_weights = torch.ones(m, 3) 146 | # mask out 147 | top_inds = [] 148 | btm_inds = [] 149 | start = 0 150 | for l in range(self.num_layers): 151 | top_inds.append(start) 152 | if l+1 < self.num_strides: 153 | start += l+1 154 | else: 155 | start += self.num_strides 156 | btm_inds.append(start-1) 157 | 158 | beta_weights[top_inds, 0] = -50 159 | beta_weights[btm_inds, 2] = -50 160 | self.arch_betas = nn.Parameter(beta_weights) 161 | self.score_func = F.softmax 162 | 163 | def scores(self): 164 | return (self.score_func(self.arch_alphas, dim=-1), 165 | self.score_func(self.arch_betas, dim=-1)) 166 | 167 | def forward(self, images, targets=None): 168 | 169 | alphas, betas = self.scores() 170 | 171 | # The first layer is different 172 | features = self.stem1(images) 173 | inputs_1 = [self.reduce1(features)] 174 | if self.activatioin_f.lower() == 'relu': 175 | features_t = F.relu(features) 176 | elif self.activatioin_f.lower() in ['leaky', 'prelu']: 177 | features_t = F.leaky_relu(features, negative_slope=0.2) 178 | elif self.activatioin_f.lower() == 'sine': 179 | features_t = torch.sin(features) 180 | features = self.stem2(features_t) 181 | 182 | hidden_states = [features] 183 | 184 | cell_ind = 0 185 | router_ind = 0 186 | for l in range(self.num_layers): 187 | # prepare next inputs 188 | inputs_0 = [0] * min(l + 2, self.num_strides) 189 | for i, hs in enumerate(hidden_states): 190 | # print('router {}: '.format(router_ind), self.cell_configs[router_ind]) 191 | h_0, h_1, h_2 = self.routers[router_ind](hs) 192 | # print(h_0 is None, h_1 is None, h_2 is None) 193 | # print(betas[router_ind]) 194 | if i > 0: 195 | inputs_0[i-1] = inputs_0[i-1] + h_0 * betas[router_ind][0] 196 | inputs_0[i] = inputs_0[i] + h_1 * betas[router_ind][1] 197 | if i < self.num_strides-1: 198 | inputs_0[i+1] = inputs_0[i+1] + h_2 * betas[router_ind][2] 199 | router_ind += 1 200 | 201 | # run cells 202 | hidden_states = [] 203 | for i, s0 in enumerate(inputs_0): 204 | # prepare next input 205 | if i >= len(inputs_1): 206 | # print("using upsampler {}.".format(i-1)) 207 | inputs_1.append(self.upsamplers[i-1](inputs_1[-1])) 208 | s1 = inputs_1[i] 209 | # print('cell: ', self.cell_configs[cell_ind]) 210 | if self.tie_cell: 211 | cell_weights = alphas 212 | else: 213 | cell_weights = alphas[l] 214 | hidden_states.append(self.cells[cell_ind](s0, s1, cell_weights)) 215 | cell_ind += 1 216 | 217 | inputs_1 = inputs_0 218 | 219 | # apply ASPP on hidden_state 220 | pred = self.decoder(hidden_states) 221 | if self.use_res: 222 | pred = images - pred 223 | pred = torch.sigmoid(pred) 224 | 225 | if self.training: 226 | if self.loss_dict is not None: 227 | loss = [] 228 | for loss_item, weight in zip(self.loss_dict, self.loss_weight): 229 | loss.append(loss_item(pred, targets) * weight) 230 | else: 231 | loss = F.mse_loss(pred, targets) 232 | return {'decoder_loss': sum(loss) / len(loss)} 233 | else: 234 | return pred 235 | 236 | def get_path_genotype(self, betas): 237 | # construct transition matrix 238 | trans = [] 239 | b_ind = 0 240 | for l in range(self.num_layers): 241 | layer = [] 242 | for i in range(self.num_strides): 243 | if i < l + 1: 244 | layer.append(betas[b_ind].detach().numpy().tolist()) 245 | b_ind += 1 246 | else: 247 | layer.append([0, 0, 0]) 248 | trans.append(layer) 249 | return viterbi(trans) 250 | 251 | def genotype(self): 252 | alphas, betas = self.scores() 253 | if self.tie_cell: 254 | gene_cell = self.cells[0].genotype(alphas) 255 | else: 256 | gene_cell = [] 257 | for i in range(self.num_layers): 258 | gene_cell.append(self.cells[0].genotype(alphas[i])) 259 | gene_path = self.get_path_genotype(betas) 260 | return gene_cell, gene_path 261 | -------------------------------------------------------------------------------- /one_stage_nas/modeling/common.py: -------------------------------------------------------------------------------- 1 | import operator 2 | 3 | from torch import nn 4 | import torch 5 | 6 | 7 | class Sine(nn.Module): 8 | def __init__(self, w0 = 1.): 9 | super().__init__() 10 | self.w0 = w0 11 | def forward(self, x): 12 | return torch.sin(self.w0 * x) 13 | 14 | 15 | def conv3x3_bn(inp, oup, stride, affine=True, activate_f='ReLU'): 16 | if activate_f is None: 17 | return nn.Sequential( 18 | nn.Conv2d(inp, oup, 3, stride, 1, bias=False), 19 | nn.BatchNorm2d(oup, affine=affine)) 20 | elif activate_f.lower() == 'relu': 21 | return nn.Sequential( 22 | nn.Conv2d(inp, oup, 3, stride, 1, bias=False), 23 | nn.BatchNorm2d(oup, affine=affine), 24 | nn.ReLU(inplace=True) 25 | ) 26 | elif activate_f.lower() == 'leaky': 27 | return nn.Sequential( 28 | nn.Conv2d(inp, oup, 3, stride, 1, bias=False), 29 | nn.BatchNorm2d(oup, affine=affine), 30 | nn.LeakyReLU(negative_slope=0.2, inplace=True) 31 | ) 32 | elif activate_f.lower() == 'prelu': 33 | return nn.Sequential( 34 | nn.Conv2d(inp, oup, 3, stride, 1, bias=False), 35 | nn.BatchNorm2d(oup, affine=affine), 36 | nn.PReLU() 37 | ) 38 | elif activate_f.lower() == 'sine': 39 | return nn.Sequential( 40 | nn.Conv2d(inp, oup, 3, stride, 1, bias=False), 41 | nn.BatchNorm2d(oup, affine=affine), 42 | Sine() 43 | ) 44 | 45 | 46 | 47 | def conv1x1_bn(inp, oup, stride, affine=True, activate_f='ReLU'): 48 | if activate_f is None: 49 | return nn.Sequential( 50 | nn.Conv2d(inp, oup, 1, stride, 0, bias=False), 51 | nn.BatchNorm2d(oup, affine=affine)) 52 | elif activate_f.lower() == 'relu': 53 | return nn.Sequential( 54 | nn.Conv2d(inp, oup, 1, stride, 0, bias=False), 55 | nn.BatchNorm2d(oup, affine=affine), 56 | nn.ReLU(inplace=True) 57 | ) 58 | elif activate_f.lower() == 'leaky': 59 | return nn.Sequential( 60 | nn.Conv2d(inp, oup, 1, stride, 0, bias=False), 61 | nn.BatchNorm2d(oup, affine=affine), 62 | nn.LeakyReLU(negative_slope=0.2, inplace=True) 63 | ) 64 | elif activate_f.lower() == 'prelu': 65 | return nn.Sequential( 66 | nn.Conv2d(inp, oup, 1, stride, 0, bias=False), 67 | nn.BatchNorm2d(oup, affine=affine), 68 | nn.PReLU() 69 | ) 70 | elif activate_f.lower() == 'sine': 71 | return nn.Sequential( 72 | nn.Conv2d(inp, oup, 1, stride, 0, bias=False), 73 | nn.BatchNorm2d(oup, affine=affine), 74 | Sine() 75 | ) 76 | 77 | 78 | def sep3x3_bn(inp, oup, rate=1, activate_f='ReLU'): 79 | if activate_f.lower() == 'relu': 80 | return nn.Sequential( 81 | nn.Conv2d(inp, inp, 3, stride=1, 82 | padding=rate, dilation=rate, groups=inp, 83 | bias=False), 84 | nn.BatchNorm2d(inp), 85 | nn.ReLU(inplace=True), 86 | nn.Conv2d(inp, oup, 1, bias=False), 87 | nn.BatchNorm2d(oup), 88 | nn.ReLU(inplace=True)) 89 | elif activate_f.lower() == 'leaky': 90 | return nn.Sequential( 91 | nn.Conv2d(inp, inp, 3, stride=1, 92 | padding=rate, dilation=rate, groups=inp, 93 | bias=False), 94 | nn.BatchNorm2d(inp), 95 | nn.LeakyReLU(negative_slope=0.2, inplace=True), 96 | nn.Conv2d(inp, oup, 1, bias=False), 97 | nn.BatchNorm2d(oup), 98 | nn.LeakyReLU(negative_slope=0.2, inplace=True)) 99 | elif activate_f.lower() == 'prelu': 100 | return nn.Sequential( 101 | nn.Conv2d(inp, inp, 3, stride=1, 102 | padding=rate, dilation=rate, groups=inp, 103 | bias=False), 104 | nn.BatchNorm2d(inp), 105 | nn.PReLU(), 106 | nn.Conv2d(inp, oup, 1, bias=False), 107 | nn.BatchNorm2d(oup), 108 | nn.PReLU()) 109 | elif activate_f.lower() == 'sine': 110 | return nn.Sequential( 111 | nn.Conv2d(inp, inp, 3, stride=1, 112 | padding=rate, dilation=rate, groups=inp, 113 | bias=False), 114 | nn.BatchNorm2d(inp), 115 | Sine(), 116 | nn.Conv2d(inp, oup, 1, bias=False), 117 | nn.BatchNorm2d(oup), 118 | Sine()) 119 | 120 | 121 | def viterbi(trans): 122 | """Dynamic programming to find the most likely path. 123 | 124 | Arguments: 125 | trans (LxSx3 array)""" 126 | prob = [1, 0, 0, 0] # keeps the path with highest prob 127 | probs = [prob] 128 | paths = [] 129 | for layer in trans: 130 | prob_next = [0, 0, 0, 0] 131 | path = [-1, -1, -1, -1] 132 | for i, stride in enumerate(layer): 133 | if i > 0: 134 | prob_up = stride[0] * prob[i] 135 | if prob_up > prob_next[i-1]: 136 | prob_next[i-1] = prob_up 137 | path[i-1] = 0 138 | prob_same = stride[1] * prob[i] 139 | if prob_same > prob_next[i]: 140 | prob_next[i] = prob_same 141 | path[i] = 1 142 | if i < 3: 143 | prob_down = stride[2] * prob[i] 144 | if prob_down > prob_next[i+1]: 145 | prob_next[i+1] = prob_down 146 | path[i+1] = 2 147 | prob = prob_next 148 | probs.append(prob) 149 | paths.append(path) 150 | 151 | max_ind, max_prob = max(enumerate(probs[-1]), key=operator.itemgetter(1)) 152 | 153 | ml_path = [max_ind] 154 | for i in range(len(paths) - 1, 0, -1): 155 | path = paths[i] 156 | ml_path.insert(0, max_ind - path[max_ind] + 1) 157 | max_ind = max_ind - path[max_ind] + 1 158 | print(ml_path) 159 | 160 | # check the prob 161 | ind = 0 162 | prob = 1 163 | for i, layer in enumerate(trans): 164 | next_ind = ml_path[i] 165 | stride = layer[ind] 166 | print(i, layer[ind]) 167 | prob = prob * stride[next_ind-ind+1] 168 | ind = next_ind 169 | 170 | assert(max_prob - prob < 0.00001) 171 | return ml_path 172 | -------------------------------------------------------------------------------- /one_stage_nas/modeling/dn_compnet.py: -------------------------------------------------------------------------------- 1 | """ 2 | Discrete structure of Auto-DeepLab 3 | 4 | Includes utils to convert continous Auto-DeepLab to discrete ones 5 | """ 6 | 7 | import os 8 | import torch 9 | from torch import nn 10 | from torch.nn import functional as F 11 | 12 | from one_stage_nas.darts.cell import FixCell 13 | from .dn_supernet import Dn_supernet 14 | from .common import conv3x3_bn, conv1x1_bn 15 | from .decoders import build_decoder 16 | from .loss import loss_dict 17 | 18 | 19 | def get_genotype_from_adl(cfg): 20 | # create ADL model 21 | adl_cfg = cfg.clone() 22 | adl_cfg.defrost() 23 | 24 | adl_cfg.merge_from_list(['MODEL.META_ARCHITECTURE', 'AutoDeepLab', 25 | 'MODEL.FILTER_MULTIPLIER', 8, 26 | 'MODEL.AFFINE', True, 27 | 'SEARCH.SEARCH_ON', True]) 28 | 29 | model = Dn_supernet(adl_cfg) 30 | # load weights 31 | SEARCH_RESULT_DIR = '/'.join((cfg.OUTPUT_DIR, cfg.DATASET.TASK, 32 | '{}/Outline-{}c{}n_TC-{}_ASPP-{}_Res-{}_Prim-{}'. 33 | format(cfg.DATASET.DATA_NAME, cfg.MODEL.NUM_LAYERS, cfg.MODEL.NUM_BLOCKS, 34 | cfg.SEARCH.TIE_CELL, cfg.MODEL.USE_ASPP, cfg.MODEL.USE_RES, cfg.MODEL.PRIMITIVES), 35 | 'search/models/model_best.pth')) 36 | ckpt = torch.load(SEARCH_RESULT_DIR) 37 | restore = {k: v for k, v in ckpt['model'].items() if 'arch' in k} 38 | model.load_state_dict(restore, strict=False) 39 | return model.genotype() 40 | 41 | 42 | class DeepLabScaler_Width(nn.Module): 43 | """Official implementation 44 | https://github.com/tensorflow/models/blob/master/research/deeplab/core/nas_cell.py#L90 45 | """ 46 | def __init__(self, inp, C, activate_f='ReLU'): 47 | super(DeepLabScaler_Width, self).__init__() 48 | self.activate_f = activate_f 49 | self.conv = conv1x1_bn(inp, C, 1, activate_f=None) 50 | 51 | def forward(self, hidden_state): 52 | if self.activate_f.lower() == 'relu': 53 | return self.conv(F.relu(hidden_state)) 54 | elif self.activate_f.lower() in ['leaky', 'prelu']: 55 | return self.conv(F.leaky_relu(hidden_state, negative_slope=0.2)) 56 | elif self.activate_f.lower() == 'sine': 57 | return self.conv(torch.sin(hidden_state)) 58 | 59 | 60 | class Dn_compnet(nn.Module): 61 | def __init__(self, cfg): 62 | super(Dn_compnet, self).__init__() 63 | 64 | # load genotype 65 | if len(cfg.DATASET.TRAIN_DATASETS) == 0: 66 | geno_file = '/'.join((cfg.OUTPUT_DIR, cfg.DATASET.TASK, 67 | '{}/Outline-{}c{}n_TC-{}_ASPP-{}_Res-{}_Prim-{}'. 68 | format(cfg.DATASET.DATA_NAME, cfg.MODEL.NUM_LAYERS, cfg.MODEL.NUM_BLOCKS, 69 | cfg.SEARCH.TIE_CELL, cfg.MODEL.USE_ASPP, cfg.MODEL.USE_RES, cfg.MODEL.PRIMITIVES), 70 | 'search/models/model_best.geno')) 71 | 72 | else: 73 | geno_file = '/'.join((cfg.OUTPUT_DIR, cfg.DATASET.TASK, 74 | '{}/Outline-{}c{}n_TC-{}_ASPP-{}_Res-{}_Prim-{}'. 75 | format(cfg.DATASET.DATA_NAME, cfg.MODEL.NUM_LAYERS, cfg.MODEL.NUM_BLOCKS, 76 | cfg.SEARCH.TIE_CELL, cfg.MODEL.USE_ASPP, cfg.MODEL.USE_RES, cfg.MODEL.PRIMITIVES), 77 | 'search/models/model_best.geno')) 78 | 79 | if os.path.exists(geno_file): 80 | print("Loading genotype from {}".format(geno_file)) 81 | genotype = torch.load(geno_file, map_location=torch.device("cpu")) 82 | else: 83 | genotype = get_genotype_from_adl(cfg) 84 | print("Saving genotype to {}".format(geno_file)) 85 | torch.save(genotype, geno_file) 86 | 87 | geno_cell, geno_path = genotype 88 | 89 | self.genotpe = genotype 90 | 91 | if 0 in geno_path: 92 | self.endpoint = (len(geno_path) - 1) - list(reversed(geno_path)).index(0) 93 | if self.endpoint == (len(geno_path) -1): 94 | self.endpoint = None 95 | else: 96 | self.endpoint = None 97 | 98 | # basic configs 99 | self.activate_f = cfg.MODEL.ACTIVATION_F 100 | self.use_res = cfg.MODEL.USE_RES 101 | self.f = cfg.MODEL.FILTER_MULTIPLIER 102 | self.num_layers = cfg.MODEL.NUM_LAYERS 103 | self.num_blocks = cfg.MODEL.NUM_BLOCKS 104 | self.num_strides = cfg.MODEL.NUM_STRIDES 105 | self.ws_factors = cfg.MODEL.WS_FACTORS 106 | self.in_channel = cfg.MODEL.IN_CHANNEL 107 | self.stem1 = conv3x3_bn(self.in_channel, 64, 1, activate_f=self.activate_f) 108 | self.stem2 = conv3x3_bn(64, 64, 1, activate_f=None) 109 | self.reduce = conv3x3_bn(64, self.f*self.num_blocks, 1, affine=False, activate_f=None) 110 | 111 | # create cells 112 | self.cells = nn.ModuleList() 113 | self.scalers = nn.ModuleList() 114 | if cfg.SEARCH.TIE_CELL: 115 | geno_cell = [geno_cell] * self.num_layers 116 | 117 | DeepLabScaler = DeepLabScaler_Width 118 | 119 | h_0 = 0 # prev hidden index 120 | h_1 = -1 # prev prev hidden index 121 | for layer, (geno, h_ind) in enumerate(zip(geno_cell, geno_path), 1): 122 | stride = self.ws_factors[h_ind] 123 | h = self.ws_factors[h_ind] 124 | self.cells.append(FixCell(geno, int(self.f * stride))) 125 | # scalers 126 | if layer == 1: 127 | inp0 = 64 128 | inp1 = 64 129 | elif layer == 2: 130 | inp0 = int(h_0 * self.f * self.num_blocks) 131 | inp1 = 64 132 | else: 133 | inp0 = int(h_0 * self.f * self.num_blocks) 134 | inp1 = int(h_1 * self.f * self.num_blocks) 135 | 136 | if layer == 1: 137 | scaler0 = DeepLabScaler(inp0, int(stride * self.f), activate_f=self.activate_f) 138 | scaler1 = DeepLabScaler(inp1, int(stride * self.f), activate_f=self.activate_f) 139 | else: 140 | scaler0 = DeepLabScaler(inp0, int(stride * self.f), activate_f=self.activate_f) 141 | scaler1 = DeepLabScaler(inp1, int(stride * self.f), activate_f=self.activate_f) 142 | 143 | h_1 = h_0 144 | h_0 = h 145 | self.scalers.append(scaler0) 146 | self.scalers.append(scaler1) 147 | self.decoder = build_decoder(cfg, out_strides=stride) 148 | if cfg.SOLVER.LOSS is not None: 149 | self.loss_dict = [] 150 | self.loss_weight = [] 151 | for loss_item, loss_weight in zip(cfg.SOLVER.LOSS, cfg.SOLVER.LOSS_WEIGHT): 152 | if 'ssim' in loss_item or 'grad' in loss_item: 153 | self.loss_dict.append(loss_dict[loss_item](channel=cfg.MODEL.IN_CHANNEL)) 154 | else: 155 | self.loss_dict.append(loss_dict[loss_item]()) 156 | self.loss_weight.append(loss_weight) 157 | 158 | else: 159 | self.loss_dict = None 160 | self.loss_weight = None 161 | 162 | def genotype(self): 163 | return self.genotpe 164 | 165 | def forward(self, images, targets=None, drop_prob=-1): 166 | if self.training and targets is None: 167 | raise ValueError("In training mode, targets should be passed.") 168 | 169 | h1 = self.stem1(images) 170 | if self.activate_f.lower() == 'relu': 171 | h0 = self.stem2(F.relu(h1)) 172 | elif self.activate_f.lower() in ['leaky', 'prelu']: 173 | h0 = self.stem2(F.leaky_relu(h1, negative_slope=0.2)) 174 | elif self.activate_f.lower() == 'sine': 175 | h0 = self.stem2(torch.sin(h1)) 176 | 177 | if self.endpoint==None: 178 | endpoint = self.reduce(h0) 179 | 180 | for i, cell in enumerate(self.cells): 181 | s0 = self.scalers[i*2](h0) 182 | s1 = self.scalers[i*2+1](h1) 183 | h1 = h0 184 | h0 = cell(s0, s1, drop_prob) 185 | if self.endpoint is not None and i == self.endpoint: 186 | endpoint = h0 187 | 188 | if self.activate_f.lower() == 'relu': 189 | pred = self.decoder([endpoint, F.relu(h0)]) 190 | elif self.activate_f.lower() in ['leaky', 'prelu']: 191 | pred = self.decoder([endpoint, F.leaky_relu(h0, negative_slope=0.2)]) 192 | elif self.activate_f.lower() == 'sine': 193 | pred= self.decoder([endpoint, torch.sin(h0)]) 194 | 195 | if self.use_res: 196 | pred = images-pred 197 | 198 | 199 | if self.training: 200 | if loss_dict is not None: 201 | loss = [] 202 | for loss_item, weight in zip(self.loss_dict, self.loss_weight): 203 | loss.append(loss_item(pred, targets) * weight) 204 | else: 205 | loss = F.mse_loss(pred, targets) 206 | return pred, {'decoder_loss': sum(loss) / len(loss)} 207 | 208 | else: 209 | return pred 210 | 211 | 212 | -------------------------------------------------------------------------------- /one_stage_nas/modeling/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.autograd import Variable 5 | import numpy as np 6 | from math import exp 7 | 8 | 9 | def gaussian(window_size, sigma): 10 | gauss = torch.Tensor([exp(-(x - window_size // 2) ** 2 / float(2 * sigma ** 2)) for x in range(window_size)]) 11 | return gauss / gauss.sum() 12 | 13 | 14 | def create_window(window_size, channel): 15 | _1D_window = gaussian(window_size, 1.5).unsqueeze(1) 16 | _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0) 17 | window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous()) 18 | return window 19 | 20 | 21 | class log_SSIM_loss(nn.Module): 22 | def __init__(self, window_size=11, channel=3, is_cuda=True, size_average=True): 23 | super(log_SSIM_loss, self).__init__() 24 | self.window_size = window_size 25 | self.channel = channel 26 | self.size_average = size_average 27 | self.window = create_window(window_size, channel) 28 | if is_cuda: 29 | self.window = self.window.cuda() 30 | 31 | 32 | def forward(self, img1, img2): 33 | mu1 = F.conv2d(img1, self.window, padding=self.window_size // 2, groups=self.channel) 34 | mu2 = F.conv2d(img2, self.window, padding=self.window_size // 2, groups=self.channel) 35 | 36 | mu1_sq = mu1.pow(2) 37 | mu2_sq = mu2.pow(2) 38 | mu1_mu2 = mu1 * mu2 39 | 40 | sigma1_sq = F.conv2d(img1 * img1, self.window, padding=self.window_size // 2, groups=self.channel) - mu1_sq 41 | sigma2_sq = F.conv2d(img2 * img2, self.window, padding=self.window_size // 2, groups=self.channel) - mu2_sq 42 | sigma12 = F.conv2d(img1 * img2, self.window, padding=self.window_size // 2, groups=self.channel) - mu1_mu2 43 | 44 | C1 = 0.01 ** 2 45 | C2 = 0.03 ** 2 46 | 47 | ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2)) 48 | 49 | return -torch.log10(ssim_map.mean()) 50 | 51 | 52 | class negative_SSIM_loss(nn.Module): 53 | def __init__(self, window_size=11, channel=3, is_cuda=True, size_average=True): 54 | super(negative_SSIM_loss, self).__init__() 55 | self.window_size = window_size 56 | self.channel = channel 57 | self.size_average = size_average 58 | self.window = create_window(window_size, channel) 59 | if is_cuda: 60 | self.window = self.window.cuda() 61 | 62 | 63 | def forward(self, img1, img2): 64 | mu1 = F.conv2d(img1, self.window, padding=self.window_size // 2, groups=self.channel) 65 | mu2 = F.conv2d(img2, self.window, padding=self.window_size // 2, groups=self.channel) 66 | 67 | mu1_sq = mu1.pow(2) 68 | mu2_sq = mu2.pow(2) 69 | mu1_mu2 = mu1 * mu2 70 | 71 | sigma1_sq = F.conv2d(img1 * img1, self.window, padding=self.window_size // 2, groups=self.channel) - mu1_sq 72 | sigma2_sq = F.conv2d(img2 * img2, self.window, padding=self.window_size // 2, groups=self.channel) - mu2_sq 73 | sigma12 = F.conv2d(img1 * img2, self.window, padding=self.window_size // 2, groups=self.channel) - mu1_mu2 74 | 75 | C1 = 0.01 ** 2 76 | C2 = 0.03 ** 2 77 | 78 | ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2)) 79 | 80 | return 1.0-ssim_map.mean() 81 | 82 | 83 | class GRAD_loss(nn.Module): 84 | def __init__(self, channel=3, is_cuda=True): 85 | super(GRAD_loss, self).__init__() 86 | self.edge_conv = nn.Conv2d(channel, channel*2, kernel_size=3, stride=1, padding=1, groups=channel, bias=False) 87 | edge_kx = np.array([[1, 0, -1], [2, 0, -2], [1, 0, -1]]) 88 | edge_ky = np.array([[1, 2, 1], [0, 0, 0], [-1, -2, -1]]) 89 | edge_k = [] 90 | for i in range(channel): 91 | edge_k.append(edge_kx) 92 | edge_k.append(edge_ky) 93 | 94 | edge_k = np.stack(edge_k) 95 | 96 | edge_k = torch.from_numpy(edge_k).float().view(channel*2, 1, 3, 3) 97 | self.edge_conv.weight = nn.Parameter(edge_k) 98 | for param in self.parameters(): 99 | param.requires_grad = False 100 | 101 | if is_cuda: self.edge_conv.cuda() 102 | 103 | def forward(self, img1, img2): 104 | img1_grad = self.edge_conv(img1) 105 | img2_grad = self.edge_conv(img2) 106 | 107 | return F.l1_loss(img1_grad, img2_grad) 108 | 109 | 110 | class exp_GRAD_loss(nn.Module): 111 | def __init__(self, channel=3, is_cuda=True): 112 | super(exp_GRAD_loss, self).__init__() 113 | self.edge_conv = nn.Conv2d(channel, channel*2, kernel_size=3, stride=1, padding=1, groups=channel, bias=False) 114 | edge_kx = np.array([[1, 0, -1], [2, 0, -2], [1, 0, -1]]) 115 | edge_ky = np.array([[1, 2, 1], [0, 0, 0], [-1, -2, -1]]) 116 | edge_k = [] 117 | for i in range(channel): 118 | edge_k.append(edge_kx) 119 | edge_k.append(edge_ky) 120 | 121 | edge_k = np.stack(edge_k) 122 | 123 | edge_k = torch.from_numpy(edge_k).float().view(channel*2, 1, 3, 3) 124 | self.edge_conv.weight = nn.Parameter(edge_k) 125 | for param in self.parameters(): 126 | param.requires_grad = False 127 | 128 | if is_cuda: self.edge_conv.cuda() 129 | 130 | def forward(self, img1, img2): 131 | img1_grad = self.edge_conv(img1) 132 | img2_grad = self.edge_conv(img2) 133 | 134 | return torch.exp(F.l1_loss(img1_grad, img2_grad)) - 1 135 | 136 | 137 | class log_PSNR_loss(torch.nn.Module): 138 | def __init__(self): 139 | super(log_PSNR_loss, self).__init__() 140 | 141 | def forward(self, img1, img2): 142 | diff = img1 - img2 143 | mse = diff*diff.mean() 144 | return -torch.log10(1.0-mse) 145 | 146 | 147 | class MSE_loss(torch.nn.Module): 148 | def __init__(self): 149 | super(MSE_loss, self).__init__() 150 | 151 | def forward(self, img1, img2): 152 | return F.mse_loss(img1, img2) 153 | 154 | 155 | class L1_loss(torch.nn.Module): 156 | def __init__(self): 157 | super(L1_loss, self).__init__() 158 | 159 | def forward(self, img1, img2): 160 | return F.l1_loss(img1, img2) 161 | 162 | 163 | loss_dict = { 164 | 'l1': L1_loss, 165 | 'mse': MSE_loss, 166 | 'grad': GRAD_loss, 167 | 'exp_grad': exp_GRAD_loss, 168 | 'log_ssim': log_SSIM_loss, 169 | 'neg_ssim': negative_SSIM_loss, 170 | 'log_psnr': log_PSNR_loss, 171 | } 172 | 173 | 174 | 175 | 176 | 177 | -------------------------------------------------------------------------------- /one_stage_nas/modeling/sid_compnet.py: -------------------------------------------------------------------------------- 1 | """ 2 | Discrete structure of Auto-DeepLab 3 | 4 | Includes utils to convert continous Auto-DeepLab to discrete ones 5 | """ 6 | 7 | import os 8 | import torch 9 | from torch import nn 10 | from torch.nn import functional as F 11 | 12 | from one_stage_nas.darts.cell import FixCell 13 | from .sid_supernet import Sid_supernet 14 | from .common import conv3x3_bn, conv1x1_bn 15 | from .decoders import build_decoder 16 | from .loss import loss_dict 17 | 18 | 19 | def get_genotype_from_adl(cfg): 20 | # create ADL model 21 | adl_cfg = cfg.clone() 22 | adl_cfg.defrost() 23 | 24 | adl_cfg.merge_from_list(['MODEL.META_ARCHITECTURE', 'Sid_supernet', 25 | 'MODEL.FILTER_MULTIPLIER', 8, 26 | 'MODEL.AFFINE', True, 27 | 'SEARCH.SEARCH_ON', True]) 28 | 29 | model = Sid_supernet(adl_cfg) 30 | # load weights 31 | SEARCH_RESULT_DIR = '/'.join((cfg.OUTPUT_DIR, cfg.DATASET.TASK, 32 | '{}/Outline-{}c{}n_TC-{}_ASPP-{}_Res-{}_Prim-{}'. 33 | format(cfg.DATASET.DATA_NAME, cfg.MODEL.NUM_LAYERS, cfg.MODEL.NUM_BLOCKS, 34 | cfg.SEARCH.TIE_CELL, cfg.MODEL.USE_ASPP, cfg.MODEL.USE_RES, cfg.MODEL.PRIMITIVES), 35 | 'search/models/model_best.pth')) 36 | ckpt = torch.load(SEARCH_RESULT_DIR) 37 | restore = {k: v for k, v in ckpt['model'].items() if 'arch' in k} 38 | model.load_state_dict(restore, strict=False) 39 | return model.genotype() 40 | 41 | 42 | class DeepLabScaler_Width(nn.Module): 43 | """Official implementation 44 | https://github.com/tensorflow/models/blob/master/research/deeplab/core/nas_cell.py#L90 45 | """ 46 | def __init__(self, scale, inp, C, activate_f='ReLU'): 47 | super(DeepLabScaler_Width, self).__init__() 48 | self.activate_f = activate_f 49 | self.scale = 2 ** scale 50 | self.conv = conv1x1_bn(inp, C, 1, activate_f=None) 51 | 52 | def forward(self, hidden_state): 53 | if self.activate_f.lower() == 'relu': 54 | return self.conv(F.relu(hidden_state)) 55 | elif self.activate_f.lower() in ['leaky', 'prelu']: 56 | return self.conv(F.leaky_relu(hidden_state, negative_slope=0.2)) 57 | elif self.activate_f.lower() == 'sine': 58 | return self.conv(torch.sin(hidden_state)) 59 | 60 | 61 | class Sid_compnet(nn.Module): 62 | def __init__(self, cfg): 63 | super(Sid_compnet, self).__init__() 64 | 65 | # load genotype 66 | if len(cfg.DATASET.TRAIN_DATASETS) == 0: 67 | geno_file = '/'.join((cfg.OUTPUT_DIR, cfg.DATASET.TASK, 68 | '{}/Outline-{}c{}n_TC-{}_ASPP-{}_Res-{}_Prim-{}'. 69 | format(cfg.DATASET.DATA_NAME, cfg.MODEL.NUM_LAYERS, cfg.MODEL.NUM_BLOCKS, 70 | cfg.SEARCH.TIE_CELL, cfg.MODEL.USE_ASPP, cfg.MODEL.USE_RES, cfg.MODEL.PRIMITIVES), 71 | 'search/models/model_best.geno')) 72 | 73 | else: 74 | geno_file = '/'.join((cfg.OUTPUT_DIR, cfg.DATASET.TASK, 75 | '{}/Outline-{}c{}n_TC-{}_ASPP-{}_Res-{}_Prim-{}'. 76 | format(cfg.DATASET.DATA_NAME, cfg.MODEL.NUM_LAYERS, cfg.MODEL.NUM_BLOCKS, 77 | cfg.SEARCH.TIE_CELL, cfg.MODEL.USE_ASPP, cfg.MODEL.USE_RES, cfg.MODEL.PRIMITIVES), 78 | 'search/models/model_best.geno')) 79 | 80 | if os.path.exists(geno_file): 81 | print("Loading genotype from {}".format(geno_file)) 82 | genotype = torch.load(geno_file, map_location=torch.device("cpu")) 83 | else: 84 | genotype = get_genotype_from_adl(cfg) 85 | print("Saving genotype to {}".format(geno_file)) 86 | torch.save(genotype, geno_file) 87 | 88 | geno_cell, geno_path = genotype 89 | 90 | self.genotpe = genotype 91 | 92 | if 0 in geno_path: 93 | self.endpoint = (len(geno_path) - 1) - list(reversed(geno_path)).index(0) 94 | if self.endpoint == (len(geno_path) -1): 95 | self.endpoint = None 96 | else: 97 | self.endpoint = None 98 | 99 | # basic configs 100 | self.activate_f = cfg.MODEL.ACTIVATION_F 101 | self.use_res = cfg.MODEL.USE_RES 102 | self.res = cfg.MODEL.RES 103 | self.f = cfg.MODEL.FILTER_MULTIPLIER 104 | self.num_layers = cfg.MODEL.NUM_LAYERS 105 | self.num_blocks = cfg.MODEL.NUM_BLOCKS 106 | self.num_strides = cfg.MODEL.NUM_STRIDES 107 | self.in_channel = cfg.MODEL.IN_CHANNEL 108 | self.stem1 = conv3x3_bn(self.in_channel, 64, 1, activate_f=self.activate_f) 109 | self.stem2 = conv3x3_bn(64, 64, 1, activate_f=None) 110 | self.reduce = conv3x3_bn(64, self.f*self.num_blocks, 1, affine=False, activate_f=None) 111 | 112 | # create cells 113 | self.cells = nn.ModuleList() 114 | self.scalers = nn.ModuleList() 115 | if cfg.SEARCH.TIE_CELL: 116 | geno_cell = [geno_cell] * self.num_layers 117 | 118 | DeepLabScaler = DeepLabScaler_Width 119 | 120 | h_0 = 0 # prev hidden index 121 | h_1 = -1 # prev prev hidden index 122 | for layer, (geno, h) in enumerate(zip(geno_cell, geno_path), 1): 123 | stride = 2 ** h 124 | self.cells.append(FixCell(geno, self.f * stride)) 125 | # scalers 126 | if layer == 1: 127 | inp0 = 64 128 | inp1 = 64 129 | elif layer == 2: 130 | inp0 = 2 ** h_0 * self.f * self.num_blocks 131 | inp1 = 64 132 | else: 133 | inp0 = 2 ** h_0 * self.f * self.num_blocks 134 | inp1 = 2 ** h_1 * self.f * self.num_blocks 135 | 136 | if layer == 1: 137 | scaler0 = DeepLabScaler(h_0 - h, inp0, 138 | stride * self.f, activate_f=self.activate_f) 139 | scaler1 = DeepLabScaler(h_0 - h, inp1, 140 | stride * self.f, activate_f=self.activate_f) 141 | else: 142 | scaler0 = DeepLabScaler(h_0 - h, inp0, 143 | stride * self.f, activate_f=self.activate_f) 144 | scaler1 = DeepLabScaler(h_1 - h, inp1, 145 | stride * self.f, activate_f=self.activate_f) 146 | 147 | h_1 = h_0 148 | h_0 = h 149 | self.scalers.append(scaler0) 150 | self.scalers.append(scaler1) 151 | self.decoder = build_decoder(cfg, out_strides=stride) 152 | if cfg.SOLVER.LOSS is not None: 153 | self.loss_dict = [] 154 | self.loss_weight = [] 155 | for loss_item, loss_weight in zip(cfg.SOLVER.LOSS, cfg.SOLVER.LOSS_WEIGHT): 156 | if 'ssim' in loss_item or 'grad' in loss_item: 157 | self.loss_dict.append(loss_dict[loss_item](channel=3)) 158 | else: 159 | self.loss_dict.append(loss_dict[loss_item]()) 160 | self.loss_weight.append(loss_weight) 161 | 162 | else: 163 | self.loss_dict = None 164 | self.loss_weight = None 165 | 166 | def genotype(self): 167 | return self.genotpe 168 | 169 | def forward(self, images, targets=None, drop_prob=-1): 170 | if self.training and targets is None: 171 | raise ValueError("In training mode, targets should be passed.") 172 | 173 | input_arw = images[0] 174 | input_rgb = images[1] 175 | h1 = self.stem1(input_arw) 176 | if self.activate_f.lower() == 'relu': 177 | h0 = self.stem2(F.relu(h1)) 178 | elif self.activate_f.lower() in ['leaky', 'prelu']: 179 | h0 = self.stem2(F.leaky_relu(h1, negative_slope=0.2)) 180 | elif self.activate_f.lower() == 'sine': 181 | h0 = self.stem2(torch.sin(h1)) 182 | 183 | if self.endpoint==None: 184 | endpoint = self.reduce(h0) 185 | 186 | for i, cell in enumerate(self.cells): 187 | s0 = self.scalers[i*2](h0) 188 | s1 = self.scalers[i*2+1](h1) 189 | h1 = h0 190 | h0 = cell(s0, s1, drop_prob) 191 | if self.endpoint is not None and i == self.endpoint: 192 | endpoint = h0 193 | 194 | if self.activate_f.lower() == 'relu': 195 | pred = self.decoder([endpoint, F.relu(h0)]) 196 | elif self.activate_f.lower() in ['leaky', 'prelu']: 197 | pred = self.decoder([endpoint, F.leaky_relu(h0, negative_slope=0.2)]) 198 | elif self.activate_f.lower() == 'sine': 199 | pred= self.decoder([endpoint, torch.sin(h0)]) 200 | 201 | if self.use_res and self.res == 'add': 202 | pred = input_rgb - pred 203 | elif self.use_res and self.res == 'mul': 204 | pred = input_rgb * pred 205 | 206 | if self.training: 207 | if loss_dict is not None: 208 | loss = [] 209 | for loss_item, weight in zip(self.loss_dict, self.loss_weight): 210 | loss.append(loss_item(pred, targets) * weight) 211 | else: 212 | loss = F.mse_loss(pred, targets) 213 | return pred, {'decoder_loss': sum(loss) / len(loss)} 214 | 215 | else: 216 | return pred 217 | 218 | -------------------------------------------------------------------------------- /one_stage_nas/modeling/sr_compnet.py: -------------------------------------------------------------------------------- 1 | """ 2 | Discrete structure of Auto-DeepLab 3 | 4 | Includes utils to convert continous Auto-DeepLab to discrete ones 5 | """ 6 | 7 | import os 8 | import torch 9 | from torch import nn 10 | from torch.nn import functional as F 11 | 12 | from one_stage_nas.darts.cell import FixCell 13 | from .sr_supernet import Sr_supernet 14 | from .common import conv3x3_bn, conv1x1_bn 15 | from .decoders import build_decoder 16 | from .loss import loss_dict 17 | 18 | 19 | def get_genotype_from_adl(cfg): 20 | # create ADL model 21 | adl_cfg = cfg.clone() 22 | adl_cfg.defrost() 23 | 24 | adl_cfg.merge_from_list(['MODEL.META_ARCHITECTURE', 'Sr_supernet', 25 | 'MODEL.FILTER_MULTIPLIER', 8, 26 | 'MODEL.AFFINE', True, 27 | 'SEARCH.SEARCH_ON', True]) 28 | 29 | model = Sr_supernet(adl_cfg) 30 | # load weights 31 | SEARCH_RESULT_DIR = '/'.join((cfg.OUTPUT_DIR, cfg.DATASET.TASK, 32 | '{}/Outline-{}c{}n_TC-{}_ASPP-{}_Res-{}_Prim-{}'. 33 | format(cfg.DATASET.DATA_NAME, cfg.MODEL.NUM_LAYERS, cfg.MODEL.NUM_BLOCKS, 34 | cfg.SEARCH.TIE_CELL, cfg.MODEL.USE_ASPP, cfg.MODEL.USE_RES, cfg.MODEL.PRIMITIVES), 35 | 'search/models/model_best.pth')) 36 | ckpt = torch.load(SEARCH_RESULT_DIR) 37 | restore = {k: v for k, v in ckpt['model'].items() if 'arch' in k} 38 | model.load_state_dict(restore, strict=False) 39 | return model.genotype() 40 | 41 | 42 | class DeepLabScaler_Width(nn.Module): 43 | """Official implementation 44 | https://github.com/tensorflow/models/blob/master/research/deeplab/core/nas_cell.py#L90 45 | """ 46 | def __init__(self, inp, C, activate_f='ReLU'): 47 | super(DeepLabScaler_Width, self).__init__() 48 | self.activate_f = activate_f 49 | self.conv = conv1x1_bn(inp, C, 1, activate_f=None) 50 | 51 | def forward(self, hidden_state): 52 | if self.activate_f.lower() == 'relu': 53 | return self.conv(F.relu(hidden_state)) 54 | elif self.activate_f.lower() in ['leaky', 'prelu']: 55 | return self.conv(F.leaky_relu(hidden_state, negative_slope=0.2)) 56 | elif self.activate_f.lower() == 'sine': 57 | return self.conv(torch.sin(hidden_state)) 58 | 59 | 60 | class Sr_compnet(nn.Module): 61 | def __init__(self, cfg): 62 | super(Sr_compnet, self).__init__() 63 | 64 | # load genotype 65 | if len(cfg.DATASET.TRAIN_DATASETS) == 0: 66 | geno_file = '/'.join((cfg.OUTPUT_DIR, cfg.DATASET.TASK, 67 | '{}/Outline-{}c{}n_TC-{}_ASPP-{}_Res-{}_Prim-{}'. 68 | format(cfg.DATASET.DATA_NAME, cfg.MODEL.NUM_LAYERS, cfg.MODEL.NUM_BLOCKS, 69 | cfg.SEARCH.TIE_CELL, cfg.MODEL.USE_ASPP, cfg.MODEL.USE_RES, cfg.MODEL.PRIMITIVES), 70 | 'search/models/model_best.geno')) 71 | 72 | else: 73 | geno_file = '/'.join((cfg.OUTPUT_DIR, cfg.DATASET.TASK, 74 | '{}/Outline-{}c{}n_TC-{}_ASPP-{}_Res-{}_Prim-{}'. 75 | format(cfg.DATASET.DATA_NAME, cfg.MODEL.NUM_LAYERS, cfg.MODEL.NUM_BLOCKS, 76 | cfg.SEARCH.TIE_CELL, cfg.MODEL.USE_ASPP, cfg.MODEL.USE_RES, cfg.MODEL.PRIMITIVES), 77 | 'search/models/model_best.geno')) 78 | 79 | if os.path.exists(geno_file): 80 | print("Loading genotype from {}".format(geno_file)) 81 | genotype = torch.load(geno_file, map_location=torch.device("cpu")) 82 | 83 | # # topological structure 84 | # genotype[0][0][2] = ('con_c_3x3_leaky', 0) 85 | # genotype[0][0][4] = ('dil_c_5x5_leaky', 2) 86 | # genotype[0][1][2] = ('con_c_3x3_leaky', 1) 87 | # genotype[0][1][5] = ('con_c_3x3_leaky', 2) 88 | 89 | 90 | # # operation 91 | # genotype[0][0][1] = ('con_c_5x5_leaky', 1) 92 | # genotype[0][0][4] = ('sep_c_3x3_leaky', 0) 93 | # genotype[0][1][3] = ('sep_c_3x3_leaky', 2) 94 | # genotype[0][1][5] = ('con_c_5x5_leaky', 1) 95 | 96 | # print('the founded result has been modified') 97 | 98 | else: 99 | genotype = get_genotype_from_adl(cfg) 100 | print("Saving genotype to {}".format(geno_file)) 101 | torch.save(genotype, geno_file) 102 | 103 | geno_cell, geno_path = genotype 104 | 105 | self.genotpe = genotype 106 | 107 | if 0 in geno_path: 108 | self.endpoint = (len(geno_path) - 1) - list(reversed(geno_path)).index(0) 109 | if self.endpoint == (len(geno_path) -1): 110 | self.endpoint = None 111 | else: 112 | self.endpoint = None 113 | 114 | # basic configs 115 | self.activate_f = cfg.MODEL.ACTIVATION_F 116 | self.use_res = cfg.MODEL.USE_RES 117 | self.res = cfg.MODEL.RES 118 | self.s_factor = cfg.DATALOADER.S_FACTOR 119 | self.f = cfg.MODEL.FILTER_MULTIPLIER 120 | self.num_layers = cfg.MODEL.NUM_LAYERS 121 | self.num_blocks = cfg.MODEL.NUM_BLOCKS 122 | self.num_strides = cfg.MODEL.NUM_STRIDES 123 | self.ws_factors = cfg.MODEL.WS_FACTORS 124 | self.in_channel = cfg.MODEL.IN_CHANNEL 125 | self.stem1 = conv3x3_bn(self.in_channel, 64, 1, activate_f=self.activate_f) 126 | self.stem2 = conv3x3_bn(64, 64, 1, activate_f=None) 127 | self.reduce = conv3x3_bn(64, self.f*self.num_blocks, 1, affine=False, activate_f=None) 128 | 129 | # create cells 130 | self.cells = nn.ModuleList() 131 | self.scalers = nn.ModuleList() 132 | if cfg.SEARCH.TIE_CELL: 133 | geno_cell = [geno_cell] * self.num_layers 134 | 135 | DeepLabScaler = DeepLabScaler_Width 136 | 137 | h_0 = 1 # prev hidden index 138 | h_1 = -1 # prev prev hidden index 139 | for layer, (geno, h_ind) in enumerate(zip(geno_cell, geno_path), 1): 140 | stride = self.ws_factors[h_ind] 141 | h = self.ws_factors[h_ind] 142 | self.cells.append(FixCell(geno, int(self.f * stride))) 143 | # scalers 144 | if layer == 1: 145 | inp0 = 64 146 | inp1 = 64 147 | elif layer == 2: 148 | inp0 = int(h_0 * self.f * self.num_blocks) 149 | inp1 = 64 150 | else: 151 | inp0 = int(h_0 * self.f * self.num_blocks) 152 | inp1 = int(h_1 * self.f * self.num_blocks) 153 | 154 | if layer == 1: 155 | scaler0 = DeepLabScaler(inp0, int(stride * self.f), activate_f=self.activate_f) 156 | scaler1 = DeepLabScaler(inp1, int(stride * self.f), activate_f=self.activate_f) 157 | else: 158 | scaler0 = DeepLabScaler(inp0, int(stride * self.f), activate_f=self.activate_f) 159 | scaler1 = DeepLabScaler(inp1, int(stride * self.f), activate_f=self.activate_f) 160 | 161 | h_1 = h_0 162 | h_0 = h 163 | self.scalers.append(scaler0) 164 | self.scalers.append(scaler1) 165 | self.decoder = build_decoder(cfg, out_strides=stride) 166 | if cfg.SOLVER.LOSS is not None: 167 | self.loss_dict = [] 168 | self.loss_weight = [] 169 | for loss_item, loss_weight in zip(cfg.SOLVER.LOSS, cfg.SOLVER.LOSS_WEIGHT): 170 | if 'ssim' in loss_item or 'grad' in loss_item: 171 | self.loss_dict.append(loss_dict[loss_item](channel=3, is_cuda=False)) 172 | else: 173 | self.loss_dict.append(loss_dict[loss_item]()) 174 | self.loss_weight.append(loss_weight) 175 | 176 | else: 177 | self.loss_dict = None 178 | self.loss_weight = None 179 | 180 | def genotype(self): 181 | return self.genotpe 182 | 183 | def forward(self, images, targets=None, drop_prob=-1): 184 | if self.training and targets is None: 185 | raise ValueError("In training mode, targets should be passed.") 186 | 187 | h1 = self.stem1(images) 188 | if self.activate_f.lower() == 'relu': 189 | h0 = self.stem2(F.relu(h1)) 190 | elif self.activate_f.lower() in ['leaky', 'prelu']: 191 | h0 = self.stem2(F.leaky_relu(h1, negative_slope=0.2)) 192 | elif self.activate_f.lower() == 'sine': 193 | h0 = self.stem2(torch.sin(h1)) 194 | 195 | if self.endpoint==None: 196 | endpoint = self.reduce(h0) 197 | 198 | for i, cell in enumerate(self.cells): 199 | s0 = self.scalers[i*2](h0) 200 | s1 = self.scalers[i*2+1](h1) 201 | h1 = h0 202 | h0 = cell(s0, s1, drop_prob) 203 | if self.endpoint is not None and i == self.endpoint: 204 | endpoint = h0 205 | 206 | if self.activate_f.lower() == 'relu': 207 | pred = self.decoder([endpoint, F.relu(h0)]) 208 | elif self.activate_f.lower() in ['leaky', 'prelu']: 209 | pred = self.decoder([endpoint, F.leaky_relu(h0, negative_slope=0.2)]) 210 | elif self.activate_f.lower() == 'sine': 211 | pred= self.decoder([endpoint, torch.sin(h0)]) 212 | 213 | if self.use_res: 214 | pred = F.interpolate(images, size=pred.size()[-2:], mode='bicubic') + pred 215 | 216 | if self.training: 217 | if loss_dict is not None: 218 | loss = [] 219 | for loss_item, weight in zip(self.loss_dict, self.loss_weight): 220 | loss.append(loss_item(pred, targets) * weight) 221 | else: 222 | loss = F.mse_loss(pred, targets) 223 | return pred, {'decoder_loss': sum(loss) / len(loss)} 224 | 225 | else: 226 | return pred 227 | 228 | -------------------------------------------------------------------------------- /one_stage_nas/solver/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 2 | from .build import make_optimizer 3 | from .build import make_lr_scheduler 4 | from .lr_scheduler import WarmupMultiStepLR 5 | -------------------------------------------------------------------------------- /one_stage_nas/solver/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hkzhang-git/HiNAS/fc56724fde81870ac90bd6d87075f0d1284b1f8c/one_stage_nas/solver/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /one_stage_nas/solver/__pycache__/build.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hkzhang-git/HiNAS/fc56724fde81870ac90bd6d87075f0d1284b1f8c/one_stage_nas/solver/__pycache__/build.cpython-36.pyc -------------------------------------------------------------------------------- /one_stage_nas/solver/__pycache__/lr_scheduler.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hkzhang-git/HiNAS/fc56724fde81870ac90bd6d87075f0d1284b1f8c/one_stage_nas/solver/__pycache__/lr_scheduler.cpython-36.pyc -------------------------------------------------------------------------------- /one_stage_nas/solver/build.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 2 | import torch 3 | 4 | from .lr_scheduler import WarmupMultiStepLR, PolynormialLR 5 | from .lr_scheduler import PolyCosineAnnealingLR 6 | 7 | 8 | class OptimizerDict(dict): 9 | 10 | def __init__(self, *args, **kwargs): 11 | super(OptimizerDict, self).__init__(*args, **kwargs) 12 | 13 | def state_dict(self): 14 | return [optim.state_dict() for optim in self.values()] 15 | 16 | def load_state_dict(self, state_dicts): 17 | for state_dict, optim in zip(state_dicts, self.values()): 18 | optim.load_state_dict(state_dict) 19 | for state in optim.state.values(): 20 | for k, v in state.items(): 21 | if isinstance(v, torch.Tensor): 22 | state[k] = v.cuda() 23 | 24 | 25 | def make_optimizer(cfg, model): 26 | if cfg.SEARCH.SEARCH_ON: 27 | return make_search_optimizers(cfg, model) 28 | else: 29 | return make_normal_optimizer(cfg, model) 30 | 31 | 32 | def make_normal_optimizer(cfg, model): 33 | params = [] 34 | for key, value in model.named_parameters(): 35 | if not value.requires_grad: 36 | continue 37 | lr = cfg.SOLVER.TRAIN.INIT_LR 38 | weight_decay = cfg.SOLVER.WEIGHT_DECAY 39 | if "bias" in key: 40 | lr = cfg.SOLVER.TRAIN.INIT_LR * cfg.SOLVER.BIAS_LR_FACTOR 41 | weight_decay = cfg.SOLVER.WEIGHT_DECAY_BIAS 42 | params += [{"params": [value], "lr": lr, "weight_decay": weight_decay}] 43 | 44 | optimizer = torch.optim.SGD(params, lr, momentum=cfg.SOLVER.MOMENTUM) 45 | return optimizer 46 | 47 | 48 | def make_search_optimizers(cfg, model): 49 | lr = cfg.SOLVER.SEARCH.LR_START 50 | 51 | optim_w = torch.optim.SGD(model.w_parameters(), lr, 52 | momentum=cfg.SOLVER.SEARCH.MOMENTUM, 53 | weight_decay=cfg.SOLVER.SEARCH.WEIGHT_DECAY) 54 | optim_a = torch.optim.Adam(model.a_parameters(), 55 | lr=cfg.SOLVER.SEARCH.LR_A, 56 | weight_decay=cfg.SOLVER.SEARCH.WD_A) 57 | return OptimizerDict(optim_w=optim_w, optim_a=optim_a) 58 | 59 | 60 | def make_search_lr_scheduler(cfg, optimizer_dict): 61 | optimizer = optimizer_dict['optim_w'] 62 | 63 | return PolyCosineAnnealingLR( 64 | optimizer, 65 | max_iter=cfg.SOLVER.MAX_EPOCH, 66 | T_max=cfg.SOLVER.SEARCH.T_MAX, 67 | eta_min=cfg.SOLVER.SEARCH.LR_END 68 | ) 69 | 70 | 71 | def make_lr_scheduler(cfg, optimizer): 72 | if cfg.SEARCH.SEARCH_ON: 73 | return make_search_lr_scheduler(cfg, optimizer) 74 | if cfg.SOLVER.SCHEDULER == 'poly': 75 | power = cfg.SOLVER.TRAIN.POWER 76 | max_iter = cfg.SOLVER.TRAIN.MAX_ITER 77 | return PolynormialLR(optimizer, max_iter, power) 78 | return WarmupMultiStepLR( 79 | optimizer, 80 | cfg.SOLVER.STEPS, 81 | cfg.SOLVER.GAMMA, 82 | warmup_factor=cfg.SOLVER.WARMUP_FACTOR, 83 | warmup_iters=cfg.SOLVER.WARMUP_ITERS, 84 | warmup_method=cfg.SOLVER.WARMUP_METHOD, 85 | ) 86 | -------------------------------------------------------------------------------- /one_stage_nas/solver/lr_scheduler.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 2 | from bisect import bisect_right 3 | import math 4 | 5 | import torch 6 | from torch import nn, optim 7 | 8 | 9 | # FIXME ideally this would be achieved with a CombinedLRScheduler, 10 | # separating MultiStepLR with WarmupLR 11 | # but the current LRScheduler design doesn't allow it 12 | class WarmupMultiStepLR(torch.optim.lr_scheduler._LRScheduler): 13 | def __init__( 14 | self, 15 | optimizer, 16 | milestones, 17 | gamma=0.1, 18 | warmup_factor=1.0 / 3, 19 | warmup_iters=500, 20 | warmup_method="linear", 21 | last_epoch=-1, 22 | ): 23 | if not list(milestones) == sorted(milestones): 24 | raise ValueError( 25 | "Milestones should be a list of" " increasing integers. Got {}", 26 | milestones, 27 | ) 28 | 29 | if warmup_method not in ("constant", "linear"): 30 | raise ValueError( 31 | "Only 'constant' or 'linear' warmup_method accepted" 32 | "got {}".format(warmup_method) 33 | ) 34 | self.milestones = milestones 35 | self.gamma = gamma 36 | self.warmup_factor = warmup_factor 37 | self.warmup_iters = warmup_iters 38 | self.warmup_method = warmup_method 39 | super(WarmupMultiStepLR, self).__init__(optimizer, last_epoch) 40 | 41 | def get_lr(self): 42 | warmup_factor = 1 43 | if self.last_epoch < self.warmup_iters: 44 | if self.warmup_method == "constant": 45 | warmup_factor = self.warmup_factor 46 | elif self.warmup_method == "linear": 47 | alpha = self.last_epoch / self.warmup_iters 48 | warmup_factor = self.warmup_factor * (1 - alpha) + alpha 49 | return [ 50 | base_lr 51 | * warmup_factor 52 | * self.gamma ** bisect_right(self.milestones, self.last_epoch) 53 | for base_lr in self.base_lrs 54 | ] 55 | 56 | 57 | class PolynormialLR(torch.optim.lr_scheduler._LRScheduler): 58 | def __init__( 59 | self, 60 | optimizer, 61 | max_iter, 62 | power=0.9, 63 | last_epoch=-1, 64 | ): 65 | self.max_iter = max_iter 66 | self.power = power 67 | super(PolynormialLR, self).__init__(optimizer, last_epoch) 68 | 69 | def get_lr(self): 70 | return [ 71 | base_lr 72 | * (1 - self.last_epoch / self.max_iter) ** self.power 73 | for base_lr in self.base_lrs 74 | ] 75 | 76 | 77 | class PolyCosineAnnealingLR(torch.optim.lr_scheduler._LRScheduler): 78 | def __init__(self, optimizer, max_iter, T_max, eta_min=0, power=0.9, last_epoch=-1): 79 | self.max_iter = max_iter 80 | self.power = power 81 | self.T_max = T_max 82 | self.eta_min = eta_min 83 | super(PolyCosineAnnealingLR, self).__init__(optimizer, last_epoch) 84 | 85 | def get_lr(self): 86 | return [self.eta_min + (base_lr - self.eta_min) * 87 | (1 + math.cos(math.pi * self.last_epoch / self.T_max)) / 2 88 | * (1 - self.last_epoch / self.max_iter) ** self.power 89 | for base_lr in self.base_lrs] 90 | 91 | 92 | # Compare with legacy implementation 93 | class LegacyCosineAnnealingLR(torch.optim.lr_scheduler._LRScheduler): 94 | def __init__(self, optimizer, T_max, eta_min=0, last_epoch=-1): 95 | self.T_max = T_max 96 | self.eta_min = eta_min 97 | super(LegacyCosineAnnealingLR, self).__init__(optimizer, last_epoch) 98 | 99 | def get_lr(self): 100 | return [self.eta_min + (base_lr - self.eta_min) * 101 | (1 + math.cos(math.pi * self.last_epoch / self.T_max)) / 2 102 | for base_lr in self.base_lrs] 103 | 104 | 105 | def main(): 106 | # test cosine 107 | # Test new scheduler 108 | print("test new scheduler") 109 | model = nn.Linear(10, 2) 110 | optimizer = optim.SGD(model.parameters(), lr=1.) 111 | steps = 10 112 | scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, steps) 113 | 114 | try: 115 | for epoch in range(2): 116 | for idx in range(steps): 117 | scheduler.step() 118 | print(scheduler.get_lr()) 119 | print(optimizer.param_groups[0]['lr']) 120 | except ZeroDivisionError as e: 121 | print(e) 122 | 123 | print("test old scheduler") 124 | model = nn.Linear(10, 2) 125 | optimizer = optim.SGD(model.parameters(), lr=1.) 126 | steps = 10 127 | scheduler = LegacyCosineAnnealingLR(optimizer, steps) 128 | 129 | for epoch in range(2): 130 | for idx in range(steps): 131 | scheduler.step() 132 | print(scheduler.get_lr()) 133 | print(optimizer.param_groups[0]['lr']) 134 | 135 | if __name__ == "__main__": 136 | main() 137 | -------------------------------------------------------------------------------- /one_stage_nas/utils/SSIM.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torch.autograd import Variable 4 | import numpy as np 5 | from math import exp 6 | 7 | def gaussian(window_size, sigma): 8 | gauss = torch.Tensor([exp(-(x - window_size//2)**2/float(2*sigma**2)) for x in range(window_size)]) 9 | return gauss/gauss.sum() 10 | 11 | def create_window(window_size, channel): 12 | _1D_window = gaussian(window_size, 1.5).unsqueeze(1) 13 | _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0) 14 | window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous()) 15 | return window 16 | 17 | def _ssim(img1, img2, window, window_size, channel, size_average = True): 18 | mu1 = F.conv2d(img1, window, padding = window_size//2, groups = channel) 19 | mu2 = F.conv2d(img2, window, padding = window_size//2, groups = channel) 20 | 21 | mu1_sq = mu1.pow(2) 22 | mu2_sq = mu2.pow(2) 23 | mu1_mu2 = mu1*mu2 24 | 25 | sigma1_sq = F.conv2d(img1*img1, window, padding = window_size//2, groups = channel) - mu1_sq 26 | sigma2_sq = F.conv2d(img2*img2, window, padding = window_size//2, groups = channel) - mu2_sq 27 | sigma12 = F.conv2d(img1*img2, window, padding = window_size//2, groups = channel) - mu1_mu2 28 | 29 | C1 = 0.01**2 30 | C2 = 0.03**2 31 | 32 | ssim_map = ((2*mu1_mu2 + C1)*(2*sigma12 + C2))/((mu1_sq + mu2_sq + C1)*(sigma1_sq + sigma2_sq + C2)) 33 | 34 | if size_average: 35 | return ssim_map.mean() 36 | else: 37 | return ssim_map.mean(1).mean(1).mean(1) 38 | 39 | class SSIM(torch.nn.Module): 40 | def __init__(self, window_size = 11, size_average = True): 41 | super(SSIM, self).__init__() 42 | self.window_size = window_size 43 | self.size_average = size_average 44 | self.channel = 1 45 | self.window = create_window(window_size, self.channel) 46 | 47 | def forward(self, img1, img2): 48 | (_, channel, _, _) = img1.size() 49 | 50 | if channel == self.channel and self.window.data.type() == img1.data.type(): 51 | window = self.window 52 | else: 53 | window = create_window(self.window_size, channel) 54 | 55 | if img1.is_cuda: 56 | window = window.cuda(img1.get_device()) 57 | window = window.type_as(img1) 58 | 59 | self.window = window 60 | self.channel = channel 61 | 62 | 63 | return _ssim(img1, img2, window, self.window_size, channel, self.size_average) 64 | 65 | def ssim(img1, img2, window_size = 11, size_average = True): 66 | (_, channel, _, _) = img1.size() 67 | window = create_window(window_size, channel) 68 | 69 | if img1.is_cuda: 70 | window = window.cuda(img1.get_device()) 71 | window = window.type_as(img1) 72 | 73 | return _ssim(img1, img2, window, window_size, channel, size_average) 74 | -------------------------------------------------------------------------------- /one_stage_nas/utils/__pycache__/checkpoint.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hkzhang-git/HiNAS/fc56724fde81870ac90bd6d87075f0d1284b1f8c/one_stage_nas/utils/__pycache__/checkpoint.cpython-36.pyc -------------------------------------------------------------------------------- /one_stage_nas/utils/__pycache__/comm.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hkzhang-git/HiNAS/fc56724fde81870ac90bd6d87075f0d1284b1f8c/one_stage_nas/utils/__pycache__/comm.cpython-36.pyc -------------------------------------------------------------------------------- /one_stage_nas/utils/__pycache__/evaluation_metrics.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hkzhang-git/HiNAS/fc56724fde81870ac90bd6d87075f0d1284b1f8c/one_stage_nas/utils/__pycache__/evaluation_metrics.cpython-36.pyc -------------------------------------------------------------------------------- /one_stage_nas/utils/__pycache__/logger.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hkzhang-git/HiNAS/fc56724fde81870ac90bd6d87075f0d1284b1f8c/one_stage_nas/utils/__pycache__/logger.cpython-36.pyc -------------------------------------------------------------------------------- /one_stage_nas/utils/__pycache__/metric_logger.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hkzhang-git/HiNAS/fc56724fde81870ac90bd6d87075f0d1284b1f8c/one_stage_nas/utils/__pycache__/metric_logger.cpython-36.pyc -------------------------------------------------------------------------------- /one_stage_nas/utils/__pycache__/misc.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hkzhang-git/HiNAS/fc56724fde81870ac90bd6d87075f0d1284b1f8c/one_stage_nas/utils/__pycache__/misc.cpython-36.pyc -------------------------------------------------------------------------------- /one_stage_nas/utils/__pycache__/visualize.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hkzhang-git/HiNAS/fc56724fde81870ac90bd6d87075f0d1284b1f8c/one_stage_nas/utils/__pycache__/visualize.cpython-36.pyc -------------------------------------------------------------------------------- /one_stage_nas/utils/checkpoint.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | 4 | import torch 5 | 6 | 7 | class Checkpointer(object): 8 | def __init__( 9 | self, 10 | model, 11 | optimizer=None, 12 | scheduler=None, 13 | save_dir="", 14 | save_to_disk=None, 15 | logger=None, 16 | ): 17 | self.model = model 18 | self.optimizer = optimizer 19 | self.scheduler = scheduler 20 | self.save_dir = save_dir 21 | self.save_to_disk = save_to_disk 22 | if logger is None: 23 | logger = logging.getLogger(__name__) 24 | self.logger = logger 25 | 26 | def save(self, name, **kwargs): 27 | if not self.save_dir: 28 | return 29 | 30 | if not self.save_to_disk: 31 | return 32 | 33 | data = {} 34 | data["model"] = self.model.state_dict() 35 | if self.optimizer is not None: 36 | data["optimizer"] = self.optimizer.state_dict() 37 | if self.scheduler is not None: 38 | data["scheduler"] = self.scheduler.state_dict() 39 | data.update(kwargs) 40 | 41 | save_file = os.path.join(self.save_dir, "{}.pth".format(name)) 42 | self.logger.info("Saving checkpoint to {}".format(save_file)) 43 | torch.save(data, save_file) 44 | self.tag_last_checkpoint(save_file) 45 | 46 | def load(self, f=None): 47 | if self.has_checkpoint(): 48 | # override argument with existing checkpoint 49 | f = self.get_checkpoint_file() 50 | if not f: 51 | # no checkpoint could be found 52 | self.logger.info("No checkpoint found. Initializing model from scratch") 53 | return {} 54 | self.logger.info("Loading checkpoint from {}".format(f)) 55 | checkpoint = self._load_file(f) 56 | checkpoint.pop('optimizer') 57 | checkpoint.pop('scheduler') 58 | # checkpoint.pop('iteration') 59 | self._load_model(checkpoint) 60 | if "optimizer" in checkpoint and self.optimizer: 61 | self.logger.info("Loading optimizer from {}".format(f)) 62 | self.optimizer.load_state_dict(checkpoint.pop("optimizer")) 63 | if "scheduler" in checkpoint and self.scheduler: 64 | self.logger.info("Loading scheduler from {}".format(f)) 65 | self.scheduler.load_state_dict(checkpoint.pop("scheduler")) 66 | 67 | # return any further checkpoint data 68 | return checkpoint 69 | 70 | def has_checkpoint(self): 71 | save_file = os.path.join(self.save_dir, "last_checkpoint") 72 | return os.path.exists(save_file) 73 | 74 | def get_checkpoint_file(self): 75 | save_file = os.path.join(self.save_dir, "last_checkpoint") 76 | try: 77 | with open(save_file, "r") as f: 78 | last_saved = f.read().strip() 79 | except IOError: 80 | # if file doesn't exist, maybe because it has just been 81 | # deleted by a separate process 82 | last_saved = "" 83 | return last_saved 84 | 85 | def tag_last_checkpoint(self, last_filename): 86 | save_file = os.path.join(self.save_dir, "last_checkpoint") 87 | with open(save_file, "w") as f: 88 | f.write(last_filename) 89 | 90 | def _load_file(self, f): 91 | return torch.load(f) 92 | 93 | def _load_model(self, checkpoint): 94 | model_state_dict = checkpoint.pop("model") 95 | try: 96 | self.model.load_state_dict(model_state_dict) 97 | except: 98 | self.model.module.load_state_dict(model_state_dict) 99 | 100 | -------------------------------------------------------------------------------- /one_stage_nas/utils/comm.py: -------------------------------------------------------------------------------- 1 | """multigpu utils 2 | """ 3 | 4 | import torch 5 | 6 | 7 | def reduce_loss_dict(loss_dict): 8 | """ 9 | Reduce the loss dictionary from all processes so that process with rank 10 | 0 has the averaged results. Returns a dict with the same fields as 11 | loss_dict, after reduction. 12 | """ 13 | with torch.no_grad(): 14 | loss_names = [] 15 | all_losses = [] 16 | for k in sorted(loss_dict.keys()): 17 | loss_names.append(k) 18 | all_losses.append(loss_dict[k].mean()) 19 | reduced_losses = {k: v for k, v in zip(loss_names, all_losses)} 20 | return reduced_losses 21 | 22 | 23 | def drop_path(x, drop_prob): 24 | if drop_prob > 0: 25 | keep_prob = 1 - drop_prob 26 | mask = torch.cuda.FloatTensor(x.size(0), 1, 1, 1).bernoulli_(keep_prob) 27 | x.div_(keep_prob) 28 | x.mul_(mask) 29 | return x 30 | 31 | 32 | def compute_params(model): 33 | n_params = 0 34 | for m in model.module.parameters(): 35 | n_params += m.numel() 36 | return n_params 37 | -------------------------------------------------------------------------------- /one_stage_nas/utils/evaluation_metrics.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import torch.nn.functional as F 4 | from math import exp 5 | 6 | # PSNR 7 | class PSNR(object): 8 | def __init__(self): 9 | self.sum_psnr = 0 10 | self.im_count = 0 11 | 12 | def __call__(self, output, gt): 13 | 14 | output = output*255.0 15 | gt = gt*255.0 16 | diff = (output - gt) 17 | mse = torch.mean(diff*diff) 18 | psnr = float(10*torch.log10(255.0*255.0/mse)) 19 | 20 | self.sum_psnr = self.sum_psnr + psnr 21 | self.im_count += 1.0 22 | 23 | def metric_get(self, frac=4): 24 | return round(self.sum_psnr/self.im_count, frac) 25 | 26 | def reset(self): 27 | self.sum_psnr = 0 28 | self.im_count = 0 29 | 30 | 31 | def gaussian(window_size=11, sigma=1.5): 32 | gauss = torch.Tensor([exp(-(x - window_size//2)**2/float(2*sigma**2)) for x in range(window_size)]) 33 | return gauss/gauss.sum() 34 | 35 | def create_window(window_size=11, channel=3): 36 | _1D_window = gaussian(window_size, 1.5).unsqueeze(1) 37 | _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0) 38 | window = _2D_window.expand(channel, 1, window_size, window_size).contiguous() 39 | return window 40 | 41 | 42 | #SSIM 43 | class SSIM(object): 44 | def __init__(self, window_size=11, channel=3, is_cuda=True): 45 | if is_cuda: 46 | self.window = create_window(window_size, channel).to('cuda') 47 | else: 48 | self.window = create_window(window_size, channel).to('cpu') 49 | 50 | self.window_size = window_size 51 | self.channel = channel 52 | self.sum_ssim = 0 53 | self.im_count = 0 54 | 55 | def __call__(self, output, gt, transpose=True): 56 | if transpose: 57 | output = output.transpose(0, 1).transpose(0, 2).unsqueeze(0) 58 | gt = gt.transpose(0, 1).transpose(0, 2).unsqueeze(0) 59 | 60 | mu1 = F.conv2d(output, self.window, padding=self.window_size // 2, groups=self.channel) 61 | mu2 = F.conv2d(gt, self.window, padding=self.window_size // 2, groups=self.channel) 62 | 63 | mu1_sq = mu1.pow(2) 64 | mu2_sq = mu2.pow(2) 65 | mu1_mu2 = mu1 * mu2 66 | 67 | sigma1_sq = F.conv2d(output * output, self.window, padding=self.window_size // 2, groups=self.channel) - mu1_sq 68 | sigma2_sq = F.conv2d(gt * gt, self.window, padding=self.window_size // 2, groups=self.channel) - mu2_sq 69 | sigma12 = F.conv2d(output * gt, self.window, padding=self.window_size // 2, groups=self.channel) - mu1_mu2 70 | 71 | C1 = 0.01 ** 2 72 | C2 = 0.03 ** 2 73 | 74 | ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2)) 75 | 76 | self.sum_ssim = self.sum_ssim + float(ssim_map.mean()) 77 | self.im_count += 1.0 78 | 79 | 80 | def metric_get(self, frac=4): 81 | return round(self.sum_ssim/self.im_count, frac) 82 | 83 | def reset(self): 84 | self.sum_ssim = 0 85 | self.im_count = 0 86 | 87 | 88 | metric_dict = { 89 | 'PSNR': PSNR, 90 | 'SSIM': SSIM 91 | } 92 | 93 | -------------------------------------------------------------------------------- /one_stage_nas/utils/logger.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 2 | import logging 3 | import os 4 | import sys 5 | 6 | 7 | def setup_logger(name, save_dir, distributed_rank=0): 8 | logger = logging.getLogger(name) 9 | logger.setLevel(logging.DEBUG) 10 | # don't log results for the non-master process 11 | if distributed_rank > 0: 12 | return logger 13 | ch = logging.StreamHandler(stream=sys.stdout) 14 | ch.setLevel(logging.DEBUG) 15 | formatter = logging.Formatter("%(asctime)s %(name)s %(levelname)s: %(message)s") 16 | ch.setFormatter(formatter) 17 | logger.addHandler(ch) 18 | 19 | if save_dir: 20 | fh = logging.FileHandler(os.path.join(save_dir, "log.txt")) 21 | fh.setLevel(logging.DEBUG) 22 | fh.setFormatter(formatter) 23 | logger.addHandler(fh) 24 | 25 | return logger 26 | -------------------------------------------------------------------------------- /one_stage_nas/utils/metric_logger.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 2 | from collections import defaultdict 3 | from collections import deque 4 | 5 | import torch 6 | 7 | 8 | class SmoothedValue(object): 9 | """Track a series of values and provide access to smoothed values over a 10 | window or the global series average. 11 | """ 12 | 13 | def __init__(self, window_size=20): 14 | self.deque = deque(maxlen=window_size) 15 | self.series = [] 16 | self.total = 0.0 17 | self.count = 0 18 | 19 | def update(self, value): 20 | self.deque.append(value) 21 | self.series.append(value) 22 | self.count += 1 23 | self.total += value 24 | 25 | @property 26 | def median(self): 27 | d = torch.tensor(list(self.deque)) 28 | return d.median().item() 29 | 30 | @property 31 | def avg(self): 32 | d = torch.tensor(list(self.deque)) 33 | return d.mean().item() 34 | 35 | @property 36 | def global_avg(self): 37 | return self.total / self.count 38 | 39 | 40 | class MetricLogger(object): 41 | def __init__(self, delimiter="\t"): 42 | self.meters = defaultdict(SmoothedValue) 43 | self.delimiter = delimiter 44 | 45 | def update(self, **kwargs): 46 | for k, v in kwargs.items(): 47 | if isinstance(v, torch.Tensor): 48 | v = v.item() 49 | assert isinstance(v, (float, int)) 50 | if v != v: 51 | # skip nan 52 | continue 53 | self.meters[k].update(v) 54 | 55 | def __getattr__(self, attr): 56 | if attr in self.meters: 57 | return self.meters[attr] 58 | return object.__getattr__(self, attr) 59 | 60 | def __str__(self): 61 | loss_str = [] 62 | for name, meter in self.meters.items(): 63 | loss_str.append( 64 | "{}: {:.4f} ({:.4f})".format(name, meter.avg, meter.global_avg) 65 | ) 66 | return self.delimiter.join(loss_str) 67 | -------------------------------------------------------------------------------- /one_stage_nas/utils/misc.py: -------------------------------------------------------------------------------- 1 | import errno 2 | import os 3 | 4 | 5 | def mkdir(path): 6 | try: 7 | os.makedirs(path) 8 | except OSError as e: 9 | if e.errno != errno.EEXIST: 10 | raise 11 | 12 | 13 | -------------------------------------------------------------------------------- /one_stage_nas/utils/visualize.py: -------------------------------------------------------------------------------- 1 | import math 2 | import copy 3 | from graphviz import Digraph 4 | 5 | def model_visualize(model, save_dir, tie=True, return_cell=False): 6 | geno_cell, geno_path = copy.deepcopy(model).module.cpu().genotype() 7 | visualize(geno_cell, geno_path, save_dir, tie=tie) 8 | 9 | if return_cell: 10 | return geno_cell 11 | 12 | def visualize(geno_cell, geno_path, save_dir, tie=True): 13 | g = Digraph( 14 | format='png', 15 | edge_attr=dict(fontsize='20', fontname="times"), 16 | node_attr=dict(style='filled', shape='rect', align='center', fontsize='20', height='0.5', width='0.5', 17 | penwidth='2', fontname="times"), 18 | engine='dot' 19 | ) 20 | g.body.extend(['rankdir=LR']) 21 | 22 | # cell tie==True 23 | if tie == True: 24 | g.node("Pre_pre_cell", fillcolor='darkseagreen2') 25 | g.node("Pre_cell", fillcolor='darkseagreen2') 26 | 27 | node_num = len(geno_cell)//2 28 | for i in range(node_num): 29 | g.node(name='Node {}'.format(i), fillcolor='lightblue') 30 | 31 | for i in range(node_num): 32 | for k in [2*i, 2*i+1]: 33 | op, j = geno_cell[k] 34 | if op != 'none': 35 | if j==1: 36 | u = "Pre_pre_cell" 37 | v = 'Node {}'.format(i) 38 | g.edge(u, v, label=op, fillcolor='red') 39 | elif j== 0: 40 | u = "Pre_cell" 41 | v = 'Node {}'.format(i) 42 | g.edge(u, v, label=op, fillcolor='red') 43 | else: 44 | u = 'Node {}'.format(j-2) 45 | v = 'Node {}'.format(i) 46 | g.edge(u, v, label=op, fillcolor='gray') 47 | 48 | g.node('Cur_cell', fillcolor='palegoldenrod') 49 | for i in range(node_num): 50 | g.edge('Node {}'.format(i), 'Cur_cell', fillcolor='palegoldenrod') 51 | 52 | # cell tie == False 53 | else: 54 | for cell_id in range(len(geno_cell)): 55 | geno_cell_i = geno_cell[cell_id] 56 | 57 | if cell_id == 0: 58 | pre_pre_cell = 'stem1' 59 | pre_cell = 'stem2' 60 | elif cell_id == 1: 61 | pre_pre_cell = 'stem2' 62 | pre_cell = 'cell_0' 63 | elif cell_id > 1: 64 | pre_pre_cell = 'cell_{}'.format(cell_id-2) 65 | pre_cell = 'cell_{}'.format(cell_id-1) 66 | 67 | cur_cell = 'cell_{}'.format(cell_id) 68 | 69 | g.node(pre_pre_cell, fillcolor='darkseagreen2') 70 | g.node(pre_cell, fillcolor='darkseagreen2') 71 | 72 | node_num = len(geno_cell_i) // 2 73 | for i in range(node_num): 74 | g.node(name='C{}_N{}'.format(cell_id, i), fillcolor='lightblue') 75 | 76 | for i in range(node_num): 77 | for k in [2 * i, 2 * i + 1]: 78 | op, j = geno_cell_i[k] 79 | if op != 'none': 80 | if j == 1: 81 | u = pre_pre_cell 82 | v = 'C{}_N{}'.format(cell_id, i) 83 | g.edge(u, v, label=op, fillcolor='red') 84 | elif j == 0: 85 | u = pre_cell 86 | v = 'C{}_N{}'.format(cell_id, i) 87 | g.edge(u, v, label=op, fillcolor='red') 88 | else: 89 | u = 'C{}_N{}'.format(cell_id, j - 2) 90 | v = 'C{}_N{}'.format(cell_id, i) 91 | g.edge(u, v, label=op, fillcolor='gray') 92 | 93 | g.node(cur_cell, fillcolor='palegoldenrod') 94 | for i in range(node_num): 95 | g.edge('C{}_N{}'.format(cell_id, i), cur_cell, fillcolor='palegoldenrod') 96 | 97 | 98 | # path 99 | arch = [] 100 | for layer_num, width in enumerate(geno_path): 101 | cell = 'cell:{} w:{}'.format(layer_num, math.pow(2, width)) 102 | arch.append(cell) 103 | 104 | g.node(name=' |-->| '.join(arch), fillcolor='lightyellow') 105 | 106 | g.render(save_dir, view=False) 107 | 108 | 109 | 110 | 111 | -------------------------------------------------------------------------------- /preprocess/dataset_json/sr/Set14.json: -------------------------------------------------------------------------------- 1 | [{"gt_path": "Set14/HR/baboon.png", "x2_path": "Set14/LR/Bi_x2/baboon.png", "x2_size": [250, 240], "x3_path": "Set14/LR/Bi_x3/baboon.png", "x3_size": [166, 160], "x4_path": "Set14/LR/Bi_x4/baboon.png", "x4_size": [125, 120], "x8_path": "Set14/LR/Bi_x8/baboon.png", "x8_size": [62, 60]}, {"gt_path": "Set14/HR/barbara.png", "x2_path": "Set14/LR/Bi_x2/barbara.png", "x2_size": [360, 288], "x3_path": "Set14/LR/Bi_x3/barbara.png", "x3_size": [240, 192], "x4_path": "Set14/LR/Bi_x4/barbara.png", "x4_size": [180, 144], "x8_path": "Set14/LR/Bi_x8/barbara.png", "x8_size": [90, 72]}, {"gt_path": "Set14/HR/bridge.png", "x2_path": "Set14/LR/Bi_x2/bridge.png", "x2_size": [256, 256], "x3_path": "Set14/LR/Bi_x3/bridge.png", "x3_size": [170, 170], "x4_path": "Set14/LR/Bi_x4/bridge.png", "x4_size": [128, 128], "x8_path": "Set14/LR/Bi_x8/bridge.png", "x8_size": [64, 64]}, {"gt_path": "Set14/HR/coastguard.png", "x2_path": "Set14/LR/Bi_x2/coastguard.png", "x2_size": [176, 144], "x3_path": "Set14/LR/Bi_x3/coastguard.png", "x3_size": [117, 96], "x4_path": "Set14/LR/Bi_x4/coastguard.png", "x4_size": [88, 72], "x8_path": "Set14/LR/Bi_x8/coastguard.png", "x8_size": [44, 36]}, {"gt_path": "Set14/HR/comic.png", "x2_path": "Set14/LR/Bi_x2/comic.png", "x2_size": [125, 180], "x3_path": "Set14/LR/Bi_x3/comic.png", "x3_size": [83, 120], "x4_path": "Set14/LR/Bi_x4/comic.png", "x4_size": [62, 90], "x8_path": "Set14/LR/Bi_x8/comic.png", "x8_size": [31, 45]}, {"gt_path": "Set14/HR/face.png", "x2_path": "Set14/LR/Bi_x2/face.png", "x2_size": [138, 138], "x3_path": "Set14/LR/Bi_x3/face.png", "x3_size": [92, 92], "x4_path": "Set14/LR/Bi_x4/face.png", "x4_size": [69, 69], "x8_path": "Set14/LR/Bi_x8/face.png", "x8_size": [34, 34]}, {"gt_path": "Set14/HR/flowers.png", "x2_path": "Set14/LR/Bi_x2/flowers.png", "x2_size": [250, 181], "x3_path": "Set14/LR/Bi_x3/flowers.png", "x3_size": [166, 120], "x4_path": "Set14/LR/Bi_x4/flowers.png", "x4_size": [125, 90], "x8_path": "Set14/LR/Bi_x8/flowers.png", "x8_size": [62, 45]}, {"gt_path": "Set14/HR/foreman.png", "x2_path": "Set14/LR/Bi_x2/foreman.png", "x2_size": [176, 144], "x3_path": "Set14/LR/Bi_x3/foreman.png", "x3_size": [117, 96], "x4_path": "Set14/LR/Bi_x4/foreman.png", "x4_size": [88, 72], "x8_path": "Set14/LR/Bi_x8/foreman.png", "x8_size": [44, 36]}, {"gt_path": "Set14/HR/lenna.png", "x2_path": "Set14/LR/Bi_x2/lenna.png", "x2_size": [256, 256], "x3_path": "Set14/LR/Bi_x3/lenna.png", "x3_size": [170, 170], "x4_path": "Set14/LR/Bi_x4/lenna.png", "x4_size": [128, 128], "x8_path": "Set14/LR/Bi_x8/lenna.png", "x8_size": [64, 64]}, {"gt_path": "Set14/HR/man.png", "x2_path": "Set14/LR/Bi_x2/man.png", "x2_size": [256, 256], "x3_path": "Set14/LR/Bi_x3/man.png", "x3_size": [170, 170], "x4_path": "Set14/LR/Bi_x4/man.png", "x4_size": [128, 128], "x8_path": "Set14/LR/Bi_x8/man.png", "x8_size": [64, 64]}, {"gt_path": "Set14/HR/monarch.png", "x2_path": "Set14/LR/Bi_x2/monarch.png", "x2_size": [384, 256], "x3_path": "Set14/LR/Bi_x3/monarch.png", "x3_size": [256, 170], "x4_path": "Set14/LR/Bi_x4/monarch.png", "x4_size": [192, 128], "x8_path": "Set14/LR/Bi_x8/monarch.png", "x8_size": [96, 64]}, {"gt_path": "Set14/HR/pepper.png", "x2_path": "Set14/LR/Bi_x2/pepper.png", "x2_size": [256, 256], "x3_path": "Set14/LR/Bi_x3/pepper.png", "x3_size": [170, 170], "x4_path": "Set14/LR/Bi_x4/pepper.png", "x4_size": [128, 128], "x8_path": "Set14/LR/Bi_x8/pepper.png", "x8_size": [64, 64]}, {"gt_path": "Set14/HR/ppt3.png", "x2_path": "Set14/LR/Bi_x2/ppt3.png", "x2_size": [264, 328], "x3_path": "Set14/LR/Bi_x3/ppt3.png", "x3_size": [176, 218], "x4_path": "Set14/LR/Bi_x4/ppt3.png", "x4_size": [132, 164], "x8_path": "Set14/LR/Bi_x8/ppt3.png", "x8_size": [66, 82]}, {"gt_path": "Set14/HR/zebra.png", "x2_path": "Set14/LR/Bi_x2/zebra.png", "x2_size": [293, 195], "x3_path": "Set14/LR/Bi_x3/zebra.png", "x3_size": [195, 130], "x4_path": "Set14/LR/Bi_x4/zebra.png", "x4_size": [146, 97], "x8_path": "Set14/LR/Bi_x8/zebra.png", "x8_size": [73, 48]}] -------------------------------------------------------------------------------- /preprocess/dataset_json/sr/Set5.json: -------------------------------------------------------------------------------- 1 | [{"gt_path": "Set5/HR/baby.png", "x2_path": "Set5/LR/Bi_x2/baby.png", "x2_size": [256, 256], "x3_path": "Set5/LR/Bi_x3/baby.png", "x3_size": [170, 170], "x4_path": "Set5/LR/Bi_x4/baby.png", "x4_size": [128, 128], "x8_path": "Set5/LR/Bi_x8/baby.png", "x8_size": [64, 64]}, {"gt_path": "Set5/HR/bird.png", "x2_path": "Set5/LR/Bi_x2/bird.png", "x2_size": [144, 144], "x3_path": "Set5/LR/Bi_x3/bird.png", "x3_size": [96, 96], "x4_path": "Set5/LR/Bi_x4/bird.png", "x4_size": [72, 72], "x8_path": "Set5/LR/Bi_x8/bird.png", "x8_size": [36, 36]}, {"gt_path": "Set5/HR/butterfly.png", "x2_path": "Set5/LR/Bi_x2/butterfly.png", "x2_size": [128, 128], "x3_path": "Set5/LR/Bi_x3/butterfly.png", "x3_size": [85, 85], "x4_path": "Set5/LR/Bi_x4/butterfly.png", "x4_size": [64, 64], "x8_path": "Set5/LR/Bi_x8/butterfly.png", "x8_size": [32, 32]}, {"gt_path": "Set5/HR/head.png", "x2_path": "Set5/LR/Bi_x2/head.png", "x2_size": [140, 140], "x3_path": "Set5/LR/Bi_x3/head.png", "x3_size": [93, 93], "x4_path": "Set5/LR/Bi_x4/head.png", "x4_size": [70, 70], "x8_path": "Set5/LR/Bi_x8/head.png", "x8_size": [35, 35]}, {"gt_path": "Set5/HR/woman.png", "x2_path": "Set5/LR/Bi_x2/woman.png", "x2_size": [114, 172], "x3_path": "Set5/LR/Bi_x3/woman.png", "x3_size": [76, 114], "x4_path": "Set5/LR/Bi_x4/woman.png", "x4_size": [57, 86], "x8_path": "Set5/LR/Bi_x8/woman.png", "x8_size": [28, 43]}] -------------------------------------------------------------------------------- /preprocess/dn_preprocess.py: -------------------------------------------------------------------------------- 1 | from utils import (denoise_dict_build, json_save, make_if_not_exist) 2 | import argparse 3 | import json 4 | import os 5 | 6 | 7 | def main(): 8 | parser = argparse.ArgumentParser(description='dataset preprocess') 9 | parser.add_argument('--data_root', type=str, default='D:/02_data/nas_data') 10 | parser.add_argument('--task', type=str, default='dn') 11 | parser.add_argument('--datasets', type=str, default=['BSD500_300', 'BSD500_200']) 12 | parser.add_argument('--json_dir', type=str, default='dataset_json') 13 | args = parser.parse_args() 14 | 15 | dict_list = denoise_dict_build(args) 16 | 17 | json_save_dir = os.path.join(args.json_dir, args.task) 18 | make_if_not_exist(json_save_dir) 19 | for dataset, dict in zip(args.datasets, dict_list): 20 | json_save(os.path.join(json_save_dir, '{}.json'.format(dataset)), dict) 21 | 22 | 23 | if __name__ == '__main__': 24 | main() 25 | 26 | 27 | 28 | 29 | -------------------------------------------------------------------------------- /preprocess/image_check.py: -------------------------------------------------------------------------------- 1 | import json 2 | import numpy as np 3 | import cv2 4 | from PIL import Image 5 | 6 | root_dir = '/home/hkzhang/Documents/sdb_a/nas_data/denoise/' 7 | json_dir = './dataset_json/denoise/CBD_real.json' 8 | with open(json_dir, 'r') as f: 9 | img_dict = json.load(f) 10 | 11 | for img_info in img_dict: 12 | img=Image.open(root_dir + img_info['path_clean']) 13 | img_name = img_info['path_clean'].split('/')[-1] 14 | if img.mode != 'RGB': 15 | print(img_name + ' : ' + img.mode) 16 | 17 | -------------------------------------------------------------------------------- /preprocess/sid_preprocess.py: -------------------------------------------------------------------------------- 1 | from utils import (sid_dict_build, json_save, make_if_not_exist) 2 | import argparse 3 | import os 4 | 5 | 6 | def main(): 7 | parser = argparse.ArgumentParser(description='dataset preprocess') 8 | parser.add_argument('--data_root', type=str, default='D:/02_data/nas_data') 9 | parser.add_argument('--task', type=str, default='sid') 10 | parser.add_argument('--dataset', type=str, default='Sony') 11 | parser.add_argument('--json_dir', type=str, default='dataset_json') 12 | args = parser.parse_args() 13 | 14 | train_dict, test_dict = sid_dict_build(args) 15 | 16 | json_save_dir = os.path.join(args.json_dir, args.task, args.dataset) 17 | make_if_not_exist(json_save_dir) 18 | json_save(os.path.join(json_save_dir, 'train.json'), train_dict) 19 | json_save(os.path.join(json_save_dir, 'test.json'), test_dict) 20 | 21 | 22 | if __name__ == '__main__': 23 | main() 24 | 25 | 26 | 27 | 28 | 29 | -------------------------------------------------------------------------------- /preprocess/sr_preprocess.py: -------------------------------------------------------------------------------- 1 | from utils import (sr_dict_build, json_save, make_if_not_exist) 2 | import argparse 3 | import json 4 | import os 5 | 6 | 7 | def main(): 8 | parser = argparse.ArgumentParser(description='dataset preprocess') 9 | parser.add_argument('--data_root', type=str, default='D:/02_data/nas_data') 10 | parser.add_argument('--task', type=str, default='sr') 11 | parser.add_argument('--datasets', type=str, default=['DIV2K_800', 'Set5', 'Set14', 'BSD100', 'Urban100', 'Manga109']) 12 | parser.add_argument('--json_dir', type=str, default='dataset_json') 13 | args = parser.parse_args() 14 | 15 | 16 | dict_list = sr_dict_build(args) 17 | 18 | json_save_dir = os.path.join(args.json_dir, args.task) 19 | make_if_not_exist(json_save_dir) 20 | for dataset, dict in zip(args.datasets, dict_list): 21 | json_save(os.path.join(json_save_dir, '{}.json'.format(dataset)), dict) 22 | 23 | if __name__ == '__main__': 24 | main() 25 | 26 | 27 | 28 | 29 | -------------------------------------------------------------------------------- /preprocess/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .dn_utils import denoise_dict_build, json_save, make_if_not_exist 2 | from .sid_utils import sid_dict_build 3 | from .sr_utils import sr_dict_build -------------------------------------------------------------------------------- /preprocess/utils/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hkzhang-git/HiNAS/fc56724fde81870ac90bd6d87075f0d1284b1f8c/preprocess/utils/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /preprocess/utils/__pycache__/dn_utils.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hkzhang-git/HiNAS/fc56724fde81870ac90bd6d87075f0d1284b1f8c/preprocess/utils/__pycache__/dn_utils.cpython-36.pyc -------------------------------------------------------------------------------- /preprocess/utils/__pycache__/sid_utils.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hkzhang-git/HiNAS/fc56724fde81870ac90bd6d87075f0d1284b1f8c/preprocess/utils/__pycache__/sid_utils.cpython-36.pyc -------------------------------------------------------------------------------- /preprocess/utils/__pycache__/sr_utils.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hkzhang-git/HiNAS/fc56724fde81870ac90bd6d87075f0d1284b1f8c/preprocess/utils/__pycache__/sr_utils.cpython-36.pyc -------------------------------------------------------------------------------- /preprocess/utils/dn_utils.py: -------------------------------------------------------------------------------- 1 | from glob import glob 2 | from PIL import Image 3 | import json 4 | import os 5 | 6 | 7 | def denoise_dict_build(args): 8 | dict_list = [] 9 | for dataset in args.datasets: 10 | if dataset in ['BSD500_300', 'BSD500_200', 'Urben100', 'set14']: 11 | dict = [] 12 | data_dir = os.path.join(args.data_root, args.task, dataset) 13 | 14 | im_list = glob(os.path.join(data_dir, '*.jpg')) 15 | if len(im_list) == 0: 16 | im_list = glob(os.path.join(data_dir, '*.png')) 17 | if len(im_list) == 0: 18 | im_list = glob(os.path.join(data_dir, '*.bmp')) 19 | 20 | im_list.sort() 21 | 22 | for im_dir in im_list: 23 | if '\\' in im_dir: 24 | im_dir=im_dir.replace('\\', '/') 25 | with Image.open(im_dir) as img: 26 | w, h = img.width, img.height 27 | 28 | sample_info = { 29 | 'path': '/'.join(im_dir.split('/')[-2:]), 30 | 'width': int(w), 31 | 'height': int(h) 32 | } 33 | dict.append(sample_info) 34 | dict_list.append(dict) 35 | 36 | return dict_list 37 | 38 | 39 | def json_save(save_path, dict_file): 40 | with open(save_path, 'w') as f: 41 | json.dump(dict_file, f) 42 | 43 | 44 | def make_if_not_exist(path): 45 | if not os.path.exists(path): 46 | os.makedirs(path) 47 | 48 | -------------------------------------------------------------------------------- /preprocess/utils/sid_utils.py: -------------------------------------------------------------------------------- 1 | from glob import glob 2 | from PIL import Image 3 | import json 4 | import os 5 | import rawpy 6 | import numpy as np 7 | 8 | 9 | # def sid_dict_build(args): 10 | # data_dir = args.data_root + '/' + args.task 11 | # train_dict=[] 12 | # test_dict=[] 13 | # 14 | # # build train_dict 15 | # with open(os.path.join(data_dir, 'Sony_train_list.txt'), 'r') as f: 16 | # train_info_list = f.readlines() 17 | # for train_info in train_info_list: 18 | # info = train_info.split() 19 | # raw_info = info[0] 20 | # raw_path = data_dir + raw_info[1:] 21 | # raw_exposure = info[0].split('/')[-1].split('_')[-1][:-5] 22 | # gt_info = info[1] 23 | # gt_path = (data_dir + gt_info[1:]) 24 | # gt_exposure = info[1].split('/')[-1].split('_')[-1][:-5] 25 | # assert os.path.exists(raw_path) and os.path.exists(gt_path) 26 | # device = '_'.join((info[2], info[3])) 27 | # sample_id = info[0].split('/')[-1][:-4] + '-' + '{}s'.format(gt_exposure) 28 | # sample_info = { 29 | # 'sample_id': sample_id, 30 | # 'raw_path': '/'.join(raw_path.split('/')[-3:]), 31 | # 'gt_path': '/'.join(gt_path.split('/')[-3:]), 32 | # 'raw_exposure': float(raw_exposure), 33 | # 'gt_exposure': float(gt_exposure), 34 | # 'device': device 35 | # } 36 | # train_dict.append(sample_info) 37 | # 38 | # with open(os.path.join(data_dir, 'Sony_val_list.txt'), 'r') as f: 39 | # train_info_list = f.readlines() 40 | # for train_info in train_info_list: 41 | # info = train_info.split() 42 | # raw_info = info[0] 43 | # raw_path = data_dir + raw_info[1:] 44 | # raw_exposure = info[0].split('/')[-1].split('_')[-1][:-5] 45 | # gt_info = info[1] 46 | # gt_path = (data_dir + gt_info[1:]) 47 | # gt_exposure = info[1].split('/')[-1].split('_')[-1][:-5] 48 | # assert os.path.exists(raw_path) and os.path.exists(gt_path) 49 | # device = '_'.join((info[2], info[3])) 50 | # sample_id = info[0].split('/')[-1][:-4] + '-' + '{}s'.format(gt_exposure) 51 | # sample_info = { 52 | # 'sample_id': sample_id, 53 | # 'raw_path': '/'.join(raw_path.split('/')[-3:]), 54 | # 'gt_path': '/'.join(gt_path.split('/')[-3:]), 55 | # 'raw_exposure': float(raw_exposure), 56 | # 'gt_exposure': float(gt_exposure), 57 | # 'device': device 58 | # } 59 | # train_dict.append(sample_info) 60 | # 61 | # # build test_dict 62 | # with open(os.path.join(data_dir, 'Sony_test_list.txt'), 'r') as f: 63 | # test_info_list = f.readlines() 64 | # for test_info in test_info_list: 65 | # info = test_info.split() 66 | # raw_info = info[0] 67 | # raw_path = data_dir + raw_info[1:] 68 | # raw_exposure = info[0].split('/')[-1].split('_')[-1][:-5] 69 | # gt_info = info[1] 70 | # gt_path = (data_dir + gt_info[1:]) 71 | # gt_exposure = info[1].split('/')[-1].split('_')[-1][:-5] 72 | # assert os.path.exists(raw_path) and os.path.exists(gt_path) 73 | # device = '_'.join((info[2], info[3])) 74 | # sample_id = info[0].split('/')[-1][:-4] + '-' + '{}s'.format(gt_exposure) 75 | # sample_info = { 76 | # 'sample_id': sample_id, 77 | # 'raw_path': '/'.join(raw_path.split('/')[-3:]), 78 | # 'gt_path': '/'.join(gt_path.split('/')[-3:]), 79 | # 'raw_exposure': float(raw_exposure), 80 | # 'gt_exposure': float(gt_exposure), 81 | # 'device': device 82 | # } 83 | # test_dict.append(sample_info) 84 | # 85 | # return train_dict, test_dict 86 | 87 | 88 | def sid_dict_build(args): 89 | data_dir = args.data_root + '/' + args.task 90 | train_dict=[] 91 | test_dict=[] 92 | 93 | # build train_dict 94 | with open(os.path.join(data_dir, 'Sony_train_list.txt'), 'r') as f: 95 | train_info_list = f.readlines() 96 | raw_info_list = np.array([info.split()[0] for info in train_info_list]) 97 | gt_info_list = [info.split()[1] for info in train_info_list] 98 | gt_set = list(set(gt_info_list)) 99 | gt_set.sort() 100 | gt_info_list = np.array(gt_info_list) 101 | 102 | for gt_info in gt_set: 103 | gt_arw = rawpy.imread('/'.join((data_dir, gt_info))) 104 | width, height = gt_arw.sizes.iwidth, gt_arw.sizes.iheight 105 | raw_info_set = list(raw_info_list[gt_info_list==gt_info]) 106 | sample_id = gt_info.split('/')[-1].split('.')[0] 107 | gt_exposure = gt_info.split('/')[-1].split('_')[-1][:-5] 108 | sample_info = { 109 | 'sample_id': sample_id, 110 | 'raw_path': [], 111 | 'gt_path': gt_info[2:], 112 | 'raw_exposure': [], 113 | 'gt_exposure': float(gt_exposure), 114 | 'width': width, 115 | 'height': height, 116 | } 117 | for raw_info in raw_info_set: 118 | raw_path = raw_info[2:] 119 | raw_exposure = raw_info.split('/')[-1].split('_')[-1][:-5] 120 | sample_info['raw_path'].append(raw_path) 121 | sample_info['raw_exposure'].append(float(raw_exposure)) 122 | 123 | train_dict.append(sample_info) 124 | 125 | with open(os.path.join(data_dir, 'Sony_val_list.txt'), 'r') as f: 126 | train_info_list = f.readlines() 127 | raw_info_list = np.array([info.split()[0] for info in train_info_list]) 128 | gt_info_list = [info.split()[1] for info in train_info_list] 129 | gt_set = list(set(gt_info_list)) 130 | gt_set.sort() 131 | gt_info_list = np.array(gt_info_list) 132 | 133 | for gt_info in gt_set: 134 | gt_arw = rawpy.imread('/'.join((data_dir, gt_info))) 135 | width, height = gt_arw.sizes.iwidth, gt_arw.sizes.iheight 136 | raw_info_set = list(raw_info_list[gt_info_list == gt_info]) 137 | sample_id = gt_info.split('/')[-1].split('.')[0] 138 | gt_exposure = gt_info.split('/')[-1].split('_')[-1][:-5] 139 | sample_info = { 140 | 'sample_id': sample_id, 141 | 'raw_path': [], 142 | 'gt_path': gt_info[2:], 143 | 'raw_exposure': [], 144 | 'gt_exposure': float(gt_exposure), 145 | 'width': width, 146 | 'height': height, 147 | } 148 | for raw_info in raw_info_set: 149 | raw_path = raw_info[2:] 150 | raw_exposure = raw_info.split('/')[-1].split('_')[-1][:-5] 151 | sample_info['raw_path'].append(raw_path) 152 | sample_info['raw_exposure'].append(float(raw_exposure)) 153 | 154 | train_dict.append(sample_info) 155 | 156 | # build test_dict 157 | with open(os.path.join(data_dir, 'Sony_test_list.txt'), 'r') as f: 158 | test_info_list = f.readlines() 159 | for test_info in test_info_list: 160 | info = test_info.split() 161 | raw_info = info[0] 162 | raw_path = data_dir + raw_info[1:] 163 | raw_exposure = info[0].split('/')[-1].split('_')[-1][:-5] 164 | gt_info = info[1] 165 | gt_path = (data_dir + gt_info[1:]) 166 | gt_exposure = info[1].split('/')[-1].split('_')[-1][:-5] 167 | assert os.path.exists(raw_path) and os.path.exists(gt_path) 168 | device = '_'.join((info[2], info[3])) 169 | sample_id = info[0].split('/')[-1][:-4] + '-' + '{}s'.format(gt_exposure) 170 | gt_arw = rawpy.imread(gt_path) 171 | width, height = gt_arw.sizes.iwidth, gt_arw.sizes.iheight 172 | sample_info = { 173 | 'sample_id': sample_id, 174 | 'raw_path': '/'.join(raw_path.split('/')[-3:]), 175 | 'gt_path': '/'.join(gt_path.split('/')[-3:]), 176 | 'raw_exposure': float(raw_exposure), 177 | 'gt_exposure': float(gt_exposure), 178 | 'device': device, 179 | 'width': width, 180 | 'height': height, 181 | } 182 | test_dict.append(sample_info) 183 | 184 | return train_dict, test_dict 185 | 186 | -------------------------------------------------------------------------------- /preprocess/utils/sr_utils.py: -------------------------------------------------------------------------------- 1 | from glob import glob 2 | from PIL import Image 3 | 4 | import os 5 | 6 | 7 | def sr_dict_build(args): 8 | dict_list = [] 9 | for dataset in args.datasets: 10 | dict = [] 11 | data_dir = '/'.join((args.data_root, args.task, dataset, 'HR')) 12 | 13 | im_list = glob(data_dir + '/*.jpg') 14 | if len(im_list) == 0: 15 | im_list = glob(data_dir + '/*.png') 16 | if len(im_list) == 0: 17 | im_list = glob(os.path.join(data_dir, '/*.bmp')) 18 | 19 | im_list.sort() 20 | 21 | for im_dir in im_list: 22 | if '\\' in im_dir: 23 | im_dir = im_dir.replace('\\', '/') 24 | # x2 25 | im_x2_dir = im_dir.replace('HR', 'LR/Bi_x2') 26 | with Image.open(im_x2_dir) as img: 27 | w_x2, h_x2 = img.width, img.height 28 | # x3 29 | im_x3_dir = im_dir.replace('HR', 'LR/Bi_x3') 30 | with Image.open(im_x3_dir) as img: 31 | w_x3, h_x3 = img.width, img.height 32 | # x4 33 | im_x4_dir = im_dir.replace('HR', 'LR/Bi_x4') 34 | with Image.open(im_x4_dir) as img: 35 | w_x4, h_x4 = img.width, img.height 36 | # x8 37 | im_x8_dir = im_dir.replace('HR', 'LR/Bi_x8') 38 | with Image.open(im_x8_dir) as img: 39 | w_x8, h_x8 = img.width, img.height 40 | 41 | sample_info = { 42 | 'gt_path': '/'.join(im_dir.split('/')[-3:]), 43 | 'x2_path': '/'.join(im_x2_dir.split('/')[-4:]), 44 | 'x2_size': [int(w_x2), int(h_x2)], 45 | 'x3_path': '/'.join(im_x3_dir.split('/')[-4:]), 46 | 'x3_size': [int(w_x3), int(h_x3)], 47 | 'x4_path': '/'.join(im_x4_dir.split('/')[-4:]), 48 | 'x4_size': [int(w_x4), int(h_x4)], 49 | 'x8_path': '/'.join(im_x8_dir.split('/')[-4:]), 50 | 'x8_size': [int(w_x8), int(h_x8)], 51 | } 52 | dict.append(sample_info) 53 | dict_list.append(dict) 54 | 55 | return dict_list 56 | 57 | 58 | 59 | 60 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | opencv_python==4.1.0.25 2 | numpy==1.16.2 3 | setuptools==40.8.0 4 | matplotlib==3.0.3 5 | scipy==1.2.1 6 | torch==1.0.0 7 | yacs==0.1.6 8 | graphviz==0.11.1 9 | Pillow==6.1.0 10 | tensorboardX==1.8 11 | setuptools==40.8.0 12 | -------------------------------------------------------------------------------- /tools/dn_eval.py: -------------------------------------------------------------------------------- 1 | """ 2 | Searching script 3 | """ 4 | 5 | import argparse 6 | import os 7 | import json 8 | import torch 9 | import sys 10 | import numpy as np 11 | sys.path.append('..') 12 | from one_stage_nas.config import cfg 13 | from one_stage_nas.data import build_transforms 14 | from one_stage_nas.utils.misc import mkdir 15 | from one_stage_nas.modeling.architectures import build_model 16 | from PIL import Image 17 | from one_stage_nas.utils.evaluation_metrics import SSIM, PSNR 18 | 19 | 20 | def crop(crop_size, w, h): 21 | slide_step = crop_size - crop_size // 4 22 | x1 = list(range(0, w-crop_size, slide_step)) 23 | x1.append(w-crop_size) 24 | y1 = list(range(0, h-crop_size, slide_step)) 25 | y1.append(h-crop_size) 26 | 27 | x2 = [x+crop_size for x in x1] 28 | y2 = [y+crop_size for y in y1] 29 | 30 | return x1, x2, y1, y2 31 | 32 | 33 | def json_loader(dict_file_dir): 34 | with open(dict_file_dir, 'r') as data_file: 35 | return json.load(data_file) 36 | 37 | 38 | def joint_patches(output_buffer, w, h, channel): 39 | count_matrix = np.zeros((int(h), int(w), channel), dtype=np.float32) 40 | im_result = torch.from_numpy(np.zeros((int(h), int(w), channel), dtype=np.float32)) 41 | gt_result = torch.from_numpy(np.zeros((int(h), int(w), channel), dtype=np.float32)) 42 | 43 | for item in output_buffer: 44 | im_patch = item['im_patch'] 45 | gt_patch = item['gt_patch'] 46 | crop_position = item['crop_position'] 47 | w0, w1, h0, h1 = int(crop_position[0]), int(crop_position[1]), int(crop_position[2]), int(crop_position[3]) 48 | 49 | im_result[h0:h1, w0:w1] = im_result[h0:h1, w0:w1] + im_patch.transpose(0, 2).transpose(0, 1).contiguous() 50 | gt_result[h0:h1, w0:w1] = gt_result[h0:h1, w0:w1] + gt_patch.transpose(0, 2).transpose(0, 1).contiguous() 51 | count_matrix[h0:h1, w0:w1] = count_matrix[h0:h1, w0:w1] + 1.0 52 | 53 | return im_result / torch.from_numpy(count_matrix), gt_result / torch.from_numpy(count_matrix) 54 | 55 | 56 | def evaluation(cfg, SIGMA, dataset): 57 | print('load test set') 58 | dataset_json_dir = '/'.join((cfg.DATALOADER.DATA_LIST_DIR, cfg.DATASET.TASK, '{}.json'.format(dataset))) 59 | data_dict = json_loader(dataset_json_dir) 60 | 61 | crop_size = cfg.DATASET.CROP_SIZE 62 | data_root = cfg.DATASET.DATA_ROOT 63 | 64 | test_dict = [] 65 | for im_info in data_dict: 66 | 67 | w, h = im_info['width'], im_info['height'] 68 | im_id = im_info['path'].split('/')[-1] 69 | 70 | # assert w >= crop_size and h >= crop_size 71 | crop_size = min(crop_size, w, h) 72 | x1, x2, y1, y2 = crop(crop_size, int(w), int(h)) 73 | 74 | for x_start, x_end in zip(x1, x2): 75 | for y_start, y_end in zip(y1, y2): 76 | 77 | sample_info = { 78 | 'path': os.path.join(data_root, cfg.DATASET.TASK, '/'.join(im_info['path'].split('/')[-3:])), 79 | 'im_id': im_id, 80 | 'width': w, 81 | 'height': h, 82 | 'x1': x_start, 83 | 'x2': x_end, 84 | 'y1': y_start, 85 | 'y2': y_end 86 | } 87 | test_dict.append(sample_info) 88 | 89 | 90 | print('model build') 91 | 92 | trained_model_dir = '/'.join((cfg.OUTPUT_DIR, cfg.DATASET.TASK, 93 | '{}/Outline-{}c{}n_TC-{}_ASPP-{}_Res-{}_Prim-{}'. 94 | format(cfg.DATASET.DATA_NAME, cfg.MODEL.NUM_LAYERS, cfg.MODEL.NUM_BLOCKS, 95 | cfg.SEARCH.TIE_CELL, cfg.MODEL.USE_ASPP, cfg.MODEL.USE_RES, cfg.MODEL.PRIMITIVES), 96 | 'train_noise_[{}]'.format(SIGMA), 'models/model_best.pth')) 97 | 98 | if not os.path.exists(trained_model_dir): 99 | print('trained_model does not exist') 100 | return None, None 101 | model = build_model(cfg) 102 | model = torch.nn.DataParallel(model).cuda() 103 | # trained_model_dir = os.path.join(cfg.MODEL_DIR, 'model_best.pth') 104 | 105 | model_state_dict = torch.load(trained_model_dir).pop("model") 106 | try: 107 | model.load_state_dict(model_state_dict) 108 | except: 109 | model.module.load_state_dict(model_state_dict) 110 | 111 | print('dataset {} evaluation...'.format(dataset)) 112 | 113 | transforms = build_transforms(task='dn', tag='test', sigma=[SIGMA]) 114 | 115 | model.eval() 116 | metric_SSIM = SSIM(window_size=11, channel=cfg.MODEL.IN_CHANNEL, is_cuda=True) 117 | metric_PSNR = PSNR() 118 | 119 | batch_size = cfg.DATALOADER.BATCH_SIZE_TEST 120 | 121 | result_save_dir = '/'.join((cfg.OUTPUT_DIR, cfg.DATASET.TASK, 122 | '{}/Outline-{}c{}n_TC-{}_ASPP-{}_Res-{}_Prim-{}'. 123 | format(cfg.DATASET.DATA_NAME, cfg.MODEL.NUM_LAYERS, cfg.MODEL.NUM_BLOCKS, 124 | cfg.SEARCH.TIE_CELL, cfg.MODEL.USE_ASPP, cfg.MODEL.USE_RES, cfg.MODEL.PRIMITIVES), 125 | 'eval_noise_[{}]'.format(SIGMA), dataset)) 126 | 127 | mkdir(result_save_dir) 128 | 129 | with torch.no_grad(): 130 | previous_im_id = '' 131 | current_im_id = '' 132 | previous_im_w = None 133 | previous_im_h = None 134 | output_buffer = [] 135 | 136 | dict_len = len(test_dict) 137 | batch_index_end=0 138 | 139 | i = 0 140 | while batch_index_end < dict_len: 141 | 142 | batch_index_start = batch_index_end 143 | batch_index_end = min(batch_index_end + batch_size, dict_len) 144 | 145 | images = [] 146 | targets = [] 147 | im_id = [] 148 | w, h = [], [] 149 | x1, x2, y1, y2 = [], [], [], [] 150 | 151 | for index in range(batch_index_start, batch_index_end): 152 | patch_info = test_dict[index] 153 | 154 | if patch_info['im_id'] != current_im_id: 155 | sample_data = Image.open(patch_info['path']) 156 | print(patch_info['im_id'] + ':' + sample_data.mode) 157 | width = patch_info['width'] 158 | height = patch_info['height'] 159 | current_im_id = patch_info['im_id'] 160 | 161 | p_x1, p_x2, p_y1, p_y2 = patch_info['x1'], patch_info['x2'], patch_info['y1'], patch_info['y2'] 162 | 163 | image = sample_data.crop((p_x1, p_y1, p_x2, p_y2)) 164 | if cfg.DATASET.TO_GRAY: 165 | image = image.convert('L') 166 | target = image 167 | 168 | sample = {'image': image, 'target': target} 169 | sample = transforms(sample) 170 | 171 | images.append(sample['image']) 172 | targets.append(sample['target']) 173 | im_id.append(patch_info['im_id']) 174 | w.append(width) 175 | h.append(height) 176 | x1.append(p_x1) 177 | x2.append(p_x2) 178 | y1.append(p_y1) 179 | y2.append(p_y2) 180 | 181 | images = torch.stack(images) 182 | targets = torch.stack(targets) 183 | output = model(images) 184 | 185 | for j in range(images.size(0)): 186 | if not (i == 0 and j == 0) and im_id[j] != previous_im_id: 187 | im_result, gt_result = joint_patches(output_buffer, previous_im_w, previous_im_h, cfg.MODEL.IN_CHANNEL) 188 | im_result[im_result > 1.0] = 1.0 189 | im_result[im_result < 0.0] = 0.0 190 | 191 | metric_SSIM(im_result.cuda(), gt_result.cuda()) 192 | metric_PSNR(im_result, gt_result) 193 | im_PIL = Image.fromarray(np.array(im_result.squeeze() * 255, np.uint8)) 194 | im_PIL.save(os.path.join(result_save_dir, previous_im_id)) 195 | output_buffer = [] 196 | 197 | previous_im_id = im_id[j] 198 | previous_im_w = w[j] 199 | previous_im_h = h[j] 200 | 201 | patch_info = { 202 | 'im_patch': output[j].cpu(), 203 | 'gt_patch': targets[j], 204 | 'crop_position': [x1[j], x2[j], y1[j], y2[j]] 205 | } 206 | output_buffer.append(patch_info) 207 | 208 | i+=1 209 | 210 | im_result, gt_result = joint_patches(output_buffer, previous_im_w, previous_im_h, cfg.MODEL.IN_CHANNEL) 211 | im_result[im_result>1.0] = 1.0 212 | im_result[im_result<0.0] = 0.0 213 | metric_SSIM(im_result.cuda(), gt_result.cuda()) 214 | metric_PSNR(im_result, gt_result) 215 | im_PIL = Image.fromarray(np.array(im_result.squeeze() * 255, np.uint8)) 216 | im_PIL.save(os.path.join(result_save_dir, previous_im_id)) 217 | 218 | ssim = metric_SSIM.metric_get() 219 | psnr = metric_PSNR.metric_get() 220 | 221 | print('dataset:{} ssim:{}, psnr:{}'.format(dataset, ssim, psnr)) 222 | with open(os.path.join(result_save_dir, 'evaluation_result.txt'), 'w') as f: 223 | f.write('SSIM:{} PSNR:{}'.format(ssim, psnr)) 224 | 225 | 226 | def main(): 227 | parser = argparse.ArgumentParser(description="evaluation") 228 | parser.add_argument( 229 | "--config-file", 230 | default="../configs/dn/BSD500_2c3n/inference.yaml", 231 | metavar="FILE", 232 | help="path to config file", 233 | type=str, 234 | ) 235 | parser.add_argument( 236 | "--device", 237 | default='0', 238 | help="path to config file", 239 | type=str, 240 | ) 241 | 242 | args = parser.parse_args() 243 | os.environ['CUDA_VISIBLE_DEVICES'] = args.device 244 | cfg.merge_from_file(args.config_file) 245 | cfg.freeze() 246 | 247 | dataset_list = cfg.DATASET.TEST_DATASETS 248 | noise_levels = cfg.DATALOADER.SIGMA 249 | 250 | for dataset in dataset_list: 251 | for SIGMA in noise_levels: 252 | evaluation(cfg, SIGMA, dataset) 253 | 254 | 255 | if __name__ == "__main__": 256 | main() 257 | -------------------------------------------------------------------------------- /tools/search.py: -------------------------------------------------------------------------------- 1 | """ 2 | Searching script 3 | """ 4 | import argparse 5 | import torch 6 | import os 7 | import sys 8 | sys.path.append('..') 9 | from one_stage_nas.config import cfg 10 | from one_stage_nas.data import build_dataset 11 | from one_stage_nas.solver import make_lr_scheduler 12 | from one_stage_nas.solver import make_optimizer 13 | from one_stage_nas.engine.searcher import do_search 14 | from one_stage_nas.modeling.architectures import build_model 15 | from one_stage_nas.utils.checkpoint import Checkpointer 16 | from one_stage_nas.utils.logger import setup_logger 17 | from one_stage_nas.utils.misc import mkdir 18 | from tensorboardX import SummaryWriter 19 | 20 | 21 | def search(cfg, output_dir): 22 | 23 | # set random seed 24 | torch.manual_seed(cfg.SEARCH.R_SEED) 25 | torch.cuda.manual_seed(cfg.SEARCH.R_SEED) 26 | 27 | model = build_model(cfg) 28 | optimizer = make_optimizer(cfg, model) 29 | scheduler = make_lr_scheduler(cfg, optimizer) 30 | 31 | checkpointer = Checkpointer( 32 | model, optimizer, scheduler, output_dir + '/models', save_to_disk=True) 33 | 34 | train_loaders, val_dict = build_dataset(cfg) 35 | 36 | arguments = {} 37 | arguments["epoch"] = 0 38 | 39 | extra_checkpoint_data = checkpointer.load(cfg.MODEL.WEIGHT) 40 | arguments.update(extra_checkpoint_data) 41 | 42 | # just use data parallel 43 | model = torch.nn.DataParallel(model).cuda() 44 | 45 | checkpoint_period = cfg.SOLVER.CHECKPOINT_PERIOD 46 | val_period = cfg.SOLVER.VALIDATE_PERIOD 47 | max_epoch = cfg.SOLVER.MAX_EPOCH 48 | arch_start_epoch = cfg.SEARCH.ARCH_START_EPOCH 49 | 50 | writer = SummaryWriter(logdir=output_dir + '/log', comment=cfg.DATASET.TASK + '_' + cfg.DATASET.DATA_NAME) 51 | 52 | do_search( 53 | model, 54 | train_loaders, 55 | val_dict, 56 | max_epoch, 57 | arch_start_epoch, 58 | val_period, 59 | optimizer, 60 | scheduler, 61 | checkpointer, 62 | checkpoint_period, 63 | arguments, 64 | writer, 65 | cfg, 66 | visual_dir=output_dir, 67 | ) 68 | 69 | 70 | def main(): 71 | parser = argparse.ArgumentParser(description="neural architecture search for four different low-level tasks") 72 | parser.add_argument( 73 | "--config-file", 74 | default='../configs/dn/BSD500_3c4n/03_search_CR_R0.yaml', 75 | metavar="FILE", 76 | help="path to config file", 77 | type=str, 78 | ) 79 | parser.add_argument( 80 | "--device", 81 | default='3', 82 | help="path to config file", 83 | type=str, 84 | ) 85 | parser.add_argument( 86 | "opts", 87 | help="Modify config options using the command-line", 88 | default=None, 89 | nargs=argparse.REMAINDER, 90 | ) 91 | 92 | args = parser.parse_args() 93 | os.environ['CUDA_VISIBLE_DEVICES'] = args.device 94 | cfg.merge_from_file(args.config_file) 95 | cfg.merge_from_list(args.opts) 96 | cfg.freeze() 97 | 98 | output_dir = '/'.join((cfg.OUTPUT_DIR, cfg.DATASET.TASK, 99 | '{}/Outline-{}c{}n_TC-{}_ASPP-{}_Res-{}_Prim-{}'. 100 | format(cfg.DATASET.DATA_NAME, cfg.MODEL.NUM_LAYERS, cfg.MODEL.NUM_BLOCKS, 101 | cfg.SEARCH.TIE_CELL, cfg.MODEL.USE_ASPP, cfg.MODEL.USE_RES, cfg.MODEL.PRIMITIVES), 102 | 'search')) 103 | mkdir(output_dir) 104 | mkdir(output_dir + '/models') 105 | logger = setup_logger("one_stage_nas", output_dir) 106 | logger.info(args) 107 | 108 | logger.info("Loaded configuration file {}".format(args.config_file)) 109 | with open(args.config_file, "r") as cf: 110 | config_str = "\n" + cf.read() 111 | logger.info(config_str) 112 | logger.info("Running with config:\n{}".format(cfg)) 113 | 114 | search(cfg, output_dir) 115 | 116 | 117 | if __name__ == "__main__": 118 | main() 119 | -------------------------------------------------------------------------------- /tools/sr_eval.py: -------------------------------------------------------------------------------- 1 | """ 2 | Searching script 3 | """ 4 | 5 | import argparse 6 | import os 7 | import json 8 | import torch 9 | import random 10 | import sys 11 | import numpy as np 12 | import matplotlib.pyplot as plt 13 | sys.path.append('..') 14 | from one_stage_nas.config import cfg 15 | from one_stage_nas.data import build_transforms 16 | from one_stage_nas.utils.misc import mkdir 17 | from one_stage_nas.modeling.architectures import build_model 18 | from PIL import Image 19 | from one_stage_nas.utils.evaluation_metrics import SSIM, PSNR 20 | 21 | import time 22 | 23 | def json_loader(dict_file_dir): 24 | with open(dict_file_dir, 'r') as data_file: 25 | return json.load(data_file) 26 | 27 | 28 | def evaluation(cfg, s_factor, dataset): 29 | print('load test set') 30 | dataset_json_dir = '/'.join((cfg.DATALOADER.DATA_LIST_DIR, cfg.DATASET.TASK, '{}.json'.format(dataset))) 31 | data_dict = json_loader(dataset_json_dir) 32 | 33 | data_root = cfg.DATASET.DATA_ROOT 34 | s_factor = cfg.DATALOADER.S_FACTOR 35 | # rearrange the testing sample list 36 | 37 | print('model build') 38 | 39 | trained_model_dir = '/'.join((cfg.OUTPUT_DIR, cfg.DATASET.TASK, 40 | '{}/Outline-{}c{}n_TC-{}_ASPP-{}_Res-{}_Prim-{}'. 41 | format(cfg.DATASET.DATA_NAME, cfg.MODEL.NUM_LAYERS, cfg.MODEL.NUM_BLOCKS, 42 | cfg.SEARCH.TIE_CELL, cfg.MODEL.USE_ASPP, cfg.MODEL.USE_RES, 43 | cfg.MODEL.PRIMITIVES), 44 | 'train_x{}/models/model_best.pth'.format(s_factor))) 45 | 46 | if not os.path.exists(trained_model_dir): 47 | print('trained_model does not exist') 48 | return None, None 49 | model = build_model(cfg) 50 | model = torch.nn.DataParallel(model).cuda() 51 | 52 | model_state_dict = torch.load(trained_model_dir).pop("model") 53 | try: 54 | model.load_state_dict(model_state_dict) 55 | except: 56 | model.module.load_state_dict(model_state_dict) 57 | 58 | print('dataset {} evaluation...'.format(dataset)) 59 | 60 | transforms = build_transforms(task='sr', tag='test') 61 | 62 | # as we record the PSNR and SSIM on Y channel, here the number of input channel is set to 1 63 | model.eval() 64 | metric_SSIM = SSIM(window_size=11, channel=1, is_cuda=True) 65 | metric_PSNR = PSNR() 66 | 67 | result_save_dir = '/'.join((cfg.OUTPUT_DIR, cfg.DATASET.TASK, 68 | '{}/Outline-{}c{}n_TC-{}_ASPP-{}_Res-{}_Prim-{}'. 69 | format(cfg.DATASET.DATA_NAME, cfg.MODEL.NUM_LAYERS, cfg.MODEL.NUM_BLOCKS, 70 | cfg.SEARCH.TIE_CELL, cfg.MODEL.USE_ASPP, cfg.MODEL.USE_RES, 71 | cfg.MODEL.PRIMITIVES), 72 | 'eval_x{}/'.format(s_factor), dataset)) 73 | 74 | mkdir(result_save_dir) 75 | 76 | with torch.no_grad(): 77 | 78 | for item in data_dict: 79 | sample_id = item['gt_path'].split('/')[-1] 80 | hr_path = item['gt_path'] 81 | lr_path = item['x{}_path'.format(s_factor)] 82 | width, height = item['x{}_size'.format(s_factor)] 83 | 84 | hr_im = Image.open('/'.join((data_root, 'sr', hr_path))).crop((0, 0, width * s_factor, height * s_factor)) 85 | lr_im = Image.open('/'.join((data_root, 'sr', lr_path))) 86 | 87 | sample = {'image': lr_im, 'target': hr_im} 88 | sample = transforms(sample) 89 | 90 | image, target = sample['image'], sample['target'] 91 | 92 | output = model(image.unsqueeze(dim=0)) 93 | 94 | im_result = output.squeeze().transpose(0, 2).transpose(0, 1) 95 | gt_result = target.transpose(0, 2).transpose(0, 1) 96 | 97 | im_result[im_result > 1.0] = 1.0 98 | im_result[im_result < 0.0] = 0.0 99 | 100 | im_result_Y = (im_result[:, :, 0] * 24.966 + 101 | im_result[:, :, 1] * 128.553 + 102 | im_result[:, :, 2] * 65.481 + 103 | 16.0) / 255.0 104 | gt_result_Y = (gt_result[:, :, 0] * 24.966 + 105 | gt_result[:, :, 1] * 128.553 + 106 | gt_result[:, :, 2] * 65.481 + 107 | 16.0) / 255.0 108 | 109 | im_result_Y = im_result_Y.unsqueeze(dim=2) 110 | gt_result_Y = gt_result_Y.unsqueeze(dim=2) 111 | 112 | metric_SSIM(im_result_Y, gt_result_Y.cuda()) 113 | metric_PSNR(im_result_Y, gt_result_Y.cuda()) 114 | im_PIL = Image.fromarray(np.array(im_result.cpu().squeeze() * 255, np.uint8)) 115 | im_PIL.save(os.path.join(result_save_dir, sample_id)) 116 | 117 | ssim = metric_SSIM.metric_get() 118 | psnr = metric_PSNR.metric_get() 119 | 120 | print('dataset:{} ssim:{}, psnr:{}'.format(dataset, ssim, psnr)) 121 | with open(os.path.join(result_save_dir, 'evaluation_result.txt'), 'w') as f: 122 | f.write('SSIM:{} PSNR:{}'.format(ssim, psnr)) 123 | 124 | 125 | def main(): 126 | parser = argparse.ArgumentParser(description="evaluation") 127 | parser.add_argument( 128 | "--config-file", 129 | default="../configs/sr/DIV2K_2c3n/03_x4_infe_CR.yaml", 130 | metavar="FILE", 131 | help="path to config file", 132 | type=str, 133 | ) 134 | parser.add_argument( 135 | "--device", 136 | default='0', 137 | help="path to config file", 138 | type=str, 139 | ) 140 | 141 | args = parser.parse_args() 142 | os.environ['CUDA_VISIBLE_DEVICES'] = args.device 143 | cfg.merge_from_file(args.config_file) 144 | cfg.freeze() 145 | 146 | for dataset in cfg.DATASET.TEST_DATASETS: 147 | evaluation(cfg, cfg.DATALOADER.S_FACTOR, dataset) 148 | 149 | 150 | if __name__ == "__main__": 151 | main() 152 | -------------------------------------------------------------------------------- /tools/train.py: -------------------------------------------------------------------------------- 1 | """ 2 | Searching script 3 | """ 4 | 5 | import argparse 6 | from tensorboardX import SummaryWriter 7 | 8 | import torch 9 | import os 10 | import sys 11 | sys.path.append('..') 12 | from one_stage_nas.config import cfg 13 | from one_stage_nas.data import build_dataset 14 | from one_stage_nas.solver import make_lr_scheduler 15 | from one_stage_nas.solver import make_optimizer 16 | from one_stage_nas.engine.trainer import do_train 17 | from one_stage_nas.modeling.architectures import build_model 18 | from one_stage_nas.utils.checkpoint import Checkpointer 19 | from one_stage_nas.utils.logger import setup_logger 20 | from one_stage_nas.utils.misc import mkdir 21 | from one_stage_nas.utils.visualize import visualize 22 | 23 | 24 | def train(cfg, output_dir): 25 | model = build_model(cfg) 26 | 27 | # visualize 28 | visual_dir = output_dir + '/arch' 29 | geno_cell, geno_path = model.genotype() 30 | visualize(geno_cell, geno_path, visual_dir, cfg.SEARCH.TIE_CELL) 31 | 32 | # just use data parallel 33 | model = torch.nn.DataParallel(model).cuda() 34 | 35 | optimizer = make_optimizer(cfg, model) 36 | scheduler = make_lr_scheduler(cfg, optimizer) 37 | 38 | checkpointer = Checkpointer( 39 | model, optimizer, scheduler, output_dir + '/models', save_to_disk=True) 40 | 41 | train_loader, val_list = build_dataset(cfg) 42 | 43 | arguments = {} 44 | arguments["iteration"] = 0 45 | arguments["genotype"] = model.module.genotype() 46 | 47 | extra_checkpoint_data = checkpointer.load(cfg.MODEL.WEIGHT) 48 | arguments.update(extra_checkpoint_data) 49 | 50 | checkpoint_period = cfg.SOLVER.CHECKPOINT_PERIOD 51 | val_period = cfg.SOLVER.VALIDATE_PERIOD 52 | max_iter = cfg.SOLVER.TRAIN.MAX_ITER 53 | 54 | writer = SummaryWriter(logdir=output_dir + '/log', comment=cfg.DATASET.TASK + '_' + cfg.DATASET.DATA_NAME) 55 | 56 | do_train( 57 | model, 58 | train_loader, 59 | val_list, 60 | max_iter, 61 | val_period, 62 | optimizer, 63 | scheduler, 64 | checkpointer, 65 | checkpoint_period, 66 | arguments, 67 | writer, 68 | cfg 69 | ) 70 | 71 | 72 | def main(): 73 | parser = argparse.ArgumentParser(description="One-stage NAS Training") 74 | parser.add_argument( 75 | "--config-file", 76 | default="../configs/sr/DIV2K_2c3n/03_x4_train_CR.yaml", 77 | metavar="FILE", 78 | help="path to config file", 79 | type=str, 80 | ) 81 | parser.add_argument( 82 | "--device", 83 | default='4', 84 | help="path to config file", 85 | type=str, 86 | ) 87 | parser.add_argument( 88 | "opts", 89 | help="Modify config options using the command-line", 90 | default=None, 91 | nargs=argparse.REMAINDER, 92 | ) 93 | 94 | args = parser.parse_args() 95 | os.environ['CUDA_VISIBLE_DEVICES'] = args.device 96 | cfg.merge_from_file(args.config_file) 97 | cfg.merge_from_list(args.opts) 98 | cfg.freeze() 99 | 100 | if cfg.DATASET.TASK in ['dn']: 101 | output_dir = '/'.join((cfg.OUTPUT_DIR, cfg.DATASET.TASK, 102 | '{}/Outline-{}c{}n_TC-{}_ASPP-{}_Res-{}_Prim-{}'. 103 | format(cfg.DATASET.DATA_NAME, cfg.MODEL.NUM_LAYERS, cfg.MODEL.NUM_BLOCKS, 104 | cfg.SEARCH.TIE_CELL, cfg.MODEL.USE_ASPP, cfg.MODEL.USE_RES, cfg.MODEL.PRIMITIVES), 105 | 'train_noise_{}'.format(cfg.DATALOADER.SIGMA))) 106 | elif cfg.DATASET.TASK in ['sid']: 107 | output_dir = '/'.join((cfg.OUTPUT_DIR, cfg.DATASET.TASK, 108 | '{}/Outline-{}c{}n_TC-{}_ASPP-{}_Res-{}_Prim-{}'. 109 | format(cfg.DATASET.DATA_NAME, cfg.MODEL.NUM_LAYERS, cfg.MODEL.NUM_BLOCKS, 110 | cfg.SEARCH.TIE_CELL, cfg.MODEL.USE_ASPP, cfg.MODEL.USE_RES, cfg.MODEL.PRIMITIVES), 111 | 'train')) 112 | elif cfg.DATASET.TASK in ['sr']: 113 | output_dir = '/'.join((cfg.OUTPUT_DIR, cfg.DATASET.TASK, 114 | '{}/Outline-{}c{}n_TC-{}_ASPP-{}_Res-{}_Prim-{}'. 115 | format(cfg.DATASET.DATA_NAME, cfg.MODEL.NUM_LAYERS, cfg.MODEL.NUM_BLOCKS, 116 | cfg.SEARCH.TIE_CELL, cfg.MODEL.USE_ASPP, cfg.MODEL.USE_RES, cfg.MODEL.PRIMITIVES), 117 | 'train_x{}'.format(cfg.DATALOADER.S_FACTOR))) 118 | 119 | mkdir(output_dir+'/models') 120 | 121 | logger = setup_logger("one_stage_nas", output_dir) 122 | logger.info(args) 123 | 124 | logger.info("Loaded configuration file {}".format(args.config_file)) 125 | with open(args.config_file, "r") as cf: 126 | config_str = "\n" + cf.read() 127 | logger.info(config_str) 128 | logger.info("Running with config:\n{}".format(cfg)) 129 | 130 | train(cfg, output_dir) 131 | 132 | 133 | if __name__ == "__main__": 134 | main() 135 | --------------------------------------------------------------------------------