├── README.md ├── configs ├── ssformer-L.yaml ├── ssformer-S.yaml └── train.yaml ├── images └── ssformer.png ├── models ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-37.pyc │ ├── __init__.cpython-38.pyc │ ├── build_model.cpython-37.pyc │ ├── build_model.cpython-38.pyc │ └── test.cpython-37.pyc ├── build_model.py ├── cvt │ ├── __pycache__ │ │ ├── cvt_SD.cpython-37.pyc │ │ ├── cvt_mla.cpython-37.pyc │ │ ├── cvt_pup.cpython-37.pyc │ │ └── cvt_srm.cpython-37.pyc │ ├── cvt_PLD.py │ ├── cvt_SD.py │ ├── cvt_mla.py │ └── cvt_pup.py ├── mit │ ├── __pycache__ │ │ ├── mit_PLD_b2.cpython-37.pyc │ │ ├── mit_PLD_b2.cpython-38.pyc │ │ ├── mit_PLD_b4.cpython-37.pyc │ │ ├── mit_PLD_b4.cpython-38.pyc │ │ ├── mit_SD.cpython-37.pyc │ │ ├── mit_mla.cpython-37.pyc │ │ ├── mit_pup.cpython-37.pyc │ │ ├── mit_srm.cpython-37.pyc │ │ ├── mit_srm.cpython-38.pyc │ │ ├── mit_srm_add.cpython-37.pyc │ │ ├── mit_srm_b4.cpython-37.pyc │ │ └── segformer.cpython-37.pyc │ ├── mit_PLD_b2.py │ ├── mit_PLD_b4.py │ ├── mit_PPD.py │ ├── mit_SD.py │ ├── mit_mla.py │ ├── mit_pup.py │ └── mit_srm_add.py ├── pvt │ ├── __pycache__ │ │ ├── pvt_PPD.cpython-37.pyc │ │ ├── pvt_SD.cpython-37.pyc │ │ ├── pvt_mla.cpython-37.pyc │ │ ├── pvt_pup.cpython-37.pyc │ │ └── pvt_srm.cpython-37.pyc │ ├── pvt_PLD.py │ ├── pvt_PPD.py │ ├── pvt_SD.py │ ├── pvt_mla.py │ └── pvt_pup.py ├── simVit │ ├── __pycache__ │ │ ├── simVit_SD.cpython-37.pyc │ │ ├── simVit_mla.cpython-37.pyc │ │ ├── simVit_pup.cpython-37.pyc │ │ └── simVit_srm.cpython-37.pyc │ ├── simVit_PLD.py │ ├── simVit_SD.py │ ├── simVit_mla.py │ └── simVit_pup.py └── ssa │ ├── __pycache__ │ ├── ssa_SD.cpython-37.pyc │ ├── ssa_mla.cpython-37.pyc │ ├── ssa_pup.cpython-37.pyc │ └── ssa_srm.cpython-37.pyc │ ├── ssa_PLD.py │ ├── ssa_SD.py │ ├── ssa_mla.py │ └── ssa_pup.py ├── requirements.txt ├── result ├── .DS_Store ├── ssformer_L │ ├── .DS_Store │ └── mit_srm_b4.png └── ssformer_S │ ├── .DS_Store │ └── ssformer_S.png ├── test.py ├── train.py └── utils ├── PolynomialLRDecay.py ├── __init__.py ├── __pycache__ ├── __init__.cpython-37.pyc ├── __init__.cpython-38.pyc ├── custom_transforms.cpython-37.pyc ├── custom_transforms.cpython-38.pyc ├── dataloader.cpython-37.pyc ├── eval_other.cpython-37.pyc ├── loss.cpython-37.pyc ├── loss.cpython-38.pyc ├── my_dataset.cpython-37.pyc ├── my_dataset.cpython-38.pyc ├── swd.cpython-37.pyc ├── swd.cpython-38.pyc ├── test_dataset.cpython-37.pyc ├── test_dataset.cpython-38.pyc ├── test_transforms.cpython-37.pyc ├── test_transforms.cpython-38.pyc ├── tools.cpython-37.pyc ├── tools.cpython-38.pyc └── utils.cpython-37.pyc ├── custom_transforms.py ├── dataloader.py ├── eval_FPS.py ├── eval_other.py ├── hlper.py ├── loss.py ├── loss2.py ├── my_dataset.py ├── swd.py ├── test_dataset.py ├── test_transforms.py ├── tools.py └── utils.py /README.md: -------------------------------------------------------------------------------- 1 | # Stepwise Feature Fusion: Local Guides Global 2 | This is the official implementation for [Stepwise Feature Fusion: Local Guides Global](https://arxiv.org/abs/2203.03635) 3 | 4 | ![SSformer](/images/ssformer.png) 5 | 6 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/stepwise-feature-fusion-local-guides-global/medical-image-segmentation-on-cvc-clinicdb)](https://paperswithcode.com/sota/medical-image-segmentation-on-cvc-clinicdb?p=stepwise-feature-fusion-local-guides-global) 7 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/stepwise-feature-fusion-local-guides-global/medical-image-segmentation-on-etis)](https://paperswithcode.com/sota/medical-image-segmentation-on-etis?p=stepwise-feature-fusion-local-guides-global) 8 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/stepwise-feature-fusion-local-guides-global/medical-image-segmentation-on-kvasir-seg)](https://paperswithcode.com/sota/medical-image-segmentation-on-kvasir-seg?p=stepwise-feature-fusion-local-guides-global) 9 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/stepwise-feature-fusion-local-guides-global/medical-image-segmentation-on-cvc-colondb)](https://paperswithcode.com/sota/medical-image-segmentation-on-cvc-colondb?p=stepwise-feature-fusion-local-guides-global) 10 | 11 | ## packages 12 | - Please see requirements.txt 13 | 14 | ## Dataset 15 | - The dataset we used can be download from [here](https://drive.google.com/file/d/1z48bsJftdp4akAlWOziqt6032huYYN9k/view?usp=sharing) 16 | 17 | ### Checkpoints 18 | - The checkpoint for ssformer-S can be downloaded from [here](https://drive.google.com/file/d/1CdX0K1_ZDMrEVGK2cmBfp33lYxLEBwlw/view?usp=sharing) 19 | - The checkpoint for ssformer-L can be downloaded from [here](https://drive.google.com/file/d/1CEwUOPm1otoEGfXSvcX-y1x80583-Q9C/view?usp=sharing) 20 | 21 | ## Usage 22 | ### Test 23 | 1. modified `configs/ssformer-S.yaml` 24 | - `dataset` set to your data path 25 | - `test.checkpoint_save_path` : path to your downloaded checkpoint 26 | 2. run `python test.py configs/ssformer-S.yaml` 27 | 28 | ### Train 29 | 1. modified `configs/train.yaml` 30 | - `model.pretrained_path` : mit pre-trained checkpoint path 31 | - `other` : path to save your training checkpoint and log file 32 | 2. run `python train.py configs/train.yaml` 33 | 34 | ## Citation 35 | ``` 36 | Wang, J., Huang, Q., Tang, F., Meng, J., Su, J., Song, S. (2022). 37 | Stepwise Feature Fusion: Local Guides Global. 38 | In: Wang, L., Dou, Q., Fletcher, P.T., Speidel, S., Li, S. (eds) 39 | Medical Image Computing and Computer Assisted Intervention – MICCAI 2022. 40 | MICCAI 2022. Lecture Notes in Computer Science, vol 13433. Springer, Cham. 41 | https://doi.org/10.1007/978-3-031-16437-8_11 42 | 43 | ``` 44 | -------------------------------------------------------------------------------- /configs/ssformer-L.yaml: -------------------------------------------------------------------------------- 1 | dataset: 2 | train_img_root : /mnt/DATA-1/DATA-2/Feilong/scformer/data/polyp/TrainDataset/images/ 3 | train_label_root: /mnt/DATA-1/DATA-2/Feilong/scformer/data/polyp/TrainDataset/masks/ 4 | 5 | test_CVC-300_img : /mnt/DATA-1/DATA-2/Feilong/scformer/data/polyp/TestDataset/CVC-300/images/ 6 | test_CVC-300_label : /mnt/DATA-1/DATA-2/Feilong/scformer/data/polyp/TestDataset/CVC-300/masks/ 7 | 8 | test_CVC-ClinicDB_img : /mnt/DATA-1/DATA-2/Feilong/scformer/data/polyp/TestDataset/CVC-ClinicDB/images/ 9 | test_CVC-ClinicDB_label : /mnt/DATA-1/DATA-2/Feilong/scformer/data/polyp/TestDataset/CVC-ClinicDB/masks/ 10 | 11 | test_CVC-ColonDB_img : /mnt/DATA-1/DATA-2/Feilong/scformer/data/polyp/TestDataset/CVC-ColonDB/images/ 12 | test_CVC-ColonDB_label : /mnt/DATA-1/DATA-2/Feilong/scformer/data/polyp/TestDataset/CVC-ColonDB/masks/ 13 | 14 | test_ETIS-LaribPolypDB_img : /mnt/DATA-1/DATA-2/Feilong/scformer/data/polyp/TestDataset/ETIS-LaribPolypDB/images/ 15 | test_ETIS-LaribPolypDB_label : /mnt/DATA-1/DATA-2/Feilong/scformer/data/polyp/TestDataset/ETIS-LaribPolypDB/masks/ 16 | 17 | test_Kvasir_img : /mnt/DATA-1/DATA-2/Feilong/scformer/data/polyp/TestDataset/Kvasir/images/ 18 | test_Kvasir_label : /mnt/DATA-1/DATA-2/Feilong/scformer/data/polyp/TestDataset/Kvasir/masks/ 19 | 20 | class_num: 1 21 | 22 | crop_size : 23 | w: 352 24 | h: 352 25 | batch_size : 32 26 | num_workers : 8 27 | 28 | Train_transform_list: 29 | resize: 30 | size: [352, 352] 31 | random_scale_crop: 32 | range: [0.75, 1.25] 33 | random_flip: 34 | lr: True 35 | ud: True 36 | random_rotate: 37 | range: [0, 359] 38 | random_image_enhance: 39 | methods: ['contrast', 'sharpness', 'brightness'] 40 | random_dilation_erosion: 41 | kernel_range: [2, 5] 42 | tonumpy: NULL 43 | normalize: 44 | mean: [0.485, 0.456, 0.406] 45 | std: [0.229, 0.224, 0.225] 46 | totensor: NULL 47 | 48 | Val_transform_list: 49 | resize: 50 | size: [352, 352] 51 | tonumpy: NULL 52 | normalize: 53 | mean: [0.485, 0.456, 0.406] 54 | std: [0.229, 0.224, 0.225] 55 | totensor: NULL 56 | 57 | 58 | model: 59 | model_name : mit_PLD_b4 60 | is_pretrained : False 61 | pretrained_path : /mnt/DATA-1/DATA-2/Feilong/classification/ssformer/ssformer/ssformer_L/ssformer_L.pth 62 | from_epoch : 0 63 | 64 | training: 65 | device : cuda 66 | lr : 1e-4 67 | max_epoch : 2000 68 | evl_epoch : 0 69 | 70 | other: 71 | checkpoint_save_path : /mnt/DATA-1/DATA-2/Feilong/scformer/train_package/mit/mit_mla 72 | logger_path : /mnt/DATA-1/DATA-2/Feilong/scformer/train_package/mit/mit_mla/mit_mla.log 73 | 74 | test: 75 | checkpoint_save_path : result/ssformer_L/ssformer_L.pth 76 | Test_transform_list: 77 | resize: 78 | size: [352, 352] 79 | tonumpy: NULL 80 | normalize: 81 | mean: [0.485, 0.456, 0.406] 82 | std: [0.229, 0.224, 0.225] 83 | totensor: NULL -------------------------------------------------------------------------------- /configs/ssformer-S.yaml: -------------------------------------------------------------------------------- 1 | dataset: 2 | train_img_root : /mnt/DATA-1/DATA-2/Feilong/scformer/data/polyp/TrainDataset/images/ 3 | train_label_root: /mnt/DATA-1/DATA-2/Feilong/scformer/data/polyp/TrainDataset/masks/ 4 | 5 | test_CVC-300_img : /mnt/DATA-1/DATA-2/Feilong/scformer/data/polyp/TestDataset/CVC-300/images/ 6 | test_CVC-300_label : /mnt/DATA-1/DATA-2/Feilong/scformer/data/polyp/TestDataset/CVC-300/masks/ 7 | 8 | test_CVC-ClinicDB_img : /mnt/DATA-1/DATA-2/Feilong/scformer/data/polyp/TestDataset/CVC-ClinicDB/images/ 9 | test_CVC-ClinicDB_label : /mnt/DATA-1/DATA-2/Feilong/scformer/data/polyp/TestDataset/CVC-ClinicDB/masks/ 10 | 11 | test_CVC-ColonDB_img : /mnt/DATA-1/DATA-2/Feilong/scformer/data/polyp/TestDataset/CVC-ColonDB/images/ 12 | test_CVC-ColonDB_label : /mnt/DATA-1/DATA-2/Feilong/scformer/data/polyp/TestDataset/CVC-ColonDB/masks/ 13 | 14 | test_ETIS-LaribPolypDB_img : /mnt/DATA-1/DATA-2/Feilong/scformer/data/polyp/TestDataset/ETIS-LaribPolypDB/images/ 15 | test_ETIS-LaribPolypDB_label : /mnt/DATA-1/DATA-2/Feilong/scformer/data/polyp/TestDataset/ETIS-LaribPolypDB/masks/ 16 | 17 | test_Kvasir_img : /mnt/DATA-1/DATA-2/Feilong/scformer/data/polyp/TestDataset/Kvasir/images/ 18 | test_Kvasir_label : /mnt/DATA-1/DATA-2/Feilong/scformer/data/polyp/TestDataset/Kvasir/masks/ 19 | 20 | class_num: 1 21 | 22 | crop_size : 23 | w: 352 24 | h: 352 25 | batch_size : 32 26 | num_workers : 8 27 | 28 | Train_transform_list: 29 | resize: 30 | size: [352, 352] 31 | random_scale_crop: 32 | range: [0.75, 1.25] 33 | random_flip: 34 | lr: True 35 | ud: True 36 | random_rotate: 37 | range: [0, 359] 38 | random_image_enhance: 39 | methods: ['contrast', 'sharpness', 'brightness'] 40 | random_dilation_erosion: 41 | kernel_range: [2, 5] 42 | tonumpy: NULL 43 | normalize: 44 | mean: [0.485, 0.456, 0.406] 45 | std: [0.229, 0.224, 0.225] 46 | totensor: NULL 47 | 48 | Val_transform_list: 49 | resize: 50 | size: [352, 352] 51 | tonumpy: NULL 52 | normalize: 53 | mean: [0.485, 0.456, 0.406] 54 | std: [0.229, 0.224, 0.225] 55 | totensor: NULL 56 | 57 | 58 | model: 59 | model_name : mit_PLD_b2 60 | is_pretrained : False 61 | pretrained_path : /mnt/DATA-1/DATA-2/Feilong/classification/ssformer/ssformer/ssformer_S/ssformer_S.pth 62 | from_epoch : 0 63 | 64 | training: 65 | device : cuda 66 | lr : 1e-4 67 | max_epoch : 2000 68 | evl_epoch : 0 69 | 70 | other: 71 | checkpoint_save_path : /mnt/DATA-1/DATA-2/Feilong/classification/ssformer/ssformer/ssformer_S 72 | logger_path : /mnt/DATA-1/DATA-2/Feilong/classification/ssformer/ssformer/ssformer_S/segformer.log 73 | 74 | test: 75 | checkpoint_save_path : result/ssformer_S/ssformer_S.pth 76 | Test_transform_list: 77 | resize: 78 | size: [352, 352] 79 | tonumpy: NULL 80 | normalize: 81 | mean: [0.485, 0.456, 0.406] 82 | std: [0.229, 0.224, 0.225] 83 | totensor: NULL -------------------------------------------------------------------------------- /configs/train.yaml: -------------------------------------------------------------------------------- 1 | dataset: 2 | train_img_root : /mnt/DATA-1/DATA-2/Feilong/scformer/data/polyp/TrainDataset/images/ 3 | train_label_root: /mnt/DATA-1/DATA-2/Feilong/scformer/data/polyp/TrainDataset/masks/ 4 | 5 | test_CVC-300_img : /mnt/DATA-1/DATA-2/Feilong/scformer/data/polyp/TestDataset/CVC-300/images/ 6 | test_CVC-300_label : /mnt/DATA-1/DATA-2/Feilong/scformer/data/polyp/TestDataset/CVC-300/masks/ 7 | 8 | test_CVC-ClinicDB_img : /mnt/DATA-1/DATA-2/Feilong/scformer/data/polyp/TestDataset/CVC-ClinicDB/images/ 9 | test_CVC-ClinicDB_label : /mnt/DATA-1/DATA-2/Feilong/scformer/data/polyp/TestDataset/CVC-ClinicDB/masks/ 10 | 11 | test_CVC-ColonDB_img : /mnt/DATA-1/DATA-2/Feilong/scformer/data/polyp/TestDataset/CVC-ColonDB/images/ 12 | test_CVC-ColonDB_label : /mnt/DATA-1/DATA-2/Feilong/scformer/data/polyp/TestDataset/CVC-ColonDB/masks/ 13 | 14 | test_ETIS-LaribPolypDB_img : /mnt/DATA-1/DATA-2/Feilong/scformer/data/polyp/TestDataset/ETIS-LaribPolypDB/images/ 15 | test_ETIS-LaribPolypDB_label : /mnt/DATA-1/DATA-2/Feilong/scformer/data/polyp/TestDataset/ETIS-LaribPolypDB/masks/ 16 | 17 | test_Kvasir_img : /mnt/DATA-1/DATA-2/Feilong/scformer/data/polyp/TestDataset/Kvasir/images/ 18 | test_Kvasir_label : /mnt/DATA-1/DATA-2/Feilong/scformer/data/polyp/TestDataset/Kvasir/masks/ 19 | 20 | class_num: 1 21 | 22 | crop_size : 23 | w: 352 24 | h: 352 25 | batch_size : 32 26 | num_workers : 8 27 | 28 | Train_transform_list: 29 | resize: 30 | size: [352, 352] 31 | random_scale_crop: 32 | range: [0.75, 1.25] 33 | random_flip: 34 | lr: True 35 | ud: True 36 | random_rotate: 37 | range: [0, 359] 38 | random_image_enhance: 39 | methods: ['contrast', 'sharpness', 'brightness'] 40 | random_dilation_erosion: 41 | kernel_range: [2, 5] 42 | tonumpy: NULL 43 | normalize: 44 | mean: [0.485, 0.456, 0.406] 45 | std: [0.229, 0.224, 0.225] 46 | totensor: NULL 47 | 48 | Val_transform_list: 49 | resize: 50 | size: [352, 352] 51 | tonumpy: NULL 52 | normalize: 53 | mean: [0.485, 0.456, 0.406] 54 | std: [0.229, 0.224, 0.225] 55 | totensor: NULL 56 | 57 | 58 | model: 59 | model_name : mit_PLD_b2 60 | is_pretrained : False 61 | pretrained_path : YOUR PRETRAINED MODEL PATH 62 | from_epoch : 0 63 | 64 | training: 65 | device : cuda 66 | lr : 1e-4 67 | max_epoch : 2000 68 | evl_epoch : 0 69 | 70 | other: 71 | checkpoint_save_path : /mnt/DATA-1/DATA-2/Feilong/classification/ssformer/lib 72 | logger_path : /mnt/DATA-1/DATA-2/Feilong/classification/ssformer/lib/segformer.log 73 | 74 | test: 75 | checkpoint_save_path : /mnt/DATA-1/DATA-2/Feilong/classification/ssformer/lib/ssformer.pth 76 | Test_transform_list: 77 | resize: 78 | size: [352, 352] 79 | tonumpy: NULL 80 | normalize: 81 | mean: [0.485, 0.456, 0.406] 82 | std: [0.229, 0.224, 0.225] 83 | totensor: NULL 84 | 85 | 86 | -------------------------------------------------------------------------------- /images/ssformer.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Qiming-Huang/ssformer/1351a07dd623c87401aa8d3a316ba03aae5ae123/images/ssformer.png -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | from .build_model import build -------------------------------------------------------------------------------- /models/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Qiming-Huang/ssformer/1351a07dd623c87401aa8d3a316ba03aae5ae123/models/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /models/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Qiming-Huang/ssformer/1351a07dd623c87401aa8d3a316ba03aae5ae123/models/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /models/__pycache__/build_model.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Qiming-Huang/ssformer/1351a07dd623c87401aa8d3a316ba03aae5ae123/models/__pycache__/build_model.cpython-37.pyc -------------------------------------------------------------------------------- /models/__pycache__/build_model.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Qiming-Huang/ssformer/1351a07dd623c87401aa8d3a316ba03aae5ae123/models/__pycache__/build_model.cpython-38.pyc -------------------------------------------------------------------------------- /models/__pycache__/test.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Qiming-Huang/ssformer/1351a07dd623c87401aa8d3a316ba03aae5ae123/models/__pycache__/test.cpython-37.pyc -------------------------------------------------------------------------------- /models/build_model.py: -------------------------------------------------------------------------------- 1 | 2 | def build(model_name, class_num=1): 3 | ############################################ 1. CVT 4 | if model_name == "cvt_pup": 5 | from .cvt.cvt_pup import cvt_PUP 6 | model = cvt_PUP(class_num=class_num) 7 | return model 8 | 9 | if model_name == "cvt_mla": 10 | from .cvt.cvt_mla import cvt_mla 11 | model = cvt_mla(class_num=class_num) 12 | return model 13 | 14 | if model_name == "cvt_PPD": 15 | from .cvt.cvt_PPD import cvt_PPD 16 | model = cvt_PPD(class_num=class_num) 17 | return model 18 | 19 | if model_name == "cvt_SD": 20 | from .cvt.cvt_SD import cvt_SD 21 | model = cvt_SD(class_num=class_num) 22 | return model 23 | 24 | if model_name == "cvt_PLD": 25 | from .cvt.cvt_PLD import cvt_PLD 26 | model = cvt_PLD(class_num=class_num) 27 | return model 28 | 29 | #################################################### 2. MIT 30 | if model_name == "mit_PLD_b2": 31 | from .mit.mit_PLD_b2 import mit_PLD_b2 32 | model = mit_PLD_b2(class_num=class_num) 33 | return model 34 | 35 | if model_name == "mit_PLD_add": 36 | from .mit.mit_PLD_add import mit_PLD_add 37 | model = mit_PLD_add(class_num=class_num) 38 | return model 39 | 40 | if model_name == "mit_mla": 41 | from .mit.mit_mla import mit_mla 42 | model = mit_mla(class_num=class_num) 43 | return model 44 | 45 | if model_name == "mit_pup": 46 | from .mit.mit_pup import mit_pup 47 | model = mit_pup(class_num=class_num) 48 | return model 49 | 50 | if model_name == "mit_SD": 51 | from .mit.mit_SD import mit_SD 52 | model = mit_SD(class_num=class_num) 53 | return model 54 | 55 | if model_name == "mit_PPD": 56 | from .mit.mit_PPD import mit_PPD 57 | model = mit_PPD(class_num=class_num) 58 | return model 59 | 60 | if model_name == "mit_PLD_b4": 61 | from .mit.mit_PLD_b4 import mit_PLD_b4 62 | model = mit_PLD_b4(class_num=class_num) 63 | return model 64 | 65 | #################################################### 3. PVT 66 | 67 | if model_name == "pvt_PLD": 68 | from .pvt.pvt_PLD import pvt_PLD 69 | model = pvt_PLD(class_num=class_num) 70 | return model 71 | 72 | if model_name == "pvt_mla": 73 | from .pvt.pvt_mla import pvt_mla 74 | model = pvt_mla(class_num=class_num) 75 | return model 76 | 77 | if model_name == "pvt_pup": 78 | from .pvt.pvt_pup import pvt_pup 79 | model = pvt_pup(class_num=class_num) 80 | return model 81 | 82 | if model_name == "pvt_SD": 83 | from .pvt.pvt_SD import pvt_SD 84 | model = pvt_SD(class_num=class_num) 85 | return model 86 | 87 | if model_name == "pvt_PPD": 88 | from .pvt.pvt_PPD import pvt_PPD 89 | model = pvt_PPD(class_num=class_num) 90 | return model 91 | 92 | #################################################### simVit 93 | 94 | 95 | if model_name == "simVit_PLD": 96 | from .simVit.simVit_PLD import simVit_PLD 97 | model = simVit_PLD(class_num=class_num) 98 | return model 99 | 100 | if model_name == "simVit_mla": 101 | from .simVit.simVit_mla import simVit_mla 102 | model = simVit_mla(class_num=class_num) 103 | return model 104 | 105 | if model_name == "simVit_pup": 106 | from .simVit.simVit_pup import simVit_pup 107 | model = simVit_pup(class_num=class_num) 108 | return model 109 | 110 | if model_name == "simVit_SD": 111 | from .simVit.simVit_SD import simVit_SD 112 | model = simVit_SD(class_num=class_num) 113 | return model 114 | 115 | if model_name == "simVit_PPD": 116 | from .simVit.simVit_PPD import simVit_PPD 117 | model = simVit_PPD(class_num=class_num) 118 | return model 119 | 120 | #################################################### 121 | 122 | 123 | if model_name == "ssa_PLD": 124 | from .ssa.ssa_PLD import ssa_PLD 125 | model = ssa_PLD(class_num=class_num) 126 | return model 127 | 128 | if model_name == "ssa_mla": 129 | from .ssa.ssa_mla import ssa_mla 130 | model = ssa_mla(class_num=class_num) 131 | return model 132 | 133 | if model_name == "ssa_pup": 134 | from .ssa.ssa_pup import ssa_pup 135 | model = ssa_pup(class_num=class_num) 136 | return model 137 | 138 | if model_name == "ssa_SD": 139 | from .ssa.ssa_SD import ssa_SD 140 | model = ssa_SD(class_num=class_num) 141 | return model 142 | 143 | if model_name == "ssa_PPD": 144 | from .ssa.ssa_PPD import ssa_PPD 145 | model = ssa_PPD(class_num=class_num) 146 | return model 147 | 148 | -------------------------------------------------------------------------------- /models/cvt/__pycache__/cvt_SD.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Qiming-Huang/ssformer/1351a07dd623c87401aa8d3a316ba03aae5ae123/models/cvt/__pycache__/cvt_SD.cpython-37.pyc -------------------------------------------------------------------------------- /models/cvt/__pycache__/cvt_mla.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Qiming-Huang/ssformer/1351a07dd623c87401aa8d3a316ba03aae5ae123/models/cvt/__pycache__/cvt_mla.cpython-37.pyc -------------------------------------------------------------------------------- /models/cvt/__pycache__/cvt_pup.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Qiming-Huang/ssformer/1351a07dd623c87401aa8d3a316ba03aae5ae123/models/cvt/__pycache__/cvt_pup.cpython-37.pyc -------------------------------------------------------------------------------- /models/cvt/__pycache__/cvt_srm.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Qiming-Huang/ssformer/1351a07dd623c87401aa8d3a316ba03aae5ae123/models/cvt/__pycache__/cvt_srm.cpython-37.pyc -------------------------------------------------------------------------------- /models/mit/__pycache__/mit_PLD_b2.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Qiming-Huang/ssformer/1351a07dd623c87401aa8d3a316ba03aae5ae123/models/mit/__pycache__/mit_PLD_b2.cpython-37.pyc -------------------------------------------------------------------------------- /models/mit/__pycache__/mit_PLD_b2.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Qiming-Huang/ssformer/1351a07dd623c87401aa8d3a316ba03aae5ae123/models/mit/__pycache__/mit_PLD_b2.cpython-38.pyc -------------------------------------------------------------------------------- /models/mit/__pycache__/mit_PLD_b4.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Qiming-Huang/ssformer/1351a07dd623c87401aa8d3a316ba03aae5ae123/models/mit/__pycache__/mit_PLD_b4.cpython-37.pyc -------------------------------------------------------------------------------- /models/mit/__pycache__/mit_PLD_b4.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Qiming-Huang/ssformer/1351a07dd623c87401aa8d3a316ba03aae5ae123/models/mit/__pycache__/mit_PLD_b4.cpython-38.pyc -------------------------------------------------------------------------------- /models/mit/__pycache__/mit_SD.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Qiming-Huang/ssformer/1351a07dd623c87401aa8d3a316ba03aae5ae123/models/mit/__pycache__/mit_SD.cpython-37.pyc -------------------------------------------------------------------------------- /models/mit/__pycache__/mit_mla.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Qiming-Huang/ssformer/1351a07dd623c87401aa8d3a316ba03aae5ae123/models/mit/__pycache__/mit_mla.cpython-37.pyc -------------------------------------------------------------------------------- /models/mit/__pycache__/mit_pup.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Qiming-Huang/ssformer/1351a07dd623c87401aa8d3a316ba03aae5ae123/models/mit/__pycache__/mit_pup.cpython-37.pyc -------------------------------------------------------------------------------- /models/mit/__pycache__/mit_srm.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Qiming-Huang/ssformer/1351a07dd623c87401aa8d3a316ba03aae5ae123/models/mit/__pycache__/mit_srm.cpython-37.pyc -------------------------------------------------------------------------------- /models/mit/__pycache__/mit_srm.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Qiming-Huang/ssformer/1351a07dd623c87401aa8d3a316ba03aae5ae123/models/mit/__pycache__/mit_srm.cpython-38.pyc -------------------------------------------------------------------------------- /models/mit/__pycache__/mit_srm_add.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Qiming-Huang/ssformer/1351a07dd623c87401aa8d3a316ba03aae5ae123/models/mit/__pycache__/mit_srm_add.cpython-37.pyc -------------------------------------------------------------------------------- /models/mit/__pycache__/mit_srm_b4.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Qiming-Huang/ssformer/1351a07dd623c87401aa8d3a316ba03aae5ae123/models/mit/__pycache__/mit_srm_b4.cpython-37.pyc -------------------------------------------------------------------------------- /models/mit/__pycache__/segformer.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Qiming-Huang/ssformer/1351a07dd623c87401aa8d3a316ba03aae5ae123/models/mit/__pycache__/segformer.cpython-37.pyc -------------------------------------------------------------------------------- /models/mit/mit_PLD_b2.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | import torch.nn.functional as F 3 | from functools import partial 4 | from timm.models.layers import to_2tuple, trunc_normal_ 5 | import math 6 | from timm.models.layers import DropPath 7 | from torch.nn import Module 8 | from mmcv.cnn import ConvModule 9 | from torch.nn import Conv2d, UpsamplingBilinear2d 10 | import torch.nn as nn 11 | import torch 12 | 13 | class Mlp(nn.Module): 14 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): 15 | super().__init__() 16 | out_features = out_features or in_features 17 | hidden_features = hidden_features or in_features 18 | self.fc1 = nn.Linear(in_features, hidden_features) 19 | self.dwconv = DWConv(hidden_features) 20 | self.act = act_layer() 21 | self.fc2 = nn.Linear(hidden_features, out_features) 22 | self.drop = nn.Dropout(drop) 23 | 24 | self.apply(self._init_weights) 25 | 26 | def _init_weights(self, m): 27 | if isinstance(m, nn.Linear): 28 | trunc_normal_(m.weight, std=.02) 29 | if isinstance(m, nn.Linear) and m.bias is not None: 30 | nn.init.constant_(m.bias, 0) 31 | elif isinstance(m, nn.LayerNorm): 32 | nn.init.constant_(m.bias, 0) 33 | nn.init.constant_(m.weight, 1.0) 34 | elif isinstance(m, nn.Conv2d): 35 | fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 36 | fan_out //= m.groups 37 | m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) 38 | if m.bias is not None: 39 | m.bias.data.zero_() 40 | 41 | def forward(self, x, H, W): 42 | x = self.fc1(x) 43 | x = self.dwconv(x, H, W) 44 | x = self.act(x) 45 | x = self.drop(x) 46 | x = self.fc2(x) 47 | x = self.drop(x) 48 | return x 49 | 50 | 51 | class Attention(nn.Module): 52 | def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., sr_ratio=1): 53 | super().__init__() 54 | assert dim % num_heads == 0, f"dim {dim} should be divided by num_heads {num_heads}." 55 | 56 | self.dim = dim 57 | self.num_heads = num_heads 58 | head_dim = dim // num_heads 59 | self.scale = qk_scale or head_dim ** -0.5 60 | 61 | self.q = nn.Linear(dim, dim, bias=qkv_bias) 62 | self.kv = nn.Linear(dim, dim * 2, bias=qkv_bias) 63 | self.attn_drop = nn.Dropout(attn_drop) 64 | self.proj = nn.Linear(dim, dim) 65 | self.proj_drop = nn.Dropout(proj_drop) 66 | 67 | self.sr_ratio = sr_ratio 68 | if sr_ratio > 1: 69 | self.sr = nn.Conv2d(dim, dim, kernel_size=sr_ratio, stride=sr_ratio) 70 | self.norm = nn.LayerNorm(dim) 71 | 72 | self.apply(self._init_weights) 73 | 74 | def _init_weights(self, m): 75 | if isinstance(m, nn.Linear): 76 | trunc_normal_(m.weight, std=.02) 77 | if isinstance(m, nn.Linear) and m.bias is not None: 78 | nn.init.constant_(m.bias, 0) 79 | elif isinstance(m, nn.LayerNorm): 80 | nn.init.constant_(m.bias, 0) 81 | nn.init.constant_(m.weight, 1.0) 82 | elif isinstance(m, nn.Conv2d): 83 | fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 84 | fan_out //= m.groups 85 | m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) 86 | if m.bias is not None: 87 | m.bias.data.zero_() 88 | 89 | def forward(self, x, H, W): 90 | B, N, C = x.shape 91 | q = self.q(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) 92 | 93 | if self.sr_ratio > 1: 94 | x_ = x.permute(0, 2, 1).reshape(B, C, H, W) 95 | x_ = self.sr(x_).reshape(B, C, -1).permute(0, 2, 1) 96 | x_ = self.norm(x_) 97 | kv = self.kv(x_).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 98 | else: 99 | kv = self.kv(x).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 100 | k, v = kv[0], kv[1] 101 | 102 | attn = (q @ k.transpose(-2, -1)) * self.scale 103 | attn = attn.softmax(dim=-1) 104 | attn = self.attn_drop(attn) 105 | 106 | x = (attn @ v).transpose(1, 2).reshape(B, N, C) 107 | x = self.proj(x) 108 | x = self.proj_drop(x) 109 | 110 | return x 111 | 112 | 113 | class Block(nn.Module): 114 | 115 | def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., 116 | drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, sr_ratio=1): 117 | super().__init__() 118 | self.norm1 = norm_layer(dim) 119 | self.attn = Attention( 120 | dim, 121 | num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, 122 | attn_drop=attn_drop, proj_drop=drop, sr_ratio=sr_ratio) 123 | # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here 124 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 125 | self.norm2 = norm_layer(dim) 126 | mlp_hidden_dim = int(dim * mlp_ratio) 127 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) 128 | 129 | def forward(self, x, H, W): 130 | x = x + self.drop_path(self.attn(self.norm1(x), H, W)) 131 | x = x + self.drop_path(self.mlp(self.norm2(x), H, W)) 132 | 133 | return x 134 | 135 | 136 | class OverlapPatchEmbed(nn.Module): 137 | """ Image to Patch Embedding 138 | """ 139 | 140 | def __init__(self, img_size=224, patch_size=7, stride=4, in_chans=3, embed_dim=768): 141 | super().__init__() 142 | img_size = to_2tuple(img_size) 143 | patch_size = to_2tuple(patch_size) 144 | 145 | self.img_size = img_size 146 | self.patch_size = patch_size 147 | self.H, self.W = img_size[0] // patch_size[0], img_size[1] // patch_size[1] 148 | self.num_patches = self.H * self.W 149 | self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=stride, 150 | padding=(patch_size[0] // 2, patch_size[1] // 2)) 151 | self.norm = nn.LayerNorm(embed_dim) 152 | 153 | def forward(self, x): 154 | x = self.proj(x) 155 | _, _, H, W = x.shape 156 | x = x.flatten(2).transpose(1, 2) 157 | x = self.norm(x) 158 | 159 | return x, H, W 160 | 161 | 162 | class MixVisionTransformer(nn.Module): 163 | def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dims=[64, 128, 256, 512], 164 | num_heads=[1, 2, 4, 8], mlp_ratios=[4, 4, 4, 4], qkv_bias=False, qk_scale=None, drop_rate=0., 165 | attn_drop_rate=0., drop_path_rate=0., norm_layer=nn.LayerNorm, 166 | depths=[3, 4, 6, 3], sr_ratios=[8, 4, 2, 1]): 167 | super().__init__() 168 | self.num_classes = num_classes 169 | self.depths = depths 170 | 171 | # patch_embed 172 | self.patch_embed1 = OverlapPatchEmbed(img_size=img_size, patch_size=7, stride=4, in_chans=in_chans, 173 | embed_dim=embed_dims[0]) 174 | self.patch_embed2 = OverlapPatchEmbed(img_size=img_size // 4, patch_size=3, stride=2, in_chans=embed_dims[0], 175 | embed_dim=embed_dims[1]) 176 | self.patch_embed3 = OverlapPatchEmbed(img_size=img_size // 8, patch_size=3, stride=2, in_chans=embed_dims[1], 177 | embed_dim=embed_dims[2]) 178 | self.patch_embed4 = OverlapPatchEmbed(img_size=img_size // 16, patch_size=3, stride=2, in_chans=embed_dims[2], 179 | embed_dim=embed_dims[3]) 180 | 181 | # transformer encoder 182 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule 183 | cur = 0 184 | self.block1 = nn.ModuleList([Block( 185 | dim=embed_dims[0], num_heads=num_heads[0], mlp_ratio=mlp_ratios[0], qkv_bias=qkv_bias, qk_scale=qk_scale, 186 | drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer, 187 | sr_ratio=sr_ratios[0]) 188 | for i in range(depths[0])]) 189 | self.norm1 = norm_layer(embed_dims[0]) 190 | 191 | cur += depths[0] 192 | self.block2 = nn.ModuleList([Block( 193 | dim=embed_dims[1], num_heads=num_heads[1], mlp_ratio=mlp_ratios[1], qkv_bias=qkv_bias, qk_scale=qk_scale, 194 | drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer, 195 | sr_ratio=sr_ratios[1]) 196 | for i in range(depths[1])]) 197 | self.norm2 = norm_layer(embed_dims[1]) 198 | 199 | cur += depths[1] 200 | self.block3 = nn.ModuleList([Block( 201 | dim=embed_dims[2], num_heads=num_heads[2], mlp_ratio=mlp_ratios[2], qkv_bias=qkv_bias, qk_scale=qk_scale, 202 | drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer, 203 | sr_ratio=sr_ratios[2]) 204 | for i in range(depths[2])]) 205 | self.norm3 = norm_layer(embed_dims[2]) 206 | 207 | cur += depths[2] 208 | self.block4 = nn.ModuleList([Block( 209 | dim=embed_dims[3], num_heads=num_heads[3], mlp_ratio=mlp_ratios[3], qkv_bias=qkv_bias, qk_scale=qk_scale, 210 | drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer, 211 | sr_ratio=sr_ratios[3]) 212 | for i in range(depths[3])]) 213 | self.norm4 = norm_layer(embed_dims[3]) 214 | 215 | def forward_features(self, x): 216 | B = x.shape[0] 217 | outs = [] 218 | 219 | # stage 1 220 | x, H, W = self.patch_embed1(x) 221 | for i, blk in enumerate(self.block1): 222 | x = blk(x, H, W) 223 | x = self.norm1(x) 224 | x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() 225 | outs.append(x) 226 | 227 | # stage 2 228 | x, H, W = self.patch_embed2(x) 229 | for i, blk in enumerate(self.block2): 230 | x = blk(x, H, W) 231 | x = self.norm2(x) 232 | x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() 233 | outs.append(x) 234 | 235 | # stage 3 236 | x, H, W = self.patch_embed3(x) 237 | for i, blk in enumerate(self.block3): 238 | x = blk(x, H, W) 239 | x = self.norm3(x) 240 | x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() 241 | outs.append(x) 242 | 243 | # stage 4 244 | x, H, W = self.patch_embed4(x) 245 | for i, blk in enumerate(self.block4): 246 | x = blk(x, H, W) 247 | x = self.norm4(x) 248 | x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() 249 | outs.append(x) 250 | 251 | return outs 252 | 253 | def forward(self, x): 254 | x = self.forward_features(x) 255 | 256 | # x = self.head(x[3]) 257 | 258 | return x 259 | 260 | 261 | class DWConv(nn.Module): 262 | def __init__(self, dim=768): 263 | super(DWConv, self).__init__() 264 | self.dwconv = nn.Conv2d(dim, dim, 3, 1, 1, bias=True, groups=dim) 265 | 266 | def forward(self, x, H, W): 267 | B, N, C = x.shape 268 | x = x.transpose(1, 2).view(B, C, H, W) 269 | x = self.dwconv(x) 270 | x = x.flatten(2).transpose(1, 2) 271 | 272 | return x 273 | 274 | 275 | class mit_b0(MixVisionTransformer): 276 | def __init__(self, **kwargs): 277 | super(mit_b0, self).__init__( 278 | patch_size=4, embed_dims=[32, 64, 160, 256], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4], 279 | qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[2, 2, 2, 2], sr_ratios=[8, 4, 2, 1], 280 | drop_rate=0.0, drop_path_rate=0.1) 281 | 282 | 283 | class mit_b1(MixVisionTransformer): 284 | def __init__(self, **kwargs): 285 | super(mit_b1, self).__init__( 286 | patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4], 287 | qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[2, 2, 2, 2], sr_ratios=[8, 4, 2, 1], 288 | drop_rate=0.0, drop_path_rate=0.1) 289 | 290 | 291 | class mit_b2(MixVisionTransformer): 292 | def __init__(self, **kwargs): 293 | super(mit_b2, self).__init__( 294 | patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4], 295 | qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[3, 4, 6, 3], sr_ratios=[8, 4, 2, 1], 296 | drop_rate=0.0, drop_path_rate=0.1) 297 | 298 | 299 | class mit_b3(MixVisionTransformer): 300 | def __init__(self, **kwargs): 301 | super(mit_b3, self).__init__( 302 | patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4], 303 | qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[3, 4, 18, 3], sr_ratios=[8, 4, 2, 1], 304 | drop_rate=0.0, drop_path_rate=0.1) 305 | 306 | 307 | class mit_b4(MixVisionTransformer): 308 | def __init__(self, **kwargs): 309 | super(mit_b4, self).__init__( 310 | patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4], 311 | qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[3, 8, 27, 3], sr_ratios=[8, 4, 2, 1], 312 | drop_rate=0.0, drop_path_rate=0.1) 313 | 314 | 315 | class mit_b5(MixVisionTransformer): 316 | def __init__(self, **kwargs): 317 | super(mit_b5, self).__init__( 318 | patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4], 319 | qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[3, 6, 40, 3], sr_ratios=[8, 4, 2, 1], 320 | drop_rate=0.0, drop_path_rate=0.1) 321 | 322 | 323 | 324 | 325 | def resize(input, 326 | size=None, 327 | scale_factor=None, 328 | mode='nearest', 329 | align_corners=None, 330 | warning=True): 331 | if warning: 332 | if size is not None and align_corners: 333 | input_h, input_w = tuple(int(x) for x in input.shape[2:]) 334 | output_h, output_w = tuple(int(x) for x in size) 335 | if output_h > input_h or output_w > output_h: 336 | if ((output_h > 1 and output_w > 1 and input_h > 1 337 | and input_w > 1) and (output_h - 1) % (input_h - 1) 338 | and (output_w - 1) % (input_w - 1)): 339 | warnings.warn( 340 | f'When align_corners={align_corners}, ' 341 | 'the output would more aligned if ' 342 | f'input size {(input_h, input_w)} is `x+1` and ' 343 | f'out size {(output_h, output_w)} is `nx+1`') 344 | return F.interpolate(input, size, scale_factor, mode, align_corners) 345 | 346 | 347 | class MLP(nn.Module): 348 | """ 349 | Linear Embedding 350 | """ 351 | 352 | def __init__(self, input_dim=512, embed_dim=768): 353 | super().__init__() 354 | self.proj = nn.Linear(input_dim, embed_dim) 355 | 356 | def forward(self, x): 357 | x = x.flatten(2).transpose(1, 2) 358 | x = self.proj(x) 359 | return x 360 | 361 | 362 | class conv(nn.Module): 363 | """ 364 | Linear Embedding 365 | """ 366 | 367 | def __init__(self, input_dim=512, embed_dim=768, k_s=3): 368 | super().__init__() 369 | 370 | self.proj = nn.Sequential(nn.Conv2d(input_dim, embed_dim, 3, padding=1, bias=False), nn.ReLU(), 371 | nn.Conv2d(embed_dim, embed_dim, 3, padding=1, bias=False), nn.ReLU()) 372 | 373 | def forward(self, x): 374 | x = self.proj(x) 375 | x = x.flatten(2).transpose(1, 2) 376 | return x 377 | 378 | 379 | import cv2 380 | import random 381 | 382 | 383 | class Decoder(Module): 384 | """ 385 | SegFormer: Simple and Efficient Design for Semantic Segmentation with Transformers 386 | """ 387 | 388 | def __init__(self, dims, dim, class_num=2): 389 | super(Decoder, self).__init__() 390 | self.num_classes = class_num 391 | 392 | c1_in_channels, c2_in_channels, c3_in_channels, c4_in_channels = dims[0], dims[1], dims[2], dims[3] 393 | embedding_dim = dim 394 | 395 | self.linear_c4 = conv(input_dim=c4_in_channels, embed_dim=embedding_dim) 396 | self.linear_c3 = conv(input_dim=c3_in_channels, embed_dim=embedding_dim) 397 | self.linear_c2 = conv(input_dim=c2_in_channels, embed_dim=embedding_dim) 398 | self.linear_c1 = conv(input_dim=c1_in_channels, embed_dim=embedding_dim) 399 | 400 | self.linear_fuse = ConvModule(in_channels=embedding_dim * 4, out_channels=embedding_dim, kernel_size=1,norm_cfg=dict(type='BN', requires_grad=True)) 401 | self.linear_fuse34 = ConvModule(in_channels=embedding_dim * 2, out_channels=embedding_dim, kernel_size=1,norm_cfg=dict(type='BN', requires_grad=True)) 402 | self.linear_fuse2 = ConvModule(in_channels=embedding_dim * 2, out_channels=embedding_dim, kernel_size=1,norm_cfg=dict(type='BN', requires_grad=True)) 403 | self.linear_fuse1 = ConvModule(in_channels=embedding_dim * 2, out_channels=embedding_dim, kernel_size=1,norm_cfg=dict(type='BN', requires_grad=True)) 404 | 405 | self.linear_pred = Conv2d(embedding_dim, self.num_classes, kernel_size=1) 406 | self.dropout = nn.Dropout(0.1) 407 | 408 | def forward(self, inputs): 409 | c1, c2, c3, c4 = inputs 410 | ############## MLP decoder on C1-C4 ########### 411 | n, _, h, w = c4.shape 412 | 413 | _c4 = self.linear_c4(c4).permute(0, 2, 1).reshape(n, -1, c4.shape[2], c4.shape[3]) 414 | _c4 = resize(_c4, size=c1.size()[2:], mode='bilinear', align_corners=False) 415 | _c3 = self.linear_c3(c3).permute(0, 2, 1).reshape(n, -1, c3.shape[2], c3.shape[3]) 416 | _c3 = resize(_c3, size=c1.size()[2:], mode='bilinear', align_corners=False) 417 | _c2 = self.linear_c2(c2).permute(0, 2, 1).reshape(n, -1, c2.shape[2], c2.shape[3]) 418 | _c2 = resize(_c2, size=c1.size()[2:], mode='bilinear', align_corners=False) 419 | _c1 = self.linear_c1(c1).permute(0, 2, 1).reshape(n, -1, c1.shape[2], c1.shape[3]) 420 | 421 | L34 = self.linear_fuse34(torch.cat([_c4, _c3], dim=1)) 422 | L2 = self.linear_fuse2(torch.cat([L34, _c2], dim=1)) 423 | _c = self.linear_fuse1(torch.cat([L2, _c1], dim=1)) 424 | 425 | 426 | x = self.dropout(_c) 427 | x = self.linear_pred(x) 428 | 429 | return x 430 | 431 | 432 | class mit_PLD_b2(nn.Module): 433 | def __init__(self, class_num=2, **kwargs): 434 | super(mit_PLD_b2, self).__init__() 435 | self.class_num = class_num 436 | self.backbone = mit_b2() 437 | self.decode_head = Decoder(dims=[64, 128, 320, 512], dim=256, class_num=class_num) 438 | self._init_weights() # load pretrain 439 | 440 | def forward(self, x): 441 | features = self.backbone(x) 442 | 443 | features = self.decode_head(features) 444 | up = UpsamplingBilinear2d(scale_factor=4) 445 | features = up(features) 446 | return features 447 | def _init_weights(self): 448 | pretrained_dict = torch.load('/mnt/DATA-1/DATA-2/Feilong/scformer/models/mit/mit_b2.pth') 449 | model_dict = self.backbone.state_dict() 450 | pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict} 451 | model_dict.update(pretrained_dict) 452 | self.backbone.load_state_dict(model_dict) 453 | print("successfully loaded!!!!") 454 | 455 | 456 | -------------------------------------------------------------------------------- /models/mit/mit_PPD.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | from mit_srm import mit_b2 10 | import os 11 | import torch 12 | import torch.nn as nn 13 | import torch.nn.functional as F 14 | 15 | 16 | class BasicConv2d(nn.Module): 17 | def __init__(self, in_planes, out_planes, kernel_size, stride=1, padding=0, dilation=1): 18 | super(BasicConv2d, self).__init__() 19 | 20 | self.conv = nn.Conv2d(in_planes, out_planes, 21 | kernel_size=kernel_size, stride=stride, 22 | padding=padding, dilation=dilation, bias=False) 23 | self.bn = nn.BatchNorm2d(out_planes) 24 | self.relu = nn.ReLU(inplace=True) 25 | 26 | def forward(self, x): 27 | x = self.conv(x) 28 | x = self.bn(x) 29 | return x 30 | 31 | 32 | class CFM(nn.Module): 33 | def __init__(self, channel): 34 | super(CFM, self).__init__() 35 | self.relu = nn.ReLU(True) 36 | 37 | self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) 38 | self.conv_upsample1 = BasicConv2d(channel, channel, 3, padding=1) 39 | self.conv_upsample2 = BasicConv2d(channel, channel, 3, padding=1) 40 | self.conv_upsample3 = BasicConv2d(channel, channel, 3, padding=1) 41 | self.conv_upsample4 = BasicConv2d(channel, channel, 3, padding=1) 42 | self.conv_upsample5 = BasicConv2d(2 * channel, 2 * channel, 3, padding=1) 43 | 44 | self.conv_concat2 = BasicConv2d(2 * channel, 2 * channel, 3, padding=1) 45 | self.conv_concat3 = BasicConv2d(3 * channel, 3 * channel, 3, padding=1) 46 | self.conv4 = BasicConv2d(3 * channel, channel, 3, padding=1) 47 | 48 | def forward(self, x1, x2, x3): 49 | x1_1 = x1 50 | x2_1 = self.conv_upsample1(self.upsample(x1)) * x2 51 | x3_1 = self.conv_upsample2(self.upsample(self.upsample(x1))) \ 52 | * self.conv_upsample3(self.upsample(x2)) * x3 53 | 54 | x2_2 = torch.cat((x2_1, self.conv_upsample4(self.upsample(x1_1))), 1) 55 | x2_2 = self.conv_concat2(x2_2) 56 | 57 | x3_2 = torch.cat((x3_1, self.conv_upsample5(self.upsample(x2_2))), 1) 58 | x3_2 = self.conv_concat3(x3_2) 59 | 60 | x1 = self.conv4(x3_2) 61 | 62 | return x1 63 | 64 | 65 | 66 | 67 | class GCN(nn.Module): 68 | def __init__(self, num_state, num_node, bias=False): 69 | super(GCN, self).__init__() 70 | self.conv1 = nn.Conv1d(num_node, num_node, kernel_size=1) 71 | self.relu = nn.ReLU(inplace=True) 72 | self.conv2 = nn.Conv1d(num_state, num_state, kernel_size=1, bias=bias) 73 | 74 | def forward(self, x): 75 | h = self.conv1(x.permute(0, 2, 1)).permute(0, 2, 1) 76 | h = h - x 77 | h = self.relu(self.conv2(h)) 78 | return h 79 | 80 | 81 | class SAM(nn.Module): 82 | def __init__(self, num_in=32, plane_mid=16, mids=4, normalize=False): 83 | super(SAM, self).__init__() 84 | 85 | self.normalize = normalize 86 | self.num_s = int(plane_mid) 87 | self.num_n = (mids) * (mids) 88 | self.priors = nn.AdaptiveAvgPool2d(output_size=(mids + 2, mids + 2)) 89 | 90 | self.conv_state = nn.Conv2d(num_in, self.num_s, kernel_size=1) 91 | self.conv_proj = nn.Conv2d(num_in, self.num_s, kernel_size=1) 92 | self.gcn = GCN(num_state=self.num_s, num_node=self.num_n) 93 | self.conv_extend = nn.Conv2d(self.num_s, num_in, kernel_size=1, bias=False) 94 | 95 | def forward(self, x, edge): 96 | edge = F.upsample(edge, (x.size()[-2], x.size()[-1])) 97 | 98 | n, c, h, w = x.size() 99 | edge = torch.nn.functional.softmax(edge, dim=1)[:, 1, :, :].unsqueeze(1) 100 | 101 | x_state_reshaped = self.conv_state(x).view(n, self.num_s, -1) 102 | x_proj = self.conv_proj(x) 103 | x_mask = x_proj * edge 104 | 105 | x_anchor1 = self.priors(x_mask) 106 | x_anchor2 = self.priors(x_mask)[:, :, 1:-1, 1:-1].reshape(n, self.num_s, -1) 107 | x_anchor = self.priors(x_mask)[:, :, 1:-1, 1:-1].reshape(n, self.num_s, -1) 108 | 109 | x_proj_reshaped = torch.matmul(x_anchor.permute(0, 2, 1), x_proj.reshape(n, self.num_s, -1)) 110 | x_proj_reshaped = torch.nn.functional.softmax(x_proj_reshaped, dim=1) 111 | 112 | x_rproj_reshaped = x_proj_reshaped 113 | 114 | x_n_state = torch.matmul(x_state_reshaped, x_proj_reshaped.permute(0, 2, 1)) 115 | if self.normalize: 116 | x_n_state = x_n_state * (1. / x_state_reshaped.size(2)) 117 | x_n_rel = self.gcn(x_n_state) 118 | 119 | x_state_reshaped = torch.matmul(x_n_rel, x_rproj_reshaped) 120 | x_state = x_state_reshaped.view(n, self.num_s, *x.size()[2:]) 121 | out = x + (self.conv_extend(x_state)) 122 | 123 | return out 124 | 125 | 126 | class ChannelAttention(nn.Module): 127 | def __init__(self, in_planes, ratio=16): 128 | super(ChannelAttention, self).__init__() 129 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 130 | self.max_pool = nn.AdaptiveMaxPool2d(1) 131 | 132 | self.fc1 = nn.Conv2d(in_planes, in_planes // 16, 1, bias=False) 133 | self.relu1 = nn.ReLU() 134 | self.fc2 = nn.Conv2d(in_planes // 16, in_planes, 1, bias=False) 135 | 136 | self.sigmoid = nn.Sigmoid() 137 | 138 | def forward(self, x): 139 | avg_out = self.fc2(self.relu1(self.fc1(self.avg_pool(x)))) 140 | max_out = self.fc2(self.relu1(self.fc1(self.max_pool(x)))) 141 | out = avg_out + max_out 142 | return self.sigmoid(out) 143 | 144 | 145 | class SpatialAttention(nn.Module): 146 | def __init__(self, kernel_size=7): 147 | super(SpatialAttention, self).__init__() 148 | 149 | assert kernel_size in (3, 7), 'kernel size must be 3 or 7' 150 | padding = 3 if kernel_size == 7 else 1 151 | 152 | self.conv1 = nn.Conv2d(2, 1, kernel_size, padding=padding, bias=False) 153 | self.sigmoid = nn.Sigmoid() 154 | 155 | def forward(self, x): 156 | avg_out = torch.mean(x, dim=1, keepdim=True) 157 | max_out, _ = torch.max(x, dim=1, keepdim=True) 158 | x = torch.cat([avg_out, max_out], dim=1) 159 | x = self.conv1(x) 160 | return self.sigmoid(x) 161 | 162 | 163 | class PolypPVT(nn.Module): 164 | def __init__(self, channel=32): 165 | super(PolypPVT, self).__init__() 166 | 167 | self.backbone = mit_b2() # [64, 128, 320, 512] 168 | path = '/mnt/DATA-1/DATA-2/Feilong/scformer/models/mit/mit_b2.pth' 169 | save_model = torch.load(path) 170 | model_dict = self.backbone.state_dict() 171 | state_dict = {k: v for k, v in save_model.items() if k in model_dict.keys()} 172 | model_dict.update(state_dict) 173 | self.backbone.load_state_dict(model_dict) 174 | 175 | self.Translayer2_0 = BasicConv2d(64, channel, 1) 176 | self.Translayer2_1 = BasicConv2d(128, channel, 1) 177 | self.Translayer3_1 = BasicConv2d(320, channel, 1) 178 | self.Translayer4_1 = BasicConv2d(512, channel, 1) 179 | 180 | self.CFM = CFM(channel) 181 | self.ca = ChannelAttention(64) 182 | self.sa = SpatialAttention() 183 | self.SAM = SAM() 184 | 185 | self.down05 = nn.Upsample(scale_factor=0.5, mode='bilinear', align_corners=True) 186 | self.out_SAM = nn.Conv2d(channel, 1, 1) 187 | self.out_CFM = nn.Conv2d(channel, 1, 1) 188 | 189 | 190 | def forward(self, x): 191 | 192 | # backbone 193 | pvt = self.backbone(x) 194 | x1 = pvt[0] 195 | x2 = pvt[1] 196 | x3 = pvt[2] 197 | x4 = pvt[3] 198 | 199 | # CIM 200 | x1 = self.ca(x1) * x1 # channel attention 201 | cim_feature = self.sa(x1) * x1 # spatial attention 202 | 203 | 204 | # CFM 205 | x2_t = self.Translayer2_1(x2) 206 | x3_t = self.Translayer3_1(x3) 207 | x4_t = self.Translayer4_1(x4) 208 | cfm_feature = self.CFM(x4_t, x3_t, x2_t) 209 | 210 | # SAM 211 | T2 = self.Translayer2_0(cim_feature) 212 | T2 = self.down05(T2) 213 | sam_feature = self.SAM(cfm_feature, T2) 214 | 215 | prediction1 = self.out_CFM(cfm_feature) 216 | prediction2 = self.out_SAM(sam_feature) 217 | 218 | prediction1_8 = F.interpolate(prediction1, scale_factor=8, mode='bilinear') 219 | prediction2_8 = F.interpolate(prediction2, scale_factor=8, mode='bilinear') 220 | return prediction1_8, prediction2_8 221 | 222 | 223 | if __name__ == '__main__': 224 | model = PolypPVT().cuda() 225 | input_tensor = torch.randn(1, 3, 352, 352).cuda() 226 | 227 | prediction1, prediction2 = model(input_tensor) 228 | print(prediction1.size(), prediction2.size()) -------------------------------------------------------------------------------- /models/mit/mit_mla.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from functools import partial 6 | from torch.nn import Sequential, Conv2d, UpsamplingBilinear2d 7 | from timm.models.layers import DropPath, to_2tuple, trunc_normal_ 8 | from timm.models.registry import register_model 9 | from timm.models.vision_transformer import _cfg 10 | import math 11 | import cv2 12 | 13 | 14 | class Mlp(nn.Module): 15 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): 16 | super().__init__() 17 | out_features = out_features or in_features 18 | hidden_features = hidden_features or in_features 19 | self.fc1 = nn.Linear(in_features, hidden_features) 20 | self.dwconv = DWConv(hidden_features) 21 | self.act = act_layer() 22 | self.fc2 = nn.Linear(hidden_features, out_features) 23 | self.drop = nn.Dropout(drop) 24 | 25 | self.apply(self._init_weights) 26 | 27 | def _init_weights(self, m): 28 | if isinstance(m, nn.Linear): 29 | trunc_normal_(m.weight, std=.02) 30 | if isinstance(m, nn.Linear) and m.bias is not None: 31 | nn.init.constant_(m.bias, 0) 32 | elif isinstance(m, nn.LayerNorm): 33 | nn.init.constant_(m.bias, 0) 34 | nn.init.constant_(m.weight, 1.0) 35 | elif isinstance(m, nn.Conv2d): 36 | fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 37 | fan_out //= m.groups 38 | m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) 39 | if m.bias is not None: 40 | m.bias.data.zero_() 41 | 42 | def forward(self, x, H, W): 43 | x = self.fc1(x) 44 | x = self.dwconv(x, H, W) 45 | x = self.act(x) 46 | x = self.drop(x) 47 | x = self.fc2(x) 48 | x = self.drop(x) 49 | return x 50 | 51 | 52 | class Attention(nn.Module): 53 | def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., sr_ratio=1): 54 | super().__init__() 55 | assert dim % num_heads == 0, f"dim {dim} should be divided by num_heads {num_heads}." 56 | 57 | self.dim = dim 58 | self.num_heads = num_heads 59 | head_dim = dim // num_heads 60 | self.scale = qk_scale or head_dim ** -0.5 61 | 62 | self.q = nn.Linear(dim, dim, bias=qkv_bias) 63 | self.kv = nn.Linear(dim, dim * 2, bias=qkv_bias) 64 | self.attn_drop = nn.Dropout(attn_drop) 65 | self.proj = nn.Linear(dim, dim) 66 | self.proj_drop = nn.Dropout(proj_drop) 67 | 68 | self.sr_ratio = sr_ratio 69 | if sr_ratio > 1: 70 | self.sr = nn.Conv2d(dim, dim, kernel_size=sr_ratio, stride=sr_ratio) 71 | self.norm = nn.LayerNorm(dim) 72 | 73 | self.apply(self._init_weights) 74 | 75 | def _init_weights(self, m): 76 | if isinstance(m, nn.Linear): 77 | trunc_normal_(m.weight, std=.02) 78 | if isinstance(m, nn.Linear) and m.bias is not None: 79 | nn.init.constant_(m.bias, 0) 80 | elif isinstance(m, nn.LayerNorm): 81 | nn.init.constant_(m.bias, 0) 82 | nn.init.constant_(m.weight, 1.0) 83 | elif isinstance(m, nn.Conv2d): 84 | fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 85 | fan_out //= m.groups 86 | m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) 87 | if m.bias is not None: 88 | m.bias.data.zero_() 89 | 90 | def forward(self, x, H, W): 91 | B, N, C = x.shape 92 | q = self.q(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) 93 | 94 | if self.sr_ratio > 1: 95 | x_ = x.permute(0, 2, 1).reshape(B, C, H, W) 96 | x_ = self.sr(x_).reshape(B, C, -1).permute(0, 2, 1) 97 | x_ = self.norm(x_) 98 | kv = self.kv(x_).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 99 | else: 100 | kv = self.kv(x).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 101 | k, v = kv[0], kv[1] 102 | 103 | attn = (q @ k.transpose(-2, -1)) * self.scale 104 | attn = attn.softmax(dim=-1) 105 | attn = self.attn_drop(attn) 106 | 107 | x = (attn @ v).transpose(1, 2).reshape(B, N, C) 108 | x = self.proj(x) 109 | x = self.proj_drop(x) 110 | 111 | return x 112 | 113 | 114 | class Block(nn.Module): 115 | 116 | def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., 117 | drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, sr_ratio=1): 118 | super().__init__() 119 | self.norm1 = norm_layer(dim) 120 | self.attn = Attention( 121 | dim, 122 | num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, 123 | attn_drop=attn_drop, proj_drop=drop, sr_ratio=sr_ratio) 124 | # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here 125 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 126 | self.norm2 = norm_layer(dim) 127 | mlp_hidden_dim = int(dim * mlp_ratio) 128 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) 129 | 130 | def forward(self, x, H, W): 131 | x = x + self.drop_path(self.attn(self.norm1(x), H, W)) 132 | x = x + self.drop_path(self.mlp(self.norm2(x), H, W)) 133 | 134 | return x 135 | 136 | 137 | class OverlapPatchEmbed(nn.Module): 138 | """ Image to Patch Embedding 139 | """ 140 | 141 | def __init__(self, img_size=224, patch_size=7, stride=4, in_chans=3, embed_dim=768): 142 | super().__init__() 143 | img_size = to_2tuple(img_size) 144 | patch_size = to_2tuple(patch_size) 145 | 146 | self.img_size = img_size 147 | self.patch_size = patch_size 148 | self.H, self.W = img_size[0] // patch_size[0], img_size[1] // patch_size[1] 149 | self.num_patches = self.H * self.W 150 | self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=stride, padding=(patch_size[0] // 2, patch_size[1] // 2)) 151 | self.norm = nn.LayerNorm(embed_dim) 152 | 153 | def forward(self, x): 154 | x = self.proj(x) 155 | _, _, H, W = x.shape 156 | x = x.flatten(2).transpose(1, 2) 157 | x = self.norm(x) 158 | 159 | return x, H, W 160 | 161 | 162 | class MixVisionTransformer(nn.Module): 163 | def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dims=[64, 128, 256, 512], 164 | num_heads=[1, 2, 4, 8], mlp_ratios=[4, 4, 4, 4], qkv_bias=False, qk_scale=None, drop_rate=0., 165 | attn_drop_rate=0., drop_path_rate=0., norm_layer=nn.LayerNorm, 166 | depths=[3, 4, 6, 3], sr_ratios=[8, 4, 2, 1]): 167 | super().__init__() 168 | self.num_classes = num_classes 169 | self.depths = depths 170 | 171 | # patch_embed 172 | self.patch_embed1 = OverlapPatchEmbed(img_size=img_size, patch_size=7, stride=4, in_chans=in_chans, 173 | embed_dim=embed_dims[0]) 174 | self.patch_embed2 = OverlapPatchEmbed(img_size=img_size // 4, patch_size=3, stride=2, in_chans=embed_dims[0], 175 | embed_dim=embed_dims[1]) 176 | self.patch_embed3 = OverlapPatchEmbed(img_size=img_size // 8, patch_size=3, stride=2, in_chans=embed_dims[1], 177 | embed_dim=embed_dims[2]) 178 | self.patch_embed4 = OverlapPatchEmbed(img_size=img_size // 16, patch_size=3, stride=2, in_chans=embed_dims[2], 179 | embed_dim=embed_dims[3]) 180 | 181 | # transformer encoder 182 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule 183 | cur = 0 184 | self.block1 = nn.ModuleList([Block( 185 | dim=embed_dims[0], num_heads=num_heads[0], mlp_ratio=mlp_ratios[0], qkv_bias=qkv_bias, qk_scale=qk_scale, 186 | drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer, 187 | sr_ratio=sr_ratios[0]) 188 | for i in range(depths[0])]) 189 | self.norm1 = norm_layer(embed_dims[0]) 190 | 191 | cur += depths[0] 192 | self.block2 = nn.ModuleList([Block( 193 | dim=embed_dims[1], num_heads=num_heads[1], mlp_ratio=mlp_ratios[1], qkv_bias=qkv_bias, qk_scale=qk_scale, 194 | drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer, 195 | sr_ratio=sr_ratios[1]) 196 | for i in range(depths[1])]) 197 | self.norm2 = norm_layer(embed_dims[1]) 198 | 199 | cur += depths[1] 200 | self.block3 = nn.ModuleList([Block( 201 | dim=embed_dims[2], num_heads=num_heads[2], mlp_ratio=mlp_ratios[2], qkv_bias=qkv_bias, qk_scale=qk_scale, 202 | drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer, 203 | sr_ratio=sr_ratios[2]) 204 | for i in range(depths[2])]) 205 | self.norm3 = norm_layer(embed_dims[2]) 206 | 207 | cur += depths[2] 208 | self.block4 = nn.ModuleList([Block( 209 | dim=embed_dims[3], num_heads=num_heads[3], mlp_ratio=mlp_ratios[3], qkv_bias=qkv_bias, qk_scale=qk_scale, 210 | drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer, 211 | sr_ratio=sr_ratios[3]) 212 | for i in range(depths[3])]) 213 | self.norm4 = norm_layer(embed_dims[3]) 214 | 215 | 216 | def forward_features(self, x): 217 | B = x.shape[0] 218 | outs = [] 219 | 220 | # stage 1 221 | x, H, W = self.patch_embed1(x) 222 | for i, blk in enumerate(self.block1): 223 | x = blk(x, H, W) 224 | x = self.norm1(x) 225 | x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() 226 | outs.append(x) 227 | 228 | # stage 2 229 | x, H, W = self.patch_embed2(x) 230 | for i, blk in enumerate(self.block2): 231 | x = blk(x, H, W) 232 | x = self.norm2(x) 233 | x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() 234 | outs.append(x) 235 | 236 | # stage 3 237 | x, H, W = self.patch_embed3(x) 238 | for i, blk in enumerate(self.block3): 239 | x = blk(x, H, W) 240 | x = self.norm3(x) 241 | x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() 242 | outs.append(x) 243 | 244 | # stage 4 245 | x, H, W = self.patch_embed4(x) 246 | for i, blk in enumerate(self.block4): 247 | x = blk(x, H, W) 248 | x = self.norm4(x) 249 | x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() 250 | outs.append(x) 251 | 252 | return outs 253 | 254 | def forward(self, x): 255 | x = self.forward_features(x) 256 | 257 | # x = self.head(x[3]) 258 | 259 | return x 260 | 261 | 262 | class DWConv(nn.Module): 263 | def __init__(self, dim=768): 264 | super(DWConv, self).__init__() 265 | self.dwconv = nn.Conv2d(dim, dim, 3, 1, 1, bias=True, groups=dim) 266 | 267 | def forward(self, x, H, W): 268 | B, N, C = x.shape 269 | x = x.transpose(1, 2).view(B, C, H, W) 270 | x = self.dwconv(x) 271 | x = x.flatten(2).transpose(1, 2) 272 | 273 | return x 274 | 275 | 276 | class mit_b0(MixVisionTransformer): 277 | def __init__(self, **kwargs): 278 | super(mit_b0, self).__init__( 279 | patch_size=4, embed_dims=[32, 64, 160, 256], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4], 280 | qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[2, 2, 2, 2], sr_ratios=[8, 4, 2, 1], 281 | drop_rate=0.0, drop_path_rate=0.1) 282 | 283 | 284 | class mit_b1(MixVisionTransformer): 285 | def __init__(self, **kwargs): 286 | super(mit_b1, self).__init__( 287 | patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4], 288 | qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[2, 2, 2, 2], sr_ratios=[8, 4, 2, 1], 289 | drop_rate=0.0, drop_path_rate=0.1) 290 | 291 | 292 | class mit_b2(MixVisionTransformer): 293 | def __init__(self, **kwargs): 294 | super(mit_b2, self).__init__( 295 | patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4], 296 | qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[3, 4, 6, 3], sr_ratios=[8, 4, 2, 1], 297 | drop_rate=0.0, drop_path_rate=0.1) 298 | 299 | 300 | class mit_b3(MixVisionTransformer): 301 | def __init__(self, **kwargs): 302 | super(mit_b3, self).__init__( 303 | patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4], 304 | qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[3, 4, 18, 3], sr_ratios=[8, 4, 2, 1], 305 | drop_rate=0.0, drop_path_rate=0.1) 306 | 307 | 308 | class mit_b4(MixVisionTransformer): 309 | def __init__(self, **kwargs): 310 | super(mit_b4, self).__init__( 311 | patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4], 312 | qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[3, 8, 27, 3], sr_ratios=[8, 4, 2, 1], 313 | drop_rate=0.0, drop_path_rate=0.1) 314 | 315 | 316 | class mit_b5(MixVisionTransformer): 317 | def __init__(self, **kwargs): 318 | super(mit_b5, self).__init__( 319 | patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4], 320 | qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[3, 6, 40, 3], sr_ratios=[8, 4, 2, 1], 321 | drop_rate=0.0, drop_path_rate=0.1) 322 | 323 | 324 | from einops import rearrange 325 | from torch.nn import * 326 | from mmcv.cnn import build_activation_layer, build_norm_layer 327 | from timm.models.layers import DropPath 328 | from einops.layers.torch import Rearrange 329 | import numpy as np 330 | import torch 331 | from torch.nn import Module, ModuleList, Upsample 332 | from mmcv.cnn import ConvModule 333 | from torch.nn import Sequential, Conv2d, UpsamplingBilinear2d 334 | import torch.nn as nn 335 | 336 | 337 | def resize(input, 338 | size=None, 339 | scale_factor=None, 340 | mode='nearest', 341 | align_corners=None, 342 | warning=True): 343 | if warning: 344 | if size is not None and align_corners: 345 | input_h, input_w = tuple(int(x) for x in input.shape[2:]) 346 | output_h, output_w = tuple(int(x) for x in size) 347 | if output_h > input_h or output_w > output_h: 348 | if ((output_h > 1 and output_w > 1 and input_h > 1 349 | and input_w > 1) and (output_h - 1) % (input_h - 1) 350 | and (output_w - 1) % (input_w - 1)): 351 | warnings.warn( 352 | f'When align_corners={align_corners}, ' 353 | 'the output would more aligned if ' 354 | f'input size {(input_h, input_w)} is `x+1` and ' 355 | f'out size {(output_h, output_w)} is `nx+1`') 356 | return F.interpolate(input, size, scale_factor, mode, align_corners) 357 | 358 | 359 | class MLP(nn.Module): 360 | """ 361 | Linear Embedding 362 | """ 363 | 364 | def __init__(self, input_dim=512, embed_dim=768): 365 | super().__init__() 366 | self.proj = nn.Linear(input_dim, embed_dim) 367 | 368 | def forward(self, x): 369 | x = x.flatten(2).transpose(1, 2) 370 | x = self.proj(x) 371 | return x 372 | 373 | class conv(nn.Module): 374 | """ 375 | Linear Embedding 376 | """ 377 | 378 | def __init__(self, input_dim=512, embed_dim=768): 379 | super().__init__() 380 | 381 | self.proj = nn.Sequential(nn.Conv2d(input_dim, embed_dim, 3, padding=1, bias=False), build_norm_layer(dict(type='BN', requires_grad=True), embed_dim)[1],nn.ReLU(), 382 | nn.Conv2d(embed_dim, embed_dim, 3, padding=1, bias=False), build_norm_layer(dict(type='BN', requires_grad=True), embed_dim)[1],nn.ReLU()) 383 | 384 | def forward(self, x): 385 | x = self.proj(x) 386 | x = x.flatten(2).transpose(1, 2) 387 | return x 388 | 389 | from mmcv.cnn import build_norm_layer 390 | 391 | class Decoder(Module): 392 | 393 | def __init__(self, dims, dim, class_num=2): 394 | super(Decoder, self).__init__() 395 | self.num_classes = class_num 396 | 397 | c1_in_channels, c2_in_channels, c3_in_channels, c4_in_channels = dims[0], dims[1], dims[2], dims[3] 398 | embedding_dim = dim 399 | 400 | self.linear_c4 = conv(input_dim=c4_in_channels, embed_dim=embedding_dim) 401 | self.linear_c3 = conv(input_dim=c3_in_channels, embed_dim=embedding_dim) 402 | self.linear_c2 = conv(input_dim=c2_in_channels, embed_dim=embedding_dim) 403 | self.linear_c1 = conv(input_dim=c1_in_channels, embed_dim=embedding_dim) 404 | 405 | self.linear_pred = nn.Conv2d(4 * embedding_dim,self.num_classes, 3, padding=1) 406 | self.dropout = nn.Dropout(0.1) 407 | 408 | def forward(self, inputs): 409 | 410 | c1, c2, c3, c4 = inputs 411 | ############## MLP decoder on C1-C4 ########### 412 | n, _, h, w = c4.shape 413 | 414 | _c4 = self.linear_c4(c4).permute(0, 2, 1).reshape(n, -1, c4.shape[2], c4.shape[3]) 415 | _c4 = resize(_c4, size=c1.size()[2:], mode='bilinear', align_corners=True) 416 | _c3 = self.linear_c3(c3).permute(0, 2, 1).reshape(n, -1, c3.shape[2], c3.shape[3]) 417 | _c3 = resize(_c3, size=c1.size()[2:], mode='bilinear', align_corners=True) 418 | _c2 = self.linear_c2(c2).permute(0, 2, 1).reshape(n, -1, c2.shape[2], c2.shape[3]) 419 | _c2 = resize(_c2, size=c1.size()[2:], mode='bilinear', align_corners=True) 420 | _c1 = self.linear_c1(c1).permute(0, 2, 1).reshape(n, -1, c1.shape[2], c1.shape[3]) 421 | 422 | _c = self.linear_pred(torch.cat([_c4, _c3, _c2, _c1], dim=1)) 423 | 424 | return _c,_c 425 | 426 | 427 | 428 | class mit_mla(nn.Module): 429 | def __init__(self, class_num=2, **kwargs): 430 | super(mit_mla, self).__init__() 431 | self.class_num = class_num 432 | ######################################load_weight 433 | self.backbone = mit_b2() 434 | ##################################### 435 | self.decode_head = Decoder(dims=[64, 128, 320, 512], dim=128, class_num=class_num) 436 | self._init_weights() # load pretrain 437 | 438 | def forward(self, x): 439 | features = self.backbone(x) 440 | features, _c = self.decode_head(features) 441 | features = F.interpolate(features, size=x.shape[2:], mode='bilinear',align_corners=False) 442 | # return features, _c 443 | return features 444 | 445 | def _init_weights(self): 446 | pretrained_dict = torch.load('/mnt/DATA-1/DATA-2/Feilong/scformer/models/mit/mit_b2.pth') 447 | model_dict = self.backbone.state_dict() 448 | pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict} 449 | model_dict.update(pretrained_dict) 450 | self.backbone.load_state_dict(model_dict) 451 | print("successfully loaded!!!!") 452 | 453 | 454 | 455 | #MitEncoder = mit_mla(class_num=1) 456 | #from torchinfo import summary 457 | #summary(MitEncoder, (16, 3, 352, 352)) 458 | #from thop import profile 459 | #import torch 460 | #input = torch.randn(1, 3, 352, 352).to('cuda') 461 | #macs, params = profile(MitEncoder, inputs=(input, )) 462 | #print('macs:',macs/1000000000) 463 | #print('params:',params/1000000) 464 | 465 | 466 | -------------------------------------------------------------------------------- /models/pvt/__pycache__/pvt_PPD.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Qiming-Huang/ssformer/1351a07dd623c87401aa8d3a316ba03aae5ae123/models/pvt/__pycache__/pvt_PPD.cpython-37.pyc -------------------------------------------------------------------------------- /models/pvt/__pycache__/pvt_SD.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Qiming-Huang/ssformer/1351a07dd623c87401aa8d3a316ba03aae5ae123/models/pvt/__pycache__/pvt_SD.cpython-37.pyc -------------------------------------------------------------------------------- /models/pvt/__pycache__/pvt_mla.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Qiming-Huang/ssformer/1351a07dd623c87401aa8d3a316ba03aae5ae123/models/pvt/__pycache__/pvt_mla.cpython-37.pyc -------------------------------------------------------------------------------- /models/pvt/__pycache__/pvt_pup.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Qiming-Huang/ssformer/1351a07dd623c87401aa8d3a316ba03aae5ae123/models/pvt/__pycache__/pvt_pup.cpython-37.pyc -------------------------------------------------------------------------------- /models/pvt/__pycache__/pvt_srm.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Qiming-Huang/ssformer/1351a07dd623c87401aa8d3a316ba03aae5ae123/models/pvt/__pycache__/pvt_srm.cpython-37.pyc -------------------------------------------------------------------------------- /models/pvt/pvt_SD.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from functools import partial 5 | import math 6 | 7 | from mmcv.cnn import ConvModule 8 | 9 | from timm.models.layers import DropPath, to_2tuple, trunc_normal_ 10 | from timm.models.registry import register_model 11 | from timm.models.vision_transformer import _cfg 12 | from mmcv.runner import load_checkpoint 13 | from torch.nn import UpsamplingBilinear2d 14 | 15 | def resize(input, 16 | size=None, 17 | scale_factor=None, 18 | mode='nearest', 19 | align_corners=None, 20 | warning=True): 21 | if warning: 22 | if size is not None and align_corners: 23 | input_h, input_w = tuple(int(x) for x in input.shape[2:]) 24 | output_h, output_w = tuple(int(x) for x in size) 25 | if output_h > input_h or output_w > output_h: 26 | if ((output_h > 1 and output_w > 1 and input_h > 1 27 | and input_w > 1) and (output_h - 1) % (input_h - 1) 28 | and (output_w - 1) % (input_w - 1)): 29 | warnings.warn( 30 | f'When align_corners={align_corners}, ' 31 | 'the output would more aligned if ' 32 | f'input size {(input_h, input_w)} is `x+1` and ' 33 | f'out size {(output_h, output_w)} is `nx+1`') 34 | return F.interpolate(input, size, scale_factor, mode, align_corners) 35 | 36 | 37 | class Mlp(nn.Module): 38 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): 39 | super().__init__() 40 | out_features = out_features or in_features 41 | hidden_features = hidden_features or in_features 42 | self.fc1 = nn.Linear(in_features, hidden_features) 43 | self.act = act_layer() 44 | self.fc2 = nn.Linear(hidden_features, out_features) 45 | self.drop = nn.Dropout(drop) 46 | 47 | def forward(self, x): 48 | x = self.fc1(x) 49 | x = self.act(x) 50 | x = self.drop(x) 51 | x = self.fc2(x) 52 | x = self.drop(x) 53 | return x 54 | 55 | 56 | class Attention(nn.Module): 57 | def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., sr_ratio=1): 58 | super().__init__() 59 | assert dim % num_heads == 0, f"dim {dim} should be divided by num_heads {num_heads}." 60 | 61 | self.dim = dim 62 | self.num_heads = num_heads 63 | head_dim = dim // num_heads 64 | self.scale = qk_scale or head_dim ** -0.5 65 | 66 | self.q = nn.Linear(dim, dim, bias=qkv_bias) 67 | self.kv = nn.Linear(dim, dim * 2, bias=qkv_bias) 68 | self.attn_drop = nn.Dropout(attn_drop) 69 | self.proj = nn.Linear(dim, dim) 70 | self.proj_drop = nn.Dropout(proj_drop) 71 | 72 | self.sr_ratio = sr_ratio 73 | if sr_ratio > 1: 74 | self.sr = nn.Conv2d(dim, dim, kernel_size=sr_ratio, stride=sr_ratio) 75 | self.norm = nn.LayerNorm(dim) 76 | 77 | def forward(self, x, H, W): 78 | B, N, C = x.shape 79 | q = self.q(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) 80 | 81 | if self.sr_ratio > 1: 82 | x_ = x.permute(0, 2, 1).reshape(B, C, H, W) 83 | x_ = self.sr(x_).reshape(B, C, -1).permute(0, 2, 1) 84 | x_ = self.norm(x_) 85 | kv = self.kv(x_).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 86 | else: 87 | kv = self.kv(x).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 88 | k, v = kv[0], kv[1] 89 | 90 | attn = (q @ k.transpose(-2, -1)) * self.scale 91 | attn = attn.softmax(dim=-1) 92 | attn = self.attn_drop(attn) 93 | 94 | x = (attn @ v).transpose(1, 2).reshape(B, N, C) 95 | x = self.proj(x) 96 | x = self.proj_drop(x) 97 | 98 | return x 99 | 100 | 101 | class Block(nn.Module): 102 | 103 | def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., 104 | drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, sr_ratio=1): 105 | super().__init__() 106 | self.norm1 = norm_layer(dim) 107 | self.attn = Attention( 108 | dim, 109 | num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, 110 | attn_drop=attn_drop, proj_drop=drop, sr_ratio=sr_ratio) 111 | # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here 112 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 113 | self.norm2 = norm_layer(dim) 114 | mlp_hidden_dim = int(dim * mlp_ratio) 115 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) 116 | 117 | def forward(self, x, H, W): 118 | x = x + self.drop_path(self.attn(self.norm1(x), H, W)) 119 | x = x + self.drop_path(self.mlp(self.norm2(x))) 120 | 121 | return x 122 | 123 | 124 | class PatchEmbed(nn.Module): 125 | """ Image to Patch Embedding 126 | """ 127 | 128 | def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768): 129 | super().__init__() 130 | img_size = to_2tuple(img_size) 131 | patch_size = to_2tuple(patch_size) 132 | 133 | self.img_size = img_size 134 | self.patch_size = patch_size 135 | assert img_size[0] % patch_size[0] == 0 and img_size[1] % patch_size[1] == 0, \ 136 | f"img_size {img_size} should be divided by patch_size {patch_size}." 137 | self.H, self.W = img_size[0] // patch_size[0], img_size[1] // patch_size[1] 138 | self.num_patches = self.H * self.W 139 | self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) 140 | self.norm = nn.LayerNorm(embed_dim) 141 | 142 | def forward(self, x): 143 | B, C, H, W = x.shape 144 | 145 | x = self.proj(x).flatten(2).transpose(1, 2) 146 | x = self.norm(x) 147 | H, W = H // self.patch_size[0], W // self.patch_size[1] 148 | 149 | return x, (H, W) 150 | 151 | 152 | class PyramidVisionTransformer(nn.Module): 153 | def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dims=[64, 128, 256, 512], 154 | num_heads=[1, 2, 4, 8], mlp_ratios=[4, 4, 4, 4], qkv_bias=False, qk_scale=None, drop_rate=0., 155 | attn_drop_rate=0., drop_path_rate=0., norm_layer=nn.LayerNorm, depths=[3, 4, 6, 3], 156 | sr_ratios=[8, 4, 2, 1], num_stages=4, F4=False): 157 | super().__init__() 158 | self.num_classes = num_classes 159 | self.depths = depths 160 | self.F4 = F4 161 | self.num_stages = num_stages 162 | 163 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule 164 | cur = 0 165 | 166 | for i in range(num_stages): 167 | patch_embed = PatchEmbed(img_size=img_size if i == 0 else img_size // (2 ** (i + 1)), 168 | patch_size=patch_size if i == 0 else 2, 169 | in_chans=in_chans if i == 0 else embed_dims[i - 1], 170 | embed_dim=embed_dims[i]) 171 | num_patches = patch_embed.num_patches if i != num_stages - 1 else patch_embed.num_patches + 1 172 | pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dims[i])) 173 | pos_drop = nn.Dropout(p=drop_rate) 174 | 175 | block = nn.ModuleList([Block( 176 | dim=embed_dims[i], num_heads=num_heads[i], mlp_ratio=mlp_ratios[i], qkv_bias=qkv_bias, 177 | qk_scale=qk_scale, drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + j], 178 | norm_layer=norm_layer, sr_ratio=sr_ratios[i]) 179 | for j in range(depths[i])]) 180 | cur += depths[i] 181 | 182 | setattr(self, f"patch_embed{i + 1}", patch_embed) 183 | setattr(self, f"pos_embed{i + 1}", pos_embed) 184 | setattr(self, f"pos_drop{i + 1}", pos_drop) 185 | setattr(self, f"block{i + 1}", block) 186 | 187 | trunc_normal_(pos_embed, std=.02) 188 | 189 | # # init weights 190 | # self.apply(self._init_weights) 191 | # 192 | # def init_weights(self, pretrained=None): 193 | # if isinstance(pretrained, str): 194 | # logger = get_root_logger() 195 | # load_checkpoint(self, pretrained, map_location='cpu', strict=False, logger=logger) 196 | 197 | def _init_weights(self, m): 198 | if isinstance(m, nn.Linear): 199 | trunc_normal_(m.weight, std=.02) 200 | if isinstance(m, nn.Linear) and m.bias is not None: 201 | nn.init.constant_(m.bias, 0) 202 | elif isinstance(m, nn.LayerNorm): 203 | nn.init.constant_(m.bias, 0) 204 | nn.init.constant_(m.weight, 1.0) 205 | 206 | def _get_pos_embed(self, pos_embed, patch_embed, H, W): 207 | if H * W == self.patch_embed1.num_patches: 208 | return pos_embed 209 | else: 210 | return F.interpolate( 211 | pos_embed.reshape(1, patch_embed.H, patch_embed.W, -1).permute(0, 3, 1, 2), 212 | size=(H, W), mode="bilinear").reshape(1, -1, H * W).permute(0, 2, 1) 213 | 214 | def forward_features(self, x): 215 | outs = [] 216 | 217 | B = x.shape[0] 218 | 219 | for i in range(self.num_stages): 220 | patch_embed = getattr(self, f"patch_embed{i + 1}") 221 | pos_embed = getattr(self, f"pos_embed{i + 1}") 222 | pos_drop = getattr(self, f"pos_drop{i + 1}") 223 | block = getattr(self, f"block{i + 1}") 224 | x, (H, W) = patch_embed(x) 225 | if i == self.num_stages - 1: 226 | pos_embed = self._get_pos_embed(pos_embed[:, 1:], patch_embed, H, W) 227 | else: 228 | pos_embed = self._get_pos_embed(pos_embed, patch_embed, H, W) 229 | 230 | x = pos_drop(x + pos_embed) 231 | for blk in block: 232 | x = blk(x, H, W) 233 | x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() 234 | outs.append(x) 235 | 236 | return outs 237 | 238 | def forward(self, x): 239 | x = self.forward_features(x) 240 | if self.F4: 241 | x = x[3:4] 242 | return x 243 | 244 | 245 | def _conv_filter(state_dict, patch_size=16): 246 | """ convert patch embedding weight from manual patchify + linear proj to conv""" 247 | out_dict = {} 248 | for k, v in state_dict.items(): 249 | if 'patch_embed.proj.weight' in k: 250 | v = v.reshape((v.shape[0], 3, patch_size, patch_size)) 251 | out_dict[k] = v 252 | 253 | return out_dict 254 | 255 | 256 | ################## 257 | class DWConv(nn.Module): 258 | def __init__(self, dim=768): 259 | super(DWConv, self).__init__() 260 | self.dwconv = nn.Conv2d(dim, dim, 3, 1, 1, bias=True, groups=dim) 261 | self.apply(self._init_weights) 262 | 263 | def _init_weights(self, m): 264 | if isinstance(m, nn.Linear): 265 | trunc_normal_(m.weight, std=.02) 266 | if isinstance(m, nn.Linear) and m.bias is not None: 267 | nn.init.constant_(m.bias, 0) 268 | elif isinstance(m, nn.LayerNorm): 269 | nn.init.constant_(m.bias, 0) 270 | nn.init.constant_(m.weight, 1.0) 271 | elif isinstance(m, nn.Conv2d): 272 | fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 273 | fan_out //= m.groups 274 | m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) 275 | if m.bias is not None: 276 | m.bias.data.zero_() 277 | def forward(self, x, H, W): 278 | B, N, C = x.shape 279 | x = x.transpose(1, 2).view(B, C, H, W) 280 | x = self.dwconv(x) 281 | x = x.flatten(2).transpose(1, 2) 282 | 283 | return x 284 | 285 | class MLP(nn.Module): 286 | """ 287 | Linear Embedding 288 | """ 289 | def __init__(self, input_dim=512, embed_dim=768): 290 | super().__init__() 291 | self.proj = nn.Linear(input_dim, embed_dim) 292 | 293 | def forward(self, x): 294 | x = x.flatten(2).transpose(1, 2) 295 | x = self.proj(x) 296 | return x 297 | 298 | class Decoder(nn.Module): 299 | """ 300 | SegFormer: Simple and Efficient Design for Semantic Segmentation with Transformers 301 | """ 302 | 303 | def __init__(self, dims, dim, class_num=2): 304 | super(Decoder, self).__init__() 305 | self.num_classes = class_num 306 | 307 | c1_in_channels, c2_in_channels, c3_in_channels, c4_in_channels = dims[0], dims[1], dims[2], dims[3] 308 | embedding_dim = dim 309 | 310 | self.linear_c4 = MLP(input_dim=c4_in_channels, embed_dim=embedding_dim) 311 | self.linear_c3 = MLP(input_dim=c3_in_channels, embed_dim=embedding_dim) 312 | self.linear_c2 = MLP(input_dim=c2_in_channels, embed_dim=embedding_dim) 313 | self.linear_c1 = MLP(input_dim=c1_in_channels, embed_dim=embedding_dim) 314 | 315 | self.linear_fuse = ConvModule(in_channels=embedding_dim * 4, out_channels=embedding_dim, kernel_size=1,norm_cfg=dict(type='BN', requires_grad=True)) 316 | self.linear_fuse34 = ConvModule(in_channels=embedding_dim * 2, out_channels=embedding_dim, kernel_size=1,norm_cfg=dict(type='BN', requires_grad=True)) 317 | self.linear_fuse2 = ConvModule(in_channels=embedding_dim * 2, out_channels=embedding_dim, kernel_size=1,norm_cfg=dict(type='BN', requires_grad=True)) 318 | self.linear_fuse1 = ConvModule(in_channels=embedding_dim * 2, out_channels=embedding_dim, kernel_size=1,norm_cfg=dict(type='BN', requires_grad=True)) 319 | 320 | self.linear_pred = torch.nn.Conv2d(embedding_dim, self.num_classes, kernel_size=1) 321 | self.dropout = nn.Dropout(0.1) 322 | 323 | def forward(self, inputs): 324 | 325 | c1, c2, c3, c4 = inputs 326 | ############## MLP decoder on C1-C4 ########### 327 | n, _, h, w = c4.shape 328 | 329 | _c4 = self.linear_c4(c4).permute(0, 2, 1).reshape(n, -1, c4.shape[2], c4.shape[3]) 330 | _c4 = resize(_c4, size=c1.size()[2:], mode='bilinear', align_corners=False) 331 | _c3 = self.linear_c3(c3).permute(0, 2, 1).reshape(n, -1, c3.shape[2], c3.shape[3]) 332 | _c3 = resize(_c3, size=c1.size()[2:], mode='bilinear', align_corners=False) 333 | _c2 = self.linear_c2(c2).permute(0, 2, 1).reshape(n, -1, c2.shape[2], c2.shape[3]) 334 | _c2 = resize(_c2, size=c1.size()[2:], mode='bilinear', align_corners=False) 335 | _c1 = self.linear_c1(c1).permute(0, 2, 1).reshape(n, -1, c1.shape[2], c1.shape[3]) 336 | 337 | _c = self.linear_fuse(torch.cat([_c4, _c3, _c2, _c1], dim=1)) 338 | 339 | x = self.dropout(_c) 340 | x = self.linear_pred(x) 341 | return x 342 | ################## 343 | 344 | 345 | class pvt_tiny(PyramidVisionTransformer): 346 | def __init__(self, **kwargs): 347 | super(pvt_tiny, self).__init__( 348 | patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[8, 8, 4, 4], 349 | qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[2, 2, 2, 2], 350 | sr_ratios=[8, 4, 2, 1], drop_rate=0.0, drop_path_rate=0.1) 351 | 352 | 353 | 354 | class pvt_small(PyramidVisionTransformer): 355 | def __init__(self, **kwargs): 356 | super(pvt_small, self).__init__( 357 | patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[8, 8, 4, 4], 358 | qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[3, 4, 6, 3], 359 | sr_ratios=[8, 4, 2, 1], drop_rate=0.0, drop_path_rate=0.1) 360 | 361 | 362 | 363 | class pvt_medium(PyramidVisionTransformer): 364 | def __init__(self, **kwargs): 365 | super(pvt_medium, self).__init__( 366 | patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[8, 8, 4, 4], 367 | qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[3, 4, 18, 3], 368 | sr_ratios=[8, 4, 2, 1], drop_rate=0.0, drop_path_rate=0.1) 369 | 370 | 371 | 372 | class pvt_large(PyramidVisionTransformer): 373 | def __init__(self, **kwargs): 374 | super(pvt_large, self).__init__( 375 | patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[8, 8, 4, 4], 376 | qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[3, 8, 27, 3], 377 | sr_ratios=[8, 4, 2, 1], drop_rate=0.0, drop_path_rate=0.1) 378 | 379 | ## 开始pvt配合segformerd 380 | class pvt_SD(nn.Module): 381 | def __init__(self, class_num): 382 | super(pvt_SD, self).__init__() 383 | self.backbone = pvt_small() 384 | self._init_weights() 385 | self.decode_head = Decoder(dims=[64, 128, 320, 512], dim=768, class_num=class_num) 386 | 387 | def _init_weights(self): 388 | pretrained_dict = torch.load("/mnt/DATA-1/DATA-2/Feilong/scformer/models/pvt/pvt_small_iter_40000.pth") 389 | model_dict = self.backbone.state_dict() 390 | 391 | pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict} 392 | model_dict.update(pretrained_dict) 393 | self.backbone.load_state_dict(model_dict) 394 | print("successfully loaded!!!!") 395 | 396 | def forward(self, x): 397 | features = self.backbone(x) 398 | 399 | features = self.decode_head(features) 400 | up = UpsamplingBilinear2d(scale_factor=4) 401 | features = up(features) 402 | 403 | return features 404 | 405 | class pvt_large_seg(nn.Module): 406 | def __init__(self, class_num): 407 | super(pvt_large_seg, self).__init__() 408 | self.backbone = pvt_large() 409 | self.decode_head = Decoder(dims=[64, 128, 320, 512], dim=768, class_num=class_num) 410 | self._init_weights() 411 | 412 | def forward(self, x): 413 | features = self.backbone(x) 414 | 415 | features = self.decode_head(features) 416 | up = UpsamplingBilinear2d(scale_factor=4) 417 | features = up(features) 418 | 419 | return features 420 | 421 | def _init_weights(self): 422 | pretrained_dict = torch.load("/data/segformer/scformer/pretrain/pvt_large_iter_40000.pth") 423 | model_dict = self.backbone.state_dict() 424 | 425 | pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict} 426 | model_dict.update(pretrained_dict) 427 | self.backbone.load_state_dict(model_dict) 428 | print("successfully loaded!!!!") 429 | 430 | MitEncoder = pvt_SD(class_num=1) 431 | MitEncoder = MitEncoder.to('cuda') 432 | from torchinfo import summary 433 | 434 | summary(MitEncoder, (1, 3, 512, 512)) 435 | 436 | from thop import profile 437 | import torch 438 | 439 | input = torch.randn(1, 3, 352, 352).to('cuda') 440 | macs, params = profile(MitEncoder, inputs=(input,)) 441 | print('macs:', macs / 1000000000) 442 | print('params:', params / 1000000) -------------------------------------------------------------------------------- /models/pvt/pvt_mla.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from functools import partial 5 | import math 6 | 7 | from mmcv.cnn import ConvModule 8 | 9 | from timm.models.layers import DropPath, to_2tuple, trunc_normal_ 10 | from timm.models.registry import register_model 11 | from timm.models.vision_transformer import _cfg 12 | 13 | 14 | from mmcv.runner import load_checkpoint 15 | from torch.nn import UpsamplingBilinear2d 16 | 17 | def resize(input, 18 | size=None, 19 | scale_factor=None, 20 | mode='nearest', 21 | align_corners=None, 22 | warning=True): 23 | if warning: 24 | if size is not None and align_corners: 25 | input_h, input_w = tuple(int(x) for x in input.shape[2:]) 26 | output_h, output_w = tuple(int(x) for x in size) 27 | if output_h > input_h or output_w > output_h: 28 | if ((output_h > 1 and output_w > 1 and input_h > 1 29 | and input_w > 1) and (output_h - 1) % (input_h - 1) 30 | and (output_w - 1) % (input_w - 1)): 31 | warnings.warn( 32 | f'When align_corners={align_corners}, ' 33 | 'the output would more aligned if ' 34 | f'input size {(input_h, input_w)} is `x+1` and ' 35 | f'out size {(output_h, output_w)} is `nx+1`') 36 | return F.interpolate(input, size, scale_factor, mode, align_corners) 37 | 38 | 39 | class Mlp(nn.Module): 40 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): 41 | super().__init__() 42 | out_features = out_features or in_features 43 | hidden_features = hidden_features or in_features 44 | self.fc1 = nn.Linear(in_features, hidden_features) 45 | self.act = act_layer() 46 | self.fc2 = nn.Linear(hidden_features, out_features) 47 | self.drop = nn.Dropout(drop) 48 | 49 | def forward(self, x): 50 | x = self.fc1(x) 51 | x = self.act(x) 52 | x = self.drop(x) 53 | x = self.fc2(x) 54 | x = self.drop(x) 55 | return x 56 | 57 | 58 | class Attention(nn.Module): 59 | def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., sr_ratio=1): 60 | super().__init__() 61 | assert dim % num_heads == 0, f"dim {dim} should be divided by num_heads {num_heads}." 62 | 63 | self.dim = dim 64 | self.num_heads = num_heads 65 | head_dim = dim // num_heads 66 | self.scale = qk_scale or head_dim ** -0.5 67 | 68 | self.q = nn.Linear(dim, dim, bias=qkv_bias) 69 | self.kv = nn.Linear(dim, dim * 2, bias=qkv_bias) 70 | self.attn_drop = nn.Dropout(attn_drop) 71 | self.proj = nn.Linear(dim, dim) 72 | self.proj_drop = nn.Dropout(proj_drop) 73 | 74 | self.sr_ratio = sr_ratio 75 | if sr_ratio > 1: 76 | self.sr = nn.Conv2d(dim, dim, kernel_size=sr_ratio, stride=sr_ratio) 77 | self.norm = nn.LayerNorm(dim) 78 | 79 | def forward(self, x, H, W): 80 | B, N, C = x.shape 81 | q = self.q(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) 82 | 83 | if self.sr_ratio > 1: 84 | x_ = x.permute(0, 2, 1).reshape(B, C, H, W) 85 | x_ = self.sr(x_).reshape(B, C, -1).permute(0, 2, 1) 86 | x_ = self.norm(x_) 87 | kv = self.kv(x_).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 88 | else: 89 | kv = self.kv(x).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 90 | k, v = kv[0], kv[1] 91 | 92 | attn = (q @ k.transpose(-2, -1)) * self.scale 93 | attn = attn.softmax(dim=-1) 94 | attn = self.attn_drop(attn) 95 | 96 | x = (attn @ v).transpose(1, 2).reshape(B, N, C) 97 | x = self.proj(x) 98 | x = self.proj_drop(x) 99 | 100 | return x 101 | 102 | 103 | class Block(nn.Module): 104 | 105 | def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., 106 | drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, sr_ratio=1): 107 | super().__init__() 108 | self.norm1 = norm_layer(dim) 109 | self.attn = Attention( 110 | dim, 111 | num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, 112 | attn_drop=attn_drop, proj_drop=drop, sr_ratio=sr_ratio) 113 | # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here 114 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 115 | self.norm2 = norm_layer(dim) 116 | mlp_hidden_dim = int(dim * mlp_ratio) 117 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) 118 | 119 | def forward(self, x, H, W): 120 | x = x + self.drop_path(self.attn(self.norm1(x), H, W)) 121 | x = x + self.drop_path(self.mlp(self.norm2(x))) 122 | 123 | return x 124 | 125 | 126 | class PatchEmbed(nn.Module): 127 | """ Image to Patch Embedding 128 | """ 129 | 130 | def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768): 131 | super().__init__() 132 | img_size = to_2tuple(img_size) 133 | patch_size = to_2tuple(patch_size) 134 | 135 | self.img_size = img_size 136 | self.patch_size = patch_size 137 | assert img_size[0] % patch_size[0] == 0 and img_size[1] % patch_size[1] == 0, \ 138 | f"img_size {img_size} should be divided by patch_size {patch_size}." 139 | self.H, self.W = img_size[0] // patch_size[0], img_size[1] // patch_size[1] 140 | self.num_patches = self.H * self.W 141 | self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) 142 | self.norm = nn.LayerNorm(embed_dim) 143 | 144 | def forward(self, x): 145 | B, C, H, W = x.shape 146 | 147 | x = self.proj(x).flatten(2).transpose(1, 2) 148 | x = self.norm(x) 149 | H, W = H // self.patch_size[0], W // self.patch_size[1] 150 | 151 | return x, (H, W) 152 | 153 | 154 | class PyramidVisionTransformer(nn.Module): 155 | def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dims=[64, 128, 256, 512], 156 | num_heads=[1, 2, 4, 8], mlp_ratios=[4, 4, 4, 4], qkv_bias=False, qk_scale=None, drop_rate=0., 157 | attn_drop_rate=0., drop_path_rate=0., norm_layer=nn.LayerNorm, depths=[3, 4, 6, 3], 158 | sr_ratios=[8, 4, 2, 1], num_stages=4, F4=False): 159 | super().__init__() 160 | self.num_classes = num_classes 161 | self.depths = depths 162 | self.F4 = F4 163 | self.num_stages = num_stages 164 | 165 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule 166 | cur = 0 167 | 168 | for i in range(num_stages): 169 | patch_embed = PatchEmbed(img_size=img_size if i == 0 else img_size // (2 ** (i + 1)), 170 | patch_size=patch_size if i == 0 else 2, 171 | in_chans=in_chans if i == 0 else embed_dims[i - 1], 172 | embed_dim=embed_dims[i]) 173 | num_patches = patch_embed.num_patches if i != num_stages - 1 else patch_embed.num_patches + 1 174 | pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dims[i])) 175 | pos_drop = nn.Dropout(p=drop_rate) 176 | 177 | block = nn.ModuleList([Block( 178 | dim=embed_dims[i], num_heads=num_heads[i], mlp_ratio=mlp_ratios[i], qkv_bias=qkv_bias, 179 | qk_scale=qk_scale, drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + j], 180 | norm_layer=norm_layer, sr_ratio=sr_ratios[i]) 181 | for j in range(depths[i])]) 182 | cur += depths[i] 183 | 184 | setattr(self, f"patch_embed{i + 1}", patch_embed) 185 | setattr(self, f"pos_embed{i + 1}", pos_embed) 186 | setattr(self, f"pos_drop{i + 1}", pos_drop) 187 | setattr(self, f"block{i + 1}", block) 188 | 189 | trunc_normal_(pos_embed, std=.02) 190 | 191 | # init weights 192 | # self.apply(self._init_weights) 193 | # 194 | # def init_weights(self, pretrained=None): 195 | # if isinstance(pretrained, str): 196 | # logger = get_root_logger() 197 | # load_checkpoint(self, pretrained, map_location='cpu', strict=False, logger=logger) 198 | 199 | def _init_weights(self, m): 200 | if isinstance(m, nn.Linear): 201 | trunc_normal_(m.weight, std=.02) 202 | if isinstance(m, nn.Linear) and m.bias is not None: 203 | nn.init.constant_(m.bias, 0) 204 | elif isinstance(m, nn.LayerNorm): 205 | nn.init.constant_(m.bias, 0) 206 | nn.init.constant_(m.weight, 1.0) 207 | 208 | def _get_pos_embed(self, pos_embed, patch_embed, H, W): 209 | if H * W == self.patch_embed1.num_patches: 210 | return pos_embed 211 | else: 212 | return F.interpolate( 213 | pos_embed.reshape(1, patch_embed.H, patch_embed.W, -1).permute(0, 3, 1, 2), 214 | size=(H, W), mode="bilinear").reshape(1, -1, H * W).permute(0, 2, 1) 215 | 216 | def forward_features(self, x): 217 | outs = [] 218 | 219 | B = x.shape[0] 220 | 221 | for i in range(self.num_stages): 222 | patch_embed = getattr(self, f"patch_embed{i + 1}") 223 | pos_embed = getattr(self, f"pos_embed{i + 1}") 224 | pos_drop = getattr(self, f"pos_drop{i + 1}") 225 | block = getattr(self, f"block{i + 1}") 226 | x, (H, W) = patch_embed(x) 227 | if i == self.num_stages - 1: 228 | pos_embed = self._get_pos_embed(pos_embed[:, 1:], patch_embed, H, W) 229 | else: 230 | pos_embed = self._get_pos_embed(pos_embed, patch_embed, H, W) 231 | 232 | x = pos_drop(x + pos_embed) 233 | for blk in block: 234 | x = blk(x, H, W) 235 | x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() 236 | outs.append(x) 237 | 238 | return outs 239 | 240 | def forward(self, x): 241 | x = self.forward_features(x) 242 | if self.F4: 243 | x = x[3:4] 244 | 245 | return x 246 | 247 | 248 | def _conv_filter(state_dict, patch_size=16): 249 | """ convert patch embedding weight from manual patchify + linear proj to conv""" 250 | out_dict = {} 251 | for k, v in state_dict.items(): 252 | if 'patch_embed.proj.weight' in k: 253 | v = v.reshape((v.shape[0], 3, patch_size, patch_size)) 254 | out_dict[k] = v 255 | 256 | return out_dict 257 | 258 | 259 | ################## 260 | class DWConv(nn.Module): 261 | def __init__(self, dim=768): 262 | super(DWConv, self).__init__() 263 | self.dwconv = nn.Conv2d(dim, dim, 3, 1, 1, bias=True, groups=dim) 264 | self.apply(self._init_weights) 265 | 266 | def _init_weights(self, m): 267 | if isinstance(m, nn.Linear): 268 | trunc_normal_(m.weight, std=.02) 269 | if isinstance(m, nn.Linear) and m.bias is not None: 270 | nn.init.constant_(m.bias, 0) 271 | elif isinstance(m, nn.LayerNorm): 272 | nn.init.constant_(m.bias, 0) 273 | nn.init.constant_(m.weight, 1.0) 274 | elif isinstance(m, nn.Conv2d): 275 | fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 276 | fan_out //= m.groups 277 | m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) 278 | if m.bias is not None: 279 | m.bias.data.zero_() 280 | def forward(self, x, H, W): 281 | B, N, C = x.shape 282 | x = x.transpose(1, 2).view(B, C, H, W) 283 | x = self.dwconv(x) 284 | x = x.flatten(2).transpose(1, 2) 285 | 286 | return x 287 | 288 | class MLP(nn.Module): 289 | """ 290 | Linear Embedding 291 | """ 292 | def __init__(self, input_dim=512, embed_dim=768): 293 | super().__init__() 294 | self.proj = nn.Linear(input_dim, embed_dim) 295 | 296 | def forward(self, x): 297 | x = x.flatten(2).transpose(1, 2) 298 | x = self.proj(x) 299 | return x 300 | 301 | from mmcv.cnn import build_norm_layer 302 | class conv(nn.Module): 303 | """ 304 | Linear Embedding 305 | """ 306 | 307 | def __init__(self, input_dim=512, embed_dim=768): 308 | super().__init__() 309 | 310 | self.proj = nn.Sequential(nn.Conv2d(input_dim, embed_dim, 3, padding=1, bias=False), build_norm_layer(dict(type='BN', requires_grad=True), embed_dim)[1],nn.ReLU(), 311 | nn.Conv2d(embed_dim, embed_dim, 3, padding=1, bias=False), build_norm_layer(dict(type='BN', requires_grad=True), embed_dim)[1],nn.ReLU()) 312 | 313 | def forward(self, x): 314 | x = self.proj(x) 315 | x = x.flatten(2).transpose(1, 2) 316 | return x 317 | 318 | class Decoder(nn.Module): 319 | """ 320 | SegFormer: Simple and Efficient Design for Semantic Segmentation with Transformers 321 | """ 322 | 323 | def __init__(self, dims, dim, class_num=2): 324 | super(Decoder, self).__init__() 325 | self.num_classes = class_num 326 | 327 | c1_in_channels, c2_in_channels, c3_in_channels, c4_in_channels = dims[0], dims[1], dims[2], dims[3] 328 | embedding_dim = dim 329 | 330 | self.linear_c4 = conv(input_dim=c4_in_channels, embed_dim=embedding_dim) 331 | self.linear_c3 = conv(input_dim=c3_in_channels, embed_dim=embedding_dim) 332 | self.linear_c2 = conv(input_dim=c2_in_channels, embed_dim=embedding_dim) 333 | self.linear_c1 = conv(input_dim=c1_in_channels, embed_dim=embedding_dim) 334 | 335 | self.linear_fuse = ConvModule(in_channels=embedding_dim * 4, out_channels=embedding_dim, kernel_size=1, norm_cfg=dict(type='BN', requires_grad=True)) 336 | self.linear_fuse34 = ConvModule(in_channels=embedding_dim * 2, out_channels=embedding_dim, kernel_size=1,norm_cfg=dict(type='BN', requires_grad=True)) 337 | self.linear_fuse2 = ConvModule(in_channels=embedding_dim * 2, out_channels=embedding_dim, kernel_size=1,norm_cfg=dict(type='BN', requires_grad=True)) 338 | self.linear_fuse1 = ConvModule(in_channels=embedding_dim * 2, out_channels=embedding_dim, kernel_size=1,norm_cfg=dict(type='BN', requires_grad=True)) 339 | 340 | self.linear_pred = nn.Conv2d(embedding_dim, self.num_classes, kernel_size=1) 341 | self.dropout = nn.Dropout(0.1) 342 | 343 | def forward(self, inputs): 344 | 345 | c1, c2, c3, c4 = inputs 346 | ############## MLP decoder on C1-C4 ########### 347 | n, _, h, w = c4.shape 348 | 349 | _c4 = self.linear_c4(c4).permute(0, 2, 1).reshape(n, -1, c4.shape[2], c4.shape[3]) 350 | _c4 = resize(_c4, size=c1.size()[2:], mode='bilinear', align_corners=False) 351 | _c3 = self.linear_c3(c3).permute(0, 2, 1).reshape(n, -1, c3.shape[2], c3.shape[3]) 352 | _c3 = resize(_c3, size=c1.size()[2:], mode='bilinear', align_corners=False) 353 | _c2 = self.linear_c2(c2).permute(0, 2, 1).reshape(n, -1, c2.shape[2], c2.shape[3]) 354 | _c2 = resize(_c2, size=c1.size()[2:], mode='bilinear', align_corners=False) 355 | _c1 = self.linear_c1(c1).permute(0, 2, 1).reshape(n, -1, c1.shape[2], c1.shape[3]) 356 | _c = self.linear_fuse(torch.cat([_c4, _c3, _c2, _c1], dim=1)) 357 | 358 | x = self.dropout(_c) 359 | x = self.linear_pred(x) 360 | return x 361 | 362 | 363 | 364 | class pvt_tiny(PyramidVisionTransformer): 365 | def __init__(self, **kwargs): 366 | super(pvt_tiny, self).__init__( 367 | patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[8, 8, 4, 4], 368 | qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[2, 2, 2, 2], 369 | sr_ratios=[8, 4, 2, 1], drop_rate=0.0, drop_path_rate=0.1) 370 | 371 | 372 | 373 | class pvt_small(PyramidVisionTransformer): 374 | def __init__(self, **kwargs): 375 | super(pvt_small, self).__init__( 376 | patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[8, 8, 4, 4], 377 | qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[3, 4, 6, 3], 378 | sr_ratios=[8, 4, 2, 1], drop_rate=0.0, drop_path_rate=0.1) 379 | 380 | 381 | 382 | class pvt_medium(PyramidVisionTransformer): 383 | def __init__(self, **kwargs): 384 | super(pvt_medium, self).__init__( 385 | patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[8, 8, 4, 4], 386 | qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[3, 4, 18, 3], 387 | sr_ratios=[8, 4, 2, 1], drop_rate=0.0, drop_path_rate=0.1) 388 | 389 | 390 | 391 | class pvt_large(PyramidVisionTransformer): 392 | def __init__(self, **kwargs): 393 | super(pvt_large, self).__init__( 394 | patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[8, 8, 4, 4], 395 | qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[3, 8, 27, 3], 396 | sr_ratios=[8, 4, 2, 1], drop_rate=0.0, drop_path_rate=0.1) 397 | 398 | ## 开始pvt配合segformerd 399 | class pvt_mla(nn.Module): 400 | def __init__(self, class_num): 401 | super(pvt_mla, self).__init__() 402 | self.backbone = pvt_small() 403 | self.decode_head = Decoder(dims=[64, 128, 320, 512], dim=128, class_num=class_num) 404 | self._init_weights() 405 | 406 | def forward(self, x): 407 | features = self.backbone(x) 408 | features = self.decode_head(features) 409 | up = UpsamplingBilinear2d(scale_factor=4) 410 | features = up(features) 411 | 412 | return features 413 | def _init_weights(self): 414 | pretrained_dict = torch.load("/mnt/DATA-1/DATA-2/Feilong/scformer/models/pvt/pvt_small_iter_40000.pth") 415 | model_dict = self.backbone.state_dict() 416 | pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict} 417 | model_dict.update(pretrained_dict) 418 | self.backbone.load_state_dict(model_dict) 419 | print("successfully load!!!") 420 | 421 | #class pvt_srm_large_seg(nn.Module): 422 | # def __init__(self, class_num): 423 | # super(pvt_srm_large_seg, self).__init__() 424 | # self.backbone = pvt_large() 425 | # self.decode_head = Decoder(dims=[64, 128, 320, 512], dim=256, class_num=class_num) 426 | # 427 | # def forward(self, x): 428 | # features = self.backbone(x) 429 | # 430 | # features = self.decode_head(features) 431 | # up = UpsamplingBilinear2d(scale_factor=4) 432 | # features = up(features) 433 | # 434 | # return features 435 | # 436 | # def _init_weights(self): 437 | # pretrained_dict = torch.load("/data/segformer/scformer/pretrain/pvt_large_iter_40000.pth") 438 | # model_dict = self.backbone.state_dict() 439 | # 440 | # pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict} 441 | # model_dict.update(pretrained_dict) 442 | # self.backbone.load_state_dict(model_dict) 443 | # print("successfully load!!!") 444 | 445 | 446 | MitEncoder = pvt_mla(class_num=1) 447 | MitEncoder = MitEncoder.to('cuda') 448 | from torchinfo import summary 449 | 450 | summary(MitEncoder, (1, 3, 512, 512)) 451 | 452 | 453 | from thop import profile 454 | import torch 455 | 456 | input = torch.randn(1, 3, 352, 352).to('cuda') 457 | macs, params = profile(MitEncoder, inputs=(input, )) 458 | print('macs:',macs/1000000000) 459 | print('params:',params/1000000) 460 | -------------------------------------------------------------------------------- /models/simVit/__pycache__/simVit_SD.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Qiming-Huang/ssformer/1351a07dd623c87401aa8d3a316ba03aae5ae123/models/simVit/__pycache__/simVit_SD.cpython-37.pyc -------------------------------------------------------------------------------- /models/simVit/__pycache__/simVit_mla.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Qiming-Huang/ssformer/1351a07dd623c87401aa8d3a316ba03aae5ae123/models/simVit/__pycache__/simVit_mla.cpython-37.pyc -------------------------------------------------------------------------------- /models/simVit/__pycache__/simVit_pup.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Qiming-Huang/ssformer/1351a07dd623c87401aa8d3a316ba03aae5ae123/models/simVit/__pycache__/simVit_pup.cpython-37.pyc -------------------------------------------------------------------------------- /models/simVit/__pycache__/simVit_srm.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Qiming-Huang/ssformer/1351a07dd623c87401aa8d3a316ba03aae5ae123/models/simVit/__pycache__/simVit_srm.cpython-37.pyc -------------------------------------------------------------------------------- /models/ssa/__pycache__/ssa_SD.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Qiming-Huang/ssformer/1351a07dd623c87401aa8d3a316ba03aae5ae123/models/ssa/__pycache__/ssa_SD.cpython-37.pyc -------------------------------------------------------------------------------- /models/ssa/__pycache__/ssa_mla.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Qiming-Huang/ssformer/1351a07dd623c87401aa8d3a316ba03aae5ae123/models/ssa/__pycache__/ssa_mla.cpython-37.pyc -------------------------------------------------------------------------------- /models/ssa/__pycache__/ssa_pup.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Qiming-Huang/ssformer/1351a07dd623c87401aa8d3a316ba03aae5ae123/models/ssa/__pycache__/ssa_pup.cpython-37.pyc -------------------------------------------------------------------------------- /models/ssa/__pycache__/ssa_srm.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Qiming-Huang/ssformer/1351a07dd623c87401aa8d3a316ba03aae5ae123/models/ssa/__pycache__/ssa_srm.cpython-37.pyc -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py==1.0.0 2 | addict==2.4.0 3 | appdirs==1.4.4 4 | attr==0.3.1 5 | backcall==0.2.0 6 | cachetools==4.2.4 7 | certifi==2021.10.8 8 | charset-normalizer==2.0.7 9 | cityscapesScripts==2.2.0 10 | coloredlogs==15.0.1 11 | cycler==0.11.0 12 | decorator==5.1.0 13 | fonttools==4.28.5 14 | google-auth==2.3.3 15 | google-auth-oauthlib==0.4.6 16 | grpcio==1.42.0 17 | humanfriendly==10.0 18 | idna==3.3 19 | ipython==7.30.1 20 | jedi==0.18.1 21 | joblib==1.1.0 22 | kiwisolver==1.3.2 23 | kornia==0.6.5 24 | Markdown==3.3.6 25 | matplotlib==3.5.1 26 | matplotlib-inline==0.1.3 27 | mmcv-full==1.2.7 28 | numpy==1.21.4 29 | oauthlib==3.1.1 30 | opencv-python==4.5.4.60 31 | packaging==21.3 32 | pandas==1.3.4 33 | parso==0.8.3 34 | pexpect==4.8.0 35 | pickleshare==0.7.5 36 | Pillow==8.4.0 37 | prompt-toolkit==3.0.23 38 | protobuf==3.19.1 39 | ptyprocess==0.7.0 40 | pyasn1==0.4.8 41 | pyasn1-modules==0.2.8 42 | Pygments==2.10.0 43 | pyparsing==3.0.6 44 | pyquaternion==0.9.9 45 | python-dateutil==2.8.2 46 | pytz==2021.3 47 | PyYAML==6.0 48 | requests==2.26.0 49 | requests-oauthlib==1.3.0 50 | rsa==4.7.2 51 | scikit-learn==1.0.1 52 | scipy==1.7.2 53 | six==1.16.0 54 | sklearn==0.0 55 | tensorboard==2.7.0 56 | tensorboard-data-server==0.6.1 57 | tensorboard-plugin-wit==1.8.0 58 | tensorboardX==2.4 59 | terminaltables==3.1.10 60 | thop==0.0.31.post2005241907 61 | threadpoolctl==3.0.0 62 | timm==0.3.2 63 | torch==1.12.0 64 | torchaudio==0.7.2 65 | torchinfo==1.6.0 66 | torchsummary==1.5.1 67 | torchvision==0.8.2+cu110 68 | tqdm==4.62.3 69 | traitlets==5.1.1 70 | ttach==0.0.3 71 | typing==3.7.4.3 72 | typing_extensions==4.0.1 73 | urllib3==1.26.7 74 | wcwidth==0.2.5 75 | Werkzeug==2.0.2 76 | yapf==0.31.0 77 | -------------------------------------------------------------------------------- /result/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Qiming-Huang/ssformer/1351a07dd623c87401aa8d3a316ba03aae5ae123/result/.DS_Store -------------------------------------------------------------------------------- /result/ssformer_L/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Qiming-Huang/ssformer/1351a07dd623c87401aa8d3a316ba03aae5ae123/result/ssformer_L/.DS_Store -------------------------------------------------------------------------------- /result/ssformer_L/mit_srm_b4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Qiming-Huang/ssformer/1351a07dd623c87401aa8d3a316ba03aae5ae123/result/ssformer_L/mit_srm_b4.png -------------------------------------------------------------------------------- /result/ssformer_S/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Qiming-Huang/ssformer/1351a07dd623c87401aa8d3a316ba03aae5ae123/result/ssformer_S/.DS_Store -------------------------------------------------------------------------------- /result/ssformer_S/ssformer_S.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Qiming-Huang/ssformer/1351a07dd623c87401aa8d3a316ba03aae5ae123/result/ssformer_S/ssformer_S.png -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from ctypes import c_int 3 | import io 4 | from PIL import Image 5 | from models import build 6 | from loguru import logger 7 | from tqdm import tqdm 8 | import torch.nn as nn 9 | import torch.optim as optmi 10 | import torch.nn.functional as F 11 | from utils.tools import mean_dice, mean_iou, Fmeasure_calu 12 | from utils.test_dataset import CustomDataSet 13 | from torch.utils.data import DataLoader 14 | from torchvision.transforms import Compose 15 | from torchvision import transforms 16 | from torchvision.utils import save_image 17 | import torch 18 | import os 19 | import sys 20 | import numpy as np 21 | import yaml 22 | from tabulate import tabulate 23 | 24 | np.seterr(divide='ignore', invalid='ignore') 25 | 26 | f = open(sys.argv[1]) 27 | config = yaml.safe_load(f) 28 | 29 | device = config['training']['device'] 30 | model = build(model_name=config['model']['model_name'], class_num=config['dataset']['class_num']) 31 | 32 | if device == "cpu": 33 | model.load_state_dict(torch.load(config['test']['checkpoint_save_path']), map_location=torch.device('cpu')) 34 | else: 35 | model.load_state_dict(torch.load(config['test']['checkpoint_save_path']),strict=False) 36 | 37 | model = model.to(device) 38 | model.eval() 39 | 40 | train_img_root = config['dataset']['train_img_root'] 41 | train_label_root = config['dataset']['train_label_root'] 42 | 43 | 44 | # batch size !!!! 45 | batch_size = 1 46 | num_workers = config['dataset']['num_workers'] 47 | checkpoint_save_path = config['other']['checkpoint_save_path'] 48 | 49 | # training 50 | max_epoch = config['training']['max_epoch'] 51 | lr = float(config['training']['lr']) 52 | 53 | Test_transform_list = config['Test_transform_list'] 54 | torch.cuda.manual_seed_all(1) 55 | torch.manual_seed(1) 56 | 57 | #dataset = ['Kvasir'] 58 | dataset = ['CVC-300', 'CVC-ColonDB', 'CVC-ClinicDB', 'ETIS-LaribPolypDB','Kvasir'] 59 | model = model.eval() 60 | val = [] 61 | for i in dataset: 62 | print(f" predicting {i}") 63 | val_ds = CustomDataSet(config['dataset']['test_' + str(i) + '_img'], config['dataset']['test_' + str(i) + '_label'], transform_list=Test_transform_list) 64 | 65 | val_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=False, num_workers=num_workers) 66 | cot = 0 67 | total_meanDic = 0 68 | Thresholds = np.linspace(1, 0, 256) 69 | with torch.no_grad(): 70 | for idx, (img, label) in tqdm(enumerate(val_loader)): 71 | img = img.to(device) 72 | label = label.to('cpu') 73 | x = model(img) 74 | pred = torch.sigmoid(x) 75 | pred = F.interpolate(pred, size=(val_ds.image_size[cot][1], val_ds.image_size[cot][0]), mode='bilinear', align_corners=False) 76 | 77 | threshold = torch.tensor([0.5]).to(device) 78 | pred = (pred > threshold).float() * 1 79 | 80 | pre_label = pred.squeeze(1).cpu().numpy() 81 | true_label = label.squeeze(1).cpu().numpy() 82 | threshold_Dice = np.zeros((img.shape[0], len(Thresholds))) 83 | 84 | for each in range(img.shape[0]): 85 | pred = pre_label[each, :].squeeze() 86 | label_ = label[each, :] 87 | label_ = np.array(label_).astype(np.uint8).squeeze() 88 | pred = (pred - pred.min()) / (pred.max() - pred.min() + 1e-8) 89 | threshold_Dic = np.zeros(len(Thresholds)) 90 | for j, threshold in enumerate(Thresholds): 91 | _, _, _, threshold_Dic[j], _, _ = Fmeasure_calu(pred, label_, threshold) 92 | 93 | threshold_Dice[each, :] = threshold_Dic 94 | column_Dic = np.mean(threshold_Dice, axis=0) 95 | 96 | cot += 1 97 | meanDic = np.mean(column_Dic) 98 | total_meanDic = total_meanDic + meanDic 99 | val.append(total_meanDic / (idx + 1)) 100 | print(val) 101 | 102 | 103 | val = np.array(val) 104 | table_header = ['Dataset', config['model']['model_name']+'_Dice','UACANet_L_Dice','First_Dice'] 105 | table_data = [('CVC-300',str(val[0]), '0.91349','None'), 106 | ('CVC-ColonDB',str(val[1]),'0.75319','0.8474'), 107 | ('CVC-ClinicDB',str(val[2]),'0.92858','0.9420' ), 108 | ('ETIS-LaribPolypDB',str(val[3]),'0.76897','0.766'), 109 | ('Kvasir',str(val[4]),'0.90614','0.9217'), 110 | ('Average',str(val.mean()),'0.853','None'),] 111 | 112 | print(tabulate(table_data, headers=table_header, tablefmt='psql')) 113 | 114 | 115 | 116 | 117 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import torchvision 2 | from models import build 3 | from loguru import logger 4 | from tqdm import tqdm 5 | import torch.nn as nn 6 | import torch.optim as optmi 7 | import torch.nn.functional as F 8 | from utils.tools import Fmeasure_calu 9 | from utils.my_dataset import CustomDataSet 10 | from utils.loss import * 11 | from torch.utils.data import DataLoader 12 | from torchvision.transforms import Compose 13 | from torchvision import transforms 14 | import torch 15 | import os 16 | import sys 17 | import numpy as np 18 | import yaml 19 | from tabulate import tabulate 20 | from torch.optim.lr_scheduler import StepLR, ExponentialLR 21 | from warmup_scheduler import GradualWarmupScheduler 22 | from thop import profile 23 | 24 | 25 | def _thresh(img): 26 | img[img > 0.5] = 1 27 | img[img <= 0.5] = 0 28 | return img 29 | 30 | def dsc(y_pred, y_true): 31 | y_pred = _thresh(y_pred) 32 | y_true = _thresh(y_true) 33 | 34 | return dc(y_pred, y_true) 35 | np.seterr(divide='ignore', invalid='ignore') 36 | 37 | np.seterr(divide='ignore', invalid='ignore') 38 | torch.cuda.manual_seed_all(1) 39 | torch.manual_seed(1) 40 | 41 | f = open(sys.argv[1]) 42 | config = yaml.safe_load(f) 43 | 44 | evl_epoch = config['training']['evl_epoch'] 45 | 46 | # 定义模型 47 | device = config['training']['device'] 48 | model = build(model_name=config['model']['model_name'], class_num=config['dataset']['class_num']) 49 | model.to(device) 50 | 51 | input = torch.randn(1, 3, 352, 352).to('cuda') 52 | macs, params = profile(model, inputs=(input, )) 53 | print('macs:',macs/1000000000) 54 | print('params:',params/1000000) 55 | logger.info(f"| model |macs:', {macs/1000000000}, 'params:', {params/1000000}|") 56 | 57 | # if pretrained 58 | if config['model']['is_pretrained']: 59 | model.load_state_dict(torch.load(config['model']['pretrained_path'])) 60 | logger.info("successfully add pretrained model") 61 | 62 | train_img_root = config['dataset']['train_img_root'] 63 | train_label_root = config['dataset']['train_label_root'] 64 | 65 | batch_size = config['dataset']['batch_size'] 66 | num_workers = config['dataset']['num_workers'] 67 | checkpoint_save_path = config['other']['checkpoint_save_path'] 68 | 69 | # transform_list 70 | Train_transform_list = config['Train_transform_list'] 71 | Val_transform_list = config['Val_transform_list'] 72 | 73 | # training 74 | max_epoch = config['training']['max_epoch'] 75 | lr = float(config['training']['lr']) 76 | 77 | train_ds = CustomDataSet(train_img_root, train_label_root, transform_list=Train_transform_list) 78 | train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=num_workers) 79 | 80 | # criterion = nn.NLLLoss().to(device) 81 | # criterion =nn.CrossEntropyLoss().to(device) 82 | # criterion = nn.BCELoss().to(device) 83 | # criterion = AsymmetricUnifiedFocalLoss() 84 | # criterion = FocalLoss() 85 | # optimizer 86 | optim = optmi.AdamW(model.parameters(), lr=lr) 87 | 88 | # scheduler_warmup is chained with schduler_steplr 89 | scheduler_steplr = StepLR(optim, step_size=200, gamma=0.1) 90 | scheduler_warmup = GradualWarmupScheduler(optim, multiplier=1, total_epoch=1, after_scheduler=scheduler_steplr) 91 | 92 | 93 | dataset = ['CVC-300', 'CVC-ColonDB', 'CVC-ClinicDB', 'ETIS-LaribPolypDB', 'Kvasir'] 94 | # logger 95 | print(config['other']['logger_path']) 96 | logger.add(config['other']['logger_path']) 97 | # start training 98 | logger.info(f"| start training .... | current model {config['model']['model_name']} |") 99 | logger.info(f"Train_transform_list: | {Train_transform_list}|") 100 | logger.info(f"Val_transform_list: |{Val_transform_list}|") 101 | best_val_dice = [0] 102 | best_loss = [100000] 103 | from_epoch = config['model']['from_epoch'] 104 | for epoch in tqdm(range(max_epoch)): 105 | train_loss = 0 106 | model.train() 107 | epoch = epoch + int(from_epoch) 108 | scheduler_warmup.step(epoch) 109 | logger.info(f"lr: |{optim.param_groups[0]['lr']}|") 110 | for idx, (img, label) in tqdm(enumerate(train_loader)): 111 | model = model.train() 112 | img = img.to(device) 113 | label = label.to(device) 114 | out = model(img) 115 | out = nn.Sigmoid()(out) 116 | loss = dice_bce_loss(out, label) 117 | train_loss += loss.item() 118 | optim.zero_grad() 119 | loss.backward() 120 | optim.step() 121 | 122 | if (epoch + 1) % 10 == 0: 123 | logger.critical(f"saving checkpoint at {epoch}") 124 | torch.save(model.state_dict(), os.path.join(checkpoint_save_path, f"{epoch+1}.pth")) 125 | 126 | if train_loss / (idx + 1) < min(best_loss): 127 | best_loss.append(train_loss / (idx + 1)) 128 | print("train epoch done") 129 | logger.info(f"| epoch : {epoch} | training done | best loss: {train_loss / (idx + 1)} |") 130 | else: 131 | logger.info(f"| epoch : {epoch} | training done | No best loss |") 132 | 133 | if epoch >= evl_epoch: 134 | model.eval() 135 | val = [] 136 | model = model.eval() 137 | 138 | for i in dataset: 139 | print("evaluating ", i) 140 | cot = 0 141 | from utils.test_dataset import CustomDataSet as test_DataSet 142 | val_ds = test_DataSet(config['dataset']['test_' + str(i) + '_img'], config['dataset']['test_' + str(i) + '_label'], transform_list=Val_transform_list) 143 | val_loader = DataLoader(val_ds, batch_size=1, shuffle=False, num_workers=num_workers) 144 | total_meanDic = 0 145 | Thresholds = np.linspace(1, 0, 256) 146 | with torch.no_grad(): 147 | for idx, (img, label) in tqdm(enumerate(val_loader)): 148 | img = img.to(device) 149 | label = label.to('cpu') 150 | x = model(img) 151 | pred = torch.sigmoid(x) 152 | pred = F.interpolate(pred, size=(val_ds.image_size[cot][1], val_ds.image_size[cot][0]), mode='bilinear', align_corners=False) 153 | cot = cot+1 154 | threshold = torch.tensor([0.5]).to(device) 155 | pred = (pred > threshold).float() * 1 156 | pre_label = pred.squeeze(1).cpu().numpy() 157 | true_label = label.squeeze(1).cpu().numpy() 158 | threshold_Dice = np.zeros((img.shape[0], len(Thresholds))) 159 | 160 | for each in range(img.shape[0]): 161 | pred = pre_label[each, :].squeeze() 162 | label_ = true_label[each, :] 163 | label_ = np.array(label_).astype(np.uint8).squeeze() 164 | pred = (pred - pred.min()) / (pred.max() - pred.min() + 1e-8) 165 | threshold_Dic = np.zeros(len(Thresholds)) 166 | 167 | for j, threshold in enumerate(Thresholds): 168 | if j == 0: 169 | _, _, _, threshold_Dic[j], _, _ = Fmeasure_calu(pred, label_, threshold) 170 | a = threshold_Dic[j] 171 | if j == 255: 172 | _, _, _, threshold_Dic[j], _, _ = Fmeasure_calu(pred, label_, threshold) 173 | if 1 <= j <= 254: 174 | threshold_Dic[j] = a 175 | 176 | threshold_Dice[each, :] = threshold_Dic 177 | column_Dic = np.mean(threshold_Dice, axis=0) 178 | 179 | meanDic = np.mean(column_Dic) 180 | total_meanDic = total_meanDic + meanDic 181 | val.append(total_meanDic / (idx + 1)) 182 | print(val) 183 | val = np.array(val) 184 | mean_total = val.mean() 185 | logger.info(f"| val : {val} | val done |") 186 | if max(best_val_dice) <= mean_total: 187 | best_val_dice.append(mean_total) 188 | table_header = ['Dataset', config['model']['model_name'] + '_Dice', 'UACANet_L_Dice', 'First_Dice'] 189 | table_data = [('CVC-300', str(val[0]), '0.91349', 'None'), 190 | ('CVC-ColonDB', str(val[1]), '0.75319', '0.8474'), 191 | ('CVC-ClinicDB', str(val[2]), '0.92858', '0.9420'), 192 | ('ETIS-LaribPolypDB', str(val[3]), '0.76897', '0.766'), 193 | ('Kvasir', str(val[4]), '0.90614', '0.9217'), 194 | ('Average', str(val.mean()), '0.853', 'None')] 195 | 196 | logger.info(tabulate(table_data, headers=table_header, tablefmt='psql')) 197 | torch.save(model.state_dict(), os.path.join(checkpoint_save_path, "best_val.pth")) 198 | else: 199 | logger.info(f"| epoch : {epoch} | val done |") 200 | 201 | -------------------------------------------------------------------------------- /utils/PolynomialLRDecay.py: -------------------------------------------------------------------------------- 1 | from torch.optim.lr_scheduler import _LRScheduler 2 | 3 | class PolynomialLRDecay(_LRScheduler): 4 | """Polynomial learning rate decay until step reach to max_decay_step 5 | 6 | Args: 7 | optimizer (Optimizer): Wrapped optimizer. 8 | max_decay_steps: after this step, we stop decreasing learning rate 9 | end_learning_rate: scheduler stoping learning rate decay, value of learning rate must be this value 10 | power: The power of the polynomial. 11 | """ 12 | 13 | def __init__(self, optimizer, max_decay_steps, end_learning_rate=0.0001, power=1.0): 14 | if max_decay_steps <= 1.: 15 | raise ValueError('max_decay_steps should be greater than 1.') 16 | self.max_decay_steps = max_decay_steps 17 | self.end_learning_rate = end_learning_rate 18 | self.power = power 19 | self.last_step = 0 20 | super().__init__(optimizer) 21 | 22 | def get_lr(self): 23 | if self.last_step > self.max_decay_steps: 24 | return [self.end_learning_rate for _ in self.base_lrs] 25 | 26 | return [(base_lr - self.end_learning_rate) * 27 | ((1 - self.last_step / self.max_decay_steps) ** (self.power)) + 28 | self.end_learning_rate for base_lr in self.base_lrs] 29 | 30 | def step(self, step=None): 31 | if step is None: 32 | step = self.last_step + 1 33 | self.last_step = step if step != 0 else 1 34 | if self.last_step <= self.max_decay_steps: 35 | decay_lrs = [(base_lr - self.end_learning_rate) * 36 | ((1 - self.last_step / self.max_decay_steps) ** (self.power)) + 37 | self.end_learning_rate for base_lr in self.base_lrs] 38 | for param_group, lr in zip(self.optimizer.param_groups, decay_lrs): 39 | param_group['lr'] = lr -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .my_dataset import ISIC2018, DataBuilder, Datareader, CustomDataSet 2 | ## from .my_dataset import CustomisedDataSet 3 | #from .tools import legacy_mean_dice as mean_dice 4 | #from .tools import legacy_mean_iou as mean_iou 5 | #from .tools import Colorize 6 | #from .swd import swd -------------------------------------------------------------------------------- /utils/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Qiming-Huang/ssformer/1351a07dd623c87401aa8d3a316ba03aae5ae123/utils/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /utils/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Qiming-Huang/ssformer/1351a07dd623c87401aa8d3a316ba03aae5ae123/utils/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /utils/__pycache__/custom_transforms.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Qiming-Huang/ssformer/1351a07dd623c87401aa8d3a316ba03aae5ae123/utils/__pycache__/custom_transforms.cpython-37.pyc -------------------------------------------------------------------------------- /utils/__pycache__/custom_transforms.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Qiming-Huang/ssformer/1351a07dd623c87401aa8d3a316ba03aae5ae123/utils/__pycache__/custom_transforms.cpython-38.pyc -------------------------------------------------------------------------------- /utils/__pycache__/dataloader.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Qiming-Huang/ssformer/1351a07dd623c87401aa8d3a316ba03aae5ae123/utils/__pycache__/dataloader.cpython-37.pyc -------------------------------------------------------------------------------- /utils/__pycache__/eval_other.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Qiming-Huang/ssformer/1351a07dd623c87401aa8d3a316ba03aae5ae123/utils/__pycache__/eval_other.cpython-37.pyc -------------------------------------------------------------------------------- /utils/__pycache__/loss.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Qiming-Huang/ssformer/1351a07dd623c87401aa8d3a316ba03aae5ae123/utils/__pycache__/loss.cpython-37.pyc -------------------------------------------------------------------------------- /utils/__pycache__/loss.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Qiming-Huang/ssformer/1351a07dd623c87401aa8d3a316ba03aae5ae123/utils/__pycache__/loss.cpython-38.pyc -------------------------------------------------------------------------------- /utils/__pycache__/my_dataset.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Qiming-Huang/ssformer/1351a07dd623c87401aa8d3a316ba03aae5ae123/utils/__pycache__/my_dataset.cpython-37.pyc -------------------------------------------------------------------------------- /utils/__pycache__/my_dataset.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Qiming-Huang/ssformer/1351a07dd623c87401aa8d3a316ba03aae5ae123/utils/__pycache__/my_dataset.cpython-38.pyc -------------------------------------------------------------------------------- /utils/__pycache__/swd.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Qiming-Huang/ssformer/1351a07dd623c87401aa8d3a316ba03aae5ae123/utils/__pycache__/swd.cpython-37.pyc -------------------------------------------------------------------------------- /utils/__pycache__/swd.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Qiming-Huang/ssformer/1351a07dd623c87401aa8d3a316ba03aae5ae123/utils/__pycache__/swd.cpython-38.pyc -------------------------------------------------------------------------------- /utils/__pycache__/test_dataset.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Qiming-Huang/ssformer/1351a07dd623c87401aa8d3a316ba03aae5ae123/utils/__pycache__/test_dataset.cpython-37.pyc -------------------------------------------------------------------------------- /utils/__pycache__/test_dataset.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Qiming-Huang/ssformer/1351a07dd623c87401aa8d3a316ba03aae5ae123/utils/__pycache__/test_dataset.cpython-38.pyc -------------------------------------------------------------------------------- /utils/__pycache__/test_transforms.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Qiming-Huang/ssformer/1351a07dd623c87401aa8d3a316ba03aae5ae123/utils/__pycache__/test_transforms.cpython-37.pyc -------------------------------------------------------------------------------- /utils/__pycache__/test_transforms.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Qiming-Huang/ssformer/1351a07dd623c87401aa8d3a316ba03aae5ae123/utils/__pycache__/test_transforms.cpython-38.pyc -------------------------------------------------------------------------------- /utils/__pycache__/tools.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Qiming-Huang/ssformer/1351a07dd623c87401aa8d3a316ba03aae5ae123/utils/__pycache__/tools.cpython-37.pyc -------------------------------------------------------------------------------- /utils/__pycache__/tools.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Qiming-Huang/ssformer/1351a07dd623c87401aa8d3a316ba03aae5ae123/utils/__pycache__/tools.cpython-38.pyc -------------------------------------------------------------------------------- /utils/__pycache__/utils.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Qiming-Huang/ssformer/1351a07dd623c87401aa8d3a316ba03aae5ae123/utils/__pycache__/utils.cpython-37.pyc -------------------------------------------------------------------------------- /utils/custom_transforms.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from PIL import Image 3 | import cv2 4 | import torch 5 | import torch.nn.functional as F 6 | from PIL import Image, ImageOps, ImageFilter, ImageEnhance 7 | 8 | class resize: 9 | def __init__(self, size): 10 | self.size = size 11 | 12 | def __call__(self, sample): 13 | if 'image' in sample.keys(): 14 | sample['image'] = sample['image'].resize(self.size, Image.BILINEAR) 15 | if 'gt' in sample.keys(): 16 | sample['gt'] = sample['gt'].resize(self.size, Image.BILINEAR) 17 | if 'mask' in sample.keys(): 18 | sample['mask'] = sample['mask'].resize(self.size, Image.BILINEAR) 19 | 20 | return sample 21 | 22 | class random_scale_crop: 23 | def __init__(self, range=[0.75, 1.25]): 24 | self.range = range 25 | 26 | def __call__(self, sample): 27 | scale = np.random.random() * (self.range[1] - self.range[0]) + self.range[0] 28 | if np.random.random() < 0.5: 29 | for key in sample.keys(): 30 | if key in ['image', 'gt', 'contour']: 31 | base_size = sample[key].size 32 | 33 | scale_size = tuple((np.array(base_size) * scale).round().astype(int)) 34 | sample[key] = sample[key].resize(scale_size) 35 | 36 | sample[key] = sample[key].crop(((sample[key].size[0] - base_size[0]) // 2, 37 | (sample[key].size[1] - base_size[1]) // 2, 38 | (sample[key].size[0] + base_size[0]) // 2, 39 | (sample[key].size[1] + base_size[1]) // 2)) 40 | 41 | return sample 42 | 43 | class random_flip: 44 | def __init__(self, lr=True, ud=True): 45 | self.lr = lr 46 | self.ud = ud 47 | 48 | def __call__(self, sample): 49 | lr = np.random.random() < 0.5 and self.lr is True 50 | ud = np.random.random() < 0.5 and self.ud is True 51 | 52 | for key in sample.keys(): 53 | if key in ['image', 'gt', 'contour']: 54 | sample[key] = np.array(sample[key]) 55 | if lr: 56 | sample[key] = np.fliplr(sample[key]) 57 | if ud: 58 | sample[key] = np.flipud(sample[key]) 59 | sample[key] = Image.fromarray(sample[key]) 60 | 61 | return sample 62 | 63 | class random_rotate: 64 | def __init__(self, range=[0, 360], interval=1): 65 | self.range = range 66 | self.interval = interval 67 | 68 | def __call__(self, sample): 69 | rot = (np.random.randint(*self.range) // self.interval) * self.interval 70 | rot = rot + 360 if rot < 0 else rot 71 | 72 | if np.random.random() < 0.5: 73 | for key in sample.keys(): 74 | if key in ['image', 'gt', 'contour']: 75 | base_size = sample[key].size 76 | 77 | sample[key] = sample[key].rotate(rot, expand=True) 78 | 79 | sample[key] = sample[key].crop(((sample[key].size[0] - base_size[0]) // 2, 80 | (sample[key].size[1] - base_size[1]) // 2, 81 | (sample[key].size[0] + base_size[0]) // 2, 82 | (sample[key].size[1] + base_size[1]) // 2)) 83 | 84 | return sample 85 | 86 | class random_image_enhance: 87 | def __init__(self, methods=['contrast', 'brightness', 'sharpness']): 88 | self.enhance_method = [] 89 | if 'contrast' in methods: 90 | self.enhance_method.append(ImageEnhance.Contrast) 91 | if 'brightness' in methods: 92 | self.enhance_method.append(ImageEnhance.Brightness) 93 | if 'sharpness' in methods: 94 | self.enhance_method.append(ImageEnhance.Sharpness) 95 | 96 | def __call__(self, sample): 97 | image = sample['image'] 98 | np.random.shuffle(self.enhance_method) 99 | 100 | for method in self.enhance_method: 101 | if np.random.random() > 0.5: 102 | enhancer = method(image) 103 | factor = float(1 + np.random.random() / 10) 104 | image = enhancer.enhance(factor) 105 | sample['image'] = image 106 | 107 | return sample 108 | 109 | class random_dilation_erosion: 110 | def __init__(self, kernel_range): 111 | self.kernel_range = kernel_range 112 | 113 | def __call__(self, sample): 114 | gt = sample['gt'] 115 | gt = np.array(gt) 116 | key = np.random.random() 117 | # kernel = np.ones(tuple([np.random.randint(*self.kernel_range)]) * 2, dtype=np.uint8) 118 | kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (np.random.randint(*self.kernel_range), ) * 2) 119 | if key < 1/3: 120 | gt = cv2.dilate(gt, kernel) 121 | elif 1/3 <= key < 2/3: 122 | gt = cv2.erode(gt, kernel) 123 | 124 | sample['gt'] = Image.fromarray(gt) 125 | 126 | return sample 127 | 128 | class random_gaussian_blur: 129 | def __init__(self): 130 | pass 131 | 132 | def __call__(self, sample): 133 | image = sample['image'] 134 | if np.random.random() < 0.5: 135 | image = image.filter(ImageFilter.GaussianBlur(radius=np.random.random())) 136 | sample['image'] = image 137 | 138 | return sample 139 | 140 | class tonumpy: 141 | def __init__(self): 142 | pass 143 | 144 | def __call__(self, sample): 145 | image, gt = sample['image'], sample['gt'] 146 | 147 | sample['image'] = np.array(image, dtype=np.float32) 148 | sample['gt'] = np.array(gt, dtype=np.float32) 149 | 150 | return sample 151 | 152 | class normalize: 153 | def __init__(self, mean, std): 154 | self.mean = mean 155 | self.std = std 156 | 157 | def __call__(self, sample): 158 | image, gt = sample['image'], sample['gt'] 159 | image /= 255 160 | image -= self.mean 161 | image /= self.std 162 | 163 | gt /= 255 164 | sample['image'] = image 165 | sample['gt'] = gt 166 | 167 | return sample 168 | 169 | class totensor: 170 | def __init__(self): 171 | pass 172 | 173 | def __call__(self, sample): 174 | image, gt = sample['image'], sample['gt'] 175 | image = image.transpose((2, 0, 1)) 176 | image = torch.from_numpy(image).float() 177 | gt = torch.from_numpy(gt) 178 | gt = gt.unsqueeze(dim=0) 179 | sample['image'] = image 180 | sample['gt'] = gt 181 | 182 | return sample 183 | -------------------------------------------------------------------------------- /utils/dataloader.py: -------------------------------------------------------------------------------- 1 | import os 2 | from PIL import Image 3 | import torch.utils.data as data 4 | import torchvision.transforms as transforms 5 | import numpy as np 6 | import random 7 | import torch 8 | 9 | 10 | class PolypDataset(data.Dataset): 11 | """ 12 | dataloader for polyp segmentation tasks 13 | """ 14 | def __init__(self, image_root, gt_root, trainsize, augmentations): 15 | self.trainsize = trainsize 16 | self.augmentations = augmentations 17 | print(self.augmentations) 18 | self.images = [image_root + f for f in os.listdir(image_root) if f.endswith('.jpg') or f.endswith('.png')] 19 | self.gts = [gt_root + f for f in os.listdir(gt_root) if f.endswith('.png')] 20 | self.images = sorted(self.images) 21 | self.gts = sorted(self.gts) 22 | self.filter_files() 23 | self.size = len(self.images) 24 | if self.augmentations == 'True': 25 | print('Using RandomRotation, RandomFlip') 26 | self.img_transform = transforms.Compose([ 27 | transforms.RandomRotation(90, resample=False, expand=False, center=None, fill=None), 28 | transforms.RandomVerticalFlip(p=0.5), 29 | transforms.RandomHorizontalFlip(p=0.5), 30 | transforms.Resize((self.trainsize, self.trainsize)), 31 | transforms.ToTensor(), 32 | transforms.Normalize([0.485, 0.456, 0.406], 33 | [0.229, 0.224, 0.225])]) 34 | self.gt_transform = transforms.Compose([ 35 | transforms.RandomRotation(90, resample=False, expand=False, center=None, fill=None), 36 | transforms.RandomVerticalFlip(p=0.5), 37 | transforms.RandomHorizontalFlip(p=0.5), 38 | transforms.Resize((self.trainsize, self.trainsize)), 39 | transforms.ToTensor()]) 40 | 41 | else: 42 | print('no augmentation') 43 | self.img_transform = transforms.Compose([ 44 | transforms.Resize((self.trainsize, self.trainsize)), 45 | # transforms.RandomVerticalFlip(p=0.5), 46 | # transforms.RandomHorizontalFlip(p=0.5), 47 | transforms.ToTensor(), 48 | transforms.Normalize([0.485, 0.456, 0.406], 49 | [0.229, 0.224, 0.225])]) 50 | 51 | self.gt_transform = transforms.Compose([ 52 | transforms.Resize((self.trainsize, self.trainsize)), 53 | # transforms.RandomVerticalFlip(p=0.5), 54 | # transforms.RandomHorizontalFlip(p=0.5), 55 | transforms.ToTensor()]) 56 | 57 | 58 | def __getitem__(self, index): 59 | 60 | image = self.rgb_loader(self.images[index]) 61 | gt = self.binary_loader(self.gts[index]) 62 | 63 | seed = np.random.randint(2147483647) # make a seed with numpy generator 64 | random.seed(seed) # apply this seed to img tranfsorms 65 | torch.manual_seed(seed) # needed for torchvision 0.7 66 | if self.img_transform is not None: 67 | image = self.img_transform(image) 68 | 69 | random.seed(seed) # apply this seed to img tranfsorms 70 | torch.manual_seed(seed) # needed for torchvision 0.7 71 | if self.gt_transform is not None: 72 | gt = self.gt_transform(gt) 73 | return image, gt 74 | 75 | def filter_files(self): 76 | assert len(self.images) == len(self.gts) 77 | images = [] 78 | gts = [] 79 | for img_path, gt_path in zip(self.images, self.gts): 80 | img = Image.open(img_path) 81 | gt = Image.open(gt_path) 82 | if img.size == gt.size: 83 | images.append(img_path) 84 | gts.append(gt_path) 85 | self.images = images 86 | self.gts = gts 87 | 88 | def rgb_loader(self, path): 89 | with open(path, 'rb') as f: 90 | img = Image.open(f) 91 | return img.convert('RGB') 92 | 93 | def binary_loader(self, path): 94 | with open(path, 'rb') as f: 95 | img = Image.open(f) 96 | # return img.convert('1') 97 | return img.convert('L') 98 | 99 | def resize(self, img, gt): 100 | assert img.size == gt.size 101 | w, h = img.size 102 | if h < self.trainsize or w < self.trainsize: 103 | h = max(h, self.trainsize) 104 | w = max(w, self.trainsize) 105 | return img.resize((w, h), Image.BILINEAR), gt.resize((w, h), Image.NEAREST) 106 | else: 107 | return img, gt 108 | 109 | def __len__(self): 110 | return self.size 111 | 112 | 113 | def get_loader(image_root, gt_root, batchsize, trainsize, shuffle=True, num_workers=8, pin_memory=True, augmentation=False): 114 | 115 | dataset = PolypDataset(image_root, gt_root, trainsize, augmentation) 116 | data_loader = data.DataLoader(dataset=dataset, 117 | batch_size=batchsize, 118 | shuffle=shuffle, 119 | num_workers=num_workers, 120 | pin_memory=pin_memory) 121 | return data_loader 122 | 123 | 124 | class test_dataset: 125 | def __init__(self, image_root, gt_root, testsize): 126 | self.testsize = testsize 127 | self.images = [image_root + f for f in os.listdir(image_root) if f.endswith('.jpg') or f.endswith('.png')] 128 | self.gts = [gt_root + f for f in os.listdir(gt_root) if f.endswith('.tif') or f.endswith('.png')] 129 | self.images = sorted(self.images) 130 | self.gts = sorted(self.gts) 131 | self.transform = transforms.Compose([ 132 | transforms.Resize((self.testsize, self.testsize)), 133 | transforms.ToTensor(), 134 | transforms.Normalize([0.485, 0.456, 0.406], 135 | [0.229, 0.224, 0.225])]) 136 | self.gt_transform = transforms.ToTensor() 137 | self.size = len(self.images) 138 | self.index = 0 139 | 140 | def load_data(self): 141 | image = self.rgb_loader(self.images[self.index]) 142 | image = self.transform(image).unsqueeze(0) 143 | gt = self.binary_loader(self.gts[self.index]) 144 | name = self.images[self.index].split('/')[-1] 145 | if name.endswith('.jpg'): 146 | name = name.split('.jpg')[0] + '.png' 147 | self.index += 1 148 | return image, gt, name 149 | 150 | def rgb_loader(self, path): 151 | with open(path, 'rb') as f: 152 | img = Image.open(f) 153 | return img.convert('RGB') 154 | 155 | def binary_loader(self, path): 156 | with open(path, 'rb') as f: 157 | img = Image.open(f) 158 | return img.convert('L') 159 | -------------------------------------------------------------------------------- /utils/eval_FPS.py: -------------------------------------------------------------------------------- 1 | import time 2 | import torch 3 | import torch.backends.cudnn as cudnn 4 | 5 | from models.mit.mit_srm import SSformer 6 | 7 | 8 | def compute_speed(model, input_size, device, iteration=100): 9 | torch.cuda.set_device(device) 10 | cudnn.benchmark = True 11 | 12 | model.eval() 13 | model = model.cuda() 14 | 15 | input = torch.randn(*input_size, device=device) 16 | 17 | for _ in range(1): 18 | model(input) 19 | 20 | print('=========Speed Testing=========') 21 | torch.cuda.synchronize() 22 | torch.cuda.synchronize() 23 | t_start = time.time() 24 | for _ in range(iteration): 25 | model(input) 26 | torch.cuda.synchronize() 27 | torch.cuda.synchronize() 28 | elapsed_time = time.time() - t_start 29 | 30 | speed_time = elapsed_time / iteration * 1000 31 | fps = iteration / elapsed_time 32 | 33 | print('Elapsed Time: [%.2f s / %d iter]' % (elapsed_time, iteration)) 34 | print('Speed Time: %.2f ms / iter FPS: %.2f' % (speed_time, fps)) 35 | return speed_time, fps 36 | 37 | 38 | if __name__ == '__main__': 39 | 40 | device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') 41 | model = SSformer().to(device) 42 | compute_speed(model, (1, 3, 352, 352), int(0), iteration=100) 43 | -------------------------------------------------------------------------------- /utils/eval_other.py: -------------------------------------------------------------------------------- 1 | from medpy.metric.binary import recall as mp_recall 2 | from medpy.metric.binary import dc 3 | import numpy as np 4 | from medpy.metric.binary import precision as mp_precision 5 | 6 | 7 | def _thresh(img): 8 | img[img > 0.5] = 1 9 | img[img <= 0.5] = 0 10 | return img 11 | 12 | def dsc(y_pred, y_true): 13 | y_pred = _thresh(y_pred) 14 | y_true = _thresh(y_true) 15 | 16 | return dc(y_pred, y_true) 17 | 18 | def iou(y_pred, y_true): 19 | y_pred = _thresh(y_pred) 20 | y_true = _thresh(y_true) 21 | 22 | intersection = np.logical_and(y_pred, y_true) 23 | union = np.logical_or(y_pred, y_true) 24 | if not np.any(union): 25 | return 0 if np.any(y_pred) else 1 26 | 27 | return intersection.sum() / float(union.sum()) 28 | 29 | def precision(y_pred, y_true): 30 | y_pred = _thresh(y_pred).astype(np.int) 31 | y_true = _thresh(y_true).astype(np.int) 32 | 33 | if y_true.sum() <= 5: 34 | # when the example is nearly empty, avoid division by 0 35 | # if the prediction is also empty, precision is 1 36 | # otherwise it's 0 37 | return 1 if y_pred.sum() <= 5 else 0 38 | 39 | if y_pred.sum() <= 5: 40 | return 0. 41 | 42 | return mp_precision(y_pred, y_true) 43 | 44 | def recall(y_pred, y_true): 45 | y_pred = _thresh(y_pred).astype(np.int) 46 | y_true = _thresh(y_true).astype(np.int) 47 | 48 | if y_true.sum() <= 5: 49 | # when the example is nearly empty, avoid division by 0 50 | # if the prediction is also empty, recall is 1 51 | # otherwise it's 0 52 | return 1 if y_pred.sum() <= 5 else 0 53 | 54 | if y_pred.sum() <= 5: 55 | return 0. 56 | 57 | r = mp_recall(y_pred, y_true) 58 | return r 59 | 60 | 61 | -------------------------------------------------------------------------------- /utils/hlper.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import nibabel as nib 4 | import matplotlib.pyplot as plt 5 | import numpy as np 6 | import pydicom 7 | from skimage.metrics import adapted_rand_error 8 | from medpy.metric.binary import precision as mp_precision 9 | from medpy.metric.binary import recall as mp_recall 10 | from medpy.metric.binary import dc 11 | 12 | def _thresh(img): 13 | img[img > 0.5] = 1 14 | img[img <= 0.5] = 0 15 | return img 16 | 17 | def dsc(y_pred, y_true): 18 | y_pred = _thresh(y_pred) 19 | y_true = _thresh(y_true) 20 | 21 | return dc(y_pred, y_true) 22 | 23 | def iou(y_pred, y_true): 24 | y_pred = _thresh(y_pred) 25 | y_true = _thresh(y_true) 26 | 27 | intersection = np.logical_and(y_pred, y_true) 28 | union = np.logical_or(y_pred, y_true) 29 | if not np.any(union): 30 | return 0 if np.any(y_pred) else 1 31 | 32 | return intersection.sum() / float(union.sum()) 33 | 34 | def precision(y_pred, y_true): 35 | y_pred = _thresh(y_pred).astype(np.int) 36 | y_true = _thresh(y_true).astype(np.int) 37 | 38 | if y_true.sum() <= 5: 39 | # when the example is nearly empty, avoid division by 0 40 | # if the prediction is also empty, precision is 1 41 | # otherwise it's 0 42 | return 1 if y_pred.sum() <= 5 else 0 43 | 44 | if y_pred.sum() <= 5: 45 | return 0. 46 | 47 | return mp_precision(y_pred, y_true) 48 | 49 | def recall(y_pred, y_true): 50 | y_pred = _thresh(y_pred).astype(np.int) 51 | y_true = _thresh(y_true).astype(np.int) 52 | 53 | if y_true.sum() <= 5: 54 | # when the example is nearly empty, avoid division by 0 55 | # if the prediction is also empty, recall is 1 56 | # otherwise it's 0 57 | return 1 if y_pred.sum() <= 5 else 0 58 | 59 | if y_pred.sum() <= 5: 60 | return 0. 61 | 62 | r = mp_recall(y_pred, y_true) 63 | return r 64 | 65 | def listdir(path): 66 | """ List files but remove hidden files from list """ 67 | return [item for item in os.listdir(path) if item[0] != '.'] 68 | 69 | def mkdir(path): 70 | if not os.path.exists(path): 71 | os.makedirs(path) 72 | 73 | def show_images_row(imgs, titles=None, rows=1, figsize=(6.4, 4.8), **kwargs): 74 | ''' 75 | Display grid of cv2 images 76 | :param img: list [cv::mat] 77 | :param title: titles 78 | :return: None 79 | ''' 80 | assert ((titles is None) or (len(imgs) == len(titles))) 81 | num_images = len(imgs) 82 | 83 | if titles is None: 84 | titles = ['Image (%d)' % i for i in range(1, num_images + 1)] 85 | 86 | fig = plt.figure(figsize=figsize) 87 | for n, (image, title) in enumerate(zip(imgs, titles)): 88 | ax = fig.add_subplot(rows, np.ceil(num_images / float(rows)), n + 1) 89 | plt.imshow(image, **kwargs) 90 | ax.set_title(title) 91 | plt.axis('off') -------------------------------------------------------------------------------- /utils/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import torch.nn as nn 4 | 5 | def bce_iou_loss(pred, mask): 6 | weight = 1 + 5 * torch.abs(F.avg_pool2d(mask, kernel_size=31, stride=1, padding=15) - mask) 7 | 8 | bce = F.binary_cross_entropy_with_logits(pred, mask, reduction='none') 9 | 10 | pred = torch.sigmoid(pred) 11 | inter = pred * mask 12 | union = pred + mask 13 | iou = 1 - (inter + 1) / (union - inter + 1) 14 | 15 | weighted_bce = (weight * bce).sum(dim=(2, 3)) / weight.sum(dim=(2, 3)) 16 | weighted_iou = (weight * iou).sum(dim=(2, 3)) / weight.sum(dim=(2, 3)) 17 | 18 | return (weighted_bce + weighted_iou).mean() 19 | 20 | def dice_bce_loss(pred, mask): 21 | bce = F.binary_cross_entropy_with_logits(pred, mask, reduction='none') 22 | 23 | pred = torch.sigmoid(pred) 24 | inter = pred * mask 25 | union = pred + mask 26 | iou = 1 - (2. * inter + 1) / (union + 1) 27 | 28 | return (bce + iou).mean() 29 | 30 | def tversky_loss(pred, mask, alpha=0.5, beta=0.5, gamma=2): 31 | pred = torch.sigmoid(pred) 32 | 33 | #flatten label and prediction tensors 34 | pred = pred.view(-1) 35 | mask = mask.view(-1) 36 | 37 | #True Positives, False Positives & False Negatives 38 | TP = (pred * mask).sum() 39 | FP = ((1 - mask) * pred).sum() 40 | FN = (mask * (1 - pred)).sum() 41 | 42 | Tversky = (TP + 1) / (TP + alpha * FP + beta * FN + 1) 43 | 44 | return (1 - Tversky) ** gamma 45 | 46 | def tversky_bce_loss(pred, mask, alpha=0.5, beta=0.5, gamma=2): 47 | bce = F.binary_cross_entropy_with_logits(pred, mask, reduction='mean') 48 | 49 | pred = torch.sigmoid(pred) 50 | 51 | #flatten label and prediction tensors 52 | pred = pred.view(-1) 53 | mask = mask.view(-1) 54 | 55 | #True Positives, False Positives & False Negatives 56 | TP = (pred * mask).sum() 57 | FP = ((1 - mask) * pred).sum() 58 | FN = (mask * (1 - pred)).sum() 59 | 60 | Tversky = (TP + 1) / (TP + alpha * FP + beta * FN + 1) 61 | 62 | return bce + (1 - Tversky) ** gamma 63 | 64 | import torch 65 | import torch.nn as nn 66 | import torch.nn.functional as F 67 | import numpy as np 68 | 69 | 70 | class DiceLoss(nn.Module): 71 | """Dice Loss PyTorch 72 | Created by: Zhang Shuai 73 | Email: shuaizzz666@gmail.com 74 | dice_loss = 1 - 2*p*t / (p^2 + t^2). p and t represent predict and target. 75 | Args: 76 | weight: An array of shape [C,] 77 | predict: A float32 tensor of shape [N, C, *], for Semantic segmentation task is [N, C, H, W] 78 | target: A int64 tensor of shape [N, *], for Semantic segmentation task is [N, H, W] 79 | Return: 80 | diceloss 81 | """ 82 | def __init__(self, weight=None): 83 | super(DiceLoss, self).__init__() 84 | if weight is not None: 85 | weight = torch.Tensor(weight) 86 | self.weight = weight / torch.sum(weight) # Normalized weight 87 | self.smooth = 1e-5 88 | 89 | def forward(self, predict, target): 90 | N, C = predict.size()[:2] 91 | predict = predict.view(N, C, -1) # (N, C, *) 92 | target = target.view(N, 1, -1) # (N, 1, *) 93 | 94 | predict = F.softmax(predict, dim=1) # (N, C, *) ==> (N, C, *) 95 | ## convert target(N, 1, *) into one hot vector (N, C, *) 96 | target_onehot = torch.zeros(predict.size()).cuda() # (N, 1, *) ==> (N, C, *) 97 | target_onehot.scatter_(1, target, 1) # (N, C, *) 98 | 99 | intersection = torch.sum(predict * target_onehot, dim=2) # (N, C) 100 | union = torch.sum(predict.pow(2), dim=2) + torch.sum(target_onehot, dim=2) # (N, C) 101 | ## p^2 + t^2 >= 2*p*t, target_onehot^2 == target_onehot 102 | dice_coef = (2 * intersection + self.smooth) / (union + self.smooth) # (N, C) 103 | 104 | if hasattr(self, 'weight'): 105 | if self.weight.type() != predict.type(): 106 | self.weight = self.weight.type_as(predict) 107 | dice_coef = dice_coef * self.weight * C # (N, C) 108 | dice_loss = 1 - torch.mean(dice_coef) # 1 109 | 110 | return dice_loss 111 | 112 | def structure_loss(pred, mask): 113 | weit = 1 + 5*torch.abs(F.avg_pool2d(mask, kernel_size=31, stride=1, padding=15) - mask) 114 | wbce = F.binary_cross_entropy_with_logits(pred, mask, reduce='none') 115 | wbce = (weit*wbce).sum(dim=(2, 3)) / weit.sum(dim=(2, 3)) 116 | 117 | pred = torch.sigmoid(pred) 118 | inter = ((pred * mask)*weit).sum(dim=(2, 3)) 119 | union = ((pred + mask)*weit).sum(dim=(2, 3)) 120 | wiou = 1 - (inter + 1)/(union - inter+1) 121 | return (wbce + wiou).mean() 122 | 123 | class Bce_iou_loss(nn.Module): 124 | 125 | def __init__(self): 126 | super(Bce_iou_loss, self).__init__() 127 | 128 | def forward(self, pred, mask): 129 | weight = 1 + 5 * torch.abs(F.avg_pool2d(mask, kernel_size=31, stride=1, padding=15) - mask) 130 | 131 | bce = F.binary_cross_entropy_with_logits(pred, mask, reduction='none') 132 | 133 | pred = torch.sigmoid(pred) 134 | inter = pred * mask 135 | union = pred + mask 136 | iou = 1 - (inter + 1) / (union - inter + 1) 137 | 138 | weighted_bce = (weight * bce).sum(dim=(2, 3)) / weight.sum(dim=(2, 3)) 139 | weighted_iou = (weight * iou).sum(dim=(2, 3)) / weight.sum(dim=(2, 3)) 140 | 141 | return (weighted_bce + weighted_iou).mean() -------------------------------------------------------------------------------- /utils/loss2.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | # Helper function to enable loss function to be flexibly used for 6 | # both 2D or 3D image segmentation - source: https://github.com/frankkramer-lab/MIScnn 7 | 8 | def identify_axis(shape): 9 | # Three dimensional 10 | if len(shape) == 5 : return [1,2,3] 11 | 12 | # Two dimensional 13 | elif len(shape) == 4 : return [1,2] 14 | 15 | # Exception - Unknown 16 | else : raise ValueError('Metric: Shape of tensor is neither 2D or 3D.') 17 | 18 | 19 | class SymmetricFocalLoss(nn.Module): 20 | """ 21 | Parameters 22 | ---------- 23 | delta : float, optional 24 | controls weight given to false positive and false negatives, by default 0.7 25 | gamma : float, optional 26 | Focal Tversky loss' focal parameter controls degree of down-weighting of easy examples, by default 2.0 27 | epsilon : float, optional 28 | clip values to prevent division by zero error 29 | """ 30 | def __init__(self, delta=0.7, gamma=2., epsilon=1e-07): 31 | super(SymmetricFocalLoss, self).__init__() 32 | self.delta = delta 33 | self.gamma = gamma 34 | self.epsilon = epsilon 35 | 36 | def forward(self, y_pred, y_true): 37 | 38 | axis = identify_axis(y_true.size()) 39 | y_pred = torch.clamp(y_pred, self.epsilon, 1. - self.epsilon) 40 | cross_entropy = -y_true * torch.log(y_pred) 41 | 42 | #calculate losses separately for each class 43 | back_ce = torch.pow(1 - y_pred[:,:,:,0], self.gamma) * cross_entropy[:,:,:,0] 44 | back_ce = (1 - self.delta) * back_ce 45 | 46 | fore_ce = torch.pow(1 - y_pred[:,:,:,1], self.gamma) * cross_entropy[:,:,:,1] 47 | fore_ce = self.delta * fore_ce 48 | 49 | loss = torch.mean(torch.sum(torch.stack([back_ce, fore_ce], axis=-1), axis=-1)) 50 | 51 | return loss 52 | 53 | 54 | class AsymmetricFocalLoss(nn.Module): 55 | """For Imbalanced datasets 56 | Parameters 57 | ---------- 58 | delta : float, optional 59 | controls weight given to false positive and false negatives, by default 0.25 60 | gamma : float, optional 61 | Focal Tversky loss' focal parameter controls degree of down-weighting of easy examples, by default 2.0 62 | epsilon : float, optional 63 | clip values to prevent division by zero error 64 | """ 65 | def __init__(self, delta=0.25, gamma=2., epsilon=1e-07): 66 | super(AsymmetricFocalLoss, self).__init__() 67 | self.delta = delta 68 | self.gamma = gamma 69 | self.epsilon = epsilon 70 | 71 | def forward(self, y_pred, y_true): 72 | 73 | axis = identify_axis(y_true.size()) 74 | y_pred = torch.clamp(y_pred, self.epsilon, 1. - self.epsilon) 75 | cross_entropy = -y_true * torch.log(y_pred) 76 | 77 | #calculate losses separately for each class, only suppressing background class 78 | back_ce = torch.pow(1 - y_pred[:,:,:,0], self.gamma) * cross_entropy[:,:,:,0] 79 | back_ce = (1 - self.delta) * back_ce 80 | 81 | fore_ce = cross_entropy[:,:,:,1] 82 | fore_ce = self.delta * fore_ce 83 | 84 | loss = torch.mean(torch.sum(torch.stack([back_ce, fore_ce], axis=-1), axis=-1)) 85 | 86 | return loss 87 | 88 | 89 | class SymmetricFocalTverskyLoss(nn.Module): 90 | """This is the implementation for binary segmentation. 91 | Parameters 92 | ---------- 93 | delta : float, optional 94 | controls weight given to false positive and false negatives, by default 0.7 95 | gamma : float, optional 96 | focal parameter controls degree of down-weighting of easy examples, by default 0.75 97 | smooth : float, optional 98 | smooithing constant to prevent division by 0 errors, by default 0.000001 99 | epsilon : float, optional 100 | clip values to prevent division by zero error 101 | """ 102 | def __init__(self, delta=0.7, gamma=0.75, epsilon=1e-07): 103 | super(SymmetricFocalTverskyLoss, self).__init__() 104 | self.delta = delta 105 | self.gamma = gamma 106 | self.epsilon = epsilon 107 | 108 | def forward(self, y_pred, y_true): 109 | y_pred = torch.clamp(y_pred, self.epsilon, 1. - self.epsilon) 110 | axis = identify_axis(y_true.size()) 111 | 112 | # Calculate true positives (tp), false negatives (fn) and false positives (fp) 113 | tp = torch.sum(y_true * y_pred, axis=axis) 114 | fn = torch.sum(y_true * (1-y_pred), axis=axis) 115 | fp = torch.sum((1-y_true) * y_pred, axis=axis) 116 | dice_class = (tp + self.epsilon)/(tp + self.delta*fn + (1-self.delta)*fp + self.epsilon) 117 | 118 | #calculate losses separately for each class, enhancing both classes 119 | back_dice = (1-dice_class[:,0]) * torch.pow(1-dice_class[:,0], -self.gamma) 120 | fore_dice = (1-dice_class[:,1]) * torch.pow(1-dice_class[:,1], -self.gamma) 121 | 122 | # Average class scores 123 | loss = torch.mean(torch.stack([back_dice,fore_dice], axis=-1)) 124 | return loss 125 | 126 | 127 | class AsymmetricFocalTverskyLoss(nn.Module): 128 | """This is the implementation for binary segmentation. 129 | Parameters 130 | ---------- 131 | delta : float, optional 132 | controls weight given to false positive and false negatives, by default 0.7 133 | gamma : float, optional 134 | focal parameter controls degree of down-weighting of easy examples, by default 0.75 135 | smooth : float, optional 136 | smooithing constant to prevent division by 0 errors, by default 0.000001 137 | epsilon : float, optional 138 | clip values to prevent division by zero error 139 | """ 140 | def __init__(self, delta=0.7, gamma=0.75, epsilon=1e-07): 141 | super(AsymmetricFocalTverskyLoss, self).__init__() 142 | self.delta = delta 143 | self.gamma = gamma 144 | self.epsilon = epsilon 145 | 146 | def forward(self, y_pred, y_true): 147 | # Clip values to prevent division by zero error 148 | y_pred = torch.clamp(y_pred, self.epsilon, 1. - self.epsilon) 149 | axis = identify_axis(y_true.size()) 150 | 151 | # Calculate true positives (tp), false negatives (fn) and false positives (fp) 152 | tp = torch.sum(y_true * y_pred, axis=axis) 153 | fn = torch.sum(y_true * (1-y_pred), axis=axis) 154 | fp = torch.sum((1-y_true) * y_pred, axis=axis) 155 | dice_class = (tp + self.epsilon)/(tp + self.delta*fn + (1-self.delta)*fp + self.epsilon) 156 | 157 | #calculate losses separately for each class, only enhancing foreground class 158 | back_dice = (1-dice_class[:,0]) 159 | fore_dice = (1-dice_class[:,1]) * torch.pow(1-dice_class[:,1], -self.gamma) 160 | 161 | # Average class scores 162 | loss = torch.mean(torch.stack([back_dice,fore_dice], axis=-1)) 163 | return loss 164 | 165 | 166 | class SymmetricUnifiedFocalLoss(nn.Module): 167 | """The Unified Focal loss is a new compound loss function that unifies Dice-based and cross entropy-based loss functions into a single framework. 168 | Parameters 169 | ---------- 170 | weight : float, optional 171 | represents lambda parameter and controls weight given to symmetric Focal Tversky loss and symmetric Focal loss, by default 0.5 172 | delta : float, optional 173 | controls weight given to each class, by default 0.6 174 | gamma : float, optional 175 | focal parameter controls the degree of background suppression and foreground enhancement, by default 0.5 176 | epsilon : float, optional 177 | clip values to prevent division by zero error 178 | """ 179 | def __init__(self, weight=0.5, delta=0.6, gamma=0.5): 180 | super(SymmetricUnifiedFocalLoss, self).__init__() 181 | self.weight = weight 182 | self.delta = delta 183 | self.gamma = gamma 184 | 185 | def forward(self, y_pred, y_true): 186 | symmetric_ftl = SymmetricUnifiedFocalLoss(delta=self.delta, gamma=self.gamma)(y_pred, y_true) 187 | symmetric_fl = SymmetricFocalLoss(delta=self.delta, gamma=self.gamma)(y_pred, y_true) 188 | if self.weight is not None: 189 | return (self.weight * symmetric_ftl) + ((1-self.weight) * symmetric_fl) 190 | else: 191 | return symmetric_ftl + symmetric_fl 192 | 193 | 194 | class AsymmetricUnifiedFocalLoss(nn.Module): 195 | """The Unified Focal loss is a new compound loss function that unifies Dice-based and cross entropy-based loss functions into a single framework. 196 | Parameters 197 | ---------- 198 | weight : float, optional 199 | represents lambda parameter and controls weight given to asymmetric Focal Tversky loss and asymmetric Focal loss, by default 0.5 200 | delta : float, optional 201 | controls weight given to each class, by default 0.6 202 | gamma : float, optional 203 | focal parameter controls the degree of background suppression and foreground enhancement, by default 0.5 204 | epsilon : float, optional 205 | clip values to prevent division by zero error 206 | """ 207 | def __init__(self, weight=0.5, delta=0.6, gamma=0.2): 208 | super(AsymmetricUnifiedFocalLoss, self).__init__() 209 | self.weight = weight 210 | self.delta = delta 211 | self.gamma = gamma 212 | 213 | def forward(self, y_pred, y_true): 214 | # Obtain Asymmetric Focal Tversky loss 215 | asymmetric_ftl = AsymmetricFocalTverskyLoss(delta=self.delta, gamma=self.gamma)(y_pred, y_true) 216 | 217 | # Obtain Asymmetric Focal loss 218 | asymmetric_fl = AsymmetricFocalLoss(delta=self.delta, gamma=self.gamma)(y_pred, y_true) 219 | 220 | # return weighted sum of Asymmetrical Focal loss and Asymmetric Focal Tversky loss 221 | if self.weight is not None: 222 | return (self.weight * asymmetric_ftl) + ((1-self.weight) * asymmetric_fl) 223 | else: 224 | return asymmetric_ftl + asymmetric_fl 225 | 226 | 227 | import torch 228 | import torch.nn as nn 229 | import torch.nn.functional as F 230 | 231 | # 针对二分类任务的 Focal Loss 232 | class FocalLoss(nn.Module): 233 | def __init__(self, alpha=0.25, gamma=2, size_average=True): 234 | super(FocalLoss, self).__init__() 235 | self.alpha = torch.tensor(alpha).cuda() 236 | self.gamma = gamma 237 | self.size_average = size_average 238 | 239 | def forward(self, pred, target): 240 | # 如果模型最后没有 nn.Sigmoid(),那么这里就需要对预测结果计算一次 Sigmoid 操作 241 | # pred = nn.Sigmoid()(pred) 242 | 243 | # 展开 pred 和 target,此时 pred.size = target.size = (BatchSize,1) 244 | pred = pred.view(-1,1) 245 | target = target.view(-1,1) 246 | 247 | # 此处将预测样本为正负的概率都计算出来,此时 pred.size = (BatchSize,2) 248 | pred = torch.cat((1-pred,pred),dim=1) 249 | 250 | # 根据 target 生成 mask,即根据 ground truth 选择所需概率 251 | # 用大白话讲就是: 252 | # 当标签为 1 时,我们就将模型预测该样本为正类的概率代入公式中进行计算 253 | # 当标签为 0 时,我们就将模型预测该样本为负类的概率代入公式中进行计算 254 | class_mask = torch.zeros(pred.shape[0],pred.shape[1]).cuda() 255 | # 这里的 scatter_ 操作不常用,其函数原型为: 256 | # scatter_(dim,index,src)->Tensor 257 | # Writes all values from the tensor src into self at the indices specified in the index tensor. 258 | # For each value in src, its output index is specified by its index in src for dimension != dim and by the corresponding value in index for dimension = dim. 259 | class_mask.scatter_(1, target.view(-1, 1).long(), 1.) 260 | 261 | # 利用 mask 将所需概率值挑选出来 262 | probs = (pred * class_mask).sum(dim=1).view(-1,1) 263 | probs = probs.clamp(min=0.0001,max=1.0) 264 | 265 | # 计算概率的 log 值 266 | log_p = probs.log() 267 | 268 | # 根据论文中所述,对 alpha 进行设置(该参数用于调整正负样本数量不均衡带来的问题) 269 | alpha = torch.ones(pred.shape[0],pred.shape[1]).cuda() 270 | alpha[:,0] = alpha[:,0] * (1-self.alpha) 271 | alpha[:,1] = alpha[:,1] * self.alpha 272 | alpha = (alpha * class_mask).sum(dim=1).view(-1,1) 273 | 274 | # 根据 Focal Loss 的公式计算 Loss 275 | batch_loss = -alpha*(torch.pow((1-probs), self.gamma))*log_p 276 | 277 | # Loss Function的常规操作,mean 与 sum 的区别不大,相当于学习率设置不一样而已 278 | if self.size_average: 279 | loss = batch_loss.mean() 280 | else: 281 | loss = batch_loss.sum() 282 | 283 | return loss 284 | 285 | 286 | -------------------------------------------------------------------------------- /utils/my_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | from numpy.random.mtrand import seed 4 | from torch.utils.data import Dataset 5 | from PIL import Image 6 | import torch 7 | import yaml 8 | import sys 9 | import albumentations as A 10 | import torchvision.transforms as transforms 11 | from utils.custom_transforms import * 12 | 13 | #f = open(sys.argv[1]) 14 | #config = yaml.safe_load(f) 15 | 16 | # Datareaderset for isic2018 17 | class ISIC2018(Dataset): 18 | def __init__(self, train_img_root, val_img_root, train_label_root, val_label_root, crop_size, mode='train'): 19 | self.train_img_files = self.read_file(train_img_root) 20 | self.val_img_files = self.read_file(val_img_root) 21 | self.train_label_files = self.read_file(train_label_root) 22 | self.val_label_files = self.read_file(val_label_root) 23 | self.mode = mode 24 | self.crop_size = crop_size 25 | 26 | def __getitem__(self, index): 27 | if self.mode == 'train': 28 | img = Image.open(self.train_img_files[index]) 29 | label = Image.open(self.train_label_files[index]) 30 | 31 | img = img.resize((self.crop_size[0], self.crop_size[1])) 32 | label = label.resize((self.crop_size[0], self.crop_size[1])) 33 | 34 | img = np.array(img) / 255 35 | label = np.array(label) 36 | 37 | img = torch.as_tensor(img) 38 | label = torch.as_tensor(label) 39 | 40 | img = img.permute(2, 0, 1) 41 | 42 | return img.float(), label.long() 43 | 44 | if self.mode == 'val': 45 | img = Image.open(self.val_img_files[index]) 46 | label = Image.open(self.val_label_files[index]) 47 | 48 | img = img.resize((self.crop_size[0], self.crop_size[1])) 49 | label = label.resize((self.crop_size[0], self.crop_size[1])) 50 | 51 | img = np.array(img) / 255 52 | label = np.array(label) 53 | 54 | img = torch.as_tensor(img) 55 | label = torch.as_tensor(label) 56 | 57 | print(img.shape) 58 | img = img.permute(2, 0, 1) 59 | 60 | return img.float(), label.long() 61 | 62 | def __len__(self): 63 | if self.mode == 'train': 64 | total_img = len(self.train_img_files) 65 | return total_img 66 | if self.mode == 'val': 67 | total_img = len(self.val_img_files) 68 | return total_img 69 | 70 | def read_file(self, path): 71 | files_list = os.listdir(path) 72 | file_path_list = [os.path.join(path, img) for img in files_list] 73 | file_path_list.sort() 74 | return file_path_list 75 | 76 | # dataset fro Kvasir 77 | class Kvasir(Dataset): 78 | def __init__(self, train_img_root, val_img_root, train_label_root, val_label_root, crop_size, mode='train'): 79 | self.train_img_files = self.read_file(train_img_root) 80 | self.val_img_files = self.read_file(val_img_root) 81 | self.train_label_files = self.read_file(train_label_root) 82 | self.val_label_files = self.read_file(val_label_root) 83 | self.mode = mode 84 | self.crop_size = crop_size 85 | 86 | def __getitem__(self, index): 87 | if self.mode == 'train': 88 | img = Image.open(self.train_img_files[index]) 89 | label = Image.open(self.train_label_files[index]) 90 | 91 | img = img.resize((self.crop_size[0], self.crop_size[1])) 92 | label = label.resize((self.crop_size[0], self.crop_size[1])) 93 | 94 | img = np.array(img) / 255 95 | label = np.array(label) 96 | 97 | print(np.max(label)) 98 | 99 | img = torch.as_tensor(img) 100 | label = torch.as_tensor(label) 101 | 102 | img = img.permute(2, 0, 1) 103 | 104 | return img.float(), label.long() 105 | 106 | if self.mode == 'val': 107 | img = Image.open(self.val_img_files[index]) 108 | label = Image.open(self.val_label_files[index]) 109 | 110 | img = img.resize((self.crop_size[0], self.crop_size[1])) 111 | label = label.resize((self.crop_size[0], self.crop_size[1])) 112 | 113 | img = np.array(img) / 255 114 | label = np.array(label) 115 | 116 | img = torch.as_tensor(img) 117 | label = torch.as_tensor(label) 118 | 119 | img = img.permute(2, 0, 1) 120 | 121 | return img.float(), label.long() 122 | 123 | def __len__(self): 124 | if self.mode == 'train': 125 | total_img = len(self.train_img_files) 126 | return total_img 127 | if self.mode == 'val': 128 | total_img = len(self.val_img_files) 129 | return total_img 130 | 131 | def read_file(self, path): 132 | files_list = os.listdir(path) 133 | file_path_list = [os.path.join(path, img) for img in files_list] 134 | file_path_list.sort() 135 | return file_path_list 136 | 137 | # dataset fro CVC-ClinicDB 138 | class CVC(Dataset): 139 | def __init__(self, train_img_root, val_img_root, train_label_root, val_label_root, crop_size, mode='train'): 140 | self.train_img_files = self.read_file(train_img_root) 141 | self.val_img_files = self.read_file(val_img_root) 142 | self.train_label_files = self.read_file(train_label_root) 143 | self.val_label_files = self.read_file(val_label_root) 144 | self.mode = mode 145 | self.crop_size = crop_size 146 | 147 | def __getitem__(self, index): 148 | if self.mode == 'train': 149 | img = Image.open(self.train_img_files[index]) 150 | label = Image.open(self.train_label_files[index]) 151 | 152 | img = img.resize((self.crop_size[0], self.crop_size[1])) 153 | label = label.resize((self.crop_size[0], self.crop_size[1])) 154 | 155 | img = np.array(img) / 255 156 | label = np.array(label) 157 | 158 | img = torch.as_tensor(img) 159 | label = torch.as_tensor(label) 160 | 161 | img = img.permute(2, 0, 1) 162 | 163 | return img.float(), label.long() 164 | 165 | if self.mode == 'val': 166 | img = Image.open(self.val_img_files[index]) 167 | label = Image.open(self.val_label_files[index]) 168 | 169 | img = img.resize((self.crop_size[0], self.crop_size[1])) 170 | label = label.resize((self.crop_size[0], self.crop_size[1])) 171 | 172 | img = np.array(img) / 255 173 | label = np.array(label) 174 | 175 | img = torch.as_tensor(img) 176 | label = torch.as_tensor(label) 177 | 178 | img = img.permute(2, 0, 1) 179 | 180 | return img.float(), label.long() 181 | 182 | def __len__(self): 183 | if self.mode == 'train': 184 | total_img = len(self.train_img_files) 185 | return total_img 186 | if self.mode == 'val': 187 | total_img = len(self.val_img_files) 188 | return total_img 189 | 190 | def read_file(self, path): 191 | files_list = os.listdir(path) 192 | file_path_list = [os.path.join(path, img) for img in files_list] 193 | file_path_list.sort() 194 | return file_path_list 195 | 196 | 197 | # build dataset 198 | class DataBuilder(Dataset): 199 | def __init__(self, train_img_root, val_img_root, train_label_root, val_label_root, crop_size, mode='train'): 200 | self.train_img_files = self.read_file(train_img_root) 201 | self.val_img_files = self.read_file(val_img_root) 202 | self.train_label_files = self.read_file(train_label_root) 203 | self.val_label_files = self.read_file(val_label_root) 204 | self.mode = mode 205 | self.crop_size = crop_size 206 | 207 | def __getitem__(self, index): 208 | if self.mode == 'train': 209 | img = Image.open(self.train_img_files[index]) 210 | label = Image.open(self.train_label_files[index]) 211 | 212 | img = img.resize((self.crop_size[0], self.crop_size[1])) 213 | label = label.resize((self.crop_size[0], self.crop_size[1])) 214 | 215 | img = np.array(img) / 255 216 | label = np.array(label) 217 | 218 | if 'cvc' in config['dataset']['train_img_root']: 219 | # just for cvc start 220 | label = label[:,:,0] 221 | # just for cvc end 222 | img = torch.as_tensor(img) 223 | label = torch.as_tensor(label) 224 | if 'Seg' in config['dataset']['train_img_root'] or 'BRATS2015' in config['dataset']['train_img_root']: 225 | img = img.unsqueeze(0) 226 | else: 227 | img = img.permute(2, 0, 1) 228 | 229 | return img.float(), label.long() 230 | 231 | if self.mode == 'val': 232 | img = Image.open(self.val_img_files[index]) 233 | label = Image.open(self.val_label_files[index]) 234 | 235 | img = img.resize((self.crop_size[0], self.crop_size[1])) 236 | label = label.resize((self.crop_size[0], self.crop_size[1])) 237 | 238 | img = np.array(img) / 255 239 | label = np.array(label) 240 | 241 | if 'cvc' in config['dataset']['train_img_root'] and 'ETIS-LaribPolypDB' not in config['dataset']['test_label_root']: 242 | # just for cvc start 243 | label = label[:,:,0] 244 | # just for cvc end 245 | 246 | img = torch.as_tensor(img) 247 | label = torch.as_tensor(label) 248 | if 'Seg' in config['dataset']['train_img_root'] or 'BRATS2015' in config['dataset']['train_img_root']: 249 | img = img.unsqueeze(0) 250 | else: 251 | img = img.permute(2, 0, 1) 252 | 253 | return img.float(), label.long() 254 | 255 | def __len__(self): 256 | if self.mode == 'train': 257 | total_img = len(self.train_img_files) 258 | return total_img 259 | if self.mode == 'val': 260 | total_img = len(self.val_img_files) 261 | return total_img 262 | 263 | def read_file(self, path): 264 | files_list = os.listdir(path) 265 | file_path_list = [os.path.join(path, img) for img in files_list] 266 | file_path_list.sort() 267 | return file_path_list 268 | 269 | 270 | # dataset fro Kvasir 271 | class Datareader(Dataset): 272 | def __init__(self, img_root, label_root, crop_size): 273 | self.img_files = self.read_file(img_root) 274 | self.label_files = self.read_file(label_root) 275 | self.crop_size = crop_size 276 | 277 | def __getitem__(self, index): 278 | img = Image.open(self.img_files[index]) 279 | label = Image.open(self.label_files[index]) 280 | 281 | img = img.resize((self.crop_size[0], self.crop_size[1])) 282 | label = label.resize((self.crop_size[0], self.crop_size[1])) 283 | 284 | img = np.array(img) / 255 285 | label = np.array(label) 286 | 287 | img = torch.as_tensor(img) 288 | label = torch.as_tensor(label) 289 | 290 | img = img.permute(2, 0, 1) 291 | 292 | return img.float(), label.long() 293 | 294 | def __len__(self): 295 | total_img = len(self.img_files) 296 | return total_img 297 | 298 | def read_file(self, path): 299 | files_list = os.listdir(path) 300 | file_path_list = [os.path.join(path, img) for img in files_list] 301 | file_path_list.sort() 302 | return file_path_list 303 | 304 | # final version of dataset 305 | import cv2 306 | 307 | class CustomDataSet(Dataset): 308 | def __init__(self, img_path, label_path, transform_list): 309 | self.img_path = img_path 310 | self.label_path = label_path 311 | 312 | self.img_files = self.read_file(self.img_path) 313 | self.label_files = self.read_file(self.label_path) 314 | self.transform = self.get_transform(transform_list) 315 | 316 | self.image_size = [] 317 | 318 | # 初始化图片大小 319 | self._init_img_size() 320 | 321 | 322 | def __getitem__(self, index): 323 | 324 | 325 | img = Image.open(self.img_files[index]).convert('RGB') 326 | label = Image.open(self.label_files[index]).convert('L') 327 | 328 | name = self.img_files[index].split('/')[-1] 329 | if name.endswith('.jpg'): 330 | name = name.split('.jpg')[0] + '.png' 331 | shape = label.size[::-1] 332 | sample = {'image': img, 'gt': label, 'name': name, 'shape': shape} 333 | sample = self.transform(sample) 334 | img, label = sample['image'],sample['gt'] 335 | img = torch.as_tensor(np.array(img)) 336 | label = torch.as_tensor(np.array(label)) 337 | return img.float(), label.float() 338 | 339 | def __len__(self): 340 | total_img = len(self.img_files) 341 | return total_img 342 | 343 | 344 | def _init_img_size(self): 345 | for i in range(self.__len__()): 346 | img = Image.open(self.img_files[i]) 347 | self.image_size.append(img.size) 348 | 349 | 350 | def read_file(self, path): 351 | files_list = os.listdir(path) 352 | file_path_list = [os.path.join(path, img) for img in files_list] 353 | file_path_list.sort() 354 | return file_path_list 355 | 356 | @staticmethod 357 | def get_transform(transform_list): 358 | tfs = [] 359 | for key, value in zip(transform_list.keys(), transform_list.values()): 360 | if value is not None: 361 | tf = eval(key)(**value) 362 | else: 363 | tf = eval(key)() 364 | tfs.append(tf) 365 | return transforms.Compose(tfs) 366 | -------------------------------------------------------------------------------- /utils/swd.py: -------------------------------------------------------------------------------- 1 | from PIL import Image 2 | import numpy as np 3 | import torch 4 | import torch.nn.functional as F 5 | import torchvision 6 | 7 | 8 | # Gaussian blur kernel 9 | def get_gaussian_kernel(device="cpu"): 10 | kernel = np.array([ 11 | [1, 4, 6, 4, 1], 12 | [4, 16, 24, 16, 4], 13 | [6, 24, 36, 24, 6], 14 | [4, 16, 24, 16, 4], 15 | [1, 4, 6, 4, 1]], np.float32) / 256.0 16 | gaussian_k = torch.as_tensor(kernel.reshape(1, 1, 5, 5)).to(device) 17 | return gaussian_k 18 | 19 | 20 | def pyramid_down(image, device="cpu"): 21 | gaussian_k = get_gaussian_kernel(device=device) 22 | # channel-wise conv(important) 23 | multiband = [F.conv2d(image[:, i:i + 1, :, :], gaussian_k, padding=2, stride=2) for i in range(3)] 24 | down_image = torch.cat(multiband, dim=1) 25 | return down_image 26 | 27 | 28 | def pyramid_up(image, device="cpu"): 29 | gaussian_k = get_gaussian_kernel(device=device) 30 | upsample = F.interpolate(image, scale_factor=2) 31 | multiband = [F.conv2d(upsample[:, i:i + 1, :, :], gaussian_k, padding=2) for i in range(3)] 32 | up_image = torch.cat(multiband, dim=1) 33 | return up_image 34 | 35 | 36 | def gaussian_pyramid(original, n_pyramids, device="cpu"): 37 | x = original 38 | # pyramid down 39 | pyramids = [original] 40 | for i in range(n_pyramids): 41 | x = pyramid_down(x, device=device) 42 | pyramids.append(x) 43 | return pyramids 44 | 45 | 46 | def laplacian_pyramid(original, n_pyramids, device="cpu"): 47 | # create gaussian pyramid 48 | pyramids = gaussian_pyramid(original, n_pyramids, device=device) 49 | 50 | # pyramid up - diff 51 | laplacian = [] 52 | for i in range(len(pyramids) - 1): 53 | diff = pyramids[i] - pyramid_up(pyramids[i + 1], device=device) 54 | laplacian.append(diff) 55 | # Add last gaussian pyramid 56 | laplacian.append(pyramids[len(pyramids) - 1]) 57 | return laplacian 58 | 59 | 60 | def minibatch_laplacian_pyramid(image, n_pyramids, batch_size, device="cpu"): 61 | n = image.size(0) // batch_size + np.sign(image.size(0) % batch_size) 62 | pyramids = [] 63 | for i in range(n): 64 | x = image[i * batch_size:(i + 1) * batch_size] 65 | p = laplacian_pyramid(x.to(device), n_pyramids, device=device) 66 | p = [x.cpu() for x in p] 67 | pyramids.append(p) 68 | del x 69 | result = [] 70 | for i in range(n_pyramids + 1): 71 | x = [] 72 | for j in range(n): 73 | x.append(pyramids[j][i]) 74 | result.append(torch.cat(x, dim=0)) 75 | return result 76 | 77 | 78 | def extract_patches(pyramid_layer, slice_indices, 79 | slice_size=7, unfold_batch_size=128, device="cpu"): 80 | assert pyramid_layer.ndim == 4 81 | n = pyramid_layer.size(0) // unfold_batch_size + np.sign(pyramid_layer.size(0) % unfold_batch_size) 82 | # random slice 7x7 83 | p_slice = [] 84 | for i in range(n): 85 | # [unfold_batch_size, ch, n_slices, slice_size, slice_size] 86 | ind_start = i * unfold_batch_size 87 | ind_end = min((i + 1) * unfold_batch_size, pyramid_layer.size(0)) 88 | x = pyramid_layer[ind_start:ind_end].unfold( 89 | 2, slice_size, 1).unfold(3, slice_size, 1).reshape( 90 | ind_end - ind_start, pyramid_layer.size(1), -1, slice_size, slice_size) 91 | # [unfold_batch_size, ch, n_descriptors, slice_size, slice_size] 92 | x = x[:, :, slice_indices, :, :] 93 | # [unfold_batch_size, n_descriptors, ch, slice_size, slice_size] 94 | p_slice.append(x.permute([0, 2, 1, 3, 4])) 95 | # sliced tensor per layer [batch, n_descriptors, ch, slice_size, slice_size] 96 | x = torch.cat(p_slice, dim=0) 97 | # normalize along ch 98 | std, mean = torch.std_mean(x, dim=(0, 1, 3, 4), keepdim=True) 99 | x = (x - mean) / (std + 1e-8) 100 | # reshape to 2rank 101 | x = x.reshape(-1, 3 * slice_size * slice_size) 102 | return x 103 | 104 | 105 | def swd(image1, image2, 106 | n_pyramids=None, slice_size=7, n_descriptors=128, 107 | n_repeat_projection=128, proj_per_repeat=4, device="cpu", return_by_resolution=False, 108 | pyramid_batchsize=128): 109 | # n_repeat_projectton * proj_per_repeat = 512 110 | # Please change these values according to memory usage. 111 | # original = n_repeat_projection=4, proj_per_repeat=128 112 | assert image1.size() == image2.size() 113 | assert image1.ndim == 4 and image2.ndim == 4 114 | 115 | if n_pyramids is None: 116 | n_pyramids = int(np.rint(np.log2(image1.size(2) // 16))) 117 | with torch.no_grad(): 118 | # minibatch laplacian pyramid for cuda memory reasons 119 | pyramid1 = minibatch_laplacian_pyramid(image1, n_pyramids, pyramid_batchsize, device=device) 120 | pyramid2 = minibatch_laplacian_pyramid(image2, n_pyramids, pyramid_batchsize, device=device) 121 | result = [] 122 | 123 | for i_pyramid in range(n_pyramids + 1): 124 | # indices 125 | n = (pyramid1[i_pyramid].size(2) - 6) * (pyramid1[i_pyramid].size(3) - 6) 126 | indices = torch.randperm(n)[:n_descriptors] 127 | 128 | # extract patches on CPU 129 | # patch : 2rank (n_image*n_descriptors, slice_size**2*ch) 130 | p1 = extract_patches(pyramid1[i_pyramid], indices, 131 | slice_size=slice_size, device="cpu") 132 | p2 = extract_patches(pyramid2[i_pyramid], indices, 133 | slice_size=slice_size, device="cpu") 134 | 135 | p1, p2 = p1.to(device), p2.to(device) 136 | 137 | distances = [] 138 | for j in range(n_repeat_projection): 139 | # random 140 | rand = torch.randn(p1.size(1), proj_per_repeat).to(device) # (slice_size**2*ch) 141 | rand = rand / torch.std(rand, dim=0, keepdim=True) # noramlize 142 | # projection 143 | proj1 = torch.matmul(p1, rand) 144 | proj2 = torch.matmul(p2, rand) 145 | proj1, _ = torch.sort(proj1, dim=0) 146 | proj2, _ = torch.sort(proj2, dim=0) 147 | d = torch.abs(proj1 - proj2) 148 | distances.append(torch.mean(d)) 149 | 150 | # swd 151 | result.append(torch.mean(torch.stack(distances))) 152 | 153 | # average over resolution 154 | result = torch.stack(result) * 1e3 155 | if return_by_resolution: 156 | return result.cpu() 157 | else: 158 | return torch.mean(result).cpu() -------------------------------------------------------------------------------- /utils/test_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | from numpy.random.mtrand import seed 4 | from torch.utils.data import Dataset 5 | from PIL import Image 6 | import torch 7 | import yaml 8 | import sys 9 | import albumentations as A 10 | import torchvision.transforms as transforms 11 | from utils.test_transforms import * 12 | 13 | 14 | # final version of dataset 15 | import cv2 16 | class CustomDataSet(Dataset): 17 | def __init__(self, img_path, label_path, transform_list): 18 | self.img_path = img_path 19 | self.label_path = label_path 20 | self.img_files = self.read_file(self.img_path) 21 | self.label_files = self.read_file(self.label_path) 22 | self.transform = self.get_transform(transform_list) 23 | self.image_size = [] 24 | self.image_name = [] 25 | 26 | # 初始化图片大小 27 | self._init_img_size() 28 | 29 | 30 | def __getitem__(self, index): 31 | 32 | img = Image.open(self.img_files[index]).convert('RGB') 33 | label = Image.open(self.label_files[index]).convert('L') 34 | name = self.img_files[index].split('/')[-1] 35 | # self.image_name.append(name) 36 | 37 | if name.endswith('.jpg'): 38 | name = name.split('.jpg')[0] + '.png' 39 | shape = label.size[::-1] 40 | sample = {'image': img, 'gt': label, 'name': name, 'shape': shape} 41 | sample = self.transform(sample) 42 | img, label = sample['image'],sample['gt'] 43 | img = torch.as_tensor(img) 44 | label = torch.as_tensor(label) 45 | return img.float(), label.float() 46 | 47 | def __len__(self): 48 | total_img = len(self.img_files) 49 | return total_img 50 | 51 | 52 | def _init_img_size(self): 53 | for i in range(self.__len__()): 54 | img = Image.open(self.img_files[i]) 55 | self.image_size.append(img.size) 56 | name = self.img_files[i].split('/')[-1] 57 | self.image_name.append(name) 58 | 59 | 60 | def read_file(self, path): 61 | files_list = os.listdir(path) 62 | file_path_list = [os.path.join(path, img) for img in files_list] 63 | file_path_list.sort() 64 | return file_path_list 65 | 66 | @staticmethod 67 | def get_transform(transform_list): 68 | tfs = [] 69 | for key, value in zip(transform_list.keys(), transform_list.values()): 70 | if value is not None: 71 | tf = eval(key)(**value) 72 | else: 73 | tf = eval(key)() 74 | tfs.append(tf) 75 | return transforms.Compose(tfs) 76 | -------------------------------------------------------------------------------- /utils/test_transforms.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from PIL import Image 3 | import cv2 4 | import torch 5 | import torch.nn.functional as F 6 | from PIL import Image, ImageOps, ImageFilter, ImageEnhance 7 | 8 | class resize: 9 | def __init__(self, size): 10 | self.size = size 11 | 12 | def __call__(self, sample): 13 | if 'image' in sample.keys(): 14 | sample['image'] = sample['image'].resize(self.size, Image.BILINEAR) 15 | if 'gt' in sample.keys(): 16 | sample['gt'] = sample['gt'] 17 | if 'mask' in sample.keys(): 18 | sample['mask'] = sample['mask'] 19 | 20 | return sample 21 | 22 | class random_scale_crop: 23 | def __init__(self, range=[0.75, 1.25]): 24 | self.range = range 25 | 26 | def __call__(self, sample): 27 | scale = np.random.random() * (self.range[1] - self.range[0]) + self.range[0] 28 | if np.random.random() < 0.5: 29 | for key in sample.keys(): 30 | if key in ['image', 'gt', 'contour']: 31 | base_size = sample[key].size 32 | 33 | scale_size = tuple((np.array(base_size) * scale).round().astype(int)) 34 | sample[key] = sample[key].resize(scale_size) 35 | 36 | sample[key] = sample[key].crop(((sample[key].size[0] - base_size[0]) // 2, 37 | (sample[key].size[1] - base_size[1]) // 2, 38 | (sample[key].size[0] + base_size[0]) // 2, 39 | (sample[key].size[1] + base_size[1]) // 2)) 40 | 41 | return sample 42 | 43 | class random_flip: 44 | def __init__(self, lr=True, ud=True): 45 | self.lr = lr 46 | self.ud = ud 47 | 48 | def __call__(self, sample): 49 | lr = np.random.random() < 0.5 and self.lr is True 50 | ud = np.random.random() < 0.5 and self.ud is True 51 | 52 | for key in sample.keys(): 53 | if key in ['image', 'gt', 'contour']: 54 | sample[key] = np.array(sample[key]) 55 | if lr: 56 | sample[key] = np.fliplr(sample[key]) 57 | if ud: 58 | sample[key] = np.flipud(sample[key]) 59 | sample[key] = Image.fromarray(sample[key]) 60 | 61 | return sample 62 | 63 | class random_rotate: 64 | def __init__(self, range=[0, 360], interval=1): 65 | self.range = range 66 | self.interval = interval 67 | 68 | def __call__(self, sample): 69 | rot = (np.random.randint(*self.range) // self.interval) * self.interval 70 | rot = rot + 360 if rot < 0 else rot 71 | 72 | if np.random.random() < 0.5: 73 | for key in sample.keys(): 74 | if key in ['image', 'gt', 'contour']: 75 | base_size = sample[key].size 76 | 77 | sample[key] = sample[key].rotate(rot, expand=True) 78 | 79 | sample[key] = sample[key].crop(((sample[key].size[0] - base_size[0]) // 2, 80 | (sample[key].size[1] - base_size[1]) // 2, 81 | (sample[key].size[0] + base_size[0]) // 2, 82 | (sample[key].size[1] + base_size[1]) // 2)) 83 | 84 | return sample 85 | 86 | class random_image_enhance: 87 | def __init__(self, methods=['contrast', 'brightness', 'sharpness']): 88 | self.enhance_method = [] 89 | if 'contrast' in methods: 90 | self.enhance_method.append(ImageEnhance.Contrast) 91 | if 'brightness' in methods: 92 | self.enhance_method.append(ImageEnhance.Brightness) 93 | if 'sharpness' in methods: 94 | self.enhance_method.append(ImageEnhance.Sharpness) 95 | 96 | def __call__(self, sample): 97 | image = sample['image'] 98 | np.random.shuffle(self.enhance_method) 99 | 100 | for method in self.enhance_method: 101 | if np.random.random() > 0.5: 102 | enhancer = method(image) 103 | factor = float(1 + np.random.random() / 10) 104 | image = enhancer.enhance(factor) 105 | sample['image'] = image 106 | 107 | return sample 108 | 109 | class random_dilation_erosion: 110 | def __init__(self, kernel_range): 111 | self.kernel_range = kernel_range 112 | 113 | def __call__(self, sample): 114 | gt = sample['gt'] 115 | gt = np.array(gt) 116 | key = np.random.random() 117 | # kernel = np.ones(tuple([np.random.randint(*self.kernel_range)]) * 2, dtype=np.uint8) 118 | kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (np.random.randint(*self.kernel_range), ) * 2) 119 | if key < 1/3: 120 | gt = cv2.dilate(gt, kernel) 121 | elif 1/3 <= key < 2/3: 122 | gt = cv2.erode(gt, kernel) 123 | 124 | sample['gt'] = Image.fromarray(gt) 125 | 126 | return sample 127 | 128 | class random_gaussian_blur: 129 | def __init__(self): 130 | pass 131 | 132 | def __call__(self, sample): 133 | image = sample['image'] 134 | if np.random.random() < 0.5: 135 | image = image.filter(ImageFilter.GaussianBlur(radius=np.random.random())) 136 | sample['image'] = image 137 | 138 | return sample 139 | 140 | class tonumpy: 141 | def __init__(self): 142 | pass 143 | 144 | def __call__(self, sample): 145 | image, gt = sample['image'], sample['gt'] 146 | 147 | sample['image'] = np.array(image, dtype=np.float32) 148 | sample['gt'] = np.array(gt, dtype=np.float32) 149 | 150 | return sample 151 | 152 | class normalize: 153 | def __init__(self, mean, std): 154 | self.mean = mean 155 | self.std = std 156 | 157 | def __call__(self, sample): 158 | image, gt = sample['image'], sample['gt'] 159 | image /= 255 160 | image -= self.mean 161 | image /= self.std 162 | 163 | gt /= 255 164 | sample['image'] = image 165 | sample['gt'] = gt 166 | 167 | return sample 168 | 169 | class totensor: 170 | def __init__(self): 171 | pass 172 | 173 | def __call__(self, sample): 174 | image, gt = sample['image'], sample['gt'] 175 | image = image.transpose((2, 0, 1)) 176 | image = torch.from_numpy(image).float() 177 | gt = torch.from_numpy(gt) 178 | gt = gt.unsqueeze(dim=0) 179 | sample['image'] = image 180 | sample['gt'] = gt 181 | 182 | return sample -------------------------------------------------------------------------------- /utils/tools.py: -------------------------------------------------------------------------------- 1 | import models 2 | import utils 3 | # from utils import ISIC2018 4 | from torch.utils.data import DataLoader 5 | from tqdm import tqdm 6 | import torch 7 | import torch.nn.functional as F 8 | import torch.nn as nn 9 | import torch.optim as optmi 10 | import os 11 | import pdb 12 | import numpy as np 13 | 14 | 15 | 16 | def intersect_and_union(pred_label, 17 | label, 18 | num_classes, 19 | ignore_index, 20 | label_map=dict(), 21 | reduce_zero_label=False): 22 | 23 | if isinstance(pred_label, str): 24 | pred_label = np.load(pred_label) 25 | # modify if custom classes 26 | if label_map is not None: 27 | for old_id, new_id in label_map.items(): 28 | label[label == old_id] = new_id 29 | if reduce_zero_label: 30 | # avoid using underflow conversion 31 | label[label == 0] = 255 32 | label = label - 1 33 | label[label == 254] = 255 34 | 35 | mask = (label != ignore_index) 36 | pred_label = pred_label[mask] 37 | label = label[mask] 38 | 39 | intersect = pred_label[pred_label == label] 40 | area_intersect, _ = np.histogram( 41 | intersect, bins=np.arange(num_classes + 1)) 42 | area_pred_label, _ = np.histogram( 43 | pred_label, bins=np.arange(num_classes + 1)) 44 | area_label, _ = np.histogram(label, bins=np.arange(num_classes + 1)) 45 | area_union = area_pred_label + area_label - area_intersect 46 | 47 | return area_intersect, area_union, area_pred_label, area_label 48 | 49 | 50 | def total_intersect_and_union(results, 51 | gt_seg_maps, 52 | num_classes, 53 | ignore_index, 54 | label_map=dict(), 55 | reduce_zero_label=False): 56 | """Calculate Total Intersection and Union. 57 | 58 | Args: 59 | results (list[ndarray]): List of prediction segmentation maps. 60 | gt_seg_maps (list[ndarray]): list of ground truth segmentation maps. 61 | num_classes (int): Number of categories. 62 | ignore_index (int): Index that will be ignored in evaluation. 63 | label_map (dict): Mapping old labels to new labels. Default: dict(). 64 | reduce_zero_label (bool): Wether ignore zero label. Default: False. 65 | 66 | Returns: 67 | ndarray: The intersection of prediction and ground truth histogram 68 | on all classes. 69 | ndarray: The union of prediction and ground truth histogram on all 70 | classes. 71 | ndarray: The prediction histogram on all classes. 72 | ndarray: The ground truth histogram on all classes. 73 | """ 74 | 75 | num_imgs = len(results) 76 | assert len(gt_seg_maps) == num_imgs 77 | total_area_intersect = np.zeros((num_classes, ), dtype=np.float) 78 | total_area_union = np.zeros((num_classes, ), dtype=np.float) 79 | total_area_pred_label = np.zeros((num_classes, ), dtype=np.float) 80 | total_area_label = np.zeros((num_classes, ), dtype=np.float) 81 | for i in range(num_imgs): 82 | area_intersect, area_union, area_pred_label, area_label = \ 83 | intersect_and_union(results[i], gt_seg_maps[i], num_classes, 84 | ignore_index, label_map, reduce_zero_label) 85 | total_area_intersect += area_intersect 86 | total_area_union += area_union 87 | total_area_pred_label += area_pred_label 88 | total_area_label += area_label 89 | return total_area_intersect, total_area_union, \ 90 | total_area_pred_label, total_area_label 91 | 92 | 93 | def mean_iou(results, 94 | gt_seg_maps, 95 | num_classes, 96 | ignore_index, 97 | nan_to_num=None, 98 | label_map=dict(), 99 | reduce_zero_label=False): 100 | """Calculate Mean Intersection and Union (mIoU) 101 | 102 | Args: 103 | results (list[ndarray]): List of prediction segmentation maps. 104 | gt_seg_maps (list[ndarray]): list of ground truth segmentation maps. 105 | num_classes (int): Number of categories. 106 | ignore_index (int): Index that will be ignored in evaluation. 107 | nan_to_num (int, optional): If specified, NaN values will be replaced 108 | by the numbers defined by the user. Default: None. 109 | label_map (dict): Mapping old labels to new labels. Default: dict(). 110 | reduce_zero_label (bool): Wether ignore zero label. Default: False. 111 | 112 | Returns: 113 | float: Overall accuracy on all images. 114 | ndarray: Per category accuracy, shape (num_classes, ). 115 | ndarray: Per category IoU, shape (num_classes, ). 116 | """ 117 | 118 | all_acc, acc, iou = eval_metrics( 119 | results=results, 120 | gt_seg_maps=gt_seg_maps, 121 | num_classes=num_classes, 122 | ignore_index=ignore_index, 123 | metrics=['mIoU'], 124 | nan_to_num=nan_to_num, 125 | label_map=label_map, 126 | reduce_zero_label=reduce_zero_label) 127 | return all_acc, acc, iou 128 | 129 | 130 | def mean_dice(results, 131 | gt_seg_maps, 132 | num_classes, 133 | ignore_index, 134 | nan_to_num=None, 135 | label_map=dict(), 136 | reduce_zero_label=False): 137 | """Calculate Mean Dice (mDice) 138 | 139 | Args: 140 | results (list[ndarray]): List of prediction segmentation maps. 141 | gt_seg_maps (list[ndarray]): list of ground truth segmentation maps. 142 | num_classes (int): Number of categories. 143 | ignore_index (int): Index that will be ignored in evaluation. 144 | nan_to_num (int, optional): If specified, NaN values will be replaced 145 | by the numbers defined by the user. Default: None. 146 | label_map (dict): Mapping old labels to new labels. Default: dict(). 147 | reduce_zero_label (bool): Wether ignore zero label. Default: False. 148 | 149 | Returns: 150 | float: Overall accuracy on all images. 151 | ndarray: Per category accuracy, shape (num_classes, ). 152 | ndarray: Per category dice, shape (num_classes, ). 153 | """ 154 | 155 | all_acc, acc, dice = eval_metrics( 156 | results=results, 157 | gt_seg_maps=gt_seg_maps, 158 | num_classes=num_classes, 159 | ignore_index=ignore_index, 160 | metrics=['mDice'], 161 | nan_to_num=nan_to_num, 162 | label_map=label_map, 163 | reduce_zero_label=reduce_zero_label) 164 | return all_acc, acc, dice 165 | 166 | 167 | def eval_metrics(results, 168 | gt_seg_maps, 169 | num_classes, 170 | ignore_index, 171 | metrics=['mIoU'], 172 | nan_to_num=None, 173 | label_map=dict(), 174 | reduce_zero_label=False): 175 | """Calculate evaluation metrics 176 | Args: 177 | results (list[ndarray]): List of prediction segmentation maps. 178 | gt_seg_maps (list[ndarray]): list of ground truth segmentation maps. 179 | num_classes (int): Number of categories. 180 | ignore_index (int): Index that will be ignored in evaluation. 181 | metrics (list[str] | str): Metrics to be evaluated, 'mIoU' and 'mDice'. 182 | nan_to_num (int, optional): If specified, NaN values will be replaced 183 | by the numbers defined by the user. Default: None. 184 | label_map (dict): Mapping old labels to new labels. Default: dict(). 185 | reduce_zero_label (bool): Wether ignore zero label. Default: False. 186 | Returns: 187 | float: Overall accuracy on all images. 188 | ndarray: Per category accuracy, shape (num_classes, ). 189 | ndarray: Per category evalution metrics, shape (num_classes, ). 190 | """ 191 | 192 | if isinstance(metrics, str): 193 | metrics = [metrics] 194 | allowed_metrics = ['mIoU', 'mDice'] 195 | if not set(metrics).issubset(set(allowed_metrics)): 196 | raise KeyError('metrics {} is not supported'.format(metrics)) 197 | total_area_intersect, total_area_union, total_area_pred_label, \ 198 | total_area_label = total_intersect_and_union(results, gt_seg_maps, 199 | num_classes, ignore_index, 200 | label_map, 201 | reduce_zero_label) 202 | all_acc = total_area_intersect.sum() / total_area_label.sum() 203 | acc = total_area_intersect / total_area_label 204 | ret_metrics = [all_acc, acc] 205 | for metric in metrics: 206 | if metric == 'mIoU': 207 | iou = total_area_intersect / total_area_union 208 | ret_metrics.append(iou) 209 | elif metric == 'mDice': 210 | dice = 2 * total_area_intersect / ( 211 | total_area_pred_label + total_area_label) 212 | ret_metrics.append(dice) 213 | if nan_to_num is not None: 214 | ret_metrics = [ 215 | np.nan_to_num(metric, nan=nan_to_num) for metric in ret_metrics 216 | ] 217 | return ret_metrics 218 | 219 | def get_confusion_matrix(pred_label, label, num_classes, ignore_index): 220 | """Intersection over Union 221 | Args: 222 | pred_label (np.ndarray): 2D predict map 223 | label (np.ndarray): label 2D label map 224 | num_classes (int): number of categories 225 | ignore_index (int): index ignore in evaluation 226 | """ 227 | 228 | mask = (label != ignore_index) 229 | pred_label = pred_label[mask] 230 | label = label[mask] 231 | 232 | n = num_classes 233 | inds = n * label + pred_label 234 | mat = np.bincount(inds, minlength=n**2).reshape(n, n) 235 | return mat 236 | 237 | def legacy_mean_dice(results, gt_seg_maps, num_classes, ignore_index): 238 | 239 | num_imgs = len(results) 240 | assert len(gt_seg_maps) == num_imgs 241 | total_mat = np.zeros((num_classes, num_classes), dtype=np.float) 242 | for i in range(num_imgs): 243 | mat = get_confusion_matrix(results[i], gt_seg_maps[i], num_classes, ignore_index=ignore_index) 244 | total_mat += mat 245 | # mat = get_confusion_matrix(results, gt_seg_maps, num_classes, ignore_index=ignore_index) 246 | # total_mat = mat 247 | all_acc = np.diag(total_mat).sum() / total_mat.sum() 248 | acc = np.diag(total_mat) / total_mat.sum(axis=1) 249 | dice = 2 * np.diag(total_mat) / (total_mat.sum(axis=1) + total_mat.sum(axis=0)) 250 | 251 | return all_acc, acc, dice 252 | 253 | # This func is deprecated since it's not memory efficient 254 | def legacy_mean_iou(results, gt_seg_maps, num_classes, ignore_index): 255 | num_imgs = len(results) 256 | assert len(gt_seg_maps) == num_imgs 257 | # total_mat = np.zeros((num_classes, num_classes), dtype=np.float) 258 | # for i in range(num_imgs): 259 | # mat = get_confusion_matrix( 260 | # results[i], gt_seg_maps[i], num_classes, ignore_index=ignore_index) 261 | # total_mat += mat 262 | mat = get_confusion_matrix(results, gt_seg_maps, num_classes, ignore_index=ignore_index) 263 | total_mat = mat 264 | all_acc = np.diag(total_mat).sum() / total_mat.sum() 265 | acc = np.diag(total_mat) / total_mat.sum(axis=1) 266 | iou = np.diag(total_mat) / ( 267 | total_mat.sum(axis=1) + total_mat.sum(axis=0) - np.diag(total_mat)) 268 | 269 | return all_acc, acc, iou 270 | 271 | 272 | def Fmeasure_calu(pred, gt, threshold): 273 | if threshold > 1: 274 | threshold = 1 275 | 276 | Label3 = np.zeros_like(gt) 277 | Label3[pred >= threshold] = 1 278 | 279 | NumRec = np.sum(Label3 == 1) 280 | NumNoRec = np.sum(Label3 == 0) 281 | 282 | LabelAnd = (Label3 == 1) & (gt == 1) 283 | NumAnd = np.sum(LabelAnd == 1) 284 | num_obj = np.sum(gt) 285 | num_pred = np.sum(Label3) 286 | 287 | FN = num_obj - NumAnd 288 | FP = NumRec - NumAnd 289 | TN = NumNoRec - FN 290 | 291 | if NumAnd == 0: 292 | PreFtem = 0 293 | RecallFtem = 0 294 | FmeasureF = 0 295 | Dice = 0 296 | SpecifTem = 0 297 | IoU = 0 298 | 299 | else: 300 | IoU = NumAnd / (FN + NumRec) 301 | PreFtem = NumAnd / NumRec 302 | RecallFtem = NumAnd / num_obj 303 | SpecifTem = TN / (TN + FP) 304 | Dice = 2 * NumAnd / (num_obj + num_pred) 305 | FmeasureF = ((2.0 * PreFtem * RecallFtem) / (PreFtem + RecallFtem)) 306 | 307 | return PreFtem, RecallFtem, SpecifTem, Dice, FmeasureF, IoU 308 | 309 | 310 | class Colorize: 311 | def __init__(self, n): 312 | self.cmap = self.colormap(256) 313 | self.cmap[n] = self.cmap[-1] 314 | self.cmap = torch.from_numpy(self.cmap[:n])#array->tensor 315 | 316 | def colormap(self, n): 317 | cmap=np.zeros([n, 3]).astype(np.uint8) 318 | cmap[0,:] = np.array([ 0, 0, 0]) 319 | cmap[1,:] = np.array([244, 35,232]) 320 | cmap[2,:] = np.array([ 70, 70, 70]) 321 | cmap[3,:] = np.array([ 102,102,156]) 322 | cmap[4,:] = np.array([ 190,153,153]) 323 | cmap[5,:] = np.array([ 153,153,153]) 324 | 325 | cmap[6,:] = np.array([ 250,170, 30]) 326 | cmap[7,:] = np.array([ 220,220, 0]) 327 | cmap[8,:] = np.array([ 107,142, 35]) 328 | cmap[9,:] = np.array([ 152,251,152]) 329 | cmap[10,:] = np.array([ 70,130,180]) 330 | 331 | cmap[11,:] = np.array([ 220, 20, 60]) 332 | cmap[12,:] = np.array([ 119, 11, 32]) 333 | cmap[13,:] = np.array([ 0, 0,142]) 334 | cmap[14,:] = np.array([ 0, 0, 70]) 335 | cmap[15,:] = np.array([ 0, 60,100]) 336 | 337 | cmap[16,:] = np.array([ 0, 80,100]) 338 | cmap[17,:] = np.array([ 0, 0,230]) 339 | cmap[18,:] = np.array([ 255, 0, 0]) 340 | 341 | return cmap 342 | 343 | def __call__(self, gray_image): 344 | size = gray_image.size()#这里就是上文的output 345 | color_image = torch.ByteTensor(3, size[0], size[1]).fill_(0) 346 | 347 | for label in range(0, len(self.cmap)): 348 | mask = gray_image == label 349 | color_image[0][mask] = self.cmap[label][0] 350 | color_image[1][mask] = self.cmap[label][1] 351 | color_image[2][mask] = self.cmap[label][2] 352 | 353 | return color_image 354 | 355 | 356 | from torchvision import transforms 357 | from torchvision.utils import save_image 358 | class ImageSaver(): 359 | def __init__(self): 360 | self.img = None 361 | self.label = None 362 | 363 | self.to_Tensor = transforms.Compose([ 364 | transforms.ToTensor()]) 365 | 366 | self.to_PIL = transforms.ToPILImage() 367 | 368 | def save(self, pred, label, b): 369 | """ 370 | img : tensor (B, C, W, H) 371 | label: tensor (B, W, H) 372 | b : batch size 373 | """ 374 | cot = 0 375 | for each in range(b): 376 | img = pred[each, :] 377 | mask = label[each, :] 378 | 379 | # img = img.resize((val_ds.image_size[cot][0], val_ds.image_size[cot][1])) 380 | # mask = mask.resize((val_ds.image_size[cot][0], val_ds.image_size[cot][1])) 381 | 382 | # img = self.to_Tensor(img) 383 | # mask = self.to_Tensor(mask) 384 | 385 | save_image(img, "./predict_images/cvc-300/"+str(cot)+".png") 386 | save_image(mask, "./predict_labels/CVC-300/"+str(cot)+".png") 387 | 388 | cot += 1 -------------------------------------------------------------------------------- /utils/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from thop import profile 4 | from thop import clever_format 5 | 6 | 7 | def clip_gradient(optimizer, grad_clip): 8 | """ 9 | For calibrating misalignment gradient via cliping gradient technique 10 | :param optimizer: 11 | :param grad_clip: 12 | :return: 13 | """ 14 | for group in optimizer.param_groups: 15 | for param in group['params']: 16 | if param.grad is not None: 17 | param.grad.data.clamp_(-grad_clip, grad_clip) 18 | 19 | 20 | def adjust_lr(optimizer, init_lr, epoch, decay_rate=0.1, decay_epoch=30): 21 | decay = decay_rate ** (epoch // decay_epoch) 22 | for param_group in optimizer.param_groups: 23 | param_group['lr'] *= decay 24 | 25 | 26 | class AvgMeter(object): 27 | def __init__(self, num=40): 28 | self.num = num 29 | self.reset() 30 | 31 | def reset(self): 32 | self.val = 0 33 | self.avg = 0 34 | self.sum = 0 35 | self.count = 0 36 | self.losses = [] 37 | 38 | def update(self, val, n=1): 39 | self.val = val 40 | self.sum += val * n 41 | self.count += n 42 | self.avg = self.sum / self.count 43 | self.losses.append(val) 44 | 45 | def show(self): 46 | return torch.mean(torch.stack(self.losses[np.maximum(len(self.losses)-self.num, 0):])) 47 | 48 | 49 | def CalParams(model, input_tensor): 50 | """ 51 | Usage: 52 | Calculate Params and FLOPs via [THOP](https://github.com/Lyken17/pytorch-OpCounter) 53 | Necessarity: 54 | from thop import profile 55 | from thop import clever_format 56 | :param model: 57 | :param input_tensor: 58 | :return: 59 | """ 60 | flops, params = profile(model, inputs=(input_tensor,)) 61 | flops, params = clever_format([flops, params], "%.3f") 62 | print('[Statistics Information]\nFLOPs: {}\nParams: {}'.format(flops, params)) --------------------------------------------------------------------------------