├── LICENSE
├── README.md
├── config
├── coco
│ ├── coco_split0_resnet50_manet.yaml
│ ├── coco_split0_resnet50_manet_5s.yaml
│ ├── coco_split0_vgg_manet.yaml
│ ├── coco_split0_vgg_manet_5s.yaml
│ ├── coco_split1_resnet50_manet.yaml
│ ├── coco_split1_resnet50_manet_5s.yaml
│ ├── coco_split1_vgg_manet.yaml
│ ├── coco_split1_vgg_manet_5s.yaml
│ ├── coco_split2_resnet50_manet.yaml
│ ├── coco_split2_resnet50_manet_5s.yaml
│ ├── coco_split2_vgg_manet.yaml
│ ├── coco_split2_vgg_manet_5s.yaml
│ ├── coco_split3_resnet50_manet.yaml
│ ├── coco_split3_resnet50_manet_5s.yaml
│ ├── coco_split3_vgg_manet.yaml
│ └── coco_split3_vgg_manet_5s.yaml
└── pascal
│ ├── pascal_split0_resnet50_manet.yaml
│ ├── pascal_split0_resnet50_manet_5s.yaml
│ ├── pascal_split0_vgg_manet.yaml
│ ├── pascal_split0_vgg_manet_5s.yaml
│ ├── pascal_split1_resnet50_manet.yaml
│ ├── pascal_split1_resnet50_manet_5s.yaml
│ ├── pascal_split1_vgg_manet.yaml
│ ├── pascal_split1_vgg_manet_5s.yaml
│ ├── pascal_split2_resnet50_manet.yaml
│ ├── pascal_split2_resnet50_manet_5s.yaml
│ ├── pascal_split2_vgg_manet.yaml
│ ├── pascal_split2_vgg_manet_5s.yaml
│ ├── pascal_split3_resnet50_manet.yaml
│ ├── pascal_split3_resnet50_manet_5s.yaml
│ ├── pascal_split3_vgg_manet.yaml
│ └── pascal_split3_vgg_manet_5s.yaml
├── env.yaml
├── lists
├── coco
│ ├── fss_list
│ │ ├── train
│ │ │ ├── data_list_0.txt
│ │ │ ├── data_list_1.txt
│ │ │ ├── data_list_2.txt
│ │ │ ├── data_list_3.txt
│ │ │ ├── sub_class_file_list_0.txt
│ │ │ ├── sub_class_file_list_1.txt
│ │ │ ├── sub_class_file_list_2.txt
│ │ │ └── sub_class_file_list_3.txt
│ │ └── val
│ │ │ ├── data_list_0.txt
│ │ │ ├── data_list_1.txt
│ │ │ ├── data_list_2.txt
│ │ │ ├── data_list_3.txt
│ │ │ ├── sub_class_file_list_0.txt
│ │ │ ├── sub_class_file_list_1.txt
│ │ │ ├── sub_class_file_list_2.txt
│ │ │ └── sub_class_file_list_3.txt
│ ├── train.txt
│ ├── train_data_list.txt
│ ├── val.txt
│ └── val_data_list.txt
└── pascal
│ ├── duplicate_removel.py
│ ├── fss_list
│ ├── train
│ │ ├── data_list_0.txt
│ │ ├── data_list_1.txt
│ │ ├── data_list_2.txt
│ │ ├── data_list_3.txt
│ │ ├── sub_class_file_list_0.txt
│ │ ├── sub_class_file_list_1.txt
│ │ ├── sub_class_file_list_2.txt
│ │ └── sub_class_file_list_3.txt
│ └── val
│ │ ├── data_list_0.txt
│ │ ├── data_list_1.txt
│ │ ├── data_list_2.txt
│ │ ├── data_list_3.txt
│ │ ├── sub_class_file_list_0.txt
│ │ ├── sub_class_file_list_1.txt
│ │ ├── sub_class_file_list_2.txt
│ │ └── sub_class_file_list_3.txt
│ ├── sbd_data.txt
│ ├── val.txt
│ ├── voc_original_train.txt
│ ├── voc_sbd_merge.txt
│ └── voc_sbd_merge_noduplicate.txt
├── model
├── ASPP.py
├── HMNet.py
├── HMNetAMP.py
├── PPM.py
├── PSPNet.py
├── backbone_res.py
├── backbone_utils.py
├── loss.py
├── mamba_blocks.py
├── resnet.py
└── vgg.py
├── test_coco.py
├── test_coco.sh
├── test_pascal.py
├── test_pascal.sh
├── train_coco.py
├── train_coco.sh
├── train_pascal.py
├── train_pascal.sh
└── util
├── __init__.py
├── config.py
├── dataset.py
├── get_weak_anns.py
├── transform.py
├── transform_tri.py
└── util.py
/LICENSE:
--------------------------------------------------------------------------------
1 | S-Lab License 1.0
2 |
3 | Copyright 2024 S-Lab
4 |
5 | Redistribution and use for non-commercial purpose in source and binary forms, with or without modification, are permitted provided that the following conditions are met:
6 |
7 | 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.
8 |
9 | 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution.
10 |
11 | 3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission.
12 |
13 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
14 |
15 | 4. In the event that redistribution and/or use for commercial purpose in source or binary forms, with or without modification is required, please contact the contributor(s) of the work.
16 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Hybrid Mamba for Few-Shot Segmentation
2 |
3 | This repository contains the code for our NIPS 2024 [paper](https://arxiv.org/abs/2409.19613) "*Hybrid Mamba for Few-Shot Segmentation*", where we design a cross attention-like Mamba method to enable support-query interactions.
4 |
5 | > **Abstract**: *Many few-shot segmentation (FSS) methods use cross attention to fuse support foreground (FG) into query features, regardless of the quadratic complexity. A recent advance Mamba can also well capture intra-sequence dependencies, yet the complexity is only linear. Hence, we aim to devise a cross (attention-like) Mamba to capture inter-sequence dependencies for FSS. A simple idea is to scan on support features to selectively compress them into the hidden state, which is then used as the initial hidden state to sequentially scan query features. Nevertheless, it suffers from (1) support forgetting issue: query features will also gradually be compressed when scanning on them, so the support features in hidden state keep reducing, and many query pixels cannot fuse sufficient support features; (2) intra-class gap issue: query FG is essentially more similar to itself rather than support FG, i.e., query may prefer not to fuse support but their own features from the hidden state, yet the effective use of support information leads to the success of FSS. To tackle them, we design a hybrid Mamba network (HMNet), including (1) a support recapped Mamba to periodically recap the support features when scanning query, so the hidden state can always contain rich support information; (2) a query intercepted Mamba to forbid the mutual interactions among query pixels, and encourage them to fuse more support features from the hidden state. Consequently, the support information is better utilized, leading to better performance. Extensive experiments have been conducted on two public benchmarks, showing the superiority of HMNet.*
6 |
7 | ## Dependencies
8 |
9 | - Python 3.10
10 | - PyTorch 1.12.0
11 | - cuda 11.6
12 | - torchvision 0.13.0
13 | ```
14 | > conda env create -f env.yaml
15 | ```
16 |
17 | ## Datasets
18 |
19 | - PASCAL-5i: [VOC2012](http://host.robots.ox.ac.uk/pascal/VOC/voc2012/) + [SBD](http://home.bharathh.info/pubs/codes/SBD/download.html)
20 | - COCO-20i: [COCO2014](https://cocodataset.org/#download)
21 |
22 | You can download the pre-processed PASCAL-5i and COCO-20i datasets [here](https://entuedu-my.sharepoint.com/:f:/g/personal/qianxion001_e_ntu_edu_sg/ErEg1GJF6ldCt1vh00MLYYwBapLiCIbd-VgbPAgCjBb_TQ?e=ibJ4DM), and extract them into `data/` folder. Then, you need to create a symbolic link to the `pascal/VOCdevkit` data folder as follows:
23 | ```
24 | > ln -s /data/pascal/VOCdevkit /data/VOCdevkit2012
25 | ```
26 |
27 | The directory structure is:
28 |
29 | ../
30 | ├── HMNet/
31 | └── data/
32 | ├── VOCdevkit2012/
33 | │ └── VOC2012/
34 | │ ├── JPEGImages/
35 | │ ├── ...
36 | │ └── SegmentationClassAug/
37 | └── MSCOCO2014/
38 | ├── annotations/
39 | │ ├── train2014/
40 | │ └── val2014/
41 | ├── train2014/
42 | └── val2014/
43 |
44 | ## Models
45 |
46 | - Download the pretrained backbones from [here](https://entuedu-my.sharepoint.com/:u:/g/personal/qianxion001_e_ntu_edu_sg/EUHlKdET3mJGie_IjtpzW5kBo45yz0PB2dW9n55Vo5acXw?e=uyuUDX) and put them into the `initmodel` directory.
47 | - Download [exp.tar.gz](https://entuedu-my.sharepoint.com/:u:/g/personal/qianxion001_e_ntu_edu_sg/EbcNC1Ram0lJozZ2qe624uEBhXNmKrI64CM0uhEPJuxaig?e=iDMvrI) to obtain all trained models for PASCAL-5i and COCO-20i.
48 |
49 | ## Testing
50 |
51 | - **Commands**:
52 | ```
53 | sh test_pascal.sh {Split: 0/1/2/3} {Net: resnet50/vgg} {Postfix: manet/manet_5s}
54 | sh test_coco.sh {Split: 0/1/2/3} {Net: resnet50/vgg} {Postfix: manet/manet_5s}
55 |
56 | # e.g., testing split 0 under 1-shot setting on PASCAL-5i, with ResNet50 as the pretrained backbone:
57 | sh test_pascal.sh 0 resnet50 manet
58 |
59 | # e.g., testing split 0 under 5-shot setting on COCO-20i, with ResNet50 as the pretrained backbone:
60 | sh test_coco.sh 0 resnet50 manet_5s
61 | ```
62 |
63 | ## References
64 |
65 | This repo is mainly built based on [BAM](https://github.com/chunbolang/BAM). Thanks for their great work!
66 |
67 |
--------------------------------------------------------------------------------
/config/coco/coco_split0_resnet50_manet.yaml:
--------------------------------------------------------------------------------
1 | Data:
2 | data_root: ../data/MSCOCO2014
3 | train_list: ./lists/coco/train.txt
4 | val_list: ./lists/coco/val.txt
5 | classes: 2
6 |
7 |
8 | Train:
9 | # Aug
10 | train_h: 633
11 | train_w: 633
12 | val_size: 633
13 | scale_min: 0.8 # minimum random scale
14 | scale_max: 1.25 # maximum random scale
15 | rotate_min: -10 # minimum random rotate
16 | rotate_max: 10 # maximum random rotate
17 | ignore_label: 255
18 | padding_label: 255
19 | # Dataset & Mode
20 | split: 0
21 | shot: 1
22 | data_set: 'coco'
23 | use_split_coco: True # True means FWB setting
24 | # Optimizer
25 | batch_size: 2 # batch size for training (bs8 for 1GPU)
26 | base_lr: 0.005
27 | epochs: 75
28 | start_epoch: 0
29 | stop_interval: 75 # stop when the best result is not updated for "stop_interval" epochs
30 | index_split: -1 # index for determining the params group with 10x learning rate
31 | power: 0.9 # 0 means no decay
32 | momentum: 0.9
33 | weight_decay: 0.0001
34 | warmup: False
35 | # Viz & Save & Resume
36 | print_freq: 10
37 | save_freq: 10
38 | resume: # path to latest checkpoint (default: none, such as epoch_10.pth)
39 | # Validate
40 | evaluate: True
41 | SubEpoch_val: True # val at the half epoch
42 | fix_random_seed_val: True
43 | batch_size_val: 1
44 | resized_val: True
45 | ori_resize: True # use original label for evaluation
46 | # Else
47 | workers: 32
48 | fix_bn: True
49 | manual_seed: 321
50 | seed_deterministic: False
51 | zoom_factor: 8 # zoom factor for final prediction during training, be in [1, 2, 4, 8]
52 |
53 | Method:
54 | layers: 50
55 | vgg: False
56 | aux_weight1: 1.0
57 | aux_weight2: 1.0
58 | low_fea: 'layer2' # low_fea for computing the Gram matrix
59 | kshot_trans_dim: 2 # K-shot dimensionality reduction
60 | merge: 'final' # fusion scheme for GFSS ('base' Eq(S1) | 'final' Eq(18) )
61 | merge_tau: 0.9 # fusion threshold tau
62 |
63 | Test_Finetune:
64 | weight: best_1shot.pth # load weight for fine-tuning or testing (such as train5_epoch_47.5_0.4926.pth)
65 | ann_type: 'mask' # mask/bbox
66 |
67 |
68 |
69 | ## deprecated multi-processing training
70 | # Distributed:
71 | # dist_url: tcp://127.0.0.1:6789
72 | # dist_backend: 'nccl'
73 | # multiprocessing_distributed: False
74 | # world_size: 1
75 | # rank: 0
76 | # use_apex: False
77 | # opt_level: 'O0'
78 | # keep_batchnorm_fp32:
79 | # loss_scale:
80 |
81 |
--------------------------------------------------------------------------------
/config/coco/coco_split0_resnet50_manet_5s.yaml:
--------------------------------------------------------------------------------
1 | Data:
2 | data_root: ../data/MSCOCO2014
3 | train_list: ./lists/coco/train.txt
4 | val_list: ./lists/coco/val.txt
5 | classes: 2
6 |
7 |
8 | Train:
9 | # Aug
10 | train_h: 633
11 | train_w: 633
12 | val_size: 633
13 | scale_min: 0.8 # minimum random scale
14 | scale_max: 1.25 # maximum random scale
15 | rotate_min: -10 # minimum random rotate
16 | rotate_max: 10 # maximum random rotate
17 | ignore_label: 255
18 | padding_label: 255
19 | # Dataset & Mode
20 | split: 0
21 | shot: 5
22 | data_set: 'coco'
23 | use_split_coco: True # True means FWB setting
24 | # Optimizer
25 | batch_size: 2 # batch size for training (bs8 for 1GPU)
26 | base_lr: 0.005
27 | epochs: 75
28 | start_epoch: 0
29 | stop_interval: 75 # stop when the best result is not updated for "stop_interval" epochs
30 | index_split: -1 # index for determining the params group with 10x learning rate
31 | power: 0.9 # 0 means no decay
32 | momentum: 0.9
33 | weight_decay: 0.0001
34 | warmup: False
35 | # Viz & Save & Resume
36 | print_freq: 10
37 | save_freq: 10
38 | resume: # path to latest checkpoint (default: none, such as epoch_10.pth)
39 | # Validate
40 | evaluate: True
41 | SubEpoch_val: True # val at the half epoch
42 | fix_random_seed_val: True
43 | batch_size_val: 1
44 | resized_val: True
45 | ori_resize: True # use original label for evaluation
46 | # Else
47 | workers: 32
48 | fix_bn: True
49 | manual_seed: 321
50 | seed_deterministic: False
51 | zoom_factor: 8 # zoom factor for final prediction during training, be in [1, 2, 4, 8]
52 |
53 | Method:
54 | layers: 50
55 | vgg: False
56 | aux_weight1: 1.0
57 | aux_weight2: 1.0
58 | low_fea: 'layer2' # low_fea for computing the Gram matrix
59 | kshot_trans_dim: 2 # K-shot dimensionality reduction
60 | merge: 'final' # fusion scheme for GFSS ('base' Eq(S1) | 'final' Eq(18) )
61 | merge_tau: 0.9 # fusion threshold tau
62 |
63 | Test_Finetune:
64 | weight: best_5shot.pth # load weight for fine-tuning or testing (such as train5_epoch_47.5_0.4926.pth)
65 | ann_type: 'mask' # mask/bbox
66 |
67 |
68 |
69 | ## deprecated multi-processing training
70 | # Distributed:
71 | # dist_url: tcp://127.0.0.1:6789
72 | # dist_backend: 'nccl'
73 | # multiprocessing_distributed: False
74 | # world_size: 1
75 | # rank: 0
76 | # use_apex: False
77 | # opt_level: 'O0'
78 | # keep_batchnorm_fp32:
79 | # loss_scale:
80 |
81 |
--------------------------------------------------------------------------------
/config/coco/coco_split0_vgg_manet.yaml:
--------------------------------------------------------------------------------
1 | Data:
2 | data_root: ../data/MSCOCO2014
3 | train_list: ./lists/coco/train.txt
4 | val_list: ./lists/coco/val.txt
5 | classes: 2
6 |
7 |
8 | Train:
9 | # Aug
10 | train_h: 633
11 | train_w: 633
12 | val_size: 633
13 | scale_min: 0.8 # minimum random scale
14 | scale_max: 1.25 # maximum random scale
15 | rotate_min: -10 # minimum random rotate
16 | rotate_max: 10 # maximum random rotate
17 | ignore_label: 255
18 | padding_label: 255
19 | # Dataset & Mode
20 | split: 0
21 | shot: 1
22 | data_set: 'coco'
23 | use_split_coco: True # True means FWB setting
24 | # Optimizer
25 | batch_size: 2 # batch size for training (bs8 for 1GPU)
26 | base_lr: 0.005
27 | epochs: 75
28 | start_epoch: 0
29 | stop_interval: 75 # stop when the best result is not updated for "stop_interval" epochs
30 | index_split: -1 # index for determining the params group with 10x learning rate
31 | power: 0.9 # 0 means no decay
32 | momentum: 0.9
33 | weight_decay: 0.0001
34 | warmup: False
35 | # Viz & Save & Resume
36 | print_freq: 10
37 | save_freq: 10
38 | resume: # path to latest checkpoint (default: none, such as epoch_10.pth)
39 | # Validate
40 | evaluate: True
41 | SubEpoch_val: True # val at the half epoch
42 | fix_random_seed_val: True
43 | batch_size_val: 1
44 | resized_val: True
45 | ori_resize: True # use original label for evaluation
46 | # Else
47 | workers: 32
48 | fix_bn: True
49 | manual_seed: 321
50 | seed_deterministic: False
51 | zoom_factor: 8 # zoom factor for final prediction during training, be in [1, 2, 4, 8]
52 |
53 | Method:
54 | layers: 50
55 | vgg: True
56 | aux_weight1: 1.0
57 | aux_weight2: 1.0
58 | low_fea: 'layer2' # low_fea for computing the Gram matrix
59 | kshot_trans_dim: 2 # K-shot dimensionality reduction
60 | merge: 'final' # fusion scheme for GFSS ('base' Eq(S1) | 'final' Eq(18) )
61 | merge_tau: 0.9 # fusion threshold tau
62 |
63 | Test_Finetune:
64 | weight: best_1shot.pth # train_epoch_36_0.3953.pth # load weight for fine-tuning or testing (such as train5_epoch_47.5_0.4926.pth)
65 | ann_type: 'mask' # mask/bbox
66 |
67 |
68 |
69 | ## deprecated multi-processing training
70 | # Distributed:
71 | # dist_url: tcp://127.0.0.1:6789
72 | # dist_backend: 'nccl'
73 | # multiprocessing_distributed: False
74 | # world_size: 1
75 | # rank: 0
76 | # use_apex: False
77 | # opt_level: 'O0'
78 | # keep_batchnorm_fp32:
79 | # loss_scale:
80 |
81 |
--------------------------------------------------------------------------------
/config/coco/coco_split0_vgg_manet_5s.yaml:
--------------------------------------------------------------------------------
1 | Data:
2 | data_root: ../data/MSCOCO2014
3 | train_list: ./lists/coco/train.txt
4 | val_list: ./lists/coco/val.txt
5 | classes: 2
6 |
7 |
8 | Train:
9 | # Aug
10 | train_h: 633
11 | train_w: 633
12 | val_size: 633
13 | scale_min: 0.8 # minimum random scale
14 | scale_max: 1.25 # maximum random scale
15 | rotate_min: -10 # minimum random rotate
16 | rotate_max: 10 # maximum random rotate
17 | ignore_label: 255
18 | padding_label: 255
19 | # Dataset & Mode
20 | split: 0
21 | shot: 5
22 | data_set: 'coco'
23 | use_split_coco: True # True means FWB setting
24 | # Optimizer
25 | batch_size: 2 # batch size for training (bs8 for 1GPU)
26 | base_lr: 0.005
27 | epochs: 75
28 | start_epoch: 0
29 | stop_interval: 75 # stop when the best result is not updated for "stop_interval" epochs
30 | index_split: -1 # index for determining the params group with 10x learning rate
31 | power: 0.9 # 0 means no decay
32 | momentum: 0.9
33 | weight_decay: 0.0001
34 | warmup: False
35 | # Viz & Save & Resume
36 | print_freq: 10
37 | save_freq: 10
38 | resume: # path to latest checkpoint (default: none, such as epoch_10.pth)
39 | # Validate
40 | evaluate: True
41 | SubEpoch_val: True # val at the half epoch
42 | fix_random_seed_val: True
43 | batch_size_val: 1
44 | resized_val: True
45 | ori_resize: True # use original label for evaluation
46 | # Else
47 | workers: 32
48 | fix_bn: True
49 | manual_seed: 321
50 | seed_deterministic: False
51 | zoom_factor: 8 # zoom factor for final prediction during training, be in [1, 2, 4, 8]
52 |
53 | Method:
54 | layers: 50
55 | vgg: True
56 | aux_weight1: 1.0
57 | aux_weight2: 1.0
58 | low_fea: 'layer2' # low_fea for computing the Gram matrix
59 | kshot_trans_dim: 2 # K-shot dimensionality reduction
60 | merge: 'final' # fusion scheme for GFSS ('base' Eq(S1) | 'final' Eq(18) )
61 | merge_tau: 0.9 # fusion threshold tau
62 |
63 | Test_Finetune:
64 | weight: best_5shot.pth # train_epoch_36_0.3953.pth # load weight for fine-tuning or testing (such as train5_epoch_47.5_0.4926.pth)
65 | ann_type: 'mask' # mask/bbox
66 |
67 |
68 |
69 | ## deprecated multi-processing training
70 | # Distributed:
71 | # dist_url: tcp://127.0.0.1:6789
72 | # dist_backend: 'nccl'
73 | # multiprocessing_distributed: False
74 | # world_size: 1
75 | # rank: 0
76 | # use_apex: False
77 | # opt_level: 'O0'
78 | # keep_batchnorm_fp32:
79 | # loss_scale:
80 |
81 |
--------------------------------------------------------------------------------
/config/coco/coco_split1_resnet50_manet.yaml:
--------------------------------------------------------------------------------
1 | Data:
2 | data_root: ../data/MSCOCO2014
3 | train_list: ./lists/coco/train.txt
4 | val_list: ./lists/coco/val.txt
5 | classes: 2
6 |
7 |
8 | Train:
9 | # Aug
10 | train_h: 633
11 | train_w: 633
12 | val_size: 633
13 | scale_min: 0.8 # minimum random scale
14 | scale_max: 1.25 # maximum random scale
15 | rotate_min: -10 # minimum random rotate
16 | rotate_max: 10 # maximum random rotate
17 | ignore_label: 255
18 | padding_label: 255
19 | # Dataset & Mode
20 | split: 1
21 | shot: 1
22 | data_set: 'coco'
23 | use_split_coco: True # True means FWB setting
24 | # Optimizer
25 | batch_size: 2 # batch size for training (bs8 for 1GPU)
26 | base_lr: 0.005
27 | epochs: 75
28 | start_epoch: 0
29 | stop_interval: 75 # stop when the best result is not updated for "stop_interval" epochs
30 | index_split: -1 # index for determining the params group with 10x learning rate
31 | power: 0.9 # 0 means no decay
32 | momentum: 0.9
33 | weight_decay: 0.0001
34 | warmup: False
35 | # Viz & Save & Resume
36 | print_freq: 10
37 | save_freq: 10
38 | resume: # path to latest checkpoint (default: none, such as epoch_10.pth)
39 | # Validate
40 | evaluate: True
41 | SubEpoch_val: True # val at the half epoch
42 | fix_random_seed_val: True
43 | batch_size_val: 1
44 | resized_val: True
45 | ori_resize: True # use original label for evaluation
46 | # Else
47 | workers: 32
48 | fix_bn: True
49 | manual_seed: 321
50 | seed_deterministic: False
51 | zoom_factor: 8 # zoom factor for final prediction during training, be in [1, 2, 4, 8]
52 |
53 | Method:
54 | layers: 50
55 | vgg: False
56 | aux_weight1: 1.0
57 | aux_weight2: 1.0
58 | low_fea: 'layer2' # low_fea for computing the Gram matrix
59 | kshot_trans_dim: 2 # K-shot dimensionality reduction
60 | merge: 'final' # fusion scheme for GFSS ('base' Eq(S1) | 'final' Eq(18) )
61 | merge_tau: 0.9 # fusion threshold tau
62 |
63 | Test_Finetune:
64 | weight: best_1shot.pth # load weight for fine-tuning or testing (such as train5_epoch_47.5_0.4926.pth)
65 | ann_type: 'mask' # mask/bbox
66 |
67 |
68 |
69 | ## deprecated multi-processing training
70 | # Distributed:
71 | # dist_url: tcp://127.0.0.1:6789
72 | # dist_backend: 'nccl'
73 | # multiprocessing_distributed: False
74 | # world_size: 1
75 | # rank: 0
76 | # use_apex: False
77 | # opt_level: 'O0'
78 | # keep_batchnorm_fp32:
79 | # loss_scale:
80 |
81 |
--------------------------------------------------------------------------------
/config/coco/coco_split1_resnet50_manet_5s.yaml:
--------------------------------------------------------------------------------
1 | Data:
2 | data_root: ../data/MSCOCO2014
3 | train_list: ./lists/coco/train.txt
4 | val_list: ./lists/coco/val.txt
5 | classes: 2
6 |
7 |
8 | Train:
9 | # Aug
10 | train_h: 633
11 | train_w: 633
12 | val_size: 633
13 | scale_min: 0.8 # minimum random scale
14 | scale_max: 1.25 # maximum random scale
15 | rotate_min: -10 # minimum random rotate
16 | rotate_max: 10 # maximum random rotate
17 | ignore_label: 255
18 | padding_label: 255
19 | # Dataset & Mode
20 | split: 1
21 | shot: 5
22 | data_set: 'coco'
23 | use_split_coco: True # True means FWB setting
24 | # Optimizer
25 | batch_size: 2 # batch size for training (bs8 for 1GPU)
26 | base_lr: 0.005
27 | epochs: 75
28 | start_epoch: 0
29 | stop_interval: 75 # stop when the best result is not updated for "stop_interval" epochs
30 | index_split: -1 # index for determining the params group with 10x learning rate
31 | power: 0.9 # 0 means no decay
32 | momentum: 0.9
33 | weight_decay: 0.0001
34 | warmup: False
35 | # Viz & Save & Resume
36 | print_freq: 10
37 | save_freq: 10
38 | resume: # path to latest checkpoint (default: none, such as epoch_10.pth)
39 | # Validate
40 | evaluate: True
41 | SubEpoch_val: True # val at the half epoch
42 | fix_random_seed_val: True
43 | batch_size_val: 1
44 | resized_val: True
45 | ori_resize: True # use original label for evaluation
46 | # Else
47 | workers: 32
48 | fix_bn: True
49 | manual_seed: 321
50 | seed_deterministic: False
51 | zoom_factor: 8 # zoom factor for final prediction during training, be in [1, 2, 4, 8]
52 |
53 | Method:
54 | layers: 50
55 | vgg: False
56 | aux_weight1: 1.0
57 | aux_weight2: 1.0
58 | low_fea: 'layer2' # low_fea for computing the Gram matrix
59 | kshot_trans_dim: 2 # K-shot dimensionality reduction
60 | merge: 'final' # fusion scheme for GFSS ('base' Eq(S1) | 'final' Eq(18) )
61 | merge_tau: 0.9 # fusion threshold tau
62 |
63 | Test_Finetune:
64 | weight: best_5shot.pth # load weight for fine-tuning or testing (such as train5_epoch_47.5_0.4926.pth)
65 | ann_type: 'mask' # mask/bbox
66 |
67 |
68 |
69 | ## deprecated multi-processing training
70 | # Distributed:
71 | # dist_url: tcp://127.0.0.1:6789
72 | # dist_backend: 'nccl'
73 | # multiprocessing_distributed: False
74 | # world_size: 1
75 | # rank: 0
76 | # use_apex: False
77 | # opt_level: 'O0'
78 | # keep_batchnorm_fp32:
79 | # loss_scale:
80 |
81 |
--------------------------------------------------------------------------------
/config/coco/coco_split1_vgg_manet.yaml:
--------------------------------------------------------------------------------
1 | Data:
2 | data_root: ../data/MSCOCO2014
3 | train_list: ./lists/coco/train.txt
4 | val_list: ./lists/coco/val.txt
5 | classes: 2
6 |
7 |
8 | Train:
9 | # Aug
10 | train_h: 633
11 | train_w: 633
12 | val_size: 633
13 | scale_min: 0.8 # minimum random scale
14 | scale_max: 1.25 # maximum random scale
15 | rotate_min: -10 # minimum random rotate
16 | rotate_max: 10 # maximum random rotate
17 | ignore_label: 255
18 | padding_label: 255
19 | # Dataset & Mode
20 | split: 1
21 | shot: 1
22 | data_set: 'coco'
23 | use_split_coco: True # True means FWB setting
24 | # Optimizer
25 | batch_size: 2 # batch size for training (bs8 for 1GPU)
26 | base_lr: 0.005
27 | epochs: 75
28 | start_epoch: 0
29 | stop_interval: 75 # stop when the best result is not updated for "stop_interval" epochs
30 | index_split: -1 # index for determining the params group with 10x learning rate
31 | power: 0.9 # 0 means no decay
32 | momentum: 0.9
33 | weight_decay: 0.0001
34 | warmup: False
35 | # Viz & Save & Resume
36 | print_freq: 10
37 | save_freq: 10
38 | resume: # path to latest checkpoint (default: none, such as epoch_10.pth)
39 | # Validate
40 | evaluate: True
41 | SubEpoch_val: True # val at the half epoch
42 | fix_random_seed_val: True
43 | batch_size_val: 1
44 | resized_val: True
45 | ori_resize: True # use original label for evaluation
46 | # Else
47 | workers: 32
48 | fix_bn: True
49 | manual_seed: 321
50 | seed_deterministic: False
51 | zoom_factor: 8 # zoom factor for final prediction during training, be in [1, 2, 4, 8]
52 |
53 | Method:
54 | layers: 50
55 | vgg: True
56 | aux_weight1: 1.0
57 | aux_weight2: 1.0
58 | low_fea: 'layer2' # low_fea for computing the Gram matrix
59 | kshot_trans_dim: 2 # K-shot dimensionality reduction
60 | merge: 'final' # fusion scheme for GFSS ('base' Eq(S1) | 'final' Eq(18) )
61 | merge_tau: 0.9 # fusion threshold tau
62 |
63 | Test_Finetune:
64 | weight: best_1shot.pth # train_epoch_36_0.3953.pth # load weight for fine-tuning or testing (such as train5_epoch_47.5_0.4926.pth)
65 | ann_type: 'mask' # mask/bbox
66 |
67 |
68 |
69 | ## deprecated multi-processing training
70 | # Distributed:
71 | # dist_url: tcp://127.0.0.1:6789
72 | # dist_backend: 'nccl'
73 | # multiprocessing_distributed: False
74 | # world_size: 1
75 | # rank: 0
76 | # use_apex: False
77 | # opt_level: 'O0'
78 | # keep_batchnorm_fp32:
79 | # loss_scale:
80 |
81 |
--------------------------------------------------------------------------------
/config/coco/coco_split1_vgg_manet_5s.yaml:
--------------------------------------------------------------------------------
1 | Data:
2 | data_root: ../data/MSCOCO2014
3 | train_list: ./lists/coco/train.txt
4 | val_list: ./lists/coco/val.txt
5 | classes: 2
6 |
7 |
8 | Train:
9 | # Aug
10 | train_h: 633
11 | train_w: 633
12 | val_size: 633
13 | scale_min: 0.8 # minimum random scale
14 | scale_max: 1.25 # maximum random scale
15 | rotate_min: -10 # minimum random rotate
16 | rotate_max: 10 # maximum random rotate
17 | ignore_label: 255
18 | padding_label: 255
19 | # Dataset & Mode
20 | split: 1
21 | shot: 5
22 | data_set: 'coco'
23 | use_split_coco: True # True means FWB setting
24 | # Optimizer
25 | batch_size: 2 # batch size for training (bs8 for 1GPU)
26 | base_lr: 0.005
27 | epochs: 75
28 | start_epoch: 0
29 | stop_interval: 75 # stop when the best result is not updated for "stop_interval" epochs
30 | index_split: -1 # index for determining the params group with 10x learning rate
31 | power: 0.9 # 0 means no decay
32 | momentum: 0.9
33 | weight_decay: 0.0001
34 | warmup: False
35 | # Viz & Save & Resume
36 | print_freq: 10
37 | save_freq: 10
38 | resume: # path to latest checkpoint (default: none, such as epoch_10.pth)
39 | # Validate
40 | evaluate: True
41 | SubEpoch_val: True # val at the half epoch
42 | fix_random_seed_val: True
43 | batch_size_val: 1
44 | resized_val: True
45 | ori_resize: True # use original label for evaluation
46 | # Else
47 | workers: 32
48 | fix_bn: True
49 | manual_seed: 321
50 | seed_deterministic: False
51 | zoom_factor: 8 # zoom factor for final prediction during training, be in [1, 2, 4, 8]
52 |
53 | Method:
54 | layers: 50
55 | vgg: True
56 | aux_weight1: 1.0
57 | aux_weight2: 1.0
58 | low_fea: 'layer2' # low_fea for computing the Gram matrix
59 | kshot_trans_dim: 2 # K-shot dimensionality reduction
60 | merge: 'final' # fusion scheme for GFSS ('base' Eq(S1) | 'final' Eq(18) )
61 | merge_tau: 0.9 # fusion threshold tau
62 |
63 | Test_Finetune:
64 | weight: best_5shot.pth # train_epoch_36_0.3953.pth # load weight for fine-tuning or testing (such as train5_epoch_47.5_0.4926.pth)
65 | ann_type: 'mask' # mask/bbox
66 |
67 |
68 |
69 | ## deprecated multi-processing training
70 | # Distributed:
71 | # dist_url: tcp://127.0.0.1:6789
72 | # dist_backend: 'nccl'
73 | # multiprocessing_distributed: False
74 | # world_size: 1
75 | # rank: 0
76 | # use_apex: False
77 | # opt_level: 'O0'
78 | # keep_batchnorm_fp32:
79 | # loss_scale:
80 |
81 |
--------------------------------------------------------------------------------
/config/coco/coco_split2_resnet50_manet.yaml:
--------------------------------------------------------------------------------
1 | Data:
2 | data_root: ../data/MSCOCO2014
3 | train_list: ./lists/coco/train.txt
4 | val_list: ./lists/coco/val.txt
5 | classes: 2
6 |
7 |
8 | Train:
9 | # Aug
10 | train_h: 633
11 | train_w: 633
12 | val_size: 633
13 | scale_min: 0.8 # minimum random scale
14 | scale_max: 1.25 # maximum random scale
15 | rotate_min: -10 # minimum random rotate
16 | rotate_max: 10 # maximum random rotate
17 | ignore_label: 255
18 | padding_label: 255
19 | # Dataset & Mode
20 | split: 2
21 | shot: 1
22 | data_set: 'coco'
23 | use_split_coco: True # True means FWB setting
24 | # Optimizer
25 | batch_size: 2 # batch size for training (bs8 for 1GPU)
26 | base_lr: 0.005
27 | epochs: 75
28 | start_epoch: 0
29 | stop_interval: 75 # stop when the best result is not updated for "stop_interval" epochs
30 | index_split: -1 # index for determining the params group with 10x learning rate
31 | power: 0.9 # 0 means no decay
32 | momentum: 0.9
33 | weight_decay: 0.0001
34 | warmup: False
35 | # Viz & Save & Resume
36 | print_freq: 10
37 | save_freq: 10
38 | resume: # path to latest checkpoint (default: none, such as epoch_10.pth)
39 | # Validate
40 | evaluate: True
41 | SubEpoch_val: True # val at the half epoch
42 | fix_random_seed_val: True
43 | batch_size_val: 1
44 | resized_val: True
45 | ori_resize: True # use original label for evaluation
46 | # Else
47 | workers: 32
48 | fix_bn: True
49 | manual_seed: 321
50 | seed_deterministic: False
51 | zoom_factor: 8 # zoom factor for final prediction during training, be in [1, 2, 4, 8]
52 |
53 | Method:
54 | layers: 50
55 | vgg: False
56 | aux_weight1: 1.0
57 | aux_weight2: 1.0
58 | low_fea: 'layer2' # low_fea for computing the Gram matrix
59 | kshot_trans_dim: 2 # K-shot dimensionality reduction
60 | merge: 'final' # fusion scheme for GFSS ('base' Eq(S1) | 'final' Eq(18) )
61 | merge_tau: 0.9 # fusion threshold tau
62 |
63 | Test_Finetune:
64 | weight: best_1shot.pth # load weight for fine-tuning or testing (such as train5_epoch_47.5_0.4926.pth)
65 | ann_type: 'mask' # mask/bbox
66 |
67 |
68 |
69 | ## deprecated multi-processing training
70 | # Distributed:
71 | # dist_url: tcp://127.0.0.1:6789
72 | # dist_backend: 'nccl'
73 | # multiprocessing_distributed: False
74 | # world_size: 1
75 | # rank: 0
76 | # use_apex: False
77 | # opt_level: 'O0'
78 | # keep_batchnorm_fp32:
79 | # loss_scale:
80 |
81 |
--------------------------------------------------------------------------------
/config/coco/coco_split2_resnet50_manet_5s.yaml:
--------------------------------------------------------------------------------
1 | Data:
2 | data_root: ../data/MSCOCO2014
3 | train_list: ./lists/coco/train.txt
4 | val_list: ./lists/coco/val.txt
5 | classes: 2
6 |
7 |
8 | Train:
9 | # Aug
10 | train_h: 633
11 | train_w: 633
12 | val_size: 633
13 | scale_min: 0.8 # minimum random scale
14 | scale_max: 1.25 # maximum random scale
15 | rotate_min: -10 # minimum random rotate
16 | rotate_max: 10 # maximum random rotate
17 | ignore_label: 255
18 | padding_label: 255
19 | # Dataset & Mode
20 | split: 2
21 | shot: 5
22 | data_set: 'coco'
23 | use_split_coco: True # True means FWB setting
24 | # Optimizer
25 | batch_size: 2 # batch size for training (bs8 for 1GPU)
26 | base_lr: 0.005
27 | epochs: 75
28 | start_epoch: 0
29 | stop_interval: 75 # stop when the best result is not updated for "stop_interval" epochs
30 | index_split: -1 # index for determining the params group with 10x learning rate
31 | power: 0.9 # 0 means no decay
32 | momentum: 0.9
33 | weight_decay: 0.0001
34 | warmup: False
35 | # Viz & Save & Resume
36 | print_freq: 10
37 | save_freq: 10
38 | resume: # path to latest checkpoint (default: none, such as epoch_10.pth)
39 | # Validate
40 | evaluate: True
41 | SubEpoch_val: True # val at the half epoch
42 | fix_random_seed_val: True
43 | batch_size_val: 1
44 | resized_val: True
45 | ori_resize: True # use original label for evaluation
46 | # Else
47 | workers: 32
48 | fix_bn: True
49 | manual_seed: 321
50 | seed_deterministic: False
51 | zoom_factor: 8 # zoom factor for final prediction during training, be in [1, 2, 4, 8]
52 |
53 | Method:
54 | layers: 50
55 | vgg: False
56 | aux_weight1: 1.0
57 | aux_weight2: 1.0
58 | low_fea: 'layer2' # low_fea for computing the Gram matrix
59 | kshot_trans_dim: 2 # K-shot dimensionality reduction
60 | merge: 'final' # fusion scheme for GFSS ('base' Eq(S1) | 'final' Eq(18) )
61 | merge_tau: 0.9 # fusion threshold tau
62 |
63 | Test_Finetune:
64 | weight: best_5shot.pth # load weight for fine-tuning or testing (such as train5_epoch_47.5_0.4926.pth)
65 | ann_type: 'mask' # mask/bbox
66 |
67 |
68 |
69 | ## deprecated multi-processing training
70 | # Distributed:
71 | # dist_url: tcp://127.0.0.1:6789
72 | # dist_backend: 'nccl'
73 | # multiprocessing_distributed: False
74 | # world_size: 1
75 | # rank: 0
76 | # use_apex: False
77 | # opt_level: 'O0'
78 | # keep_batchnorm_fp32:
79 | # loss_scale:
80 |
81 |
--------------------------------------------------------------------------------
/config/coco/coco_split2_vgg_manet.yaml:
--------------------------------------------------------------------------------
1 | Data:
2 | data_root: ../data/MSCOCO2014
3 | train_list: ./lists/coco/train.txt
4 | val_list: ./lists/coco/val.txt
5 | classes: 2
6 |
7 |
8 | Train:
9 | # Aug
10 | train_h: 633
11 | train_w: 633
12 | val_size: 633
13 | scale_min: 0.8 # minimum random scale
14 | scale_max: 1.25 # maximum random scale
15 | rotate_min: -10 # minimum random rotate
16 | rotate_max: 10 # maximum random rotate
17 | ignore_label: 255
18 | padding_label: 255
19 | # Dataset & Mode
20 | split: 2
21 | shot: 1
22 | data_set: 'coco'
23 | use_split_coco: True # True means FWB setting
24 | # Optimizer
25 | batch_size: 2 # batch size for training (bs8 for 1GPU)
26 | base_lr: 0.005
27 | epochs: 75
28 | start_epoch: 0
29 | stop_interval: 75 # stop when the best result is not updated for "stop_interval" epochs
30 | index_split: -1 # index for determining the params group with 10x learning rate
31 | power: 0.9 # 0 means no decay
32 | momentum: 0.9
33 | weight_decay: 0.0001
34 | warmup: False
35 | # Viz & Save & Resume
36 | print_freq: 10
37 | save_freq: 10
38 | resume: # path to latest checkpoint (default: none, such as epoch_10.pth)
39 | # Validate
40 | evaluate: True
41 | SubEpoch_val: True # val at the half epoch
42 | fix_random_seed_val: True
43 | batch_size_val: 1
44 | resized_val: True
45 | ori_resize: True # use original label for evaluation
46 | # Else
47 | workers: 32
48 | fix_bn: True
49 | manual_seed: 321
50 | seed_deterministic: False
51 | zoom_factor: 8 # zoom factor for final prediction during training, be in [1, 2, 4, 8]
52 |
53 | Method:
54 | layers: 50
55 | vgg: True
56 | aux_weight1: 1.0
57 | aux_weight2: 1.0
58 | low_fea: 'layer2' # low_fea for computing the Gram matrix
59 | kshot_trans_dim: 2 # K-shot dimensionality reduction
60 | merge: 'final' # fusion scheme for GFSS ('base' Eq(S1) | 'final' Eq(18) )
61 | merge_tau: 0.9 # fusion threshold tau
62 |
63 | Test_Finetune:
64 | weight: best_1shot.pth # train_epoch_36_0.3953.pth # load weight for fine-tuning or testing (such as train5_epoch_47.5_0.4926.pth)
65 | ann_type: 'mask' # mask/bbox
66 |
67 |
68 |
69 | ## deprecated multi-processing training
70 | # Distributed:
71 | # dist_url: tcp://127.0.0.1:6789
72 | # dist_backend: 'nccl'
73 | # multiprocessing_distributed: False
74 | # world_size: 1
75 | # rank: 0
76 | # use_apex: False
77 | # opt_level: 'O0'
78 | # keep_batchnorm_fp32:
79 | # loss_scale:
80 |
81 |
--------------------------------------------------------------------------------
/config/coco/coco_split2_vgg_manet_5s.yaml:
--------------------------------------------------------------------------------
1 | Data:
2 | data_root: ../data/MSCOCO2014
3 | train_list: ./lists/coco/train.txt
4 | val_list: ./lists/coco/val.txt
5 | classes: 2
6 |
7 |
8 | Train:
9 | # Aug
10 | train_h: 633
11 | train_w: 633
12 | val_size: 633
13 | scale_min: 0.8 # minimum random scale
14 | scale_max: 1.25 # maximum random scale
15 | rotate_min: -10 # minimum random rotate
16 | rotate_max: 10 # maximum random rotate
17 | ignore_label: 255
18 | padding_label: 255
19 | # Dataset & Mode
20 | split: 2
21 | shot: 5
22 | data_set: 'coco'
23 | use_split_coco: True # True means FWB setting
24 | # Optimizer
25 | batch_size: 2 # batch size for training (bs8 for 1GPU)
26 | base_lr: 0.005
27 | epochs: 75
28 | start_epoch: 0
29 | stop_interval: 75 # stop when the best result is not updated for "stop_interval" epochs
30 | index_split: -1 # index for determining the params group with 10x learning rate
31 | power: 0.9 # 0 means no decay
32 | momentum: 0.9
33 | weight_decay: 0.0001
34 | warmup: False
35 | # Viz & Save & Resume
36 | print_freq: 10
37 | save_freq: 10
38 | resume: # path to latest checkpoint (default: none, such as epoch_10.pth)
39 | # Validate
40 | evaluate: True
41 | SubEpoch_val: True # val at the half epoch
42 | fix_random_seed_val: True
43 | batch_size_val: 1
44 | resized_val: True
45 | ori_resize: True # use original label for evaluation
46 | # Else
47 | workers: 32
48 | fix_bn: True
49 | manual_seed: 321
50 | seed_deterministic: False
51 | zoom_factor: 8 # zoom factor for final prediction during training, be in [1, 2, 4, 8]
52 |
53 | Method:
54 | layers: 50
55 | vgg: True
56 | aux_weight1: 1.0
57 | aux_weight2: 1.0
58 | low_fea: 'layer2' # low_fea for computing the Gram matrix
59 | kshot_trans_dim: 2 # K-shot dimensionality reduction
60 | merge: 'final' # fusion scheme for GFSS ('base' Eq(S1) | 'final' Eq(18) )
61 | merge_tau: 0.9 # fusion threshold tau
62 |
63 | Test_Finetune:
64 | weight: best_5shot.pth # train_epoch_36_0.3953.pth # load weight for fine-tuning or testing (such as train5_epoch_47.5_0.4926.pth)
65 | ann_type: 'mask' # mask/bbox
66 |
67 |
68 |
69 | ## deprecated multi-processing training
70 | # Distributed:
71 | # dist_url: tcp://127.0.0.1:6789
72 | # dist_backend: 'nccl'
73 | # multiprocessing_distributed: False
74 | # world_size: 1
75 | # rank: 0
76 | # use_apex: False
77 | # opt_level: 'O0'
78 | # keep_batchnorm_fp32:
79 | # loss_scale:
80 |
81 |
--------------------------------------------------------------------------------
/config/coco/coco_split3_resnet50_manet.yaml:
--------------------------------------------------------------------------------
1 | Data:
2 | data_root: ../data/MSCOCO2014
3 | train_list: ./lists/coco/train.txt
4 | val_list: ./lists/coco/val.txt
5 | classes: 2
6 |
7 |
8 | Train:
9 | # Aug
10 | train_h: 633
11 | train_w: 633
12 | val_size: 633
13 | scale_min: 0.8 # minimum random scale
14 | scale_max: 1.25 # maximum random scale
15 | rotate_min: -10 # minimum random rotate
16 | rotate_max: 10 # maximum random rotate
17 | ignore_label: 255
18 | padding_label: 255
19 | # Dataset & Mode
20 | split: 3
21 | shot: 1
22 | data_set: 'coco'
23 | use_split_coco: True # True means FWB setting
24 | # Optimizer
25 | batch_size: 2 # batch size for training (bs8 for 1GPU)
26 | base_lr: 0.005
27 | epochs: 75
28 | start_epoch: 0
29 | stop_interval: 75 # stop when the best result is not updated for "stop_interval" epochs
30 | index_split: -1 # index for determining the params group with 10x learning rate
31 | power: 0.9 # 0 means no decay
32 | momentum: 0.9
33 | weight_decay: 0.0001
34 | warmup: False
35 | # Viz & Save & Resume
36 | print_freq: 10
37 | save_freq: 10
38 | resume: # path to latest checkpoint (default: none, such as epoch_10.pth)
39 | # Validate
40 | evaluate: True
41 | SubEpoch_val: True # val at the half epoch
42 | fix_random_seed_val: True
43 | batch_size_val: 1
44 | resized_val: True
45 | ori_resize: True # use original label for evaluation
46 | # Else
47 | workers: 32
48 | fix_bn: True
49 | manual_seed: 321
50 | seed_deterministic: False
51 | zoom_factor: 8 # zoom factor for final prediction during training, be in [1, 2, 4, 8]
52 |
53 | Method:
54 | layers: 50
55 | vgg: False
56 | aux_weight1: 1.0
57 | aux_weight2: 1.0
58 | low_fea: 'layer2' # low_fea for computing the Gram matrix
59 | kshot_trans_dim: 2 # K-shot dimensionality reduction
60 | merge: 'final' # fusion scheme for GFSS ('base' Eq(S1) | 'final' Eq(18) )
61 | merge_tau: 0.9 # fusion threshold tau
62 |
63 | Test_Finetune:
64 | weight: best_1shot.pth # load weight for fine-tuning or testing (such as train5_epoch_47.5_0.4926.pth)
65 | ann_type: 'mask' # mask/bbox
66 |
67 |
68 |
69 | ## deprecated multi-processing training
70 | # Distributed:
71 | # dist_url: tcp://127.0.0.1:6789
72 | # dist_backend: 'nccl'
73 | # multiprocessing_distributed: False
74 | # world_size: 1
75 | # rank: 0
76 | # use_apex: False
77 | # opt_level: 'O0'
78 | # keep_batchnorm_fp32:
79 | # loss_scale:
80 |
81 |
--------------------------------------------------------------------------------
/config/coco/coco_split3_resnet50_manet_5s.yaml:
--------------------------------------------------------------------------------
1 | Data:
2 | data_root: ../data/MSCOCO2014
3 | train_list: ./lists/coco/train.txt
4 | val_list: ./lists/coco/val.txt
5 | classes: 2
6 |
7 |
8 | Train:
9 | # Aug
10 | train_h: 633
11 | train_w: 633
12 | val_size: 633
13 | scale_min: 0.8 # minimum random scale
14 | scale_max: 1.25 # maximum random scale
15 | rotate_min: -10 # minimum random rotate
16 | rotate_max: 10 # maximum random rotate
17 | ignore_label: 255
18 | padding_label: 255
19 | # Dataset & Mode
20 | split: 3
21 | shot: 5
22 | data_set: 'coco'
23 | use_split_coco: True # True means FWB setting
24 | # Optimizer
25 | batch_size: 2 # batch size for training (bs8 for 1GPU)
26 | base_lr: 0.005
27 | epochs: 75
28 | start_epoch: 0
29 | stop_interval: 75 # stop when the best result is not updated for "stop_interval" epochs
30 | index_split: -1 # index for determining the params group with 10x learning rate
31 | power: 0.9 # 0 means no decay
32 | momentum: 0.9
33 | weight_decay: 0.0001
34 | warmup: False
35 | # Viz & Save & Resume
36 | print_freq: 10
37 | save_freq: 10
38 | resume: # path to latest checkpoint (default: none, such as epoch_10.pth)
39 | # Validate
40 | evaluate: True
41 | SubEpoch_val: True # val at the half epoch
42 | fix_random_seed_val: True
43 | batch_size_val: 1
44 | resized_val: True
45 | ori_resize: True # use original label for evaluation
46 | # Else
47 | workers: 32
48 | fix_bn: True
49 | manual_seed: 321
50 | seed_deterministic: False
51 | zoom_factor: 8 # zoom factor for final prediction during training, be in [1, 2, 4, 8]
52 |
53 | Method:
54 | layers: 50
55 | vgg: False
56 | aux_weight1: 1.0
57 | aux_weight2: 1.0
58 | low_fea: 'layer2' # low_fea for computing the Gram matrix
59 | kshot_trans_dim: 2 # K-shot dimensionality reduction
60 | merge: 'final' # fusion scheme for GFSS ('base' Eq(S1) | 'final' Eq(18) )
61 | merge_tau: 0.9 # fusion threshold tau
62 |
63 | Test_Finetune:
64 | weight: best_5shot.pth # load weight for fine-tuning or testing (such as train5_epoch_47.5_0.4926.pth)
65 | ann_type: 'mask' # mask/bbox
66 |
67 |
68 |
69 | ## deprecated multi-processing training
70 | # Distributed:
71 | # dist_url: tcp://127.0.0.1:6789
72 | # dist_backend: 'nccl'
73 | # multiprocessing_distributed: False
74 | # world_size: 1
75 | # rank: 0
76 | # use_apex: False
77 | # opt_level: 'O0'
78 | # keep_batchnorm_fp32:
79 | # loss_scale:
80 |
81 |
--------------------------------------------------------------------------------
/config/coco/coco_split3_vgg_manet.yaml:
--------------------------------------------------------------------------------
1 | Data:
2 | data_root: ../data/MSCOCO2014
3 | train_list: ./lists/coco/train.txt
4 | val_list: ./lists/coco/val.txt
5 | classes: 2
6 |
7 |
8 | Train:
9 | # Aug
10 | train_h: 633
11 | train_w: 633
12 | val_size: 633
13 | scale_min: 0.8 # minimum random scale
14 | scale_max: 1.25 # maximum random scale
15 | rotate_min: -10 # minimum random rotate
16 | rotate_max: 10 # maximum random rotate
17 | ignore_label: 255
18 | padding_label: 255
19 | # Dataset & Mode
20 | split: 3
21 | shot: 1
22 | data_set: 'coco'
23 | use_split_coco: True # True means FWB setting
24 | # Optimizer
25 | batch_size: 2 # batch size for training (bs8 for 1GPU)
26 | base_lr: 0.005
27 | epochs: 75
28 | start_epoch: 0
29 | stop_interval: 75 # stop when the best result is not updated for "stop_interval" epochs
30 | index_split: -1 # index for determining the params group with 10x learning rate
31 | power: 0.9 # 0 means no decay
32 | momentum: 0.9
33 | weight_decay: 0.0001
34 | warmup: False
35 | # Viz & Save & Resume
36 | print_freq: 10
37 | save_freq: 10
38 | resume: # path to latest checkpoint (default: none, such as epoch_10.pth)
39 | # Validate
40 | evaluate: True
41 | SubEpoch_val: True # val at the half epoch
42 | fix_random_seed_val: True
43 | batch_size_val: 1
44 | resized_val: True
45 | ori_resize: True # use original label for evaluation
46 | # Else
47 | workers: 32
48 | fix_bn: True
49 | manual_seed: 321
50 | seed_deterministic: False
51 | zoom_factor: 8 # zoom factor for final prediction during training, be in [1, 2, 4, 8]
52 |
53 | Method:
54 | layers: 50
55 | vgg: True
56 | aux_weight1: 1.0
57 | aux_weight2: 1.0
58 | low_fea: 'layer2' # low_fea for computing the Gram matrix
59 | kshot_trans_dim: 2 # K-shot dimensionality reduction
60 | merge: 'final' # fusion scheme for GFSS ('base' Eq(S1) | 'final' Eq(18) )
61 | merge_tau: 0.9 # fusion threshold tau
62 |
63 | Test_Finetune:
64 | weight: best_1shot.pth # train_epoch_36_0.3953.pth # load weight for fine-tuning or testing (such as train5_epoch_47.5_0.4926.pth)
65 | ann_type: 'mask' # mask/bbox
66 |
67 |
68 |
69 | ## deprecated multi-processing training
70 | # Distributed:
71 | # dist_url: tcp://127.0.0.1:6789
72 | # dist_backend: 'nccl'
73 | # multiprocessing_distributed: False
74 | # world_size: 1
75 | # rank: 0
76 | # use_apex: False
77 | # opt_level: 'O0'
78 | # keep_batchnorm_fp32:
79 | # loss_scale:
80 |
81 |
--------------------------------------------------------------------------------
/config/coco/coco_split3_vgg_manet_5s.yaml:
--------------------------------------------------------------------------------
1 | Data:
2 | data_root: ../data/MSCOCO2014
3 | train_list: ./lists/coco/train.txt
4 | val_list: ./lists/coco/val.txt
5 | classes: 2
6 |
7 |
8 | Train:
9 | # Aug
10 | train_h: 633
11 | train_w: 633
12 | val_size: 633
13 | scale_min: 0.8 # minimum random scale
14 | scale_max: 1.25 # maximum random scale
15 | rotate_min: -10 # minimum random rotate
16 | rotate_max: 10 # maximum random rotate
17 | ignore_label: 255
18 | padding_label: 255
19 | # Dataset & Mode
20 | split: 3
21 | shot: 5
22 | data_set: 'coco'
23 | use_split_coco: True # True means FWB setting
24 | # Optimizer
25 | batch_size: 2 # batch size for training (bs8 for 1GPU)
26 | base_lr: 0.005
27 | epochs: 75
28 | start_epoch: 0
29 | stop_interval: 75 # stop when the best result is not updated for "stop_interval" epochs
30 | index_split: -1 # index for determining the params group with 10x learning rate
31 | power: 0.9 # 0 means no decay
32 | momentum: 0.9
33 | weight_decay: 0.0001
34 | warmup: False
35 | # Viz & Save & Resume
36 | print_freq: 10
37 | save_freq: 10
38 | resume: # path to latest checkpoint (default: none, such as epoch_10.pth)
39 | # Validate
40 | evaluate: True
41 | SubEpoch_val: True # val at the half epoch
42 | fix_random_seed_val: True
43 | batch_size_val: 1
44 | resized_val: True
45 | ori_resize: True # use original label for evaluation
46 | # Else
47 | workers: 32
48 | fix_bn: True
49 | manual_seed: 321
50 | seed_deterministic: False
51 | zoom_factor: 8 # zoom factor for final prediction during training, be in [1, 2, 4, 8]
52 |
53 | Method:
54 | layers: 50
55 | vgg: True
56 | aux_weight1: 1.0
57 | aux_weight2: 1.0
58 | low_fea: 'layer2' # low_fea for computing the Gram matrix
59 | kshot_trans_dim: 2 # K-shot dimensionality reduction
60 | merge: 'final' # fusion scheme for GFSS ('base' Eq(S1) | 'final' Eq(18) )
61 | merge_tau: 0.9 # fusion threshold tau
62 |
63 | Test_Finetune:
64 | weight: best_5shot.pth # train_epoch_36_0.3953.pth # load weight for fine-tuning or testing (such as train5_epoch_47.5_0.4926.pth)
65 | ann_type: 'mask' # mask/bbox
66 |
67 |
68 |
69 | ## deprecated multi-processing training
70 | # Distributed:
71 | # dist_url: tcp://127.0.0.1:6789
72 | # dist_backend: 'nccl'
73 | # multiprocessing_distributed: False
74 | # world_size: 1
75 | # rank: 0
76 | # use_apex: False
77 | # opt_level: 'O0'
78 | # keep_batchnorm_fp32:
79 | # loss_scale:
80 |
81 |
--------------------------------------------------------------------------------
/config/pascal/pascal_split0_resnet50_manet.yaml:
--------------------------------------------------------------------------------
1 | Data:
2 | data_root: ../data/VOCdevkit2012/VOC2012
3 | train_list: ./lists/pascal/voc_sbd_merge_noduplicate.txt
4 | val_list: ./lists/pascal/val.txt
5 | classes: 2
6 |
7 |
8 | Train:
9 | # Aug
10 | train_h: 473
11 | train_w: 473
12 | val_size: 473
13 | scale_min: 0.9 # minimum random scale
14 | scale_max: 1.1 # maximum random scale
15 | rotate_min: -10 # minimum random rotate
16 | rotate_max: 10 # maximum random rotate
17 | ignore_label: 255
18 | padding_label: 255
19 | # Dataset & Mode
20 | split: 0
21 | shot: 1
22 | data_set: 'pascal'
23 | use_split_coco: False # True means FWB setting
24 | # Optimizer
25 | batch_size: 2 # batch size for training (bs8 for 1GPU)
26 | base_lr: 0.005
27 | epochs: 300
28 | start_epoch: 0
29 | stop_interval: 100 # stop when the best result is not updated for "stop_interval" epochs
30 | index_split: -1 # index for determining the params group with 10x learning rate
31 | power: 0.9 # 0 means no decay
32 | momentum: 0.9
33 | weight_decay: 0.0001
34 | warmup: False
35 | # Viz & Save & Resume
36 | print_freq: 10
37 | save_freq: 10
38 | resume: # path to latest checkpoint (default: none, such as epoch_10.pth)
39 | # Validate
40 | evaluate: True
41 | SubEpoch_val: True # val at the half epoch
42 | fix_random_seed_val: True
43 | batch_size_val: 1
44 | resized_val: True
45 | ori_resize: True # use original label for evaluation
46 | # Else
47 | workers: 8
48 | fix_bn: True
49 | manual_seed: 321
50 | seed_deterministic: False
51 | zoom_factor: 8 # zoom factor for final prediction during training, be in [1, 2, 4, 8]
52 |
53 | Method:
54 | layers: 50
55 | vgg: False
56 | aux_weight1: 1.0
57 | aux_weight2: 1.0
58 | low_fea: 'layer2' # low_fea for computing the Gram matrix
59 | kshot_trans_dim: 2 # K-shot dimensionality reduction
60 | merge: 'final' # fusion scheme for GFSS ('base' Eq(S1) | 'final' Eq(18) )
61 | merge_tau: 0.9 # fusion threshold tau
62 |
63 | Test_Finetune:
64 | weight: best_1shot.pth # load weight for fine-tuning or testing (such as win8_train_epoch_43_0.6753.pth)
65 | ann_type: 'mask' # mask/bbox
66 |
67 |
68 |
69 | ## deprecated multi-processing training
70 | # Distributed:
71 | # dist_url: tcp://127.0.0.1:6789
72 | # dist_backend: 'nccl'
73 | # multiprocessing_distributed: False
74 | # world_size: 1
75 | # rank: 0
76 | # use_apex: False
77 | # opt_level: 'O0'
78 | # keep_batchnorm_fp32:
79 | # loss_scale:
80 |
81 |
--------------------------------------------------------------------------------
/config/pascal/pascal_split0_resnet50_manet_5s.yaml:
--------------------------------------------------------------------------------
1 | Data:
2 | data_root: ../data/VOCdevkit2012/VOC2012
3 | train_list: ./lists/pascal/voc_sbd_merge_noduplicate.txt
4 | val_list: ./lists/pascal/val.txt
5 | classes: 2
6 |
7 |
8 | Train:
9 | # Aug
10 | train_h: 473
11 | train_w: 473
12 | val_size: 473
13 | scale_min: 0.9 # minimum random scale
14 | scale_max: 1.1 # maximum random scale
15 | rotate_min: -10 # minimum random rotate
16 | rotate_max: 10 # maximum random rotate
17 | ignore_label: 255
18 | padding_label: 255
19 | # Dataset & Mode
20 | split: 0
21 | shot: 5
22 | data_set: 'pascal'
23 | use_split_coco: False # True means FWB setting
24 | # Optimizer
25 | batch_size: 2 # batch size for training (bs8 for 1GPU)
26 | base_lr: 0.005
27 | epochs: 300
28 | start_epoch: 0
29 | stop_interval: 100 # stop when the best result is not updated for "stop_interval" epochs
30 | index_split: -1 # index for determining the params group with 10x learning rate
31 | power: 0.9 # 0 means no decay
32 | momentum: 0.9
33 | weight_decay: 0.0001
34 | warmup: False
35 | # Viz & Save & Resume
36 | print_freq: 10
37 | save_freq: 10
38 | resume: # path to latest checkpoint (default: none, such as epoch_10.pth)
39 | # Validate
40 | evaluate: True
41 | SubEpoch_val: True # val at the half epoch
42 | fix_random_seed_val: True
43 | batch_size_val: 1
44 | resized_val: True
45 | ori_resize: True # use original label for evaluation
46 | # Else
47 | workers: 8
48 | fix_bn: True
49 | manual_seed: 321
50 | seed_deterministic: False
51 | zoom_factor: 8 # zoom factor for final prediction during training, be in [1, 2, 4, 8]
52 |
53 | Method:
54 | layers: 50
55 | vgg: False
56 | aux_weight1: 1.0
57 | aux_weight2: 1.0
58 | low_fea: 'layer2' # low_fea for computing the Gram matrix
59 | kshot_trans_dim: 2 # K-shot dimensionality reduction
60 | merge: 'final' # fusion scheme for GFSS ('base' Eq(S1) | 'final' Eq(18) )
61 | merge_tau: 0.9 # fusion threshold tau
62 |
63 | Test_Finetune:
64 | weight: best_5shot.pth # load weight for fine-tuning or testing (such as win8_train_epoch_43_0.6753.pth)
65 | ann_type: 'mask' # mask/bbox
66 |
67 |
68 |
69 | ## deprecated multi-processing training
70 | # Distributed:
71 | # dist_url: tcp://127.0.0.1:6789
72 | # dist_backend: 'nccl'
73 | # multiprocessing_distributed: False
74 | # world_size: 1
75 | # rank: 0
76 | # use_apex: False
77 | # opt_level: 'O0'
78 | # keep_batchnorm_fp32:
79 | # loss_scale:
80 |
81 |
--------------------------------------------------------------------------------
/config/pascal/pascal_split0_vgg_manet.yaml:
--------------------------------------------------------------------------------
1 | Data:
2 | data_root: ../data/VOCdevkit2012/VOC2012
3 | train_list: ./lists/pascal/voc_sbd_merge_noduplicate.txt
4 | val_list: ./lists/pascal/val.txt
5 | classes: 2
6 |
7 |
8 | Train:
9 | # Aug
10 | train_h: 473
11 | train_w: 473
12 | val_size: 473
13 | scale_min: 0.9 # minimum random scale
14 | scale_max: 1.1 # maximum random scale
15 | rotate_min: -10 # minimum random rotate
16 | rotate_max: 10 # maximum random rotate
17 | ignore_label: 255
18 | padding_label: 255
19 | # Dataset & Mode
20 | split: 0
21 | shot: 1
22 | data_set: 'pascal'
23 | use_split_coco: False # True means FWB setting
24 | # Optimizer
25 | batch_size: 2 # batch size for training (bs8 for 1GPU)
26 | base_lr: 0.005
27 | epochs: 300
28 | start_epoch: 0
29 | stop_interval: 100 # stop when the best result is not updated for "stop_interval" epochs
30 | index_split: -1 # index for determining the params group with 10x learning rate
31 | power: 0.9 # 0 means no decay
32 | momentum: 0.9
33 | weight_decay: 0.0001
34 | warmup: False
35 | # Viz & Save & Resume
36 | print_freq: 10
37 | save_freq: 10
38 | resume: # path to latest checkpoint (default: none, such as epoch_10.pth)
39 | # Validate
40 | evaluate: True
41 | SubEpoch_val: True # val at the half epoch
42 | fix_random_seed_val: True
43 | batch_size_val: 1
44 | resized_val: True
45 | ori_resize: True # use original label for evaluation
46 | # Else
47 | workers: 8
48 | fix_bn: True
49 | manual_seed: 321
50 | seed_deterministic: False
51 | zoom_factor: 8 # zoom factor for final prediction during training, be in [1, 2, 4, 8]
52 |
53 | Method:
54 | layers: 50
55 | vgg: True
56 | aux_weight1: 1.0
57 | aux_weight2: 1.0
58 | low_fea: 'layer2' # low_fea for computing the Gram matrix
59 | kshot_trans_dim: 2 # K-shot dimensionality reduction
60 | merge: 'final' # fusion scheme for GFSS ('base' Eq(S1) | 'final' Eq(18) )
61 | merge_tau: 0.9 # fusion threshold tau
62 |
63 | Test_Finetune:
64 | weight: best_1shot.pth # load weight for fine-tuning or testing (such as win8_train_epoch_43_0.6753.pth)
65 | ann_type: 'mask' # mask/bbox
66 |
67 |
68 |
69 | ## deprecated multi-processing training
70 | # Distributed:
71 | # dist_url: tcp://127.0.0.1:6789
72 | # dist_backend: 'nccl'
73 | # multiprocessing_distributed: False
74 | # world_size: 1
75 | # rank: 0
76 | # use_apex: False
77 | # opt_level: 'O0'
78 | # keep_batchnorm_fp32:
79 | # loss_scale:
80 |
81 |
--------------------------------------------------------------------------------
/config/pascal/pascal_split0_vgg_manet_5s.yaml:
--------------------------------------------------------------------------------
1 | Data:
2 | data_root: ../data/VOCdevkit2012/VOC2012
3 | train_list: ./lists/pascal/voc_sbd_merge_noduplicate.txt
4 | val_list: ./lists/pascal/val.txt
5 | classes: 2
6 |
7 |
8 | Train:
9 | # Aug
10 | train_h: 473
11 | train_w: 473
12 | val_size: 473
13 | scale_min: 0.9 # minimum random scale
14 | scale_max: 1.1 # maximum random scale
15 | rotate_min: -10 # minimum random rotate
16 | rotate_max: 10 # maximum random rotate
17 | ignore_label: 255
18 | padding_label: 255
19 | # Dataset & Mode
20 | split: 0
21 | shot: 5
22 | data_set: 'pascal'
23 | use_split_coco: False # True means FWB setting
24 | # Optimizer
25 | batch_size: 2 # batch size for training (bs8 for 1GPU)
26 | base_lr: 0.005
27 | epochs: 300
28 | start_epoch: 0
29 | stop_interval: 100 # stop when the best result is not updated for "stop_interval" epochs
30 | index_split: -1 # index for determining the params group with 10x learning rate
31 | power: 0.9 # 0 means no decay
32 | momentum: 0.9
33 | weight_decay: 0.0001
34 | warmup: False
35 | # Viz & Save & Resume
36 | print_freq: 10
37 | save_freq: 10
38 | resume: # path to latest checkpoint (default: none, such as epoch_10.pth)
39 | # Validate
40 | evaluate: True
41 | SubEpoch_val: True # val at the half epoch
42 | fix_random_seed_val: True
43 | batch_size_val: 1
44 | resized_val: True
45 | ori_resize: True # use original label for evaluation
46 | # Else
47 | workers: 8
48 | fix_bn: True
49 | manual_seed: 321
50 | seed_deterministic: False
51 | zoom_factor: 8 # zoom factor for final prediction during training, be in [1, 2, 4, 8]
52 |
53 | Method:
54 | layers: 50
55 | vgg: True
56 | aux_weight1: 1.0
57 | aux_weight2: 1.0
58 | low_fea: 'layer2' # low_fea for computing the Gram matrix
59 | kshot_trans_dim: 2 # K-shot dimensionality reduction
60 | merge: 'final' # fusion scheme for GFSS ('base' Eq(S1) | 'final' Eq(18) )
61 | merge_tau: 0.9 # fusion threshold tau
62 |
63 | Test_Finetune:
64 | weight: best_5shot.pth # load weight for fine-tuning or testing (such as win8_train_epoch_43_0.6753.pth)
65 | ann_type: 'mask' # mask/bbox
66 |
67 |
68 |
69 | ## deprecated multi-processing training
70 | # Distributed:
71 | # dist_url: tcp://127.0.0.1:6789
72 | # dist_backend: 'nccl'
73 | # multiprocessing_distributed: False
74 | # world_size: 1
75 | # rank: 0
76 | # use_apex: False
77 | # opt_level: 'O0'
78 | # keep_batchnorm_fp32:
79 | # loss_scale:
80 |
81 |
--------------------------------------------------------------------------------
/config/pascal/pascal_split1_resnet50_manet.yaml:
--------------------------------------------------------------------------------
1 | Data:
2 | data_root: ../data/VOCdevkit2012/VOC2012
3 | train_list: ./lists/pascal/voc_sbd_merge_noduplicate.txt
4 | val_list: ./lists/pascal/val.txt
5 | classes: 2
6 |
7 |
8 | Train:
9 | # Aug
10 | train_h: 473
11 | train_w: 473
12 | val_size: 473
13 | scale_min: 0.9 # minimum random scale
14 | scale_max: 1.1 # maximum random scale
15 | rotate_min: -10 # minimum random rotate
16 | rotate_max: 10 # maximum random rotate
17 | ignore_label: 255
18 | padding_label: 255
19 | # Dataset & Mode
20 | split: 1
21 | shot: 1
22 | data_set: 'pascal'
23 | use_split_coco: False # True means FWB setting
24 | # Optimizer
25 | batch_size: 2 # batch size for training (bs8 for 1GPU)
26 | base_lr: 0.005
27 | epochs: 300
28 | start_epoch: 0
29 | stop_interval: 100 # stop when the best result is not updated for "stop_interval" epochs
30 | index_split: -1 # index for determining the params group with 10x learning rate
31 | power: 0.9 # 0 means no decay
32 | momentum: 0.9
33 | weight_decay: 0.0001
34 | warmup: False
35 | # Viz & Save & Resume
36 | print_freq: 10
37 | save_freq: 10
38 | resume: # path to latest checkpoint (default: none, such as epoch_10.pth)
39 | # Validate
40 | evaluate: True
41 | SubEpoch_val: True # val at the half epoch
42 | fix_random_seed_val: True
43 | batch_size_val: 1
44 | resized_val: True
45 | ori_resize: True # use original label for evaluation
46 | # Else
47 | workers: 8
48 | fix_bn: True
49 | manual_seed: 321
50 | seed_deterministic: False
51 | zoom_factor: 8 # zoom factor for final prediction during training, be in [1, 2, 4, 8]
52 |
53 | Method:
54 | layers: 50
55 | vgg: False
56 | aux_weight1: 1.0
57 | aux_weight2: 1.0
58 | low_fea: 'layer2' # low_fea for computing the Gram matrix
59 | kshot_trans_dim: 2 # K-shot dimensionality reduction
60 | merge: 'final' # fusion scheme for GFSS ('base' Eq(S1) | 'final' Eq(18) )
61 | merge_tau: 0.9 # fusion threshold tau
62 |
63 | Test_Finetune:
64 | weight: best_1shot.pth # load weight for fine-tuning or testing (such as win8_train_epoch_124_0.7262.pth)
65 | ann_type: 'mask' # mask/bbox
66 |
67 |
68 |
69 | ## deprecated multi-processing training
70 | # Distributed:
71 | # dist_url: tcp://127.0.0.1:6789
72 | # dist_backend: 'nccl'
73 | # multiprocessing_distributed: False
74 | # world_size: 1
75 | # rank: 0
76 | # use_apex: False
77 | # opt_level: 'O0'
78 | # keep_batchnorm_fp32:
79 | # loss_scale:
80 |
81 |
--------------------------------------------------------------------------------
/config/pascal/pascal_split1_resnet50_manet_5s.yaml:
--------------------------------------------------------------------------------
1 | Data:
2 | data_root: ../data/VOCdevkit2012/VOC2012
3 | train_list: ./lists/pascal/voc_sbd_merge_noduplicate.txt
4 | val_list: ./lists/pascal/val.txt
5 | classes: 2
6 |
7 |
8 | Train:
9 | # Aug
10 | train_h: 473
11 | train_w: 473
12 | val_size: 473
13 | scale_min: 0.9 # minimum random scale
14 | scale_max: 1.1 # maximum random scale
15 | rotate_min: -10 # minimum random rotate
16 | rotate_max: 10 # maximum random rotate
17 | ignore_label: 255
18 | padding_label: 255
19 | # Dataset & Mode
20 | split: 1
21 | shot: 5
22 | data_set: 'pascal'
23 | use_split_coco: False # True means FWB setting
24 | # Optimizer
25 | batch_size: 2 # batch size for training (bs8 for 1GPU)
26 | base_lr: 0.005
27 | epochs: 300
28 | start_epoch: 0
29 | stop_interval: 100 # stop when the best result is not updated for "stop_interval" epochs
30 | index_split: -1 # index for determining the params group with 10x learning rate
31 | power: 0.9 # 0 means no decay
32 | momentum: 0.9
33 | weight_decay: 0.0001
34 | warmup: False
35 | # Viz & Save & Resume
36 | print_freq: 10
37 | save_freq: 10
38 | resume: # path to latest checkpoint (default: none, such as epoch_10.pth)
39 | # Validate
40 | evaluate: True
41 | SubEpoch_val: True # val at the half epoch
42 | fix_random_seed_val: True
43 | batch_size_val: 1
44 | resized_val: True
45 | ori_resize: True # use original label for evaluation
46 | # Else
47 | workers: 8
48 | fix_bn: True
49 | manual_seed: 321
50 | seed_deterministic: False
51 | zoom_factor: 8 # zoom factor for final prediction during training, be in [1, 2, 4, 8]
52 |
53 | Method:
54 | layers: 50
55 | vgg: False
56 | aux_weight1: 1.0
57 | aux_weight2: 1.0
58 | low_fea: 'layer2' # low_fea for computing the Gram matrix
59 | kshot_trans_dim: 2 # K-shot dimensionality reduction
60 | merge: 'final' # fusion scheme for GFSS ('base' Eq(S1) | 'final' Eq(18) )
61 | merge_tau: 0.9 # fusion threshold tau
62 |
63 | Test_Finetune:
64 | weight: best_5shot.pth # load weight for fine-tuning or testing (such as win8_train_epoch_124_0.7262.pth)
65 | ann_type: 'mask' # mask/bbox
66 |
67 |
68 |
69 | ## deprecated multi-processing training
70 | # Distributed:
71 | # dist_url: tcp://127.0.0.1:6789
72 | # dist_backend: 'nccl'
73 | # multiprocessing_distributed: False
74 | # world_size: 1
75 | # rank: 0
76 | # use_apex: False
77 | # opt_level: 'O0'
78 | # keep_batchnorm_fp32:
79 | # loss_scale:
80 |
81 |
--------------------------------------------------------------------------------
/config/pascal/pascal_split1_vgg_manet.yaml:
--------------------------------------------------------------------------------
1 | Data:
2 | data_root: ../data/VOCdevkit2012/VOC2012
3 | train_list: ./lists/pascal/voc_sbd_merge_noduplicate.txt
4 | val_list: ./lists/pascal/val.txt
5 | classes: 2
6 |
7 |
8 | Train:
9 | # Aug
10 | train_h: 473
11 | train_w: 473
12 | val_size: 473
13 | scale_min: 0.9 # minimum random scale
14 | scale_max: 1.1 # maximum random scale
15 | rotate_min: -10 # minimum random rotate
16 | rotate_max: 10 # maximum random rotate
17 | ignore_label: 255
18 | padding_label: 255
19 | # Dataset & Mode
20 | split: 1
21 | shot: 1
22 | data_set: 'pascal'
23 | use_split_coco: False # True means FWB setting
24 | # Optimizer
25 | batch_size: 2 # batch size for training (bs8 for 1GPU)
26 | base_lr: 0.005
27 | epochs: 300
28 | start_epoch: 0
29 | stop_interval: 100 # stop when the best result is not updated for "stop_interval" epochs
30 | index_split: -1 # index for determining the params group with 10x learning rate
31 | power: 0.9 # 0 means no decay
32 | momentum: 0.9
33 | weight_decay: 0.0001
34 | warmup: False
35 | # Viz & Save & Resume
36 | print_freq: 10
37 | save_freq: 10
38 | resume: # path to latest checkpoint (default: none, such as epoch_10.pth)
39 | # Validate
40 | evaluate: True
41 | SubEpoch_val: True # val at the half epoch
42 | fix_random_seed_val: True
43 | batch_size_val: 1
44 | resized_val: True
45 | ori_resize: True # use original label for evaluation
46 | # Else
47 | workers: 8
48 | fix_bn: True
49 | manual_seed: 321
50 | seed_deterministic: False
51 | zoom_factor: 8 # zoom factor for final prediction during training, be in [1, 2, 4, 8]
52 |
53 | Method:
54 | layers: 50
55 | vgg: True
56 | aux_weight1: 1.0
57 | aux_weight2: 1.0
58 | low_fea: 'layer2' # low_fea for computing the Gram matrix
59 | kshot_trans_dim: 2 # K-shot dimensionality reduction
60 | merge: 'final' # fusion scheme for GFSS ('base' Eq(S1) | 'final' Eq(18) )
61 | merge_tau: 0.9 # fusion threshold tau
62 |
63 | Test_Finetune:
64 | weight: best_1shot.pth # load weight for fine-tuning or testing (such as win8_train_epoch_124_0.7262.pth)
65 | ann_type: 'mask' # mask/bbox
66 |
67 |
68 |
69 | ## deprecated multi-processing training
70 | # Distributed:
71 | # dist_url: tcp://127.0.0.1:6789
72 | # dist_backend: 'nccl'
73 | # multiprocessing_distributed: False
74 | # world_size: 1
75 | # rank: 0
76 | # use_apex: False
77 | # opt_level: 'O0'
78 | # keep_batchnorm_fp32:
79 | # loss_scale:
80 |
81 |
--------------------------------------------------------------------------------
/config/pascal/pascal_split1_vgg_manet_5s.yaml:
--------------------------------------------------------------------------------
1 | Data:
2 | data_root: ../data/VOCdevkit2012/VOC2012
3 | train_list: ./lists/pascal/voc_sbd_merge_noduplicate.txt
4 | val_list: ./lists/pascal/val.txt
5 | classes: 2
6 |
7 |
8 | Train:
9 | # Aug
10 | train_h: 473
11 | train_w: 473
12 | val_size: 473
13 | scale_min: 0.9 # minimum random scale
14 | scale_max: 1.1 # maximum random scale
15 | rotate_min: -10 # minimum random rotate
16 | rotate_max: 10 # maximum random rotate
17 | ignore_label: 255
18 | padding_label: 255
19 | # Dataset & Mode
20 | split: 1
21 | shot: 5
22 | data_set: 'pascal'
23 | use_split_coco: False # True means FWB setting
24 | # Optimizer
25 | batch_size: 2 # batch size for training (bs8 for 1GPU)
26 | base_lr: 0.005
27 | epochs: 300
28 | start_epoch: 0
29 | stop_interval: 100 # stop when the best result is not updated for "stop_interval" epochs
30 | index_split: -1 # index for determining the params group with 10x learning rate
31 | power: 0.9 # 0 means no decay
32 | momentum: 0.9
33 | weight_decay: 0.0001
34 | warmup: False
35 | # Viz & Save & Resume
36 | print_freq: 10
37 | save_freq: 10
38 | resume: # path to latest checkpoint (default: none, such as epoch_10.pth)
39 | # Validate
40 | evaluate: True
41 | SubEpoch_val: True # val at the half epoch
42 | fix_random_seed_val: True
43 | batch_size_val: 1
44 | resized_val: True
45 | ori_resize: True # use original label for evaluation
46 | # Else
47 | workers: 8
48 | fix_bn: True
49 | manual_seed: 321
50 | seed_deterministic: False
51 | zoom_factor: 8 # zoom factor for final prediction during training, be in [1, 2, 4, 8]
52 |
53 | Method:
54 | layers: 50
55 | vgg: True
56 | aux_weight1: 1.0
57 | aux_weight2: 1.0
58 | low_fea: 'layer2' # low_fea for computing the Gram matrix
59 | kshot_trans_dim: 2 # K-shot dimensionality reduction
60 | merge: 'final' # fusion scheme for GFSS ('base' Eq(S1) | 'final' Eq(18) )
61 | merge_tau: 0.9 # fusion threshold tau
62 |
63 | Test_Finetune:
64 | weight: best_5shot.pth # load weight for fine-tuning or testing (such as win8_train_epoch_124_0.7262.pth)
65 | ann_type: 'mask' # mask/bbox
66 |
67 |
68 |
69 | ## deprecated multi-processing training
70 | # Distributed:
71 | # dist_url: tcp://127.0.0.1:6789
72 | # dist_backend: 'nccl'
73 | # multiprocessing_distributed: False
74 | # world_size: 1
75 | # rank: 0
76 | # use_apex: False
77 | # opt_level: 'O0'
78 | # keep_batchnorm_fp32:
79 | # loss_scale:
80 |
81 |
--------------------------------------------------------------------------------
/config/pascal/pascal_split2_resnet50_manet.yaml:
--------------------------------------------------------------------------------
1 | Data:
2 | data_root: ../data/VOCdevkit2012/VOC2012
3 | train_list: ./lists/pascal/voc_sbd_merge_noduplicate.txt
4 | val_list: ./lists/pascal/val.txt
5 | classes: 2
6 |
7 |
8 | Train:
9 | # Aug
10 | train_h: 473
11 | train_w: 473
12 | val_size: 473
13 | scale_min: 0.9 # minimum random scale
14 | scale_max: 1.1 # maximum random scale
15 | rotate_min: -10 # minimum random rotate
16 | rotate_max: 10 # maximum random rotate
17 | ignore_label: 255
18 | padding_label: 255
19 | # Dataset & Mode
20 | split: 2
21 | shot: 1
22 | data_set: 'pascal'
23 | use_split_coco: False # True means FWB setting
24 | # Optimizer
25 | batch_size: 2 # batch size for training (bs8 for 1GPU)
26 | base_lr: 0.005
27 | epochs: 300
28 | start_epoch: 0
29 | stop_interval: 100 # stop when the best result is not updated for "stop_interval" epochs
30 | index_split: -1 # index for determining the params group with 10x learning rate
31 | power: 0.9 # 0 means no decay
32 | momentum: 0.9
33 | weight_decay: 0.0001
34 | warmup: False
35 | # Viz & Save & Resume
36 | print_freq: 10
37 | save_freq: 10
38 | resume: # path to latest checkpoint (default: none, such as epoch_10.pth)
39 | # Validate
40 | evaluate: True
41 | SubEpoch_val: True # val at the half epoch
42 | fix_random_seed_val: True
43 | batch_size_val: 1
44 | resized_val: True
45 | ori_resize: True # use original label for evaluation
46 | # Else
47 | workers: 8
48 | fix_bn: True
49 | manual_seed: 321
50 | seed_deterministic: False
51 | zoom_factor: 8 # zoom factor for final prediction during training, be in [1, 2, 4, 8]
52 |
53 | Method:
54 | layers: 50
55 | vgg: False
56 | aux_weight1: 1.0
57 | aux_weight2: 1.0
58 | low_fea: 'layer2' # low_fea for computing the Gram matrix
59 | kshot_trans_dim: 2 # K-shot dimensionality reduction
60 | merge: 'final' # fusion scheme for GFSS ('base' Eq(S1) | 'final' Eq(18) )
61 | merge_tau: 0.9 # fusion threshold tau
62 |
63 | Test_Finetune:
64 | weight: best_1shot.pth # load weight for fine-tuning or testing (such as win8_train_epoch_47_0.6718.pth)
65 | ann_type: 'mask' # mask/bbox
66 |
67 |
68 |
69 | ## deprecated multi-processing training
70 | # Distributed:
71 | # dist_url: tcp://127.0.0.1:6789
72 | # dist_backend: 'nccl'
73 | # multiprocessing_distributed: False
74 | # world_size: 1
75 | # rank: 0
76 | # use_apex: False
77 | # opt_level: 'O0'
78 | # keep_batchnorm_fp32:
79 | # loss_scale:
80 |
81 |
--------------------------------------------------------------------------------
/config/pascal/pascal_split2_resnet50_manet_5s.yaml:
--------------------------------------------------------------------------------
1 | Data:
2 | data_root: ../data/VOCdevkit2012/VOC2012
3 | train_list: ./lists/pascal/voc_sbd_merge_noduplicate.txt
4 | val_list: ./lists/pascal/val.txt
5 | classes: 2
6 |
7 |
8 | Train:
9 | # Aug
10 | train_h: 473
11 | train_w: 473
12 | val_size: 473
13 | scale_min: 0.9 # minimum random scale
14 | scale_max: 1.1 # maximum random scale
15 | rotate_min: -10 # minimum random rotate
16 | rotate_max: 10 # maximum random rotate
17 | ignore_label: 255
18 | padding_label: 255
19 | # Dataset & Mode
20 | split: 2
21 | shot: 5
22 | data_set: 'pascal'
23 | use_split_coco: False # True means FWB setting
24 | # Optimizer
25 | batch_size: 2 # batch size for training (bs8 for 1GPU)
26 | base_lr: 0.005
27 | epochs: 300
28 | start_epoch: 0
29 | stop_interval: 100 # stop when the best result is not updated for "stop_interval" epochs
30 | index_split: -1 # index for determining the params group with 10x learning rate
31 | power: 0.9 # 0 means no decay
32 | momentum: 0.9
33 | weight_decay: 0.0001
34 | warmup: False
35 | # Viz & Save & Resume
36 | print_freq: 10
37 | save_freq: 10
38 | resume: # path to latest checkpoint (default: none, such as epoch_10.pth)
39 | # Validate
40 | evaluate: True
41 | SubEpoch_val: True # val at the half epoch
42 | fix_random_seed_val: True
43 | batch_size_val: 1
44 | resized_val: True
45 | ori_resize: True # use original label for evaluation
46 | # Else
47 | workers: 8
48 | fix_bn: True
49 | manual_seed: 321
50 | seed_deterministic: False
51 | zoom_factor: 8 # zoom factor for final prediction during training, be in [1, 2, 4, 8]
52 |
53 | Method:
54 | layers: 50
55 | vgg: False
56 | aux_weight1: 1.0
57 | aux_weight2: 1.0
58 | low_fea: 'layer2' # low_fea for computing the Gram matrix
59 | kshot_trans_dim: 2 # K-shot dimensionality reduction
60 | merge: 'final' # fusion scheme for GFSS ('base' Eq(S1) | 'final' Eq(18) )
61 | merge_tau: 0.9 # fusion threshold tau
62 |
63 | Test_Finetune:
64 | weight: best_5shot.pth # load weight for fine-tuning or testing (such as win8_train_epoch_47_0.6718.pth)
65 | ann_type: 'mask' # mask/bbox
66 |
67 |
68 |
69 | ## deprecated multi-processing training
70 | # Distributed:
71 | # dist_url: tcp://127.0.0.1:6789
72 | # dist_backend: 'nccl'
73 | # multiprocessing_distributed: False
74 | # world_size: 1
75 | # rank: 0
76 | # use_apex: False
77 | # opt_level: 'O0'
78 | # keep_batchnorm_fp32:
79 | # loss_scale:
80 |
81 |
--------------------------------------------------------------------------------
/config/pascal/pascal_split2_vgg_manet.yaml:
--------------------------------------------------------------------------------
1 | Data:
2 | data_root: ../data/VOCdevkit2012/VOC2012
3 | train_list: ./lists/pascal/voc_sbd_merge_noduplicate.txt
4 | val_list: ./lists/pascal/val.txt
5 | classes: 2
6 |
7 |
8 | Train:
9 | # Aug
10 | train_h: 473
11 | train_w: 473
12 | val_size: 473
13 | scale_min: 0.9 # minimum random scale
14 | scale_max: 1.1 # maximum random scale
15 | rotate_min: -10 # minimum random rotate
16 | rotate_max: 10 # maximum random rotate
17 | ignore_label: 255
18 | padding_label: 255
19 | # Dataset & Mode
20 | split: 2
21 | shot: 1
22 | data_set: 'pascal'
23 | use_split_coco: False # True means FWB setting
24 | # Optimizer
25 | batch_size: 2 # batch size for training (bs8 for 1GPU)
26 | base_lr: 0.005
27 | epochs: 300
28 | start_epoch: 0
29 | stop_interval: 100 # stop when the best result is not updated for "stop_interval" epochs
30 | index_split: -1 # index for determining the params group with 10x learning rate
31 | power: 0.9 # 0 means no decay
32 | momentum: 0.9
33 | weight_decay: 0.0001
34 | warmup: False
35 | # Viz & Save & Resume
36 | print_freq: 10
37 | save_freq: 10
38 | resume: # path to latest checkpoint (default: none, such as epoch_10.pth)
39 | # Validate
40 | evaluate: True
41 | SubEpoch_val: True # val at the half epoch
42 | fix_random_seed_val: True
43 | batch_size_val: 1
44 | resized_val: True
45 | ori_resize: True # use original label for evaluation
46 | # Else
47 | workers: 8
48 | fix_bn: True
49 | manual_seed: 321
50 | seed_deterministic: False
51 | zoom_factor: 8 # zoom factor for final prediction during training, be in [1, 2, 4, 8]
52 |
53 | Method:
54 | layers: 50
55 | vgg: True
56 | aux_weight1: 1.0
57 | aux_weight2: 1.0
58 | low_fea: 'layer2' # low_fea for computing the Gram matrix
59 | kshot_trans_dim: 2 # K-shot dimensionality reduction
60 | merge: 'final' # fusion scheme for GFSS ('base' Eq(S1) | 'final' Eq(18) )
61 | merge_tau: 0.9 # fusion threshold tau
62 |
63 | Test_Finetune:
64 | weight: best_1shot.pth # load weight for fine-tuning or testing (such as win8_train_epoch_47_0.6718.pth)
65 | ann_type: 'mask' # mask/bbox
66 |
67 |
68 |
69 | ## deprecated multi-processing training
70 | # Distributed:
71 | # dist_url: tcp://127.0.0.1:6789
72 | # dist_backend: 'nccl'
73 | # multiprocessing_distributed: False
74 | # world_size: 1
75 | # rank: 0
76 | # use_apex: False
77 | # opt_level: 'O0'
78 | # keep_batchnorm_fp32:
79 | # loss_scale:
80 |
81 |
--------------------------------------------------------------------------------
/config/pascal/pascal_split2_vgg_manet_5s.yaml:
--------------------------------------------------------------------------------
1 | Data:
2 | data_root: ../data/VOCdevkit2012/VOC2012
3 | train_list: ./lists/pascal/voc_sbd_merge_noduplicate.txt
4 | val_list: ./lists/pascal/val.txt
5 | classes: 2
6 |
7 |
8 | Train:
9 | # Aug
10 | train_h: 473
11 | train_w: 473
12 | val_size: 473
13 | scale_min: 0.9 # minimum random scale
14 | scale_max: 1.1 # maximum random scale
15 | rotate_min: -10 # minimum random rotate
16 | rotate_max: 10 # maximum random rotate
17 | ignore_label: 255
18 | padding_label: 255
19 | # Dataset & Mode
20 | split: 2
21 | shot: 5
22 | data_set: 'pascal'
23 | use_split_coco: False # True means FWB setting
24 | # Optimizer
25 | batch_size: 2 # batch size for training (bs8 for 1GPU)
26 | base_lr: 0.005
27 | epochs: 300
28 | start_epoch: 0
29 | stop_interval: 100 # stop when the best result is not updated for "stop_interval" epochs
30 | index_split: -1 # index for determining the params group with 10x learning rate
31 | power: 0.9 # 0 means no decay
32 | momentum: 0.9
33 | weight_decay: 0.0001
34 | warmup: False
35 | # Viz & Save & Resume
36 | print_freq: 10
37 | save_freq: 10
38 | resume: # path to latest checkpoint (default: none, such as epoch_10.pth)
39 | # Validate
40 | evaluate: True
41 | SubEpoch_val: True # val at the half epoch
42 | fix_random_seed_val: True
43 | batch_size_val: 1
44 | resized_val: True
45 | ori_resize: True # use original label for evaluation
46 | # Else
47 | workers: 8
48 | fix_bn: True
49 | manual_seed: 321
50 | seed_deterministic: False
51 | zoom_factor: 8 # zoom factor for final prediction during training, be in [1, 2, 4, 8]
52 |
53 | Method:
54 | layers: 50
55 | vgg: True
56 | aux_weight1: 1.0
57 | aux_weight2: 1.0
58 | low_fea: 'layer2' # low_fea for computing the Gram matrix
59 | kshot_trans_dim: 2 # K-shot dimensionality reduction
60 | merge: 'final' # fusion scheme for GFSS ('base' Eq(S1) | 'final' Eq(18) )
61 | merge_tau: 0.9 # fusion threshold tau
62 |
63 | Test_Finetune:
64 | weight: best_5shot.pth # load weight for fine-tuning or testing (such as win8_train_epoch_47_0.6718.pth)
65 | ann_type: 'mask' # mask/bbox
66 |
67 |
68 |
69 | ## deprecated multi-processing training
70 | # Distributed:
71 | # dist_url: tcp://127.0.0.1:6789
72 | # dist_backend: 'nccl'
73 | # multiprocessing_distributed: False
74 | # world_size: 1
75 | # rank: 0
76 | # use_apex: False
77 | # opt_level: 'O0'
78 | # keep_batchnorm_fp32:
79 | # loss_scale:
80 |
81 |
--------------------------------------------------------------------------------
/config/pascal/pascal_split3_resnet50_manet.yaml:
--------------------------------------------------------------------------------
1 | Data:
2 | data_root: ../data/VOCdevkit2012/VOC2012
3 | train_list: ./lists/pascal/voc_sbd_merge_noduplicate.txt
4 | val_list: ./lists/pascal/val.txt
5 | classes: 2
6 |
7 |
8 | Train:
9 | # Aug
10 | train_h: 473
11 | train_w: 473
12 | val_size: 473
13 | scale_min: 0.9 # minimum random scale
14 | scale_max: 1.1 # maximum random scale
15 | rotate_min: -10 # minimum random rotate
16 | rotate_max: 10 # maximum random rotate
17 | ignore_label: 255
18 | padding_label: 255
19 | # Dataset & Mode
20 | split: 3
21 | shot: 1
22 | data_set: 'pascal'
23 | use_split_coco: False # True means FWB setting
24 | # Optimizer
25 | batch_size: 2 # batch size for training (bs8 for 1GPU)
26 | base_lr: 0.005
27 | epochs: 300
28 | start_epoch: 0
29 | stop_interval: 100 # stop when the best result is not updated for "stop_interval" epochs
30 | index_split: -1 # index for determining the params group with 10x learning rate
31 | power: 0.9 # 0 means no decay
32 | momentum: 0.9
33 | weight_decay: 0.0001
34 | warmup: False
35 | # Viz & Save & Resume
36 | print_freq: 10
37 | save_freq: 10
38 | resume: # path to latest checkpoint (default: none, such as epoch_10.pth)
39 | # Validate
40 | evaluate: True
41 | SubEpoch_val: True # val at the half epoch
42 | fix_random_seed_val: True
43 | batch_size_val: 1
44 | resized_val: True
45 | ori_resize: True # use original label for evaluation
46 | # Else
47 | workers: 8
48 | fix_bn: True
49 | manual_seed: 321
50 | seed_deterministic: False
51 | zoom_factor: 8 # zoom factor for final prediction during training, be in [1, 2, 4, 8]
52 |
53 | Method:
54 | layers: 50
55 | vgg: False
56 | aux_weight1: 1.0
57 | aux_weight2: 1.0
58 | low_fea: 'layer2' # low_fea for computing the Gram matrix
59 | kshot_trans_dim: 2 # K-shot dimensionality reduction
60 | merge: 'final' # fusion scheme for GFSS ('base' Eq(S1) | 'final' Eq(18) )
61 | merge_tau: 0.9 # fusion threshold tau
62 |
63 | Test_Finetune:
64 | weight: best_1shot.pth # load weight for fine-tuning or testing (such as win8_train_epoch_93_0.6053.pth)
65 | ann_type: 'mask' # mask/bbox
66 |
67 |
68 |
69 | ## deprecated multi-processing training
70 | # Distributed:
71 | # dist_url: tcp://127.0.0.1:6789
72 | # dist_backend: 'nccl'
73 | # multiprocessing_distributed: False
74 | # world_size: 1
75 | # rank: 0
76 | # use_apex: False
77 | # opt_level: 'O0'
78 | # keep_batchnorm_fp32:
79 | # loss_scale:
80 |
81 |
--------------------------------------------------------------------------------
/config/pascal/pascal_split3_resnet50_manet_5s.yaml:
--------------------------------------------------------------------------------
1 | Data:
2 | data_root: ../data/VOCdevkit2012/VOC2012
3 | train_list: ./lists/pascal/voc_sbd_merge_noduplicate.txt
4 | val_list: ./lists/pascal/val.txt
5 | classes: 2
6 |
7 |
8 | Train:
9 | # Aug
10 | train_h: 473
11 | train_w: 473
12 | val_size: 473
13 | scale_min: 0.9 # minimum random scale
14 | scale_max: 1.1 # maximum random scale
15 | rotate_min: -10 # minimum random rotate
16 | rotate_max: 10 # maximum random rotate
17 | ignore_label: 255
18 | padding_label: 255
19 | # Dataset & Mode
20 | split: 3
21 | shot: 5
22 | data_set: 'pascal'
23 | use_split_coco: False # True means FWB setting
24 | # Optimizer
25 | batch_size: 2 # batch size for training (bs8 for 1GPU)
26 | base_lr: 0.005
27 | epochs: 300
28 | start_epoch: 0
29 | stop_interval: 100 # stop when the best result is not updated for "stop_interval" epochs
30 | index_split: -1 # index for determining the params group with 10x learning rate
31 | power: 0.9 # 0 means no decay
32 | momentum: 0.9
33 | weight_decay: 0.0001
34 | warmup: False
35 | # Viz & Save & Resume
36 | print_freq: 10
37 | save_freq: 10
38 | resume: # path to latest checkpoint (default: none, such as epoch_10.pth)
39 | # Validate
40 | evaluate: True
41 | SubEpoch_val: True # val at the half epoch
42 | fix_random_seed_val: True
43 | batch_size_val: 1
44 | resized_val: True
45 | ori_resize: True # use original label for evaluation
46 | # Else
47 | workers: 8
48 | fix_bn: True
49 | manual_seed: 321
50 | seed_deterministic: False
51 | zoom_factor: 8 # zoom factor for final prediction during training, be in [1, 2, 4, 8]
52 |
53 | Method:
54 | layers: 50
55 | vgg: False
56 | aux_weight1: 1.0
57 | aux_weight2: 1.0
58 | low_fea: 'layer2' # low_fea for computing the Gram matrix
59 | kshot_trans_dim: 2 # K-shot dimensionality reduction
60 | merge: 'final' # fusion scheme for GFSS ('base' Eq(S1) | 'final' Eq(18) )
61 | merge_tau: 0.9 # fusion threshold tau
62 |
63 | Test_Finetune:
64 | weight: best_5shot.pth # load weight for fine-tuning or testing (such as win8_train_epoch_93_0.6053.pth)
65 | ann_type: 'mask' # mask/bbox
66 |
67 |
68 |
69 | ## deprecated multi-processing training
70 | # Distributed:
71 | # dist_url: tcp://127.0.0.1:6789
72 | # dist_backend: 'nccl'
73 | # multiprocessing_distributed: False
74 | # world_size: 1
75 | # rank: 0
76 | # use_apex: False
77 | # opt_level: 'O0'
78 | # keep_batchnorm_fp32:
79 | # loss_scale:
80 |
81 |
--------------------------------------------------------------------------------
/config/pascal/pascal_split3_vgg_manet.yaml:
--------------------------------------------------------------------------------
1 | Data:
2 | data_root: ../data/VOCdevkit2012/VOC2012
3 | train_list: ./lists/pascal/voc_sbd_merge_noduplicate.txt
4 | val_list: ./lists/pascal/val.txt
5 | classes: 2
6 |
7 |
8 | Train:
9 | # Aug
10 | train_h: 473
11 | train_w: 473
12 | val_size: 473
13 | scale_min: 0.9 # minimum random scale
14 | scale_max: 1.1 # maximum random scale
15 | rotate_min: -10 # minimum random rotate
16 | rotate_max: 10 # maximum random rotate
17 | ignore_label: 255
18 | padding_label: 255
19 | # Dataset & Mode
20 | split: 3
21 | shot: 1
22 | data_set: 'pascal'
23 | use_split_coco: False # True means FWB setting
24 | # Optimizer
25 | batch_size: 2 # batch size for training (bs8 for 1GPU)
26 | base_lr: 0.005
27 | epochs: 300
28 | start_epoch: 0
29 | stop_interval: 100 # stop when the best result is not updated for "stop_interval" epochs
30 | index_split: -1 # index for determining the params group with 10x learning rate
31 | power: 0.9 # 0 means no decay
32 | momentum: 0.9
33 | weight_decay: 0.0001
34 | warmup: False
35 | # Viz & Save & Resume
36 | print_freq: 10
37 | save_freq: 10
38 | resume: # path to latest checkpoint (default: none, such as epoch_10.pth)
39 | # Validate
40 | evaluate: True
41 | SubEpoch_val: True # val at the half epoch
42 | fix_random_seed_val: True
43 | batch_size_val: 1
44 | resized_val: True
45 | ori_resize: True # use original label for evaluation
46 | # Else
47 | workers: 8
48 | fix_bn: True
49 | manual_seed: 321
50 | seed_deterministic: False
51 | zoom_factor: 8 # zoom factor for final prediction during training, be in [1, 2, 4, 8]
52 |
53 | Method:
54 | layers: 50
55 | vgg: True
56 | aux_weight1: 1.0
57 | aux_weight2: 1.0
58 | low_fea: 'layer2' # low_fea for computing the Gram matrix
59 | kshot_trans_dim: 2 # K-shot dimensionality reduction
60 | merge: 'final' # fusion scheme for GFSS ('base' Eq(S1) | 'final' Eq(18) )
61 | merge_tau: 0.9 # fusion threshold tau
62 |
63 | Test_Finetune:
64 | weight: best_1shot.pth # load weight for fine-tuning or testing (such as win8_train_epoch_93_0.6053.pth)
65 | ann_type: 'mask' # mask/bbox
66 |
67 |
68 |
69 | ## deprecated multi-processing training
70 | # Distributed:
71 | # dist_url: tcp://127.0.0.1:6789
72 | # dist_backend: 'nccl'
73 | # multiprocessing_distributed: False
74 | # world_size: 1
75 | # rank: 0
76 | # use_apex: False
77 | # opt_level: 'O0'
78 | # keep_batchnorm_fp32:
79 | # loss_scale:
80 |
81 |
--------------------------------------------------------------------------------
/config/pascal/pascal_split3_vgg_manet_5s.yaml:
--------------------------------------------------------------------------------
1 | Data:
2 | data_root: ../data/VOCdevkit2012/VOC2012
3 | train_list: ./lists/pascal/voc_sbd_merge_noduplicate.txt
4 | val_list: ./lists/pascal/val.txt
5 | classes: 2
6 |
7 |
8 | Train:
9 | # Aug
10 | train_h: 473
11 | train_w: 473
12 | val_size: 473
13 | scale_min: 0.9 # minimum random scale
14 | scale_max: 1.1 # maximum random scale
15 | rotate_min: -10 # minimum random rotate
16 | rotate_max: 10 # maximum random rotate
17 | ignore_label: 255
18 | padding_label: 255
19 | # Dataset & Mode
20 | split: 3
21 | shot: 5
22 | data_set: 'pascal'
23 | use_split_coco: False # True means FWB setting
24 | # Optimizer
25 | batch_size: 2 # batch size for training (bs8 for 1GPU)
26 | base_lr: 0.005
27 | epochs: 300
28 | start_epoch: 0
29 | stop_interval: 100 # stop when the best result is not updated for "stop_interval" epochs
30 | index_split: -1 # index for determining the params group with 10x learning rate
31 | power: 0.9 # 0 means no decay
32 | momentum: 0.9
33 | weight_decay: 0.0001
34 | warmup: False
35 | # Viz & Save & Resume
36 | print_freq: 10
37 | save_freq: 10
38 | resume: # path to latest checkpoint (default: none, such as epoch_10.pth)
39 | # Validate
40 | evaluate: True
41 | SubEpoch_val: True # val at the half epoch
42 | fix_random_seed_val: True
43 | batch_size_val: 1
44 | resized_val: True
45 | ori_resize: True # use original label for evaluation
46 | # Else
47 | workers: 8
48 | fix_bn: True
49 | manual_seed: 321
50 | seed_deterministic: False
51 | zoom_factor: 8 # zoom factor for final prediction during training, be in [1, 2, 4, 8]
52 |
53 | Method:
54 | layers: 50
55 | vgg: True
56 | aux_weight1: 1.0
57 | aux_weight2: 1.0
58 | low_fea: 'layer2' # low_fea for computing the Gram matrix
59 | kshot_trans_dim: 2 # K-shot dimensionality reduction
60 | merge: 'final' # fusion scheme for GFSS ('base' Eq(S1) | 'final' Eq(18) )
61 | merge_tau: 0.9 # fusion threshold tau
62 |
63 | Test_Finetune:
64 | weight: best_5shot.pth # load weight for fine-tuning or testing (such as win8_train_epoch_93_0.6053.pth)
65 | ann_type: 'mask' # mask/bbox
66 |
67 |
68 |
69 | ## deprecated multi-processing training
70 | # Distributed:
71 | # dist_url: tcp://127.0.0.1:6789
72 | # dist_backend: 'nccl'
73 | # multiprocessing_distributed: False
74 | # world_size: 1
75 | # rank: 0
76 | # use_apex: False
77 | # opt_level: 'O0'
78 | # keep_batchnorm_fp32:
79 | # loss_scale:
80 |
81 |
--------------------------------------------------------------------------------
/env.yaml:
--------------------------------------------------------------------------------
1 | name: hmnet
2 | channels:
3 | - pytorch
4 | - conda-forge
5 | - https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/main
6 | - https://mirrors.tuna.tsinghua.edu.cn/anaconda/cloud/conda-forge/
7 | - https://mirrors.tuna.tsinghua.edu.cn/anaconda/cloud/pytorch/
8 | - https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/main/
9 | - https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/free/
10 | - defaults
11 | dependencies:
12 | - _libgcc_mutex=0.1=main
13 | - _openmp_mutex=5.1=1_gnu
14 | - blas=1.0=mkl
15 | - brotli-python=1.0.9=py310hd8f1fbe_7
16 | - bzip2=1.0.8=h5eee18b_5
17 | - ca-certificates=2024.2.2=hbcca054_0
18 | - certifi=2024.2.2=pyhd8ed1ab_0
19 | - charset-normalizer=3.3.2=pyhd8ed1ab_0
20 | - cudatoolkit=11.6.0=hecad31d_10
21 | - ffmpeg=4.3=hf484d3e_0
22 | - freetype=2.10.4=h0708190_1
23 | - giflib=5.2.1=h36c2ea0_2
24 | - gmp=6.2.1=h58526e2_0
25 | - gnutls=3.6.13=h85f3911_1
26 | - idna=3.7=pyhd8ed1ab_0
27 | - intel-openmp=2021.4.0=h06a4308_3561
28 | - jbig=2.1=h7f98852_2003
29 | - jpeg=9e=h166bdaf_1
30 | - lame=3.100=h7f98852_1001
31 | - lcms2=2.12=hddcbb42_0
32 | - ld_impl_linux-64=2.38=h1181459_1
33 | - lerc=2.2.1=h9c3ff4c_0
34 | - libdeflate=1.7=h7f98852_5
35 | - libffi=3.4.4=h6a678d5_0
36 | - libgcc-ng=11.2.0=h1234567_1
37 | - libgomp=11.2.0=h1234567_1
38 | - libiconv=1.17=h166bdaf_0
39 | - libpng=1.6.37=h21135ba_2
40 | - libstdcxx-ng=11.2.0=h1234567_1
41 | - libtiff=4.3.0=hf544144_1
42 | - libuuid=1.41.5=h5eee18b_0
43 | - libwebp=1.2.2=h3452ae3_0
44 | - libwebp-base=1.2.2=h7f98852_1
45 | - lz4-c=1.9.3=h9c3ff4c_1
46 | - mkl=2021.4.0=h06a4308_640
47 | - mkl-service=2.4.0=py310ha2c4b55_0
48 | - mkl_fft=1.3.1=py310h2b4bcf5_1
49 | - mkl_random=1.2.2=py310h00e6091_0
50 | - ncurses=6.4=h6a678d5_0
51 | - nettle=3.6=he412f7d_0
52 | - numpy=1.24.3=py310hd5efca6_0
53 | - numpy-base=1.24.3=py310h8e6c178_0
54 | - openh264=2.1.1=h780b84a_0
55 | - openssl=3.0.13=h7f8727e_0
56 | - pillow=9.4.0=py310h6a678d5_0
57 | - pip=23.3.1=py310h06a4308_0
58 | - pysocks=1.7.1=pyha2e5f31_6
59 | - python=3.10.14=h955ad1f_0
60 | - python_abi=3.10=2_cp310
61 | - pytorch=1.12.0=py3.10_cuda11.6_cudnn8.3.2_0
62 | - pytorch-mutex=1.0=cuda
63 | - readline=8.2=h5eee18b_0
64 | - requests=2.31.0=pyhd8ed1ab_0
65 | - setuptools=68.2.2=py310h06a4308_0
66 | - six=1.16.0=pyh6c4a22f_0
67 | - sqlite=3.41.2=h5eee18b_0
68 | - tk=8.6.12=h1ccaba5_0
69 | - torchaudio=0.12.0=py310_cu116
70 | - torchvision=0.13.0=py310_cu116
71 | - typing_extensions=4.11.0=pyha770c72_0
72 | - urllib3=2.2.1=pyhd8ed1ab_0
73 | - wheel=0.41.2=py310h06a4308_0
74 | - xz=5.4.6=h5eee18b_0
75 | - zlib=1.2.13=h5eee18b_0
76 | - zstd=1.5.0=ha95c52a_0
77 | - pip:
78 | - addict==2.4.0
79 | - args==0.1.0
80 | - chardet==5.2.0
81 | - clint==0.5.1
82 | - cloudpickle==3.0.0
83 | - contourpy==1.2.1
84 | - coverage==7.5.0
85 | - cycler==0.12.1
86 | - einops==0.7.0
87 | - exceptiongroup==1.2.1
88 | - filelock==3.13.4
89 | - fonttools==4.51.0
90 | - fsspec==2024.3.1
91 | - huggingface-hub==0.22.2
92 | - imageio==2.34.1
93 | - importlib-metadata==7.1.0
94 | - iniconfig==2.0.0
95 | - joblib==1.4.0
96 | - jsonpatch==1.33
97 | - jsonpointer==2.4
98 | - kiwisolver==1.4.5
99 | - lazy-loader==0.4
100 | - mamba==0.11.3
101 | - mamba-ssm==1.2.0.post1
102 | - matplotlib==3.8.4
103 | - mmcls==0.25.0
104 | - mmcv-full==1.6.1
105 | - mmsegmentation==0.27.0
106 | - networkx==3.3
107 | - ninja==1.11.1.1
108 | - opencv-python==4.5.5.64
109 | - packaging==24.0
110 | - pandas==2.2.2
111 | - platformdirs==4.2.1
112 | - pluggy==1.5.0
113 | - prettytable==3.10.0
114 | - protobuf==5.26.1
115 | - pyparsing==3.1.2
116 | - pytest==8.1.1
117 | - python-dateutil==2.9.0.post0
118 | - pytz==2024.1
119 | - pyyaml==6.0.1
120 | - regex==2024.4.16
121 | - safetensors==0.4.3
122 | - scikit-image==0.23.2
123 | - scikit-learn==1.4.2
124 | - scipy==1.13.0
125 | - seaborn==0.13.2
126 | - submitit==1.5.1
127 | - tensorboardx==2.6.2.2
128 | - termcolor==2.4.0
129 | - thop==0.1.1-2209072238
130 | - threadpoolctl==3.4.0
131 | - tifffile==2024.4.18
132 | - timm==0.4.12
133 | - tokenizers==0.19.1
134 | - tomli==2.0.1
135 | - tornado==6.4
136 | - tqdm==4.66.2
137 | - transformers==4.40.1
138 | - triton==2.3.0
139 | - tzdata==2024.1
140 | - visdom==0.2.4
141 | - wcwidth==0.2.13
142 | - websocket-client==1.8.0
143 | - yacs==0.1.8
144 | - yapf==0.40.2
145 | - zipp==3.18.1
146 |
--------------------------------------------------------------------------------
/lists/pascal/duplicate_removel.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | raw_path = './voc_sbd_merge.txt'
4 | new_path = './voc_sbd_merge_noduplicate.txt'
5 | lines = open(raw_path).readlines()
6 | new_f = open(new_path, 'w+')
7 |
8 |
9 | existing_lines = []
10 | for line in lines:
11 | if line not in existing_lines:
12 | existing_lines.append(line)
13 | new_f.write(line)
14 | print('Ori: {}, new: {}'.format(len(lines), len(existing_lines)))
15 | print('Finished.')
--------------------------------------------------------------------------------
/model/ASPP.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | import torch.utils.data
5 |
6 | class ASPP(nn.Module):
7 | def __init__(self, out_channels=256):
8 | super(ASPP, self).__init__()
9 | self.layer6_0 = nn.Sequential(
10 | nn.Conv2d(out_channels , out_channels, kernel_size=1, stride=1, padding=0, bias=True),
11 | nn.ReLU(),
12 | )
13 | self.layer6_1 = nn.Sequential(
14 | nn.Conv2d(out_channels , out_channels, kernel_size=1, stride=1, padding=0, bias=True),
15 | nn.ReLU(),
16 | )
17 | self.layer6_2 = nn.Sequential(
18 | nn.Conv2d(out_channels , out_channels , kernel_size=3, stride=1, padding=6,dilation=6, bias=True),
19 | nn.ReLU(),
20 | )
21 | self.layer6_3 = nn.Sequential(
22 | nn.Conv2d(out_channels , out_channels, kernel_size=3, stride=1, padding=12, dilation=12, bias=True),
23 | nn.ReLU(),
24 | )
25 | self.layer6_4 = nn.Sequential(
26 | nn.Conv2d(out_channels , out_channels , kernel_size=3, stride=1, padding=18, dilation=18, bias=True),
27 | nn.ReLU(),
28 | )
29 |
30 | self._init_weight()
31 |
32 | def _init_weight(self):
33 | for m in self.modules():
34 | if isinstance(m, nn.Conv2d):
35 | torch.nn.init.kaiming_normal_(m.weight)
36 | elif isinstance(m, nn.BatchNorm2d):
37 | m.weight.data.fill_(1)
38 | m.bias.data.zero_()
39 |
40 | def forward(self, x):
41 | feature_size = x.shape[-2:]
42 | global_feature = F.avg_pool2d(x, kernel_size=feature_size)
43 |
44 | global_feature = self.layer6_0(global_feature)
45 |
46 | global_feature = global_feature.expand(-1, -1, feature_size[0], feature_size[1])
47 | out = torch.cat(
48 | [global_feature, self.layer6_1(x), self.layer6_2(x), self.layer6_3(x), self.layer6_4(x)], dim=1)
49 | return out
--------------------------------------------------------------------------------
/model/PPM.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | class PPM(nn.Module):
6 | def __init__(self, in_dim, reduction_dim, bins):
7 | super(PPM, self).__init__()
8 | self.features = []
9 | for bin in bins:
10 | self.features.append(nn.Sequential(
11 | nn.AdaptiveAvgPool2d(bin),
12 | nn.Conv2d(in_dim, reduction_dim, kernel_size=1, bias=False),
13 | nn.BatchNorm2d(reduction_dim),
14 | nn.ReLU(inplace=True)
15 | ))
16 | self.features = nn.ModuleList(self.features)
17 |
18 | def forward(self, x):
19 | x_size = x.size()
20 | out = [x]
21 | for f in self.features:
22 | out.append(F.interpolate(f(x), x_size[2:], mode='bilinear', align_corners=True))
23 | return torch.cat(out, 1)
--------------------------------------------------------------------------------
/model/PSPNet.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | import torch.nn.functional as F
4 | from torch.nn import BatchNorm2d as BatchNorm
5 | import model.resnet as models
6 | import model.vgg as vgg_models
7 | from model.PPM import PPM
8 |
9 |
10 | def get_vgg16_layer(model):
11 | layer0_idx = range(0,7)
12 | layer1_idx = range(7,14)
13 | layer2_idx = range(14,24)
14 | layer3_idx = range(24,34)
15 | layer4_idx = range(34,43)
16 | layers_0 = []
17 | layers_1 = []
18 | layers_2 = []
19 | layers_3 = []
20 | layers_4 = []
21 | for idx in layer0_idx:
22 | layers_0 += [model.features[idx]]
23 | for idx in layer1_idx:
24 | layers_1 += [model.features[idx]]
25 | for idx in layer2_idx:
26 | layers_2 += [model.features[idx]]
27 | for idx in layer3_idx:
28 | layers_3 += [model.features[idx]]
29 | for idx in layer4_idx:
30 | layers_4 += [model.features[idx]]
31 | layer0 = nn.Sequential(*layers_0)
32 | layer1 = nn.Sequential(*layers_1)
33 | layer2 = nn.Sequential(*layers_2)
34 | layer3 = nn.Sequential(*layers_3)
35 | layer4 = nn.Sequential(*layers_4)
36 | return layer0,layer1,layer2,layer3,layer4
37 |
38 | class OneModel(nn.Module):
39 | def __init__(self, args):
40 | super(OneModel, self).__init__()
41 |
42 | self.layers = args.layers
43 | self.zoom_factor = args.zoom_factor
44 | self.vgg = args.vgg
45 | self.dataset = args.data_set
46 | self.criterion = nn.CrossEntropyLoss(ignore_index=args.ignore_label)
47 |
48 | self.pretrained = True
49 | self.classes = 16 if self.dataset=='pascal' else 61
50 |
51 | assert self.layers in [50, 101, 152]
52 |
53 | if self.vgg:
54 | print('INFO: Using VGG_16 bn')
55 | vgg_models.BatchNorm = BatchNorm
56 | vgg16 = vgg_models.vgg16_bn(pretrained=self.pretrained)
57 | print(vgg16)
58 | self.layer0, self.layer1, self.layer2, self.layer3, self.layer4 = get_vgg16_layer(vgg16)
59 | else:
60 | print('INFO: Using ResNet {}'.format(self.layers))
61 | if self.layers == 50:
62 | resnet = models.resnet50(pretrained=self.pretrained)
63 | elif self.layers == 101:
64 | resnet = models.resnet101(pretrained=self.pretrained)
65 | else:
66 | resnet = models.resnet152(pretrained=self.pretrained)
67 | self.layer0 = nn.Sequential(resnet.conv1, resnet.bn1, resnet.relu1, resnet.conv2, resnet.bn2, resnet.relu2, resnet.conv3, resnet.bn3, resnet.relu3, resnet.maxpool)
68 | self.layer1, self.layer2, self.layer3, self.layer4 = resnet.layer1, resnet.layer2, resnet.layer3, resnet.layer4
69 |
70 | for n, m in self.layer3.named_modules():
71 | if 'conv2' in n:
72 | m.dilation, m.padding, m.stride = (2, 2), (2, 2), (1, 1)
73 | elif 'downsample.0' in n:
74 | m.stride = (1, 1)
75 | for n, m in self.layer4.named_modules():
76 | if 'conv2' in n:
77 | m.dilation, m.padding, m.stride = (4, 4), (4, 4), (1, 1)
78 | elif 'downsample.0' in n:
79 | m.stride = (1, 1)
80 |
81 | # Base Learner
82 | self.encoder = nn.Sequential(self.layer0, self.layer1, self.layer2, self.layer3, self.layer4)
83 | fea_dim = 512 if self.vgg else 2048
84 | bins=(1, 2, 3, 6)
85 | self.ppm = PPM(fea_dim, int(fea_dim/len(bins)), bins)
86 | self.cls = nn.Sequential(
87 | nn.Conv2d(fea_dim*2, 512, kernel_size=3, padding=1, bias=False),
88 | nn.BatchNorm2d(512),
89 | nn.ReLU(inplace=True),
90 | nn.Dropout2d(p=0.1),
91 | nn.Conv2d(512, self.classes, kernel_size=1))
92 |
93 | def get_optim(self, model, args, LR):
94 | optimizer = torch.optim.SGD(
95 | [
96 | {'params': model.encoder.parameters()},
97 | {'params': model.ppm.parameters()},
98 | {'params': model.cls.parameters()},
99 | ], lr=LR, momentum=args.momentum, weight_decay=args.weight_decay)
100 | return optimizer
101 |
102 |
103 | def forward(self, x, y_m):
104 | x_size = x.size()
105 | h = int((x_size[2] - 1) / 8 * self.zoom_factor + 1)
106 | w = int((x_size[3] - 1) / 8 * self.zoom_factor + 1) # 473
107 |
108 | x = self.encoder(x)
109 | x = self.ppm(x)
110 | x = self.cls(x)
111 |
112 | if self.zoom_factor != 1:
113 | x = F.interpolate(x, size=(h, w), mode='bilinear', align_corners=True)
114 |
115 | if self.training:
116 | main_loss = self.criterion(x, y_m.long())
117 | return x.max(1)[1], main_loss
118 | else:
119 | return x
120 |
121 |
--------------------------------------------------------------------------------
/model/backbone_res.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import Tensor
3 | import torch.nn as nn
4 | # from torchvision.models.utils import load_state_dict_from_url
5 | from torch.hub import load_state_dict_from_url
6 | from typing import Type, Any, Callable, Union, List, Optional
7 | from collections import OrderedDict
8 |
9 |
10 | __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101',
11 | 'resnet152', 'resnext50_32x4d', 'resnext101_32x8d',
12 | 'wide_resnet50_2', 'wide_resnet101_2']
13 |
14 |
15 | model_urls = {
16 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
17 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
18 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
19 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
20 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
21 | 'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth',
22 | 'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth',
23 | 'wide_resnet50_2': 'https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth',
24 | 'wide_resnet101_2': 'https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth',
25 | }
26 |
27 |
28 | def conv3x3(in_planes: int, out_planes: int, stride: int = 1, groups: int = 1, dilation: int = 1) -> nn.Conv2d:
29 | """3x3 convolution with padding"""
30 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
31 | padding=dilation, groups=groups, bias=False, dilation=dilation)
32 |
33 |
34 | def conv1x1(in_planes: int, out_planes: int, stride: int = 1) -> nn.Conv2d:
35 | """1x1 convolution"""
36 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
37 |
38 |
39 | class BasicBlock(nn.Module):
40 | expansion: int = 1
41 |
42 | def __init__(
43 | self,
44 | inplanes: int,
45 | planes: int,
46 | stride: int = 1,
47 | downsample: Optional[nn.Module] = None,
48 | groups: int = 1,
49 | base_width: int = 64,
50 | dilation: int = 1,
51 | norm_layer: Optional[Callable[..., nn.Module]] = None
52 | ) -> None:
53 | super(BasicBlock, self).__init__()
54 | if norm_layer is None:
55 | norm_layer = nn.BatchNorm2d
56 | if groups != 1 or base_width != 64:
57 | raise ValueError('BasicBlock only supports groups=1 and base_width=64')
58 | if dilation > 1:
59 | raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
60 | # Both self.conv1 and self.downsample layers downsample the input when stride != 1
61 | self.conv1 = conv3x3(inplanes, planes, stride)
62 | self.bn1 = norm_layer(planes)
63 | self.relu = nn.ReLU(inplace=True)
64 | self.conv2 = conv3x3(planes, planes)
65 | self.bn2 = norm_layer(planes)
66 | self.downsample = downsample
67 | self.stride = stride
68 |
69 | def forward(self, x: Tensor) -> Tensor:
70 | identity = x
71 |
72 | out = self.conv1(x)
73 | out = self.bn1(out)
74 | out = self.relu(out)
75 |
76 | out = self.conv2(out)
77 | out = self.bn2(out)
78 |
79 | if self.downsample is not None:
80 | identity = self.downsample(x)
81 |
82 | out += identity
83 | out = self.relu(out)
84 |
85 | return out
86 |
87 |
88 | class Bottleneck(nn.Module):
89 | # Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2)
90 | # while original implementation places the stride at the first 1x1 convolution(self.conv1)
91 | # according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385.
92 | # This variant is also known as ResNet V1.5 and improves accuracy according to
93 | # https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch.
94 |
95 | expansion: int = 4
96 |
97 | def __init__(
98 | self,
99 | inplanes: int,
100 | planes: int,
101 | stride: int = 1,
102 | downsample: Optional[nn.Module] = None,
103 | groups: int = 1,
104 | base_width: int = 64,
105 | dilation: int = 1,
106 | norm_layer: Optional[Callable[..., nn.Module]] = None
107 | ) -> None:
108 | super(Bottleneck, self).__init__()
109 | if norm_layer is None:
110 | norm_layer = nn.BatchNorm2d
111 | width = int(planes * (base_width / 64.)) * groups
112 | # Both self.conv2 and self.downsample layers downsample the input when stride != 1
113 | self.conv1 = conv1x1(inplanes, width)
114 | self.bn1 = norm_layer(width)
115 | self.conv2 = conv3x3(width, width, stride, groups, dilation)
116 | self.bn2 = norm_layer(width)
117 | self.conv3 = conv1x1(width, planes * self.expansion)
118 | self.bn3 = norm_layer(planes * self.expansion)
119 | self.relu = nn.ReLU(inplace=True)
120 | self.downsample = downsample
121 | self.stride = stride
122 |
123 | def forward(self, x: Tensor) -> Tensor:
124 | identity = x
125 |
126 | out = self.conv1(x)
127 | out = self.bn1(out)
128 | out = self.relu(out)
129 |
130 | out = self.conv2(out)
131 | out = self.bn2(out)
132 | out = self.relu(out)
133 |
134 | out = self.conv3(out)
135 | out = self.bn3(out)
136 |
137 | if self.downsample is not None:
138 | identity = self.downsample(x)
139 |
140 | out += identity
141 | out = self.relu(out)
142 |
143 | return out
144 |
145 |
146 | class ResNet(nn.Module):
147 |
148 | def __init__(
149 | self,
150 | block: Type[Union[BasicBlock, Bottleneck]],
151 | layers: List[int],
152 | num_classes: int = 1000,
153 | zero_init_residual: bool = False,
154 | groups: int = 1,
155 | width_per_group: int = 64,
156 | replace_stride_with_dilation: Optional[List[bool]] = None,
157 | norm_layer: Optional[Callable[..., nn.Module]] = None,
158 | deep_stem: bool = True,
159 | ) -> None:
160 | super(ResNet, self).__init__()
161 | if norm_layer is None:
162 | norm_layer = nn.BatchNorm2d
163 | self._norm_layer = norm_layer
164 |
165 | self.inplanes = 128 # 64
166 | self.dilation = 1
167 | if replace_stride_with_dilation is None:
168 | # each element in the tuple indicates if we should replace
169 | # the 2x2 stride with a dilated convolution instead
170 | replace_stride_with_dilation = [False, False, False]
171 | if len(replace_stride_with_dilation) != 3:
172 | raise ValueError("replace_stride_with_dilation should be None "
173 | "or a 3-element tuple, got {}".format(replace_stride_with_dilation))
174 | self.groups = groups
175 | self.deep_stem = deep_stem
176 | self.base_width = width_per_group
177 | self._make_stem_layer(3, self.inplanes)
178 | self.layer1 = self._make_layer(block, 64, layers[0])
179 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2,
180 | dilate=replace_stride_with_dilation[0])
181 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2,
182 | dilate=replace_stride_with_dilation[1])
183 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2,
184 | dilate=replace_stride_with_dilation[2])
185 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
186 | self.fc = nn.Linear(512 * block.expansion, num_classes)
187 |
188 | for m in self.modules():
189 | if isinstance(m, nn.Conv2d):
190 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
191 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
192 | nn.init.constant_(m.weight, 1)
193 | nn.init.constant_(m.bias, 0)
194 |
195 | # Zero-initialize the last BN in each residual branch,
196 | # so that the residual branch starts with zeros, and each residual block behaves like an identity.
197 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
198 | if zero_init_residual:
199 | for m in self.modules():
200 | if isinstance(m, Bottleneck):
201 | nn.init.constant_(m.bn3.weight, 0) # type: ignore[arg-type]
202 | elif isinstance(m, BasicBlock):
203 | nn.init.constant_(m.bn2.weight, 0) # type: ignore[arg-type]
204 |
205 | def _make_stem_layer(self, in_channels, stem_channels):
206 | """Make stem layer for ResNet."""
207 | if self.deep_stem:
208 | self.conv1 = conv3x3(in_channels, 64, stride=2)
209 | self.bn1 = self._norm_layer(64)
210 | self.relu1 = nn.ReLU(inplace=True)
211 | self.conv2 = conv3x3(64, 64)
212 | self.bn2 = self._norm_layer(64)
213 | self.relu2 = nn.ReLU(inplace=True)
214 | self.conv3 = conv3x3(64, 128)
215 | self.bn3 = self._norm_layer(128)
216 | self.relu3 = nn.ReLU(inplace=True)
217 | else:
218 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3,
219 | bias=False)
220 | self.bn1 = self._norm_layer(self.inplanes)
221 | self.relu = nn.ReLU(inplace=True)
222 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
223 |
224 | def _make_layer(self, block: Type[Union[BasicBlock, Bottleneck]], planes: int, blocks: int,
225 | stride: int = 1, dilate: bool = False) -> nn.Sequential:
226 | norm_layer = self._norm_layer
227 | downsample = None
228 | previous_dilation = self.dilation
229 | if dilate:
230 | self.dilation *= stride
231 | stride = 1
232 | if stride != 1 or self.inplanes != planes * block.expansion:
233 | downsample = nn.Sequential(
234 | conv1x1(self.inplanes, planes * block.expansion, stride),
235 | norm_layer(planes * block.expansion),
236 | )
237 |
238 | layers = []
239 | layers.append(block(self.inplanes, planes, stride, downsample, self.groups,
240 | self.base_width, previous_dilation, norm_layer))
241 | self.inplanes = planes * block.expansion
242 | for _ in range(1, blocks):
243 | layers.append(block(self.inplanes, planes, groups=self.groups,
244 | base_width=self.base_width, dilation=self.dilation,
245 | norm_layer=norm_layer))
246 |
247 | return nn.Sequential(*layers)
248 |
249 | def _forward_impl(self, x: Tensor) -> Tensor:
250 | # See note [TorchScript super()]
251 | x = self.conv1(x)
252 | x = self.bn1(x)
253 | x = self.relu(x)
254 | x = self.maxpool(x)
255 |
256 | x = self.layer1(x)
257 | x = self.layer2(x)
258 | x = self.layer3(x)
259 | x = self.layer4(x)
260 |
261 | # x = self.avgpool(x)
262 | # x = torch.flatten(x, 1)
263 | # x = self.fc(x)
264 |
265 | return x
266 |
267 | def forward(self, x: Tensor) -> Tensor:
268 | return self._forward_impl(x)
269 |
270 |
271 | def _resnet(
272 | arch: str,
273 | block: Type[Union[BasicBlock, Bottleneck]],
274 | layers: List[int],
275 | pretrained: bool,
276 | progress: bool,
277 | **kwargs: Any
278 | ) -> ResNet:
279 | model = ResNet(block, layers, **kwargs)
280 | if 'warp' in pretrained:
281 | model.fc = nn.Identity()
282 | if pretrained:
283 | # state_dict = load_state_dict_from_url(model_urls[arch],
284 | # progress=progress)
285 | state_dict = torch.load(pretrained, map_location='cpu')
286 | model.load_state_dict(state_dict)
287 | return model
288 |
289 |
290 | def resnet18(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
291 | r"""ResNet-18 model from
292 | `"Deep Residual Learning for Image Recognition" `_.
293 |
294 | Args:
295 | pretrained (bool): If True, returns a model pre-trained on ImageNet
296 | progress (bool): If True, displays a progress bar of the download to stderr
297 | """
298 | return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress,
299 | **kwargs)
300 |
301 |
302 | def resnet34(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
303 | r"""ResNet-34 model from
304 | `"Deep Residual Learning for Image Recognition" `_.
305 |
306 | Args:
307 | pretrained (bool): If True, returns a model pre-trained on ImageNet
308 | progress (bool): If True, displays a progress bar of the download to stderr
309 | """
310 | return _resnet('resnet34', BasicBlock, [3, 4, 6, 3], pretrained, progress,
311 | **kwargs)
312 |
313 |
314 | def resnet50(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
315 | r"""ResNet-50 model from
316 | `"Deep Residual Learning for Image Recognition" `_.
317 |
318 | Args:
319 | pretrained (bool): If True, returns a model pre-trained on ImageNet
320 | progress (bool): If True, displays a progress bar of the download to stderr
321 | """
322 | return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress,
323 | **kwargs)
324 |
325 |
326 | def resnet101(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
327 | r"""ResNet-101 model from
328 | `"Deep Residual Learning for Image Recognition" `_.
329 |
330 | Args:
331 | pretrained (bool): If True, returns a model pre-trained on ImageNet
332 | progress (bool): If True, displays a progress bar of the download to stderr
333 | """
334 | return _resnet('resnet101', Bottleneck, [3, 4, 23, 3], pretrained, progress,
335 | **kwargs)
336 |
337 |
338 | def resnet152(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
339 | r"""ResNet-152 model from
340 | `"Deep Residual Learning for Image Recognition" `_.
341 |
342 | Args:
343 | pretrained (bool): If True, returns a model pre-trained on ImageNet
344 | progress (bool): If True, displays a progress bar of the download to stderr
345 | """
346 | return _resnet('resnet152', Bottleneck, [3, 8, 36, 3], pretrained, progress,
347 | **kwargs)
348 |
349 |
350 | def resnext50_32x4d(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
351 | r"""ResNeXt-50 32x4d model from
352 | `"Aggregated Residual Transformation for Deep Neural Networks" `_.
353 |
354 | Args:
355 | pretrained (bool): If True, returns a model pre-trained on ImageNet
356 | progress (bool): If True, displays a progress bar of the download to stderr
357 | """
358 | kwargs['groups'] = 32
359 | kwargs['width_per_group'] = 4
360 | return _resnet('resnext50_32x4d', Bottleneck, [3, 4, 6, 3],
361 | pretrained, progress, **kwargs)
362 |
363 |
364 | def resnext101_32x8d(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
365 | r"""ResNeXt-101 32x8d model from
366 | `"Aggregated Residual Transformation for Deep Neural Networks" `_.
367 |
368 | Args:
369 | pretrained (bool): If True, returns a model pre-trained on ImageNet
370 | progress (bool): If True, displays a progress bar of the download to stderr
371 | """
372 | kwargs['groups'] = 32
373 | kwargs['width_per_group'] = 8
374 | return _resnet('resnext101_32x8d', Bottleneck, [3, 4, 23, 3],
375 | pretrained, progress, **kwargs)
376 |
377 |
378 | def wide_resnet50_2(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
379 | r"""Wide ResNet-50-2 model from
380 | `"Wide Residual Networks" `_.
381 |
382 | The model is the same as ResNet except for the bottleneck number of channels
383 | which is twice larger in every block. The number of channels in outer 1x1
384 | convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048
385 | channels, and in Wide ResNet-50-2 has 2048-1024-2048.
386 |
387 | Args:
388 | pretrained (bool): If True, returns a model pre-trained on ImageNet
389 | progress (bool): If True, displays a progress bar of the download to stderr
390 | """
391 | kwargs['width_per_group'] = 64 * 2
392 | return _resnet('wide_resnet50_2', Bottleneck, [3, 4, 6, 3],
393 | pretrained, progress, **kwargs)
394 |
395 |
396 | def wide_resnet101_2(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
397 | r"""Wide ResNet-101-2 model from
398 | `"Wide Residual Networks" `_.
399 |
400 | The model is the same as ResNet except for the bottleneck number of channels
401 | which is twice larger in every block. The number of channels in outer 1x1
402 | convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048
403 | channels, and in Wide ResNet-50-2 has 2048-1024-2048.
404 |
405 | Args:
406 | pretrained (bool): If True, returns a model pre-trained on ImageNet
407 | progress (bool): If True, displays a progress bar of the download to stderr
408 | """
409 | kwargs['width_per_group'] = 64 * 2
410 | return _resnet('wide_resnet101_2', Bottleneck, [3, 4, 23, 3],
411 | pretrained, progress, **kwargs)
412 |
--------------------------------------------------------------------------------
/model/backbone_utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | import torch.nn.functional as F
4 | from torchvision.models._utils import IntermediateLayerGetter
5 | from model.backbone_res import *
6 |
7 |
8 | class FrozenBatchNorm2d(torch.nn.Module):
9 | """
10 | BatchNorm2d where the batch statistics and the affine parameters are fixed.
11 | Copy-paste from torchvision.misc.ops with added eps before rqsrt,
12 | without which any other models than torchvision.models.resnet[18,34,50,101]
13 | produce nans.
14 | """
15 |
16 | def __init__(self, n):
17 | super(FrozenBatchNorm2d, self).__init__()
18 | self.register_buffer("weight", torch.ones(n))
19 | self.register_buffer("bias", torch.zeros(n))
20 | self.register_buffer("running_mean", torch.zeros(n))
21 | self.register_buffer("running_var", torch.ones(n))
22 |
23 | def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
24 | missing_keys, unexpected_keys, error_msgs):
25 | num_batches_tracked_key = prefix + 'num_batches_tracked'
26 | if num_batches_tracked_key in state_dict:
27 | del state_dict[num_batches_tracked_key]
28 |
29 | super(FrozenBatchNorm2d, self)._load_from_state_dict(
30 | state_dict, prefix, local_metadata, strict,
31 | missing_keys, unexpected_keys, error_msgs)
32 |
33 | def forward(self, x):
34 | # move reshapes to the beginning
35 | # to make it fuser-friendly
36 | w = self.weight.reshape(1, -1, 1, 1)
37 | b = self.bias.reshape(1, -1, 1, 1)
38 | rv = self.running_var.reshape(1, -1, 1, 1)
39 | rm = self.running_mean.reshape(1, -1, 1, 1)
40 | eps = 1e-5
41 | scale = w * (rv + eps).rsqrt()
42 | bias = b - rm * scale
43 | return x * scale + bias
44 |
45 |
46 | class BackboneBase(nn.Module):
47 |
48 | def __init__(self, backbone: nn.Module, train_backbone: bool, num_channels: int, return_interm_layers: bool):
49 | super().__init__()
50 | for name, parameter in backbone.named_parameters():
51 | if not train_backbone or 'layer2' not in name and 'layer3' not in name and 'layer4' not in name:
52 | parameter.requires_grad_(False)
53 | if return_interm_layers:
54 | return_layers = {"layer1": "0", "layer2": "1", "layer3": "2", "layer4": "3"}
55 | else:
56 | return_layers = {'layer4': "0"}
57 | self.body = IntermediateLayerGetter(backbone, return_layers=return_layers)
58 | self.num_channels = num_channels
59 |
60 | def forward(self, x):
61 | x = self.body(x)
62 | return x
63 |
64 |
65 | resnets_dict = {
66 | 'resnet50': (resnet50, 'initmodel/resnet50_v2.pth'),
67 | 'resnet101': (resnet101, 'initmodel/resnet101_v2.pth'),
68 | }
69 |
70 |
71 | class Backbone(BackboneBase):
72 | """ResNet backbone with frozen BatchNorm."""
73 |
74 | def __init__(self, name: str,
75 | train_backbone: bool,
76 | return_interm_layers: bool,
77 | dilation: list):
78 | backbone = resnets_dict[name][0](
79 | replace_stride_with_dilation=dilation,
80 | pretrained=resnets_dict[name][1], norm_layer=FrozenBatchNorm2d)
81 | num_channels = 512 if name in ('resnet18', 'resnet34') else 2048
82 | super().__init__(backbone, train_backbone, num_channels, return_interm_layers)
83 |
--------------------------------------------------------------------------------
/model/loss.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 |
6 | # ========================================
7 | # Weighted Dice loss
8 | # By default, weight is uniformly fixed as 1
9 | # From https://github.com/YanFangCS/CyCTR-Pytorch
10 | # ========================================
11 | def weighted_dice_loss(prediction, target_seg, weighted_val: float = 1.0, reduction: str = "sum", eps: float = 1e-8):
12 | """
13 | Weighted version of Dice Loss
14 |
15 | Args:
16 | prediction: prediction
17 | target_seg: segmentation target
18 | weighted_val: values of k positives,
19 | reduction: 'none' | 'mean' | 'sum'
20 | 'none': No reduction will be applied to the output.
21 | 'mean': The output will be averaged.
22 | 'sum' : The output will be summed.
23 | eps: the minimum eps,
24 | """
25 | target_seg_fg = target_seg == 1
26 | target_seg_bg = target_seg == 0
27 | target_seg = torch.stack([target_seg_bg, target_seg_fg], dim=1).float()
28 |
29 | n, _, h, w = target_seg.shape
30 |
31 | prediction = prediction.reshape(-1, h, w)
32 | target_seg = target_seg.reshape(-1, h, w)
33 | prediction = torch.sigmoid(prediction)
34 | prediction = prediction.reshape(-1, h * w)
35 | target_seg = target_seg.reshape(-1, h * w)
36 |
37 | # calculate dice loss
38 | loss_part = (prediction ** 2).sum(dim=-1) + (target_seg ** 2).sum(dim=-1)
39 | loss = 1 - 2 * (target_seg * prediction).sum(dim=-1) / torch.clamp(loss_part, min=eps)
40 | # normalize the loss
41 | loss = loss * weighted_val
42 |
43 | if reduction == "sum":
44 | loss = loss.sum() / n
45 | elif reduction == "mean":
46 | loss = loss.mean()
47 | return loss
48 |
49 |
50 | class WeightedDiceLoss(nn.Module):
51 | def __init__(self, weighted_val: float = 1.0, reduction: str = "sum"):
52 | super(WeightedDiceLoss, self).__init__()
53 | self.weighted_val = weighted_val
54 | self.reduction = reduction
55 |
56 | def forward(self, prediction, target_seg):
57 | return weighted_dice_loss(prediction, target_seg, self.weighted_val, self.reduction)
58 |
59 |
60 | # ========================================
61 | # Cross-Entropy loss + Dice loss
62 | # ========================================
63 | class CEDiceLoss(nn.Module):
64 | def __init__(self, reduction="mean", ignore_index=255):
65 | super(CEDiceLoss, self).__init__()
66 | self.reduction = reduction
67 | self.ignore_index = ignore_index
68 |
69 | self.ce = nn.CrossEntropyLoss(ignore_index=self.ignore_index, reduction=self.reduction)
70 | self.dice = WeightedDiceLoss(reduction=reduction)
71 |
72 | def forward(self, output, target):
73 | ce = self.ce(output, target)
74 | dice = self.dice(output, target)
75 | return ce + dice
76 |
--------------------------------------------------------------------------------
/model/resnet.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import math
4 | import torch.utils.model_zoo as model_zoo
5 |
6 | BatchNorm = nn.BatchNorm2d
7 |
8 | __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101',
9 | 'resnet152']
10 |
11 |
12 | model_urls = {
13 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
14 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
15 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
16 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
17 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
18 | }
19 |
20 |
21 | def conv3x3(in_planes, out_planes, stride=1):
22 | """3x3 convolution with padding"""
23 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
24 | padding=1, bias=False)
25 |
26 |
27 | class BasicBlock(nn.Module):
28 | expansion = 1
29 |
30 | def __init__(self, inplanes, planes, stride=1, downsample=None):
31 | super(BasicBlock, self).__init__()
32 | self.conv1 = conv3x3(inplanes, planes, stride)
33 | self.bn1 = BatchNorm(planes)
34 | self.relu = nn.ReLU(inplace=True)
35 | self.conv2 = conv3x3(planes, planes)
36 | self.bn2 = BatchNorm(planes)
37 | self.downsample = downsample
38 | self.stride = stride
39 |
40 | def forward(self, x):
41 | residual = x
42 |
43 | out = self.conv1(x)
44 | out = self.bn1(out)
45 | out = self.relu(out)
46 |
47 | out = self.conv2(out)
48 | out = self.bn2(out)
49 |
50 | if self.downsample is not None:
51 | residual = self.downsample(x)
52 |
53 | out += residual
54 | out = self.relu(out)
55 |
56 | return out
57 |
58 |
59 | class Bottleneck(nn.Module):
60 | expansion = 4
61 |
62 | def __init__(self, inplanes, planes, stride=1, downsample=None):
63 | super(Bottleneck, self).__init__()
64 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
65 | self.bn1 = BatchNorm(planes)
66 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
67 | padding=1, bias=False)
68 | self.bn2 = BatchNorm(planes)
69 | self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1, bias=False)
70 | self.bn3 = BatchNorm(planes * self.expansion)
71 | self.relu = nn.ReLU(inplace=True)
72 | self.downsample = downsample
73 | self.stride = stride
74 |
75 | def forward(self, x):
76 | residual = x
77 |
78 | out = self.conv1(x)
79 | out = self.bn1(out)
80 | out = self.relu(out)
81 |
82 | out = self.conv2(out)
83 | out = self.bn2(out)
84 | out = self.relu(out)
85 |
86 | out = self.conv3(out)
87 | out = self.bn3(out)
88 |
89 | if self.downsample is not None:
90 | residual = self.downsample(x)
91 |
92 | out += residual
93 | out = self.relu(out)
94 |
95 | return out
96 |
97 |
98 | class ResNet(nn.Module):
99 |
100 | def __init__(self, block, layers, num_classes=1000, deep_base=True):
101 | super(ResNet, self).__init__()
102 | self.deep_base = deep_base
103 | if not self.deep_base:
104 | self.inplanes = 64
105 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
106 | self.bn1 = BatchNorm(64)
107 | self.relu = nn.ReLU(inplace=True)
108 | else:
109 | self.inplanes = 128
110 | self.conv1 = conv3x3(3, 64, stride=2)
111 | self.bn1 = BatchNorm(64)
112 | self.relu1 = nn.ReLU(inplace=True)
113 | self.conv2 = conv3x3(64, 64)
114 | self.bn2 = BatchNorm(64)
115 | self.relu2 = nn.ReLU(inplace=True)
116 | self.conv3 = conv3x3(64, 128)
117 | self.bn3 = BatchNorm(128)
118 | self.relu3 = nn.ReLU(inplace=True)
119 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
120 | self.layer1 = self._make_layer(block, 64, layers[0])
121 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
122 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
123 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
124 | self.avgpool = nn.AvgPool2d(7, stride=1)
125 | self.fc = nn.Linear(512 * block.expansion, num_classes)
126 |
127 | for m in self.modules():
128 | if isinstance(m, nn.Conv2d):
129 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
130 | elif isinstance(m, BatchNorm):
131 | nn.init.constant_(m.weight, 1)
132 | nn.init.constant_(m.bias, 0)
133 |
134 | def _make_layer(self, block, planes, blocks, stride=1):
135 | downsample = None
136 | if stride != 1 or self.inplanes != planes * block.expansion:
137 | downsample = nn.Sequential(
138 | nn.Conv2d(self.inplanes, planes * block.expansion,
139 | kernel_size=1, stride=stride, bias=False),
140 | BatchNorm(planes * block.expansion),
141 | )
142 |
143 | layers = []
144 | layers.append(block(self.inplanes, planes, stride, downsample))
145 | self.inplanes = planes * block.expansion
146 | for i in range(1, blocks):
147 | layers.append(block(self.inplanes, planes))
148 |
149 | return nn.Sequential(*layers)
150 |
151 | def forward(self, x):
152 | x = self.relu1(self.bn1(self.conv1(x)))
153 | if self.deep_base:
154 | x = self.relu2(self.bn2(self.conv2(x)))
155 | x = self.relu3(self.bn3(self.conv3(x)))
156 | x = self.maxpool(x)
157 |
158 | x = self.layer1(x)
159 | x = self.layer2(x)
160 | x = self.layer3(x)
161 | x = self.layer4(x)
162 |
163 | x = self.avgpool(x)
164 | x = x.view(x.size(0), -1)
165 | x = self.fc(x)
166 |
167 | return x
168 |
169 |
170 | def resnet18(pretrained=False, **kwargs):
171 | """Constructs a ResNet-18 model.
172 |
173 | Args:
174 | pretrained (bool): If True, returns a model pre-trained on ImageNet
175 | """
176 | model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs)
177 | if pretrained:
178 | model.load_state_dict(model_zoo.load_url(model_urls['resnet18']))
179 | return model
180 |
181 |
182 | def resnet34(pretrained=False, **kwargs):
183 | """Constructs a ResNet-34 model.
184 |
185 | Args:
186 | pretrained (bool): If True, returns a model pre-trained on ImageNet
187 | """
188 | model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs)
189 | if pretrained:
190 | model.load_state_dict(model_zoo.load_url(model_urls['resnet34']))
191 | return model
192 |
193 |
194 | def resnet50(pretrained=True, **kwargs):
195 | """Constructs a ResNet-50 model.
196 |
197 | Args:
198 | pretrained (bool): If True, returns a model pre-trained on ImageNet
199 | """
200 | model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs)
201 | if pretrained:
202 | # model.load_state_dict(model_zoo.load_url(model_urls['resnet50']))
203 | model_path = './initmodel/resnet50_v2.pth'
204 | model.load_state_dict(torch.load(model_path), strict=False)
205 | return model
206 |
207 |
208 | def resnet101(pretrained=False, **kwargs):
209 | """Constructs a ResNet-101 model.
210 |
211 | Args:
212 | pretrained (bool): If True, returns a model pre-trained on ImageNet
213 | """
214 | model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs)
215 | if pretrained:
216 | # model.load_state_dict(model_zoo.load_url(model_urls['resnet101']))
217 | model_path = './initmodel/resnet101_v2.pth'
218 | model.load_state_dict(torch.load(model_path), strict=False)
219 | return model
220 |
221 |
222 | def resnet152(pretrained=False, **kwargs):
223 | """Constructs a ResNet-152 model.
224 |
225 | Args:
226 | pretrained (bool): If True, returns a model pre-trained on ImageNet
227 | """
228 | model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs)
229 | if pretrained:
230 | # model.load_state_dict(model_zoo.load_url(model_urls['resnet152']))
231 | model_path = './initmodel/resnet152_v2.pth'
232 | model.load_state_dict(torch.load(model_path), strict=False)
233 | return model
234 |
--------------------------------------------------------------------------------
/model/vgg.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.utils.model_zoo as model_zoo
4 |
5 | BatchNorm = nn.BatchNorm2d
6 |
7 | __all__ = [
8 | 'VGG', 'vgg11', 'vgg11_bn', 'vgg13', 'vgg13_bn', 'vgg16', 'vgg16_bn',
9 | 'vgg19_bn', 'vgg19',
10 | ]
11 |
12 |
13 | model_urls = {
14 | 'vgg11': 'https://download.pytorch.org/models/vgg11-bbd30ac9.pth',
15 | 'vgg13': 'https://download.pytorch.org/models/vgg13-c768596a.pth',
16 | 'vgg16': 'https://download.pytorch.org/models/vgg16-397923af.pth',
17 | 'vgg19': 'https://download.pytorch.org/models/vgg19-dcbb9e9d.pth',
18 | 'vgg11_bn': 'https://download.pytorch.org/models/vgg11_bn-6002323d.pth',
19 | 'vgg13_bn': 'https://download.pytorch.org/models/vgg13_bn-abd245e5.pth',
20 | 'vgg16_bn': 'https://download.pytorch.org/models/vgg16_bn-6c64b313.pth',
21 | 'vgg19_bn': 'https://download.pytorch.org/models/vgg19_bn-c79401a0.pth',
22 | }
23 |
24 |
25 | class VGG(nn.Module):
26 |
27 | def __init__(self, features, num_classes=1000, init_weights=True):
28 | super(VGG, self).__init__()
29 | self.features = features
30 | self.avgpool = nn.AdaptiveAvgPool2d((7, 7))
31 | self.classifier = nn.Sequential(
32 | nn.Linear(512 * 7 * 7, 4096),
33 | nn.ReLU(True),
34 | nn.Dropout(),
35 | nn.Linear(4096, 4096),
36 | nn.ReLU(True),
37 | nn.Dropout(),
38 | nn.Linear(4096, num_classes),
39 | )
40 | if init_weights:
41 | self._initialize_weights()
42 |
43 | def forward(self, x):
44 | x = self.features(x)
45 | x = self.avgpool(x)
46 | x = x.view(x.size(0), -1)
47 | x = self.classifier(x)
48 | return x
49 |
50 | def _initialize_weights(self):
51 | for m in self.modules():
52 | if isinstance(m, nn.Conv2d):
53 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
54 | if m.bias is not None:
55 | nn.init.constant_(m.bias, 0)
56 | elif isinstance(m, BatchNorm):
57 | nn.init.constant_(m.weight, 1)
58 | nn.init.constant_(m.bias, 0)
59 | elif isinstance(m, nn.Linear):
60 | nn.init.normal_(m.weight, 0, 0.01)
61 | nn.init.constant_(m.bias, 0)
62 |
63 |
64 | def make_layers(cfg, batch_norm=False):
65 | layers = []
66 | in_channels = 3
67 | for v in cfg:
68 | if v == 'M':
69 | layers += [nn.MaxPool2d(kernel_size=2, stride=2, ceil_mode=True)]
70 | else:
71 | conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1)
72 | if batch_norm:
73 | layers += [conv2d, BatchNorm(v), nn.ReLU(inplace=True)]
74 | else:
75 | layers += [conv2d, nn.ReLU(inplace=True)]
76 | in_channels = v
77 | return nn.Sequential(*layers)
78 |
79 |
80 | cfg = {
81 | 'A': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
82 | 'B': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
83 | 'D': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'],
84 | 'E': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'],
85 | }
86 |
87 |
88 | def vgg11(pretrained=False, **kwargs):
89 | """VGG 11-layer model (configuration "A")
90 | Args:
91 | pretrained (bool): If True, returns a model pre-trained on ImageNet
92 | """
93 | if pretrained:
94 | kwargs['init_weights'] = False
95 | model = VGG(make_layers(cfg['A']), **kwargs)
96 | if pretrained:
97 | model.load_state_dict(model_zoo.load_url(model_urls['vgg11']))
98 | return model
99 |
100 |
101 | def vgg11_bn(pretrained=False, **kwargs):
102 | """VGG 11-layer model (configuration "A") with batch normalization
103 | Args:
104 | pretrained (bool): If True, returns a model pre-trained on ImageNet
105 | """
106 | if pretrained:
107 | kwargs['init_weights'] = False
108 | model = VGG(make_layers(cfg['A'], batch_norm=True), **kwargs)
109 | if pretrained:
110 | model.load_state_dict(model_zoo.load_url(model_urls['vgg11_bn']))
111 | return model
112 |
113 |
114 | def vgg13(pretrained=False, **kwargs):
115 | """VGG 13-layer model (configuration "B")
116 | Args:
117 | pretrained (bool): If True, returns a model pre-trained on ImageNet
118 | """
119 | if pretrained:
120 | kwargs['init_weights'] = False
121 | model = VGG(make_layers(cfg['B']), **kwargs)
122 | if pretrained:
123 | model.load_state_dict(model_zoo.load_url(model_urls['vgg13']))
124 | return model
125 |
126 |
127 | def vgg13_bn(pretrained=False, **kwargs):
128 | """VGG 13-layer model (configuration "B") with batch normalization
129 | Args:
130 | pretrained (bool): If True, returns a model pre-trained on ImageNet
131 | """
132 | if pretrained:
133 | kwargs['init_weights'] = False
134 | model = VGG(make_layers(cfg['B'], batch_norm=True), **kwargs)
135 | if pretrained:
136 | model.load_state_dict(model_zoo.load_url(model_urls['vgg13_bn']))
137 | return model
138 |
139 |
140 | def vgg16(pretrained=False, **kwargs):
141 | """VGG 16-layer model (configuration "D")
142 | Args:
143 | pretrained (bool): If True, returns a model pre-trained on ImageNet
144 | """
145 | if pretrained:
146 | kwargs['init_weights'] = False
147 | model = VGG(make_layers(cfg['D']), **kwargs)
148 | if pretrained:
149 | #model.load_state_dict(model_zoo.load_url(model_urls['vgg16_bn']))
150 | model_path = './initmodel/vgg16.pth'
151 | model.load_state_dict(torch.load(model_path), strict=False)
152 | return model
153 |
154 |
155 | def vgg16_bn(pretrained=False, **kwargs):
156 | """VGG 16-layer model (configuration "D") with batch normalization
157 | Args:
158 | pretrained (bool): If True, returns a model pre-trained on ImageNet
159 | """
160 | if pretrained:
161 | kwargs['init_weights'] = False
162 | model = VGG(make_layers(cfg['D'], batch_norm=True), **kwargs)
163 | if pretrained:
164 | #model.load_state_dict(model_zoo.load_url(model_urls['vgg16_bn']))
165 | model_path = './initmodel/vgg16_bn.pth'
166 | model.load_state_dict(torch.load(model_path), strict=False)
167 | return model
168 |
169 |
170 | def vgg19(pretrained=False, **kwargs):
171 | """VGG 19-layer model (configuration "E")
172 | Args:
173 | pretrained (bool): If True, returns a model pre-trained on ImageNet
174 | """
175 | if pretrained:
176 | kwargs['init_weights'] = False
177 | model = VGG(make_layers(cfg['E']), **kwargs)
178 | if pretrained:
179 | model.load_state_dict(model_zoo.load_url(model_urls['vgg19']))
180 | return model
181 |
182 |
183 | def vgg19_bn(pretrained=False, **kwargs):
184 | """VGG 19-layer model (configuration 'E') with batch normalization
185 | Args:
186 | pretrained (bool): If True, returns a model pre-trained on ImageNet
187 | """
188 | if pretrained:
189 | kwargs['init_weights'] = False
190 | model = VGG(make_layers(cfg['E'], batch_norm=True), **kwargs)
191 | if pretrained:
192 | model.load_state_dict(model_zoo.load_url(model_urls['vgg19_bn']))
193 | return model
194 |
195 | if __name__ =='__main__':
196 | import os
197 | # os.environ["CUDA_VISIBLE_DEVICES"] = '7'
198 | input = torch.rand(4, 3, 473, 473).cuda()
199 | target = torch.rand(4, 473, 473).cuda()*1.0
200 | model = vgg16_bn(pretrained=False).cuda()
201 | model.train()
202 | layer0_idx = range(0,6)
203 | layer1_idx = range(6,13)
204 | layer2_idx = range(13,23)
205 | layer3_idx = range(23,33)
206 | layer4_idx = range(34,43)
207 | #layer4_idx = range(34,43)
208 | print(model.features)
209 | layers_0 = []
210 | layers_1 = []
211 | layers_2 = []
212 | layers_3 = []
213 | layers_4 = []
214 | for idx in layer0_idx:
215 | layers_0 += [model.features[idx]]
216 | for idx in layer1_idx:
217 | layers_1 += [model.features[idx]]
218 | for idx in layer2_idx:
219 | layers_2 += [model.features[idx]]
220 | for idx in layer3_idx:
221 | layers_3 += [model.features[idx]]
222 | for idx in layer4_idx:
223 | layers_4 += [model.features[idx]]
224 |
225 | layer0 = nn.Sequential(*layers_0)
226 | layer1 = nn.Sequential(*layers_1)
227 | layer2 = nn.Sequential(*layers_2)
228 | layer3 = nn.Sequential(*layers_3)
229 | layer4 = nn.Sequential(*layers_4)
230 |
231 | output = layer0(input)
232 | print(layer0)
233 | print('layer 0: {}'.format(output.size()))
234 | output = layer1(output)
235 | print(layer1)
236 | print('layer 1: {}'.format(output.size()))
237 | output = layer2(output)
238 | print(layer2)
239 | print('layer 2: {}'.format(output.size()))
240 | output = layer3(output)
241 | print(layer3)
242 | print('layer 3: {}'.format(output.size()))
243 | output = layer4(output)
244 | print(layer4)
245 | print('layer 4: {}'.format(output.size()))
246 |
--------------------------------------------------------------------------------
/test_coco.py:
--------------------------------------------------------------------------------
1 | import os
2 | import datetime
3 | import random
4 | import time
5 | import cv2
6 | import numpy as np
7 | import logging
8 | import argparse
9 | import math
10 | from visdom import Visdom
11 | import os.path as osp
12 |
13 | import torch
14 | import torch.backends.cudnn as cudnn
15 | import torch.nn as nn
16 | import torch.nn.functional as F
17 | import torch.nn.parallel
18 | import torch.optim
19 | import torch.utils.data
20 | import torch.multiprocessing as mp
21 | import torch.distributed as dist
22 | from torch.utils.data.distributed import DistributedSampler
23 | from torch.cuda.amp import autocast
24 |
25 | from tensorboardX import SummaryWriter
26 |
27 | from model import HMNetAMP
28 |
29 | from util import dataset
30 | from util import transform, transform_tri, config
31 | from util.util import AverageMeter, poly_learning_rate, intersectionAndUnionGPU, get_model_para_number, setup_seed, \
32 | get_logger, get_save_path, \
33 | is_same_model, fix_bn, sum_list, check_makedirs
34 | import matplotlib.pyplot as plt
35 |
36 | cv2.ocl.setUseOpenCL(False)
37 | cv2.setNumThreads(0)
38 | val_manual_seed = 123
39 | setup_seed(val_manual_seed, False)
40 | seed_array = [321]
41 | val_num = len(seed_array)
42 |
43 |
44 | def get_parser():
45 | parser = argparse.ArgumentParser(description='PyTorch Few-Shot Semantic Segmentation')
46 | parser.add_argument('--arch', type=str, default='HMNetAMP')
47 | parser.add_argument('--viz', action='store_true', default=False)
48 | parser.add_argument('--config', type=str, default='config/coco/coco_split3_resnet50.yaml',
49 | help='config file') # coco/coco_split0_resnet50.yaml
50 | parser.add_argument('--episode', help='number of test episodes', type=int, default=4000)
51 | parser.add_argument('--opts', help='see config/ade20k/ade20k_pspnet50.yaml for all options', default=None,
52 | nargs=argparse.REMAINDER)
53 | args = parser.parse_args()
54 | assert args.config is not None
55 | cfg = config.load_cfg_from_cfg_file(args.config)
56 | cfg = config.merge_cfg_from_args(cfg, args)
57 | if args.opts is not None:
58 | cfg = config.merge_cfg_from_list(cfg, args.opts)
59 | return cfg
60 |
61 |
62 | def get_model(args):
63 | model = eval(args.arch).OneModel(args)
64 | optimizer, optimizer_swin = model.get_optim(model, args, LR=args.base_lr)
65 |
66 | model = model.cuda()
67 |
68 | # Resume
69 | get_save_path(args)
70 | check_makedirs(args.snapshot_path)
71 | check_makedirs(args.result_path)
72 |
73 | if args.weight:
74 | weight_path = osp.join(args.snapshot_path, args.weight)
75 | if os.path.isfile(weight_path):
76 | logger.info("=> loading checkpoint '{}'".format(weight_path))
77 | checkpoint = torch.load(weight_path, map_location=torch.device('cpu'))
78 | args.start_epoch = checkpoint['epoch']
79 | new_param = checkpoint['state_dict']
80 | try:
81 | model.load_state_dict(new_param)
82 | except RuntimeError: # 1GPU loads mGPU model
83 | for key in list(new_param.keys()):
84 | new_param[key[7:]] = new_param.pop(key)
85 | model.load_state_dict(new_param)
86 | optimizer.load_state_dict(checkpoint['optimizer'])
87 | optimizer_swin.load_state_dict(checkpoint['optimizer_swin'])
88 | logger.info("=> loaded checkpoint '{}' (epoch {})".format(weight_path, checkpoint['epoch']))
89 | else:
90 | logger.info("=> no checkpoint found at '{}'".format(weight_path))
91 |
92 | # Get model para.
93 | total_number, learnable_number = get_model_para_number(model)
94 | print('Number of Parameters: %d' % (total_number))
95 | print('Number of Learnable Parameters: %d' % (learnable_number))
96 |
97 | time.sleep(5)
98 | return model, optimizer, optimizer_swin
99 |
100 |
101 | def main():
102 | global args, logger, writer
103 | args = get_parser()
104 | logger = get_logger()
105 | args.distributed = True if torch.cuda.device_count() > 1 else False
106 | print(args)
107 |
108 | if args.manual_seed is not None:
109 | setup_seed(args.manual_seed, args.seed_deterministic)
110 |
111 | assert args.classes > 1
112 | assert args.zoom_factor in [1, 2, 4, 8]
113 | assert (args.train_h - 1) % 8 == 0 and (args.train_w - 1) % 8 == 0
114 |
115 | logger.info("=> creating model ...")
116 | model, optimizer, optimizer_swin = get_model(args)
117 | logger.info(model)
118 |
119 | # ---------------------- DATASET ----------------------
120 | value_scale = 255
121 | mean = [0.485, 0.456, 0.406]
122 | mean = [item * value_scale for item in mean]
123 | std = [0.229, 0.224, 0.225]
124 | std = [item * value_scale for item in std]
125 | # Val
126 | if args.evaluate:
127 | if args.resized_val:
128 | val_transform = transform.Compose([
129 | transform.Resize(size=args.val_size),
130 | transform.ToTensor(),
131 | transform.Normalize(mean=mean, std=std)])
132 | else:
133 | val_transform = transform.Compose([
134 | transform.test_Resize(size=args.val_size),
135 | transform.ToTensor(),
136 | transform.Normalize(mean=mean, std=std)])
137 | if args.data_set == 'pascal' or args.data_set == 'coco':
138 | val_data = dataset.SemData(split=args.split, shot=args.shot, data_root=args.data_root,
139 | data_list=args.val_list, transform=val_transform, mode='val',
140 | ann_type=args.ann_type, data_set=args.data_set, use_split_coco=args.use_split_coco)
141 | val_loader = torch.utils.data.DataLoader(val_data, batch_size=args.batch_size_val, shuffle=False,
142 | num_workers=args.workers, pin_memory=False, sampler=None)
143 |
144 | # ---------------------- VAL ----------------------
145 | start_time = time.time()
146 | FBIoU_array = np.zeros(val_num)
147 | mIoU_array = np.zeros(val_num)
148 | pIoU_array = np.zeros(val_num)
149 | for val_id in range(val_num):
150 | val_seed = seed_array[val_id]
151 | print('Val: [{}/{}] \t Seed: {}'.format(val_id + 1, val_num, val_seed))
152 | fb_iou, miou, piou = validate(val_loader, model, val_seed, args.episode)
153 | FBIoU_array[val_id], mIoU_array[val_id], pIoU_array[val_id] = \
154 | fb_iou, miou, piou
155 |
156 | total_time = time.time() - start_time
157 | t_m, t_s = divmod(total_time, 60)
158 | t_h, t_m = divmod(t_m, 60)
159 | total_time = '{:02d}h {:02d}m {:02d}s'.format(int(t_h), int(t_m), int(t_s))
160 |
161 | print('\nTotal running time: {}'.format(total_time))
162 | print('Seed0: {}'.format(val_manual_seed))
163 | print('Seed: {}'.format(seed_array))
164 | print('mIoU: {}'.format(np.round(mIoU_array, 4)))
165 | print('FBIoU: {}'.format(np.round(FBIoU_array, 4)))
166 | print('pIoU: {}'.format(np.round(pIoU_array, 4)))
167 | print('-' * 43)
168 | print('Best_Seed_m: {} \t Best_Seed_F: {} \t Best_Seed_p: {}'.format(seed_array[mIoU_array.argmax()],
169 | seed_array[FBIoU_array.argmax()],
170 | seed_array[pIoU_array.argmax()]))
171 | print('Best_mIoU: {:.4f} \t Best_FBIoU: {:.4f} \t Best_pIoU: {:.4f}'.format(
172 | mIoU_array.max(), FBIoU_array.max(), pIoU_array.max()))
173 | print('Mean_mIoU: {:.4f} \t Mean_FBIoU: {:.4f} \t Mean_pIoU: {:.4f}'.format(
174 | mIoU_array.mean(), FBIoU_array.mean(), pIoU_array.mean()))
175 |
176 |
177 | def validate(val_loader, model, val_seed, episode):
178 | logger.info('>>>>>>>>>>>>>>>> Start Evaluation >>>>>>>>>>>>>>>>')
179 | batch_time = AverageMeter()
180 | model_time = AverageMeter()
181 | data_time = AverageMeter()
182 | loss_meter = AverageMeter()
183 |
184 | intersection_meter = AverageMeter() # final
185 | union_meter = AverageMeter()
186 | target_meter = AverageMeter()
187 |
188 | if args.data_set == 'pascal':
189 | test_num = 1000
190 | split_gap = 5
191 | elif args.data_set == 'coco':
192 | test_num = episode
193 | split_gap = 20
194 |
195 | class_intersection_meter = [0] * split_gap
196 | class_union_meter = [0] * split_gap
197 |
198 | setup_seed(val_seed, args.seed_deterministic)
199 |
200 | criterion = nn.CrossEntropyLoss(ignore_index=args.ignore_label)
201 |
202 | model.eval()
203 | end = time.time()
204 | val_start = end
205 |
206 | assert test_num % args.batch_size_val == 0
207 | db_epoch = math.ceil(test_num / (len(val_loader) - args.batch_size_val))
208 | iter_num = 0
209 |
210 | for e in range(db_epoch):
211 | for i, (input, target, s_input, s_mask, subcls, ori_label) in enumerate(val_loader):
212 | if iter_num * args.batch_size_val >= test_num:
213 | break
214 | iter_num += 1
215 | data_time.update(time.time() - end)
216 |
217 | s_input = s_input.cuda(non_blocking=True)
218 | s_mask = s_mask.cuda(non_blocking=True)
219 | input = input.cuda(non_blocking=True)
220 | target = target.cuda(non_blocking=True)
221 | ori_label = ori_label.cuda(non_blocking=True)
222 |
223 | start_time = time.time()
224 |
225 | with autocast():
226 | output = model(s_x=s_input, s_y=s_mask, x=input, y_m=target, cat_idx=subcls)
227 | model_time.update(time.time() - start_time)
228 |
229 | if args.ori_resize:
230 | output = F.interpolate(output, size=ori_label.size()[-2:], mode='bilinear', align_corners=True)
231 | target = ori_label.long()
232 |
233 | output = F.interpolate(output, size=target.size()[1:], mode='bilinear', align_corners=True)
234 | loss = criterion(output, target)
235 |
236 | output = output.max(1)[1]
237 | subcls = subcls[0].cpu().numpy()[0]
238 |
239 | intersection, union, new_target = intersectionAndUnionGPU(output, target, args.classes, args.ignore_label)
240 | intersection, union, new_target = intersection.cpu().numpy(), union.cpu().numpy(), new_target.cpu().numpy()
241 | intersection_meter.update(intersection), union_meter.update(union), target_meter.update(new_target)
242 | class_intersection_meter[subcls] += intersection[1]
243 | class_union_meter[subcls] += union[1]
244 |
245 | accuracy = sum(intersection_meter.val) / (sum(target_meter.val) + 1e-10)
246 | loss_meter.update(loss.item(), input.size(0))
247 | batch_time.update(time.time() - end)
248 | end = time.time()
249 |
250 | remain_iter = test_num / args.batch_size_val - iter_num
251 | remain_time = remain_iter * batch_time.avg
252 | t_m, t_s = divmod(remain_time, 60)
253 | t_h, t_m = divmod(t_m, 60)
254 | remain_time = '{:02d}:{:02d}:{:02d}'.format(int(t_h), int(t_m), int(t_s))
255 |
256 | if ((i + 1) % round((test_num / 100)) == 0):
257 | logger.info('Test: [{}/{}] '
258 | 'Data {data_time.val:.3f} ({data_time.avg:.3f}) '
259 | 'Batch {batch_time.val:.3f} ({batch_time.avg:.3f}) '
260 | 'Remain {remain_time} '
261 | 'Loss {loss_meter.val:.4f} ({loss_meter.avg:.4f}) '
262 | 'Accuracy {accuracy:.4f}.'.format(iter_num * args.batch_size_val, test_num,
263 | data_time=data_time,
264 | batch_time=batch_time,
265 | remain_time=remain_time,
266 | loss_meter=loss_meter,
267 | accuracy=accuracy))
268 | val_time = time.time() - val_start
269 |
270 | iou_class = intersection_meter.sum / (union_meter.sum + 1e-10)
271 | mIoU = np.mean(iou_class)
272 |
273 | class_iou_class = []
274 | class_miou = 0
275 | for i in range(len(class_intersection_meter)):
276 | class_iou = class_intersection_meter[i] / (class_union_meter[i] + 1e-10)
277 | class_iou_class.append(class_iou)
278 | class_miou += class_iou
279 |
280 | class_miou = class_miou * 1.0 / len(class_intersection_meter)
281 | logger.info('meanIoU---Val result: mIoU {:.4f}.'.format(class_miou)) # final
282 | logger.info('<<<<<<< Novel Results <<<<<<<')
283 | for i in range(split_gap):
284 | logger.info('Class_{} Result: iou {:.4f}.'.format(i + 1, class_iou_class[i]))
285 |
286 | logger.info('FBIoU---Val result: FBIoU {:.4f}.'.format(mIoU))
287 | for i in range(args.classes):
288 | logger.info('Class_{} Result: iou_f {:.4f}.'.format(i, iou_class[i]))
289 | logger.info('<<<<<<<<<<<<<<<<< End Evaluation <<<<<<<<<<<<<<<<<')
290 |
291 | print('total time: {:.4f}, avg inference time: {:.4f}, count: {}'.format(val_time, model_time.avg, test_num))
292 |
293 | return mIoU, class_miou, iou_class[1]
294 |
295 |
296 | if __name__ == '__main__':
297 | main()
298 |
--------------------------------------------------------------------------------
/test_coco.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | export CUDA_VISIBLE_DEVICES=0
3 |
4 | dataset=coco
5 | exp_name=split$1 # 0
6 | arch=HMNetAMP
7 | net=$2 # vgg/resnet50
8 | postfix=$3 # manet/manet_5s
9 |
10 | config=config/${dataset}/${dataset}_${exp_name}_${net}_${postfix}.yaml
11 |
12 | python test_coco.py --config=${config} --arch=${arch} --eposide=4000
13 |
--------------------------------------------------------------------------------
/test_pascal.py:
--------------------------------------------------------------------------------
1 | import os
2 | import datetime
3 | import random
4 | import time
5 | import cv2
6 | import numpy as np
7 | import logging
8 | import argparse
9 | import math
10 | from visdom import Visdom
11 | import os.path as osp
12 |
13 | import torch
14 | import torch.backends.cudnn as cudnn
15 | import torch.nn as nn
16 | import torch.nn.functional as F
17 | import torch.nn.parallel
18 | import torch.optim
19 | import torch.utils.data
20 | import torch.multiprocessing as mp
21 | import torch.distributed as dist
22 | from torch.utils.data.distributed import DistributedSampler
23 |
24 | from tensorboardX import SummaryWriter
25 |
26 | from model import HMNet
27 |
28 | from util import dataset
29 | from util import transform, transform_tri, config
30 | from util.util import AverageMeter, poly_learning_rate, intersectionAndUnionGPU, get_model_para_number, setup_seed, \
31 | get_logger, get_save_path, \
32 | is_same_model, fix_bn, sum_list, check_makedirs
33 | import matplotlib.pyplot as plt
34 |
35 | cv2.ocl.setUseOpenCL(False)
36 | cv2.setNumThreads(0)
37 | val_manual_seed = 123
38 | setup_seed(val_manual_seed, False)
39 | seed_array = [321]
40 | val_num = len(seed_array)
41 |
42 |
43 | def get_parser():
44 | parser = argparse.ArgumentParser(description='PyTorch Few-Shot Semantic Segmentation')
45 | parser.add_argument('--arch', type=str, default='HMNet')
46 | parser.add_argument('--viz', action='store_true', default=False)
47 | parser.add_argument('--config', type=str, default='config/coco/coco_split3_resnet50.yaml',
48 | help='config file') # coco/coco_split0_resnet50.yaml
49 | parser.add_argument('--episode', help='number of test episodes', type=int, default=1000)
50 | parser.add_argument('--opts', help='see config/ade20k/ade20k_pspnet50.yaml for all options', default=None,
51 | nargs=argparse.REMAINDER)
52 | args = parser.parse_args()
53 | assert args.config is not None
54 | cfg = config.load_cfg_from_cfg_file(args.config)
55 | cfg = config.merge_cfg_from_args(cfg, args)
56 | if args.opts is not None:
57 | cfg = config.merge_cfg_from_list(cfg, args.opts)
58 | return cfg
59 |
60 |
61 | def get_model(args):
62 | model = eval(args.arch).OneModel(args)
63 | optimizer, optimizer_swin = model.get_optim(model, args, LR=args.base_lr)
64 |
65 | model = model.cuda()
66 |
67 | # Resume
68 | get_save_path(args)
69 | check_makedirs(args.snapshot_path)
70 | check_makedirs(args.result_path)
71 |
72 | if args.weight:
73 | weight_path = osp.join(args.snapshot_path, args.weight)
74 | if os.path.isfile(weight_path):
75 | logger.info("=> loading checkpoint '{}'".format(weight_path))
76 | checkpoint = torch.load(weight_path, map_location=torch.device('cpu'))
77 | args.start_epoch = checkpoint['epoch']
78 | new_param = checkpoint['state_dict']
79 | try:
80 | model.load_state_dict(new_param)
81 | except RuntimeError: # 1GPU loads mGPU model
82 | for key in list(new_param.keys()):
83 | new_param[key[7:]] = new_param.pop(key)
84 | model.load_state_dict(new_param)
85 | optimizer.load_state_dict(checkpoint['optimizer'])
86 | optimizer_swin.load_state_dict(checkpoint['optimizer_swin'])
87 | logger.info("=> loaded checkpoint '{}' (epoch {})".format(weight_path, checkpoint['epoch']))
88 | else:
89 | logger.info("=> no checkpoint found at '{}'".format(weight_path))
90 |
91 | # Get model para.
92 | total_number, learnable_number = get_model_para_number(model)
93 | print('Number of Parameters: %d' % (total_number))
94 | print('Number of Learnable Parameters: %d' % (learnable_number))
95 |
96 | time.sleep(5)
97 | return model, optimizer, optimizer_swin
98 |
99 |
100 | def main():
101 | global args, logger, writer
102 | args = get_parser()
103 | logger = get_logger()
104 | args.distributed = True if torch.cuda.device_count() > 1 else False
105 | print(args)
106 |
107 | if args.manual_seed is not None:
108 | setup_seed(args.manual_seed, args.seed_deterministic)
109 |
110 | assert args.classes > 1
111 | assert args.zoom_factor in [1, 2, 4, 8]
112 | assert (args.train_h - 1) % 8 == 0 and (args.train_w - 1) % 8 == 0
113 |
114 | logger.info("=> creating model ...")
115 | model, optimizer, optimizer_swin = get_model(args)
116 | logger.info(model)
117 |
118 | # ---------------------- DATASET ----------------------
119 | value_scale = 255
120 | mean = [0.485, 0.456, 0.406]
121 | mean = [item * value_scale for item in mean]
122 | std = [0.229, 0.224, 0.225]
123 | std = [item * value_scale for item in std]
124 | # Val
125 | if args.evaluate:
126 | if args.resized_val:
127 | val_transform = transform.Compose([
128 | transform.Resize(size=args.val_size),
129 | transform.ToTensor(),
130 | transform.Normalize(mean=mean, std=std)])
131 | else:
132 | val_transform = transform.Compose([
133 | transform.test_Resize(size=args.val_size),
134 | transform.ToTensor(),
135 | transform.Normalize(mean=mean, std=std)])
136 | if args.data_set == 'pascal' or args.data_set == 'coco':
137 | val_data = dataset.SemData(split=args.split, shot=args.shot, data_root=args.data_root,
138 | data_list=args.val_list, transform=val_transform, mode='val',
139 | ann_type=args.ann_type, data_set=args.data_set, use_split_coco=args.use_split_coco)
140 | val_loader = torch.utils.data.DataLoader(val_data, batch_size=args.batch_size_val, shuffle=False,
141 | num_workers=args.workers, pin_memory=False, sampler=None)
142 |
143 | # ---------------------- VAL ----------------------
144 | start_time = time.time()
145 | FBIoU_array = np.zeros(val_num)
146 | mIoU_array = np.zeros(val_num)
147 | pIoU_array = np.zeros(val_num)
148 | for val_id in range(val_num):
149 | val_seed = seed_array[val_id]
150 | print('Val: [{}/{}] \t Seed: {}'.format(val_id + 1, val_num, val_seed))
151 | fb_iou, miou, piou = validate(val_loader, model, val_seed, args.episode)
152 | FBIoU_array[val_id], mIoU_array[val_id], pIoU_array[val_id] = \
153 | fb_iou, miou, piou
154 |
155 | total_time = time.time() - start_time
156 | t_m, t_s = divmod(total_time, 60)
157 | t_h, t_m = divmod(t_m, 60)
158 | total_time = '{:02d}h {:02d}m {:02d}s'.format(int(t_h), int(t_m), int(t_s))
159 |
160 | print('\nTotal running time: {}'.format(total_time))
161 | print('Seed0: {}'.format(val_manual_seed))
162 | print('Seed: {}'.format(seed_array))
163 | print('mIoU: {}'.format(np.round(mIoU_array, 4)))
164 | print('FBIoU: {}'.format(np.round(FBIoU_array, 4)))
165 | print('pIoU: {}'.format(np.round(pIoU_array, 4)))
166 | print('-' * 43)
167 | print('Best_Seed_m: {} \t Best_Seed_F: {} \t Best_Seed_p: {}'.format(seed_array[mIoU_array.argmax()],
168 | seed_array[FBIoU_array.argmax()],
169 | seed_array[pIoU_array.argmax()]))
170 | print('Best_mIoU: {:.4f} \t Best_FBIoU: {:.4f} \t Best_pIoU: {:.4f}'.format(
171 | mIoU_array.max(), FBIoU_array.max(), pIoU_array.max()))
172 | print('Mean_mIoU: {:.4f} \t Mean_FBIoU: {:.4f} \t Mean_pIoU: {:.4f}'.format(
173 | mIoU_array.mean(), FBIoU_array.mean(), pIoU_array.mean()))
174 |
175 |
176 | def validate(val_loader, model, val_seed, episode):
177 | logger.info('>>>>>>>>>>>>>>>> Start Evaluation >>>>>>>>>>>>>>>>')
178 | batch_time = AverageMeter()
179 | model_time = AverageMeter()
180 | data_time = AverageMeter()
181 | loss_meter = AverageMeter()
182 |
183 | intersection_meter = AverageMeter() # final
184 | union_meter = AverageMeter()
185 | target_meter = AverageMeter()
186 |
187 | if args.data_set == 'pascal':
188 | test_num = 1000
189 | split_gap = 5
190 | elif args.data_set == 'coco':
191 | test_num = episode
192 | split_gap = 20
193 |
194 | class_intersection_meter = [0] * split_gap
195 | class_union_meter = [0] * split_gap
196 |
197 | setup_seed(val_seed, args.seed_deterministic)
198 |
199 | criterion = nn.CrossEntropyLoss(ignore_index=args.ignore_label)
200 |
201 | model.eval()
202 | end = time.time()
203 | val_start = end
204 |
205 | assert test_num % args.batch_size_val == 0
206 | db_epoch = math.ceil(test_num / (len(val_loader) - args.batch_size_val))
207 | iter_num = 0
208 |
209 | for e in range(db_epoch):
210 | for i, (input, target, s_input, s_mask, subcls, ori_label) in enumerate(val_loader):
211 | if iter_num * args.batch_size_val >= test_num:
212 | break
213 | iter_num += 1
214 | data_time.update(time.time() - end)
215 |
216 | s_input = s_input.cuda(non_blocking=True)
217 | s_mask = s_mask.cuda(non_blocking=True)
218 | input = input.cuda(non_blocking=True)
219 | target = target.cuda(non_blocking=True)
220 | ori_label = ori_label.cuda(non_blocking=True)
221 |
222 | start_time = time.time()
223 |
224 | output = model(s_x=s_input, s_y=s_mask, x=input, y_m=target, cat_idx=subcls)
225 | model_time.update(time.time() - start_time)
226 |
227 | if args.ori_resize:
228 | output = F.interpolate(output, size=ori_label.size()[-2:], mode='bilinear', align_corners=True)
229 | target = ori_label.long()
230 |
231 | output = F.interpolate(output, size=target.size()[1:], mode='bilinear', align_corners=True)
232 | loss = criterion(output, target)
233 |
234 | output = output.max(1)[1]
235 | subcls = subcls[0].cpu().numpy()[0]
236 |
237 | intersection, union, new_target = intersectionAndUnionGPU(output, target, args.classes, args.ignore_label)
238 | intersection, union, new_target = intersection.cpu().numpy(), union.cpu().numpy(), new_target.cpu().numpy()
239 | intersection_meter.update(intersection), union_meter.update(union), target_meter.update(new_target)
240 | class_intersection_meter[subcls] += intersection[1]
241 | class_union_meter[subcls] += union[1]
242 |
243 | accuracy = sum(intersection_meter.val) / (sum(target_meter.val) + 1e-10)
244 | loss_meter.update(loss.item(), input.size(0))
245 | batch_time.update(time.time() - end)
246 | end = time.time()
247 |
248 | remain_iter = test_num / args.batch_size_val - iter_num
249 | remain_time = remain_iter * batch_time.avg
250 | t_m, t_s = divmod(remain_time, 60)
251 | t_h, t_m = divmod(t_m, 60)
252 | remain_time = '{:02d}:{:02d}:{:02d}'.format(int(t_h), int(t_m), int(t_s))
253 |
254 | if ((i + 1) % round((test_num / 100)) == 0):
255 | logger.info('Test: [{}/{}] '
256 | 'Data {data_time.val:.3f} ({data_time.avg:.3f}) '
257 | 'Batch {batch_time.val:.3f} ({batch_time.avg:.3f}) '
258 | 'Remain {remain_time} '
259 | 'Loss {loss_meter.val:.4f} ({loss_meter.avg:.4f}) '
260 | 'Accuracy {accuracy:.4f}.'.format(iter_num * args.batch_size_val, test_num,
261 | data_time=data_time,
262 | batch_time=batch_time,
263 | remain_time=remain_time,
264 | loss_meter=loss_meter,
265 | accuracy=accuracy))
266 | val_time = time.time() - val_start
267 |
268 | iou_class = intersection_meter.sum / (union_meter.sum + 1e-10)
269 | mIoU = np.mean(iou_class)
270 |
271 | class_iou_class = []
272 | class_miou = 0
273 | for i in range(len(class_intersection_meter)):
274 | class_iou = class_intersection_meter[i] / (class_union_meter[i] + 1e-10)
275 | class_iou_class.append(class_iou)
276 | class_miou += class_iou
277 |
278 | class_miou = class_miou * 1.0 / len(class_intersection_meter)
279 | logger.info('meanIoU---Val result: mIoU {:.4f}.'.format(class_miou)) # final
280 | logger.info('<<<<<<< Novel Results <<<<<<<')
281 | for i in range(split_gap):
282 | logger.info('Class_{} Result: iou {:.4f}.'.format(i + 1, class_iou_class[i]))
283 |
284 | logger.info('FBIoU---Val result: FBIoU {:.4f}.'.format(mIoU))
285 | for i in range(args.classes):
286 | logger.info('Class_{} Result: iou_f {:.4f}.'.format(i, iou_class[i]))
287 | logger.info('<<<<<<<<<<<<<<<<< End Evaluation <<<<<<<<<<<<<<<<<')
288 |
289 | print('total time: {:.4f}, avg inference time: {:.4f}, count: {}'.format(val_time, model_time.avg, test_num))
290 |
291 | return mIoU, class_miou, iou_class[1]
292 |
293 |
294 | if __name__ == '__main__':
295 | main()
296 |
--------------------------------------------------------------------------------
/test_pascal.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | export CUDA_VISIBLE_DEVICES=0
3 |
4 | dataset=pascal
5 | exp_name=split$1 # 0
6 | arch=HMNet
7 | net=$2 # vgg/resnet50
8 | postfix=$3 # manet/manet_5s
9 |
10 | config=config/${dataset}/${dataset}_${exp_name}_${net}_${postfix}.yaml
11 |
12 | python test_pascal.py --config=${config} --arch=${arch} --eposide=1000
13 |
--------------------------------------------------------------------------------
/train_coco.sh:
--------------------------------------------------------------------------------
1 | gpu=8
2 | port=$1 # 1234
3 | dataset=coco
4 | exp_name=split$2 # 0/1/2/3
5 | arch=HMNetAMP
6 | net=$3 # vgg/renet50
7 | postfix=$4 # manet/manet_5s
8 |
9 | exp_dir=exp/${dataset}/${arch}/${exp_name}/${net}
10 | snapshot_dir=${exp_dir}/snapshot
11 | result_dir=${exp_dir}/result
12 | config=config/${dataset}/${dataset}_${exp_name}_${net}_${postfix}.yaml
13 | mkdir -p ${snapshot_dir} ${result_dir}
14 | now=$(date +"%Y%m%d_%H%M%S")
15 | cp train_coco.sh train_coco.py ${config} ${exp_dir}
16 |
17 | echo ${arch}
18 | echo ${config}
19 |
20 | python3 -m torch.distributed.launch --nproc_per_node=${gpu} --master_port=${port} train_coco.py \
21 | --config=${config} \
22 | --arch=${arch} \
23 | 2>&1 | tee ${result_dir}/train-$now.log
24 |
--------------------------------------------------------------------------------
/train_pascal.sh:
--------------------------------------------------------------------------------
1 | gpu=4
2 | port=$1 # 1234
3 | dataset=pascal
4 | exp_name=split$2 # 0/1/2/3
5 | arch=HMNet
6 | net=$3 # vgg/renet50
7 | postfix=$4 # manet/manet_5s
8 |
9 | exp_dir=exp/${dataset}/${arch}/${exp_name}/${net}
10 | snapshot_dir=${exp_dir}/snapshot
11 | result_dir=${exp_dir}/result
12 | config=config/${dataset}/${dataset}_${exp_name}_${net}_${postfix}.yaml
13 | mkdir -p ${snapshot_dir} ${result_dir}
14 | now=$(date +"%Y%m%d_%H%M%S")
15 | cp train_pascal.sh train_pascal.py ${config} ${exp_dir}
16 |
17 | echo ${arch}
18 | echo ${config}
19 |
20 | python3 -m torch.distributed.launch --nproc_per_node=${gpu} --master_port=${port} train_pascal.py \
21 | --config=${config} \
22 | --arch=${arch} \
23 | 2>&1 | tee ${result_dir}/train-$now.log
24 |
--------------------------------------------------------------------------------
/util/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Sam1224/HMNet/556916cc2036f32d68d8e0b810eb9ebcbee6ae1a/util/__init__.py
--------------------------------------------------------------------------------
/util/config.py:
--------------------------------------------------------------------------------
1 | # -----------------------------------------------------------------------------
2 | # Functions for parsing args
3 | # -----------------------------------------------------------------------------
4 | import yaml
5 | import os
6 | from ast import literal_eval
7 | import copy
8 |
9 |
10 | class CfgNode(dict):
11 | """
12 | CfgNode represents an internal node in the configuration tree. It's a simple
13 | dict-like container that allows for attribute-based access to keys.
14 | """
15 |
16 | def __init__(self, init_dict=None, key_list=None, new_allowed=False):
17 | # Recursively convert nested dictionaries in init_dict into CfgNodes
18 | init_dict = {} if init_dict is None else init_dict
19 | key_list = [] if key_list is None else key_list
20 | for k, v in init_dict.items():
21 | if type(v) is dict:
22 | # Convert dict to CfgNode
23 | init_dict[k] = CfgNode(v, key_list=key_list + [k])
24 | super(CfgNode, self).__init__(init_dict)
25 |
26 | def __getattr__(self, name):
27 | if name in self:
28 | return self[name]
29 | else:
30 | raise AttributeError(name)
31 |
32 | def __setattr__(self, name, value):
33 | self[name] = value
34 |
35 | def __str__(self):
36 | def _indent(s_, num_spaces):
37 | s = s_.split("\n")
38 | if len(s) == 1:
39 | return s_
40 | first = s.pop(0)
41 | s = [(num_spaces * " ") + line for line in s]
42 | s = "\n".join(s)
43 | s = first + "\n" + s
44 | return s
45 |
46 | r = ""
47 | s = []
48 | for k, v in sorted(self.items()):
49 | seperator = "\n" if isinstance(v, CfgNode) else " "
50 | attr_str = "{}:{}{}".format(str(k), seperator, str(v))
51 | attr_str = _indent(attr_str, 2)
52 | s.append(attr_str)
53 | r += "\n".join(s)
54 | return r
55 |
56 | def __repr__(self): # print
57 | return "{}({})".format(self.__class__.__name__, super(CfgNode, self).__repr__())
58 |
59 |
60 | def load_cfg_from_cfg_file(file):
61 | cfg = {}
62 | assert os.path.isfile(file) and file.endswith('.yaml'), \
63 | '{} is not a yaml file'.format(file)
64 |
65 | with open(file, 'r') as f:
66 | cfg_from_file = yaml.safe_load(f)
67 |
68 | for key in cfg_from_file:
69 | for k, v in cfg_from_file[key].items():
70 | cfg[k] = v
71 |
72 | cfg = CfgNode(cfg)
73 | return cfg
74 |
75 | def merge_cfg_from_args(cfg, args):
76 | args_dict = args.__dict__
77 | for k ,v in args_dict.items():
78 | if not k == 'config' or k == 'opts':
79 | cfg[k] = v
80 |
81 | return cfg
82 |
83 | def merge_cfg_from_list(cfg, cfg_list):
84 | new_cfg = copy.deepcopy(cfg)
85 | assert len(cfg_list) % 2 == 0
86 | for full_key, v in zip(cfg_list[0::2], cfg_list[1::2]):
87 | subkey = full_key.split('.')[-1]
88 | assert subkey in cfg, 'Non-existent key: {}'.format(full_key)
89 | value = _decode_cfg_value(v)
90 | value = _check_and_coerce_cfg_value_type(
91 | value, cfg[subkey], subkey, full_key
92 | )
93 | setattr(new_cfg, subkey, value)
94 |
95 | return new_cfg
96 |
97 |
98 | def _decode_cfg_value(v):
99 | """Decodes a raw config value (e.g., from a yaml config files or command
100 | line argument) into a Python object.
101 | """
102 | # All remaining processing is only applied to strings
103 | if not isinstance(v, str):
104 | return v
105 | # Try to interpret `v` as a:
106 | # string, number, tuple, list, dict, boolean, or None
107 | try:
108 | v = literal_eval(v)
109 | # The following two excepts allow v to pass through when it represents a
110 | # string.
111 | #
112 | # Longer explanation:
113 | # The type of v is always a string (before calling literal_eval), but
114 | # sometimes it *represents* a string and other times a data structure, like
115 | # a list. In the case that v represents a string, what we got back from the
116 | # yaml parser is 'foo' *without quotes* (so, not '"foo"'). literal_eval is
117 | # ok with '"foo"', but will raise a ValueError if given 'foo'. In other
118 | # cases, like paths (v = 'foo/bar' and not v = '"foo/bar"'), literal_eval
119 | # will raise a SyntaxError.
120 | except ValueError:
121 | pass
122 | except SyntaxError:
123 | pass
124 | return v
125 |
126 |
127 | def _check_and_coerce_cfg_value_type(replacement, original, key, full_key):
128 | """Checks that `replacement`, which is intended to replace `original` is of
129 | the right type. The type is correct if it matches exactly or is one of a few
130 | cases in which the type can be easily coerced.
131 | """
132 | original_type = type(original)
133 | replacement_type = type(replacement)
134 |
135 | # The types must match (with some exceptions)
136 | if replacement_type == original_type:
137 | return replacement
138 |
139 | # Cast replacement from from_type to to_type if the replacement and original
140 | # types match from_type and to_type
141 | def conditional_cast(from_type, to_type):
142 | if replacement_type == from_type and original_type == to_type:
143 | return True, to_type(replacement)
144 | else:
145 | return False, None
146 |
147 | # Conditionally casts
148 | # list <-> tuple
149 | casts = [(tuple, list), (list, tuple)]
150 | # For py2: allow converting from str (bytes) to a unicode string
151 | try:
152 | casts.append((str, unicode)) # noqa: F821
153 | except Exception:
154 | pass
155 |
156 | for (from_type, to_type) in casts:
157 | converted, converted_value = conditional_cast(from_type, to_type)
158 | if converted:
159 | return converted_value
160 |
161 | raise ValueError(
162 | "Type mismatch ({} vs. {}) with values ({} vs. {}) for config "
163 | "key: {}".format(
164 | original_type, replacement_type, original, replacement, full_key
165 | )
166 | )
167 |
168 |
169 | def _assert_with_logging(cond, msg):
170 | if not cond:
171 | logger.debug(msg)
172 | assert cond, msg
173 |
174 |
--------------------------------------------------------------------------------
/util/dataset.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import os
3 | import os.path
4 | import cv2
5 | import numpy as np
6 | import copy
7 |
8 | from torch.utils.data import Dataset
9 | import torch.nn.functional as F
10 | import torch
11 | import random
12 | import time
13 | from tqdm import tqdm
14 |
15 | from .get_weak_anns import transform_anns
16 |
17 | IMG_EXTENSIONS = ['.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm']
18 |
19 |
20 | def is_image_file(filename):
21 | filename_lower = filename.lower()
22 | return any(filename_lower.endswith(extension) for extension in IMG_EXTENSIONS)
23 |
24 |
25 | def make_dataset(split=0, data_root=None, data_list=None, sub_list=None, filter_intersection=False):
26 | assert split in [0, 1, 2, 3]
27 | if not os.path.isfile(data_list):
28 | raise (RuntimeError("Image list file do not exist: " + data_list + "\n"))
29 |
30 | # Shaban uses these lines to remove small objects:
31 | # if util.change_coordinates(mask, 32.0, 0.0).sum() > 2:
32 | # filtered_item.append(item)
33 | # which means the mask will be downsampled to 1/32 of the original size and the valid area should be larger than 2,
34 | # therefore the area in original size should be accordingly larger than 2 * 32 * 32
35 | image_label_list = []
36 | list_read = open(data_list).readlines()
37 | print("Processing data...".format(sub_list))
38 | sub_class_file_list = {}
39 | for sub_c in sub_list:
40 | sub_class_file_list[sub_c] = []
41 |
42 | for l_idx in tqdm(range(len(list_read))):
43 | line = list_read[l_idx]
44 | line = line.strip()
45 | line_split = line.split(' ')
46 | image_name = os.path.join(data_root, line_split[0])
47 | label_name = os.path.join(data_root, line_split[1])
48 | item = (image_name, label_name)
49 | label = cv2.imread(label_name, cv2.IMREAD_GRAYSCALE)
50 | label_class = np.unique(label).tolist()
51 |
52 | if 0 in label_class:
53 | label_class.remove(0)
54 | if 255 in label_class:
55 | label_class.remove(255)
56 |
57 | new_label_class = []
58 |
59 | if filter_intersection: # filter images containing objects of novel categories during meta-training
60 | if set(label_class).issubset(set(sub_list)):
61 | for c in label_class:
62 | if c in sub_list:
63 | tmp_label = np.zeros_like(label)
64 | target_pix = np.where(label == c)
65 | tmp_label[target_pix[0], target_pix[1]] = 1
66 | if tmp_label.sum() >= 2 * 32 * 32:
67 | new_label_class.append(c)
68 | else:
69 | for c in label_class:
70 | if c in sub_list:
71 | tmp_label = np.zeros_like(label)
72 | target_pix = np.where(label == c)
73 | tmp_label[target_pix[0], target_pix[1]] = 1
74 | if tmp_label.sum() >= 2 * 32 * 32:
75 | new_label_class.append(c)
76 |
77 | label_class = new_label_class
78 |
79 | if len(label_class) > 0:
80 | image_label_list.append(item)
81 | for c in label_class:
82 | if c in sub_list:
83 | sub_class_file_list[c].append(item)
84 |
85 | print("Checking image&label pair {} list done! ".format(split))
86 | return image_label_list, sub_class_file_list
87 |
88 |
89 | class SemData(Dataset):
90 | def __init__(self, split=3, shot=1, data_root=None, base_data_root=None, data_list=None, data_set=None,
91 | use_split_coco=False,
92 | transform=None, transform_tri=None, mode='train', ann_type='mask',
93 | ft_transform=None, ft_aug_size=None):
94 |
95 | assert mode in ['train', 'val', 'demo', 'finetune']
96 | assert data_set in ['pascal', 'coco']
97 | if mode == 'finetune':
98 | assert ft_transform is not None
99 | assert ft_aug_size is not None
100 |
101 | if data_set == 'pascal':
102 | self.num_classes = 20
103 | elif data_set == 'coco':
104 | self.num_classes = 80
105 |
106 | self.mode = mode
107 | self.split = split
108 | self.shot = shot
109 | self.data_root = data_root
110 | self.base_data_root = base_data_root
111 | self.ann_type = ann_type
112 |
113 | if data_set == 'pascal':
114 | self.class_list = list(range(1, 21)) # [1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20]
115 | if self.split == 3:
116 | self.sub_list = list(range(1, 16)) # [1,2,3,4,5,6,7,8,9,10,11,12,13,14,15]
117 | self.sub_val_list = list(range(16, 21)) # [16,17,18,19,20]
118 | elif self.split == 2:
119 | self.sub_list = list(range(1, 11)) + list(range(16, 21)) # [1,2,3,4,5,6,7,8,9,10,16,17,18,19,20]
120 | self.sub_val_list = list(range(11, 16)) # [11,12,13,14,15]
121 | elif self.split == 1:
122 | self.sub_list = list(range(1, 6)) + list(range(11, 21)) # [1,2,3,4,5,11,12,13,14,15,16,17,18,19,20]
123 | self.sub_val_list = list(range(6, 11)) # [6,7,8,9,10]
124 | elif self.split == 0:
125 | self.sub_list = list(range(6, 21)) # [6,7,8,9,10,11,12,13,14,15,16,17,18,19,20]
126 | self.sub_val_list = list(range(1, 6)) # [1,2,3,4,5]
127 |
128 | elif data_set == 'coco':
129 | if use_split_coco:
130 | print('INFO: using SPLIT COCO (FWB)')
131 | self.class_list = list(range(1, 81))
132 | if self.split == 3:
133 | self.sub_val_list = list(range(4, 81, 4))
134 | self.sub_list = list(set(self.class_list) - set(self.sub_val_list))
135 | elif self.split == 2:
136 | self.sub_val_list = list(range(3, 80, 4))
137 | self.sub_list = list(set(self.class_list) - set(self.sub_val_list))
138 | elif self.split == 1:
139 | self.sub_val_list = list(range(2, 79, 4))
140 | self.sub_list = list(set(self.class_list) - set(self.sub_val_list))
141 | elif self.split == 0:
142 | self.sub_val_list = list(range(1, 78, 4))
143 | self.sub_list = list(set(self.class_list) - set(self.sub_val_list))
144 | else:
145 | print('INFO: using COCO (PANet)')
146 | self.class_list = list(range(1, 81))
147 | if self.split == 3:
148 | self.sub_list = list(range(1, 61))
149 | self.sub_val_list = list(range(61, 81))
150 | elif self.split == 2:
151 | self.sub_list = list(range(1, 41)) + list(range(61, 81))
152 | self.sub_val_list = list(range(41, 61))
153 | elif self.split == 1:
154 | self.sub_list = list(range(1, 21)) + list(range(41, 81))
155 | self.sub_val_list = list(range(21, 41))
156 | elif self.split == 0:
157 | self.sub_list = list(range(21, 81))
158 | self.sub_val_list = list(range(1, 21))
159 |
160 | print('sub_list: ', self.sub_list)
161 | print('sub_val_list: ', self.sub_val_list)
162 |
163 | # @@@ For convenience, we skip the step of building datasets and instead use the pre-generated lists @@@
164 | # if self.mode == 'train':
165 | # self.data_list, self.sub_class_file_list = make_dataset(split, data_root, data_list, self.sub_list, True)
166 | # assert len(self.sub_class_file_list.keys()) == len(self.sub_list)
167 | # elif self.mode == 'val' or self.mode == 'demo' or self.mode == 'finetune':
168 | # self.data_list, self.sub_class_file_list = make_dataset(split, data_root, data_list, self.sub_val_list, False)
169 | # assert len(self.sub_class_file_list.keys()) == len(self.sub_val_list)
170 |
171 | mode = 'train' if self.mode == 'train' else 'val'
172 |
173 | fss_list_root = './lists/{}/fss_list/{}/'.format(data_set, mode)
174 | fss_data_list_path = fss_list_root + 'data_list_{}.txt'.format(split)
175 | fss_sub_class_file_list_path = fss_list_root + 'sub_class_file_list_{}.txt'.format(split)
176 |
177 | # Write FSS Data
178 | # with open(fss_data_list_path, 'w') as f:
179 | # for item in self.data_list:
180 | # img, label = item
181 | # f.write(img + ' ')
182 | # f.write(label + '\n')
183 | # with open(fss_sub_class_file_list_path, 'w') as f:
184 | # f.write(str(self.sub_class_file_list))
185 |
186 | # Read FSS Data
187 | with open(fss_data_list_path, 'r') as f:
188 | f_str = f.readlines()
189 | self.data_list = []
190 | for line in f_str:
191 | img, mask = line.split(' ')
192 | self.data_list.append((img, mask.strip()))
193 |
194 | with open(fss_sub_class_file_list_path, 'r') as f:
195 | f_str = f.read()
196 | self.sub_class_file_list = eval(f_str)
197 |
198 | self.transform = transform
199 | self.transform_tri = transform_tri
200 | self.ft_transform = ft_transform
201 | self.ft_aug_size = ft_aug_size
202 |
203 | def __len__(self):
204 | return len(self.data_list)
205 |
206 | def __getitem__(self, index):
207 | image_path, label_path = self.data_list[index]
208 | image = cv2.imread(image_path, cv2.IMREAD_COLOR)
209 | image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
210 | image = np.float32(image)
211 | label = cv2.imread(label_path, cv2.IMREAD_GRAYSCALE)
212 |
213 | if image.shape[0] != label.shape[0] or image.shape[1] != label.shape[1]:
214 | raise (RuntimeError("Query Image & label shape mismatch: " + image_path + " " + label_path + "\n"))
215 | label_class = np.unique(label).tolist()
216 | if 0 in label_class:
217 | label_class.remove(0)
218 | if 255 in label_class:
219 | label_class.remove(255)
220 | new_label_class = []
221 | for c in label_class:
222 | if c in self.sub_val_list:
223 | if self.mode == 'val' or self.mode == 'demo' or self.mode == 'finetune':
224 | new_label_class.append(c)
225 | if c in self.sub_list:
226 | if self.mode == 'train':
227 | new_label_class.append(c)
228 | label_class = new_label_class
229 | assert len(label_class) > 0
230 |
231 | class_chosen = label_class[random.randint(1, len(label_class)) - 1]
232 | target_pix = np.where(label == class_chosen)
233 | ignore_pix = np.where(label == 255)
234 | label[:, :] = 0
235 | if target_pix[0].shape[0] > 0:
236 | label[target_pix[0], target_pix[1]] = 1
237 | label[ignore_pix[0], ignore_pix[1]] = 255
238 |
239 | file_class_chosen = self.sub_class_file_list[class_chosen]
240 | num_file = len(file_class_chosen)
241 |
242 | support_image_path_list = []
243 | support_label_path_list = []
244 | support_idx_list = []
245 | for k in range(self.shot):
246 | support_idx = random.randint(1, num_file) - 1
247 | support_image_path = image_path
248 | support_label_path = label_path
249 | while ((
250 | support_image_path == image_path and support_label_path == label_path) or support_idx in support_idx_list):
251 | support_idx = random.randint(1, num_file) - 1
252 | support_image_path, support_label_path = file_class_chosen[support_idx]
253 | support_idx_list.append(support_idx)
254 | support_image_path_list.append(support_image_path)
255 | support_label_path_list.append(support_label_path)
256 |
257 | support_image_list_ori = []
258 | support_label_list_ori = []
259 | support_label_list_ori_mask = []
260 | subcls_list = []
261 | if self.mode == 'train':
262 | subcls_list.append(self.sub_list.index(class_chosen))
263 | else:
264 | subcls_list.append(self.sub_val_list.index(class_chosen))
265 | for k in range(self.shot):
266 | support_image_path = support_image_path_list[k]
267 | support_label_path = support_label_path_list[k]
268 | support_image = cv2.imread(support_image_path, cv2.IMREAD_COLOR)
269 | support_image = cv2.cvtColor(support_image, cv2.COLOR_BGR2RGB)
270 | support_image = np.float32(support_image)
271 | support_label = cv2.imread(support_label_path, cv2.IMREAD_GRAYSCALE)
272 | target_pix = np.where(support_label == class_chosen)
273 | ignore_pix = np.where(support_label == 255)
274 | support_label[:, :] = 0
275 | support_label[target_pix[0], target_pix[1]] = 1
276 |
277 | support_label, support_label_mask = transform_anns(support_label, self.ann_type) # mask/bbox
278 | support_label[ignore_pix[0], ignore_pix[1]] = 255
279 | support_label_mask[ignore_pix[0], ignore_pix[1]] = 255
280 | if support_image.shape[0] != support_label.shape[0] or support_image.shape[1] != support_label.shape[1]:
281 | raise (RuntimeError(
282 | "Support Image & label shape mismatch: " + support_image_path + " " + support_label_path + "\n"))
283 | support_image_list_ori.append(support_image)
284 | support_label_list_ori.append(support_label)
285 | support_label_list_ori_mask.append(support_label_mask)
286 | assert len(support_label_list_ori) == self.shot and len(support_image_list_ori) == self.shot
287 |
288 | raw_image = image.copy()
289 | raw_label = label.copy()
290 | support_image_list = [[] for _ in range(self.shot)]
291 | support_label_list = [[] for _ in range(self.shot)]
292 | if self.transform is not None:
293 | image, label = self.transform(image, label)
294 | for k in range(self.shot):
295 | support_image_list[k], support_label_list[k] = self.transform(support_image_list_ori[k],
296 | support_label_list_ori[k])
297 |
298 | s_xs = support_image_list
299 | s_ys = support_label_list
300 | s_x = s_xs[0].unsqueeze(0)
301 | for i in range(1, self.shot):
302 | s_x = torch.cat([s_xs[i].unsqueeze(0), s_x], 0)
303 | s_y = s_ys[0].unsqueeze(0)
304 | for i in range(1, self.shot):
305 | s_y = torch.cat([s_ys[i].unsqueeze(0), s_y], 0)
306 |
307 | # Return
308 | if self.mode == 'train':
309 | return image, label, s_x, s_y, subcls_list
310 | elif self.mode == 'val':
311 | return image, label, s_x, s_y, subcls_list, raw_label
312 |
--------------------------------------------------------------------------------
/util/get_weak_anns.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import, division
2 |
3 | import networkx as nx
4 | import numpy as np
5 | from scipy.ndimage import binary_dilation, binary_erosion, maximum_filter
6 | from scipy.special import comb
7 | from skimage.filters import rank
8 | from skimage.morphology import dilation, disk, erosion, medial_axis
9 | from sklearn.neighbors import radius_neighbors_graph
10 | import cv2
11 | import matplotlib.pyplot as plt
12 | import matplotlib.patches as mpatches
13 | from scipy import ndimage
14 |
15 |
16 | def find_bbox(mask):
17 | _, labels, stats, centroids = cv2.connectedComponentsWithStats(mask.astype(np.uint8))
18 | return stats[1:] # remove bg stat
19 |
20 |
21 | def transform_anns(mask, ann_type):
22 | mask_ori = mask.copy()
23 |
24 | if ann_type == 'bbox':
25 | bboxs = find_bbox(mask)
26 | for j in bboxs:
27 | cv2.rectangle(mask, (j[0], j[1]), (j[0] + j[2], j[1] + j[3]), 1, -1) # -1->fill; 2->draw_rec
28 | return mask, mask_ori
29 |
30 | elif ann_type == 'mask':
31 | return mask, mask_ori
32 |
33 |
34 | if __name__ == '__main__':
35 | label_path = '2008_001227.png'
36 | mask = cv2.imread(label_path, cv2.IMREAD_GRAYSCALE)
37 | bboxs = find_bbox(mask)
38 | mask_color = cv2.imread(label_path, cv2.IMREAD_COLOR)
39 | for j in bboxs:
40 | cv2.rectangle(mask_color, (j[0], j[1]), (j[0] + j[2], j[1] + j[3]), (0, 255, 0), -1)
41 | cv2.imwrite('bbox.png', mask_color)
42 |
43 | print('done')
44 |
--------------------------------------------------------------------------------
/util/transform.py:
--------------------------------------------------------------------------------
1 | import random
2 | import math
3 | import numpy as np
4 | import numbers
5 | import collections
6 | collections.Iterable = collections.abc.Iterable
7 | import cv2
8 |
9 | import torch
10 |
11 | manual_seed = 123
12 | torch.manual_seed(manual_seed)
13 | np.random.seed(manual_seed)
14 | torch.manual_seed(manual_seed)
15 | torch.cuda.manual_seed_all(manual_seed)
16 | random.seed(manual_seed)
17 |
18 |
19 | class Compose(object):
20 | # Composes segtransforms: segtransform.Compose([segtransform.RandScale([0.5, 2.0]), segtransform.ToTensor()])
21 | def __init__(self, segtransform):
22 | self.segtransform = segtransform
23 |
24 | def __call__(self, image, label):
25 | for t in self.segtransform:
26 | image, label = t(image, label)
27 | return image, label
28 |
29 |
30 | import time
31 |
32 |
33 | class ToTensor(object):
34 | # Converts numpy.ndarray (H x W x C) to a torch.FloatTensor of shape (C x H x W).
35 | def __call__(self, image, label):
36 | if not isinstance(image, np.ndarray) or not isinstance(label, np.ndarray):
37 | raise (RuntimeError("segtransform.ToTensor() only handle np.ndarray"
38 | "[eg: data readed by cv2.imread()].\n"))
39 | if len(image.shape) > 3 or len(image.shape) < 2:
40 | raise (RuntimeError("segtransform.ToTensor() only handle np.ndarray with 3 dims or 2 dims.\n"))
41 | if len(image.shape) == 2:
42 | image = np.expand_dims(image, axis=2)
43 | if not len(label.shape) == 2:
44 | raise (RuntimeError("segtransform.ToTensor() only handle np.ndarray labellabel with 2 dims.\n"))
45 |
46 | image = torch.from_numpy(image.transpose((2, 0, 1)))
47 | if not isinstance(image, torch.FloatTensor):
48 | image = image.float()
49 | label = torch.from_numpy(label)
50 | if not isinstance(label, torch.LongTensor):
51 | label = label.long()
52 | return image, label
53 |
54 |
55 | class CLAHE(object):
56 | # Apply Contrast Limited Adaptive Histogram Equalization to the input image.
57 | def __init__(self, clip_limit=4.0, tile_grid_size=(8, 8)):
58 | self.clip_limit = clip_limit
59 | self.tile_grid_size = tile_grid_size
60 |
61 | def __call__(self, image, label):
62 | if image.dtype != np.uint8:
63 | image = image.astype(np.uint8)
64 |
65 | clahe_mat = cv2.createCLAHE(clipLimit=self.clip_limit, tileGridSize=self.tile_grid_size)
66 |
67 | if len(image.shape) == 2 or image.shape[2] == 1:
68 | image = clahe_mat.apply(image)
69 | else:
70 | image = cv2.cvtColor(image, cv2.COLOR_RGB2LAB)
71 | image[:, :, 0] = clahe_mat.apply(image[:, :, 0])
72 | image = cv2.cvtColor(image, cv2.COLOR_LAB2RGB)
73 | image = image.astype(np.float32)
74 | return image, label
75 |
76 |
77 | class ToNumpy(object):
78 | # Converts torch.FloatTensor of shape (C x H x W) to a numpy.ndarray (H x W x C).
79 | def __call__(self, image, label):
80 | if not isinstance(image, torch.Tensor) or not isinstance(label, torch.Tensor):
81 | raise (RuntimeError("segtransform.ToNumpy() only handle torch.tensor"))
82 |
83 | image = image.cpu().numpy().transpose((1, 2, 0))
84 | if not image.dtype == np.uint8:
85 | image = image.astype(np.uint8)
86 | label = label.cpu().numpy().transpose((1, 2, 0))
87 | if not label.dtype == np.uint8:
88 | label = label.astype(np.uint8)
89 | return image, label
90 |
91 |
92 | class Normalize(object):
93 | # Normalize tensor with mean and standard deviation along channel: channel = (channel - mean) / std
94 | def __init__(self, mean, std=None):
95 | if std is None:
96 | assert len(mean) > 0
97 | else:
98 | assert len(mean) == len(std)
99 | self.mean = mean
100 | self.std = std
101 |
102 | def __call__(self, image, label):
103 | if self.std is None:
104 | for t, m in zip(image, self.mean):
105 | t.sub_(m)
106 | else:
107 | for t, m, s in zip(image, self.mean, self.std):
108 | t.sub_(m).div_(s)
109 | return image, label
110 |
111 |
112 | class UnNormalize(object):
113 | # UnNormalize tensor with mean and standard deviation along channel: channel = (channel * std) + mean
114 | def __init__(self, mean, std=None):
115 | if std is None:
116 | assert len(mean) > 0
117 | else:
118 | assert len(mean) == len(std)
119 | self.mean = mean
120 | self.std = std
121 |
122 | def __call__(self, image, label):
123 | if self.std is None:
124 | for t, m in zip(image, self.mean):
125 | t.add_(m)
126 | else:
127 | for t, m, s in zip(image, self.mean, self.std):
128 | t.mul_(s).add_(m)
129 | return image, label
130 |
131 |
132 | class Resize(object):
133 | # Resize the input to the given size, 'size' is a 2-element tuple or list in the order of (h, w).
134 | def __init__(self, size):
135 | self.size = size
136 |
137 | def __call__(self, image, label):
138 |
139 | # value_scale = 255
140 | # mean = [0.485, 0.456, 0.406]
141 | # mean = [item * value_scale for item in mean]
142 | # std = [0.229, 0.224, 0.225]
143 | # std = [item * value_scale for item in std]
144 | #
145 | # def find_new_hw(ori_h, ori_w, test_size):
146 | # if ori_h >= ori_w:
147 | # ratio = test_size * 1.0 / ori_h
148 | # new_h = test_size
149 | # new_w = int(ori_w * ratio)
150 | # elif ori_w > ori_h:
151 | # ratio = test_size * 1.0 / ori_w
152 | # new_h = int(ori_h * ratio)
153 | # new_w = test_size
154 | #
155 | # if new_h % 8 != 0:
156 | # new_h = (int(new_h / 8)) * 8
157 | # else:
158 | # new_h = new_h
159 | # if new_w % 8 != 0:
160 | # new_w = (int(new_w / 8)) * 8
161 | # else:
162 | # new_w = new_w
163 | # return new_h, new_w
164 | #
165 | # test_size = self.size
166 | # new_h, new_w = find_new_hw(image.shape[0], image.shape[1], test_size)
167 | # # new_h, new_w = test_size, test_size
168 | # image_crop = cv2.resize(image, dsize=(int(new_w), int(new_h)), interpolation=cv2.INTER_LINEAR)
169 | # back_crop = np.zeros((test_size, test_size, 3))
170 | # # back_crop[:,:,0] = mean[0]
171 | # # back_crop[:,:,1] = mean[1]
172 | # # back_crop[:,:,2] = mean[2]
173 | # back_crop[:new_h, :new_w, :] = image_crop
174 | # image = back_crop
175 | #
176 | # s_mask = label
177 | # new_h, new_w = find_new_hw(s_mask.shape[0], s_mask.shape[1], test_size)
178 | # # new_h, new_w = test_size, test_size
179 | # s_mask = cv2.resize(s_mask.astype(np.float32), dsize=(int(new_w), int(new_h)), interpolation=cv2.INTER_NEAREST)
180 | # back_crop_s_mask = np.ones((test_size, test_size)) * 255
181 | # back_crop_s_mask[:new_h, :new_w] = s_mask
182 | # label = back_crop_s_mask
183 |
184 | image = cv2.resize(image, dsize=(self.size, self.size), interpolation=cv2.INTER_LINEAR)
185 | label = cv2.resize(label, dsize=(self.size, self.size), interpolation=cv2.INTER_NEAREST)
186 |
187 | return image, label
188 |
189 |
190 | class test_Resize(object):
191 | # Resize the input to the given size, 'size' is a 2-element tuple or list in the order of (h, w).
192 | def __init__(self, size):
193 | self.size = size
194 |
195 | def __call__(self, image, label):
196 |
197 | # value_scale = 255
198 | # mean = [0.485, 0.456, 0.406]
199 | # mean = [item * value_scale for item in mean]
200 | # std = [0.229, 0.224, 0.225]
201 | # std = [item * value_scale for item in std]
202 | #
203 | # def find_new_hw(ori_h, ori_w, test_size):
204 | # if max(ori_h, ori_w) > test_size:
205 | # if ori_h >= ori_w:
206 | # ratio = test_size * 1.0 / ori_h
207 | # new_h = test_size
208 | # new_w = int(ori_w * ratio)
209 | # elif ori_w > ori_h:
210 | # ratio = test_size * 1.0 / ori_w
211 | # new_h = int(ori_h * ratio)
212 | # new_w = test_size
213 | #
214 | # if new_h % 8 != 0:
215 | # new_h = (int(new_h / 8)) * 8
216 | # else:
217 | # new_h = new_h
218 | # if new_w % 8 != 0:
219 | # new_w = (int(new_w / 8)) * 8
220 | # else:
221 | # new_w = new_w
222 | # return new_h, new_w
223 | # else:
224 | # return ori_h, ori_w
225 | #
226 | # test_size = self.size
227 | # new_h, new_w = find_new_hw(image.shape[0], image.shape[1], test_size)
228 | # if new_w != image.shape[0] or new_h != image.shape[1]:
229 | # image_crop = cv2.resize(image, dsize=(int(new_w), int(new_h)), interpolation=cv2.INTER_LINEAR)
230 | # else:
231 | # image_crop = image.copy()
232 | # back_crop = np.zeros((test_size, test_size, 3))
233 | # back_crop[:new_h, :new_w, :] = image_crop
234 | # image = back_crop
235 | #
236 | # s_mask = label
237 | # new_h, new_w = find_new_hw(s_mask.shape[0], s_mask.shape[1], test_size)
238 | # if new_w != s_mask.shape[0] or new_h != s_mask.shape[1]:
239 | # s_mask = cv2.resize(s_mask.astype(np.float32), dsize=(int(new_w), int(new_h)),
240 | # interpolation=cv2.INTER_NEAREST)
241 | # back_crop_s_mask = np.ones((test_size, test_size)) * 255
242 | # back_crop_s_mask[:new_h, :new_w] = s_mask
243 | # label = back_crop_s_mask
244 |
245 | image = cv2.resize(image, dsize=(self.size, self.size), interpolation=cv2.INTER_LINEAR)
246 | label = cv2.resize(label, dsize=(self.size, self.size), interpolation=cv2.INTER_NEAREST)
247 |
248 | return image, label
249 |
250 |
251 | class Direct_Resize(object):
252 | # Resize the input to the given size, 'size' is a 2-element tuple or list in the order of (h, w).
253 | def __init__(self, size):
254 | self.size = size
255 |
256 | def __call__(self, image, label):
257 | test_size = self.size
258 |
259 | image = cv2.resize(image, dsize=(test_size, test_size), interpolation=cv2.INTER_LINEAR)
260 | label = cv2.resize(label.astype(np.float32), dsize=(test_size, test_size), interpolation=cv2.INTER_NEAREST)
261 |
262 | return image, label
263 |
264 |
265 | class RandScale(object):
266 | # Randomly resize image & label with scale factor in [scale_min, scale_max]
267 | def __init__(self, scale, aspect_ratio=None):
268 | assert (isinstance(scale, collections.Iterable) and len(scale) == 2)
269 | if isinstance(scale, collections.Iterable) and len(scale) == 2 \
270 | and isinstance(scale[0], numbers.Number) and isinstance(scale[1], numbers.Number) \
271 | and 0 < scale[0] < scale[1]:
272 | self.scale = scale
273 | else:
274 | raise (RuntimeError("segtransform.RandScale() scale param error.\n"))
275 | if aspect_ratio is None:
276 | self.aspect_ratio = aspect_ratio
277 | elif isinstance(aspect_ratio, collections.Iterable) and len(aspect_ratio) == 2 \
278 | and isinstance(aspect_ratio[0], numbers.Number) and isinstance(aspect_ratio[1], numbers.Number) \
279 | and 0 < aspect_ratio[0] < aspect_ratio[1]:
280 | self.aspect_ratio = aspect_ratio
281 | else:
282 | raise (RuntimeError("segtransform.RandScale() aspect_ratio param error.\n"))
283 |
284 | def __call__(self, image, label):
285 | temp_scale = self.scale[0] + (self.scale[1] - self.scale[0]) * random.random()
286 | temp_aspect_ratio = 1.0
287 | if self.aspect_ratio is not None:
288 | temp_aspect_ratio = self.aspect_ratio[0] + (self.aspect_ratio[1] - self.aspect_ratio[0]) * random.random()
289 | temp_aspect_ratio = math.sqrt(temp_aspect_ratio)
290 | scale_factor_x = temp_scale * temp_aspect_ratio
291 | scale_factor_y = temp_scale / temp_aspect_ratio
292 | image = cv2.resize(image, None, fx=scale_factor_x, fy=scale_factor_y, interpolation=cv2.INTER_LINEAR)
293 | label = cv2.resize(label, None, fx=scale_factor_x, fy=scale_factor_y, interpolation=cv2.INTER_NEAREST)
294 | return image, label
295 |
296 |
297 | class Crop(object):
298 | """Crops the given ndarray image (H*W*C or H*W).
299 | Args:
300 | size (sequence or int): Desired output size of the crop. If size is an
301 | int instead of sequence like (h, w), a square crop (size, size) is made.
302 | """
303 |
304 | def __init__(self, size, crop_type='center', padding=None, ignore_label=255):
305 | self.size = size
306 | if isinstance(size, int):
307 | self.crop_h = size
308 | self.crop_w = size
309 | elif isinstance(size, collections.Iterable) and len(size) == 2 \
310 | and isinstance(size[0], int) and isinstance(size[1], int) \
311 | and size[0] > 0 and size[1] > 0:
312 | self.crop_h = size[0]
313 | self.crop_w = size[1]
314 | else:
315 | raise (RuntimeError("crop size error.\n"))
316 | if crop_type == 'center' or crop_type == 'rand':
317 | self.crop_type = crop_type
318 | else:
319 | raise (RuntimeError("crop type error: rand | center\n"))
320 | if padding is None:
321 | self.padding = padding
322 | elif isinstance(padding, list):
323 | if all(isinstance(i, numbers.Number) for i in padding):
324 | self.padding = padding
325 | else:
326 | raise (RuntimeError("padding in Crop() should be a number list\n"))
327 | if len(padding) != 3:
328 | raise (RuntimeError("padding channel is not equal with 3\n"))
329 | else:
330 | raise (RuntimeError("padding in Crop() should be a number list\n"))
331 | if isinstance(ignore_label, int):
332 | self.ignore_label = ignore_label
333 | else:
334 | raise (RuntimeError("ignore_label should be an integer number\n"))
335 |
336 | def __call__(self, image, label):
337 | h, w = label.shape
338 |
339 | pad_h = max(self.crop_h - h, 0)
340 | pad_w = max(self.crop_w - w, 0)
341 | pad_h_half = int(pad_h / 2)
342 | pad_w_half = int(pad_w / 2)
343 | if pad_h > 0 or pad_w > 0:
344 | if self.padding is None:
345 | raise (RuntimeError("segtransform.Crop() need padding while padding argument is None\n"))
346 | image = cv2.copyMakeBorder(image, pad_h_half, pad_h - pad_h_half, pad_w_half, pad_w - pad_w_half,
347 | cv2.BORDER_CONSTANT, value=self.padding)
348 | label = cv2.copyMakeBorder(label, pad_h_half, pad_h - pad_h_half, pad_w_half, pad_w - pad_w_half,
349 | cv2.BORDER_CONSTANT, value=self.ignore_label)
350 | h, w = label.shape
351 | raw_label = label
352 | raw_image = image
353 |
354 | if self.crop_type == 'rand':
355 | h_off = random.randint(0, h - self.crop_h)
356 | w_off = random.randint(0, w - self.crop_w)
357 | else:
358 | h_off = int((h - self.crop_h) / 2)
359 | w_off = int((w - self.crop_w) / 2)
360 | image = image[h_off:h_off + self.crop_h, w_off:w_off + self.crop_w]
361 | label = label[h_off:h_off + self.crop_h, w_off:w_off + self.crop_w]
362 | raw_pos_num = np.sum(raw_label == 1)
363 | pos_num = np.sum(label == 1)
364 | crop_cnt = 0
365 | while (pos_num < 0.85 * raw_pos_num and crop_cnt <= 30):
366 | image = raw_image
367 | label = raw_label
368 | if self.crop_type == 'rand':
369 | h_off = random.randint(0, h - self.crop_h)
370 | w_off = random.randint(0, w - self.crop_w)
371 | else:
372 | h_off = int((h - self.crop_h) / 2)
373 | w_off = int((w - self.crop_w) / 2)
374 | image = image[h_off:h_off + self.crop_h, w_off:w_off + self.crop_w]
375 | label = label[h_off:h_off + self.crop_h, w_off:w_off + self.crop_w]
376 | raw_pos_num = np.sum(raw_label == 1)
377 | pos_num = np.sum(label == 1)
378 | crop_cnt += 1
379 | if crop_cnt >= 50:
380 | image = cv2.resize(raw_image, (self.size[0], self.size[0]), interpolation=cv2.INTER_LINEAR)
381 | label = cv2.resize(raw_label, (self.size[0], self.size[0]), interpolation=cv2.INTER_NEAREST)
382 |
383 | if image.shape != (self.size[0], self.size[0], 3):
384 | image = cv2.resize(image, (self.size[0], self.size[0]), interpolation=cv2.INTER_LINEAR)
385 | label = cv2.resize(label, (self.size[0], self.size[0]), interpolation=cv2.INTER_NEAREST)
386 |
387 | return image, label
388 |
389 |
390 | class RandRotate(object):
391 | # Randomly rotate image & label with rotate factor in [rotate_min, rotate_max]
392 | def __init__(self, rotate, padding, ignore_label=255, p=0.5):
393 | assert (isinstance(rotate, collections.Iterable) and len(rotate) == 2)
394 | if isinstance(rotate[0], numbers.Number) and isinstance(rotate[1], numbers.Number) and rotate[0] < rotate[1]:
395 | self.rotate = rotate
396 | else:
397 | raise (RuntimeError("segtransform.RandRotate() scale param error.\n"))
398 | assert padding is not None
399 | assert isinstance(padding, list) and len(padding) == 3
400 | if all(isinstance(i, numbers.Number) for i in padding):
401 | self.padding = padding
402 | else:
403 | raise (RuntimeError("padding in RandRotate() should be a number list\n"))
404 | assert isinstance(ignore_label, int)
405 | self.ignore_label = ignore_label
406 | self.p = p
407 |
408 | def __call__(self, image, label):
409 | if random.random() < self.p:
410 | angle = self.rotate[0] + (self.rotate[1] - self.rotate[0]) * random.random()
411 | h, w = label.shape
412 | matrix = cv2.getRotationMatrix2D((w / 2, h / 2), angle, 1)
413 | image = cv2.warpAffine(image, matrix, (w, h), flags=cv2.INTER_LINEAR, borderMode=cv2.BORDER_CONSTANT,
414 | borderValue=self.padding)
415 | label = cv2.warpAffine(label, matrix, (w, h), flags=cv2.INTER_NEAREST, borderMode=cv2.BORDER_CONSTANT,
416 | borderValue=self.ignore_label)
417 | return image, label
418 |
419 |
420 | class RandomHorizontalFlip(object):
421 | def __init__(self, p=0.5):
422 | self.p = p
423 |
424 | def __call__(self, image, label):
425 | if random.random() < self.p:
426 | image = cv2.flip(image, 1)
427 | label = cv2.flip(label, 1)
428 | return image, label
429 |
430 |
431 | class RandomVerticalFlip(object):
432 | def __init__(self, p=0.5):
433 | self.p = p
434 |
435 | def __call__(self, image, label):
436 | if random.random() < self.p:
437 | image = cv2.flip(image, 0)
438 | label = cv2.flip(label, 0)
439 | return image, label
440 |
441 |
442 | class RandomGaussianBlur(object):
443 | def __init__(self, radius=5):
444 | self.radius = radius
445 |
446 | def __call__(self, image, label):
447 | if random.random() < 0.5:
448 | image = cv2.GaussianBlur(image, (self.radius, self.radius), 0)
449 | return image, label
450 |
451 |
452 | class RGB2BGR(object):
453 | # Converts image from RGB order to BGR order, for model initialized from Caffe
454 | def __call__(self, image, label):
455 | image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
456 | return image, label
457 |
458 |
459 | class BGR2RGB(object):
460 | # Converts image from BGR order to RGB order, for model initialized from Pytorch
461 | def __call__(self, image, label):
462 | image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
463 | return image, label
464 |
--------------------------------------------------------------------------------
/util/util.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | from PIL import Image
4 | import random
5 | import logging
6 | import cv2
7 | import matplotlib.pyplot as plt
8 | from matplotlib.pyplot import MultipleLocator
9 | from matplotlib.ticker import FuncFormatter, FormatStrFormatter
10 | from matplotlib import font_manager
11 | from matplotlib import rcParams
12 | import seaborn as sns
13 | import pandas as pd
14 | import math
15 | from seaborn.distributions import distplot
16 | from tqdm import tqdm
17 | from scipy import ndimage
18 |
19 | # from util.get_weak_anns import find_bbox
20 |
21 | import torch
22 | from torch import nn
23 | import torch.backends.cudnn as cudnn
24 | import torch.nn.init as initer
25 |
26 |
27 | class AverageMeter(object):
28 | """Computes and stores the average and current value"""
29 |
30 | def __init__(self):
31 | self.reset()
32 |
33 | def reset(self):
34 | self.val = 0
35 | self.avg = 0
36 | self.sum = 0
37 | self.count = 0
38 |
39 | def update(self, val, n=1):
40 | self.val = val
41 | self.sum += val * n
42 | self.count += n
43 | self.avg = self.sum / self.count
44 |
45 |
46 | def step_learning_rate(optimizer, base_lr, epoch, step_epoch, multiplier=0.1):
47 | """Sets the learning rate to the base LR decayed by 10 every step epochs"""
48 | lr = base_lr * (multiplier ** (epoch // step_epoch))
49 | for param_group in optimizer.param_groups:
50 | param_group['lr'] = lr
51 |
52 |
53 | def poly_learning_rate(optimizer, base_lr, curr_iter, max_iter, power=0.9, index_split=-1, scale_lr=10., warmup=False,
54 | warmup_step=500):
55 | """poly learning rate policy"""
56 | if warmup and curr_iter < warmup_step:
57 | lr = base_lr * (0.1 + 0.9 * (curr_iter / warmup_step))
58 | else:
59 | lr = base_lr * (1 - float(curr_iter) / max_iter) ** power
60 |
61 | # if curr_iter % 50 == 0:
62 | # print('Base LR: {:.4f}, Curr LR: {:.4f}, Warmup: {}.'.format(base_lr, lr, (warmup and curr_iter < warmup_step)))
63 |
64 | for index, param_group in enumerate(optimizer.param_groups):
65 | if index <= index_split:
66 | param_group['lr'] = lr
67 | else:
68 | param_group['lr'] = lr * scale_lr # 10x LR
69 |
70 |
71 | def intersectionAndUnion(output, target, K, ignore_index=255):
72 | # 'K' classes, output and target sizes are N or N * L or N * H * W, each value in range 0 to K - 1.
73 | assert (output.ndim in [1, 2, 3])
74 | assert output.shape == target.shape
75 | output = output.reshape(output.size).copy()
76 | target = target.reshape(target.size)
77 | output[np.where(target == ignore_index)[0]] = ignore_index
78 | intersection = output[np.where(output == target)[0]]
79 | area_intersection, _ = np.histogram(intersection, bins=np.arange(K + 1))
80 | area_output, _ = np.histogram(output, bins=np.arange(K + 1))
81 | area_target, _ = np.histogram(target, bins=np.arange(K + 1))
82 | area_union = area_output + area_target - area_intersection
83 | return area_intersection, area_union, area_target
84 |
85 |
86 | def intersectionAndUnionGPU(output, target, K, ignore_index=255):
87 | # 'K' classes, output and target sizes are N or N * L or N * H * W, each value in range 0 to K - 1.
88 | assert (output.dim() in [1, 2, 3])
89 | assert output.shape == target.shape
90 | output = output.view(-1)
91 | target = target.view(-1)
92 | output[target == ignore_index] = ignore_index
93 | intersection = output[output == target]
94 | area_intersection = torch.histc(intersection, bins=K, min=0, max=K - 1)
95 | area_output = torch.histc(output, bins=K, min=0, max=K - 1)
96 | area_target = torch.histc(target, bins=K, min=0, max=K - 1)
97 | area_union = area_output + area_target - area_intersection
98 | return area_intersection, area_union, area_target
99 |
100 |
101 | def check_mkdir(dir_name):
102 | if not os.path.exists(dir_name):
103 | os.mkdir(dir_name)
104 |
105 |
106 | def check_makedirs(dir_name):
107 | if not os.path.exists(dir_name):
108 | os.makedirs(dir_name)
109 |
110 |
111 | def del_file(path):
112 | for i in os.listdir(path):
113 | path_file = os.path.join(path, i)
114 | if os.path.isfile(path_file):
115 | os.remove(path_file)
116 | else:
117 | del_file(path_file)
118 |
119 |
120 | def init_weights(model, conv='kaiming', batchnorm='normal', linear='kaiming', lstm='kaiming'):
121 | """
122 | :param model: Pytorch Model which is nn.Module
123 | :param conv: 'kaiming' or 'xavier'
124 | :param batchnorm: 'normal' or 'constant'
125 | :param linear: 'kaiming' or 'xavier'
126 | :param lstm: 'kaiming' or 'xavier'
127 | """
128 | for m in model.modules():
129 | if isinstance(m, (nn.Conv1d, nn.Conv2d, nn.Conv3d)):
130 | if conv == 'kaiming':
131 | initer.kaiming_normal_(m.weight)
132 | elif conv == 'xavier':
133 | initer.xavier_normal_(m.weight)
134 | else:
135 | raise ValueError("init type of conv error.\n")
136 | if m.bias is not None:
137 | initer.constant_(m.bias, 0)
138 |
139 | elif isinstance(m,
140 | (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d)): # , BatchNorm1d, BatchNorm2d, BatchNorm3d)):
141 | if batchnorm == 'normal':
142 | initer.normal_(m.weight, 1.0, 0.02)
143 | elif batchnorm == 'constant':
144 | initer.constant_(m.weight, 1.0)
145 | else:
146 | raise ValueError("init type of batchnorm error.\n")
147 | initer.constant_(m.bias, 0.0)
148 |
149 | elif isinstance(m, nn.Linear):
150 | if linear == 'kaiming':
151 | initer.kaiming_normal_(m.weight)
152 | elif linear == 'xavier':
153 | initer.xavier_normal_(m.weight)
154 | else:
155 | raise ValueError("init type of linear error.\n")
156 | if m.bias is not None:
157 | initer.constant_(m.bias, 0)
158 |
159 | elif isinstance(m, nn.LSTM):
160 | for name, param in m.named_parameters():
161 | if 'weight' in name:
162 | if lstm == 'kaiming':
163 | initer.kaiming_normal_(param)
164 | elif lstm == 'xavier':
165 | initer.xavier_normal_(param)
166 | else:
167 | raise ValueError("init type of lstm error.\n")
168 | elif 'bias' in name:
169 | initer.constant_(param, 0)
170 |
171 |
172 | def colorize(gray, palette):
173 | # gray: numpy array of the label and 1*3N size list palette
174 | color = Image.fromarray(gray.astype(np.uint8)).convert('P')
175 | color.putpalette(palette)
176 | return color
177 |
178 |
179 | # ------------------------------------------------------
180 | def get_model_para_number(model):
181 | total_number = 0
182 | learnable_number = 0
183 | for para in model.parameters():
184 | total_number += torch.numel(para)
185 | if para.requires_grad == True:
186 | learnable_number += torch.numel(para)
187 | return total_number, learnable_number
188 |
189 |
190 | def setup_seed(seed=2021, deterministic=False):
191 | if deterministic:
192 | cudnn.benchmark = False
193 | cudnn.deterministic = True
194 | torch.manual_seed(seed)
195 | torch.cuda.manual_seed(seed)
196 | torch.cuda.manual_seed_all(seed)
197 | np.random.seed(seed)
198 | random.seed(seed)
199 | os.environ['PYTHONHASHSEED'] = str(seed)
200 |
201 |
202 | def get_logger():
203 | logger_name = "main-logger"
204 | logger = logging.getLogger()
205 | logger.setLevel(logging.INFO)
206 | handler = logging.StreamHandler()
207 | fmt = "[%(asctime)s %(levelname)s %(filename)s line %(lineno)d %(process)d] %(message)s"
208 | handler.setFormatter(logging.Formatter(fmt))
209 | logger.addHandler(handler)
210 | return logger
211 |
212 |
213 | def get_save_path(args):
214 | backbone_str = 'vgg' if args.vgg else 'resnet' + str(args.layers)
215 | args.snapshot_path = 'exp/{}/{}/split{}/{}/snapshot'.format(args.data_set, args.arch, args.split, backbone_str)
216 | args.result_path = 'exp/{}/{}/split{}/{}/result'.format(args.data_set, args.arch, args.split, backbone_str)
217 |
218 |
219 | def get_train_val_set(args):
220 | if args.data_set == 'pascal':
221 | class_list = list(range(1, 21)) # [1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20]
222 | if args.split == 3:
223 | sub_list = list(range(1, 16)) # [1,2,3,4,5,6,7,8,9,10,11,12,13,14,15]
224 | sub_val_list = list(range(16, 21)) # [16,17,18,19,20]
225 | elif args.split == 2:
226 | sub_list = list(range(1, 11)) + list(range(16, 21)) # [1,2,3,4,5,6,7,8,9,10,16,17,18,19,20]
227 | sub_val_list = list(range(11, 16)) # [11,12,13,14,15]
228 | elif args.split == 1:
229 | sub_list = list(range(1, 6)) + list(range(11, 21)) # [1,2,3,4,5,11,12,13,14,15,16,17,18,19,20]
230 | sub_val_list = list(range(6, 11)) # [6,7,8,9,10]
231 | elif args.split == 0:
232 | sub_list = list(range(6, 21)) # [6,7,8,9,10,11,12,13,14,15,16,17,18,19,20]
233 | sub_val_list = list(range(1, 6)) # [1,2,3,4,5]
234 |
235 | elif args.data_set == 'coco':
236 | if args.use_split_coco:
237 | print('INFO: using SPLIT COCO (FWB)')
238 | class_list = list(range(1, 81))
239 | if args.split == 3:
240 | sub_val_list = list(range(4, 81, 4))
241 | sub_list = list(set(class_list) - set(sub_val_list))
242 | elif args.split == 2:
243 | sub_val_list = list(range(3, 80, 4))
244 | sub_list = list(set(class_list) - set(sub_val_list))
245 | elif args.split == 1:
246 | sub_val_list = list(range(2, 79, 4))
247 | sub_list = list(set(class_list) - set(sub_val_list))
248 | elif args.split == 0:
249 | sub_val_list = list(range(1, 78, 4))
250 | sub_list = list(set(class_list) - set(sub_val_list))
251 | else:
252 | print('INFO: using COCO (PANet)')
253 | class_list = list(range(1, 81))
254 | if args.split == 3:
255 | sub_list = list(range(1, 61))
256 | sub_val_list = list(range(61, 81))
257 | elif args.split == 2:
258 | sub_list = list(range(1, 41)) + list(range(61, 81))
259 | sub_val_list = list(range(41, 61))
260 | elif args.split == 1:
261 | sub_list = list(range(1, 21)) + list(range(41, 81))
262 | sub_val_list = list(range(21, 41))
263 | elif args.split == 0:
264 | sub_list = list(range(21, 81))
265 | sub_val_list = list(range(1, 21))
266 |
267 | return sub_list, sub_val_list
268 |
269 |
270 | def is_same_model(model1, model2):
271 | flag = 0
272 | count = 0
273 | for k, v in model1.state_dict().items():
274 | model1_val = v
275 | model2_val = model2.state_dict()[k]
276 | if (model1_val == model2_val).all():
277 | pass
278 | else:
279 | flag += 1
280 | print('value of key <{}> mismatch'.format(k))
281 | count += 1
282 |
283 | return True if flag == 0 else False
284 |
285 |
286 | def fix_bn(m):
287 | classname = m.__class__.__name__
288 | if classname.find('BatchNorm') != -1:
289 | m.eval()
290 |
291 |
292 | def sum_list(list):
293 | sum = 0
294 | for item in list:
295 | sum += item
296 | return sum
297 |
--------------------------------------------------------------------------------