├── LICENSE ├── README.md ├── config ├── coco │ ├── coco_split0_resnet50_manet.yaml │ ├── coco_split0_resnet50_manet_5s.yaml │ ├── coco_split0_vgg_manet.yaml │ ├── coco_split0_vgg_manet_5s.yaml │ ├── coco_split1_resnet50_manet.yaml │ ├── coco_split1_resnet50_manet_5s.yaml │ ├── coco_split1_vgg_manet.yaml │ ├── coco_split1_vgg_manet_5s.yaml │ ├── coco_split2_resnet50_manet.yaml │ ├── coco_split2_resnet50_manet_5s.yaml │ ├── coco_split2_vgg_manet.yaml │ ├── coco_split2_vgg_manet_5s.yaml │ ├── coco_split3_resnet50_manet.yaml │ ├── coco_split3_resnet50_manet_5s.yaml │ ├── coco_split3_vgg_manet.yaml │ └── coco_split3_vgg_manet_5s.yaml └── pascal │ ├── pascal_split0_resnet50_manet.yaml │ ├── pascal_split0_resnet50_manet_5s.yaml │ ├── pascal_split0_vgg_manet.yaml │ ├── pascal_split0_vgg_manet_5s.yaml │ ├── pascal_split1_resnet50_manet.yaml │ ├── pascal_split1_resnet50_manet_5s.yaml │ ├── pascal_split1_vgg_manet.yaml │ ├── pascal_split1_vgg_manet_5s.yaml │ ├── pascal_split2_resnet50_manet.yaml │ ├── pascal_split2_resnet50_manet_5s.yaml │ ├── pascal_split2_vgg_manet.yaml │ ├── pascal_split2_vgg_manet_5s.yaml │ ├── pascal_split3_resnet50_manet.yaml │ ├── pascal_split3_resnet50_manet_5s.yaml │ ├── pascal_split3_vgg_manet.yaml │ └── pascal_split3_vgg_manet_5s.yaml ├── env.yaml ├── lists ├── coco │ ├── fss_list │ │ ├── train │ │ │ ├── data_list_0.txt │ │ │ ├── data_list_1.txt │ │ │ ├── data_list_2.txt │ │ │ ├── data_list_3.txt │ │ │ ├── sub_class_file_list_0.txt │ │ │ ├── sub_class_file_list_1.txt │ │ │ ├── sub_class_file_list_2.txt │ │ │ └── sub_class_file_list_3.txt │ │ └── val │ │ │ ├── data_list_0.txt │ │ │ ├── data_list_1.txt │ │ │ ├── data_list_2.txt │ │ │ ├── data_list_3.txt │ │ │ ├── sub_class_file_list_0.txt │ │ │ ├── sub_class_file_list_1.txt │ │ │ ├── sub_class_file_list_2.txt │ │ │ └── sub_class_file_list_3.txt │ ├── train.txt │ ├── train_data_list.txt │ ├── val.txt │ └── val_data_list.txt └── pascal │ ├── duplicate_removel.py │ ├── fss_list │ ├── train │ │ ├── data_list_0.txt │ │ ├── data_list_1.txt │ │ ├── data_list_2.txt │ │ ├── data_list_3.txt │ │ ├── sub_class_file_list_0.txt │ │ ├── sub_class_file_list_1.txt │ │ ├── sub_class_file_list_2.txt │ │ └── sub_class_file_list_3.txt │ └── val │ │ ├── data_list_0.txt │ │ ├── data_list_1.txt │ │ ├── data_list_2.txt │ │ ├── data_list_3.txt │ │ ├── sub_class_file_list_0.txt │ │ ├── sub_class_file_list_1.txt │ │ ├── sub_class_file_list_2.txt │ │ └── sub_class_file_list_3.txt │ ├── sbd_data.txt │ ├── val.txt │ ├── voc_original_train.txt │ ├── voc_sbd_merge.txt │ └── voc_sbd_merge_noduplicate.txt ├── model ├── ASPP.py ├── HMNet.py ├── HMNetAMP.py ├── PPM.py ├── PSPNet.py ├── backbone_res.py ├── backbone_utils.py ├── loss.py ├── mamba_blocks.py ├── resnet.py └── vgg.py ├── test_coco.py ├── test_coco.sh ├── test_pascal.py ├── test_pascal.sh ├── train_coco.py ├── train_coco.sh ├── train_pascal.py ├── train_pascal.sh └── util ├── __init__.py ├── config.py ├── dataset.py ├── get_weak_anns.py ├── transform.py ├── transform_tri.py └── util.py /LICENSE: -------------------------------------------------------------------------------- 1 | S-Lab License 1.0 2 | 3 | Copyright 2024 S-Lab 4 | 5 | Redistribution and use for non-commercial purpose in source and binary forms, with or without modification, are permitted provided that the following conditions are met: 6 | 7 | 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. 8 | 9 | 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. 10 | 11 | 3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. 12 | 13 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 14 | 15 | 4. In the event that redistribution and/or use for commercial purpose in source or binary forms, with or without modification is required, please contact the contributor(s) of the work. 16 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Hybrid Mamba for Few-Shot Segmentation 2 | 3 | This repository contains the code for our NIPS 2024 [paper](https://arxiv.org/abs/2409.19613) "*Hybrid Mamba for Few-Shot Segmentation*", where we design a cross attention-like Mamba method to enable support-query interactions. 4 | 5 | > **Abstract**: *Many few-shot segmentation (FSS) methods use cross attention to fuse support foreground (FG) into query features, regardless of the quadratic complexity. A recent advance Mamba can also well capture intra-sequence dependencies, yet the complexity is only linear. Hence, we aim to devise a cross (attention-like) Mamba to capture inter-sequence dependencies for FSS. A simple idea is to scan on support features to selectively compress them into the hidden state, which is then used as the initial hidden state to sequentially scan query features. Nevertheless, it suffers from (1) support forgetting issue: query features will also gradually be compressed when scanning on them, so the support features in hidden state keep reducing, and many query pixels cannot fuse sufficient support features; (2) intra-class gap issue: query FG is essentially more similar to itself rather than support FG, i.e., query may prefer not to fuse support but their own features from the hidden state, yet the effective use of support information leads to the success of FSS. To tackle them, we design a hybrid Mamba network (HMNet), including (1) a support recapped Mamba to periodically recap the support features when scanning query, so the hidden state can always contain rich support information; (2) a query intercepted Mamba to forbid the mutual interactions among query pixels, and encourage them to fuse more support features from the hidden state. Consequently, the support information is better utilized, leading to better performance. Extensive experiments have been conducted on two public benchmarks, showing the superiority of HMNet.* 6 | 7 | ## Dependencies 8 | 9 | - Python 3.10 10 | - PyTorch 1.12.0 11 | - cuda 11.6 12 | - torchvision 0.13.0 13 | ``` 14 | > conda env create -f env.yaml 15 | ``` 16 | 17 | ## Datasets 18 | 19 | - PASCAL-5i: [VOC2012](http://host.robots.ox.ac.uk/pascal/VOC/voc2012/) + [SBD](http://home.bharathh.info/pubs/codes/SBD/download.html) 20 | - COCO-20i: [COCO2014](https://cocodataset.org/#download) 21 | 22 | You can download the pre-processed PASCAL-5i and COCO-20i datasets [here](https://entuedu-my.sharepoint.com/:f:/g/personal/qianxion001_e_ntu_edu_sg/ErEg1GJF6ldCt1vh00MLYYwBapLiCIbd-VgbPAgCjBb_TQ?e=ibJ4DM), and extract them into `data/` folder. Then, you need to create a symbolic link to the `pascal/VOCdevkit` data folder as follows: 23 | ``` 24 | > ln -s /data/pascal/VOCdevkit /data/VOCdevkit2012 25 | ``` 26 | 27 | The directory structure is: 28 | 29 | ../ 30 | ├── HMNet/ 31 | └── data/ 32 | ├── VOCdevkit2012/ 33 | │ └── VOC2012/ 34 | │ ├── JPEGImages/ 35 | │ ├── ... 36 | │ └── SegmentationClassAug/ 37 | └── MSCOCO2014/ 38 | ├── annotations/ 39 | │ ├── train2014/ 40 | │ └── val2014/ 41 | ├── train2014/ 42 | └── val2014/ 43 | 44 | ## Models 45 | 46 | - Download the pretrained backbones from [here](https://entuedu-my.sharepoint.com/:u:/g/personal/qianxion001_e_ntu_edu_sg/EUHlKdET3mJGie_IjtpzW5kBo45yz0PB2dW9n55Vo5acXw?e=uyuUDX) and put them into the `initmodel` directory. 47 | - Download [exp.tar.gz](https://entuedu-my.sharepoint.com/:u:/g/personal/qianxion001_e_ntu_edu_sg/EbcNC1Ram0lJozZ2qe624uEBhXNmKrI64CM0uhEPJuxaig?e=iDMvrI) to obtain all trained models for PASCAL-5i and COCO-20i. 48 | 49 | ## Testing 50 | 51 | - **Commands**: 52 | ``` 53 | sh test_pascal.sh {Split: 0/1/2/3} {Net: resnet50/vgg} {Postfix: manet/manet_5s} 54 | sh test_coco.sh {Split: 0/1/2/3} {Net: resnet50/vgg} {Postfix: manet/manet_5s} 55 | 56 | # e.g., testing split 0 under 1-shot setting on PASCAL-5i, with ResNet50 as the pretrained backbone: 57 | sh test_pascal.sh 0 resnet50 manet 58 | 59 | # e.g., testing split 0 under 5-shot setting on COCO-20i, with ResNet50 as the pretrained backbone: 60 | sh test_coco.sh 0 resnet50 manet_5s 61 | ``` 62 | 63 | ## References 64 | 65 | This repo is mainly built based on [BAM](https://github.com/chunbolang/BAM). Thanks for their great work! 66 | 67 | -------------------------------------------------------------------------------- /config/coco/coco_split0_resnet50_manet.yaml: -------------------------------------------------------------------------------- 1 | Data: 2 | data_root: ../data/MSCOCO2014 3 | train_list: ./lists/coco/train.txt 4 | val_list: ./lists/coco/val.txt 5 | classes: 2 6 | 7 | 8 | Train: 9 | # Aug 10 | train_h: 633 11 | train_w: 633 12 | val_size: 633 13 | scale_min: 0.8 # minimum random scale 14 | scale_max: 1.25 # maximum random scale 15 | rotate_min: -10 # minimum random rotate 16 | rotate_max: 10 # maximum random rotate 17 | ignore_label: 255 18 | padding_label: 255 19 | # Dataset & Mode 20 | split: 0 21 | shot: 1 22 | data_set: 'coco' 23 | use_split_coco: True # True means FWB setting 24 | # Optimizer 25 | batch_size: 2 # batch size for training (bs8 for 1GPU) 26 | base_lr: 0.005 27 | epochs: 75 28 | start_epoch: 0 29 | stop_interval: 75 # stop when the best result is not updated for "stop_interval" epochs 30 | index_split: -1 # index for determining the params group with 10x learning rate 31 | power: 0.9 # 0 means no decay 32 | momentum: 0.9 33 | weight_decay: 0.0001 34 | warmup: False 35 | # Viz & Save & Resume 36 | print_freq: 10 37 | save_freq: 10 38 | resume: # path to latest checkpoint (default: none, such as epoch_10.pth) 39 | # Validate 40 | evaluate: True 41 | SubEpoch_val: True # val at the half epoch 42 | fix_random_seed_val: True 43 | batch_size_val: 1 44 | resized_val: True 45 | ori_resize: True # use original label for evaluation 46 | # Else 47 | workers: 32 48 | fix_bn: True 49 | manual_seed: 321 50 | seed_deterministic: False 51 | zoom_factor: 8 # zoom factor for final prediction during training, be in [1, 2, 4, 8] 52 | 53 | Method: 54 | layers: 50 55 | vgg: False 56 | aux_weight1: 1.0 57 | aux_weight2: 1.0 58 | low_fea: 'layer2' # low_fea for computing the Gram matrix 59 | kshot_trans_dim: 2 # K-shot dimensionality reduction 60 | merge: 'final' # fusion scheme for GFSS ('base' Eq(S1) | 'final' Eq(18) ) 61 | merge_tau: 0.9 # fusion threshold tau 62 | 63 | Test_Finetune: 64 | weight: best_1shot.pth # load weight for fine-tuning or testing (such as train5_epoch_47.5_0.4926.pth) 65 | ann_type: 'mask' # mask/bbox 66 | 67 | 68 | 69 | ## deprecated multi-processing training 70 | # Distributed: 71 | # dist_url: tcp://127.0.0.1:6789 72 | # dist_backend: 'nccl' 73 | # multiprocessing_distributed: False 74 | # world_size: 1 75 | # rank: 0 76 | # use_apex: False 77 | # opt_level: 'O0' 78 | # keep_batchnorm_fp32: 79 | # loss_scale: 80 | 81 | -------------------------------------------------------------------------------- /config/coco/coco_split0_resnet50_manet_5s.yaml: -------------------------------------------------------------------------------- 1 | Data: 2 | data_root: ../data/MSCOCO2014 3 | train_list: ./lists/coco/train.txt 4 | val_list: ./lists/coco/val.txt 5 | classes: 2 6 | 7 | 8 | Train: 9 | # Aug 10 | train_h: 633 11 | train_w: 633 12 | val_size: 633 13 | scale_min: 0.8 # minimum random scale 14 | scale_max: 1.25 # maximum random scale 15 | rotate_min: -10 # minimum random rotate 16 | rotate_max: 10 # maximum random rotate 17 | ignore_label: 255 18 | padding_label: 255 19 | # Dataset & Mode 20 | split: 0 21 | shot: 5 22 | data_set: 'coco' 23 | use_split_coco: True # True means FWB setting 24 | # Optimizer 25 | batch_size: 2 # batch size for training (bs8 for 1GPU) 26 | base_lr: 0.005 27 | epochs: 75 28 | start_epoch: 0 29 | stop_interval: 75 # stop when the best result is not updated for "stop_interval" epochs 30 | index_split: -1 # index for determining the params group with 10x learning rate 31 | power: 0.9 # 0 means no decay 32 | momentum: 0.9 33 | weight_decay: 0.0001 34 | warmup: False 35 | # Viz & Save & Resume 36 | print_freq: 10 37 | save_freq: 10 38 | resume: # path to latest checkpoint (default: none, such as epoch_10.pth) 39 | # Validate 40 | evaluate: True 41 | SubEpoch_val: True # val at the half epoch 42 | fix_random_seed_val: True 43 | batch_size_val: 1 44 | resized_val: True 45 | ori_resize: True # use original label for evaluation 46 | # Else 47 | workers: 32 48 | fix_bn: True 49 | manual_seed: 321 50 | seed_deterministic: False 51 | zoom_factor: 8 # zoom factor for final prediction during training, be in [1, 2, 4, 8] 52 | 53 | Method: 54 | layers: 50 55 | vgg: False 56 | aux_weight1: 1.0 57 | aux_weight2: 1.0 58 | low_fea: 'layer2' # low_fea for computing the Gram matrix 59 | kshot_trans_dim: 2 # K-shot dimensionality reduction 60 | merge: 'final' # fusion scheme for GFSS ('base' Eq(S1) | 'final' Eq(18) ) 61 | merge_tau: 0.9 # fusion threshold tau 62 | 63 | Test_Finetune: 64 | weight: best_5shot.pth # load weight for fine-tuning or testing (such as train5_epoch_47.5_0.4926.pth) 65 | ann_type: 'mask' # mask/bbox 66 | 67 | 68 | 69 | ## deprecated multi-processing training 70 | # Distributed: 71 | # dist_url: tcp://127.0.0.1:6789 72 | # dist_backend: 'nccl' 73 | # multiprocessing_distributed: False 74 | # world_size: 1 75 | # rank: 0 76 | # use_apex: False 77 | # opt_level: 'O0' 78 | # keep_batchnorm_fp32: 79 | # loss_scale: 80 | 81 | -------------------------------------------------------------------------------- /config/coco/coco_split0_vgg_manet.yaml: -------------------------------------------------------------------------------- 1 | Data: 2 | data_root: ../data/MSCOCO2014 3 | train_list: ./lists/coco/train.txt 4 | val_list: ./lists/coco/val.txt 5 | classes: 2 6 | 7 | 8 | Train: 9 | # Aug 10 | train_h: 633 11 | train_w: 633 12 | val_size: 633 13 | scale_min: 0.8 # minimum random scale 14 | scale_max: 1.25 # maximum random scale 15 | rotate_min: -10 # minimum random rotate 16 | rotate_max: 10 # maximum random rotate 17 | ignore_label: 255 18 | padding_label: 255 19 | # Dataset & Mode 20 | split: 0 21 | shot: 1 22 | data_set: 'coco' 23 | use_split_coco: True # True means FWB setting 24 | # Optimizer 25 | batch_size: 2 # batch size for training (bs8 for 1GPU) 26 | base_lr: 0.005 27 | epochs: 75 28 | start_epoch: 0 29 | stop_interval: 75 # stop when the best result is not updated for "stop_interval" epochs 30 | index_split: -1 # index for determining the params group with 10x learning rate 31 | power: 0.9 # 0 means no decay 32 | momentum: 0.9 33 | weight_decay: 0.0001 34 | warmup: False 35 | # Viz & Save & Resume 36 | print_freq: 10 37 | save_freq: 10 38 | resume: # path to latest checkpoint (default: none, such as epoch_10.pth) 39 | # Validate 40 | evaluate: True 41 | SubEpoch_val: True # val at the half epoch 42 | fix_random_seed_val: True 43 | batch_size_val: 1 44 | resized_val: True 45 | ori_resize: True # use original label for evaluation 46 | # Else 47 | workers: 32 48 | fix_bn: True 49 | manual_seed: 321 50 | seed_deterministic: False 51 | zoom_factor: 8 # zoom factor for final prediction during training, be in [1, 2, 4, 8] 52 | 53 | Method: 54 | layers: 50 55 | vgg: True 56 | aux_weight1: 1.0 57 | aux_weight2: 1.0 58 | low_fea: 'layer2' # low_fea for computing the Gram matrix 59 | kshot_trans_dim: 2 # K-shot dimensionality reduction 60 | merge: 'final' # fusion scheme for GFSS ('base' Eq(S1) | 'final' Eq(18) ) 61 | merge_tau: 0.9 # fusion threshold tau 62 | 63 | Test_Finetune: 64 | weight: best_1shot.pth # train_epoch_36_0.3953.pth # load weight for fine-tuning or testing (such as train5_epoch_47.5_0.4926.pth) 65 | ann_type: 'mask' # mask/bbox 66 | 67 | 68 | 69 | ## deprecated multi-processing training 70 | # Distributed: 71 | # dist_url: tcp://127.0.0.1:6789 72 | # dist_backend: 'nccl' 73 | # multiprocessing_distributed: False 74 | # world_size: 1 75 | # rank: 0 76 | # use_apex: False 77 | # opt_level: 'O0' 78 | # keep_batchnorm_fp32: 79 | # loss_scale: 80 | 81 | -------------------------------------------------------------------------------- /config/coco/coco_split0_vgg_manet_5s.yaml: -------------------------------------------------------------------------------- 1 | Data: 2 | data_root: ../data/MSCOCO2014 3 | train_list: ./lists/coco/train.txt 4 | val_list: ./lists/coco/val.txt 5 | classes: 2 6 | 7 | 8 | Train: 9 | # Aug 10 | train_h: 633 11 | train_w: 633 12 | val_size: 633 13 | scale_min: 0.8 # minimum random scale 14 | scale_max: 1.25 # maximum random scale 15 | rotate_min: -10 # minimum random rotate 16 | rotate_max: 10 # maximum random rotate 17 | ignore_label: 255 18 | padding_label: 255 19 | # Dataset & Mode 20 | split: 0 21 | shot: 5 22 | data_set: 'coco' 23 | use_split_coco: True # True means FWB setting 24 | # Optimizer 25 | batch_size: 2 # batch size for training (bs8 for 1GPU) 26 | base_lr: 0.005 27 | epochs: 75 28 | start_epoch: 0 29 | stop_interval: 75 # stop when the best result is not updated for "stop_interval" epochs 30 | index_split: -1 # index for determining the params group with 10x learning rate 31 | power: 0.9 # 0 means no decay 32 | momentum: 0.9 33 | weight_decay: 0.0001 34 | warmup: False 35 | # Viz & Save & Resume 36 | print_freq: 10 37 | save_freq: 10 38 | resume: # path to latest checkpoint (default: none, such as epoch_10.pth) 39 | # Validate 40 | evaluate: True 41 | SubEpoch_val: True # val at the half epoch 42 | fix_random_seed_val: True 43 | batch_size_val: 1 44 | resized_val: True 45 | ori_resize: True # use original label for evaluation 46 | # Else 47 | workers: 32 48 | fix_bn: True 49 | manual_seed: 321 50 | seed_deterministic: False 51 | zoom_factor: 8 # zoom factor for final prediction during training, be in [1, 2, 4, 8] 52 | 53 | Method: 54 | layers: 50 55 | vgg: True 56 | aux_weight1: 1.0 57 | aux_weight2: 1.0 58 | low_fea: 'layer2' # low_fea for computing the Gram matrix 59 | kshot_trans_dim: 2 # K-shot dimensionality reduction 60 | merge: 'final' # fusion scheme for GFSS ('base' Eq(S1) | 'final' Eq(18) ) 61 | merge_tau: 0.9 # fusion threshold tau 62 | 63 | Test_Finetune: 64 | weight: best_5shot.pth # train_epoch_36_0.3953.pth # load weight for fine-tuning or testing (such as train5_epoch_47.5_0.4926.pth) 65 | ann_type: 'mask' # mask/bbox 66 | 67 | 68 | 69 | ## deprecated multi-processing training 70 | # Distributed: 71 | # dist_url: tcp://127.0.0.1:6789 72 | # dist_backend: 'nccl' 73 | # multiprocessing_distributed: False 74 | # world_size: 1 75 | # rank: 0 76 | # use_apex: False 77 | # opt_level: 'O0' 78 | # keep_batchnorm_fp32: 79 | # loss_scale: 80 | 81 | -------------------------------------------------------------------------------- /config/coco/coco_split1_resnet50_manet.yaml: -------------------------------------------------------------------------------- 1 | Data: 2 | data_root: ../data/MSCOCO2014 3 | train_list: ./lists/coco/train.txt 4 | val_list: ./lists/coco/val.txt 5 | classes: 2 6 | 7 | 8 | Train: 9 | # Aug 10 | train_h: 633 11 | train_w: 633 12 | val_size: 633 13 | scale_min: 0.8 # minimum random scale 14 | scale_max: 1.25 # maximum random scale 15 | rotate_min: -10 # minimum random rotate 16 | rotate_max: 10 # maximum random rotate 17 | ignore_label: 255 18 | padding_label: 255 19 | # Dataset & Mode 20 | split: 1 21 | shot: 1 22 | data_set: 'coco' 23 | use_split_coco: True # True means FWB setting 24 | # Optimizer 25 | batch_size: 2 # batch size for training (bs8 for 1GPU) 26 | base_lr: 0.005 27 | epochs: 75 28 | start_epoch: 0 29 | stop_interval: 75 # stop when the best result is not updated for "stop_interval" epochs 30 | index_split: -1 # index for determining the params group with 10x learning rate 31 | power: 0.9 # 0 means no decay 32 | momentum: 0.9 33 | weight_decay: 0.0001 34 | warmup: False 35 | # Viz & Save & Resume 36 | print_freq: 10 37 | save_freq: 10 38 | resume: # path to latest checkpoint (default: none, such as epoch_10.pth) 39 | # Validate 40 | evaluate: True 41 | SubEpoch_val: True # val at the half epoch 42 | fix_random_seed_val: True 43 | batch_size_val: 1 44 | resized_val: True 45 | ori_resize: True # use original label for evaluation 46 | # Else 47 | workers: 32 48 | fix_bn: True 49 | manual_seed: 321 50 | seed_deterministic: False 51 | zoom_factor: 8 # zoom factor for final prediction during training, be in [1, 2, 4, 8] 52 | 53 | Method: 54 | layers: 50 55 | vgg: False 56 | aux_weight1: 1.0 57 | aux_weight2: 1.0 58 | low_fea: 'layer2' # low_fea for computing the Gram matrix 59 | kshot_trans_dim: 2 # K-shot dimensionality reduction 60 | merge: 'final' # fusion scheme for GFSS ('base' Eq(S1) | 'final' Eq(18) ) 61 | merge_tau: 0.9 # fusion threshold tau 62 | 63 | Test_Finetune: 64 | weight: best_1shot.pth # load weight for fine-tuning or testing (such as train5_epoch_47.5_0.4926.pth) 65 | ann_type: 'mask' # mask/bbox 66 | 67 | 68 | 69 | ## deprecated multi-processing training 70 | # Distributed: 71 | # dist_url: tcp://127.0.0.1:6789 72 | # dist_backend: 'nccl' 73 | # multiprocessing_distributed: False 74 | # world_size: 1 75 | # rank: 0 76 | # use_apex: False 77 | # opt_level: 'O0' 78 | # keep_batchnorm_fp32: 79 | # loss_scale: 80 | 81 | -------------------------------------------------------------------------------- /config/coco/coco_split1_resnet50_manet_5s.yaml: -------------------------------------------------------------------------------- 1 | Data: 2 | data_root: ../data/MSCOCO2014 3 | train_list: ./lists/coco/train.txt 4 | val_list: ./lists/coco/val.txt 5 | classes: 2 6 | 7 | 8 | Train: 9 | # Aug 10 | train_h: 633 11 | train_w: 633 12 | val_size: 633 13 | scale_min: 0.8 # minimum random scale 14 | scale_max: 1.25 # maximum random scale 15 | rotate_min: -10 # minimum random rotate 16 | rotate_max: 10 # maximum random rotate 17 | ignore_label: 255 18 | padding_label: 255 19 | # Dataset & Mode 20 | split: 1 21 | shot: 5 22 | data_set: 'coco' 23 | use_split_coco: True # True means FWB setting 24 | # Optimizer 25 | batch_size: 2 # batch size for training (bs8 for 1GPU) 26 | base_lr: 0.005 27 | epochs: 75 28 | start_epoch: 0 29 | stop_interval: 75 # stop when the best result is not updated for "stop_interval" epochs 30 | index_split: -1 # index for determining the params group with 10x learning rate 31 | power: 0.9 # 0 means no decay 32 | momentum: 0.9 33 | weight_decay: 0.0001 34 | warmup: False 35 | # Viz & Save & Resume 36 | print_freq: 10 37 | save_freq: 10 38 | resume: # path to latest checkpoint (default: none, such as epoch_10.pth) 39 | # Validate 40 | evaluate: True 41 | SubEpoch_val: True # val at the half epoch 42 | fix_random_seed_val: True 43 | batch_size_val: 1 44 | resized_val: True 45 | ori_resize: True # use original label for evaluation 46 | # Else 47 | workers: 32 48 | fix_bn: True 49 | manual_seed: 321 50 | seed_deterministic: False 51 | zoom_factor: 8 # zoom factor for final prediction during training, be in [1, 2, 4, 8] 52 | 53 | Method: 54 | layers: 50 55 | vgg: False 56 | aux_weight1: 1.0 57 | aux_weight2: 1.0 58 | low_fea: 'layer2' # low_fea for computing the Gram matrix 59 | kshot_trans_dim: 2 # K-shot dimensionality reduction 60 | merge: 'final' # fusion scheme for GFSS ('base' Eq(S1) | 'final' Eq(18) ) 61 | merge_tau: 0.9 # fusion threshold tau 62 | 63 | Test_Finetune: 64 | weight: best_5shot.pth # load weight for fine-tuning or testing (such as train5_epoch_47.5_0.4926.pth) 65 | ann_type: 'mask' # mask/bbox 66 | 67 | 68 | 69 | ## deprecated multi-processing training 70 | # Distributed: 71 | # dist_url: tcp://127.0.0.1:6789 72 | # dist_backend: 'nccl' 73 | # multiprocessing_distributed: False 74 | # world_size: 1 75 | # rank: 0 76 | # use_apex: False 77 | # opt_level: 'O0' 78 | # keep_batchnorm_fp32: 79 | # loss_scale: 80 | 81 | -------------------------------------------------------------------------------- /config/coco/coco_split1_vgg_manet.yaml: -------------------------------------------------------------------------------- 1 | Data: 2 | data_root: ../data/MSCOCO2014 3 | train_list: ./lists/coco/train.txt 4 | val_list: ./lists/coco/val.txt 5 | classes: 2 6 | 7 | 8 | Train: 9 | # Aug 10 | train_h: 633 11 | train_w: 633 12 | val_size: 633 13 | scale_min: 0.8 # minimum random scale 14 | scale_max: 1.25 # maximum random scale 15 | rotate_min: -10 # minimum random rotate 16 | rotate_max: 10 # maximum random rotate 17 | ignore_label: 255 18 | padding_label: 255 19 | # Dataset & Mode 20 | split: 1 21 | shot: 1 22 | data_set: 'coco' 23 | use_split_coco: True # True means FWB setting 24 | # Optimizer 25 | batch_size: 2 # batch size for training (bs8 for 1GPU) 26 | base_lr: 0.005 27 | epochs: 75 28 | start_epoch: 0 29 | stop_interval: 75 # stop when the best result is not updated for "stop_interval" epochs 30 | index_split: -1 # index for determining the params group with 10x learning rate 31 | power: 0.9 # 0 means no decay 32 | momentum: 0.9 33 | weight_decay: 0.0001 34 | warmup: False 35 | # Viz & Save & Resume 36 | print_freq: 10 37 | save_freq: 10 38 | resume: # path to latest checkpoint (default: none, such as epoch_10.pth) 39 | # Validate 40 | evaluate: True 41 | SubEpoch_val: True # val at the half epoch 42 | fix_random_seed_val: True 43 | batch_size_val: 1 44 | resized_val: True 45 | ori_resize: True # use original label for evaluation 46 | # Else 47 | workers: 32 48 | fix_bn: True 49 | manual_seed: 321 50 | seed_deterministic: False 51 | zoom_factor: 8 # zoom factor for final prediction during training, be in [1, 2, 4, 8] 52 | 53 | Method: 54 | layers: 50 55 | vgg: True 56 | aux_weight1: 1.0 57 | aux_weight2: 1.0 58 | low_fea: 'layer2' # low_fea for computing the Gram matrix 59 | kshot_trans_dim: 2 # K-shot dimensionality reduction 60 | merge: 'final' # fusion scheme for GFSS ('base' Eq(S1) | 'final' Eq(18) ) 61 | merge_tau: 0.9 # fusion threshold tau 62 | 63 | Test_Finetune: 64 | weight: best_1shot.pth # train_epoch_36_0.3953.pth # load weight for fine-tuning or testing (such as train5_epoch_47.5_0.4926.pth) 65 | ann_type: 'mask' # mask/bbox 66 | 67 | 68 | 69 | ## deprecated multi-processing training 70 | # Distributed: 71 | # dist_url: tcp://127.0.0.1:6789 72 | # dist_backend: 'nccl' 73 | # multiprocessing_distributed: False 74 | # world_size: 1 75 | # rank: 0 76 | # use_apex: False 77 | # opt_level: 'O0' 78 | # keep_batchnorm_fp32: 79 | # loss_scale: 80 | 81 | -------------------------------------------------------------------------------- /config/coco/coco_split1_vgg_manet_5s.yaml: -------------------------------------------------------------------------------- 1 | Data: 2 | data_root: ../data/MSCOCO2014 3 | train_list: ./lists/coco/train.txt 4 | val_list: ./lists/coco/val.txt 5 | classes: 2 6 | 7 | 8 | Train: 9 | # Aug 10 | train_h: 633 11 | train_w: 633 12 | val_size: 633 13 | scale_min: 0.8 # minimum random scale 14 | scale_max: 1.25 # maximum random scale 15 | rotate_min: -10 # minimum random rotate 16 | rotate_max: 10 # maximum random rotate 17 | ignore_label: 255 18 | padding_label: 255 19 | # Dataset & Mode 20 | split: 1 21 | shot: 5 22 | data_set: 'coco' 23 | use_split_coco: True # True means FWB setting 24 | # Optimizer 25 | batch_size: 2 # batch size for training (bs8 for 1GPU) 26 | base_lr: 0.005 27 | epochs: 75 28 | start_epoch: 0 29 | stop_interval: 75 # stop when the best result is not updated for "stop_interval" epochs 30 | index_split: -1 # index for determining the params group with 10x learning rate 31 | power: 0.9 # 0 means no decay 32 | momentum: 0.9 33 | weight_decay: 0.0001 34 | warmup: False 35 | # Viz & Save & Resume 36 | print_freq: 10 37 | save_freq: 10 38 | resume: # path to latest checkpoint (default: none, such as epoch_10.pth) 39 | # Validate 40 | evaluate: True 41 | SubEpoch_val: True # val at the half epoch 42 | fix_random_seed_val: True 43 | batch_size_val: 1 44 | resized_val: True 45 | ori_resize: True # use original label for evaluation 46 | # Else 47 | workers: 32 48 | fix_bn: True 49 | manual_seed: 321 50 | seed_deterministic: False 51 | zoom_factor: 8 # zoom factor for final prediction during training, be in [1, 2, 4, 8] 52 | 53 | Method: 54 | layers: 50 55 | vgg: True 56 | aux_weight1: 1.0 57 | aux_weight2: 1.0 58 | low_fea: 'layer2' # low_fea for computing the Gram matrix 59 | kshot_trans_dim: 2 # K-shot dimensionality reduction 60 | merge: 'final' # fusion scheme for GFSS ('base' Eq(S1) | 'final' Eq(18) ) 61 | merge_tau: 0.9 # fusion threshold tau 62 | 63 | Test_Finetune: 64 | weight: best_5shot.pth # train_epoch_36_0.3953.pth # load weight for fine-tuning or testing (such as train5_epoch_47.5_0.4926.pth) 65 | ann_type: 'mask' # mask/bbox 66 | 67 | 68 | 69 | ## deprecated multi-processing training 70 | # Distributed: 71 | # dist_url: tcp://127.0.0.1:6789 72 | # dist_backend: 'nccl' 73 | # multiprocessing_distributed: False 74 | # world_size: 1 75 | # rank: 0 76 | # use_apex: False 77 | # opt_level: 'O0' 78 | # keep_batchnorm_fp32: 79 | # loss_scale: 80 | 81 | -------------------------------------------------------------------------------- /config/coco/coco_split2_resnet50_manet.yaml: -------------------------------------------------------------------------------- 1 | Data: 2 | data_root: ../data/MSCOCO2014 3 | train_list: ./lists/coco/train.txt 4 | val_list: ./lists/coco/val.txt 5 | classes: 2 6 | 7 | 8 | Train: 9 | # Aug 10 | train_h: 633 11 | train_w: 633 12 | val_size: 633 13 | scale_min: 0.8 # minimum random scale 14 | scale_max: 1.25 # maximum random scale 15 | rotate_min: -10 # minimum random rotate 16 | rotate_max: 10 # maximum random rotate 17 | ignore_label: 255 18 | padding_label: 255 19 | # Dataset & Mode 20 | split: 2 21 | shot: 1 22 | data_set: 'coco' 23 | use_split_coco: True # True means FWB setting 24 | # Optimizer 25 | batch_size: 2 # batch size for training (bs8 for 1GPU) 26 | base_lr: 0.005 27 | epochs: 75 28 | start_epoch: 0 29 | stop_interval: 75 # stop when the best result is not updated for "stop_interval" epochs 30 | index_split: -1 # index for determining the params group with 10x learning rate 31 | power: 0.9 # 0 means no decay 32 | momentum: 0.9 33 | weight_decay: 0.0001 34 | warmup: False 35 | # Viz & Save & Resume 36 | print_freq: 10 37 | save_freq: 10 38 | resume: # path to latest checkpoint (default: none, such as epoch_10.pth) 39 | # Validate 40 | evaluate: True 41 | SubEpoch_val: True # val at the half epoch 42 | fix_random_seed_val: True 43 | batch_size_val: 1 44 | resized_val: True 45 | ori_resize: True # use original label for evaluation 46 | # Else 47 | workers: 32 48 | fix_bn: True 49 | manual_seed: 321 50 | seed_deterministic: False 51 | zoom_factor: 8 # zoom factor for final prediction during training, be in [1, 2, 4, 8] 52 | 53 | Method: 54 | layers: 50 55 | vgg: False 56 | aux_weight1: 1.0 57 | aux_weight2: 1.0 58 | low_fea: 'layer2' # low_fea for computing the Gram matrix 59 | kshot_trans_dim: 2 # K-shot dimensionality reduction 60 | merge: 'final' # fusion scheme for GFSS ('base' Eq(S1) | 'final' Eq(18) ) 61 | merge_tau: 0.9 # fusion threshold tau 62 | 63 | Test_Finetune: 64 | weight: best_1shot.pth # load weight for fine-tuning or testing (such as train5_epoch_47.5_0.4926.pth) 65 | ann_type: 'mask' # mask/bbox 66 | 67 | 68 | 69 | ## deprecated multi-processing training 70 | # Distributed: 71 | # dist_url: tcp://127.0.0.1:6789 72 | # dist_backend: 'nccl' 73 | # multiprocessing_distributed: False 74 | # world_size: 1 75 | # rank: 0 76 | # use_apex: False 77 | # opt_level: 'O0' 78 | # keep_batchnorm_fp32: 79 | # loss_scale: 80 | 81 | -------------------------------------------------------------------------------- /config/coco/coco_split2_resnet50_manet_5s.yaml: -------------------------------------------------------------------------------- 1 | Data: 2 | data_root: ../data/MSCOCO2014 3 | train_list: ./lists/coco/train.txt 4 | val_list: ./lists/coco/val.txt 5 | classes: 2 6 | 7 | 8 | Train: 9 | # Aug 10 | train_h: 633 11 | train_w: 633 12 | val_size: 633 13 | scale_min: 0.8 # minimum random scale 14 | scale_max: 1.25 # maximum random scale 15 | rotate_min: -10 # minimum random rotate 16 | rotate_max: 10 # maximum random rotate 17 | ignore_label: 255 18 | padding_label: 255 19 | # Dataset & Mode 20 | split: 2 21 | shot: 5 22 | data_set: 'coco' 23 | use_split_coco: True # True means FWB setting 24 | # Optimizer 25 | batch_size: 2 # batch size for training (bs8 for 1GPU) 26 | base_lr: 0.005 27 | epochs: 75 28 | start_epoch: 0 29 | stop_interval: 75 # stop when the best result is not updated for "stop_interval" epochs 30 | index_split: -1 # index for determining the params group with 10x learning rate 31 | power: 0.9 # 0 means no decay 32 | momentum: 0.9 33 | weight_decay: 0.0001 34 | warmup: False 35 | # Viz & Save & Resume 36 | print_freq: 10 37 | save_freq: 10 38 | resume: # path to latest checkpoint (default: none, such as epoch_10.pth) 39 | # Validate 40 | evaluate: True 41 | SubEpoch_val: True # val at the half epoch 42 | fix_random_seed_val: True 43 | batch_size_val: 1 44 | resized_val: True 45 | ori_resize: True # use original label for evaluation 46 | # Else 47 | workers: 32 48 | fix_bn: True 49 | manual_seed: 321 50 | seed_deterministic: False 51 | zoom_factor: 8 # zoom factor for final prediction during training, be in [1, 2, 4, 8] 52 | 53 | Method: 54 | layers: 50 55 | vgg: False 56 | aux_weight1: 1.0 57 | aux_weight2: 1.0 58 | low_fea: 'layer2' # low_fea for computing the Gram matrix 59 | kshot_trans_dim: 2 # K-shot dimensionality reduction 60 | merge: 'final' # fusion scheme for GFSS ('base' Eq(S1) | 'final' Eq(18) ) 61 | merge_tau: 0.9 # fusion threshold tau 62 | 63 | Test_Finetune: 64 | weight: best_5shot.pth # load weight for fine-tuning or testing (such as train5_epoch_47.5_0.4926.pth) 65 | ann_type: 'mask' # mask/bbox 66 | 67 | 68 | 69 | ## deprecated multi-processing training 70 | # Distributed: 71 | # dist_url: tcp://127.0.0.1:6789 72 | # dist_backend: 'nccl' 73 | # multiprocessing_distributed: False 74 | # world_size: 1 75 | # rank: 0 76 | # use_apex: False 77 | # opt_level: 'O0' 78 | # keep_batchnorm_fp32: 79 | # loss_scale: 80 | 81 | -------------------------------------------------------------------------------- /config/coco/coco_split2_vgg_manet.yaml: -------------------------------------------------------------------------------- 1 | Data: 2 | data_root: ../data/MSCOCO2014 3 | train_list: ./lists/coco/train.txt 4 | val_list: ./lists/coco/val.txt 5 | classes: 2 6 | 7 | 8 | Train: 9 | # Aug 10 | train_h: 633 11 | train_w: 633 12 | val_size: 633 13 | scale_min: 0.8 # minimum random scale 14 | scale_max: 1.25 # maximum random scale 15 | rotate_min: -10 # minimum random rotate 16 | rotate_max: 10 # maximum random rotate 17 | ignore_label: 255 18 | padding_label: 255 19 | # Dataset & Mode 20 | split: 2 21 | shot: 1 22 | data_set: 'coco' 23 | use_split_coco: True # True means FWB setting 24 | # Optimizer 25 | batch_size: 2 # batch size for training (bs8 for 1GPU) 26 | base_lr: 0.005 27 | epochs: 75 28 | start_epoch: 0 29 | stop_interval: 75 # stop when the best result is not updated for "stop_interval" epochs 30 | index_split: -1 # index for determining the params group with 10x learning rate 31 | power: 0.9 # 0 means no decay 32 | momentum: 0.9 33 | weight_decay: 0.0001 34 | warmup: False 35 | # Viz & Save & Resume 36 | print_freq: 10 37 | save_freq: 10 38 | resume: # path to latest checkpoint (default: none, such as epoch_10.pth) 39 | # Validate 40 | evaluate: True 41 | SubEpoch_val: True # val at the half epoch 42 | fix_random_seed_val: True 43 | batch_size_val: 1 44 | resized_val: True 45 | ori_resize: True # use original label for evaluation 46 | # Else 47 | workers: 32 48 | fix_bn: True 49 | manual_seed: 321 50 | seed_deterministic: False 51 | zoom_factor: 8 # zoom factor for final prediction during training, be in [1, 2, 4, 8] 52 | 53 | Method: 54 | layers: 50 55 | vgg: True 56 | aux_weight1: 1.0 57 | aux_weight2: 1.0 58 | low_fea: 'layer2' # low_fea for computing the Gram matrix 59 | kshot_trans_dim: 2 # K-shot dimensionality reduction 60 | merge: 'final' # fusion scheme for GFSS ('base' Eq(S1) | 'final' Eq(18) ) 61 | merge_tau: 0.9 # fusion threshold tau 62 | 63 | Test_Finetune: 64 | weight: best_1shot.pth # train_epoch_36_0.3953.pth # load weight for fine-tuning or testing (such as train5_epoch_47.5_0.4926.pth) 65 | ann_type: 'mask' # mask/bbox 66 | 67 | 68 | 69 | ## deprecated multi-processing training 70 | # Distributed: 71 | # dist_url: tcp://127.0.0.1:6789 72 | # dist_backend: 'nccl' 73 | # multiprocessing_distributed: False 74 | # world_size: 1 75 | # rank: 0 76 | # use_apex: False 77 | # opt_level: 'O0' 78 | # keep_batchnorm_fp32: 79 | # loss_scale: 80 | 81 | -------------------------------------------------------------------------------- /config/coco/coco_split2_vgg_manet_5s.yaml: -------------------------------------------------------------------------------- 1 | Data: 2 | data_root: ../data/MSCOCO2014 3 | train_list: ./lists/coco/train.txt 4 | val_list: ./lists/coco/val.txt 5 | classes: 2 6 | 7 | 8 | Train: 9 | # Aug 10 | train_h: 633 11 | train_w: 633 12 | val_size: 633 13 | scale_min: 0.8 # minimum random scale 14 | scale_max: 1.25 # maximum random scale 15 | rotate_min: -10 # minimum random rotate 16 | rotate_max: 10 # maximum random rotate 17 | ignore_label: 255 18 | padding_label: 255 19 | # Dataset & Mode 20 | split: 2 21 | shot: 5 22 | data_set: 'coco' 23 | use_split_coco: True # True means FWB setting 24 | # Optimizer 25 | batch_size: 2 # batch size for training (bs8 for 1GPU) 26 | base_lr: 0.005 27 | epochs: 75 28 | start_epoch: 0 29 | stop_interval: 75 # stop when the best result is not updated for "stop_interval" epochs 30 | index_split: -1 # index for determining the params group with 10x learning rate 31 | power: 0.9 # 0 means no decay 32 | momentum: 0.9 33 | weight_decay: 0.0001 34 | warmup: False 35 | # Viz & Save & Resume 36 | print_freq: 10 37 | save_freq: 10 38 | resume: # path to latest checkpoint (default: none, such as epoch_10.pth) 39 | # Validate 40 | evaluate: True 41 | SubEpoch_val: True # val at the half epoch 42 | fix_random_seed_val: True 43 | batch_size_val: 1 44 | resized_val: True 45 | ori_resize: True # use original label for evaluation 46 | # Else 47 | workers: 32 48 | fix_bn: True 49 | manual_seed: 321 50 | seed_deterministic: False 51 | zoom_factor: 8 # zoom factor for final prediction during training, be in [1, 2, 4, 8] 52 | 53 | Method: 54 | layers: 50 55 | vgg: True 56 | aux_weight1: 1.0 57 | aux_weight2: 1.0 58 | low_fea: 'layer2' # low_fea for computing the Gram matrix 59 | kshot_trans_dim: 2 # K-shot dimensionality reduction 60 | merge: 'final' # fusion scheme for GFSS ('base' Eq(S1) | 'final' Eq(18) ) 61 | merge_tau: 0.9 # fusion threshold tau 62 | 63 | Test_Finetune: 64 | weight: best_5shot.pth # train_epoch_36_0.3953.pth # load weight for fine-tuning or testing (such as train5_epoch_47.5_0.4926.pth) 65 | ann_type: 'mask' # mask/bbox 66 | 67 | 68 | 69 | ## deprecated multi-processing training 70 | # Distributed: 71 | # dist_url: tcp://127.0.0.1:6789 72 | # dist_backend: 'nccl' 73 | # multiprocessing_distributed: False 74 | # world_size: 1 75 | # rank: 0 76 | # use_apex: False 77 | # opt_level: 'O0' 78 | # keep_batchnorm_fp32: 79 | # loss_scale: 80 | 81 | -------------------------------------------------------------------------------- /config/coco/coco_split3_resnet50_manet.yaml: -------------------------------------------------------------------------------- 1 | Data: 2 | data_root: ../data/MSCOCO2014 3 | train_list: ./lists/coco/train.txt 4 | val_list: ./lists/coco/val.txt 5 | classes: 2 6 | 7 | 8 | Train: 9 | # Aug 10 | train_h: 633 11 | train_w: 633 12 | val_size: 633 13 | scale_min: 0.8 # minimum random scale 14 | scale_max: 1.25 # maximum random scale 15 | rotate_min: -10 # minimum random rotate 16 | rotate_max: 10 # maximum random rotate 17 | ignore_label: 255 18 | padding_label: 255 19 | # Dataset & Mode 20 | split: 3 21 | shot: 1 22 | data_set: 'coco' 23 | use_split_coco: True # True means FWB setting 24 | # Optimizer 25 | batch_size: 2 # batch size for training (bs8 for 1GPU) 26 | base_lr: 0.005 27 | epochs: 75 28 | start_epoch: 0 29 | stop_interval: 75 # stop when the best result is not updated for "stop_interval" epochs 30 | index_split: -1 # index for determining the params group with 10x learning rate 31 | power: 0.9 # 0 means no decay 32 | momentum: 0.9 33 | weight_decay: 0.0001 34 | warmup: False 35 | # Viz & Save & Resume 36 | print_freq: 10 37 | save_freq: 10 38 | resume: # path to latest checkpoint (default: none, such as epoch_10.pth) 39 | # Validate 40 | evaluate: True 41 | SubEpoch_val: True # val at the half epoch 42 | fix_random_seed_val: True 43 | batch_size_val: 1 44 | resized_val: True 45 | ori_resize: True # use original label for evaluation 46 | # Else 47 | workers: 32 48 | fix_bn: True 49 | manual_seed: 321 50 | seed_deterministic: False 51 | zoom_factor: 8 # zoom factor for final prediction during training, be in [1, 2, 4, 8] 52 | 53 | Method: 54 | layers: 50 55 | vgg: False 56 | aux_weight1: 1.0 57 | aux_weight2: 1.0 58 | low_fea: 'layer2' # low_fea for computing the Gram matrix 59 | kshot_trans_dim: 2 # K-shot dimensionality reduction 60 | merge: 'final' # fusion scheme for GFSS ('base' Eq(S1) | 'final' Eq(18) ) 61 | merge_tau: 0.9 # fusion threshold tau 62 | 63 | Test_Finetune: 64 | weight: best_1shot.pth # load weight for fine-tuning or testing (such as train5_epoch_47.5_0.4926.pth) 65 | ann_type: 'mask' # mask/bbox 66 | 67 | 68 | 69 | ## deprecated multi-processing training 70 | # Distributed: 71 | # dist_url: tcp://127.0.0.1:6789 72 | # dist_backend: 'nccl' 73 | # multiprocessing_distributed: False 74 | # world_size: 1 75 | # rank: 0 76 | # use_apex: False 77 | # opt_level: 'O0' 78 | # keep_batchnorm_fp32: 79 | # loss_scale: 80 | 81 | -------------------------------------------------------------------------------- /config/coco/coco_split3_resnet50_manet_5s.yaml: -------------------------------------------------------------------------------- 1 | Data: 2 | data_root: ../data/MSCOCO2014 3 | train_list: ./lists/coco/train.txt 4 | val_list: ./lists/coco/val.txt 5 | classes: 2 6 | 7 | 8 | Train: 9 | # Aug 10 | train_h: 633 11 | train_w: 633 12 | val_size: 633 13 | scale_min: 0.8 # minimum random scale 14 | scale_max: 1.25 # maximum random scale 15 | rotate_min: -10 # minimum random rotate 16 | rotate_max: 10 # maximum random rotate 17 | ignore_label: 255 18 | padding_label: 255 19 | # Dataset & Mode 20 | split: 3 21 | shot: 5 22 | data_set: 'coco' 23 | use_split_coco: True # True means FWB setting 24 | # Optimizer 25 | batch_size: 2 # batch size for training (bs8 for 1GPU) 26 | base_lr: 0.005 27 | epochs: 75 28 | start_epoch: 0 29 | stop_interval: 75 # stop when the best result is not updated for "stop_interval" epochs 30 | index_split: -1 # index for determining the params group with 10x learning rate 31 | power: 0.9 # 0 means no decay 32 | momentum: 0.9 33 | weight_decay: 0.0001 34 | warmup: False 35 | # Viz & Save & Resume 36 | print_freq: 10 37 | save_freq: 10 38 | resume: # path to latest checkpoint (default: none, such as epoch_10.pth) 39 | # Validate 40 | evaluate: True 41 | SubEpoch_val: True # val at the half epoch 42 | fix_random_seed_val: True 43 | batch_size_val: 1 44 | resized_val: True 45 | ori_resize: True # use original label for evaluation 46 | # Else 47 | workers: 32 48 | fix_bn: True 49 | manual_seed: 321 50 | seed_deterministic: False 51 | zoom_factor: 8 # zoom factor for final prediction during training, be in [1, 2, 4, 8] 52 | 53 | Method: 54 | layers: 50 55 | vgg: False 56 | aux_weight1: 1.0 57 | aux_weight2: 1.0 58 | low_fea: 'layer2' # low_fea for computing the Gram matrix 59 | kshot_trans_dim: 2 # K-shot dimensionality reduction 60 | merge: 'final' # fusion scheme for GFSS ('base' Eq(S1) | 'final' Eq(18) ) 61 | merge_tau: 0.9 # fusion threshold tau 62 | 63 | Test_Finetune: 64 | weight: best_5shot.pth # load weight for fine-tuning or testing (such as train5_epoch_47.5_0.4926.pth) 65 | ann_type: 'mask' # mask/bbox 66 | 67 | 68 | 69 | ## deprecated multi-processing training 70 | # Distributed: 71 | # dist_url: tcp://127.0.0.1:6789 72 | # dist_backend: 'nccl' 73 | # multiprocessing_distributed: False 74 | # world_size: 1 75 | # rank: 0 76 | # use_apex: False 77 | # opt_level: 'O0' 78 | # keep_batchnorm_fp32: 79 | # loss_scale: 80 | 81 | -------------------------------------------------------------------------------- /config/coco/coco_split3_vgg_manet.yaml: -------------------------------------------------------------------------------- 1 | Data: 2 | data_root: ../data/MSCOCO2014 3 | train_list: ./lists/coco/train.txt 4 | val_list: ./lists/coco/val.txt 5 | classes: 2 6 | 7 | 8 | Train: 9 | # Aug 10 | train_h: 633 11 | train_w: 633 12 | val_size: 633 13 | scale_min: 0.8 # minimum random scale 14 | scale_max: 1.25 # maximum random scale 15 | rotate_min: -10 # minimum random rotate 16 | rotate_max: 10 # maximum random rotate 17 | ignore_label: 255 18 | padding_label: 255 19 | # Dataset & Mode 20 | split: 3 21 | shot: 1 22 | data_set: 'coco' 23 | use_split_coco: True # True means FWB setting 24 | # Optimizer 25 | batch_size: 2 # batch size for training (bs8 for 1GPU) 26 | base_lr: 0.005 27 | epochs: 75 28 | start_epoch: 0 29 | stop_interval: 75 # stop when the best result is not updated for "stop_interval" epochs 30 | index_split: -1 # index for determining the params group with 10x learning rate 31 | power: 0.9 # 0 means no decay 32 | momentum: 0.9 33 | weight_decay: 0.0001 34 | warmup: False 35 | # Viz & Save & Resume 36 | print_freq: 10 37 | save_freq: 10 38 | resume: # path to latest checkpoint (default: none, such as epoch_10.pth) 39 | # Validate 40 | evaluate: True 41 | SubEpoch_val: True # val at the half epoch 42 | fix_random_seed_val: True 43 | batch_size_val: 1 44 | resized_val: True 45 | ori_resize: True # use original label for evaluation 46 | # Else 47 | workers: 32 48 | fix_bn: True 49 | manual_seed: 321 50 | seed_deterministic: False 51 | zoom_factor: 8 # zoom factor for final prediction during training, be in [1, 2, 4, 8] 52 | 53 | Method: 54 | layers: 50 55 | vgg: True 56 | aux_weight1: 1.0 57 | aux_weight2: 1.0 58 | low_fea: 'layer2' # low_fea for computing the Gram matrix 59 | kshot_trans_dim: 2 # K-shot dimensionality reduction 60 | merge: 'final' # fusion scheme for GFSS ('base' Eq(S1) | 'final' Eq(18) ) 61 | merge_tau: 0.9 # fusion threshold tau 62 | 63 | Test_Finetune: 64 | weight: best_1shot.pth # train_epoch_36_0.3953.pth # load weight for fine-tuning or testing (such as train5_epoch_47.5_0.4926.pth) 65 | ann_type: 'mask' # mask/bbox 66 | 67 | 68 | 69 | ## deprecated multi-processing training 70 | # Distributed: 71 | # dist_url: tcp://127.0.0.1:6789 72 | # dist_backend: 'nccl' 73 | # multiprocessing_distributed: False 74 | # world_size: 1 75 | # rank: 0 76 | # use_apex: False 77 | # opt_level: 'O0' 78 | # keep_batchnorm_fp32: 79 | # loss_scale: 80 | 81 | -------------------------------------------------------------------------------- /config/coco/coco_split3_vgg_manet_5s.yaml: -------------------------------------------------------------------------------- 1 | Data: 2 | data_root: ../data/MSCOCO2014 3 | train_list: ./lists/coco/train.txt 4 | val_list: ./lists/coco/val.txt 5 | classes: 2 6 | 7 | 8 | Train: 9 | # Aug 10 | train_h: 633 11 | train_w: 633 12 | val_size: 633 13 | scale_min: 0.8 # minimum random scale 14 | scale_max: 1.25 # maximum random scale 15 | rotate_min: -10 # minimum random rotate 16 | rotate_max: 10 # maximum random rotate 17 | ignore_label: 255 18 | padding_label: 255 19 | # Dataset & Mode 20 | split: 3 21 | shot: 5 22 | data_set: 'coco' 23 | use_split_coco: True # True means FWB setting 24 | # Optimizer 25 | batch_size: 2 # batch size for training (bs8 for 1GPU) 26 | base_lr: 0.005 27 | epochs: 75 28 | start_epoch: 0 29 | stop_interval: 75 # stop when the best result is not updated for "stop_interval" epochs 30 | index_split: -1 # index for determining the params group with 10x learning rate 31 | power: 0.9 # 0 means no decay 32 | momentum: 0.9 33 | weight_decay: 0.0001 34 | warmup: False 35 | # Viz & Save & Resume 36 | print_freq: 10 37 | save_freq: 10 38 | resume: # path to latest checkpoint (default: none, such as epoch_10.pth) 39 | # Validate 40 | evaluate: True 41 | SubEpoch_val: True # val at the half epoch 42 | fix_random_seed_val: True 43 | batch_size_val: 1 44 | resized_val: True 45 | ori_resize: True # use original label for evaluation 46 | # Else 47 | workers: 32 48 | fix_bn: True 49 | manual_seed: 321 50 | seed_deterministic: False 51 | zoom_factor: 8 # zoom factor for final prediction during training, be in [1, 2, 4, 8] 52 | 53 | Method: 54 | layers: 50 55 | vgg: True 56 | aux_weight1: 1.0 57 | aux_weight2: 1.0 58 | low_fea: 'layer2' # low_fea for computing the Gram matrix 59 | kshot_trans_dim: 2 # K-shot dimensionality reduction 60 | merge: 'final' # fusion scheme for GFSS ('base' Eq(S1) | 'final' Eq(18) ) 61 | merge_tau: 0.9 # fusion threshold tau 62 | 63 | Test_Finetune: 64 | weight: best_5shot.pth # train_epoch_36_0.3953.pth # load weight for fine-tuning or testing (such as train5_epoch_47.5_0.4926.pth) 65 | ann_type: 'mask' # mask/bbox 66 | 67 | 68 | 69 | ## deprecated multi-processing training 70 | # Distributed: 71 | # dist_url: tcp://127.0.0.1:6789 72 | # dist_backend: 'nccl' 73 | # multiprocessing_distributed: False 74 | # world_size: 1 75 | # rank: 0 76 | # use_apex: False 77 | # opt_level: 'O0' 78 | # keep_batchnorm_fp32: 79 | # loss_scale: 80 | 81 | -------------------------------------------------------------------------------- /config/pascal/pascal_split0_resnet50_manet.yaml: -------------------------------------------------------------------------------- 1 | Data: 2 | data_root: ../data/VOCdevkit2012/VOC2012 3 | train_list: ./lists/pascal/voc_sbd_merge_noduplicate.txt 4 | val_list: ./lists/pascal/val.txt 5 | classes: 2 6 | 7 | 8 | Train: 9 | # Aug 10 | train_h: 473 11 | train_w: 473 12 | val_size: 473 13 | scale_min: 0.9 # minimum random scale 14 | scale_max: 1.1 # maximum random scale 15 | rotate_min: -10 # minimum random rotate 16 | rotate_max: 10 # maximum random rotate 17 | ignore_label: 255 18 | padding_label: 255 19 | # Dataset & Mode 20 | split: 0 21 | shot: 1 22 | data_set: 'pascal' 23 | use_split_coco: False # True means FWB setting 24 | # Optimizer 25 | batch_size: 2 # batch size for training (bs8 for 1GPU) 26 | base_lr: 0.005 27 | epochs: 300 28 | start_epoch: 0 29 | stop_interval: 100 # stop when the best result is not updated for "stop_interval" epochs 30 | index_split: -1 # index for determining the params group with 10x learning rate 31 | power: 0.9 # 0 means no decay 32 | momentum: 0.9 33 | weight_decay: 0.0001 34 | warmup: False 35 | # Viz & Save & Resume 36 | print_freq: 10 37 | save_freq: 10 38 | resume: # path to latest checkpoint (default: none, such as epoch_10.pth) 39 | # Validate 40 | evaluate: True 41 | SubEpoch_val: True # val at the half epoch 42 | fix_random_seed_val: True 43 | batch_size_val: 1 44 | resized_val: True 45 | ori_resize: True # use original label for evaluation 46 | # Else 47 | workers: 8 48 | fix_bn: True 49 | manual_seed: 321 50 | seed_deterministic: False 51 | zoom_factor: 8 # zoom factor for final prediction during training, be in [1, 2, 4, 8] 52 | 53 | Method: 54 | layers: 50 55 | vgg: False 56 | aux_weight1: 1.0 57 | aux_weight2: 1.0 58 | low_fea: 'layer2' # low_fea for computing the Gram matrix 59 | kshot_trans_dim: 2 # K-shot dimensionality reduction 60 | merge: 'final' # fusion scheme for GFSS ('base' Eq(S1) | 'final' Eq(18) ) 61 | merge_tau: 0.9 # fusion threshold tau 62 | 63 | Test_Finetune: 64 | weight: best_1shot.pth # load weight for fine-tuning or testing (such as win8_train_epoch_43_0.6753.pth) 65 | ann_type: 'mask' # mask/bbox 66 | 67 | 68 | 69 | ## deprecated multi-processing training 70 | # Distributed: 71 | # dist_url: tcp://127.0.0.1:6789 72 | # dist_backend: 'nccl' 73 | # multiprocessing_distributed: False 74 | # world_size: 1 75 | # rank: 0 76 | # use_apex: False 77 | # opt_level: 'O0' 78 | # keep_batchnorm_fp32: 79 | # loss_scale: 80 | 81 | -------------------------------------------------------------------------------- /config/pascal/pascal_split0_resnet50_manet_5s.yaml: -------------------------------------------------------------------------------- 1 | Data: 2 | data_root: ../data/VOCdevkit2012/VOC2012 3 | train_list: ./lists/pascal/voc_sbd_merge_noduplicate.txt 4 | val_list: ./lists/pascal/val.txt 5 | classes: 2 6 | 7 | 8 | Train: 9 | # Aug 10 | train_h: 473 11 | train_w: 473 12 | val_size: 473 13 | scale_min: 0.9 # minimum random scale 14 | scale_max: 1.1 # maximum random scale 15 | rotate_min: -10 # minimum random rotate 16 | rotate_max: 10 # maximum random rotate 17 | ignore_label: 255 18 | padding_label: 255 19 | # Dataset & Mode 20 | split: 0 21 | shot: 5 22 | data_set: 'pascal' 23 | use_split_coco: False # True means FWB setting 24 | # Optimizer 25 | batch_size: 2 # batch size for training (bs8 for 1GPU) 26 | base_lr: 0.005 27 | epochs: 300 28 | start_epoch: 0 29 | stop_interval: 100 # stop when the best result is not updated for "stop_interval" epochs 30 | index_split: -1 # index for determining the params group with 10x learning rate 31 | power: 0.9 # 0 means no decay 32 | momentum: 0.9 33 | weight_decay: 0.0001 34 | warmup: False 35 | # Viz & Save & Resume 36 | print_freq: 10 37 | save_freq: 10 38 | resume: # path to latest checkpoint (default: none, such as epoch_10.pth) 39 | # Validate 40 | evaluate: True 41 | SubEpoch_val: True # val at the half epoch 42 | fix_random_seed_val: True 43 | batch_size_val: 1 44 | resized_val: True 45 | ori_resize: True # use original label for evaluation 46 | # Else 47 | workers: 8 48 | fix_bn: True 49 | manual_seed: 321 50 | seed_deterministic: False 51 | zoom_factor: 8 # zoom factor for final prediction during training, be in [1, 2, 4, 8] 52 | 53 | Method: 54 | layers: 50 55 | vgg: False 56 | aux_weight1: 1.0 57 | aux_weight2: 1.0 58 | low_fea: 'layer2' # low_fea for computing the Gram matrix 59 | kshot_trans_dim: 2 # K-shot dimensionality reduction 60 | merge: 'final' # fusion scheme for GFSS ('base' Eq(S1) | 'final' Eq(18) ) 61 | merge_tau: 0.9 # fusion threshold tau 62 | 63 | Test_Finetune: 64 | weight: best_5shot.pth # load weight for fine-tuning or testing (such as win8_train_epoch_43_0.6753.pth) 65 | ann_type: 'mask' # mask/bbox 66 | 67 | 68 | 69 | ## deprecated multi-processing training 70 | # Distributed: 71 | # dist_url: tcp://127.0.0.1:6789 72 | # dist_backend: 'nccl' 73 | # multiprocessing_distributed: False 74 | # world_size: 1 75 | # rank: 0 76 | # use_apex: False 77 | # opt_level: 'O0' 78 | # keep_batchnorm_fp32: 79 | # loss_scale: 80 | 81 | -------------------------------------------------------------------------------- /config/pascal/pascal_split0_vgg_manet.yaml: -------------------------------------------------------------------------------- 1 | Data: 2 | data_root: ../data/VOCdevkit2012/VOC2012 3 | train_list: ./lists/pascal/voc_sbd_merge_noduplicate.txt 4 | val_list: ./lists/pascal/val.txt 5 | classes: 2 6 | 7 | 8 | Train: 9 | # Aug 10 | train_h: 473 11 | train_w: 473 12 | val_size: 473 13 | scale_min: 0.9 # minimum random scale 14 | scale_max: 1.1 # maximum random scale 15 | rotate_min: -10 # minimum random rotate 16 | rotate_max: 10 # maximum random rotate 17 | ignore_label: 255 18 | padding_label: 255 19 | # Dataset & Mode 20 | split: 0 21 | shot: 1 22 | data_set: 'pascal' 23 | use_split_coco: False # True means FWB setting 24 | # Optimizer 25 | batch_size: 2 # batch size for training (bs8 for 1GPU) 26 | base_lr: 0.005 27 | epochs: 300 28 | start_epoch: 0 29 | stop_interval: 100 # stop when the best result is not updated for "stop_interval" epochs 30 | index_split: -1 # index for determining the params group with 10x learning rate 31 | power: 0.9 # 0 means no decay 32 | momentum: 0.9 33 | weight_decay: 0.0001 34 | warmup: False 35 | # Viz & Save & Resume 36 | print_freq: 10 37 | save_freq: 10 38 | resume: # path to latest checkpoint (default: none, such as epoch_10.pth) 39 | # Validate 40 | evaluate: True 41 | SubEpoch_val: True # val at the half epoch 42 | fix_random_seed_val: True 43 | batch_size_val: 1 44 | resized_val: True 45 | ori_resize: True # use original label for evaluation 46 | # Else 47 | workers: 8 48 | fix_bn: True 49 | manual_seed: 321 50 | seed_deterministic: False 51 | zoom_factor: 8 # zoom factor for final prediction during training, be in [1, 2, 4, 8] 52 | 53 | Method: 54 | layers: 50 55 | vgg: True 56 | aux_weight1: 1.0 57 | aux_weight2: 1.0 58 | low_fea: 'layer2' # low_fea for computing the Gram matrix 59 | kshot_trans_dim: 2 # K-shot dimensionality reduction 60 | merge: 'final' # fusion scheme for GFSS ('base' Eq(S1) | 'final' Eq(18) ) 61 | merge_tau: 0.9 # fusion threshold tau 62 | 63 | Test_Finetune: 64 | weight: best_1shot.pth # load weight for fine-tuning or testing (such as win8_train_epoch_43_0.6753.pth) 65 | ann_type: 'mask' # mask/bbox 66 | 67 | 68 | 69 | ## deprecated multi-processing training 70 | # Distributed: 71 | # dist_url: tcp://127.0.0.1:6789 72 | # dist_backend: 'nccl' 73 | # multiprocessing_distributed: False 74 | # world_size: 1 75 | # rank: 0 76 | # use_apex: False 77 | # opt_level: 'O0' 78 | # keep_batchnorm_fp32: 79 | # loss_scale: 80 | 81 | -------------------------------------------------------------------------------- /config/pascal/pascal_split0_vgg_manet_5s.yaml: -------------------------------------------------------------------------------- 1 | Data: 2 | data_root: ../data/VOCdevkit2012/VOC2012 3 | train_list: ./lists/pascal/voc_sbd_merge_noduplicate.txt 4 | val_list: ./lists/pascal/val.txt 5 | classes: 2 6 | 7 | 8 | Train: 9 | # Aug 10 | train_h: 473 11 | train_w: 473 12 | val_size: 473 13 | scale_min: 0.9 # minimum random scale 14 | scale_max: 1.1 # maximum random scale 15 | rotate_min: -10 # minimum random rotate 16 | rotate_max: 10 # maximum random rotate 17 | ignore_label: 255 18 | padding_label: 255 19 | # Dataset & Mode 20 | split: 0 21 | shot: 5 22 | data_set: 'pascal' 23 | use_split_coco: False # True means FWB setting 24 | # Optimizer 25 | batch_size: 2 # batch size for training (bs8 for 1GPU) 26 | base_lr: 0.005 27 | epochs: 300 28 | start_epoch: 0 29 | stop_interval: 100 # stop when the best result is not updated for "stop_interval" epochs 30 | index_split: -1 # index for determining the params group with 10x learning rate 31 | power: 0.9 # 0 means no decay 32 | momentum: 0.9 33 | weight_decay: 0.0001 34 | warmup: False 35 | # Viz & Save & Resume 36 | print_freq: 10 37 | save_freq: 10 38 | resume: # path to latest checkpoint (default: none, such as epoch_10.pth) 39 | # Validate 40 | evaluate: True 41 | SubEpoch_val: True # val at the half epoch 42 | fix_random_seed_val: True 43 | batch_size_val: 1 44 | resized_val: True 45 | ori_resize: True # use original label for evaluation 46 | # Else 47 | workers: 8 48 | fix_bn: True 49 | manual_seed: 321 50 | seed_deterministic: False 51 | zoom_factor: 8 # zoom factor for final prediction during training, be in [1, 2, 4, 8] 52 | 53 | Method: 54 | layers: 50 55 | vgg: True 56 | aux_weight1: 1.0 57 | aux_weight2: 1.0 58 | low_fea: 'layer2' # low_fea for computing the Gram matrix 59 | kshot_trans_dim: 2 # K-shot dimensionality reduction 60 | merge: 'final' # fusion scheme for GFSS ('base' Eq(S1) | 'final' Eq(18) ) 61 | merge_tau: 0.9 # fusion threshold tau 62 | 63 | Test_Finetune: 64 | weight: best_5shot.pth # load weight for fine-tuning or testing (such as win8_train_epoch_43_0.6753.pth) 65 | ann_type: 'mask' # mask/bbox 66 | 67 | 68 | 69 | ## deprecated multi-processing training 70 | # Distributed: 71 | # dist_url: tcp://127.0.0.1:6789 72 | # dist_backend: 'nccl' 73 | # multiprocessing_distributed: False 74 | # world_size: 1 75 | # rank: 0 76 | # use_apex: False 77 | # opt_level: 'O0' 78 | # keep_batchnorm_fp32: 79 | # loss_scale: 80 | 81 | -------------------------------------------------------------------------------- /config/pascal/pascal_split1_resnet50_manet.yaml: -------------------------------------------------------------------------------- 1 | Data: 2 | data_root: ../data/VOCdevkit2012/VOC2012 3 | train_list: ./lists/pascal/voc_sbd_merge_noduplicate.txt 4 | val_list: ./lists/pascal/val.txt 5 | classes: 2 6 | 7 | 8 | Train: 9 | # Aug 10 | train_h: 473 11 | train_w: 473 12 | val_size: 473 13 | scale_min: 0.9 # minimum random scale 14 | scale_max: 1.1 # maximum random scale 15 | rotate_min: -10 # minimum random rotate 16 | rotate_max: 10 # maximum random rotate 17 | ignore_label: 255 18 | padding_label: 255 19 | # Dataset & Mode 20 | split: 1 21 | shot: 1 22 | data_set: 'pascal' 23 | use_split_coco: False # True means FWB setting 24 | # Optimizer 25 | batch_size: 2 # batch size for training (bs8 for 1GPU) 26 | base_lr: 0.005 27 | epochs: 300 28 | start_epoch: 0 29 | stop_interval: 100 # stop when the best result is not updated for "stop_interval" epochs 30 | index_split: -1 # index for determining the params group with 10x learning rate 31 | power: 0.9 # 0 means no decay 32 | momentum: 0.9 33 | weight_decay: 0.0001 34 | warmup: False 35 | # Viz & Save & Resume 36 | print_freq: 10 37 | save_freq: 10 38 | resume: # path to latest checkpoint (default: none, such as epoch_10.pth) 39 | # Validate 40 | evaluate: True 41 | SubEpoch_val: True # val at the half epoch 42 | fix_random_seed_val: True 43 | batch_size_val: 1 44 | resized_val: True 45 | ori_resize: True # use original label for evaluation 46 | # Else 47 | workers: 8 48 | fix_bn: True 49 | manual_seed: 321 50 | seed_deterministic: False 51 | zoom_factor: 8 # zoom factor for final prediction during training, be in [1, 2, 4, 8] 52 | 53 | Method: 54 | layers: 50 55 | vgg: False 56 | aux_weight1: 1.0 57 | aux_weight2: 1.0 58 | low_fea: 'layer2' # low_fea for computing the Gram matrix 59 | kshot_trans_dim: 2 # K-shot dimensionality reduction 60 | merge: 'final' # fusion scheme for GFSS ('base' Eq(S1) | 'final' Eq(18) ) 61 | merge_tau: 0.9 # fusion threshold tau 62 | 63 | Test_Finetune: 64 | weight: best_1shot.pth # load weight for fine-tuning or testing (such as win8_train_epoch_124_0.7262.pth) 65 | ann_type: 'mask' # mask/bbox 66 | 67 | 68 | 69 | ## deprecated multi-processing training 70 | # Distributed: 71 | # dist_url: tcp://127.0.0.1:6789 72 | # dist_backend: 'nccl' 73 | # multiprocessing_distributed: False 74 | # world_size: 1 75 | # rank: 0 76 | # use_apex: False 77 | # opt_level: 'O0' 78 | # keep_batchnorm_fp32: 79 | # loss_scale: 80 | 81 | -------------------------------------------------------------------------------- /config/pascal/pascal_split1_resnet50_manet_5s.yaml: -------------------------------------------------------------------------------- 1 | Data: 2 | data_root: ../data/VOCdevkit2012/VOC2012 3 | train_list: ./lists/pascal/voc_sbd_merge_noduplicate.txt 4 | val_list: ./lists/pascal/val.txt 5 | classes: 2 6 | 7 | 8 | Train: 9 | # Aug 10 | train_h: 473 11 | train_w: 473 12 | val_size: 473 13 | scale_min: 0.9 # minimum random scale 14 | scale_max: 1.1 # maximum random scale 15 | rotate_min: -10 # minimum random rotate 16 | rotate_max: 10 # maximum random rotate 17 | ignore_label: 255 18 | padding_label: 255 19 | # Dataset & Mode 20 | split: 1 21 | shot: 5 22 | data_set: 'pascal' 23 | use_split_coco: False # True means FWB setting 24 | # Optimizer 25 | batch_size: 2 # batch size for training (bs8 for 1GPU) 26 | base_lr: 0.005 27 | epochs: 300 28 | start_epoch: 0 29 | stop_interval: 100 # stop when the best result is not updated for "stop_interval" epochs 30 | index_split: -1 # index for determining the params group with 10x learning rate 31 | power: 0.9 # 0 means no decay 32 | momentum: 0.9 33 | weight_decay: 0.0001 34 | warmup: False 35 | # Viz & Save & Resume 36 | print_freq: 10 37 | save_freq: 10 38 | resume: # path to latest checkpoint (default: none, such as epoch_10.pth) 39 | # Validate 40 | evaluate: True 41 | SubEpoch_val: True # val at the half epoch 42 | fix_random_seed_val: True 43 | batch_size_val: 1 44 | resized_val: True 45 | ori_resize: True # use original label for evaluation 46 | # Else 47 | workers: 8 48 | fix_bn: True 49 | manual_seed: 321 50 | seed_deterministic: False 51 | zoom_factor: 8 # zoom factor for final prediction during training, be in [1, 2, 4, 8] 52 | 53 | Method: 54 | layers: 50 55 | vgg: False 56 | aux_weight1: 1.0 57 | aux_weight2: 1.0 58 | low_fea: 'layer2' # low_fea for computing the Gram matrix 59 | kshot_trans_dim: 2 # K-shot dimensionality reduction 60 | merge: 'final' # fusion scheme for GFSS ('base' Eq(S1) | 'final' Eq(18) ) 61 | merge_tau: 0.9 # fusion threshold tau 62 | 63 | Test_Finetune: 64 | weight: best_5shot.pth # load weight for fine-tuning or testing (such as win8_train_epoch_124_0.7262.pth) 65 | ann_type: 'mask' # mask/bbox 66 | 67 | 68 | 69 | ## deprecated multi-processing training 70 | # Distributed: 71 | # dist_url: tcp://127.0.0.1:6789 72 | # dist_backend: 'nccl' 73 | # multiprocessing_distributed: False 74 | # world_size: 1 75 | # rank: 0 76 | # use_apex: False 77 | # opt_level: 'O0' 78 | # keep_batchnorm_fp32: 79 | # loss_scale: 80 | 81 | -------------------------------------------------------------------------------- /config/pascal/pascal_split1_vgg_manet.yaml: -------------------------------------------------------------------------------- 1 | Data: 2 | data_root: ../data/VOCdevkit2012/VOC2012 3 | train_list: ./lists/pascal/voc_sbd_merge_noduplicate.txt 4 | val_list: ./lists/pascal/val.txt 5 | classes: 2 6 | 7 | 8 | Train: 9 | # Aug 10 | train_h: 473 11 | train_w: 473 12 | val_size: 473 13 | scale_min: 0.9 # minimum random scale 14 | scale_max: 1.1 # maximum random scale 15 | rotate_min: -10 # minimum random rotate 16 | rotate_max: 10 # maximum random rotate 17 | ignore_label: 255 18 | padding_label: 255 19 | # Dataset & Mode 20 | split: 1 21 | shot: 1 22 | data_set: 'pascal' 23 | use_split_coco: False # True means FWB setting 24 | # Optimizer 25 | batch_size: 2 # batch size for training (bs8 for 1GPU) 26 | base_lr: 0.005 27 | epochs: 300 28 | start_epoch: 0 29 | stop_interval: 100 # stop when the best result is not updated for "stop_interval" epochs 30 | index_split: -1 # index for determining the params group with 10x learning rate 31 | power: 0.9 # 0 means no decay 32 | momentum: 0.9 33 | weight_decay: 0.0001 34 | warmup: False 35 | # Viz & Save & Resume 36 | print_freq: 10 37 | save_freq: 10 38 | resume: # path to latest checkpoint (default: none, such as epoch_10.pth) 39 | # Validate 40 | evaluate: True 41 | SubEpoch_val: True # val at the half epoch 42 | fix_random_seed_val: True 43 | batch_size_val: 1 44 | resized_val: True 45 | ori_resize: True # use original label for evaluation 46 | # Else 47 | workers: 8 48 | fix_bn: True 49 | manual_seed: 321 50 | seed_deterministic: False 51 | zoom_factor: 8 # zoom factor for final prediction during training, be in [1, 2, 4, 8] 52 | 53 | Method: 54 | layers: 50 55 | vgg: True 56 | aux_weight1: 1.0 57 | aux_weight2: 1.0 58 | low_fea: 'layer2' # low_fea for computing the Gram matrix 59 | kshot_trans_dim: 2 # K-shot dimensionality reduction 60 | merge: 'final' # fusion scheme for GFSS ('base' Eq(S1) | 'final' Eq(18) ) 61 | merge_tau: 0.9 # fusion threshold tau 62 | 63 | Test_Finetune: 64 | weight: best_1shot.pth # load weight for fine-tuning or testing (such as win8_train_epoch_124_0.7262.pth) 65 | ann_type: 'mask' # mask/bbox 66 | 67 | 68 | 69 | ## deprecated multi-processing training 70 | # Distributed: 71 | # dist_url: tcp://127.0.0.1:6789 72 | # dist_backend: 'nccl' 73 | # multiprocessing_distributed: False 74 | # world_size: 1 75 | # rank: 0 76 | # use_apex: False 77 | # opt_level: 'O0' 78 | # keep_batchnorm_fp32: 79 | # loss_scale: 80 | 81 | -------------------------------------------------------------------------------- /config/pascal/pascal_split1_vgg_manet_5s.yaml: -------------------------------------------------------------------------------- 1 | Data: 2 | data_root: ../data/VOCdevkit2012/VOC2012 3 | train_list: ./lists/pascal/voc_sbd_merge_noduplicate.txt 4 | val_list: ./lists/pascal/val.txt 5 | classes: 2 6 | 7 | 8 | Train: 9 | # Aug 10 | train_h: 473 11 | train_w: 473 12 | val_size: 473 13 | scale_min: 0.9 # minimum random scale 14 | scale_max: 1.1 # maximum random scale 15 | rotate_min: -10 # minimum random rotate 16 | rotate_max: 10 # maximum random rotate 17 | ignore_label: 255 18 | padding_label: 255 19 | # Dataset & Mode 20 | split: 1 21 | shot: 5 22 | data_set: 'pascal' 23 | use_split_coco: False # True means FWB setting 24 | # Optimizer 25 | batch_size: 2 # batch size for training (bs8 for 1GPU) 26 | base_lr: 0.005 27 | epochs: 300 28 | start_epoch: 0 29 | stop_interval: 100 # stop when the best result is not updated for "stop_interval" epochs 30 | index_split: -1 # index for determining the params group with 10x learning rate 31 | power: 0.9 # 0 means no decay 32 | momentum: 0.9 33 | weight_decay: 0.0001 34 | warmup: False 35 | # Viz & Save & Resume 36 | print_freq: 10 37 | save_freq: 10 38 | resume: # path to latest checkpoint (default: none, such as epoch_10.pth) 39 | # Validate 40 | evaluate: True 41 | SubEpoch_val: True # val at the half epoch 42 | fix_random_seed_val: True 43 | batch_size_val: 1 44 | resized_val: True 45 | ori_resize: True # use original label for evaluation 46 | # Else 47 | workers: 8 48 | fix_bn: True 49 | manual_seed: 321 50 | seed_deterministic: False 51 | zoom_factor: 8 # zoom factor for final prediction during training, be in [1, 2, 4, 8] 52 | 53 | Method: 54 | layers: 50 55 | vgg: True 56 | aux_weight1: 1.0 57 | aux_weight2: 1.0 58 | low_fea: 'layer2' # low_fea for computing the Gram matrix 59 | kshot_trans_dim: 2 # K-shot dimensionality reduction 60 | merge: 'final' # fusion scheme for GFSS ('base' Eq(S1) | 'final' Eq(18) ) 61 | merge_tau: 0.9 # fusion threshold tau 62 | 63 | Test_Finetune: 64 | weight: best_5shot.pth # load weight for fine-tuning or testing (such as win8_train_epoch_124_0.7262.pth) 65 | ann_type: 'mask' # mask/bbox 66 | 67 | 68 | 69 | ## deprecated multi-processing training 70 | # Distributed: 71 | # dist_url: tcp://127.0.0.1:6789 72 | # dist_backend: 'nccl' 73 | # multiprocessing_distributed: False 74 | # world_size: 1 75 | # rank: 0 76 | # use_apex: False 77 | # opt_level: 'O0' 78 | # keep_batchnorm_fp32: 79 | # loss_scale: 80 | 81 | -------------------------------------------------------------------------------- /config/pascal/pascal_split2_resnet50_manet.yaml: -------------------------------------------------------------------------------- 1 | Data: 2 | data_root: ../data/VOCdevkit2012/VOC2012 3 | train_list: ./lists/pascal/voc_sbd_merge_noduplicate.txt 4 | val_list: ./lists/pascal/val.txt 5 | classes: 2 6 | 7 | 8 | Train: 9 | # Aug 10 | train_h: 473 11 | train_w: 473 12 | val_size: 473 13 | scale_min: 0.9 # minimum random scale 14 | scale_max: 1.1 # maximum random scale 15 | rotate_min: -10 # minimum random rotate 16 | rotate_max: 10 # maximum random rotate 17 | ignore_label: 255 18 | padding_label: 255 19 | # Dataset & Mode 20 | split: 2 21 | shot: 1 22 | data_set: 'pascal' 23 | use_split_coco: False # True means FWB setting 24 | # Optimizer 25 | batch_size: 2 # batch size for training (bs8 for 1GPU) 26 | base_lr: 0.005 27 | epochs: 300 28 | start_epoch: 0 29 | stop_interval: 100 # stop when the best result is not updated for "stop_interval" epochs 30 | index_split: -1 # index for determining the params group with 10x learning rate 31 | power: 0.9 # 0 means no decay 32 | momentum: 0.9 33 | weight_decay: 0.0001 34 | warmup: False 35 | # Viz & Save & Resume 36 | print_freq: 10 37 | save_freq: 10 38 | resume: # path to latest checkpoint (default: none, such as epoch_10.pth) 39 | # Validate 40 | evaluate: True 41 | SubEpoch_val: True # val at the half epoch 42 | fix_random_seed_val: True 43 | batch_size_val: 1 44 | resized_val: True 45 | ori_resize: True # use original label for evaluation 46 | # Else 47 | workers: 8 48 | fix_bn: True 49 | manual_seed: 321 50 | seed_deterministic: False 51 | zoom_factor: 8 # zoom factor for final prediction during training, be in [1, 2, 4, 8] 52 | 53 | Method: 54 | layers: 50 55 | vgg: False 56 | aux_weight1: 1.0 57 | aux_weight2: 1.0 58 | low_fea: 'layer2' # low_fea for computing the Gram matrix 59 | kshot_trans_dim: 2 # K-shot dimensionality reduction 60 | merge: 'final' # fusion scheme for GFSS ('base' Eq(S1) | 'final' Eq(18) ) 61 | merge_tau: 0.9 # fusion threshold tau 62 | 63 | Test_Finetune: 64 | weight: best_1shot.pth # load weight for fine-tuning or testing (such as win8_train_epoch_47_0.6718.pth) 65 | ann_type: 'mask' # mask/bbox 66 | 67 | 68 | 69 | ## deprecated multi-processing training 70 | # Distributed: 71 | # dist_url: tcp://127.0.0.1:6789 72 | # dist_backend: 'nccl' 73 | # multiprocessing_distributed: False 74 | # world_size: 1 75 | # rank: 0 76 | # use_apex: False 77 | # opt_level: 'O0' 78 | # keep_batchnorm_fp32: 79 | # loss_scale: 80 | 81 | -------------------------------------------------------------------------------- /config/pascal/pascal_split2_resnet50_manet_5s.yaml: -------------------------------------------------------------------------------- 1 | Data: 2 | data_root: ../data/VOCdevkit2012/VOC2012 3 | train_list: ./lists/pascal/voc_sbd_merge_noduplicate.txt 4 | val_list: ./lists/pascal/val.txt 5 | classes: 2 6 | 7 | 8 | Train: 9 | # Aug 10 | train_h: 473 11 | train_w: 473 12 | val_size: 473 13 | scale_min: 0.9 # minimum random scale 14 | scale_max: 1.1 # maximum random scale 15 | rotate_min: -10 # minimum random rotate 16 | rotate_max: 10 # maximum random rotate 17 | ignore_label: 255 18 | padding_label: 255 19 | # Dataset & Mode 20 | split: 2 21 | shot: 5 22 | data_set: 'pascal' 23 | use_split_coco: False # True means FWB setting 24 | # Optimizer 25 | batch_size: 2 # batch size for training (bs8 for 1GPU) 26 | base_lr: 0.005 27 | epochs: 300 28 | start_epoch: 0 29 | stop_interval: 100 # stop when the best result is not updated for "stop_interval" epochs 30 | index_split: -1 # index for determining the params group with 10x learning rate 31 | power: 0.9 # 0 means no decay 32 | momentum: 0.9 33 | weight_decay: 0.0001 34 | warmup: False 35 | # Viz & Save & Resume 36 | print_freq: 10 37 | save_freq: 10 38 | resume: # path to latest checkpoint (default: none, such as epoch_10.pth) 39 | # Validate 40 | evaluate: True 41 | SubEpoch_val: True # val at the half epoch 42 | fix_random_seed_val: True 43 | batch_size_val: 1 44 | resized_val: True 45 | ori_resize: True # use original label for evaluation 46 | # Else 47 | workers: 8 48 | fix_bn: True 49 | manual_seed: 321 50 | seed_deterministic: False 51 | zoom_factor: 8 # zoom factor for final prediction during training, be in [1, 2, 4, 8] 52 | 53 | Method: 54 | layers: 50 55 | vgg: False 56 | aux_weight1: 1.0 57 | aux_weight2: 1.0 58 | low_fea: 'layer2' # low_fea for computing the Gram matrix 59 | kshot_trans_dim: 2 # K-shot dimensionality reduction 60 | merge: 'final' # fusion scheme for GFSS ('base' Eq(S1) | 'final' Eq(18) ) 61 | merge_tau: 0.9 # fusion threshold tau 62 | 63 | Test_Finetune: 64 | weight: best_5shot.pth # load weight for fine-tuning or testing (such as win8_train_epoch_47_0.6718.pth) 65 | ann_type: 'mask' # mask/bbox 66 | 67 | 68 | 69 | ## deprecated multi-processing training 70 | # Distributed: 71 | # dist_url: tcp://127.0.0.1:6789 72 | # dist_backend: 'nccl' 73 | # multiprocessing_distributed: False 74 | # world_size: 1 75 | # rank: 0 76 | # use_apex: False 77 | # opt_level: 'O0' 78 | # keep_batchnorm_fp32: 79 | # loss_scale: 80 | 81 | -------------------------------------------------------------------------------- /config/pascal/pascal_split2_vgg_manet.yaml: -------------------------------------------------------------------------------- 1 | Data: 2 | data_root: ../data/VOCdevkit2012/VOC2012 3 | train_list: ./lists/pascal/voc_sbd_merge_noduplicate.txt 4 | val_list: ./lists/pascal/val.txt 5 | classes: 2 6 | 7 | 8 | Train: 9 | # Aug 10 | train_h: 473 11 | train_w: 473 12 | val_size: 473 13 | scale_min: 0.9 # minimum random scale 14 | scale_max: 1.1 # maximum random scale 15 | rotate_min: -10 # minimum random rotate 16 | rotate_max: 10 # maximum random rotate 17 | ignore_label: 255 18 | padding_label: 255 19 | # Dataset & Mode 20 | split: 2 21 | shot: 1 22 | data_set: 'pascal' 23 | use_split_coco: False # True means FWB setting 24 | # Optimizer 25 | batch_size: 2 # batch size for training (bs8 for 1GPU) 26 | base_lr: 0.005 27 | epochs: 300 28 | start_epoch: 0 29 | stop_interval: 100 # stop when the best result is not updated for "stop_interval" epochs 30 | index_split: -1 # index for determining the params group with 10x learning rate 31 | power: 0.9 # 0 means no decay 32 | momentum: 0.9 33 | weight_decay: 0.0001 34 | warmup: False 35 | # Viz & Save & Resume 36 | print_freq: 10 37 | save_freq: 10 38 | resume: # path to latest checkpoint (default: none, such as epoch_10.pth) 39 | # Validate 40 | evaluate: True 41 | SubEpoch_val: True # val at the half epoch 42 | fix_random_seed_val: True 43 | batch_size_val: 1 44 | resized_val: True 45 | ori_resize: True # use original label for evaluation 46 | # Else 47 | workers: 8 48 | fix_bn: True 49 | manual_seed: 321 50 | seed_deterministic: False 51 | zoom_factor: 8 # zoom factor for final prediction during training, be in [1, 2, 4, 8] 52 | 53 | Method: 54 | layers: 50 55 | vgg: True 56 | aux_weight1: 1.0 57 | aux_weight2: 1.0 58 | low_fea: 'layer2' # low_fea for computing the Gram matrix 59 | kshot_trans_dim: 2 # K-shot dimensionality reduction 60 | merge: 'final' # fusion scheme for GFSS ('base' Eq(S1) | 'final' Eq(18) ) 61 | merge_tau: 0.9 # fusion threshold tau 62 | 63 | Test_Finetune: 64 | weight: best_1shot.pth # load weight for fine-tuning or testing (such as win8_train_epoch_47_0.6718.pth) 65 | ann_type: 'mask' # mask/bbox 66 | 67 | 68 | 69 | ## deprecated multi-processing training 70 | # Distributed: 71 | # dist_url: tcp://127.0.0.1:6789 72 | # dist_backend: 'nccl' 73 | # multiprocessing_distributed: False 74 | # world_size: 1 75 | # rank: 0 76 | # use_apex: False 77 | # opt_level: 'O0' 78 | # keep_batchnorm_fp32: 79 | # loss_scale: 80 | 81 | -------------------------------------------------------------------------------- /config/pascal/pascal_split2_vgg_manet_5s.yaml: -------------------------------------------------------------------------------- 1 | Data: 2 | data_root: ../data/VOCdevkit2012/VOC2012 3 | train_list: ./lists/pascal/voc_sbd_merge_noduplicate.txt 4 | val_list: ./lists/pascal/val.txt 5 | classes: 2 6 | 7 | 8 | Train: 9 | # Aug 10 | train_h: 473 11 | train_w: 473 12 | val_size: 473 13 | scale_min: 0.9 # minimum random scale 14 | scale_max: 1.1 # maximum random scale 15 | rotate_min: -10 # minimum random rotate 16 | rotate_max: 10 # maximum random rotate 17 | ignore_label: 255 18 | padding_label: 255 19 | # Dataset & Mode 20 | split: 2 21 | shot: 5 22 | data_set: 'pascal' 23 | use_split_coco: False # True means FWB setting 24 | # Optimizer 25 | batch_size: 2 # batch size for training (bs8 for 1GPU) 26 | base_lr: 0.005 27 | epochs: 300 28 | start_epoch: 0 29 | stop_interval: 100 # stop when the best result is not updated for "stop_interval" epochs 30 | index_split: -1 # index for determining the params group with 10x learning rate 31 | power: 0.9 # 0 means no decay 32 | momentum: 0.9 33 | weight_decay: 0.0001 34 | warmup: False 35 | # Viz & Save & Resume 36 | print_freq: 10 37 | save_freq: 10 38 | resume: # path to latest checkpoint (default: none, such as epoch_10.pth) 39 | # Validate 40 | evaluate: True 41 | SubEpoch_val: True # val at the half epoch 42 | fix_random_seed_val: True 43 | batch_size_val: 1 44 | resized_val: True 45 | ori_resize: True # use original label for evaluation 46 | # Else 47 | workers: 8 48 | fix_bn: True 49 | manual_seed: 321 50 | seed_deterministic: False 51 | zoom_factor: 8 # zoom factor for final prediction during training, be in [1, 2, 4, 8] 52 | 53 | Method: 54 | layers: 50 55 | vgg: True 56 | aux_weight1: 1.0 57 | aux_weight2: 1.0 58 | low_fea: 'layer2' # low_fea for computing the Gram matrix 59 | kshot_trans_dim: 2 # K-shot dimensionality reduction 60 | merge: 'final' # fusion scheme for GFSS ('base' Eq(S1) | 'final' Eq(18) ) 61 | merge_tau: 0.9 # fusion threshold tau 62 | 63 | Test_Finetune: 64 | weight: best_5shot.pth # load weight for fine-tuning or testing (such as win8_train_epoch_47_0.6718.pth) 65 | ann_type: 'mask' # mask/bbox 66 | 67 | 68 | 69 | ## deprecated multi-processing training 70 | # Distributed: 71 | # dist_url: tcp://127.0.0.1:6789 72 | # dist_backend: 'nccl' 73 | # multiprocessing_distributed: False 74 | # world_size: 1 75 | # rank: 0 76 | # use_apex: False 77 | # opt_level: 'O0' 78 | # keep_batchnorm_fp32: 79 | # loss_scale: 80 | 81 | -------------------------------------------------------------------------------- /config/pascal/pascal_split3_resnet50_manet.yaml: -------------------------------------------------------------------------------- 1 | Data: 2 | data_root: ../data/VOCdevkit2012/VOC2012 3 | train_list: ./lists/pascal/voc_sbd_merge_noduplicate.txt 4 | val_list: ./lists/pascal/val.txt 5 | classes: 2 6 | 7 | 8 | Train: 9 | # Aug 10 | train_h: 473 11 | train_w: 473 12 | val_size: 473 13 | scale_min: 0.9 # minimum random scale 14 | scale_max: 1.1 # maximum random scale 15 | rotate_min: -10 # minimum random rotate 16 | rotate_max: 10 # maximum random rotate 17 | ignore_label: 255 18 | padding_label: 255 19 | # Dataset & Mode 20 | split: 3 21 | shot: 1 22 | data_set: 'pascal' 23 | use_split_coco: False # True means FWB setting 24 | # Optimizer 25 | batch_size: 2 # batch size for training (bs8 for 1GPU) 26 | base_lr: 0.005 27 | epochs: 300 28 | start_epoch: 0 29 | stop_interval: 100 # stop when the best result is not updated for "stop_interval" epochs 30 | index_split: -1 # index for determining the params group with 10x learning rate 31 | power: 0.9 # 0 means no decay 32 | momentum: 0.9 33 | weight_decay: 0.0001 34 | warmup: False 35 | # Viz & Save & Resume 36 | print_freq: 10 37 | save_freq: 10 38 | resume: # path to latest checkpoint (default: none, such as epoch_10.pth) 39 | # Validate 40 | evaluate: True 41 | SubEpoch_val: True # val at the half epoch 42 | fix_random_seed_val: True 43 | batch_size_val: 1 44 | resized_val: True 45 | ori_resize: True # use original label for evaluation 46 | # Else 47 | workers: 8 48 | fix_bn: True 49 | manual_seed: 321 50 | seed_deterministic: False 51 | zoom_factor: 8 # zoom factor for final prediction during training, be in [1, 2, 4, 8] 52 | 53 | Method: 54 | layers: 50 55 | vgg: False 56 | aux_weight1: 1.0 57 | aux_weight2: 1.0 58 | low_fea: 'layer2' # low_fea for computing the Gram matrix 59 | kshot_trans_dim: 2 # K-shot dimensionality reduction 60 | merge: 'final' # fusion scheme for GFSS ('base' Eq(S1) | 'final' Eq(18) ) 61 | merge_tau: 0.9 # fusion threshold tau 62 | 63 | Test_Finetune: 64 | weight: best_1shot.pth # load weight for fine-tuning or testing (such as win8_train_epoch_93_0.6053.pth) 65 | ann_type: 'mask' # mask/bbox 66 | 67 | 68 | 69 | ## deprecated multi-processing training 70 | # Distributed: 71 | # dist_url: tcp://127.0.0.1:6789 72 | # dist_backend: 'nccl' 73 | # multiprocessing_distributed: False 74 | # world_size: 1 75 | # rank: 0 76 | # use_apex: False 77 | # opt_level: 'O0' 78 | # keep_batchnorm_fp32: 79 | # loss_scale: 80 | 81 | -------------------------------------------------------------------------------- /config/pascal/pascal_split3_resnet50_manet_5s.yaml: -------------------------------------------------------------------------------- 1 | Data: 2 | data_root: ../data/VOCdevkit2012/VOC2012 3 | train_list: ./lists/pascal/voc_sbd_merge_noduplicate.txt 4 | val_list: ./lists/pascal/val.txt 5 | classes: 2 6 | 7 | 8 | Train: 9 | # Aug 10 | train_h: 473 11 | train_w: 473 12 | val_size: 473 13 | scale_min: 0.9 # minimum random scale 14 | scale_max: 1.1 # maximum random scale 15 | rotate_min: -10 # minimum random rotate 16 | rotate_max: 10 # maximum random rotate 17 | ignore_label: 255 18 | padding_label: 255 19 | # Dataset & Mode 20 | split: 3 21 | shot: 5 22 | data_set: 'pascal' 23 | use_split_coco: False # True means FWB setting 24 | # Optimizer 25 | batch_size: 2 # batch size for training (bs8 for 1GPU) 26 | base_lr: 0.005 27 | epochs: 300 28 | start_epoch: 0 29 | stop_interval: 100 # stop when the best result is not updated for "stop_interval" epochs 30 | index_split: -1 # index for determining the params group with 10x learning rate 31 | power: 0.9 # 0 means no decay 32 | momentum: 0.9 33 | weight_decay: 0.0001 34 | warmup: False 35 | # Viz & Save & Resume 36 | print_freq: 10 37 | save_freq: 10 38 | resume: # path to latest checkpoint (default: none, such as epoch_10.pth) 39 | # Validate 40 | evaluate: True 41 | SubEpoch_val: True # val at the half epoch 42 | fix_random_seed_val: True 43 | batch_size_val: 1 44 | resized_val: True 45 | ori_resize: True # use original label for evaluation 46 | # Else 47 | workers: 8 48 | fix_bn: True 49 | manual_seed: 321 50 | seed_deterministic: False 51 | zoom_factor: 8 # zoom factor for final prediction during training, be in [1, 2, 4, 8] 52 | 53 | Method: 54 | layers: 50 55 | vgg: False 56 | aux_weight1: 1.0 57 | aux_weight2: 1.0 58 | low_fea: 'layer2' # low_fea for computing the Gram matrix 59 | kshot_trans_dim: 2 # K-shot dimensionality reduction 60 | merge: 'final' # fusion scheme for GFSS ('base' Eq(S1) | 'final' Eq(18) ) 61 | merge_tau: 0.9 # fusion threshold tau 62 | 63 | Test_Finetune: 64 | weight: best_5shot.pth # load weight for fine-tuning or testing (such as win8_train_epoch_93_0.6053.pth) 65 | ann_type: 'mask' # mask/bbox 66 | 67 | 68 | 69 | ## deprecated multi-processing training 70 | # Distributed: 71 | # dist_url: tcp://127.0.0.1:6789 72 | # dist_backend: 'nccl' 73 | # multiprocessing_distributed: False 74 | # world_size: 1 75 | # rank: 0 76 | # use_apex: False 77 | # opt_level: 'O0' 78 | # keep_batchnorm_fp32: 79 | # loss_scale: 80 | 81 | -------------------------------------------------------------------------------- /config/pascal/pascal_split3_vgg_manet.yaml: -------------------------------------------------------------------------------- 1 | Data: 2 | data_root: ../data/VOCdevkit2012/VOC2012 3 | train_list: ./lists/pascal/voc_sbd_merge_noduplicate.txt 4 | val_list: ./lists/pascal/val.txt 5 | classes: 2 6 | 7 | 8 | Train: 9 | # Aug 10 | train_h: 473 11 | train_w: 473 12 | val_size: 473 13 | scale_min: 0.9 # minimum random scale 14 | scale_max: 1.1 # maximum random scale 15 | rotate_min: -10 # minimum random rotate 16 | rotate_max: 10 # maximum random rotate 17 | ignore_label: 255 18 | padding_label: 255 19 | # Dataset & Mode 20 | split: 3 21 | shot: 1 22 | data_set: 'pascal' 23 | use_split_coco: False # True means FWB setting 24 | # Optimizer 25 | batch_size: 2 # batch size for training (bs8 for 1GPU) 26 | base_lr: 0.005 27 | epochs: 300 28 | start_epoch: 0 29 | stop_interval: 100 # stop when the best result is not updated for "stop_interval" epochs 30 | index_split: -1 # index for determining the params group with 10x learning rate 31 | power: 0.9 # 0 means no decay 32 | momentum: 0.9 33 | weight_decay: 0.0001 34 | warmup: False 35 | # Viz & Save & Resume 36 | print_freq: 10 37 | save_freq: 10 38 | resume: # path to latest checkpoint (default: none, such as epoch_10.pth) 39 | # Validate 40 | evaluate: True 41 | SubEpoch_val: True # val at the half epoch 42 | fix_random_seed_val: True 43 | batch_size_val: 1 44 | resized_val: True 45 | ori_resize: True # use original label for evaluation 46 | # Else 47 | workers: 8 48 | fix_bn: True 49 | manual_seed: 321 50 | seed_deterministic: False 51 | zoom_factor: 8 # zoom factor for final prediction during training, be in [1, 2, 4, 8] 52 | 53 | Method: 54 | layers: 50 55 | vgg: True 56 | aux_weight1: 1.0 57 | aux_weight2: 1.0 58 | low_fea: 'layer2' # low_fea for computing the Gram matrix 59 | kshot_trans_dim: 2 # K-shot dimensionality reduction 60 | merge: 'final' # fusion scheme for GFSS ('base' Eq(S1) | 'final' Eq(18) ) 61 | merge_tau: 0.9 # fusion threshold tau 62 | 63 | Test_Finetune: 64 | weight: best_1shot.pth # load weight for fine-tuning or testing (such as win8_train_epoch_93_0.6053.pth) 65 | ann_type: 'mask' # mask/bbox 66 | 67 | 68 | 69 | ## deprecated multi-processing training 70 | # Distributed: 71 | # dist_url: tcp://127.0.0.1:6789 72 | # dist_backend: 'nccl' 73 | # multiprocessing_distributed: False 74 | # world_size: 1 75 | # rank: 0 76 | # use_apex: False 77 | # opt_level: 'O0' 78 | # keep_batchnorm_fp32: 79 | # loss_scale: 80 | 81 | -------------------------------------------------------------------------------- /config/pascal/pascal_split3_vgg_manet_5s.yaml: -------------------------------------------------------------------------------- 1 | Data: 2 | data_root: ../data/VOCdevkit2012/VOC2012 3 | train_list: ./lists/pascal/voc_sbd_merge_noduplicate.txt 4 | val_list: ./lists/pascal/val.txt 5 | classes: 2 6 | 7 | 8 | Train: 9 | # Aug 10 | train_h: 473 11 | train_w: 473 12 | val_size: 473 13 | scale_min: 0.9 # minimum random scale 14 | scale_max: 1.1 # maximum random scale 15 | rotate_min: -10 # minimum random rotate 16 | rotate_max: 10 # maximum random rotate 17 | ignore_label: 255 18 | padding_label: 255 19 | # Dataset & Mode 20 | split: 3 21 | shot: 5 22 | data_set: 'pascal' 23 | use_split_coco: False # True means FWB setting 24 | # Optimizer 25 | batch_size: 2 # batch size for training (bs8 for 1GPU) 26 | base_lr: 0.005 27 | epochs: 300 28 | start_epoch: 0 29 | stop_interval: 100 # stop when the best result is not updated for "stop_interval" epochs 30 | index_split: -1 # index for determining the params group with 10x learning rate 31 | power: 0.9 # 0 means no decay 32 | momentum: 0.9 33 | weight_decay: 0.0001 34 | warmup: False 35 | # Viz & Save & Resume 36 | print_freq: 10 37 | save_freq: 10 38 | resume: # path to latest checkpoint (default: none, such as epoch_10.pth) 39 | # Validate 40 | evaluate: True 41 | SubEpoch_val: True # val at the half epoch 42 | fix_random_seed_val: True 43 | batch_size_val: 1 44 | resized_val: True 45 | ori_resize: True # use original label for evaluation 46 | # Else 47 | workers: 8 48 | fix_bn: True 49 | manual_seed: 321 50 | seed_deterministic: False 51 | zoom_factor: 8 # zoom factor for final prediction during training, be in [1, 2, 4, 8] 52 | 53 | Method: 54 | layers: 50 55 | vgg: True 56 | aux_weight1: 1.0 57 | aux_weight2: 1.0 58 | low_fea: 'layer2' # low_fea for computing the Gram matrix 59 | kshot_trans_dim: 2 # K-shot dimensionality reduction 60 | merge: 'final' # fusion scheme for GFSS ('base' Eq(S1) | 'final' Eq(18) ) 61 | merge_tau: 0.9 # fusion threshold tau 62 | 63 | Test_Finetune: 64 | weight: best_5shot.pth # load weight for fine-tuning or testing (such as win8_train_epoch_93_0.6053.pth) 65 | ann_type: 'mask' # mask/bbox 66 | 67 | 68 | 69 | ## deprecated multi-processing training 70 | # Distributed: 71 | # dist_url: tcp://127.0.0.1:6789 72 | # dist_backend: 'nccl' 73 | # multiprocessing_distributed: False 74 | # world_size: 1 75 | # rank: 0 76 | # use_apex: False 77 | # opt_level: 'O0' 78 | # keep_batchnorm_fp32: 79 | # loss_scale: 80 | 81 | -------------------------------------------------------------------------------- /env.yaml: -------------------------------------------------------------------------------- 1 | name: hmnet 2 | channels: 3 | - pytorch 4 | - conda-forge 5 | - https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/main 6 | - https://mirrors.tuna.tsinghua.edu.cn/anaconda/cloud/conda-forge/ 7 | - https://mirrors.tuna.tsinghua.edu.cn/anaconda/cloud/pytorch/ 8 | - https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/main/ 9 | - https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/free/ 10 | - defaults 11 | dependencies: 12 | - _libgcc_mutex=0.1=main 13 | - _openmp_mutex=5.1=1_gnu 14 | - blas=1.0=mkl 15 | - brotli-python=1.0.9=py310hd8f1fbe_7 16 | - bzip2=1.0.8=h5eee18b_5 17 | - ca-certificates=2024.2.2=hbcca054_0 18 | - certifi=2024.2.2=pyhd8ed1ab_0 19 | - charset-normalizer=3.3.2=pyhd8ed1ab_0 20 | - cudatoolkit=11.6.0=hecad31d_10 21 | - ffmpeg=4.3=hf484d3e_0 22 | - freetype=2.10.4=h0708190_1 23 | - giflib=5.2.1=h36c2ea0_2 24 | - gmp=6.2.1=h58526e2_0 25 | - gnutls=3.6.13=h85f3911_1 26 | - idna=3.7=pyhd8ed1ab_0 27 | - intel-openmp=2021.4.0=h06a4308_3561 28 | - jbig=2.1=h7f98852_2003 29 | - jpeg=9e=h166bdaf_1 30 | - lame=3.100=h7f98852_1001 31 | - lcms2=2.12=hddcbb42_0 32 | - ld_impl_linux-64=2.38=h1181459_1 33 | - lerc=2.2.1=h9c3ff4c_0 34 | - libdeflate=1.7=h7f98852_5 35 | - libffi=3.4.4=h6a678d5_0 36 | - libgcc-ng=11.2.0=h1234567_1 37 | - libgomp=11.2.0=h1234567_1 38 | - libiconv=1.17=h166bdaf_0 39 | - libpng=1.6.37=h21135ba_2 40 | - libstdcxx-ng=11.2.0=h1234567_1 41 | - libtiff=4.3.0=hf544144_1 42 | - libuuid=1.41.5=h5eee18b_0 43 | - libwebp=1.2.2=h3452ae3_0 44 | - libwebp-base=1.2.2=h7f98852_1 45 | - lz4-c=1.9.3=h9c3ff4c_1 46 | - mkl=2021.4.0=h06a4308_640 47 | - mkl-service=2.4.0=py310ha2c4b55_0 48 | - mkl_fft=1.3.1=py310h2b4bcf5_1 49 | - mkl_random=1.2.2=py310h00e6091_0 50 | - ncurses=6.4=h6a678d5_0 51 | - nettle=3.6=he412f7d_0 52 | - numpy=1.24.3=py310hd5efca6_0 53 | - numpy-base=1.24.3=py310h8e6c178_0 54 | - openh264=2.1.1=h780b84a_0 55 | - openssl=3.0.13=h7f8727e_0 56 | - pillow=9.4.0=py310h6a678d5_0 57 | - pip=23.3.1=py310h06a4308_0 58 | - pysocks=1.7.1=pyha2e5f31_6 59 | - python=3.10.14=h955ad1f_0 60 | - python_abi=3.10=2_cp310 61 | - pytorch=1.12.0=py3.10_cuda11.6_cudnn8.3.2_0 62 | - pytorch-mutex=1.0=cuda 63 | - readline=8.2=h5eee18b_0 64 | - requests=2.31.0=pyhd8ed1ab_0 65 | - setuptools=68.2.2=py310h06a4308_0 66 | - six=1.16.0=pyh6c4a22f_0 67 | - sqlite=3.41.2=h5eee18b_0 68 | - tk=8.6.12=h1ccaba5_0 69 | - torchaudio=0.12.0=py310_cu116 70 | - torchvision=0.13.0=py310_cu116 71 | - typing_extensions=4.11.0=pyha770c72_0 72 | - urllib3=2.2.1=pyhd8ed1ab_0 73 | - wheel=0.41.2=py310h06a4308_0 74 | - xz=5.4.6=h5eee18b_0 75 | - zlib=1.2.13=h5eee18b_0 76 | - zstd=1.5.0=ha95c52a_0 77 | - pip: 78 | - addict==2.4.0 79 | - args==0.1.0 80 | - chardet==5.2.0 81 | - clint==0.5.1 82 | - cloudpickle==3.0.0 83 | - contourpy==1.2.1 84 | - coverage==7.5.0 85 | - cycler==0.12.1 86 | - einops==0.7.0 87 | - exceptiongroup==1.2.1 88 | - filelock==3.13.4 89 | - fonttools==4.51.0 90 | - fsspec==2024.3.1 91 | - huggingface-hub==0.22.2 92 | - imageio==2.34.1 93 | - importlib-metadata==7.1.0 94 | - iniconfig==2.0.0 95 | - joblib==1.4.0 96 | - jsonpatch==1.33 97 | - jsonpointer==2.4 98 | - kiwisolver==1.4.5 99 | - lazy-loader==0.4 100 | - mamba==0.11.3 101 | - mamba-ssm==1.2.0.post1 102 | - matplotlib==3.8.4 103 | - mmcls==0.25.0 104 | - mmcv-full==1.6.1 105 | - mmsegmentation==0.27.0 106 | - networkx==3.3 107 | - ninja==1.11.1.1 108 | - opencv-python==4.5.5.64 109 | - packaging==24.0 110 | - pandas==2.2.2 111 | - platformdirs==4.2.1 112 | - pluggy==1.5.0 113 | - prettytable==3.10.0 114 | - protobuf==5.26.1 115 | - pyparsing==3.1.2 116 | - pytest==8.1.1 117 | - python-dateutil==2.9.0.post0 118 | - pytz==2024.1 119 | - pyyaml==6.0.1 120 | - regex==2024.4.16 121 | - safetensors==0.4.3 122 | - scikit-image==0.23.2 123 | - scikit-learn==1.4.2 124 | - scipy==1.13.0 125 | - seaborn==0.13.2 126 | - submitit==1.5.1 127 | - tensorboardx==2.6.2.2 128 | - termcolor==2.4.0 129 | - thop==0.1.1-2209072238 130 | - threadpoolctl==3.4.0 131 | - tifffile==2024.4.18 132 | - timm==0.4.12 133 | - tokenizers==0.19.1 134 | - tomli==2.0.1 135 | - tornado==6.4 136 | - tqdm==4.66.2 137 | - transformers==4.40.1 138 | - triton==2.3.0 139 | - tzdata==2024.1 140 | - visdom==0.2.4 141 | - wcwidth==0.2.13 142 | - websocket-client==1.8.0 143 | - yacs==0.1.8 144 | - yapf==0.40.2 145 | - zipp==3.18.1 146 | -------------------------------------------------------------------------------- /lists/pascal/duplicate_removel.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | raw_path = './voc_sbd_merge.txt' 4 | new_path = './voc_sbd_merge_noduplicate.txt' 5 | lines = open(raw_path).readlines() 6 | new_f = open(new_path, 'w+') 7 | 8 | 9 | existing_lines = [] 10 | for line in lines: 11 | if line not in existing_lines: 12 | existing_lines.append(line) 13 | new_f.write(line) 14 | print('Ori: {}, new: {}'.format(len(lines), len(existing_lines))) 15 | print('Finished.') -------------------------------------------------------------------------------- /model/ASPP.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import torch.utils.data 5 | 6 | class ASPP(nn.Module): 7 | def __init__(self, out_channels=256): 8 | super(ASPP, self).__init__() 9 | self.layer6_0 = nn.Sequential( 10 | nn.Conv2d(out_channels , out_channels, kernel_size=1, stride=1, padding=0, bias=True), 11 | nn.ReLU(), 12 | ) 13 | self.layer6_1 = nn.Sequential( 14 | nn.Conv2d(out_channels , out_channels, kernel_size=1, stride=1, padding=0, bias=True), 15 | nn.ReLU(), 16 | ) 17 | self.layer6_2 = nn.Sequential( 18 | nn.Conv2d(out_channels , out_channels , kernel_size=3, stride=1, padding=6,dilation=6, bias=True), 19 | nn.ReLU(), 20 | ) 21 | self.layer6_3 = nn.Sequential( 22 | nn.Conv2d(out_channels , out_channels, kernel_size=3, stride=1, padding=12, dilation=12, bias=True), 23 | nn.ReLU(), 24 | ) 25 | self.layer6_4 = nn.Sequential( 26 | nn.Conv2d(out_channels , out_channels , kernel_size=3, stride=1, padding=18, dilation=18, bias=True), 27 | nn.ReLU(), 28 | ) 29 | 30 | self._init_weight() 31 | 32 | def _init_weight(self): 33 | for m in self.modules(): 34 | if isinstance(m, nn.Conv2d): 35 | torch.nn.init.kaiming_normal_(m.weight) 36 | elif isinstance(m, nn.BatchNorm2d): 37 | m.weight.data.fill_(1) 38 | m.bias.data.zero_() 39 | 40 | def forward(self, x): 41 | feature_size = x.shape[-2:] 42 | global_feature = F.avg_pool2d(x, kernel_size=feature_size) 43 | 44 | global_feature = self.layer6_0(global_feature) 45 | 46 | global_feature = global_feature.expand(-1, -1, feature_size[0], feature_size[1]) 47 | out = torch.cat( 48 | [global_feature, self.layer6_1(x), self.layer6_2(x), self.layer6_3(x), self.layer6_4(x)], dim=1) 49 | return out -------------------------------------------------------------------------------- /model/PPM.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | class PPM(nn.Module): 6 | def __init__(self, in_dim, reduction_dim, bins): 7 | super(PPM, self).__init__() 8 | self.features = [] 9 | for bin in bins: 10 | self.features.append(nn.Sequential( 11 | nn.AdaptiveAvgPool2d(bin), 12 | nn.Conv2d(in_dim, reduction_dim, kernel_size=1, bias=False), 13 | nn.BatchNorm2d(reduction_dim), 14 | nn.ReLU(inplace=True) 15 | )) 16 | self.features = nn.ModuleList(self.features) 17 | 18 | def forward(self, x): 19 | x_size = x.size() 20 | out = [x] 21 | for f in self.features: 22 | out.append(F.interpolate(f(x), x_size[2:], mode='bilinear', align_corners=True)) 23 | return torch.cat(out, 1) -------------------------------------------------------------------------------- /model/PSPNet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.nn.functional as F 4 | from torch.nn import BatchNorm2d as BatchNorm 5 | import model.resnet as models 6 | import model.vgg as vgg_models 7 | from model.PPM import PPM 8 | 9 | 10 | def get_vgg16_layer(model): 11 | layer0_idx = range(0,7) 12 | layer1_idx = range(7,14) 13 | layer2_idx = range(14,24) 14 | layer3_idx = range(24,34) 15 | layer4_idx = range(34,43) 16 | layers_0 = [] 17 | layers_1 = [] 18 | layers_2 = [] 19 | layers_3 = [] 20 | layers_4 = [] 21 | for idx in layer0_idx: 22 | layers_0 += [model.features[idx]] 23 | for idx in layer1_idx: 24 | layers_1 += [model.features[idx]] 25 | for idx in layer2_idx: 26 | layers_2 += [model.features[idx]] 27 | for idx in layer3_idx: 28 | layers_3 += [model.features[idx]] 29 | for idx in layer4_idx: 30 | layers_4 += [model.features[idx]] 31 | layer0 = nn.Sequential(*layers_0) 32 | layer1 = nn.Sequential(*layers_1) 33 | layer2 = nn.Sequential(*layers_2) 34 | layer3 = nn.Sequential(*layers_3) 35 | layer4 = nn.Sequential(*layers_4) 36 | return layer0,layer1,layer2,layer3,layer4 37 | 38 | class OneModel(nn.Module): 39 | def __init__(self, args): 40 | super(OneModel, self).__init__() 41 | 42 | self.layers = args.layers 43 | self.zoom_factor = args.zoom_factor 44 | self.vgg = args.vgg 45 | self.dataset = args.data_set 46 | self.criterion = nn.CrossEntropyLoss(ignore_index=args.ignore_label) 47 | 48 | self.pretrained = True 49 | self.classes = 16 if self.dataset=='pascal' else 61 50 | 51 | assert self.layers in [50, 101, 152] 52 | 53 | if self.vgg: 54 | print('INFO: Using VGG_16 bn') 55 | vgg_models.BatchNorm = BatchNorm 56 | vgg16 = vgg_models.vgg16_bn(pretrained=self.pretrained) 57 | print(vgg16) 58 | self.layer0, self.layer1, self.layer2, self.layer3, self.layer4 = get_vgg16_layer(vgg16) 59 | else: 60 | print('INFO: Using ResNet {}'.format(self.layers)) 61 | if self.layers == 50: 62 | resnet = models.resnet50(pretrained=self.pretrained) 63 | elif self.layers == 101: 64 | resnet = models.resnet101(pretrained=self.pretrained) 65 | else: 66 | resnet = models.resnet152(pretrained=self.pretrained) 67 | self.layer0 = nn.Sequential(resnet.conv1, resnet.bn1, resnet.relu1, resnet.conv2, resnet.bn2, resnet.relu2, resnet.conv3, resnet.bn3, resnet.relu3, resnet.maxpool) 68 | self.layer1, self.layer2, self.layer3, self.layer4 = resnet.layer1, resnet.layer2, resnet.layer3, resnet.layer4 69 | 70 | for n, m in self.layer3.named_modules(): 71 | if 'conv2' in n: 72 | m.dilation, m.padding, m.stride = (2, 2), (2, 2), (1, 1) 73 | elif 'downsample.0' in n: 74 | m.stride = (1, 1) 75 | for n, m in self.layer4.named_modules(): 76 | if 'conv2' in n: 77 | m.dilation, m.padding, m.stride = (4, 4), (4, 4), (1, 1) 78 | elif 'downsample.0' in n: 79 | m.stride = (1, 1) 80 | 81 | # Base Learner 82 | self.encoder = nn.Sequential(self.layer0, self.layer1, self.layer2, self.layer3, self.layer4) 83 | fea_dim = 512 if self.vgg else 2048 84 | bins=(1, 2, 3, 6) 85 | self.ppm = PPM(fea_dim, int(fea_dim/len(bins)), bins) 86 | self.cls = nn.Sequential( 87 | nn.Conv2d(fea_dim*2, 512, kernel_size=3, padding=1, bias=False), 88 | nn.BatchNorm2d(512), 89 | nn.ReLU(inplace=True), 90 | nn.Dropout2d(p=0.1), 91 | nn.Conv2d(512, self.classes, kernel_size=1)) 92 | 93 | def get_optim(self, model, args, LR): 94 | optimizer = torch.optim.SGD( 95 | [ 96 | {'params': model.encoder.parameters()}, 97 | {'params': model.ppm.parameters()}, 98 | {'params': model.cls.parameters()}, 99 | ], lr=LR, momentum=args.momentum, weight_decay=args.weight_decay) 100 | return optimizer 101 | 102 | 103 | def forward(self, x, y_m): 104 | x_size = x.size() 105 | h = int((x_size[2] - 1) / 8 * self.zoom_factor + 1) 106 | w = int((x_size[3] - 1) / 8 * self.zoom_factor + 1) # 473 107 | 108 | x = self.encoder(x) 109 | x = self.ppm(x) 110 | x = self.cls(x) 111 | 112 | if self.zoom_factor != 1: 113 | x = F.interpolate(x, size=(h, w), mode='bilinear', align_corners=True) 114 | 115 | if self.training: 116 | main_loss = self.criterion(x, y_m.long()) 117 | return x.max(1)[1], main_loss 118 | else: 119 | return x 120 | 121 | -------------------------------------------------------------------------------- /model/backbone_res.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import Tensor 3 | import torch.nn as nn 4 | # from torchvision.models.utils import load_state_dict_from_url 5 | from torch.hub import load_state_dict_from_url 6 | from typing import Type, Any, Callable, Union, List, Optional 7 | from collections import OrderedDict 8 | 9 | 10 | __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 11 | 'resnet152', 'resnext50_32x4d', 'resnext101_32x8d', 12 | 'wide_resnet50_2', 'wide_resnet101_2'] 13 | 14 | 15 | model_urls = { 16 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', 17 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', 18 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', 19 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', 20 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', 21 | 'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth', 22 | 'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth', 23 | 'wide_resnet50_2': 'https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth', 24 | 'wide_resnet101_2': 'https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth', 25 | } 26 | 27 | 28 | def conv3x3(in_planes: int, out_planes: int, stride: int = 1, groups: int = 1, dilation: int = 1) -> nn.Conv2d: 29 | """3x3 convolution with padding""" 30 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 31 | padding=dilation, groups=groups, bias=False, dilation=dilation) 32 | 33 | 34 | def conv1x1(in_planes: int, out_planes: int, stride: int = 1) -> nn.Conv2d: 35 | """1x1 convolution""" 36 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) 37 | 38 | 39 | class BasicBlock(nn.Module): 40 | expansion: int = 1 41 | 42 | def __init__( 43 | self, 44 | inplanes: int, 45 | planes: int, 46 | stride: int = 1, 47 | downsample: Optional[nn.Module] = None, 48 | groups: int = 1, 49 | base_width: int = 64, 50 | dilation: int = 1, 51 | norm_layer: Optional[Callable[..., nn.Module]] = None 52 | ) -> None: 53 | super(BasicBlock, self).__init__() 54 | if norm_layer is None: 55 | norm_layer = nn.BatchNorm2d 56 | if groups != 1 or base_width != 64: 57 | raise ValueError('BasicBlock only supports groups=1 and base_width=64') 58 | if dilation > 1: 59 | raise NotImplementedError("Dilation > 1 not supported in BasicBlock") 60 | # Both self.conv1 and self.downsample layers downsample the input when stride != 1 61 | self.conv1 = conv3x3(inplanes, planes, stride) 62 | self.bn1 = norm_layer(planes) 63 | self.relu = nn.ReLU(inplace=True) 64 | self.conv2 = conv3x3(planes, planes) 65 | self.bn2 = norm_layer(planes) 66 | self.downsample = downsample 67 | self.stride = stride 68 | 69 | def forward(self, x: Tensor) -> Tensor: 70 | identity = x 71 | 72 | out = self.conv1(x) 73 | out = self.bn1(out) 74 | out = self.relu(out) 75 | 76 | out = self.conv2(out) 77 | out = self.bn2(out) 78 | 79 | if self.downsample is not None: 80 | identity = self.downsample(x) 81 | 82 | out += identity 83 | out = self.relu(out) 84 | 85 | return out 86 | 87 | 88 | class Bottleneck(nn.Module): 89 | # Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2) 90 | # while original implementation places the stride at the first 1x1 convolution(self.conv1) 91 | # according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385. 92 | # This variant is also known as ResNet V1.5 and improves accuracy according to 93 | # https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch. 94 | 95 | expansion: int = 4 96 | 97 | def __init__( 98 | self, 99 | inplanes: int, 100 | planes: int, 101 | stride: int = 1, 102 | downsample: Optional[nn.Module] = None, 103 | groups: int = 1, 104 | base_width: int = 64, 105 | dilation: int = 1, 106 | norm_layer: Optional[Callable[..., nn.Module]] = None 107 | ) -> None: 108 | super(Bottleneck, self).__init__() 109 | if norm_layer is None: 110 | norm_layer = nn.BatchNorm2d 111 | width = int(planes * (base_width / 64.)) * groups 112 | # Both self.conv2 and self.downsample layers downsample the input when stride != 1 113 | self.conv1 = conv1x1(inplanes, width) 114 | self.bn1 = norm_layer(width) 115 | self.conv2 = conv3x3(width, width, stride, groups, dilation) 116 | self.bn2 = norm_layer(width) 117 | self.conv3 = conv1x1(width, planes * self.expansion) 118 | self.bn3 = norm_layer(planes * self.expansion) 119 | self.relu = nn.ReLU(inplace=True) 120 | self.downsample = downsample 121 | self.stride = stride 122 | 123 | def forward(self, x: Tensor) -> Tensor: 124 | identity = x 125 | 126 | out = self.conv1(x) 127 | out = self.bn1(out) 128 | out = self.relu(out) 129 | 130 | out = self.conv2(out) 131 | out = self.bn2(out) 132 | out = self.relu(out) 133 | 134 | out = self.conv3(out) 135 | out = self.bn3(out) 136 | 137 | if self.downsample is not None: 138 | identity = self.downsample(x) 139 | 140 | out += identity 141 | out = self.relu(out) 142 | 143 | return out 144 | 145 | 146 | class ResNet(nn.Module): 147 | 148 | def __init__( 149 | self, 150 | block: Type[Union[BasicBlock, Bottleneck]], 151 | layers: List[int], 152 | num_classes: int = 1000, 153 | zero_init_residual: bool = False, 154 | groups: int = 1, 155 | width_per_group: int = 64, 156 | replace_stride_with_dilation: Optional[List[bool]] = None, 157 | norm_layer: Optional[Callable[..., nn.Module]] = None, 158 | deep_stem: bool = True, 159 | ) -> None: 160 | super(ResNet, self).__init__() 161 | if norm_layer is None: 162 | norm_layer = nn.BatchNorm2d 163 | self._norm_layer = norm_layer 164 | 165 | self.inplanes = 128 # 64 166 | self.dilation = 1 167 | if replace_stride_with_dilation is None: 168 | # each element in the tuple indicates if we should replace 169 | # the 2x2 stride with a dilated convolution instead 170 | replace_stride_with_dilation = [False, False, False] 171 | if len(replace_stride_with_dilation) != 3: 172 | raise ValueError("replace_stride_with_dilation should be None " 173 | "or a 3-element tuple, got {}".format(replace_stride_with_dilation)) 174 | self.groups = groups 175 | self.deep_stem = deep_stem 176 | self.base_width = width_per_group 177 | self._make_stem_layer(3, self.inplanes) 178 | self.layer1 = self._make_layer(block, 64, layers[0]) 179 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2, 180 | dilate=replace_stride_with_dilation[0]) 181 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2, 182 | dilate=replace_stride_with_dilation[1]) 183 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2, 184 | dilate=replace_stride_with_dilation[2]) 185 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 186 | self.fc = nn.Linear(512 * block.expansion, num_classes) 187 | 188 | for m in self.modules(): 189 | if isinstance(m, nn.Conv2d): 190 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 191 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 192 | nn.init.constant_(m.weight, 1) 193 | nn.init.constant_(m.bias, 0) 194 | 195 | # Zero-initialize the last BN in each residual branch, 196 | # so that the residual branch starts with zeros, and each residual block behaves like an identity. 197 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 198 | if zero_init_residual: 199 | for m in self.modules(): 200 | if isinstance(m, Bottleneck): 201 | nn.init.constant_(m.bn3.weight, 0) # type: ignore[arg-type] 202 | elif isinstance(m, BasicBlock): 203 | nn.init.constant_(m.bn2.weight, 0) # type: ignore[arg-type] 204 | 205 | def _make_stem_layer(self, in_channels, stem_channels): 206 | """Make stem layer for ResNet.""" 207 | if self.deep_stem: 208 | self.conv1 = conv3x3(in_channels, 64, stride=2) 209 | self.bn1 = self._norm_layer(64) 210 | self.relu1 = nn.ReLU(inplace=True) 211 | self.conv2 = conv3x3(64, 64) 212 | self.bn2 = self._norm_layer(64) 213 | self.relu2 = nn.ReLU(inplace=True) 214 | self.conv3 = conv3x3(64, 128) 215 | self.bn3 = self._norm_layer(128) 216 | self.relu3 = nn.ReLU(inplace=True) 217 | else: 218 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, 219 | bias=False) 220 | self.bn1 = self._norm_layer(self.inplanes) 221 | self.relu = nn.ReLU(inplace=True) 222 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 223 | 224 | def _make_layer(self, block: Type[Union[BasicBlock, Bottleneck]], planes: int, blocks: int, 225 | stride: int = 1, dilate: bool = False) -> nn.Sequential: 226 | norm_layer = self._norm_layer 227 | downsample = None 228 | previous_dilation = self.dilation 229 | if dilate: 230 | self.dilation *= stride 231 | stride = 1 232 | if stride != 1 or self.inplanes != planes * block.expansion: 233 | downsample = nn.Sequential( 234 | conv1x1(self.inplanes, planes * block.expansion, stride), 235 | norm_layer(planes * block.expansion), 236 | ) 237 | 238 | layers = [] 239 | layers.append(block(self.inplanes, planes, stride, downsample, self.groups, 240 | self.base_width, previous_dilation, norm_layer)) 241 | self.inplanes = planes * block.expansion 242 | for _ in range(1, blocks): 243 | layers.append(block(self.inplanes, planes, groups=self.groups, 244 | base_width=self.base_width, dilation=self.dilation, 245 | norm_layer=norm_layer)) 246 | 247 | return nn.Sequential(*layers) 248 | 249 | def _forward_impl(self, x: Tensor) -> Tensor: 250 | # See note [TorchScript super()] 251 | x = self.conv1(x) 252 | x = self.bn1(x) 253 | x = self.relu(x) 254 | x = self.maxpool(x) 255 | 256 | x = self.layer1(x) 257 | x = self.layer2(x) 258 | x = self.layer3(x) 259 | x = self.layer4(x) 260 | 261 | # x = self.avgpool(x) 262 | # x = torch.flatten(x, 1) 263 | # x = self.fc(x) 264 | 265 | return x 266 | 267 | def forward(self, x: Tensor) -> Tensor: 268 | return self._forward_impl(x) 269 | 270 | 271 | def _resnet( 272 | arch: str, 273 | block: Type[Union[BasicBlock, Bottleneck]], 274 | layers: List[int], 275 | pretrained: bool, 276 | progress: bool, 277 | **kwargs: Any 278 | ) -> ResNet: 279 | model = ResNet(block, layers, **kwargs) 280 | if 'warp' in pretrained: 281 | model.fc = nn.Identity() 282 | if pretrained: 283 | # state_dict = load_state_dict_from_url(model_urls[arch], 284 | # progress=progress) 285 | state_dict = torch.load(pretrained, map_location='cpu') 286 | model.load_state_dict(state_dict) 287 | return model 288 | 289 | 290 | def resnet18(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: 291 | r"""ResNet-18 model from 292 | `"Deep Residual Learning for Image Recognition" `_. 293 | 294 | Args: 295 | pretrained (bool): If True, returns a model pre-trained on ImageNet 296 | progress (bool): If True, displays a progress bar of the download to stderr 297 | """ 298 | return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress, 299 | **kwargs) 300 | 301 | 302 | def resnet34(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: 303 | r"""ResNet-34 model from 304 | `"Deep Residual Learning for Image Recognition" `_. 305 | 306 | Args: 307 | pretrained (bool): If True, returns a model pre-trained on ImageNet 308 | progress (bool): If True, displays a progress bar of the download to stderr 309 | """ 310 | return _resnet('resnet34', BasicBlock, [3, 4, 6, 3], pretrained, progress, 311 | **kwargs) 312 | 313 | 314 | def resnet50(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: 315 | r"""ResNet-50 model from 316 | `"Deep Residual Learning for Image Recognition" `_. 317 | 318 | Args: 319 | pretrained (bool): If True, returns a model pre-trained on ImageNet 320 | progress (bool): If True, displays a progress bar of the download to stderr 321 | """ 322 | return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress, 323 | **kwargs) 324 | 325 | 326 | def resnet101(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: 327 | r"""ResNet-101 model from 328 | `"Deep Residual Learning for Image Recognition" `_. 329 | 330 | Args: 331 | pretrained (bool): If True, returns a model pre-trained on ImageNet 332 | progress (bool): If True, displays a progress bar of the download to stderr 333 | """ 334 | return _resnet('resnet101', Bottleneck, [3, 4, 23, 3], pretrained, progress, 335 | **kwargs) 336 | 337 | 338 | def resnet152(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: 339 | r"""ResNet-152 model from 340 | `"Deep Residual Learning for Image Recognition" `_. 341 | 342 | Args: 343 | pretrained (bool): If True, returns a model pre-trained on ImageNet 344 | progress (bool): If True, displays a progress bar of the download to stderr 345 | """ 346 | return _resnet('resnet152', Bottleneck, [3, 8, 36, 3], pretrained, progress, 347 | **kwargs) 348 | 349 | 350 | def resnext50_32x4d(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: 351 | r"""ResNeXt-50 32x4d model from 352 | `"Aggregated Residual Transformation for Deep Neural Networks" `_. 353 | 354 | Args: 355 | pretrained (bool): If True, returns a model pre-trained on ImageNet 356 | progress (bool): If True, displays a progress bar of the download to stderr 357 | """ 358 | kwargs['groups'] = 32 359 | kwargs['width_per_group'] = 4 360 | return _resnet('resnext50_32x4d', Bottleneck, [3, 4, 6, 3], 361 | pretrained, progress, **kwargs) 362 | 363 | 364 | def resnext101_32x8d(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: 365 | r"""ResNeXt-101 32x8d model from 366 | `"Aggregated Residual Transformation for Deep Neural Networks" `_. 367 | 368 | Args: 369 | pretrained (bool): If True, returns a model pre-trained on ImageNet 370 | progress (bool): If True, displays a progress bar of the download to stderr 371 | """ 372 | kwargs['groups'] = 32 373 | kwargs['width_per_group'] = 8 374 | return _resnet('resnext101_32x8d', Bottleneck, [3, 4, 23, 3], 375 | pretrained, progress, **kwargs) 376 | 377 | 378 | def wide_resnet50_2(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: 379 | r"""Wide ResNet-50-2 model from 380 | `"Wide Residual Networks" `_. 381 | 382 | The model is the same as ResNet except for the bottleneck number of channels 383 | which is twice larger in every block. The number of channels in outer 1x1 384 | convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048 385 | channels, and in Wide ResNet-50-2 has 2048-1024-2048. 386 | 387 | Args: 388 | pretrained (bool): If True, returns a model pre-trained on ImageNet 389 | progress (bool): If True, displays a progress bar of the download to stderr 390 | """ 391 | kwargs['width_per_group'] = 64 * 2 392 | return _resnet('wide_resnet50_2', Bottleneck, [3, 4, 6, 3], 393 | pretrained, progress, **kwargs) 394 | 395 | 396 | def wide_resnet101_2(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: 397 | r"""Wide ResNet-101-2 model from 398 | `"Wide Residual Networks" `_. 399 | 400 | The model is the same as ResNet except for the bottleneck number of channels 401 | which is twice larger in every block. The number of channels in outer 1x1 402 | convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048 403 | channels, and in Wide ResNet-50-2 has 2048-1024-2048. 404 | 405 | Args: 406 | pretrained (bool): If True, returns a model pre-trained on ImageNet 407 | progress (bool): If True, displays a progress bar of the download to stderr 408 | """ 409 | kwargs['width_per_group'] = 64 * 2 410 | return _resnet('wide_resnet101_2', Bottleneck, [3, 4, 23, 3], 411 | pretrained, progress, **kwargs) 412 | -------------------------------------------------------------------------------- /model/backbone_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.nn.functional as F 4 | from torchvision.models._utils import IntermediateLayerGetter 5 | from model.backbone_res import * 6 | 7 | 8 | class FrozenBatchNorm2d(torch.nn.Module): 9 | """ 10 | BatchNorm2d where the batch statistics and the affine parameters are fixed. 11 | Copy-paste from torchvision.misc.ops with added eps before rqsrt, 12 | without which any other models than torchvision.models.resnet[18,34,50,101] 13 | produce nans. 14 | """ 15 | 16 | def __init__(self, n): 17 | super(FrozenBatchNorm2d, self).__init__() 18 | self.register_buffer("weight", torch.ones(n)) 19 | self.register_buffer("bias", torch.zeros(n)) 20 | self.register_buffer("running_mean", torch.zeros(n)) 21 | self.register_buffer("running_var", torch.ones(n)) 22 | 23 | def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, 24 | missing_keys, unexpected_keys, error_msgs): 25 | num_batches_tracked_key = prefix + 'num_batches_tracked' 26 | if num_batches_tracked_key in state_dict: 27 | del state_dict[num_batches_tracked_key] 28 | 29 | super(FrozenBatchNorm2d, self)._load_from_state_dict( 30 | state_dict, prefix, local_metadata, strict, 31 | missing_keys, unexpected_keys, error_msgs) 32 | 33 | def forward(self, x): 34 | # move reshapes to the beginning 35 | # to make it fuser-friendly 36 | w = self.weight.reshape(1, -1, 1, 1) 37 | b = self.bias.reshape(1, -1, 1, 1) 38 | rv = self.running_var.reshape(1, -1, 1, 1) 39 | rm = self.running_mean.reshape(1, -1, 1, 1) 40 | eps = 1e-5 41 | scale = w * (rv + eps).rsqrt() 42 | bias = b - rm * scale 43 | return x * scale + bias 44 | 45 | 46 | class BackboneBase(nn.Module): 47 | 48 | def __init__(self, backbone: nn.Module, train_backbone: bool, num_channels: int, return_interm_layers: bool): 49 | super().__init__() 50 | for name, parameter in backbone.named_parameters(): 51 | if not train_backbone or 'layer2' not in name and 'layer3' not in name and 'layer4' not in name: 52 | parameter.requires_grad_(False) 53 | if return_interm_layers: 54 | return_layers = {"layer1": "0", "layer2": "1", "layer3": "2", "layer4": "3"} 55 | else: 56 | return_layers = {'layer4': "0"} 57 | self.body = IntermediateLayerGetter(backbone, return_layers=return_layers) 58 | self.num_channels = num_channels 59 | 60 | def forward(self, x): 61 | x = self.body(x) 62 | return x 63 | 64 | 65 | resnets_dict = { 66 | 'resnet50': (resnet50, 'initmodel/resnet50_v2.pth'), 67 | 'resnet101': (resnet101, 'initmodel/resnet101_v2.pth'), 68 | } 69 | 70 | 71 | class Backbone(BackboneBase): 72 | """ResNet backbone with frozen BatchNorm.""" 73 | 74 | def __init__(self, name: str, 75 | train_backbone: bool, 76 | return_interm_layers: bool, 77 | dilation: list): 78 | backbone = resnets_dict[name][0]( 79 | replace_stride_with_dilation=dilation, 80 | pretrained=resnets_dict[name][1], norm_layer=FrozenBatchNorm2d) 81 | num_channels = 512 if name in ('resnet18', 'resnet34') else 2048 82 | super().__init__(backbone, train_backbone, num_channels, return_interm_layers) 83 | -------------------------------------------------------------------------------- /model/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | # ======================================== 7 | # Weighted Dice loss 8 | # By default, weight is uniformly fixed as 1 9 | # From https://github.com/YanFangCS/CyCTR-Pytorch 10 | # ======================================== 11 | def weighted_dice_loss(prediction, target_seg, weighted_val: float = 1.0, reduction: str = "sum", eps: float = 1e-8): 12 | """ 13 | Weighted version of Dice Loss 14 | 15 | Args: 16 | prediction: prediction 17 | target_seg: segmentation target 18 | weighted_val: values of k positives, 19 | reduction: 'none' | 'mean' | 'sum' 20 | 'none': No reduction will be applied to the output. 21 | 'mean': The output will be averaged. 22 | 'sum' : The output will be summed. 23 | eps: the minimum eps, 24 | """ 25 | target_seg_fg = target_seg == 1 26 | target_seg_bg = target_seg == 0 27 | target_seg = torch.stack([target_seg_bg, target_seg_fg], dim=1).float() 28 | 29 | n, _, h, w = target_seg.shape 30 | 31 | prediction = prediction.reshape(-1, h, w) 32 | target_seg = target_seg.reshape(-1, h, w) 33 | prediction = torch.sigmoid(prediction) 34 | prediction = prediction.reshape(-1, h * w) 35 | target_seg = target_seg.reshape(-1, h * w) 36 | 37 | # calculate dice loss 38 | loss_part = (prediction ** 2).sum(dim=-1) + (target_seg ** 2).sum(dim=-1) 39 | loss = 1 - 2 * (target_seg * prediction).sum(dim=-1) / torch.clamp(loss_part, min=eps) 40 | # normalize the loss 41 | loss = loss * weighted_val 42 | 43 | if reduction == "sum": 44 | loss = loss.sum() / n 45 | elif reduction == "mean": 46 | loss = loss.mean() 47 | return loss 48 | 49 | 50 | class WeightedDiceLoss(nn.Module): 51 | def __init__(self, weighted_val: float = 1.0, reduction: str = "sum"): 52 | super(WeightedDiceLoss, self).__init__() 53 | self.weighted_val = weighted_val 54 | self.reduction = reduction 55 | 56 | def forward(self, prediction, target_seg): 57 | return weighted_dice_loss(prediction, target_seg, self.weighted_val, self.reduction) 58 | 59 | 60 | # ======================================== 61 | # Cross-Entropy loss + Dice loss 62 | # ======================================== 63 | class CEDiceLoss(nn.Module): 64 | def __init__(self, reduction="mean", ignore_index=255): 65 | super(CEDiceLoss, self).__init__() 66 | self.reduction = reduction 67 | self.ignore_index = ignore_index 68 | 69 | self.ce = nn.CrossEntropyLoss(ignore_index=self.ignore_index, reduction=self.reduction) 70 | self.dice = WeightedDiceLoss(reduction=reduction) 71 | 72 | def forward(self, output, target): 73 | ce = self.ce(output, target) 74 | dice = self.dice(output, target) 75 | return ce + dice 76 | -------------------------------------------------------------------------------- /model/resnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import math 4 | import torch.utils.model_zoo as model_zoo 5 | 6 | BatchNorm = nn.BatchNorm2d 7 | 8 | __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 9 | 'resnet152'] 10 | 11 | 12 | model_urls = { 13 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', 14 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', 15 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', 16 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', 17 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', 18 | } 19 | 20 | 21 | def conv3x3(in_planes, out_planes, stride=1): 22 | """3x3 convolution with padding""" 23 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 24 | padding=1, bias=False) 25 | 26 | 27 | class BasicBlock(nn.Module): 28 | expansion = 1 29 | 30 | def __init__(self, inplanes, planes, stride=1, downsample=None): 31 | super(BasicBlock, self).__init__() 32 | self.conv1 = conv3x3(inplanes, planes, stride) 33 | self.bn1 = BatchNorm(planes) 34 | self.relu = nn.ReLU(inplace=True) 35 | self.conv2 = conv3x3(planes, planes) 36 | self.bn2 = BatchNorm(planes) 37 | self.downsample = downsample 38 | self.stride = stride 39 | 40 | def forward(self, x): 41 | residual = x 42 | 43 | out = self.conv1(x) 44 | out = self.bn1(out) 45 | out = self.relu(out) 46 | 47 | out = self.conv2(out) 48 | out = self.bn2(out) 49 | 50 | if self.downsample is not None: 51 | residual = self.downsample(x) 52 | 53 | out += residual 54 | out = self.relu(out) 55 | 56 | return out 57 | 58 | 59 | class Bottleneck(nn.Module): 60 | expansion = 4 61 | 62 | def __init__(self, inplanes, planes, stride=1, downsample=None): 63 | super(Bottleneck, self).__init__() 64 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 65 | self.bn1 = BatchNorm(planes) 66 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 67 | padding=1, bias=False) 68 | self.bn2 = BatchNorm(planes) 69 | self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1, bias=False) 70 | self.bn3 = BatchNorm(planes * self.expansion) 71 | self.relu = nn.ReLU(inplace=True) 72 | self.downsample = downsample 73 | self.stride = stride 74 | 75 | def forward(self, x): 76 | residual = x 77 | 78 | out = self.conv1(x) 79 | out = self.bn1(out) 80 | out = self.relu(out) 81 | 82 | out = self.conv2(out) 83 | out = self.bn2(out) 84 | out = self.relu(out) 85 | 86 | out = self.conv3(out) 87 | out = self.bn3(out) 88 | 89 | if self.downsample is not None: 90 | residual = self.downsample(x) 91 | 92 | out += residual 93 | out = self.relu(out) 94 | 95 | return out 96 | 97 | 98 | class ResNet(nn.Module): 99 | 100 | def __init__(self, block, layers, num_classes=1000, deep_base=True): 101 | super(ResNet, self).__init__() 102 | self.deep_base = deep_base 103 | if not self.deep_base: 104 | self.inplanes = 64 105 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False) 106 | self.bn1 = BatchNorm(64) 107 | self.relu = nn.ReLU(inplace=True) 108 | else: 109 | self.inplanes = 128 110 | self.conv1 = conv3x3(3, 64, stride=2) 111 | self.bn1 = BatchNorm(64) 112 | self.relu1 = nn.ReLU(inplace=True) 113 | self.conv2 = conv3x3(64, 64) 114 | self.bn2 = BatchNorm(64) 115 | self.relu2 = nn.ReLU(inplace=True) 116 | self.conv3 = conv3x3(64, 128) 117 | self.bn3 = BatchNorm(128) 118 | self.relu3 = nn.ReLU(inplace=True) 119 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 120 | self.layer1 = self._make_layer(block, 64, layers[0]) 121 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 122 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 123 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2) 124 | self.avgpool = nn.AvgPool2d(7, stride=1) 125 | self.fc = nn.Linear(512 * block.expansion, num_classes) 126 | 127 | for m in self.modules(): 128 | if isinstance(m, nn.Conv2d): 129 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 130 | elif isinstance(m, BatchNorm): 131 | nn.init.constant_(m.weight, 1) 132 | nn.init.constant_(m.bias, 0) 133 | 134 | def _make_layer(self, block, planes, blocks, stride=1): 135 | downsample = None 136 | if stride != 1 or self.inplanes != planes * block.expansion: 137 | downsample = nn.Sequential( 138 | nn.Conv2d(self.inplanes, planes * block.expansion, 139 | kernel_size=1, stride=stride, bias=False), 140 | BatchNorm(planes * block.expansion), 141 | ) 142 | 143 | layers = [] 144 | layers.append(block(self.inplanes, planes, stride, downsample)) 145 | self.inplanes = planes * block.expansion 146 | for i in range(1, blocks): 147 | layers.append(block(self.inplanes, planes)) 148 | 149 | return nn.Sequential(*layers) 150 | 151 | def forward(self, x): 152 | x = self.relu1(self.bn1(self.conv1(x))) 153 | if self.deep_base: 154 | x = self.relu2(self.bn2(self.conv2(x))) 155 | x = self.relu3(self.bn3(self.conv3(x))) 156 | x = self.maxpool(x) 157 | 158 | x = self.layer1(x) 159 | x = self.layer2(x) 160 | x = self.layer3(x) 161 | x = self.layer4(x) 162 | 163 | x = self.avgpool(x) 164 | x = x.view(x.size(0), -1) 165 | x = self.fc(x) 166 | 167 | return x 168 | 169 | 170 | def resnet18(pretrained=False, **kwargs): 171 | """Constructs a ResNet-18 model. 172 | 173 | Args: 174 | pretrained (bool): If True, returns a model pre-trained on ImageNet 175 | """ 176 | model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs) 177 | if pretrained: 178 | model.load_state_dict(model_zoo.load_url(model_urls['resnet18'])) 179 | return model 180 | 181 | 182 | def resnet34(pretrained=False, **kwargs): 183 | """Constructs a ResNet-34 model. 184 | 185 | Args: 186 | pretrained (bool): If True, returns a model pre-trained on ImageNet 187 | """ 188 | model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs) 189 | if pretrained: 190 | model.load_state_dict(model_zoo.load_url(model_urls['resnet34'])) 191 | return model 192 | 193 | 194 | def resnet50(pretrained=True, **kwargs): 195 | """Constructs a ResNet-50 model. 196 | 197 | Args: 198 | pretrained (bool): If True, returns a model pre-trained on ImageNet 199 | """ 200 | model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs) 201 | if pretrained: 202 | # model.load_state_dict(model_zoo.load_url(model_urls['resnet50'])) 203 | model_path = './initmodel/resnet50_v2.pth' 204 | model.load_state_dict(torch.load(model_path), strict=False) 205 | return model 206 | 207 | 208 | def resnet101(pretrained=False, **kwargs): 209 | """Constructs a ResNet-101 model. 210 | 211 | Args: 212 | pretrained (bool): If True, returns a model pre-trained on ImageNet 213 | """ 214 | model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs) 215 | if pretrained: 216 | # model.load_state_dict(model_zoo.load_url(model_urls['resnet101'])) 217 | model_path = './initmodel/resnet101_v2.pth' 218 | model.load_state_dict(torch.load(model_path), strict=False) 219 | return model 220 | 221 | 222 | def resnet152(pretrained=False, **kwargs): 223 | """Constructs a ResNet-152 model. 224 | 225 | Args: 226 | pretrained (bool): If True, returns a model pre-trained on ImageNet 227 | """ 228 | model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs) 229 | if pretrained: 230 | # model.load_state_dict(model_zoo.load_url(model_urls['resnet152'])) 231 | model_path = './initmodel/resnet152_v2.pth' 232 | model.load_state_dict(torch.load(model_path), strict=False) 233 | return model 234 | -------------------------------------------------------------------------------- /model/vgg.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.utils.model_zoo as model_zoo 4 | 5 | BatchNorm = nn.BatchNorm2d 6 | 7 | __all__ = [ 8 | 'VGG', 'vgg11', 'vgg11_bn', 'vgg13', 'vgg13_bn', 'vgg16', 'vgg16_bn', 9 | 'vgg19_bn', 'vgg19', 10 | ] 11 | 12 | 13 | model_urls = { 14 | 'vgg11': 'https://download.pytorch.org/models/vgg11-bbd30ac9.pth', 15 | 'vgg13': 'https://download.pytorch.org/models/vgg13-c768596a.pth', 16 | 'vgg16': 'https://download.pytorch.org/models/vgg16-397923af.pth', 17 | 'vgg19': 'https://download.pytorch.org/models/vgg19-dcbb9e9d.pth', 18 | 'vgg11_bn': 'https://download.pytorch.org/models/vgg11_bn-6002323d.pth', 19 | 'vgg13_bn': 'https://download.pytorch.org/models/vgg13_bn-abd245e5.pth', 20 | 'vgg16_bn': 'https://download.pytorch.org/models/vgg16_bn-6c64b313.pth', 21 | 'vgg19_bn': 'https://download.pytorch.org/models/vgg19_bn-c79401a0.pth', 22 | } 23 | 24 | 25 | class VGG(nn.Module): 26 | 27 | def __init__(self, features, num_classes=1000, init_weights=True): 28 | super(VGG, self).__init__() 29 | self.features = features 30 | self.avgpool = nn.AdaptiveAvgPool2d((7, 7)) 31 | self.classifier = nn.Sequential( 32 | nn.Linear(512 * 7 * 7, 4096), 33 | nn.ReLU(True), 34 | nn.Dropout(), 35 | nn.Linear(4096, 4096), 36 | nn.ReLU(True), 37 | nn.Dropout(), 38 | nn.Linear(4096, num_classes), 39 | ) 40 | if init_weights: 41 | self._initialize_weights() 42 | 43 | def forward(self, x): 44 | x = self.features(x) 45 | x = self.avgpool(x) 46 | x = x.view(x.size(0), -1) 47 | x = self.classifier(x) 48 | return x 49 | 50 | def _initialize_weights(self): 51 | for m in self.modules(): 52 | if isinstance(m, nn.Conv2d): 53 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 54 | if m.bias is not None: 55 | nn.init.constant_(m.bias, 0) 56 | elif isinstance(m, BatchNorm): 57 | nn.init.constant_(m.weight, 1) 58 | nn.init.constant_(m.bias, 0) 59 | elif isinstance(m, nn.Linear): 60 | nn.init.normal_(m.weight, 0, 0.01) 61 | nn.init.constant_(m.bias, 0) 62 | 63 | 64 | def make_layers(cfg, batch_norm=False): 65 | layers = [] 66 | in_channels = 3 67 | for v in cfg: 68 | if v == 'M': 69 | layers += [nn.MaxPool2d(kernel_size=2, stride=2, ceil_mode=True)] 70 | else: 71 | conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1) 72 | if batch_norm: 73 | layers += [conv2d, BatchNorm(v), nn.ReLU(inplace=True)] 74 | else: 75 | layers += [conv2d, nn.ReLU(inplace=True)] 76 | in_channels = v 77 | return nn.Sequential(*layers) 78 | 79 | 80 | cfg = { 81 | 'A': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], 82 | 'B': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], 83 | 'D': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'], 84 | 'E': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'], 85 | } 86 | 87 | 88 | def vgg11(pretrained=False, **kwargs): 89 | """VGG 11-layer model (configuration "A") 90 | Args: 91 | pretrained (bool): If True, returns a model pre-trained on ImageNet 92 | """ 93 | if pretrained: 94 | kwargs['init_weights'] = False 95 | model = VGG(make_layers(cfg['A']), **kwargs) 96 | if pretrained: 97 | model.load_state_dict(model_zoo.load_url(model_urls['vgg11'])) 98 | return model 99 | 100 | 101 | def vgg11_bn(pretrained=False, **kwargs): 102 | """VGG 11-layer model (configuration "A") with batch normalization 103 | Args: 104 | pretrained (bool): If True, returns a model pre-trained on ImageNet 105 | """ 106 | if pretrained: 107 | kwargs['init_weights'] = False 108 | model = VGG(make_layers(cfg['A'], batch_norm=True), **kwargs) 109 | if pretrained: 110 | model.load_state_dict(model_zoo.load_url(model_urls['vgg11_bn'])) 111 | return model 112 | 113 | 114 | def vgg13(pretrained=False, **kwargs): 115 | """VGG 13-layer model (configuration "B") 116 | Args: 117 | pretrained (bool): If True, returns a model pre-trained on ImageNet 118 | """ 119 | if pretrained: 120 | kwargs['init_weights'] = False 121 | model = VGG(make_layers(cfg['B']), **kwargs) 122 | if pretrained: 123 | model.load_state_dict(model_zoo.load_url(model_urls['vgg13'])) 124 | return model 125 | 126 | 127 | def vgg13_bn(pretrained=False, **kwargs): 128 | """VGG 13-layer model (configuration "B") with batch normalization 129 | Args: 130 | pretrained (bool): If True, returns a model pre-trained on ImageNet 131 | """ 132 | if pretrained: 133 | kwargs['init_weights'] = False 134 | model = VGG(make_layers(cfg['B'], batch_norm=True), **kwargs) 135 | if pretrained: 136 | model.load_state_dict(model_zoo.load_url(model_urls['vgg13_bn'])) 137 | return model 138 | 139 | 140 | def vgg16(pretrained=False, **kwargs): 141 | """VGG 16-layer model (configuration "D") 142 | Args: 143 | pretrained (bool): If True, returns a model pre-trained on ImageNet 144 | """ 145 | if pretrained: 146 | kwargs['init_weights'] = False 147 | model = VGG(make_layers(cfg['D']), **kwargs) 148 | if pretrained: 149 | #model.load_state_dict(model_zoo.load_url(model_urls['vgg16_bn'])) 150 | model_path = './initmodel/vgg16.pth' 151 | model.load_state_dict(torch.load(model_path), strict=False) 152 | return model 153 | 154 | 155 | def vgg16_bn(pretrained=False, **kwargs): 156 | """VGG 16-layer model (configuration "D") with batch normalization 157 | Args: 158 | pretrained (bool): If True, returns a model pre-trained on ImageNet 159 | """ 160 | if pretrained: 161 | kwargs['init_weights'] = False 162 | model = VGG(make_layers(cfg['D'], batch_norm=True), **kwargs) 163 | if pretrained: 164 | #model.load_state_dict(model_zoo.load_url(model_urls['vgg16_bn'])) 165 | model_path = './initmodel/vgg16_bn.pth' 166 | model.load_state_dict(torch.load(model_path), strict=False) 167 | return model 168 | 169 | 170 | def vgg19(pretrained=False, **kwargs): 171 | """VGG 19-layer model (configuration "E") 172 | Args: 173 | pretrained (bool): If True, returns a model pre-trained on ImageNet 174 | """ 175 | if pretrained: 176 | kwargs['init_weights'] = False 177 | model = VGG(make_layers(cfg['E']), **kwargs) 178 | if pretrained: 179 | model.load_state_dict(model_zoo.load_url(model_urls['vgg19'])) 180 | return model 181 | 182 | 183 | def vgg19_bn(pretrained=False, **kwargs): 184 | """VGG 19-layer model (configuration 'E') with batch normalization 185 | Args: 186 | pretrained (bool): If True, returns a model pre-trained on ImageNet 187 | """ 188 | if pretrained: 189 | kwargs['init_weights'] = False 190 | model = VGG(make_layers(cfg['E'], batch_norm=True), **kwargs) 191 | if pretrained: 192 | model.load_state_dict(model_zoo.load_url(model_urls['vgg19_bn'])) 193 | return model 194 | 195 | if __name__ =='__main__': 196 | import os 197 | # os.environ["CUDA_VISIBLE_DEVICES"] = '7' 198 | input = torch.rand(4, 3, 473, 473).cuda() 199 | target = torch.rand(4, 473, 473).cuda()*1.0 200 | model = vgg16_bn(pretrained=False).cuda() 201 | model.train() 202 | layer0_idx = range(0,6) 203 | layer1_idx = range(6,13) 204 | layer2_idx = range(13,23) 205 | layer3_idx = range(23,33) 206 | layer4_idx = range(34,43) 207 | #layer4_idx = range(34,43) 208 | print(model.features) 209 | layers_0 = [] 210 | layers_1 = [] 211 | layers_2 = [] 212 | layers_3 = [] 213 | layers_4 = [] 214 | for idx in layer0_idx: 215 | layers_0 += [model.features[idx]] 216 | for idx in layer1_idx: 217 | layers_1 += [model.features[idx]] 218 | for idx in layer2_idx: 219 | layers_2 += [model.features[idx]] 220 | for idx in layer3_idx: 221 | layers_3 += [model.features[idx]] 222 | for idx in layer4_idx: 223 | layers_4 += [model.features[idx]] 224 | 225 | layer0 = nn.Sequential(*layers_0) 226 | layer1 = nn.Sequential(*layers_1) 227 | layer2 = nn.Sequential(*layers_2) 228 | layer3 = nn.Sequential(*layers_3) 229 | layer4 = nn.Sequential(*layers_4) 230 | 231 | output = layer0(input) 232 | print(layer0) 233 | print('layer 0: {}'.format(output.size())) 234 | output = layer1(output) 235 | print(layer1) 236 | print('layer 1: {}'.format(output.size())) 237 | output = layer2(output) 238 | print(layer2) 239 | print('layer 2: {}'.format(output.size())) 240 | output = layer3(output) 241 | print(layer3) 242 | print('layer 3: {}'.format(output.size())) 243 | output = layer4(output) 244 | print(layer4) 245 | print('layer 4: {}'.format(output.size())) 246 | -------------------------------------------------------------------------------- /test_coco.py: -------------------------------------------------------------------------------- 1 | import os 2 | import datetime 3 | import random 4 | import time 5 | import cv2 6 | import numpy as np 7 | import logging 8 | import argparse 9 | import math 10 | from visdom import Visdom 11 | import os.path as osp 12 | 13 | import torch 14 | import torch.backends.cudnn as cudnn 15 | import torch.nn as nn 16 | import torch.nn.functional as F 17 | import torch.nn.parallel 18 | import torch.optim 19 | import torch.utils.data 20 | import torch.multiprocessing as mp 21 | import torch.distributed as dist 22 | from torch.utils.data.distributed import DistributedSampler 23 | from torch.cuda.amp import autocast 24 | 25 | from tensorboardX import SummaryWriter 26 | 27 | from model import HMNetAMP 28 | 29 | from util import dataset 30 | from util import transform, transform_tri, config 31 | from util.util import AverageMeter, poly_learning_rate, intersectionAndUnionGPU, get_model_para_number, setup_seed, \ 32 | get_logger, get_save_path, \ 33 | is_same_model, fix_bn, sum_list, check_makedirs 34 | import matplotlib.pyplot as plt 35 | 36 | cv2.ocl.setUseOpenCL(False) 37 | cv2.setNumThreads(0) 38 | val_manual_seed = 123 39 | setup_seed(val_manual_seed, False) 40 | seed_array = [321] 41 | val_num = len(seed_array) 42 | 43 | 44 | def get_parser(): 45 | parser = argparse.ArgumentParser(description='PyTorch Few-Shot Semantic Segmentation') 46 | parser.add_argument('--arch', type=str, default='HMNetAMP') 47 | parser.add_argument('--viz', action='store_true', default=False) 48 | parser.add_argument('--config', type=str, default='config/coco/coco_split3_resnet50.yaml', 49 | help='config file') # coco/coco_split0_resnet50.yaml 50 | parser.add_argument('--episode', help='number of test episodes', type=int, default=4000) 51 | parser.add_argument('--opts', help='see config/ade20k/ade20k_pspnet50.yaml for all options', default=None, 52 | nargs=argparse.REMAINDER) 53 | args = parser.parse_args() 54 | assert args.config is not None 55 | cfg = config.load_cfg_from_cfg_file(args.config) 56 | cfg = config.merge_cfg_from_args(cfg, args) 57 | if args.opts is not None: 58 | cfg = config.merge_cfg_from_list(cfg, args.opts) 59 | return cfg 60 | 61 | 62 | def get_model(args): 63 | model = eval(args.arch).OneModel(args) 64 | optimizer, optimizer_swin = model.get_optim(model, args, LR=args.base_lr) 65 | 66 | model = model.cuda() 67 | 68 | # Resume 69 | get_save_path(args) 70 | check_makedirs(args.snapshot_path) 71 | check_makedirs(args.result_path) 72 | 73 | if args.weight: 74 | weight_path = osp.join(args.snapshot_path, args.weight) 75 | if os.path.isfile(weight_path): 76 | logger.info("=> loading checkpoint '{}'".format(weight_path)) 77 | checkpoint = torch.load(weight_path, map_location=torch.device('cpu')) 78 | args.start_epoch = checkpoint['epoch'] 79 | new_param = checkpoint['state_dict'] 80 | try: 81 | model.load_state_dict(new_param) 82 | except RuntimeError: # 1GPU loads mGPU model 83 | for key in list(new_param.keys()): 84 | new_param[key[7:]] = new_param.pop(key) 85 | model.load_state_dict(new_param) 86 | optimizer.load_state_dict(checkpoint['optimizer']) 87 | optimizer_swin.load_state_dict(checkpoint['optimizer_swin']) 88 | logger.info("=> loaded checkpoint '{}' (epoch {})".format(weight_path, checkpoint['epoch'])) 89 | else: 90 | logger.info("=> no checkpoint found at '{}'".format(weight_path)) 91 | 92 | # Get model para. 93 | total_number, learnable_number = get_model_para_number(model) 94 | print('Number of Parameters: %d' % (total_number)) 95 | print('Number of Learnable Parameters: %d' % (learnable_number)) 96 | 97 | time.sleep(5) 98 | return model, optimizer, optimizer_swin 99 | 100 | 101 | def main(): 102 | global args, logger, writer 103 | args = get_parser() 104 | logger = get_logger() 105 | args.distributed = True if torch.cuda.device_count() > 1 else False 106 | print(args) 107 | 108 | if args.manual_seed is not None: 109 | setup_seed(args.manual_seed, args.seed_deterministic) 110 | 111 | assert args.classes > 1 112 | assert args.zoom_factor in [1, 2, 4, 8] 113 | assert (args.train_h - 1) % 8 == 0 and (args.train_w - 1) % 8 == 0 114 | 115 | logger.info("=> creating model ...") 116 | model, optimizer, optimizer_swin = get_model(args) 117 | logger.info(model) 118 | 119 | # ---------------------- DATASET ---------------------- 120 | value_scale = 255 121 | mean = [0.485, 0.456, 0.406] 122 | mean = [item * value_scale for item in mean] 123 | std = [0.229, 0.224, 0.225] 124 | std = [item * value_scale for item in std] 125 | # Val 126 | if args.evaluate: 127 | if args.resized_val: 128 | val_transform = transform.Compose([ 129 | transform.Resize(size=args.val_size), 130 | transform.ToTensor(), 131 | transform.Normalize(mean=mean, std=std)]) 132 | else: 133 | val_transform = transform.Compose([ 134 | transform.test_Resize(size=args.val_size), 135 | transform.ToTensor(), 136 | transform.Normalize(mean=mean, std=std)]) 137 | if args.data_set == 'pascal' or args.data_set == 'coco': 138 | val_data = dataset.SemData(split=args.split, shot=args.shot, data_root=args.data_root, 139 | data_list=args.val_list, transform=val_transform, mode='val', 140 | ann_type=args.ann_type, data_set=args.data_set, use_split_coco=args.use_split_coco) 141 | val_loader = torch.utils.data.DataLoader(val_data, batch_size=args.batch_size_val, shuffle=False, 142 | num_workers=args.workers, pin_memory=False, sampler=None) 143 | 144 | # ---------------------- VAL ---------------------- 145 | start_time = time.time() 146 | FBIoU_array = np.zeros(val_num) 147 | mIoU_array = np.zeros(val_num) 148 | pIoU_array = np.zeros(val_num) 149 | for val_id in range(val_num): 150 | val_seed = seed_array[val_id] 151 | print('Val: [{}/{}] \t Seed: {}'.format(val_id + 1, val_num, val_seed)) 152 | fb_iou, miou, piou = validate(val_loader, model, val_seed, args.episode) 153 | FBIoU_array[val_id], mIoU_array[val_id], pIoU_array[val_id] = \ 154 | fb_iou, miou, piou 155 | 156 | total_time = time.time() - start_time 157 | t_m, t_s = divmod(total_time, 60) 158 | t_h, t_m = divmod(t_m, 60) 159 | total_time = '{:02d}h {:02d}m {:02d}s'.format(int(t_h), int(t_m), int(t_s)) 160 | 161 | print('\nTotal running time: {}'.format(total_time)) 162 | print('Seed0: {}'.format(val_manual_seed)) 163 | print('Seed: {}'.format(seed_array)) 164 | print('mIoU: {}'.format(np.round(mIoU_array, 4))) 165 | print('FBIoU: {}'.format(np.round(FBIoU_array, 4))) 166 | print('pIoU: {}'.format(np.round(pIoU_array, 4))) 167 | print('-' * 43) 168 | print('Best_Seed_m: {} \t Best_Seed_F: {} \t Best_Seed_p: {}'.format(seed_array[mIoU_array.argmax()], 169 | seed_array[FBIoU_array.argmax()], 170 | seed_array[pIoU_array.argmax()])) 171 | print('Best_mIoU: {:.4f} \t Best_FBIoU: {:.4f} \t Best_pIoU: {:.4f}'.format( 172 | mIoU_array.max(), FBIoU_array.max(), pIoU_array.max())) 173 | print('Mean_mIoU: {:.4f} \t Mean_FBIoU: {:.4f} \t Mean_pIoU: {:.4f}'.format( 174 | mIoU_array.mean(), FBIoU_array.mean(), pIoU_array.mean())) 175 | 176 | 177 | def validate(val_loader, model, val_seed, episode): 178 | logger.info('>>>>>>>>>>>>>>>> Start Evaluation >>>>>>>>>>>>>>>>') 179 | batch_time = AverageMeter() 180 | model_time = AverageMeter() 181 | data_time = AverageMeter() 182 | loss_meter = AverageMeter() 183 | 184 | intersection_meter = AverageMeter() # final 185 | union_meter = AverageMeter() 186 | target_meter = AverageMeter() 187 | 188 | if args.data_set == 'pascal': 189 | test_num = 1000 190 | split_gap = 5 191 | elif args.data_set == 'coco': 192 | test_num = episode 193 | split_gap = 20 194 | 195 | class_intersection_meter = [0] * split_gap 196 | class_union_meter = [0] * split_gap 197 | 198 | setup_seed(val_seed, args.seed_deterministic) 199 | 200 | criterion = nn.CrossEntropyLoss(ignore_index=args.ignore_label) 201 | 202 | model.eval() 203 | end = time.time() 204 | val_start = end 205 | 206 | assert test_num % args.batch_size_val == 0 207 | db_epoch = math.ceil(test_num / (len(val_loader) - args.batch_size_val)) 208 | iter_num = 0 209 | 210 | for e in range(db_epoch): 211 | for i, (input, target, s_input, s_mask, subcls, ori_label) in enumerate(val_loader): 212 | if iter_num * args.batch_size_val >= test_num: 213 | break 214 | iter_num += 1 215 | data_time.update(time.time() - end) 216 | 217 | s_input = s_input.cuda(non_blocking=True) 218 | s_mask = s_mask.cuda(non_blocking=True) 219 | input = input.cuda(non_blocking=True) 220 | target = target.cuda(non_blocking=True) 221 | ori_label = ori_label.cuda(non_blocking=True) 222 | 223 | start_time = time.time() 224 | 225 | with autocast(): 226 | output = model(s_x=s_input, s_y=s_mask, x=input, y_m=target, cat_idx=subcls) 227 | model_time.update(time.time() - start_time) 228 | 229 | if args.ori_resize: 230 | output = F.interpolate(output, size=ori_label.size()[-2:], mode='bilinear', align_corners=True) 231 | target = ori_label.long() 232 | 233 | output = F.interpolate(output, size=target.size()[1:], mode='bilinear', align_corners=True) 234 | loss = criterion(output, target) 235 | 236 | output = output.max(1)[1] 237 | subcls = subcls[0].cpu().numpy()[0] 238 | 239 | intersection, union, new_target = intersectionAndUnionGPU(output, target, args.classes, args.ignore_label) 240 | intersection, union, new_target = intersection.cpu().numpy(), union.cpu().numpy(), new_target.cpu().numpy() 241 | intersection_meter.update(intersection), union_meter.update(union), target_meter.update(new_target) 242 | class_intersection_meter[subcls] += intersection[1] 243 | class_union_meter[subcls] += union[1] 244 | 245 | accuracy = sum(intersection_meter.val) / (sum(target_meter.val) + 1e-10) 246 | loss_meter.update(loss.item(), input.size(0)) 247 | batch_time.update(time.time() - end) 248 | end = time.time() 249 | 250 | remain_iter = test_num / args.batch_size_val - iter_num 251 | remain_time = remain_iter * batch_time.avg 252 | t_m, t_s = divmod(remain_time, 60) 253 | t_h, t_m = divmod(t_m, 60) 254 | remain_time = '{:02d}:{:02d}:{:02d}'.format(int(t_h), int(t_m), int(t_s)) 255 | 256 | if ((i + 1) % round((test_num / 100)) == 0): 257 | logger.info('Test: [{}/{}] ' 258 | 'Data {data_time.val:.3f} ({data_time.avg:.3f}) ' 259 | 'Batch {batch_time.val:.3f} ({batch_time.avg:.3f}) ' 260 | 'Remain {remain_time} ' 261 | 'Loss {loss_meter.val:.4f} ({loss_meter.avg:.4f}) ' 262 | 'Accuracy {accuracy:.4f}.'.format(iter_num * args.batch_size_val, test_num, 263 | data_time=data_time, 264 | batch_time=batch_time, 265 | remain_time=remain_time, 266 | loss_meter=loss_meter, 267 | accuracy=accuracy)) 268 | val_time = time.time() - val_start 269 | 270 | iou_class = intersection_meter.sum / (union_meter.sum + 1e-10) 271 | mIoU = np.mean(iou_class) 272 | 273 | class_iou_class = [] 274 | class_miou = 0 275 | for i in range(len(class_intersection_meter)): 276 | class_iou = class_intersection_meter[i] / (class_union_meter[i] + 1e-10) 277 | class_iou_class.append(class_iou) 278 | class_miou += class_iou 279 | 280 | class_miou = class_miou * 1.0 / len(class_intersection_meter) 281 | logger.info('meanIoU---Val result: mIoU {:.4f}.'.format(class_miou)) # final 282 | logger.info('<<<<<<< Novel Results <<<<<<<') 283 | for i in range(split_gap): 284 | logger.info('Class_{} Result: iou {:.4f}.'.format(i + 1, class_iou_class[i])) 285 | 286 | logger.info('FBIoU---Val result: FBIoU {:.4f}.'.format(mIoU)) 287 | for i in range(args.classes): 288 | logger.info('Class_{} Result: iou_f {:.4f}.'.format(i, iou_class[i])) 289 | logger.info('<<<<<<<<<<<<<<<<< End Evaluation <<<<<<<<<<<<<<<<<') 290 | 291 | print('total time: {:.4f}, avg inference time: {:.4f}, count: {}'.format(val_time, model_time.avg, test_num)) 292 | 293 | return mIoU, class_miou, iou_class[1] 294 | 295 | 296 | if __name__ == '__main__': 297 | main() 298 | -------------------------------------------------------------------------------- /test_coco.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | export CUDA_VISIBLE_DEVICES=0 3 | 4 | dataset=coco 5 | exp_name=split$1 # 0 6 | arch=HMNetAMP 7 | net=$2 # vgg/resnet50 8 | postfix=$3 # manet/manet_5s 9 | 10 | config=config/${dataset}/${dataset}_${exp_name}_${net}_${postfix}.yaml 11 | 12 | python test_coco.py --config=${config} --arch=${arch} --eposide=4000 13 | -------------------------------------------------------------------------------- /test_pascal.py: -------------------------------------------------------------------------------- 1 | import os 2 | import datetime 3 | import random 4 | import time 5 | import cv2 6 | import numpy as np 7 | import logging 8 | import argparse 9 | import math 10 | from visdom import Visdom 11 | import os.path as osp 12 | 13 | import torch 14 | import torch.backends.cudnn as cudnn 15 | import torch.nn as nn 16 | import torch.nn.functional as F 17 | import torch.nn.parallel 18 | import torch.optim 19 | import torch.utils.data 20 | import torch.multiprocessing as mp 21 | import torch.distributed as dist 22 | from torch.utils.data.distributed import DistributedSampler 23 | 24 | from tensorboardX import SummaryWriter 25 | 26 | from model import HMNet 27 | 28 | from util import dataset 29 | from util import transform, transform_tri, config 30 | from util.util import AverageMeter, poly_learning_rate, intersectionAndUnionGPU, get_model_para_number, setup_seed, \ 31 | get_logger, get_save_path, \ 32 | is_same_model, fix_bn, sum_list, check_makedirs 33 | import matplotlib.pyplot as plt 34 | 35 | cv2.ocl.setUseOpenCL(False) 36 | cv2.setNumThreads(0) 37 | val_manual_seed = 123 38 | setup_seed(val_manual_seed, False) 39 | seed_array = [321] 40 | val_num = len(seed_array) 41 | 42 | 43 | def get_parser(): 44 | parser = argparse.ArgumentParser(description='PyTorch Few-Shot Semantic Segmentation') 45 | parser.add_argument('--arch', type=str, default='HMNet') 46 | parser.add_argument('--viz', action='store_true', default=False) 47 | parser.add_argument('--config', type=str, default='config/coco/coco_split3_resnet50.yaml', 48 | help='config file') # coco/coco_split0_resnet50.yaml 49 | parser.add_argument('--episode', help='number of test episodes', type=int, default=1000) 50 | parser.add_argument('--opts', help='see config/ade20k/ade20k_pspnet50.yaml for all options', default=None, 51 | nargs=argparse.REMAINDER) 52 | args = parser.parse_args() 53 | assert args.config is not None 54 | cfg = config.load_cfg_from_cfg_file(args.config) 55 | cfg = config.merge_cfg_from_args(cfg, args) 56 | if args.opts is not None: 57 | cfg = config.merge_cfg_from_list(cfg, args.opts) 58 | return cfg 59 | 60 | 61 | def get_model(args): 62 | model = eval(args.arch).OneModel(args) 63 | optimizer, optimizer_swin = model.get_optim(model, args, LR=args.base_lr) 64 | 65 | model = model.cuda() 66 | 67 | # Resume 68 | get_save_path(args) 69 | check_makedirs(args.snapshot_path) 70 | check_makedirs(args.result_path) 71 | 72 | if args.weight: 73 | weight_path = osp.join(args.snapshot_path, args.weight) 74 | if os.path.isfile(weight_path): 75 | logger.info("=> loading checkpoint '{}'".format(weight_path)) 76 | checkpoint = torch.load(weight_path, map_location=torch.device('cpu')) 77 | args.start_epoch = checkpoint['epoch'] 78 | new_param = checkpoint['state_dict'] 79 | try: 80 | model.load_state_dict(new_param) 81 | except RuntimeError: # 1GPU loads mGPU model 82 | for key in list(new_param.keys()): 83 | new_param[key[7:]] = new_param.pop(key) 84 | model.load_state_dict(new_param) 85 | optimizer.load_state_dict(checkpoint['optimizer']) 86 | optimizer_swin.load_state_dict(checkpoint['optimizer_swin']) 87 | logger.info("=> loaded checkpoint '{}' (epoch {})".format(weight_path, checkpoint['epoch'])) 88 | else: 89 | logger.info("=> no checkpoint found at '{}'".format(weight_path)) 90 | 91 | # Get model para. 92 | total_number, learnable_number = get_model_para_number(model) 93 | print('Number of Parameters: %d' % (total_number)) 94 | print('Number of Learnable Parameters: %d' % (learnable_number)) 95 | 96 | time.sleep(5) 97 | return model, optimizer, optimizer_swin 98 | 99 | 100 | def main(): 101 | global args, logger, writer 102 | args = get_parser() 103 | logger = get_logger() 104 | args.distributed = True if torch.cuda.device_count() > 1 else False 105 | print(args) 106 | 107 | if args.manual_seed is not None: 108 | setup_seed(args.manual_seed, args.seed_deterministic) 109 | 110 | assert args.classes > 1 111 | assert args.zoom_factor in [1, 2, 4, 8] 112 | assert (args.train_h - 1) % 8 == 0 and (args.train_w - 1) % 8 == 0 113 | 114 | logger.info("=> creating model ...") 115 | model, optimizer, optimizer_swin = get_model(args) 116 | logger.info(model) 117 | 118 | # ---------------------- DATASET ---------------------- 119 | value_scale = 255 120 | mean = [0.485, 0.456, 0.406] 121 | mean = [item * value_scale for item in mean] 122 | std = [0.229, 0.224, 0.225] 123 | std = [item * value_scale for item in std] 124 | # Val 125 | if args.evaluate: 126 | if args.resized_val: 127 | val_transform = transform.Compose([ 128 | transform.Resize(size=args.val_size), 129 | transform.ToTensor(), 130 | transform.Normalize(mean=mean, std=std)]) 131 | else: 132 | val_transform = transform.Compose([ 133 | transform.test_Resize(size=args.val_size), 134 | transform.ToTensor(), 135 | transform.Normalize(mean=mean, std=std)]) 136 | if args.data_set == 'pascal' or args.data_set == 'coco': 137 | val_data = dataset.SemData(split=args.split, shot=args.shot, data_root=args.data_root, 138 | data_list=args.val_list, transform=val_transform, mode='val', 139 | ann_type=args.ann_type, data_set=args.data_set, use_split_coco=args.use_split_coco) 140 | val_loader = torch.utils.data.DataLoader(val_data, batch_size=args.batch_size_val, shuffle=False, 141 | num_workers=args.workers, pin_memory=False, sampler=None) 142 | 143 | # ---------------------- VAL ---------------------- 144 | start_time = time.time() 145 | FBIoU_array = np.zeros(val_num) 146 | mIoU_array = np.zeros(val_num) 147 | pIoU_array = np.zeros(val_num) 148 | for val_id in range(val_num): 149 | val_seed = seed_array[val_id] 150 | print('Val: [{}/{}] \t Seed: {}'.format(val_id + 1, val_num, val_seed)) 151 | fb_iou, miou, piou = validate(val_loader, model, val_seed, args.episode) 152 | FBIoU_array[val_id], mIoU_array[val_id], pIoU_array[val_id] = \ 153 | fb_iou, miou, piou 154 | 155 | total_time = time.time() - start_time 156 | t_m, t_s = divmod(total_time, 60) 157 | t_h, t_m = divmod(t_m, 60) 158 | total_time = '{:02d}h {:02d}m {:02d}s'.format(int(t_h), int(t_m), int(t_s)) 159 | 160 | print('\nTotal running time: {}'.format(total_time)) 161 | print('Seed0: {}'.format(val_manual_seed)) 162 | print('Seed: {}'.format(seed_array)) 163 | print('mIoU: {}'.format(np.round(mIoU_array, 4))) 164 | print('FBIoU: {}'.format(np.round(FBIoU_array, 4))) 165 | print('pIoU: {}'.format(np.round(pIoU_array, 4))) 166 | print('-' * 43) 167 | print('Best_Seed_m: {} \t Best_Seed_F: {} \t Best_Seed_p: {}'.format(seed_array[mIoU_array.argmax()], 168 | seed_array[FBIoU_array.argmax()], 169 | seed_array[pIoU_array.argmax()])) 170 | print('Best_mIoU: {:.4f} \t Best_FBIoU: {:.4f} \t Best_pIoU: {:.4f}'.format( 171 | mIoU_array.max(), FBIoU_array.max(), pIoU_array.max())) 172 | print('Mean_mIoU: {:.4f} \t Mean_FBIoU: {:.4f} \t Mean_pIoU: {:.4f}'.format( 173 | mIoU_array.mean(), FBIoU_array.mean(), pIoU_array.mean())) 174 | 175 | 176 | def validate(val_loader, model, val_seed, episode): 177 | logger.info('>>>>>>>>>>>>>>>> Start Evaluation >>>>>>>>>>>>>>>>') 178 | batch_time = AverageMeter() 179 | model_time = AverageMeter() 180 | data_time = AverageMeter() 181 | loss_meter = AverageMeter() 182 | 183 | intersection_meter = AverageMeter() # final 184 | union_meter = AverageMeter() 185 | target_meter = AverageMeter() 186 | 187 | if args.data_set == 'pascal': 188 | test_num = 1000 189 | split_gap = 5 190 | elif args.data_set == 'coco': 191 | test_num = episode 192 | split_gap = 20 193 | 194 | class_intersection_meter = [0] * split_gap 195 | class_union_meter = [0] * split_gap 196 | 197 | setup_seed(val_seed, args.seed_deterministic) 198 | 199 | criterion = nn.CrossEntropyLoss(ignore_index=args.ignore_label) 200 | 201 | model.eval() 202 | end = time.time() 203 | val_start = end 204 | 205 | assert test_num % args.batch_size_val == 0 206 | db_epoch = math.ceil(test_num / (len(val_loader) - args.batch_size_val)) 207 | iter_num = 0 208 | 209 | for e in range(db_epoch): 210 | for i, (input, target, s_input, s_mask, subcls, ori_label) in enumerate(val_loader): 211 | if iter_num * args.batch_size_val >= test_num: 212 | break 213 | iter_num += 1 214 | data_time.update(time.time() - end) 215 | 216 | s_input = s_input.cuda(non_blocking=True) 217 | s_mask = s_mask.cuda(non_blocking=True) 218 | input = input.cuda(non_blocking=True) 219 | target = target.cuda(non_blocking=True) 220 | ori_label = ori_label.cuda(non_blocking=True) 221 | 222 | start_time = time.time() 223 | 224 | output = model(s_x=s_input, s_y=s_mask, x=input, y_m=target, cat_idx=subcls) 225 | model_time.update(time.time() - start_time) 226 | 227 | if args.ori_resize: 228 | output = F.interpolate(output, size=ori_label.size()[-2:], mode='bilinear', align_corners=True) 229 | target = ori_label.long() 230 | 231 | output = F.interpolate(output, size=target.size()[1:], mode='bilinear', align_corners=True) 232 | loss = criterion(output, target) 233 | 234 | output = output.max(1)[1] 235 | subcls = subcls[0].cpu().numpy()[0] 236 | 237 | intersection, union, new_target = intersectionAndUnionGPU(output, target, args.classes, args.ignore_label) 238 | intersection, union, new_target = intersection.cpu().numpy(), union.cpu().numpy(), new_target.cpu().numpy() 239 | intersection_meter.update(intersection), union_meter.update(union), target_meter.update(new_target) 240 | class_intersection_meter[subcls] += intersection[1] 241 | class_union_meter[subcls] += union[1] 242 | 243 | accuracy = sum(intersection_meter.val) / (sum(target_meter.val) + 1e-10) 244 | loss_meter.update(loss.item(), input.size(0)) 245 | batch_time.update(time.time() - end) 246 | end = time.time() 247 | 248 | remain_iter = test_num / args.batch_size_val - iter_num 249 | remain_time = remain_iter * batch_time.avg 250 | t_m, t_s = divmod(remain_time, 60) 251 | t_h, t_m = divmod(t_m, 60) 252 | remain_time = '{:02d}:{:02d}:{:02d}'.format(int(t_h), int(t_m), int(t_s)) 253 | 254 | if ((i + 1) % round((test_num / 100)) == 0): 255 | logger.info('Test: [{}/{}] ' 256 | 'Data {data_time.val:.3f} ({data_time.avg:.3f}) ' 257 | 'Batch {batch_time.val:.3f} ({batch_time.avg:.3f}) ' 258 | 'Remain {remain_time} ' 259 | 'Loss {loss_meter.val:.4f} ({loss_meter.avg:.4f}) ' 260 | 'Accuracy {accuracy:.4f}.'.format(iter_num * args.batch_size_val, test_num, 261 | data_time=data_time, 262 | batch_time=batch_time, 263 | remain_time=remain_time, 264 | loss_meter=loss_meter, 265 | accuracy=accuracy)) 266 | val_time = time.time() - val_start 267 | 268 | iou_class = intersection_meter.sum / (union_meter.sum + 1e-10) 269 | mIoU = np.mean(iou_class) 270 | 271 | class_iou_class = [] 272 | class_miou = 0 273 | for i in range(len(class_intersection_meter)): 274 | class_iou = class_intersection_meter[i] / (class_union_meter[i] + 1e-10) 275 | class_iou_class.append(class_iou) 276 | class_miou += class_iou 277 | 278 | class_miou = class_miou * 1.0 / len(class_intersection_meter) 279 | logger.info('meanIoU---Val result: mIoU {:.4f}.'.format(class_miou)) # final 280 | logger.info('<<<<<<< Novel Results <<<<<<<') 281 | for i in range(split_gap): 282 | logger.info('Class_{} Result: iou {:.4f}.'.format(i + 1, class_iou_class[i])) 283 | 284 | logger.info('FBIoU---Val result: FBIoU {:.4f}.'.format(mIoU)) 285 | for i in range(args.classes): 286 | logger.info('Class_{} Result: iou_f {:.4f}.'.format(i, iou_class[i])) 287 | logger.info('<<<<<<<<<<<<<<<<< End Evaluation <<<<<<<<<<<<<<<<<') 288 | 289 | print('total time: {:.4f}, avg inference time: {:.4f}, count: {}'.format(val_time, model_time.avg, test_num)) 290 | 291 | return mIoU, class_miou, iou_class[1] 292 | 293 | 294 | if __name__ == '__main__': 295 | main() 296 | -------------------------------------------------------------------------------- /test_pascal.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | export CUDA_VISIBLE_DEVICES=0 3 | 4 | dataset=pascal 5 | exp_name=split$1 # 0 6 | arch=HMNet 7 | net=$2 # vgg/resnet50 8 | postfix=$3 # manet/manet_5s 9 | 10 | config=config/${dataset}/${dataset}_${exp_name}_${net}_${postfix}.yaml 11 | 12 | python test_pascal.py --config=${config} --arch=${arch} --eposide=1000 13 | -------------------------------------------------------------------------------- /train_coco.sh: -------------------------------------------------------------------------------- 1 | gpu=8 2 | port=$1 # 1234 3 | dataset=coco 4 | exp_name=split$2 # 0/1/2/3 5 | arch=HMNetAMP 6 | net=$3 # vgg/renet50 7 | postfix=$4 # manet/manet_5s 8 | 9 | exp_dir=exp/${dataset}/${arch}/${exp_name}/${net} 10 | snapshot_dir=${exp_dir}/snapshot 11 | result_dir=${exp_dir}/result 12 | config=config/${dataset}/${dataset}_${exp_name}_${net}_${postfix}.yaml 13 | mkdir -p ${snapshot_dir} ${result_dir} 14 | now=$(date +"%Y%m%d_%H%M%S") 15 | cp train_coco.sh train_coco.py ${config} ${exp_dir} 16 | 17 | echo ${arch} 18 | echo ${config} 19 | 20 | python3 -m torch.distributed.launch --nproc_per_node=${gpu} --master_port=${port} train_coco.py \ 21 | --config=${config} \ 22 | --arch=${arch} \ 23 | 2>&1 | tee ${result_dir}/train-$now.log 24 | -------------------------------------------------------------------------------- /train_pascal.sh: -------------------------------------------------------------------------------- 1 | gpu=4 2 | port=$1 # 1234 3 | dataset=pascal 4 | exp_name=split$2 # 0/1/2/3 5 | arch=HMNet 6 | net=$3 # vgg/renet50 7 | postfix=$4 # manet/manet_5s 8 | 9 | exp_dir=exp/${dataset}/${arch}/${exp_name}/${net} 10 | snapshot_dir=${exp_dir}/snapshot 11 | result_dir=${exp_dir}/result 12 | config=config/${dataset}/${dataset}_${exp_name}_${net}_${postfix}.yaml 13 | mkdir -p ${snapshot_dir} ${result_dir} 14 | now=$(date +"%Y%m%d_%H%M%S") 15 | cp train_pascal.sh train_pascal.py ${config} ${exp_dir} 16 | 17 | echo ${arch} 18 | echo ${config} 19 | 20 | python3 -m torch.distributed.launch --nproc_per_node=${gpu} --master_port=${port} train_pascal.py \ 21 | --config=${config} \ 22 | --arch=${arch} \ 23 | 2>&1 | tee ${result_dir}/train-$now.log 24 | -------------------------------------------------------------------------------- /util/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Sam1224/HMNet/556916cc2036f32d68d8e0b810eb9ebcbee6ae1a/util/__init__.py -------------------------------------------------------------------------------- /util/config.py: -------------------------------------------------------------------------------- 1 | # ----------------------------------------------------------------------------- 2 | # Functions for parsing args 3 | # ----------------------------------------------------------------------------- 4 | import yaml 5 | import os 6 | from ast import literal_eval 7 | import copy 8 | 9 | 10 | class CfgNode(dict): 11 | """ 12 | CfgNode represents an internal node in the configuration tree. It's a simple 13 | dict-like container that allows for attribute-based access to keys. 14 | """ 15 | 16 | def __init__(self, init_dict=None, key_list=None, new_allowed=False): 17 | # Recursively convert nested dictionaries in init_dict into CfgNodes 18 | init_dict = {} if init_dict is None else init_dict 19 | key_list = [] if key_list is None else key_list 20 | for k, v in init_dict.items(): 21 | if type(v) is dict: 22 | # Convert dict to CfgNode 23 | init_dict[k] = CfgNode(v, key_list=key_list + [k]) 24 | super(CfgNode, self).__init__(init_dict) 25 | 26 | def __getattr__(self, name): 27 | if name in self: 28 | return self[name] 29 | else: 30 | raise AttributeError(name) 31 | 32 | def __setattr__(self, name, value): 33 | self[name] = value 34 | 35 | def __str__(self): 36 | def _indent(s_, num_spaces): 37 | s = s_.split("\n") 38 | if len(s) == 1: 39 | return s_ 40 | first = s.pop(0) 41 | s = [(num_spaces * " ") + line for line in s] 42 | s = "\n".join(s) 43 | s = first + "\n" + s 44 | return s 45 | 46 | r = "" 47 | s = [] 48 | for k, v in sorted(self.items()): 49 | seperator = "\n" if isinstance(v, CfgNode) else " " 50 | attr_str = "{}:{}{}".format(str(k), seperator, str(v)) 51 | attr_str = _indent(attr_str, 2) 52 | s.append(attr_str) 53 | r += "\n".join(s) 54 | return r 55 | 56 | def __repr__(self): # print 57 | return "{}({})".format(self.__class__.__name__, super(CfgNode, self).__repr__()) 58 | 59 | 60 | def load_cfg_from_cfg_file(file): 61 | cfg = {} 62 | assert os.path.isfile(file) and file.endswith('.yaml'), \ 63 | '{} is not a yaml file'.format(file) 64 | 65 | with open(file, 'r') as f: 66 | cfg_from_file = yaml.safe_load(f) 67 | 68 | for key in cfg_from_file: 69 | for k, v in cfg_from_file[key].items(): 70 | cfg[k] = v 71 | 72 | cfg = CfgNode(cfg) 73 | return cfg 74 | 75 | def merge_cfg_from_args(cfg, args): 76 | args_dict = args.__dict__ 77 | for k ,v in args_dict.items(): 78 | if not k == 'config' or k == 'opts': 79 | cfg[k] = v 80 | 81 | return cfg 82 | 83 | def merge_cfg_from_list(cfg, cfg_list): 84 | new_cfg = copy.deepcopy(cfg) 85 | assert len(cfg_list) % 2 == 0 86 | for full_key, v in zip(cfg_list[0::2], cfg_list[1::2]): 87 | subkey = full_key.split('.')[-1] 88 | assert subkey in cfg, 'Non-existent key: {}'.format(full_key) 89 | value = _decode_cfg_value(v) 90 | value = _check_and_coerce_cfg_value_type( 91 | value, cfg[subkey], subkey, full_key 92 | ) 93 | setattr(new_cfg, subkey, value) 94 | 95 | return new_cfg 96 | 97 | 98 | def _decode_cfg_value(v): 99 | """Decodes a raw config value (e.g., from a yaml config files or command 100 | line argument) into a Python object. 101 | """ 102 | # All remaining processing is only applied to strings 103 | if not isinstance(v, str): 104 | return v 105 | # Try to interpret `v` as a: 106 | # string, number, tuple, list, dict, boolean, or None 107 | try: 108 | v = literal_eval(v) 109 | # The following two excepts allow v to pass through when it represents a 110 | # string. 111 | # 112 | # Longer explanation: 113 | # The type of v is always a string (before calling literal_eval), but 114 | # sometimes it *represents* a string and other times a data structure, like 115 | # a list. In the case that v represents a string, what we got back from the 116 | # yaml parser is 'foo' *without quotes* (so, not '"foo"'). literal_eval is 117 | # ok with '"foo"', but will raise a ValueError if given 'foo'. In other 118 | # cases, like paths (v = 'foo/bar' and not v = '"foo/bar"'), literal_eval 119 | # will raise a SyntaxError. 120 | except ValueError: 121 | pass 122 | except SyntaxError: 123 | pass 124 | return v 125 | 126 | 127 | def _check_and_coerce_cfg_value_type(replacement, original, key, full_key): 128 | """Checks that `replacement`, which is intended to replace `original` is of 129 | the right type. The type is correct if it matches exactly or is one of a few 130 | cases in which the type can be easily coerced. 131 | """ 132 | original_type = type(original) 133 | replacement_type = type(replacement) 134 | 135 | # The types must match (with some exceptions) 136 | if replacement_type == original_type: 137 | return replacement 138 | 139 | # Cast replacement from from_type to to_type if the replacement and original 140 | # types match from_type and to_type 141 | def conditional_cast(from_type, to_type): 142 | if replacement_type == from_type and original_type == to_type: 143 | return True, to_type(replacement) 144 | else: 145 | return False, None 146 | 147 | # Conditionally casts 148 | # list <-> tuple 149 | casts = [(tuple, list), (list, tuple)] 150 | # For py2: allow converting from str (bytes) to a unicode string 151 | try: 152 | casts.append((str, unicode)) # noqa: F821 153 | except Exception: 154 | pass 155 | 156 | for (from_type, to_type) in casts: 157 | converted, converted_value = conditional_cast(from_type, to_type) 158 | if converted: 159 | return converted_value 160 | 161 | raise ValueError( 162 | "Type mismatch ({} vs. {}) with values ({} vs. {}) for config " 163 | "key: {}".format( 164 | original_type, replacement_type, original, replacement, full_key 165 | ) 166 | ) 167 | 168 | 169 | def _assert_with_logging(cond, msg): 170 | if not cond: 171 | logger.debug(msg) 172 | assert cond, msg 173 | 174 | -------------------------------------------------------------------------------- /util/dataset.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | import os.path 4 | import cv2 5 | import numpy as np 6 | import copy 7 | 8 | from torch.utils.data import Dataset 9 | import torch.nn.functional as F 10 | import torch 11 | import random 12 | import time 13 | from tqdm import tqdm 14 | 15 | from .get_weak_anns import transform_anns 16 | 17 | IMG_EXTENSIONS = ['.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm'] 18 | 19 | 20 | def is_image_file(filename): 21 | filename_lower = filename.lower() 22 | return any(filename_lower.endswith(extension) for extension in IMG_EXTENSIONS) 23 | 24 | 25 | def make_dataset(split=0, data_root=None, data_list=None, sub_list=None, filter_intersection=False): 26 | assert split in [0, 1, 2, 3] 27 | if not os.path.isfile(data_list): 28 | raise (RuntimeError("Image list file do not exist: " + data_list + "\n")) 29 | 30 | # Shaban uses these lines to remove small objects: 31 | # if util.change_coordinates(mask, 32.0, 0.0).sum() > 2: 32 | # filtered_item.append(item) 33 | # which means the mask will be downsampled to 1/32 of the original size and the valid area should be larger than 2, 34 | # therefore the area in original size should be accordingly larger than 2 * 32 * 32 35 | image_label_list = [] 36 | list_read = open(data_list).readlines() 37 | print("Processing data...".format(sub_list)) 38 | sub_class_file_list = {} 39 | for sub_c in sub_list: 40 | sub_class_file_list[sub_c] = [] 41 | 42 | for l_idx in tqdm(range(len(list_read))): 43 | line = list_read[l_idx] 44 | line = line.strip() 45 | line_split = line.split(' ') 46 | image_name = os.path.join(data_root, line_split[0]) 47 | label_name = os.path.join(data_root, line_split[1]) 48 | item = (image_name, label_name) 49 | label = cv2.imread(label_name, cv2.IMREAD_GRAYSCALE) 50 | label_class = np.unique(label).tolist() 51 | 52 | if 0 in label_class: 53 | label_class.remove(0) 54 | if 255 in label_class: 55 | label_class.remove(255) 56 | 57 | new_label_class = [] 58 | 59 | if filter_intersection: # filter images containing objects of novel categories during meta-training 60 | if set(label_class).issubset(set(sub_list)): 61 | for c in label_class: 62 | if c in sub_list: 63 | tmp_label = np.zeros_like(label) 64 | target_pix = np.where(label == c) 65 | tmp_label[target_pix[0], target_pix[1]] = 1 66 | if tmp_label.sum() >= 2 * 32 * 32: 67 | new_label_class.append(c) 68 | else: 69 | for c in label_class: 70 | if c in sub_list: 71 | tmp_label = np.zeros_like(label) 72 | target_pix = np.where(label == c) 73 | tmp_label[target_pix[0], target_pix[1]] = 1 74 | if tmp_label.sum() >= 2 * 32 * 32: 75 | new_label_class.append(c) 76 | 77 | label_class = new_label_class 78 | 79 | if len(label_class) > 0: 80 | image_label_list.append(item) 81 | for c in label_class: 82 | if c in sub_list: 83 | sub_class_file_list[c].append(item) 84 | 85 | print("Checking image&label pair {} list done! ".format(split)) 86 | return image_label_list, sub_class_file_list 87 | 88 | 89 | class SemData(Dataset): 90 | def __init__(self, split=3, shot=1, data_root=None, base_data_root=None, data_list=None, data_set=None, 91 | use_split_coco=False, 92 | transform=None, transform_tri=None, mode='train', ann_type='mask', 93 | ft_transform=None, ft_aug_size=None): 94 | 95 | assert mode in ['train', 'val', 'demo', 'finetune'] 96 | assert data_set in ['pascal', 'coco'] 97 | if mode == 'finetune': 98 | assert ft_transform is not None 99 | assert ft_aug_size is not None 100 | 101 | if data_set == 'pascal': 102 | self.num_classes = 20 103 | elif data_set == 'coco': 104 | self.num_classes = 80 105 | 106 | self.mode = mode 107 | self.split = split 108 | self.shot = shot 109 | self.data_root = data_root 110 | self.base_data_root = base_data_root 111 | self.ann_type = ann_type 112 | 113 | if data_set == 'pascal': 114 | self.class_list = list(range(1, 21)) # [1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20] 115 | if self.split == 3: 116 | self.sub_list = list(range(1, 16)) # [1,2,3,4,5,6,7,8,9,10,11,12,13,14,15] 117 | self.sub_val_list = list(range(16, 21)) # [16,17,18,19,20] 118 | elif self.split == 2: 119 | self.sub_list = list(range(1, 11)) + list(range(16, 21)) # [1,2,3,4,5,6,7,8,9,10,16,17,18,19,20] 120 | self.sub_val_list = list(range(11, 16)) # [11,12,13,14,15] 121 | elif self.split == 1: 122 | self.sub_list = list(range(1, 6)) + list(range(11, 21)) # [1,2,3,4,5,11,12,13,14,15,16,17,18,19,20] 123 | self.sub_val_list = list(range(6, 11)) # [6,7,8,9,10] 124 | elif self.split == 0: 125 | self.sub_list = list(range(6, 21)) # [6,7,8,9,10,11,12,13,14,15,16,17,18,19,20] 126 | self.sub_val_list = list(range(1, 6)) # [1,2,3,4,5] 127 | 128 | elif data_set == 'coco': 129 | if use_split_coco: 130 | print('INFO: using SPLIT COCO (FWB)') 131 | self.class_list = list(range(1, 81)) 132 | if self.split == 3: 133 | self.sub_val_list = list(range(4, 81, 4)) 134 | self.sub_list = list(set(self.class_list) - set(self.sub_val_list)) 135 | elif self.split == 2: 136 | self.sub_val_list = list(range(3, 80, 4)) 137 | self.sub_list = list(set(self.class_list) - set(self.sub_val_list)) 138 | elif self.split == 1: 139 | self.sub_val_list = list(range(2, 79, 4)) 140 | self.sub_list = list(set(self.class_list) - set(self.sub_val_list)) 141 | elif self.split == 0: 142 | self.sub_val_list = list(range(1, 78, 4)) 143 | self.sub_list = list(set(self.class_list) - set(self.sub_val_list)) 144 | else: 145 | print('INFO: using COCO (PANet)') 146 | self.class_list = list(range(1, 81)) 147 | if self.split == 3: 148 | self.sub_list = list(range(1, 61)) 149 | self.sub_val_list = list(range(61, 81)) 150 | elif self.split == 2: 151 | self.sub_list = list(range(1, 41)) + list(range(61, 81)) 152 | self.sub_val_list = list(range(41, 61)) 153 | elif self.split == 1: 154 | self.sub_list = list(range(1, 21)) + list(range(41, 81)) 155 | self.sub_val_list = list(range(21, 41)) 156 | elif self.split == 0: 157 | self.sub_list = list(range(21, 81)) 158 | self.sub_val_list = list(range(1, 21)) 159 | 160 | print('sub_list: ', self.sub_list) 161 | print('sub_val_list: ', self.sub_val_list) 162 | 163 | # @@@ For convenience, we skip the step of building datasets and instead use the pre-generated lists @@@ 164 | # if self.mode == 'train': 165 | # self.data_list, self.sub_class_file_list = make_dataset(split, data_root, data_list, self.sub_list, True) 166 | # assert len(self.sub_class_file_list.keys()) == len(self.sub_list) 167 | # elif self.mode == 'val' or self.mode == 'demo' or self.mode == 'finetune': 168 | # self.data_list, self.sub_class_file_list = make_dataset(split, data_root, data_list, self.sub_val_list, False) 169 | # assert len(self.sub_class_file_list.keys()) == len(self.sub_val_list) 170 | 171 | mode = 'train' if self.mode == 'train' else 'val' 172 | 173 | fss_list_root = './lists/{}/fss_list/{}/'.format(data_set, mode) 174 | fss_data_list_path = fss_list_root + 'data_list_{}.txt'.format(split) 175 | fss_sub_class_file_list_path = fss_list_root + 'sub_class_file_list_{}.txt'.format(split) 176 | 177 | # Write FSS Data 178 | # with open(fss_data_list_path, 'w') as f: 179 | # for item in self.data_list: 180 | # img, label = item 181 | # f.write(img + ' ') 182 | # f.write(label + '\n') 183 | # with open(fss_sub_class_file_list_path, 'w') as f: 184 | # f.write(str(self.sub_class_file_list)) 185 | 186 | # Read FSS Data 187 | with open(fss_data_list_path, 'r') as f: 188 | f_str = f.readlines() 189 | self.data_list = [] 190 | for line in f_str: 191 | img, mask = line.split(' ') 192 | self.data_list.append((img, mask.strip())) 193 | 194 | with open(fss_sub_class_file_list_path, 'r') as f: 195 | f_str = f.read() 196 | self.sub_class_file_list = eval(f_str) 197 | 198 | self.transform = transform 199 | self.transform_tri = transform_tri 200 | self.ft_transform = ft_transform 201 | self.ft_aug_size = ft_aug_size 202 | 203 | def __len__(self): 204 | return len(self.data_list) 205 | 206 | def __getitem__(self, index): 207 | image_path, label_path = self.data_list[index] 208 | image = cv2.imread(image_path, cv2.IMREAD_COLOR) 209 | image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) 210 | image = np.float32(image) 211 | label = cv2.imread(label_path, cv2.IMREAD_GRAYSCALE) 212 | 213 | if image.shape[0] != label.shape[0] or image.shape[1] != label.shape[1]: 214 | raise (RuntimeError("Query Image & label shape mismatch: " + image_path + " " + label_path + "\n")) 215 | label_class = np.unique(label).tolist() 216 | if 0 in label_class: 217 | label_class.remove(0) 218 | if 255 in label_class: 219 | label_class.remove(255) 220 | new_label_class = [] 221 | for c in label_class: 222 | if c in self.sub_val_list: 223 | if self.mode == 'val' or self.mode == 'demo' or self.mode == 'finetune': 224 | new_label_class.append(c) 225 | if c in self.sub_list: 226 | if self.mode == 'train': 227 | new_label_class.append(c) 228 | label_class = new_label_class 229 | assert len(label_class) > 0 230 | 231 | class_chosen = label_class[random.randint(1, len(label_class)) - 1] 232 | target_pix = np.where(label == class_chosen) 233 | ignore_pix = np.where(label == 255) 234 | label[:, :] = 0 235 | if target_pix[0].shape[0] > 0: 236 | label[target_pix[0], target_pix[1]] = 1 237 | label[ignore_pix[0], ignore_pix[1]] = 255 238 | 239 | file_class_chosen = self.sub_class_file_list[class_chosen] 240 | num_file = len(file_class_chosen) 241 | 242 | support_image_path_list = [] 243 | support_label_path_list = [] 244 | support_idx_list = [] 245 | for k in range(self.shot): 246 | support_idx = random.randint(1, num_file) - 1 247 | support_image_path = image_path 248 | support_label_path = label_path 249 | while (( 250 | support_image_path == image_path and support_label_path == label_path) or support_idx in support_idx_list): 251 | support_idx = random.randint(1, num_file) - 1 252 | support_image_path, support_label_path = file_class_chosen[support_idx] 253 | support_idx_list.append(support_idx) 254 | support_image_path_list.append(support_image_path) 255 | support_label_path_list.append(support_label_path) 256 | 257 | support_image_list_ori = [] 258 | support_label_list_ori = [] 259 | support_label_list_ori_mask = [] 260 | subcls_list = [] 261 | if self.mode == 'train': 262 | subcls_list.append(self.sub_list.index(class_chosen)) 263 | else: 264 | subcls_list.append(self.sub_val_list.index(class_chosen)) 265 | for k in range(self.shot): 266 | support_image_path = support_image_path_list[k] 267 | support_label_path = support_label_path_list[k] 268 | support_image = cv2.imread(support_image_path, cv2.IMREAD_COLOR) 269 | support_image = cv2.cvtColor(support_image, cv2.COLOR_BGR2RGB) 270 | support_image = np.float32(support_image) 271 | support_label = cv2.imread(support_label_path, cv2.IMREAD_GRAYSCALE) 272 | target_pix = np.where(support_label == class_chosen) 273 | ignore_pix = np.where(support_label == 255) 274 | support_label[:, :] = 0 275 | support_label[target_pix[0], target_pix[1]] = 1 276 | 277 | support_label, support_label_mask = transform_anns(support_label, self.ann_type) # mask/bbox 278 | support_label[ignore_pix[0], ignore_pix[1]] = 255 279 | support_label_mask[ignore_pix[0], ignore_pix[1]] = 255 280 | if support_image.shape[0] != support_label.shape[0] or support_image.shape[1] != support_label.shape[1]: 281 | raise (RuntimeError( 282 | "Support Image & label shape mismatch: " + support_image_path + " " + support_label_path + "\n")) 283 | support_image_list_ori.append(support_image) 284 | support_label_list_ori.append(support_label) 285 | support_label_list_ori_mask.append(support_label_mask) 286 | assert len(support_label_list_ori) == self.shot and len(support_image_list_ori) == self.shot 287 | 288 | raw_image = image.copy() 289 | raw_label = label.copy() 290 | support_image_list = [[] for _ in range(self.shot)] 291 | support_label_list = [[] for _ in range(self.shot)] 292 | if self.transform is not None: 293 | image, label = self.transform(image, label) 294 | for k in range(self.shot): 295 | support_image_list[k], support_label_list[k] = self.transform(support_image_list_ori[k], 296 | support_label_list_ori[k]) 297 | 298 | s_xs = support_image_list 299 | s_ys = support_label_list 300 | s_x = s_xs[0].unsqueeze(0) 301 | for i in range(1, self.shot): 302 | s_x = torch.cat([s_xs[i].unsqueeze(0), s_x], 0) 303 | s_y = s_ys[0].unsqueeze(0) 304 | for i in range(1, self.shot): 305 | s_y = torch.cat([s_ys[i].unsqueeze(0), s_y], 0) 306 | 307 | # Return 308 | if self.mode == 'train': 309 | return image, label, s_x, s_y, subcls_list 310 | elif self.mode == 'val': 311 | return image, label, s_x, s_y, subcls_list, raw_label 312 | -------------------------------------------------------------------------------- /util/get_weak_anns.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import, division 2 | 3 | import networkx as nx 4 | import numpy as np 5 | from scipy.ndimage import binary_dilation, binary_erosion, maximum_filter 6 | from scipy.special import comb 7 | from skimage.filters import rank 8 | from skimage.morphology import dilation, disk, erosion, medial_axis 9 | from sklearn.neighbors import radius_neighbors_graph 10 | import cv2 11 | import matplotlib.pyplot as plt 12 | import matplotlib.patches as mpatches 13 | from scipy import ndimage 14 | 15 | 16 | def find_bbox(mask): 17 | _, labels, stats, centroids = cv2.connectedComponentsWithStats(mask.astype(np.uint8)) 18 | return stats[1:] # remove bg stat 19 | 20 | 21 | def transform_anns(mask, ann_type): 22 | mask_ori = mask.copy() 23 | 24 | if ann_type == 'bbox': 25 | bboxs = find_bbox(mask) 26 | for j in bboxs: 27 | cv2.rectangle(mask, (j[0], j[1]), (j[0] + j[2], j[1] + j[3]), 1, -1) # -1->fill; 2->draw_rec 28 | return mask, mask_ori 29 | 30 | elif ann_type == 'mask': 31 | return mask, mask_ori 32 | 33 | 34 | if __name__ == '__main__': 35 | label_path = '2008_001227.png' 36 | mask = cv2.imread(label_path, cv2.IMREAD_GRAYSCALE) 37 | bboxs = find_bbox(mask) 38 | mask_color = cv2.imread(label_path, cv2.IMREAD_COLOR) 39 | for j in bboxs: 40 | cv2.rectangle(mask_color, (j[0], j[1]), (j[0] + j[2], j[1] + j[3]), (0, 255, 0), -1) 41 | cv2.imwrite('bbox.png', mask_color) 42 | 43 | print('done') 44 | -------------------------------------------------------------------------------- /util/transform.py: -------------------------------------------------------------------------------- 1 | import random 2 | import math 3 | import numpy as np 4 | import numbers 5 | import collections 6 | collections.Iterable = collections.abc.Iterable 7 | import cv2 8 | 9 | import torch 10 | 11 | manual_seed = 123 12 | torch.manual_seed(manual_seed) 13 | np.random.seed(manual_seed) 14 | torch.manual_seed(manual_seed) 15 | torch.cuda.manual_seed_all(manual_seed) 16 | random.seed(manual_seed) 17 | 18 | 19 | class Compose(object): 20 | # Composes segtransforms: segtransform.Compose([segtransform.RandScale([0.5, 2.0]), segtransform.ToTensor()]) 21 | def __init__(self, segtransform): 22 | self.segtransform = segtransform 23 | 24 | def __call__(self, image, label): 25 | for t in self.segtransform: 26 | image, label = t(image, label) 27 | return image, label 28 | 29 | 30 | import time 31 | 32 | 33 | class ToTensor(object): 34 | # Converts numpy.ndarray (H x W x C) to a torch.FloatTensor of shape (C x H x W). 35 | def __call__(self, image, label): 36 | if not isinstance(image, np.ndarray) or not isinstance(label, np.ndarray): 37 | raise (RuntimeError("segtransform.ToTensor() only handle np.ndarray" 38 | "[eg: data readed by cv2.imread()].\n")) 39 | if len(image.shape) > 3 or len(image.shape) < 2: 40 | raise (RuntimeError("segtransform.ToTensor() only handle np.ndarray with 3 dims or 2 dims.\n")) 41 | if len(image.shape) == 2: 42 | image = np.expand_dims(image, axis=2) 43 | if not len(label.shape) == 2: 44 | raise (RuntimeError("segtransform.ToTensor() only handle np.ndarray labellabel with 2 dims.\n")) 45 | 46 | image = torch.from_numpy(image.transpose((2, 0, 1))) 47 | if not isinstance(image, torch.FloatTensor): 48 | image = image.float() 49 | label = torch.from_numpy(label) 50 | if not isinstance(label, torch.LongTensor): 51 | label = label.long() 52 | return image, label 53 | 54 | 55 | class CLAHE(object): 56 | # Apply Contrast Limited Adaptive Histogram Equalization to the input image. 57 | def __init__(self, clip_limit=4.0, tile_grid_size=(8, 8)): 58 | self.clip_limit = clip_limit 59 | self.tile_grid_size = tile_grid_size 60 | 61 | def __call__(self, image, label): 62 | if image.dtype != np.uint8: 63 | image = image.astype(np.uint8) 64 | 65 | clahe_mat = cv2.createCLAHE(clipLimit=self.clip_limit, tileGridSize=self.tile_grid_size) 66 | 67 | if len(image.shape) == 2 or image.shape[2] == 1: 68 | image = clahe_mat.apply(image) 69 | else: 70 | image = cv2.cvtColor(image, cv2.COLOR_RGB2LAB) 71 | image[:, :, 0] = clahe_mat.apply(image[:, :, 0]) 72 | image = cv2.cvtColor(image, cv2.COLOR_LAB2RGB) 73 | image = image.astype(np.float32) 74 | return image, label 75 | 76 | 77 | class ToNumpy(object): 78 | # Converts torch.FloatTensor of shape (C x H x W) to a numpy.ndarray (H x W x C). 79 | def __call__(self, image, label): 80 | if not isinstance(image, torch.Tensor) or not isinstance(label, torch.Tensor): 81 | raise (RuntimeError("segtransform.ToNumpy() only handle torch.tensor")) 82 | 83 | image = image.cpu().numpy().transpose((1, 2, 0)) 84 | if not image.dtype == np.uint8: 85 | image = image.astype(np.uint8) 86 | label = label.cpu().numpy().transpose((1, 2, 0)) 87 | if not label.dtype == np.uint8: 88 | label = label.astype(np.uint8) 89 | return image, label 90 | 91 | 92 | class Normalize(object): 93 | # Normalize tensor with mean and standard deviation along channel: channel = (channel - mean) / std 94 | def __init__(self, mean, std=None): 95 | if std is None: 96 | assert len(mean) > 0 97 | else: 98 | assert len(mean) == len(std) 99 | self.mean = mean 100 | self.std = std 101 | 102 | def __call__(self, image, label): 103 | if self.std is None: 104 | for t, m in zip(image, self.mean): 105 | t.sub_(m) 106 | else: 107 | for t, m, s in zip(image, self.mean, self.std): 108 | t.sub_(m).div_(s) 109 | return image, label 110 | 111 | 112 | class UnNormalize(object): 113 | # UnNormalize tensor with mean and standard deviation along channel: channel = (channel * std) + mean 114 | def __init__(self, mean, std=None): 115 | if std is None: 116 | assert len(mean) > 0 117 | else: 118 | assert len(mean) == len(std) 119 | self.mean = mean 120 | self.std = std 121 | 122 | def __call__(self, image, label): 123 | if self.std is None: 124 | for t, m in zip(image, self.mean): 125 | t.add_(m) 126 | else: 127 | for t, m, s in zip(image, self.mean, self.std): 128 | t.mul_(s).add_(m) 129 | return image, label 130 | 131 | 132 | class Resize(object): 133 | # Resize the input to the given size, 'size' is a 2-element tuple or list in the order of (h, w). 134 | def __init__(self, size): 135 | self.size = size 136 | 137 | def __call__(self, image, label): 138 | 139 | # value_scale = 255 140 | # mean = [0.485, 0.456, 0.406] 141 | # mean = [item * value_scale for item in mean] 142 | # std = [0.229, 0.224, 0.225] 143 | # std = [item * value_scale for item in std] 144 | # 145 | # def find_new_hw(ori_h, ori_w, test_size): 146 | # if ori_h >= ori_w: 147 | # ratio = test_size * 1.0 / ori_h 148 | # new_h = test_size 149 | # new_w = int(ori_w * ratio) 150 | # elif ori_w > ori_h: 151 | # ratio = test_size * 1.0 / ori_w 152 | # new_h = int(ori_h * ratio) 153 | # new_w = test_size 154 | # 155 | # if new_h % 8 != 0: 156 | # new_h = (int(new_h / 8)) * 8 157 | # else: 158 | # new_h = new_h 159 | # if new_w % 8 != 0: 160 | # new_w = (int(new_w / 8)) * 8 161 | # else: 162 | # new_w = new_w 163 | # return new_h, new_w 164 | # 165 | # test_size = self.size 166 | # new_h, new_w = find_new_hw(image.shape[0], image.shape[1], test_size) 167 | # # new_h, new_w = test_size, test_size 168 | # image_crop = cv2.resize(image, dsize=(int(new_w), int(new_h)), interpolation=cv2.INTER_LINEAR) 169 | # back_crop = np.zeros((test_size, test_size, 3)) 170 | # # back_crop[:,:,0] = mean[0] 171 | # # back_crop[:,:,1] = mean[1] 172 | # # back_crop[:,:,2] = mean[2] 173 | # back_crop[:new_h, :new_w, :] = image_crop 174 | # image = back_crop 175 | # 176 | # s_mask = label 177 | # new_h, new_w = find_new_hw(s_mask.shape[0], s_mask.shape[1], test_size) 178 | # # new_h, new_w = test_size, test_size 179 | # s_mask = cv2.resize(s_mask.astype(np.float32), dsize=(int(new_w), int(new_h)), interpolation=cv2.INTER_NEAREST) 180 | # back_crop_s_mask = np.ones((test_size, test_size)) * 255 181 | # back_crop_s_mask[:new_h, :new_w] = s_mask 182 | # label = back_crop_s_mask 183 | 184 | image = cv2.resize(image, dsize=(self.size, self.size), interpolation=cv2.INTER_LINEAR) 185 | label = cv2.resize(label, dsize=(self.size, self.size), interpolation=cv2.INTER_NEAREST) 186 | 187 | return image, label 188 | 189 | 190 | class test_Resize(object): 191 | # Resize the input to the given size, 'size' is a 2-element tuple or list in the order of (h, w). 192 | def __init__(self, size): 193 | self.size = size 194 | 195 | def __call__(self, image, label): 196 | 197 | # value_scale = 255 198 | # mean = [0.485, 0.456, 0.406] 199 | # mean = [item * value_scale for item in mean] 200 | # std = [0.229, 0.224, 0.225] 201 | # std = [item * value_scale for item in std] 202 | # 203 | # def find_new_hw(ori_h, ori_w, test_size): 204 | # if max(ori_h, ori_w) > test_size: 205 | # if ori_h >= ori_w: 206 | # ratio = test_size * 1.0 / ori_h 207 | # new_h = test_size 208 | # new_w = int(ori_w * ratio) 209 | # elif ori_w > ori_h: 210 | # ratio = test_size * 1.0 / ori_w 211 | # new_h = int(ori_h * ratio) 212 | # new_w = test_size 213 | # 214 | # if new_h % 8 != 0: 215 | # new_h = (int(new_h / 8)) * 8 216 | # else: 217 | # new_h = new_h 218 | # if new_w % 8 != 0: 219 | # new_w = (int(new_w / 8)) * 8 220 | # else: 221 | # new_w = new_w 222 | # return new_h, new_w 223 | # else: 224 | # return ori_h, ori_w 225 | # 226 | # test_size = self.size 227 | # new_h, new_w = find_new_hw(image.shape[0], image.shape[1], test_size) 228 | # if new_w != image.shape[0] or new_h != image.shape[1]: 229 | # image_crop = cv2.resize(image, dsize=(int(new_w), int(new_h)), interpolation=cv2.INTER_LINEAR) 230 | # else: 231 | # image_crop = image.copy() 232 | # back_crop = np.zeros((test_size, test_size, 3)) 233 | # back_crop[:new_h, :new_w, :] = image_crop 234 | # image = back_crop 235 | # 236 | # s_mask = label 237 | # new_h, new_w = find_new_hw(s_mask.shape[0], s_mask.shape[1], test_size) 238 | # if new_w != s_mask.shape[0] or new_h != s_mask.shape[1]: 239 | # s_mask = cv2.resize(s_mask.astype(np.float32), dsize=(int(new_w), int(new_h)), 240 | # interpolation=cv2.INTER_NEAREST) 241 | # back_crop_s_mask = np.ones((test_size, test_size)) * 255 242 | # back_crop_s_mask[:new_h, :new_w] = s_mask 243 | # label = back_crop_s_mask 244 | 245 | image = cv2.resize(image, dsize=(self.size, self.size), interpolation=cv2.INTER_LINEAR) 246 | label = cv2.resize(label, dsize=(self.size, self.size), interpolation=cv2.INTER_NEAREST) 247 | 248 | return image, label 249 | 250 | 251 | class Direct_Resize(object): 252 | # Resize the input to the given size, 'size' is a 2-element tuple or list in the order of (h, w). 253 | def __init__(self, size): 254 | self.size = size 255 | 256 | def __call__(self, image, label): 257 | test_size = self.size 258 | 259 | image = cv2.resize(image, dsize=(test_size, test_size), interpolation=cv2.INTER_LINEAR) 260 | label = cv2.resize(label.astype(np.float32), dsize=(test_size, test_size), interpolation=cv2.INTER_NEAREST) 261 | 262 | return image, label 263 | 264 | 265 | class RandScale(object): 266 | # Randomly resize image & label with scale factor in [scale_min, scale_max] 267 | def __init__(self, scale, aspect_ratio=None): 268 | assert (isinstance(scale, collections.Iterable) and len(scale) == 2) 269 | if isinstance(scale, collections.Iterable) and len(scale) == 2 \ 270 | and isinstance(scale[0], numbers.Number) and isinstance(scale[1], numbers.Number) \ 271 | and 0 < scale[0] < scale[1]: 272 | self.scale = scale 273 | else: 274 | raise (RuntimeError("segtransform.RandScale() scale param error.\n")) 275 | if aspect_ratio is None: 276 | self.aspect_ratio = aspect_ratio 277 | elif isinstance(aspect_ratio, collections.Iterable) and len(aspect_ratio) == 2 \ 278 | and isinstance(aspect_ratio[0], numbers.Number) and isinstance(aspect_ratio[1], numbers.Number) \ 279 | and 0 < aspect_ratio[0] < aspect_ratio[1]: 280 | self.aspect_ratio = aspect_ratio 281 | else: 282 | raise (RuntimeError("segtransform.RandScale() aspect_ratio param error.\n")) 283 | 284 | def __call__(self, image, label): 285 | temp_scale = self.scale[0] + (self.scale[1] - self.scale[0]) * random.random() 286 | temp_aspect_ratio = 1.0 287 | if self.aspect_ratio is not None: 288 | temp_aspect_ratio = self.aspect_ratio[0] + (self.aspect_ratio[1] - self.aspect_ratio[0]) * random.random() 289 | temp_aspect_ratio = math.sqrt(temp_aspect_ratio) 290 | scale_factor_x = temp_scale * temp_aspect_ratio 291 | scale_factor_y = temp_scale / temp_aspect_ratio 292 | image = cv2.resize(image, None, fx=scale_factor_x, fy=scale_factor_y, interpolation=cv2.INTER_LINEAR) 293 | label = cv2.resize(label, None, fx=scale_factor_x, fy=scale_factor_y, interpolation=cv2.INTER_NEAREST) 294 | return image, label 295 | 296 | 297 | class Crop(object): 298 | """Crops the given ndarray image (H*W*C or H*W). 299 | Args: 300 | size (sequence or int): Desired output size of the crop. If size is an 301 | int instead of sequence like (h, w), a square crop (size, size) is made. 302 | """ 303 | 304 | def __init__(self, size, crop_type='center', padding=None, ignore_label=255): 305 | self.size = size 306 | if isinstance(size, int): 307 | self.crop_h = size 308 | self.crop_w = size 309 | elif isinstance(size, collections.Iterable) and len(size) == 2 \ 310 | and isinstance(size[0], int) and isinstance(size[1], int) \ 311 | and size[0] > 0 and size[1] > 0: 312 | self.crop_h = size[0] 313 | self.crop_w = size[1] 314 | else: 315 | raise (RuntimeError("crop size error.\n")) 316 | if crop_type == 'center' or crop_type == 'rand': 317 | self.crop_type = crop_type 318 | else: 319 | raise (RuntimeError("crop type error: rand | center\n")) 320 | if padding is None: 321 | self.padding = padding 322 | elif isinstance(padding, list): 323 | if all(isinstance(i, numbers.Number) for i in padding): 324 | self.padding = padding 325 | else: 326 | raise (RuntimeError("padding in Crop() should be a number list\n")) 327 | if len(padding) != 3: 328 | raise (RuntimeError("padding channel is not equal with 3\n")) 329 | else: 330 | raise (RuntimeError("padding in Crop() should be a number list\n")) 331 | if isinstance(ignore_label, int): 332 | self.ignore_label = ignore_label 333 | else: 334 | raise (RuntimeError("ignore_label should be an integer number\n")) 335 | 336 | def __call__(self, image, label): 337 | h, w = label.shape 338 | 339 | pad_h = max(self.crop_h - h, 0) 340 | pad_w = max(self.crop_w - w, 0) 341 | pad_h_half = int(pad_h / 2) 342 | pad_w_half = int(pad_w / 2) 343 | if pad_h > 0 or pad_w > 0: 344 | if self.padding is None: 345 | raise (RuntimeError("segtransform.Crop() need padding while padding argument is None\n")) 346 | image = cv2.copyMakeBorder(image, pad_h_half, pad_h - pad_h_half, pad_w_half, pad_w - pad_w_half, 347 | cv2.BORDER_CONSTANT, value=self.padding) 348 | label = cv2.copyMakeBorder(label, pad_h_half, pad_h - pad_h_half, pad_w_half, pad_w - pad_w_half, 349 | cv2.BORDER_CONSTANT, value=self.ignore_label) 350 | h, w = label.shape 351 | raw_label = label 352 | raw_image = image 353 | 354 | if self.crop_type == 'rand': 355 | h_off = random.randint(0, h - self.crop_h) 356 | w_off = random.randint(0, w - self.crop_w) 357 | else: 358 | h_off = int((h - self.crop_h) / 2) 359 | w_off = int((w - self.crop_w) / 2) 360 | image = image[h_off:h_off + self.crop_h, w_off:w_off + self.crop_w] 361 | label = label[h_off:h_off + self.crop_h, w_off:w_off + self.crop_w] 362 | raw_pos_num = np.sum(raw_label == 1) 363 | pos_num = np.sum(label == 1) 364 | crop_cnt = 0 365 | while (pos_num < 0.85 * raw_pos_num and crop_cnt <= 30): 366 | image = raw_image 367 | label = raw_label 368 | if self.crop_type == 'rand': 369 | h_off = random.randint(0, h - self.crop_h) 370 | w_off = random.randint(0, w - self.crop_w) 371 | else: 372 | h_off = int((h - self.crop_h) / 2) 373 | w_off = int((w - self.crop_w) / 2) 374 | image = image[h_off:h_off + self.crop_h, w_off:w_off + self.crop_w] 375 | label = label[h_off:h_off + self.crop_h, w_off:w_off + self.crop_w] 376 | raw_pos_num = np.sum(raw_label == 1) 377 | pos_num = np.sum(label == 1) 378 | crop_cnt += 1 379 | if crop_cnt >= 50: 380 | image = cv2.resize(raw_image, (self.size[0], self.size[0]), interpolation=cv2.INTER_LINEAR) 381 | label = cv2.resize(raw_label, (self.size[0], self.size[0]), interpolation=cv2.INTER_NEAREST) 382 | 383 | if image.shape != (self.size[0], self.size[0], 3): 384 | image = cv2.resize(image, (self.size[0], self.size[0]), interpolation=cv2.INTER_LINEAR) 385 | label = cv2.resize(label, (self.size[0], self.size[0]), interpolation=cv2.INTER_NEAREST) 386 | 387 | return image, label 388 | 389 | 390 | class RandRotate(object): 391 | # Randomly rotate image & label with rotate factor in [rotate_min, rotate_max] 392 | def __init__(self, rotate, padding, ignore_label=255, p=0.5): 393 | assert (isinstance(rotate, collections.Iterable) and len(rotate) == 2) 394 | if isinstance(rotate[0], numbers.Number) and isinstance(rotate[1], numbers.Number) and rotate[0] < rotate[1]: 395 | self.rotate = rotate 396 | else: 397 | raise (RuntimeError("segtransform.RandRotate() scale param error.\n")) 398 | assert padding is not None 399 | assert isinstance(padding, list) and len(padding) == 3 400 | if all(isinstance(i, numbers.Number) for i in padding): 401 | self.padding = padding 402 | else: 403 | raise (RuntimeError("padding in RandRotate() should be a number list\n")) 404 | assert isinstance(ignore_label, int) 405 | self.ignore_label = ignore_label 406 | self.p = p 407 | 408 | def __call__(self, image, label): 409 | if random.random() < self.p: 410 | angle = self.rotate[0] + (self.rotate[1] - self.rotate[0]) * random.random() 411 | h, w = label.shape 412 | matrix = cv2.getRotationMatrix2D((w / 2, h / 2), angle, 1) 413 | image = cv2.warpAffine(image, matrix, (w, h), flags=cv2.INTER_LINEAR, borderMode=cv2.BORDER_CONSTANT, 414 | borderValue=self.padding) 415 | label = cv2.warpAffine(label, matrix, (w, h), flags=cv2.INTER_NEAREST, borderMode=cv2.BORDER_CONSTANT, 416 | borderValue=self.ignore_label) 417 | return image, label 418 | 419 | 420 | class RandomHorizontalFlip(object): 421 | def __init__(self, p=0.5): 422 | self.p = p 423 | 424 | def __call__(self, image, label): 425 | if random.random() < self.p: 426 | image = cv2.flip(image, 1) 427 | label = cv2.flip(label, 1) 428 | return image, label 429 | 430 | 431 | class RandomVerticalFlip(object): 432 | def __init__(self, p=0.5): 433 | self.p = p 434 | 435 | def __call__(self, image, label): 436 | if random.random() < self.p: 437 | image = cv2.flip(image, 0) 438 | label = cv2.flip(label, 0) 439 | return image, label 440 | 441 | 442 | class RandomGaussianBlur(object): 443 | def __init__(self, radius=5): 444 | self.radius = radius 445 | 446 | def __call__(self, image, label): 447 | if random.random() < 0.5: 448 | image = cv2.GaussianBlur(image, (self.radius, self.radius), 0) 449 | return image, label 450 | 451 | 452 | class RGB2BGR(object): 453 | # Converts image from RGB order to BGR order, for model initialized from Caffe 454 | def __call__(self, image, label): 455 | image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR) 456 | return image, label 457 | 458 | 459 | class BGR2RGB(object): 460 | # Converts image from BGR order to RGB order, for model initialized from Pytorch 461 | def __call__(self, image, label): 462 | image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) 463 | return image, label 464 | -------------------------------------------------------------------------------- /util/util.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | from PIL import Image 4 | import random 5 | import logging 6 | import cv2 7 | import matplotlib.pyplot as plt 8 | from matplotlib.pyplot import MultipleLocator 9 | from matplotlib.ticker import FuncFormatter, FormatStrFormatter 10 | from matplotlib import font_manager 11 | from matplotlib import rcParams 12 | import seaborn as sns 13 | import pandas as pd 14 | import math 15 | from seaborn.distributions import distplot 16 | from tqdm import tqdm 17 | from scipy import ndimage 18 | 19 | # from util.get_weak_anns import find_bbox 20 | 21 | import torch 22 | from torch import nn 23 | import torch.backends.cudnn as cudnn 24 | import torch.nn.init as initer 25 | 26 | 27 | class AverageMeter(object): 28 | """Computes and stores the average and current value""" 29 | 30 | def __init__(self): 31 | self.reset() 32 | 33 | def reset(self): 34 | self.val = 0 35 | self.avg = 0 36 | self.sum = 0 37 | self.count = 0 38 | 39 | def update(self, val, n=1): 40 | self.val = val 41 | self.sum += val * n 42 | self.count += n 43 | self.avg = self.sum / self.count 44 | 45 | 46 | def step_learning_rate(optimizer, base_lr, epoch, step_epoch, multiplier=0.1): 47 | """Sets the learning rate to the base LR decayed by 10 every step epochs""" 48 | lr = base_lr * (multiplier ** (epoch // step_epoch)) 49 | for param_group in optimizer.param_groups: 50 | param_group['lr'] = lr 51 | 52 | 53 | def poly_learning_rate(optimizer, base_lr, curr_iter, max_iter, power=0.9, index_split=-1, scale_lr=10., warmup=False, 54 | warmup_step=500): 55 | """poly learning rate policy""" 56 | if warmup and curr_iter < warmup_step: 57 | lr = base_lr * (0.1 + 0.9 * (curr_iter / warmup_step)) 58 | else: 59 | lr = base_lr * (1 - float(curr_iter) / max_iter) ** power 60 | 61 | # if curr_iter % 50 == 0: 62 | # print('Base LR: {:.4f}, Curr LR: {:.4f}, Warmup: {}.'.format(base_lr, lr, (warmup and curr_iter < warmup_step))) 63 | 64 | for index, param_group in enumerate(optimizer.param_groups): 65 | if index <= index_split: 66 | param_group['lr'] = lr 67 | else: 68 | param_group['lr'] = lr * scale_lr # 10x LR 69 | 70 | 71 | def intersectionAndUnion(output, target, K, ignore_index=255): 72 | # 'K' classes, output and target sizes are N or N * L or N * H * W, each value in range 0 to K - 1. 73 | assert (output.ndim in [1, 2, 3]) 74 | assert output.shape == target.shape 75 | output = output.reshape(output.size).copy() 76 | target = target.reshape(target.size) 77 | output[np.where(target == ignore_index)[0]] = ignore_index 78 | intersection = output[np.where(output == target)[0]] 79 | area_intersection, _ = np.histogram(intersection, bins=np.arange(K + 1)) 80 | area_output, _ = np.histogram(output, bins=np.arange(K + 1)) 81 | area_target, _ = np.histogram(target, bins=np.arange(K + 1)) 82 | area_union = area_output + area_target - area_intersection 83 | return area_intersection, area_union, area_target 84 | 85 | 86 | def intersectionAndUnionGPU(output, target, K, ignore_index=255): 87 | # 'K' classes, output and target sizes are N or N * L or N * H * W, each value in range 0 to K - 1. 88 | assert (output.dim() in [1, 2, 3]) 89 | assert output.shape == target.shape 90 | output = output.view(-1) 91 | target = target.view(-1) 92 | output[target == ignore_index] = ignore_index 93 | intersection = output[output == target] 94 | area_intersection = torch.histc(intersection, bins=K, min=0, max=K - 1) 95 | area_output = torch.histc(output, bins=K, min=0, max=K - 1) 96 | area_target = torch.histc(target, bins=K, min=0, max=K - 1) 97 | area_union = area_output + area_target - area_intersection 98 | return area_intersection, area_union, area_target 99 | 100 | 101 | def check_mkdir(dir_name): 102 | if not os.path.exists(dir_name): 103 | os.mkdir(dir_name) 104 | 105 | 106 | def check_makedirs(dir_name): 107 | if not os.path.exists(dir_name): 108 | os.makedirs(dir_name) 109 | 110 | 111 | def del_file(path): 112 | for i in os.listdir(path): 113 | path_file = os.path.join(path, i) 114 | if os.path.isfile(path_file): 115 | os.remove(path_file) 116 | else: 117 | del_file(path_file) 118 | 119 | 120 | def init_weights(model, conv='kaiming', batchnorm='normal', linear='kaiming', lstm='kaiming'): 121 | """ 122 | :param model: Pytorch Model which is nn.Module 123 | :param conv: 'kaiming' or 'xavier' 124 | :param batchnorm: 'normal' or 'constant' 125 | :param linear: 'kaiming' or 'xavier' 126 | :param lstm: 'kaiming' or 'xavier' 127 | """ 128 | for m in model.modules(): 129 | if isinstance(m, (nn.Conv1d, nn.Conv2d, nn.Conv3d)): 130 | if conv == 'kaiming': 131 | initer.kaiming_normal_(m.weight) 132 | elif conv == 'xavier': 133 | initer.xavier_normal_(m.weight) 134 | else: 135 | raise ValueError("init type of conv error.\n") 136 | if m.bias is not None: 137 | initer.constant_(m.bias, 0) 138 | 139 | elif isinstance(m, 140 | (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d)): # , BatchNorm1d, BatchNorm2d, BatchNorm3d)): 141 | if batchnorm == 'normal': 142 | initer.normal_(m.weight, 1.0, 0.02) 143 | elif batchnorm == 'constant': 144 | initer.constant_(m.weight, 1.0) 145 | else: 146 | raise ValueError("init type of batchnorm error.\n") 147 | initer.constant_(m.bias, 0.0) 148 | 149 | elif isinstance(m, nn.Linear): 150 | if linear == 'kaiming': 151 | initer.kaiming_normal_(m.weight) 152 | elif linear == 'xavier': 153 | initer.xavier_normal_(m.weight) 154 | else: 155 | raise ValueError("init type of linear error.\n") 156 | if m.bias is not None: 157 | initer.constant_(m.bias, 0) 158 | 159 | elif isinstance(m, nn.LSTM): 160 | for name, param in m.named_parameters(): 161 | if 'weight' in name: 162 | if lstm == 'kaiming': 163 | initer.kaiming_normal_(param) 164 | elif lstm == 'xavier': 165 | initer.xavier_normal_(param) 166 | else: 167 | raise ValueError("init type of lstm error.\n") 168 | elif 'bias' in name: 169 | initer.constant_(param, 0) 170 | 171 | 172 | def colorize(gray, palette): 173 | # gray: numpy array of the label and 1*3N size list palette 174 | color = Image.fromarray(gray.astype(np.uint8)).convert('P') 175 | color.putpalette(palette) 176 | return color 177 | 178 | 179 | # ------------------------------------------------------ 180 | def get_model_para_number(model): 181 | total_number = 0 182 | learnable_number = 0 183 | for para in model.parameters(): 184 | total_number += torch.numel(para) 185 | if para.requires_grad == True: 186 | learnable_number += torch.numel(para) 187 | return total_number, learnable_number 188 | 189 | 190 | def setup_seed(seed=2021, deterministic=False): 191 | if deterministic: 192 | cudnn.benchmark = False 193 | cudnn.deterministic = True 194 | torch.manual_seed(seed) 195 | torch.cuda.manual_seed(seed) 196 | torch.cuda.manual_seed_all(seed) 197 | np.random.seed(seed) 198 | random.seed(seed) 199 | os.environ['PYTHONHASHSEED'] = str(seed) 200 | 201 | 202 | def get_logger(): 203 | logger_name = "main-logger" 204 | logger = logging.getLogger() 205 | logger.setLevel(logging.INFO) 206 | handler = logging.StreamHandler() 207 | fmt = "[%(asctime)s %(levelname)s %(filename)s line %(lineno)d %(process)d] %(message)s" 208 | handler.setFormatter(logging.Formatter(fmt)) 209 | logger.addHandler(handler) 210 | return logger 211 | 212 | 213 | def get_save_path(args): 214 | backbone_str = 'vgg' if args.vgg else 'resnet' + str(args.layers) 215 | args.snapshot_path = 'exp/{}/{}/split{}/{}/snapshot'.format(args.data_set, args.arch, args.split, backbone_str) 216 | args.result_path = 'exp/{}/{}/split{}/{}/result'.format(args.data_set, args.arch, args.split, backbone_str) 217 | 218 | 219 | def get_train_val_set(args): 220 | if args.data_set == 'pascal': 221 | class_list = list(range(1, 21)) # [1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20] 222 | if args.split == 3: 223 | sub_list = list(range(1, 16)) # [1,2,3,4,5,6,7,8,9,10,11,12,13,14,15] 224 | sub_val_list = list(range(16, 21)) # [16,17,18,19,20] 225 | elif args.split == 2: 226 | sub_list = list(range(1, 11)) + list(range(16, 21)) # [1,2,3,4,5,6,7,8,9,10,16,17,18,19,20] 227 | sub_val_list = list(range(11, 16)) # [11,12,13,14,15] 228 | elif args.split == 1: 229 | sub_list = list(range(1, 6)) + list(range(11, 21)) # [1,2,3,4,5,11,12,13,14,15,16,17,18,19,20] 230 | sub_val_list = list(range(6, 11)) # [6,7,8,9,10] 231 | elif args.split == 0: 232 | sub_list = list(range(6, 21)) # [6,7,8,9,10,11,12,13,14,15,16,17,18,19,20] 233 | sub_val_list = list(range(1, 6)) # [1,2,3,4,5] 234 | 235 | elif args.data_set == 'coco': 236 | if args.use_split_coco: 237 | print('INFO: using SPLIT COCO (FWB)') 238 | class_list = list(range(1, 81)) 239 | if args.split == 3: 240 | sub_val_list = list(range(4, 81, 4)) 241 | sub_list = list(set(class_list) - set(sub_val_list)) 242 | elif args.split == 2: 243 | sub_val_list = list(range(3, 80, 4)) 244 | sub_list = list(set(class_list) - set(sub_val_list)) 245 | elif args.split == 1: 246 | sub_val_list = list(range(2, 79, 4)) 247 | sub_list = list(set(class_list) - set(sub_val_list)) 248 | elif args.split == 0: 249 | sub_val_list = list(range(1, 78, 4)) 250 | sub_list = list(set(class_list) - set(sub_val_list)) 251 | else: 252 | print('INFO: using COCO (PANet)') 253 | class_list = list(range(1, 81)) 254 | if args.split == 3: 255 | sub_list = list(range(1, 61)) 256 | sub_val_list = list(range(61, 81)) 257 | elif args.split == 2: 258 | sub_list = list(range(1, 41)) + list(range(61, 81)) 259 | sub_val_list = list(range(41, 61)) 260 | elif args.split == 1: 261 | sub_list = list(range(1, 21)) + list(range(41, 81)) 262 | sub_val_list = list(range(21, 41)) 263 | elif args.split == 0: 264 | sub_list = list(range(21, 81)) 265 | sub_val_list = list(range(1, 21)) 266 | 267 | return sub_list, sub_val_list 268 | 269 | 270 | def is_same_model(model1, model2): 271 | flag = 0 272 | count = 0 273 | for k, v in model1.state_dict().items(): 274 | model1_val = v 275 | model2_val = model2.state_dict()[k] 276 | if (model1_val == model2_val).all(): 277 | pass 278 | else: 279 | flag += 1 280 | print('value of key <{}> mismatch'.format(k)) 281 | count += 1 282 | 283 | return True if flag == 0 else False 284 | 285 | 286 | def fix_bn(m): 287 | classname = m.__class__.__name__ 288 | if classname.find('BatchNorm') != -1: 289 | m.eval() 290 | 291 | 292 | def sum_list(list): 293 | sum = 0 294 | for item in list: 295 | sum += item 296 | return sum 297 | --------------------------------------------------------------------------------