├── README.md ├── cal_PD_and_Fa.py ├── cal_mIoU_and_nIoU.py ├── components ├── __pycache__ │ ├── cal_mean_std.cpython-36.pyc │ ├── cal_mean_std.cpython-38.pyc │ ├── dataset_final_edge_copy_paste_final_2_img_path.cpython-36.pyc │ ├── dataset_final_edge_copy_paste_final_2_img_path.cpython-38.pyc │ ├── edges.cpython-36.pyc │ ├── edges.cpython-38.pyc │ ├── metric_new_crop.cpython-36.pyc │ ├── metric_new_crop.cpython-38.pyc │ ├── utils_all_edge_copy_paste_final_2_img_path.cpython-36.pyc │ └── utils_all_edge_copy_paste_final_2_img_path.cpython-38.pyc ├── cal_mean_std.py ├── dataset_final_edge_copy_paste_final_2_img_path.py ├── edges.py ├── gardient.py ├── metric_new_crop.py └── utils_all_edge_copy_paste_final_2_img_path.py ├── imgs ├── Main results.png ├── PAL framework.png ├── Results on the SIRST3 with centroid point label.png ├── Results on the SIRST3 with coarse point label.png ├── Results on the three separate dataset with centroid point label.png ├── Results on the three separate dataset with coarse point label.png ├── Visualization on the SIRST3 with centroid point label.png └── Visualization on the SIRST3 with coarse point label.png ├── loss ├── Edge_loss.py ├── __init__.py └── __pycache__ │ ├── Edge_loss.cpython-36.pyc │ ├── Edge_loss.cpython-38.pyc │ ├── __init__.cpython-36.pyc │ └── __init__.cpython-38.pyc ├── mm └── attention │ ├── SEAttention.py │ └── __pycache__ │ ├── SEAttention.cpython-36.pyc │ └── SEAttention.cpython-38.pyc ├── model ├── ACM │ ├── ACM_no_sigmoid.py │ ├── __init__.py │ ├── __pycache__ │ │ ├── ACM_no_sigmoid.cpython-36.pyc │ │ ├── __init__.cpython-36.pyc │ │ └── fusion.cpython-36.pyc │ └── fusion.py ├── ALC │ ├── ALC_no_sigmoid.py │ ├── __init__.py │ ├── __pycache__ │ │ ├── ALC_no_sigmoid.cpython-36.pyc │ │ ├── __init__.cpython-36.pyc │ │ └── fusion.cpython-36.pyc │ └── fusion.py ├── ALCL │ ├── ALCL_no_sigmoid.py │ ├── __init__.py │ └── __pycache__ │ │ ├── ALCL_no_sigmoid.cpython-36.pyc │ │ └── __init__.cpython-36.pyc ├── DNA │ ├── DNA_no_sigmoid.py │ ├── __init__.py │ └── __pycache__ │ │ ├── DNA_no_sigmoid.cpython-36.pyc │ │ └── __init__.cpython-36.pyc ├── GGL │ ├── GGL_no_sigmoid.py │ ├── __init__.py │ └── __pycache__ │ │ ├── GGL_no_sigmoid.cpython-36.pyc │ │ └── __init__.cpython-36.pyc ├── MLCL │ ├── MLCL_no_sigmoid.py │ ├── MLCL_small_no_sigmoid.py │ ├── __init__.py │ └── __pycache__ │ │ ├── MLCL_no_sigmoid.cpython-36.pyc │ │ └── __init__.cpython-36.pyc ├── MSDA │ ├── MSDA_no_sigmoid.py │ ├── __init__.py │ └── __pycache__ │ │ ├── MSDA_no_sigmoid.cpython-36.pyc │ │ ├── MSDA_no_sigmoid.cpython-38.pyc │ │ ├── __init__.cpython-36.pyc │ │ └── __init__.cpython-38.pyc └── UIU │ ├── UIU_no_sigmoid.py │ ├── __init__.py │ ├── __pycache__ │ ├── UIU_no_sigmoid.cpython-36.pyc │ ├── UIU_no_sigmoid.cpython-38.pyc │ ├── __init__.cpython-36.pyc │ ├── __init__.cpython-38.pyc │ ├── fusion.cpython-36.pyc │ └── fusion.cpython-38.pyc │ └── fusion.py ├── test_model.py ├── tools ├── centroid_anno.m └── coarse_anno.m ├── train_model.py └── utilts.py /README.md: -------------------------------------------------------------------------------- 1 | ## The official complete code for paper "From Easy to Hard: Progressive Active Learning Framework for Infrared Small Target Detection with Single Point Supervision" [[Paper/arXiv](https://arxiv.org/abs/2412.11154)] 2 | 3 | 8 | 9 | In this project demo, we have integrated multiple SIRST detection networks ([**ACM**](https://arxiv.org/abs/2009.14530), [**ALC**](https://arxiv.org/abs/2012.08573), [**MLCL-Net**](https://doi.org/10.1016/j.infrared.2022.104107), [**ALCL-Net**](https://ieeexplore.ieee.org/document/9785618), [**DNANet**](https://arxiv.org/abs/2106.00487), [**GGL-Net**](https://ieeexplore.ieee.org/abstract/document/10230271), [**UIUNet**](https://arxiv.org/abs/2212.00968), [**MSDA-Net**](https://arxiv.org/abs/2406.02037)), label forms (Full supervision, Coarse single-point supervision, Centroid single-point supervision), and datasets ([**SIRST**](https://ieeexplore.ieee.org/document/9423171), [**NUDT-SIRST**](https://ieeexplore.ieee.org/document/9864119), [**IRSTD-1k**](https://ieeexplore.ieee.org/document/9880295) and [**SIRST3**](https://arxiv.org/pdf/2304.01484)). At the same time, more networks and functions can be integrated into the project later. We hope we can contribute to the development of this field. 10 | 11 |

12 | Main results
13 |

14 | 15 |
16 | Comparison of different methods on the SIRST3 dataset. CNN Full, CNN Coarse, and CNN Centroid denote CNN-based methods under full supervision, coarse and centroid point supervision. 17 |

18 | 19 | 20 | 25 | 26 | ## Overview 27 | 28 | We consider that an excellent learning process should be from easy to hard and take into account the learning ability of the current learner (model) rather than directly treating all tasks (samples) equally. Inspired by organisms gradually adapting to the environment and continuously accumulating knowledge, we first propose an innovative progressive active learning idea, which emphasizes that the network progressively and actively recognizes and learns more hard samples to achieve continuous performance enhancement. For details, please see [[Paper/arXiv](https://arxiv.org/abs/2412.11154)]. 29 | 30 |

31 | PAL framework
32 |

33 | 34 | 35 | ## Datasets 36 | 1. Original datasets 37 | * **NUDT-SIRST** [[Original dataset](https://pan.baidu.com/s/1WdA_yOHDnIiyj4C9SbW_Kg?pwd=nudt)] [[paper](https://ieeexplore.ieee.org/document/9864119)] 38 | * **SIRST** [[Original dataset](https://github.com/YimianDai/sirst)] [[paper](https://ieeexplore.ieee.org/document/9423171)] 39 | * **IRSTD-1k** [[Original dataset](https://drive.google.com/file/d/1JoGDGF96v4CncKZprDnoIor0k1opaLZa/view)] [[paper](https://ieeexplore.ieee.org/document/9880295)] 40 | * **SIRST3** [[Original dataset](https://github.com/XinyiYing/LESPS)] [[paper](https://arxiv.org/pdf/2304.01484)] 41 | 42 | 2. The labels are processed according to the "coarse_anno.m" and "centroid_anno.m" files in the "tools" folder to produce coarse point labels and centroid point labels. (**You can also skip this step and use the complete dataset in step 3 directly.**) 43 | 44 | 3. The datasets we created from original datasets (**can be used directly in our demo**) 45 | 46 | * [💎 Download the dataset required by our code!!!](https://pan.baidu.com/s/1_QIs9zUM_7MqJgwzO2aC0Q?pwd=1234) 47 | 48 | 49 | ## How to use our code 50 | 1. Download the dataset 51 | 52 |         Click [download datasets](https://pan.baidu.com/s/1_QIs9zUM_7MqJgwzO2aC0Q?pwd=1234) 53 | 54 |         Unzip the downloaded compressed package to the root directory of the project. 55 | 56 | 2. Creat a Anaconda Virtual Environment 57 | 58 | ``` 59 | conda create -n PAL python=3.8 60 | conda activate PAL 61 | ``` 62 | 3. Configure the running environment 63 | 64 | ``` 65 | pip install torch==1.13.1+cu116 torchvision==0.14.1+cu116 torchaudio==0.13.1 --extra-index-url https://download.pytorch.org/whl/cu116 66 | pip install segmentation_models_pytorch -i https://pypi.tuna.tsinghua.edu.cn/simple 67 | pip install PyWavelets -i https://pypi.tuna.tsinghua.edu.cn/simple 68 | pip install scikit-image -i https://pypi.tuna.tsinghua.edu.cn/simple 69 | pip install albumentations==1.3.0 -i https://pypi.tuna.tsinghua.edu.cn/simple 70 | pip install scikit-learn matplotlib thop h5py SimpleITK scikit-image medpy yacs torchinfo 71 | ``` 72 | 4. Training the model 73 | 74 | The default model, dataset and label forms are MSDA-Net, SIRST3, and coarse point labels. If you need to train others, please modify the corresponding setting in "train_model.py". Just change the settings to your choice. It's very simple. For details, please see the beginning of the code of "train_model.py".
75 | ``` 76 | python train_model.py 77 | ``` 78 | 5. Testing the Model 79 | 80 | The default model, dataset and label forms are MSDA-Net, SIRST3, and coarse point labels. If you need to test others, please modify the corresponding setting in "test_model.py". Notably, in the "test_model.py" file, you also need to assign the name of the folder where the weight file is located to the "test_dir_name" variable so that the program can find the corresponding model weights. For details, please see the beginning of the code of "test_model.py". 81 | ``` 82 | python test_model.py 83 | ``` 84 | 6. Performance Evaluation 85 | Use "cal_mIoU_and_nIoU.py" and "cal_PD_and_Fa.py" for performance evaluation. Notably, the corresponding folder path should be replaced. default:SIRST3. 86 | ``` 87 | python cal_mIoU_and_nIoU.py 88 | python cal_PD_and_Fa.py 89 | ``` 90 | 91 | ## Results 92 | 93 | * **Quantative Results on the SIRST3 dataset with Coarse point labels:** 94 |

95 | Results on the SIRST3 with coarse point label 96 |

97 | 98 | * **Quantative Results on the three individual datasets with Coarse point labels:** 99 |

100 | Results on the three separate dataset with coarse point label 101 |

102 | 103 | * **Quantative Results on the SIRST3 dataset with Centroid point labels:** 104 |

105 | Results on the SIRST3 with centroid point label 106 |

107 | 108 | * **Quantative Results on the three individual datasets with Centroid point labels:** 109 |

110 | Results on the three separate dataset with centroid point label 111 |

112 | 113 | * **Qualitative results on the SIRST3 dataset with Coarse point labels:** (Red denotes the correct detections, blue denotes the false detections, and yellow denotes the missed detections.) 114 |

115 | Visualization on the SIRST3 with coarse point label 116 |

117 | 118 | * **Qualitative results on the SIRST3 dataset with Centroid point labels:** (Red denotes the correct detections, blue denotes the false detections, and yellow denotes the missed detections.) 119 |

120 | Visualization on the SIRST3 with centroid point label 121 |

122 | 123 | 124 | 125 | ## Citation 126 | 127 | If you find this repo helpful, please give us a 🤩**star**🤩. Please consider citing the **PAL** if it benefits your project.
128 | 129 | BibTeX reference is as follows. 130 | ``` 131 | @misc{yu2024easyhardprogressiveactive, 132 | title={From Easy to Hard: Progressive Active Learning Framework for Infrared Small Target Detection with Single Point Supervision}, 133 | author={Chuang Yu and Jinmiao Zhao and Yunpeng Liu and Sicheng Zhao and Yimian Dai and Xiangyu Yue}, 134 | year={2024}, 135 | eprint={2412.11154}, 136 | archivePrefix={arXiv}, 137 | primaryClass={cs.CV}, 138 | url={https://arxiv.org/abs/2412.11154}, 139 | } 140 | ``` 141 | 142 | word reference is as follows. 143 | ``` 144 | Chuang Yu, Jinmiao Zhao, Yunpeng Liu, Sicheng Zhao, Yimian Dai, Xiangyu Yue. From Easy to Hard: Progressive Active Learning Framework for Infrared Small Target Detection with Single Point Supervision. arXiv preprint arXiv:2412.11154, 2024. 145 | ``` 146 | 147 | 148 | 149 | ## Other link 150 | 151 | 1. My homepage: [[YuChuang](https://github.com/YuChuang1205)] 152 | 2. "MSDA-Net" demo: [[Link](https://github.com/YuChuang1205/MSDA-Net)] 153 | -------------------------------------------------------------------------------- /cal_PD_and_Fa.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python3 2 | # coding = gbk 3 | """ 4 | @Author : yuchuang,zhaojinmiao 5 | @Time : 6 | @desc: 7 | """ 8 | import os 9 | import cv2 10 | import numpy as np 11 | import sys 12 | from skimage import measure 13 | import os 14 | from os.path import join, isfile 15 | 16 | 17 | def cal_Pd_Fa(input_pred, input_true): 18 | image = measure.label(input_pred, connectivity=2) 19 | coord_image = measure.regionprops(image) 20 | label = measure.label(input_true, connectivity=2) 21 | coord_label = measure.regionprops(label) 22 | target = len(coord_label) 23 | image_area_total = [] 24 | image_area_match = [] 25 | distance_match = [] 26 | dismatch = [] 27 | for K in range(len(coord_image)): 28 | area_image = np.array(coord_image[K].area) 29 | #print(area_image) 30 | image_area_total.append(area_image) 31 | 32 | for i in range(len(coord_label)): 33 | centroid_label = np.array(list(coord_label[i].centroid)) 34 | for m in range(len(coord_image)): 35 | centroid_image = np.array(list(coord_image[m].centroid)) 36 | distance = np.linalg.norm(centroid_image - centroid_label) 37 | area_image = np.array(coord_image[m].area) 38 | if distance < 3: 39 | distance_match.append(distance) 40 | image_area_match.append(area_image) 41 | del coord_image[m] 42 | break 43 | 44 | FA = np.sum(image_area_total) - np.sum(image_area_match) 45 | PD = len(distance_match) 46 | return target, FA, PD 47 | 48 | # IMAGE_SIZE = 512 49 | root_path = os.path.abspath('.') 50 | # input_path = os.path.join(root_path, 'input') 51 | dataset_path = os.path.join(root_path,'dataset','SIRST3') 52 | test_dataset_path = os.path.join(dataset_path,'val') 53 | 54 | input_pred_path = os.path.join(test_dataset_path, 'pre_results') 55 | input_true_path = os.path.join(test_dataset_path, 'mask') 56 | 57 | # img_num = count_images_in_folder(input_pred_path) 58 | # print(img_num) 59 | input_pred_list = os.listdir(input_true_path) 60 | img_num = len(input_pred_list) 61 | print(img_num) 62 | input_pred_list.sort() 63 | target_all = 0 64 | FA_all = 0 65 | PD_all = 0 66 | all_pixel = 0 67 | for i in range(len(input_pred_list)): 68 | print("正在处理:", input_pred_list[i]) 69 | img_name = input_pred_list[i] 70 | input_pred_img_path = os.path.join(input_pred_path, img_name) 71 | input_true_img_path = os.path.join(input_true_path, img_name) 72 | input_pred = cv2.imread(input_pred_img_path, cv2.IMREAD_GRAYSCALE) 73 | print(input_pred.shape) 74 | all_pixel += input_pred.shape[0] * input_pred.shape[1] 75 | input_pred = np.where(input_pred>127,255,0) 76 | input_true = cv2.imread(input_true_img_path, cv2.IMREAD_GRAYSCALE) 77 | target, FA, PD = cal_Pd_Fa(input_pred, input_true) 78 | #print(FA) 79 | target_all = target_all + target 80 | FA_all = FA_all + FA 81 | PD_all = PD_all + PD 82 | Pd = PD_all / target_all 83 | Fa = FA_all / all_pixel 84 | 85 | print("Pd为:", Pd) 86 | print("Fa为:", Fa*1000000) 87 | print("Done!!!!!!!!!!") 88 | 89 | 90 | 91 | 92 | 93 | -------------------------------------------------------------------------------- /cal_mIoU_and_nIoU.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python3 2 | # coding = gbk 3 | """ 4 | @Author : yuchuang,zhaojinmiao 5 | @Time : 6 | @desc: 7 | """ 8 | import os 9 | import cv2 10 | import numpy as np 11 | import sys 12 | 13 | 14 | def cal_iou(input_pred, input_true): 15 | 16 | input_pred = input_pred/255 17 | input_true = input_true/255 18 | 19 | inter_count = np.sum(input_pred*input_true) 20 | outer_count = np.sum(input_pred) + np.sum(input_true) - inter_count 21 | 22 | 23 | return inter_count, outer_count 24 | 25 | def sigmoid(x): 26 | s = 1 / (1 + np.exp(-x)) 27 | return s 28 | 29 | root_path = os.path.abspath('.') 30 | # input_path = os.path.join(root_path, 'input') 31 | 32 | dataset_path = os.path.join(root_path,'dataset','SIRST3') 33 | test_dataset_path = os.path.join(dataset_path,'val') 34 | 35 | input_pred_path = os.path.join(test_dataset_path, 'pre_results') 36 | input_true_path = os.path.join(test_dataset_path, 'mask') 37 | 38 | input_pred_list = os.listdir(input_true_path) 39 | 40 | input_pred_list.sort() 41 | 42 | inter_count_all = 0 43 | outer_count_all = 0 44 | niou_list = [] 45 | for i in range(len(input_pred_list)): 46 | print("正在处理:", input_pred_list[i]) 47 | img_name = input_pred_list[i] 48 | input_pred_img_path = os.path.join(input_pred_path, img_name) 49 | input_true_img_path = os.path.join(input_true_path, img_name) 50 | input_pred = cv2.imread(input_pred_img_path, cv2.IMREAD_GRAYSCALE) 51 | #input_pred = sigmoid(input_pred) 52 | input_pred = np.where(input_pred>127,255,0) 53 | #print(input_pred) 54 | input_true = cv2.imread(input_true_img_path, cv2.IMREAD_GRAYSCALE) 55 | inter_count, outer_count = cal_iou(input_pred, input_true) 56 | inter_count_all = inter_count_all + inter_count 57 | outer_count_all = outer_count_all + outer_count 58 | niou_list.append(inter_count/(outer_count+1e-6)) 59 | 60 | niou = np.mean(niou_list) 61 | miou = inter_count_all/outer_count_all 62 | 63 | print("moiu为:", miou) 64 | print("niou为:", niou) 65 | print("Done!!!!!!!!!!") 66 | 67 | 68 | 69 | 70 | 71 | -------------------------------------------------------------------------------- /components/__pycache__/cal_mean_std.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YuChuang1205/PAL/a9046c5fc65e7b2548d6e195b624239fa836cf3e/components/__pycache__/cal_mean_std.cpython-36.pyc -------------------------------------------------------------------------------- /components/__pycache__/cal_mean_std.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YuChuang1205/PAL/a9046c5fc65e7b2548d6e195b624239fa836cf3e/components/__pycache__/cal_mean_std.cpython-38.pyc -------------------------------------------------------------------------------- /components/__pycache__/dataset_final_edge_copy_paste_final_2_img_path.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YuChuang1205/PAL/a9046c5fc65e7b2548d6e195b624239fa836cf3e/components/__pycache__/dataset_final_edge_copy_paste_final_2_img_path.cpython-36.pyc -------------------------------------------------------------------------------- /components/__pycache__/dataset_final_edge_copy_paste_final_2_img_path.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YuChuang1205/PAL/a9046c5fc65e7b2548d6e195b624239fa836cf3e/components/__pycache__/dataset_final_edge_copy_paste_final_2_img_path.cpython-38.pyc -------------------------------------------------------------------------------- /components/__pycache__/edges.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YuChuang1205/PAL/a9046c5fc65e7b2548d6e195b624239fa836cf3e/components/__pycache__/edges.cpython-36.pyc -------------------------------------------------------------------------------- /components/__pycache__/edges.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YuChuang1205/PAL/a9046c5fc65e7b2548d6e195b624239fa836cf3e/components/__pycache__/edges.cpython-38.pyc -------------------------------------------------------------------------------- /components/__pycache__/metric_new_crop.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YuChuang1205/PAL/a9046c5fc65e7b2548d6e195b624239fa836cf3e/components/__pycache__/metric_new_crop.cpython-36.pyc -------------------------------------------------------------------------------- /components/__pycache__/metric_new_crop.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YuChuang1205/PAL/a9046c5fc65e7b2548d6e195b624239fa836cf3e/components/__pycache__/metric_new_crop.cpython-38.pyc -------------------------------------------------------------------------------- /components/__pycache__/utils_all_edge_copy_paste_final_2_img_path.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YuChuang1205/PAL/a9046c5fc65e7b2548d6e195b624239fa836cf3e/components/__pycache__/utils_all_edge_copy_paste_final_2_img_path.cpython-36.pyc -------------------------------------------------------------------------------- /components/__pycache__/utils_all_edge_copy_paste_final_2_img_path.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YuChuang1205/PAL/a9046c5fc65e7b2548d6e195b624239fa836cf3e/components/__pycache__/utils_all_edge_copy_paste_final_2_img_path.cpython-38.pyc -------------------------------------------------------------------------------- /components/cal_mean_std.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from PIL import Image 3 | import os 4 | # 5 | 6 | def Calculate_mean_std(img_dir): 7 | img_list = os.listdir(img_dir) 8 | 9 | mean_list = [] 10 | std_list = [] 11 | 12 | for i in range(len(img_list)): 13 | #print(i) 14 | img_path = os.path.join(img_dir,img_list[i]) 15 | img = np.array(Image.open(img_path).convert("L")) 16 | mean_list.append(img.mean()) 17 | std_list.append(img.std()) 18 | mean_out = np.mean(mean_list)/255 19 | std_out = np.mean(std_list)/255 20 | print("路径为:", img_dir) 21 | print("数据集均值为:", mean_out) 22 | print("数据集方差为:", std_out) 23 | 24 | return mean_out, std_out 25 | 26 | -------------------------------------------------------------------------------- /components/dataset_final_edge_copy_paste_final_2_img_path.py: -------------------------------------------------------------------------------- 1 | import os 2 | from PIL import Image 3 | from torch.utils.data import Dataset 4 | import numpy as np 5 | import matplotlib.pyplot as plt 6 | import albumentations as A 7 | from albumentations.pytorch import ToTensorV2 8 | from torch.utils.data import DataLoader 9 | import time 10 | from components.edges import onehot_to_binary_edges, mask_to_onehot 11 | import cv2 12 | import random 13 | import math 14 | 15 | 16 | 17 | def random_crop(img, mask, patch_size): 18 | h, w, c = img.shape 19 | mh, mw = mask.shape 20 | 21 | assert (h, w) == (mh, mw), "Image and mask must have the same height and width" 22 | 23 | if min(h, w) < patch_size: 24 | img = np.pad(img, ((0, max(h, patch_size) - h), (0, max(w, patch_size) - w), (0, 0)), mode='constant') 25 | mask = np.pad(mask, ((0, max(h, patch_size) - h), (0, max(w, patch_size) - w)), mode='constant') 26 | h, w, _ = img.shape 27 | 28 | h_start = random.randint(0, h - patch_size) 29 | h_end = h_start + patch_size 30 | w_start = random.randint(0, w - patch_size) 31 | w_end = w_start + patch_size 32 | 33 | img_patch = img[h_start:h_end, w_start:w_end, :] 34 | mask_patch = mask[h_start:h_end, w_start:w_end] 35 | 36 | return img_patch, mask_patch 37 | 38 | 39 | class SirstDataset(Dataset): 40 | def __init__(self, image_dir, mask_dir, patch_size, transform=None, mode='None'): 41 | self.image_dir = image_dir 42 | self.mask_dir = mask_dir 43 | self.transform = transform 44 | self.images = np.sort(os.listdir(image_dir)) 45 | self.mode = mode 46 | self.patch_size = patch_size 47 | 48 | def __len__(self): 49 | return len(self.images) 50 | 51 | def __getitem__(self, index): 52 | img_path = os.path.join(self.image_dir, self.images[index]) 53 | mask_path = os.path.join(self.mask_dir, self.images[index]) 54 | image = np.array(Image.open(img_path).convert("RGB")) 55 | # print(image.shape) 56 | mask = np.array(Image.open(mask_path).convert("L"), dtype=np.float32) 57 | mask = (mask > 127.5).astype(float) 58 | 59 | if (self.mode == 'train'): 60 | image_patch, mask_patch = random_crop(image, mask, self.patch_size) 61 | if self.transform is not None: 62 | augmentations = self.transform(image=image_patch, mask=mask_patch) 63 | image = augmentations["image"] 64 | mask = augmentations["mask"] 65 | mask_2 = mask.numpy() 66 | mask_2 = mask_2.astype(np.int64) 67 | oneHot_label = mask_to_onehot(mask_2, 2) 68 | edge = onehot_to_binary_edges(oneHot_label, 1, 2) 69 | edge[1, :] = 0 70 | edge[-1:, :] = 0 71 | edge[:, :1] = 0 72 | edge[:, -1:] = 0 73 | edge = np.expand_dims(edge, axis=0).astype(np.int64) 74 | return image, mask, edge 75 | 76 | elif (self.mode == 'val'): 77 | times = 32 78 | h, w, c = image.shape 79 | pad_height = math.ceil(h / times) * times - h 80 | pad_width = math.ceil(w / times) * times - w 81 | image = np.pad(image, ((0, pad_height), (0, pad_width), (0, 0)), mode='constant') 82 | mask = np.pad(mask, ((0, pad_height), (0, pad_width)), mode='constant') 83 | if self.transform is not None: 84 | augmentations = self.transform(image=image, mask=mask) 85 | image = augmentations["image"] 86 | mask = augmentations["mask"] 87 | 88 | return image, mask, h, w 89 | -------------------------------------------------------------------------------- /components/edges.py: -------------------------------------------------------------------------------- 1 | from scipy.ndimage.morphology import distance_transform_edt 2 | import numpy as np 3 | import torch 4 | def onehot_to_multiclass_edges(mask, radius, num_classes): 5 | """ 6 | Converts a segmentation mask (K,H,W) to an edgemap (K,H,W) 7 | """ 8 | if radius < 0: 9 | return mask 10 | 11 | # We need to pad the borders for boundary conditions 12 | mask_pad = np.pad(mask, ((0, 0), (1, 1), (1, 1)), mode='constant', constant_values=0) 13 | 14 | channels = [] 15 | for i in range(num_classes): 16 | dist = distance_transform_edt(mask_pad[i, :]) + distance_transform_edt(1.0 - mask_pad[i, :]) 17 | dist = dist[1:-1, 1:-1] 18 | dist[dist > radius] = 0 19 | dist = (dist > 0).astype(np.uint8) 20 | channels.append(dist) 21 | 22 | return np.array(channels) 23 | 24 | 25 | def onehot_to_binary_edges(mask, radius, num_classes): 26 | """ 27 | Converts a segmentation mask (K,H,W) to a binary edgemap (H,W) 28 | """ 29 | 30 | if radius < 0: 31 | return mask 32 | 33 | # We need to pad the borders for boundary conditions 34 | mask_pad = np.pad(mask, ((0, 0), (1, 1), (1, 1)), mode='constant', constant_values=0) 35 | 36 | edgemap = np.zeros(mask.shape[1:]) 37 | for i in range(num_classes): 38 | # 提取轮廓 39 | dist = distance_transform_edt(mask_pad[i, :]) + distance_transform_edt(1.0 - mask_pad[i, :]) 40 | dist = dist[1:-1, 1:-1] 41 | dist[dist > radius] = 0 42 | edgemap += dist 43 | edgemap = (edgemap > 0).astype(np.uint8)*255 44 | return edgemap 45 | 46 | 47 | def mask_to_onehot(mask, num_classes): 48 | """ 49 | Converts a segmentation mask (H,W) to (K,H,W) where the last dim is a one 50 | hot encoding vector 51 | """ 52 | _mask = [mask == (i) for i in range(num_classes)] 53 | return np.array(_mask).astype(np.uint8) -------------------------------------------------------------------------------- /components/gardient.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | def gradient_1order(x,h_x=None,w_x=None): 4 | if h_x is None and w_x is None: 5 | h_x = x.size()[2] 6 | w_x = x.size()[3] 7 | r = F.pad(x, (0, 1, 0, 0))[:, :, :, 1:] 8 | l = F.pad(x, (1, 0, 0, 0))[:, :, :, :w_x] 9 | t = F.pad(x, (0, 0, 1, 0))[:, :, :h_x, :] 10 | b = F.pad(x, (0, 0, 0, 1))[:, :, 1:, :] 11 | xgrad = torch.pow(torch.pow((r - l) * 0.5, 2) + torch.pow((t - b) * 0.5, 2), 0.5) 12 | return xgrad 13 | if __name__ == '__main__': 14 | input=torch.randn(50,512,7,7) 15 | output = gradient_1order(input) 16 | print(output.shape) -------------------------------------------------------------------------------- /components/metric_new_crop.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import numpy as np 6 | from skimage import measure 7 | 8 | TEST_BATCH_SIZE = 1 9 | class SigmoidMetric(): 10 | def __init__(self, score_thresh=0.5): 11 | self.score_thresh = score_thresh 12 | self.reset() 13 | 14 | def update(self, pred, labels): 15 | correct, labeled = self.batch_pix_accuracy(pred, labels) 16 | inter, union = self.batch_intersection_union(pred, labels) 17 | 18 | self.total_correct += correct 19 | self.total_label += labeled 20 | self.total_inter += inter 21 | self.total_union += union 22 | 23 | def get(self): 24 | """Gets the current evaluation result.""" 25 | pixAcc = 1.0 * self.total_correct / (np.spacing(1) + self.total_label) 26 | IoU = 1.0 * self.total_inter / (np.spacing(1) + self.total_union) 27 | mIoU = IoU.mean() 28 | return pixAcc, mIoU 29 | 30 | def reset(self): 31 | """Resets the internal evaluation result to initial state.""" 32 | self.total_inter = 0 33 | self.total_union = 0 34 | self.total_correct = 0 35 | self.total_label = 0 36 | 37 | def batch_pix_accuracy(self, output, target): 38 | assert output.shape == target.shape 39 | output = output.cpu().detach().numpy() 40 | target = target.cpu().detach().numpy() 41 | 42 | predict = (output > self.score_thresh).astype('int64') # P 43 | target = (target > self.score_thresh).astype('int64') 44 | # -----------------------------#在这个之前必须变为0、1 45 | pixel_labeled = np.sum(target > 0) # T 46 | pixel_correct = np.sum((predict == target) * (target > 0)) # TP 47 | assert pixel_correct <= pixel_labeled 48 | return pixel_correct, pixel_labeled 49 | 50 | def batch_intersection_union(self, output, target): 51 | mini = 1 52 | maxi = 1 # nclass 53 | nbins = 1 # nclass 54 | predict = (output.cpu().detach().numpy() > self.score_thresh).astype('int64') # P 55 | target = (target.cpu().detach().numpy() > self.score_thresh).astype('int64') # T 56 | # target = target.cpu().numpy().astype('int64') # T 57 | intersection = predict * (predict == target) # TP 58 | 59 | 60 | # areas of intersection and union 61 | area_inter, _ = np.histogram(intersection, bins=nbins, range=(mini, maxi)) # 统计二值化图像中像素值为1的像素数量 62 | area_pred, _ = np.histogram(predict, bins=nbins, range=(mini, maxi)) 63 | area_lab, _ = np.histogram(target, bins=nbins, range=(mini, maxi)) 64 | area_union = area_pred + area_lab - area_inter 65 | assert (area_inter <= area_union).all() 66 | return area_inter, area_union 67 | 68 | 69 | class SamplewiseSigmoidMetric(): 70 | def __init__(self, nclass, score_thresh=0.5): 71 | self.nclass = nclass 72 | self.score_thresh = score_thresh 73 | self.reset() 74 | 75 | def update(self, preds, labels): 76 | """Updates the internal evaluation result.""" 77 | inter_arr, union_arr = self.batch_intersection_union(preds, labels) 78 | self.total_inter = np.append(self.total_inter, inter_arr) 79 | self.total_union = np.append(self.total_union, union_arr) 80 | 81 | def get(self): 82 | """Gets the current evaluation result.""" 83 | IoU = 1.0 * self.total_inter / (np.spacing(1) + self.total_union) 84 | mIoU = IoU.mean() 85 | return IoU, mIoU 86 | 87 | def reset(self): 88 | """Resets the internal evaluation result to initial state.""" 89 | self.total_inter = np.array([]) 90 | self.total_union = np.array([]) 91 | self.total_correct = np.array([]) 92 | self.total_label = np.array([]) 93 | 94 | def batch_intersection_union(self, output, target): 95 | """nIoU""" 96 | # inputs are tensor 97 | # the category 0 is ignored class, typically for background / boundary 98 | mini = 1 99 | maxi = 1 # nclass 100 | nbins = 1 # nclass 101 | 102 | predict = (output.cpu().detach().numpy() > self.score_thresh).astype('int64') # P 103 | target = (target.cpu().detach().numpy() > self.score_thresh).astype('int64') # T 104 | # target = target.cpu().detach().numpy().astype('int64') # T 105 | intersection = predict * (predict == target) # TP 106 | 107 | num_sample = intersection.shape[0] 108 | area_inter_arr = np.zeros(num_sample) 109 | area_pred_arr = np.zeros(num_sample) 110 | area_lab_arr = np.zeros(num_sample) 111 | area_union_arr = np.zeros(num_sample) 112 | 113 | for b in range(num_sample): 114 | # areas of intersection and union 115 | area_inter, _ = np.histogram(intersection[b], bins=nbins, range=(mini, maxi)) 116 | area_inter_arr[b] = area_inter 117 | 118 | area_pred, _ = np.histogram(predict[b], bins=nbins, range=(mini, maxi)) 119 | area_pred_arr[b] = area_pred 120 | 121 | area_lab, _ = np.histogram(target[b], bins=nbins, range=(mini, maxi)) 122 | area_lab_arr[b] = area_lab 123 | 124 | area_union = area_pred + area_lab - area_inter 125 | area_union_arr[b] = area_union 126 | 127 | assert (area_inter <= area_union).all() 128 | 129 | return area_inter_arr, area_union_arr 130 | 131 | 132 | class PD_FA_2(): 133 | def __init__(self, nclass): 134 | super(PD_FA_2, self).__init__() 135 | self.nclass = nclass 136 | self.image_area_total = [] 137 | self.image_area_match = [] 138 | self.FA = 0 139 | self.PD = 0 140 | self.target = 0 141 | self.all_pixel = 0 142 | def update(self, preds, labels): 143 | 144 | 145 | predits = np.array((preds > 0.5).cpu()).astype('int64') 146 | 147 | for i in range(predits.shape[0]): 148 | self.image_h = predits.shape[-2] 149 | self.image_w = predits.shape[-1] 150 | self.all_pixel += self.image_h * self.image_w 151 | 152 | 153 | labelss = np.array((labels > 0.5).cpu()).astype('int64') # P 154 | 155 | image = measure.label(predits, connectivity=2)# 寻找最大连通域,二维图像当connectivity=2时代表8连通. 156 | # print('image.size', image.shape) 157 | #print(image) 158 | coord_image = measure.regionprops(image)# 返回所有连通区块的属性列表# 属性列表中包含了每个连通区块的一些统计信息,比如面积、中心坐标等 159 | #print(coord_image) 160 | label = measure.label(labelss, connectivity=2) 161 | coord_label = measure.regionprops(label) 162 | 163 | self.target += len(coord_label) # 标签总小目标数len(coord_label) 164 | self.image_area_total = [] 165 | self.image_area_match = [] 166 | self.distance_match = [] 167 | self.dismatch = [] 168 | 169 | for K in range(len(coord_image)): 170 | area_image = np.array(coord_image[K].area) # coord_image[K].area——第K个连通区域(目标)中,区域内像素点总数 171 | self.image_area_total.append(area_image) #预测图像中各连通区域的面积列表 172 | # 比较标签图像中的每个连通域的质心与预测图像中的连通域的质心之间的距离,如果距离小于3,则将预测图像连通域的面积和距离添加到相应的列表中,同时删除已匹配的预测图像连通域。 173 | for i in range(len(coord_label)): 174 | centroid_label = np.array(list(coord_label[i].centroid)) # coord_label[i].centroid标签连通区域i的质心坐标,centroid_label标签中目标的坐标集 175 | for m in range(len(coord_image)): 176 | centroid_image = np.array(list(coord_image[m].centroid)) # coord_image[m].centroid预测连通区域m的质心坐标,centroid_image预测图像中目标的坐标集 177 | distance = np.linalg.norm(centroid_image - centroid_label) #计算当前标签图像连通域 i 的质心与预测图像连通域 m 的质心之间的欧氏距离。 178 | area_image = np.array(coord_image[m].area) # 获取当前预测图像连通域 m 的面积。 179 | if distance < 3: # 如果质心距离小于3(这里的3是一个阈值,可以根据实际情况调整)。 180 | self.distance_match.append(distance) # 将匹配的质心距离添加到 self.distance_match 列表中 181 | self.image_area_match.append(area_image) # 将匹配的预测图像连通域面积添加到 self.image_area_match 列表中。 182 | 183 | del coord_image[m] # 从 coord_image 列表中删除连通域 m,因为它已经匹配到了。 184 | break 185 | 186 | self.dismatch = np.sum(self.image_area_total)-np.sum(self.image_area_match) 187 | self.FA += self.dismatch 188 | self.PD+=len(self.distance_match) # 预测到的小目标数 189 | 190 | def get(self,img_num): 191 | # print("imgae_w:", self.image_w) 192 | # print("imgae_h:", self.image_h) 193 | Final_FA = self.FA / self.all_pixel 194 | Final_PD = self.PD / self.target 195 | #print("预测的目标点PD",self.PD) 196 | #print("预测的目标点target",self.target) 197 | 198 | return Final_FA,Final_PD 199 | 200 | 201 | def reset(self): 202 | self.FA = 0 203 | self.PD = 0 204 | self.target = 0 205 | self.all_pixel = 0 206 | 207 | if __name__ == '__main__': 208 | pred = torch.rand(8, 1, 512, 512) 209 | target = torch.rand(8, 1, 512, 512) 210 | m1 = SigmoidMetric() 211 | m2 = SamplewiseSigmoidMetric(nclass=1, score_thresh=0.5) 212 | m1.update(pred, target) 213 | m2.update(pred, target) 214 | pixAcc, mIoU = m1.get() 215 | _, nIoU = m2.get() 216 | -------------------------------------------------------------------------------- /components/utils_all_edge_copy_paste_final_2_img_path.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchvision 3 | from components.dataset_final_edge_copy_paste_final_2_img_path import SirstDataset 4 | from torch.utils.data import DataLoader 5 | import os 6 | 7 | def make_dir(path): 8 | if os.path.exists(path) == False: 9 | os.makedirs(path) 10 | 11 | 12 | def get_loaders( 13 | train_dir, 14 | train_maskdir, 15 | val_dir, 16 | val_maskdir, 17 | patch_size, 18 | train_batch_size, 19 | test_batch_size, 20 | train_transform, 21 | val_transform, 22 | num_workers=4, 23 | pin_memory=True, 24 | ): 25 | train_ds = SirstDataset( 26 | image_dir=train_dir, 27 | mask_dir=train_maskdir, 28 | patch_size=patch_size, 29 | transform=train_transform, 30 | mode='train', 31 | ) 32 | 33 | train_loader = DataLoader( 34 | train_ds, 35 | batch_size=train_batch_size, 36 | num_workers=num_workers, 37 | pin_memory=pin_memory, 38 | shuffle=True, 39 | ) 40 | 41 | val_ds = SirstDataset( 42 | image_dir=val_dir, 43 | mask_dir=val_maskdir, 44 | patch_size=patch_size, 45 | transform=val_transform, 46 | mode='val', 47 | ) 48 | 49 | val_loader = DataLoader( 50 | val_ds, 51 | batch_size=test_batch_size, 52 | num_workers=num_workers, 53 | pin_memory=pin_memory, 54 | shuffle=False, 55 | ) 56 | 57 | return train_loader, val_loader 58 | 59 | -------------------------------------------------------------------------------- /imgs/Main results.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YuChuang1205/PAL/a9046c5fc65e7b2548d6e195b624239fa836cf3e/imgs/Main results.png -------------------------------------------------------------------------------- /imgs/PAL framework.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YuChuang1205/PAL/a9046c5fc65e7b2548d6e195b624239fa836cf3e/imgs/PAL framework.png -------------------------------------------------------------------------------- /imgs/Results on the SIRST3 with centroid point label.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YuChuang1205/PAL/a9046c5fc65e7b2548d6e195b624239fa836cf3e/imgs/Results on the SIRST3 with centroid point label.png -------------------------------------------------------------------------------- /imgs/Results on the SIRST3 with coarse point label.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YuChuang1205/PAL/a9046c5fc65e7b2548d6e195b624239fa836cf3e/imgs/Results on the SIRST3 with coarse point label.png -------------------------------------------------------------------------------- /imgs/Results on the three separate dataset with centroid point label.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YuChuang1205/PAL/a9046c5fc65e7b2548d6e195b624239fa836cf3e/imgs/Results on the three separate dataset with centroid point label.png -------------------------------------------------------------------------------- /imgs/Results on the three separate dataset with coarse point label.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YuChuang1205/PAL/a9046c5fc65e7b2548d6e195b624239fa836cf3e/imgs/Results on the three separate dataset with coarse point label.png -------------------------------------------------------------------------------- /imgs/Visualization on the SIRST3 with centroid point label.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YuChuang1205/PAL/a9046c5fc65e7b2548d6e195b624239fa836cf3e/imgs/Visualization on the SIRST3 with centroid point label.png -------------------------------------------------------------------------------- /imgs/Visualization on the SIRST3 with coarse point label.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YuChuang1205/PAL/a9046c5fc65e7b2548d6e195b624239fa836cf3e/imgs/Visualization on the SIRST3 with coarse point label.png -------------------------------------------------------------------------------- /loss/Edge_loss.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python3 2 | # coding = gbk 3 | """ 4 | @Author : yuchuang 5 | @Time : 2024/3/30 22:07 6 | @desc: 7 | """ 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | import numpy as np 12 | import torch.nn as nn 13 | from segmentation_models_pytorch.losses import DiceLoss, SoftCrossEntropyLoss, LovaszLoss,SoftBCEWithLogitsLoss 14 | 15 | def edgeSCE_loss(pred, target, edge): 16 | 17 | BinaryCrossEntropy_fn = SoftBCEWithLogitsLoss(smooth_factor=None, reduction='None') 18 | 19 | edge_weight = 4. 20 | loss_sce = BinaryCrossEntropy_fn(pred, target) 21 | #print(loss_sce.size()) 22 | #print(edge.size()) 23 | edge = edge.clone() 24 | edge[edge == 0] = 1. 25 | edge[edge > 0] = edge_weight 26 | loss_sce *= edge 27 | 28 | loss_sce_, ind = loss_sce.contiguous().view(-1).sort() 29 | min_value = loss_sce_[int(0.5 * loss_sce.numel())] 30 | loss_sce = loss_sce[loss_sce >= min_value] 31 | loss_sce = loss_sce.mean() 32 | loss = loss_sce 33 | return loss 34 | 35 | # if __name__ == '__main__': 36 | # target=torch.ones((2,1,256,256),dtype=torch.float32) 37 | # input=(torch.ones((2,1,256,256))*0.9) 38 | # input[0,0,0,0] = 0.99 39 | # loss=edgeBCE_Dice_loss(input,target,target*255) 40 | # print(loss) 41 | 42 | -------------------------------------------------------------------------------- /loss/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YuChuang1205/PAL/a9046c5fc65e7b2548d6e195b624239fa836cf3e/loss/__init__.py -------------------------------------------------------------------------------- /loss/__pycache__/Edge_loss.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YuChuang1205/PAL/a9046c5fc65e7b2548d6e195b624239fa836cf3e/loss/__pycache__/Edge_loss.cpython-36.pyc -------------------------------------------------------------------------------- /loss/__pycache__/Edge_loss.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YuChuang1205/PAL/a9046c5fc65e7b2548d6e195b624239fa836cf3e/loss/__pycache__/Edge_loss.cpython-38.pyc -------------------------------------------------------------------------------- /loss/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YuChuang1205/PAL/a9046c5fc65e7b2548d6e195b624239fa836cf3e/loss/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /loss/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YuChuang1205/PAL/a9046c5fc65e7b2548d6e195b624239fa836cf3e/loss/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /mm/attention/SEAttention.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torch import nn 4 | from torch.nn import init 5 | 6 | 7 | 8 | class SEAttention(nn.Module): 9 | 10 | def __init__(self, channel,reduction=16): 11 | super().__init__() 12 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 13 | self.fc = nn.Sequential( 14 | nn.Linear(channel, channel // reduction, bias=False), 15 | nn.ReLU(inplace=True), 16 | nn.Linear(channel // reduction, channel, bias=False), 17 | nn.Sigmoid() 18 | ) 19 | 20 | 21 | def init_weights(self): 22 | for m in self.modules(): 23 | if isinstance(m, nn.Conv2d): 24 | init.kaiming_normal_(m.weight, mode='fan_out') 25 | if m.bias is not None: 26 | init.constant_(m.bias, 0) 27 | elif isinstance(m, nn.BatchNorm2d): 28 | init.constant_(m.weight, 1) 29 | init.constant_(m.bias, 0) 30 | elif isinstance(m, nn.Linear): 31 | init.normal_(m.weight, std=0.001) 32 | if m.bias is not None: 33 | init.constant_(m.bias, 0) 34 | 35 | def forward(self, x): 36 | b, c, _, _ = x.size() 37 | y = self.avg_pool(x).view(b, c) 38 | y = self.fc(y).view(b, c, 1, 1) 39 | return x * y.expand_as(x) 40 | 41 | 42 | if __name__ == '__main__': 43 | input=torch.randn(50,512,7,7) 44 | se = SEAttention(channel=512,reduction=8) 45 | output=se(input) 46 | print(output.shape) 47 | 48 | -------------------------------------------------------------------------------- /mm/attention/__pycache__/SEAttention.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YuChuang1205/PAL/a9046c5fc65e7b2548d6e195b624239fa836cf3e/mm/attention/__pycache__/SEAttention.cpython-36.pyc -------------------------------------------------------------------------------- /mm/attention/__pycache__/SEAttention.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YuChuang1205/PAL/a9046c5fc65e7b2548d6e195b624239fa836cf3e/mm/attention/__pycache__/SEAttention.cpython-38.pyc -------------------------------------------------------------------------------- /model/ACM/ACM_no_sigmoid.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | from torch.nn import BatchNorm2d 5 | from torchvision.models.resnet import BasicBlock 6 | from .fusion import AsymBiChaFuse 7 | import torch.nn.functional as F 8 | # from model.utils import init_weights, count_param 9 | 10 | class ACM_No_Sigmoid(nn.Module): 11 | def __init__(self, in_channels=3, layers=[3,3,3], channels=[8,16,32,64], fuse_mode='AsymBi', tiny=False, classes=1, 12 | norm_layer=BatchNorm2d,groups=1, norm_kwargs=None, **kwargs): 13 | super(ACM_No_Sigmoid, self).__init__() 14 | self.layer_num = len(layers) 15 | self.tiny = tiny 16 | self._norm_layer = norm_layer 17 | self.groups = groups 18 | self.momentum=0.9 19 | stem_width = int(channels[0]) ##channels: 8 16 32 64 20 | # self.stem.add(norm_layer(scale=False, center=False,**({} if norm_kwargs is None else norm_kwargs))) 21 | if tiny: # 默认是False 22 | self.stem = nn.Sequential( 23 | norm_layer(in_channels,self.momentum), 24 | nn.Conv2d(in_channels, out_channels=stem_width * 2, kernel_size=3, stride=1,padding=1, bias=False), 25 | norm_layer(stem_width * 2, momentum=self.momentum), 26 | nn.ReLU(inplace=True) 27 | ) 28 | else: 29 | self.stem = nn.Sequential( 30 | # self.stem.add(nn.Conv2D(channels=stem_width*2, kernel_size=3, strides=2, 31 | # padding=1, use_bias=False)) 32 | # self.stem.add(norm_layer(in_channels=stem_width*2)) 33 | # self.stem.add(nn.Activation('relu')) 34 | # self.stem.add(nn.MaxPool2D(pool_size=3, strides=2, padding=1)) 35 | norm_layer(in_channels, momentum=self.momentum), 36 | nn.Conv2d(in_channels=in_channels,out_channels=stem_width, kernel_size=3, stride=2,padding=1, bias=False), 37 | norm_layer(stem_width,momentum=self.momentum), 38 | nn.ReLU(inplace=True), 39 | nn.Conv2d(in_channels=stem_width,out_channels=stem_width, kernel_size=3, stride=1,padding=1, bias=False), 40 | norm_layer(stem_width,momentum=self.momentum), 41 | nn.ReLU(inplace=True), 42 | nn.Conv2d(in_channels=stem_width,out_channels=stem_width * 2, kernel_size=3, stride=1,padding=1, bias=False), 43 | norm_layer(stem_width * 2,momentum=self.momentum), 44 | nn.ReLU(inplace=True), 45 | nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 46 | ) 47 | 48 | self.layer1 = self._make_layer(block=BasicBlock, blocks=layers[0], 49 | out_channels=channels[1], 50 | in_channels=channels[1], stride=1) 51 | 52 | self.layer2 = self._make_layer(block=BasicBlock, blocks=layers[1], 53 | out_channels=channels[2], stride=2, 54 | in_channels=channels[1]) 55 | # 56 | self.layer3 = self._make_layer(block=BasicBlock, blocks=layers[2], 57 | out_channels=channels[3], stride=2, 58 | in_channels=channels[2]) 59 | 60 | self.deconv2 = nn.ConvTranspose2d(in_channels=channels[3] ,out_channels=channels[2], kernel_size=(4, 4), ##channels: 8 16 32 64 61 | stride=2, padding=1) 62 | self.uplayer2 = self._make_layer(block=BasicBlock, blocks=layers[1], 63 | out_channels=channels[2], stride=1, 64 | in_channels=channels[2]) 65 | self.fuse2 = self._fuse_layer(fuse_mode, channels=channels[2]) 66 | 67 | self.deconv1 = nn.ConvTranspose2d(in_channels=channels[2] ,out_channels=channels[1], kernel_size=(4, 4), 68 | stride=2, padding=1) 69 | self.uplayer1 = self._make_layer(block=BasicBlock, blocks=layers[0], 70 | out_channels=channels[1], stride=1, 71 | in_channels=channels[1]) 72 | self.fuse1 = self._fuse_layer(fuse_mode, channels=channels[1]) 73 | 74 | self.head = _FCNHead(in_channels=channels[1], channels=classes, momentum=self.momentum) 75 | 76 | 77 | def _make_layer(self, block, out_channels, in_channels, blocks, stride): 78 | 79 | norm_layer = self._norm_layer 80 | downsample = None 81 | 82 | if stride != 1 or out_channels != in_channels: 83 | downsample = nn.Sequential( 84 | conv1x1(in_channels, out_channels , stride), 85 | norm_layer(out_channels * block.expansion, momentum=self.momentum), 86 | ) 87 | 88 | layers = [] 89 | layers.append(block(in_channels, out_channels, stride, downsample, self.groups, norm_layer=norm_layer)) 90 | self.inplanes = out_channels * block.expansion 91 | for _ in range(1, blocks): 92 | layers.append(block(self.inplanes, out_channels, self.groups, norm_layer=norm_layer)) 93 | return nn.Sequential(*layers) 94 | 95 | def _fuse_layer(self, fuse_mode, channels): 96 | 97 | if fuse_mode == 'AsymBi': 98 | fuse_layer = AsymBiChaFuse(channels=channels) 99 | else: 100 | raise ValueError('Unknown fuse_mode') 101 | return fuse_layer 102 | 103 | def forward(self, x): 104 | 105 | _, _, hei, wid = x.shape 106 | 107 | x = self.stem(x) # (4,16,120,120) 108 | c1 = self.layer1(x) # (4,16,120,120) 109 | c2 = self.layer2(c1) # (4,32, 60, 60) 110 | c3 = self.layer3(c2) # (4,64, 30, 30) 111 | 112 | deconvc2 = self.deconv2(c3) # (4,32, 60, 60) 113 | fusec2 = self.fuse2(deconvc2, c2) # (4,32, 60, 60) 114 | upc2 = self.uplayer2(fusec2) # (4,32, 60, 60) 115 | 116 | deconvc1 = self.deconv1(upc2) # (4,16,120,120) 117 | fusec1 = self.fuse1(deconvc1, c1) # (4,16,120,120) 118 | upc1 = self.uplayer1(fusec1) # (4,16,120,120) 119 | 120 | pred = self.head(upc1) # (4,1,120,120) 121 | 122 | if self.tiny: 123 | out = pred 124 | else: 125 | # out = F.contrib.BilinearResize2D(pred, height=hei, width=wid) # down 4 126 | out = F.interpolate(pred, scale_factor=4, mode='bilinear') # down 4 # (4,1,480,480) 127 | 128 | return out 129 | 130 | def evaluate(self, x): 131 | """evaluating network with inputs and targets""" 132 | return self.forward(x) 133 | 134 | 135 | class _FCNHead(nn.Module): 136 | # pylint: disable=redefined-outer-name 137 | def __init__(self, in_channels, channels, momentum, norm_layer=nn.BatchNorm2d, norm_kwargs=None, **kwargs): 138 | super(_FCNHead, self).__init__() 139 | inter_channels = in_channels // 4 140 | self.block = nn.Sequential( 141 | nn.Conv2d(in_channels=in_channels, out_channels=inter_channels,kernel_size=3, padding=1, bias=False), 142 | norm_layer(inter_channels, momentum=momentum), 143 | nn.ReLU(inplace=True), 144 | nn.Dropout(0.1), 145 | nn.Conv2d(in_channels=inter_channels, out_channels=channels,kernel_size=1) 146 | ) 147 | # pylint: disable=arguments-differ 148 | def forward(self, x): 149 | return self.block(x) 150 | 151 | def conv1x1(in_planes, out_planes, stride=1): 152 | """1x1 convolution""" 153 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) 154 | 155 | 156 | 157 | # ######################################################### 158 | # ###2.测试ASKCResUNet 159 | # if __name__ == '__main__': 160 | # DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") # 让torch判断是否使用GPU,建议使用GPU环境,因为会快很多 161 | # layers = [3] * 3 162 | # channels = [x * 1 for x in [8, 16, 32, 64]] 163 | # in_channels = 3 164 | # model= ASKCResUNet(in_channels, layers=layers, channels=channels, fuse_mode='AsymBi',tiny=False, classes=1) 165 | # 166 | # model=model.cuda() 167 | # DATA = torch.randn(8,3,480,480).to(DEVICE) 168 | # 169 | # output=model(DATA) 170 | # print("output:",np.shape(output)) 171 | # ########################################################## -------------------------------------------------------------------------------- /model/ACM/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /model/ACM/__pycache__/ACM_no_sigmoid.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YuChuang1205/PAL/a9046c5fc65e7b2548d6e195b624239fa836cf3e/model/ACM/__pycache__/ACM_no_sigmoid.cpython-36.pyc -------------------------------------------------------------------------------- /model/ACM/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YuChuang1205/PAL/a9046c5fc65e7b2548d6e195b624239fa836cf3e/model/ACM/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /model/ACM/__pycache__/fusion.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YuChuang1205/PAL/a9046c5fc65e7b2548d6e195b624239fa836cf3e/model/ACM/__pycache__/fusion.cpython-36.pyc -------------------------------------------------------------------------------- /model/ACM/fusion.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from torch.nn import BatchNorm2d 3 | import torch 4 | from torch.nn import GroupNorm 5 | class AsymBiChaFuse(nn.Module): 6 | def __init__(self, channels=64, r=4): 7 | super(AsymBiChaFuse, self).__init__() 8 | self.channels = channels 9 | self.bottleneck_channels = int(channels // r) 10 | 11 | 12 | self.topdown = nn.Sequential( 13 | nn.AdaptiveAvgPool2d(1), 14 | nn.Conv2d(in_channels=self.channels, out_channels=self.bottleneck_channels, kernel_size=1, stride=1,padding=0), 15 | # nn.BatchNorm2d(self.bottleneck_channels,momentum=0.9), 16 | nn.GroupNorm(1, self.bottleneck_channels), 17 | nn.ReLU(inplace=True), 18 | nn.Conv2d(in_channels=self.bottleneck_channels,out_channels=self.channels, kernel_size=1, stride=1,padding=0), 19 | #nn.BatchNorm2d(self.channels,momentum=0.9), 20 | nn.GroupNorm(1, self.channels), 21 | nn.Sigmoid() 22 | ) 23 | 24 | self.bottomup = nn.Sequential( 25 | nn.Conv2d(in_channels=self.channels,out_channels=self.bottleneck_channels, kernel_size=1, stride=1,padding=0), 26 | #nn.BatchNorm2d(self.bottleneck_channels,momentum=0.9), 27 | nn.GroupNorm(1, self.bottleneck_channels), 28 | nn.ReLU(inplace=True), 29 | nn.Conv2d(in_channels=self.bottleneck_channels,out_channels=self.channels, kernel_size=1, stride=1,padding=0), 30 | #nn.BatchNorm2d(self.channels,momentum=0.9), 31 | nn.GroupNorm(1, self.channels), 32 | nn.Sigmoid() 33 | ) 34 | 35 | self.post = nn.Sequential( 36 | nn.Conv2d(in_channels=channels,out_channels=channels, kernel_size=3, stride=1, padding=1, dilation=1), 37 | #nn.BatchNorm2d(channels,momentum=0.9), 38 | nn.GroupNorm(1, self.channels), 39 | nn.ReLU(inplace=True) 40 | ) 41 | 42 | def forward(self, xh, xl): 43 | 44 | topdown_wei = self.topdown(xh) 45 | bottomup_wei = self.bottomup(xl) 46 | xs = 2 * torch.mul(xl, topdown_wei) + 2 * torch.mul(xh, bottomup_wei) 47 | xs = self.post(xs) 48 | return xs 49 | 50 | # from mxnet.gluon.block import HybridBlock 51 | # topdown = HybridBlock(prefix='topdown') 52 | # print(topdown) -------------------------------------------------------------------------------- /model/ALC/ALC_no_sigmoid.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | import os 3 | from torch.nn.modules import module 4 | import torch 5 | import torch.nn as nn 6 | from torch.nn import BatchNorm2d 7 | from .fusion import AsymBiChaFuse 8 | 9 | # from mxnet import nd 10 | from torchvision import transforms 11 | from torchvision.models.resnet import BasicBlock 12 | 13 | class _FCNHead(nn.Module): 14 | # pylint: disable=redefined-outer-name 15 | def __init__(self, in_channels, channels, momentum, norm_layer=nn.BatchNorm2d, norm_kwargs=None, **kwargs): 16 | super(_FCNHead, self).__init__() 17 | inter_channels = in_channels // 4 18 | self.block = nn.Sequential( 19 | nn.Conv2d(in_channels=in_channels, out_channels=inter_channels,kernel_size=3, padding=1, bias=False), 20 | norm_layer(inter_channels, momentum=momentum), 21 | nn.ReLU(inplace=True), 22 | nn.Dropout(0.1), 23 | nn.Conv2d(in_channels=inter_channels, out_channels=channels,kernel_size=1) 24 | ) 25 | # pylint: disable=arguments-differ 26 | def forward(self, x): 27 | return self.block(x) 28 | 29 | def conv1x1(in_planes, out_planes, stride=1): 30 | """1x1 convolution""" 31 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) 32 | 33 | 34 | class ALC_No_Sigmoid(nn.Module): 35 | def __init__(self, in_channels=3, layers=[4,4,4], channels=[8,16,32,64], fuse_mode='AsymBi', act_dilation=16, classes=1, tinyFlag=False, 36 | norm_layer=BatchNorm2d,groups=1,norm_kwargs=None, **kwargs): 37 | super(ALC_No_Sigmoid, self).__init__() 38 | 39 | self.layer_num = len(layers) 40 | self.tinyFlag = tinyFlag 41 | self.groups = groups 42 | self._norm_layer = norm_layer 43 | stem_width = int(channels[0]) 44 | self.momentum=0.9 45 | if tinyFlag: 46 | self.stem = nn.Sequential( 47 | norm_layer(in_channels, self.momentum), 48 | nn.Conv2d(in_channels, out_channels=stem_width * 2, kernel_size=3, stride=1, padding=1, bias=False), 49 | norm_layer(stem_width * 2, momentum=self.momentum), 50 | nn.ReLU(inplace=True) 51 | ) 52 | 53 | else: 54 | self.stem = nn.Sequential( 55 | # self.stem.add(nn.Conv2D(channels=stem_width*2, kernel_size=3, strides=2, 56 | # padding=1, use_bias=False)) 57 | # self.stem.add(norm_layer(in_channels=stem_width*2)) 58 | # self.stem.add(nn.Activation('relu')) 59 | # self.stem.add(nn.MaxPool2D(pool_size=3, strides=2, padding=1)) 60 | norm_layer(in_channels, momentum=self.momentum), 61 | nn.Conv2d(in_channels=in_channels, out_channels=stem_width, kernel_size=3, stride=2, padding=1, bias=False), 62 | norm_layer(stem_width, momentum=self.momentum), 63 | nn.ReLU(inplace=True), 64 | nn.Conv2d(in_channels=stem_width, out_channels=stem_width, kernel_size=3, stride=1, padding=1, 65 | bias=False), 66 | norm_layer(stem_width, momentum=self.momentum), 67 | nn.ReLU(inplace=True), 68 | nn.Conv2d(in_channels=stem_width, out_channels=stem_width * 2, kernel_size=3, stride=1, padding=1, 69 | bias=False), 70 | norm_layer(stem_width * 2, momentum=self.momentum), 71 | nn.ReLU(inplace=True), 72 | nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 73 | ) 74 | 75 | 76 | # self.head1 = _FCNHead(in_channels=channels[1], channels=classes) 77 | # self.head2 = _FCNHead(in_channels=channels[2], channels=classes) 78 | # self.head3 = _FCNHead(in_channels=channels[3], channels=classes) 79 | # self.head4 = _FCNHead(in_channels=channels[4], channels=classes) 80 | 81 | self.head = _FCNHead(in_channels=channels[0], channels=classes, momentum=self.momentum) 82 | 83 | self.layer1 = self._make_layer(block=BasicBlock, blocks=layers[0], 84 | out_channels=channels[1], 85 | in_channels=channels[1], stride=1) 86 | 87 | self.layer2 = self._make_layer(block=BasicBlock, blocks=layers[1], 88 | out_channels=channels[2], stride=2, 89 | in_channels=channels[1]) 90 | # 91 | self.layer3 = self._make_layer(block=BasicBlock, blocks=layers[2], 92 | out_channels=channels[3], stride=2, 93 | in_channels=channels[2]) 94 | self.deconv2 = nn.ConvTranspose2d(in_channels=channels[3], out_channels=channels[2], kernel_size=(4, 4), 95 | ##channels: 8 16 32 64 96 | stride=2, padding=1) 97 | self.uplayer2 = self._make_layer(block=BasicBlock, blocks=layers[1], 98 | out_channels=channels[2], stride=1, 99 | in_channels=channels[2]) 100 | 101 | 102 | self.deconv1 = nn.ConvTranspose2d(in_channels=channels[2], out_channels=channels[1], kernel_size=(4, 4), 103 | stride=2, padding=1) 104 | 105 | self.deconv0 = nn.ConvTranspose2d(in_channels=channels[1], out_channels=channels[0], kernel_size=(4, 4), 106 | stride=2, padding=1) 107 | 108 | self.uplayer1 = self._make_layer(block=BasicBlock, blocks=layers[0], 109 | out_channels=channels[1], stride=1, 110 | in_channels=channels[1]) 111 | 112 | 113 | if self.layer_num == 4: 114 | self.layer4 = self._make_layer(block=BasicBlock, blocks=layers[3], 115 | out_channels=channels[3], stride=2, 116 | in_channels=channels[3]) 117 | 118 | if self.layer_num == 4: 119 | self.fuse34 = self._fuse_layer(fuse_mode, channels=channels[3]) # channels[4] 120 | 121 | self.fuse23 = self._fuse_layer(fuse_mode, channels=channels[2]) # 64 122 | self.fuse12 = self._fuse_layer(fuse_mode, channels=channels[1]) # 32 123 | 124 | # if fuse_order == 'reverse': 125 | # self.fuse12 = self._fuse_layer(fuse_mode, channels=channels[2]) # channels[2] 126 | # self.fuse23 = self._fuse_layer(fuse_mode, channels=channels[3]) # channels[3] 127 | # self.fuse34 = self._fuse_layer(fuse_mode, channels=channels[4]) # channels[4] 128 | # elif fuse_order == 'normal': 129 | # self.fuse34 = self._fuse_layer(fuse_mode, channels=channels[4]) # channels[4] 130 | # self.fuse23 = self._fuse_layer(fuse_mode, channels=channels[4]) # channels[4] 131 | # self.fuse12 = self._fuse_layer(fuse_mode, channels=channels[4]) # channels[4] 132 | 133 | def _make_layer(self, block, out_channels, in_channels, blocks, stride): 134 | 135 | norm_layer = self._norm_layer 136 | downsample = None 137 | 138 | if stride != 1 or out_channels != in_channels: 139 | downsample = nn.Sequential( 140 | conv1x1(in_channels, out_channels , stride), 141 | norm_layer(out_channels * block.expansion, momentum=self.momentum), 142 | ) 143 | 144 | layers = [] 145 | layers.append(block(in_channels, out_channels, stride, downsample, self.groups, norm_layer=norm_layer)) 146 | self.inplanes = out_channels * block.expansion 147 | for _ in range(1, blocks): 148 | layers.append(block(self.inplanes, out_channels, self.groups, norm_layer=norm_layer)) 149 | return nn.Sequential(*layers) 150 | 151 | def _fuse_layer(self, fuse_mode, channels): 152 | 153 | if fuse_mode == 'AsymBi': 154 | fuse_layer = AsymBiChaFuse(channels=channels) 155 | else: 156 | raise ValueError('Unknown fuse_mode') 157 | return fuse_layer 158 | 159 | def forward(self, x): 160 | 161 | _, _, hei, wid = x.shape# 1024 1024 162 | 163 | x = self.stem(x) #torch.Size([8, 16, 256, 256]) 164 | c1 = self.layer1(x) # torch.Size([8, 16, 256, 256]) 165 | c2 = self.layer2(c1) # torch.Size([8, 32, 128, 128]) 166 | 167 | out = self.layer3(c2) # (8,64, 64, 64) 168 | 169 | if self.layer_num == 4: 170 | c4 = self.layer4(out) # torch.Size([8,64, 32, 32]) 171 | if self.tinyFlag: 172 | c4 = transforms.Resize([hei//4, wid//4])(c4) # down 4 173 | else: 174 | c4 = transforms.Resize([hei//16, wid//16])(c4) # down 16 torch.Size([8, 64, 64, 64]) 175 | out = self.fuse34(c4, out) #torch.Size([8, 64, 128, 128])` 176 | 177 | if self.tinyFlag: 178 | out = transforms.Resize([hei//2, wid//2])(out) # down 16 torch.Size([8, 64, 64, 64]) 179 | else: 180 | out = transforms.Resize([hei//16, wid//16])(out) # down 8, 128 torch.Size([8, 64, 64, 64]) 181 | 182 | out = self.deconv2(out) # torch.Size([8, 32, 128, 128]) 183 | out = self.fuse23(out, c2) # torch.Size([8, 32, 128, 128]) 184 | if self.tinyFlag: 185 | out = transforms.Resize([hei, wid])(out) # down 1 186 | else: 187 | out = transforms.Resize( [hei//8, wid//8])(out) # (4,16,120,120) 188 | 189 | out = self.deconv1(out) # torch.Size([8, 16, 256, 256]) 190 | out = self.fuse12(out, c1) # torch.Size([8, 16, 256, 256]) 191 | 192 | out = self.deconv0(out) # torch.Size([8, 8, 512, 512]) 193 | pred = self.head(out) # torch.Size([8, 8, 512, 512]) 194 | 195 | 196 | if self.tinyFlag: 197 | out = pred 198 | else: 199 | out = transforms.Resize( [hei, wid])(pred) # down 4 200 | 201 | ######### reverse order ########## 202 | # up_c2 = F.contrib.BilinearResize2D(c2, height=hei//4, width=wid//4) # down 4 203 | # fuse2 = self.fuse12(up_c2, c1) # down 4, channels[2] 204 | # 205 | # up_c3 = F.contrib.BilinearResize2D(c3, height=hei//4, width=wid//4) # down 4 206 | # fuse3 = self.fuse23(up_c3, fuse2) # down 4, channels[3] 207 | # 208 | # up_c4 = F.contrib.BilinearResize2D(c4, height=hei//4, width=wid//4) # down 4 209 | # fuse4 = self.fuse34(up_c4, fuse3) # down 4, channels[4] 210 | # 211 | 212 | ######### normal order ########## 213 | # out = F.contrib.BilinearResize2D(c4, height=hei//16, width=wid//16) 214 | # out = self.fuse34(out, c3) 215 | # out = F.contrib.BilinearResize2D(out, height=hei//8, width=wid//8) 216 | # out = self.fuse23(out, c2) 217 | # out = F.contrib.BilinearResize2D(out, height=hei//4, width=wid//4) 218 | # out = self.fuse12(out, c1) 219 | # out = self.head(out) 220 | # out = F.contrib.BilinearResize2D(out, height=hei, width=wid) 221 | 222 | 223 | return out 224 | 225 | def evaluate(self, x): 226 | """evaluating network with inputs and targets""" 227 | return self.forward(x) 228 | 229 | 230 | 231 | -------------------------------------------------------------------------------- /model/ALC/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /model/ALC/__pycache__/ALC_no_sigmoid.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YuChuang1205/PAL/a9046c5fc65e7b2548d6e195b624239fa836cf3e/model/ALC/__pycache__/ALC_no_sigmoid.cpython-36.pyc -------------------------------------------------------------------------------- /model/ALC/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YuChuang1205/PAL/a9046c5fc65e7b2548d6e195b624239fa836cf3e/model/ALC/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /model/ALC/__pycache__/fusion.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YuChuang1205/PAL/a9046c5fc65e7b2548d6e195b624239fa836cf3e/model/ALC/__pycache__/fusion.cpython-36.pyc -------------------------------------------------------------------------------- /model/ALC/fusion.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from torch.nn import BatchNorm2d 3 | import torch 4 | from torch.nn import GroupNorm 5 | class AsymBiChaFuse(nn.Module): 6 | def __init__(self, channels=64, r=4): 7 | super(AsymBiChaFuse, self).__init__() 8 | self.channels = channels 9 | self.bottleneck_channels = int(channels // r) 10 | 11 | 12 | self.topdown = nn.Sequential( 13 | nn.AdaptiveAvgPool2d(1), 14 | nn.Conv2d(in_channels=self.channels, out_channels=self.bottleneck_channels, kernel_size=1, stride=1,padding=0), 15 | # nn.BatchNorm2d(self.bottleneck_channels,momentum=0.9), 16 | nn.GroupNorm(1, self.bottleneck_channels), 17 | nn.ReLU(inplace=True), 18 | nn.Conv2d(in_channels=self.bottleneck_channels,out_channels=self.channels, kernel_size=1, stride=1,padding=0), 19 | #nn.BatchNorm2d(self.channels,momentum=0.9), 20 | nn.GroupNorm(1, self.channels), 21 | nn.Sigmoid() 22 | ) 23 | 24 | self.bottomup = nn.Sequential( 25 | nn.Conv2d(in_channels=self.channels,out_channels=self.bottleneck_channels, kernel_size=1, stride=1,padding=0), 26 | #nn.BatchNorm2d(self.bottleneck_channels,momentum=0.9), 27 | nn.GroupNorm(1, self.bottleneck_channels), 28 | nn.ReLU(inplace=True), 29 | nn.Conv2d(in_channels=self.bottleneck_channels,out_channels=self.channels, kernel_size=1, stride=1,padding=0), 30 | #nn.BatchNorm2d(self.channels,momentum=0.9), 31 | nn.GroupNorm(1, self.channels), 32 | nn.Sigmoid() 33 | ) 34 | 35 | self.post = nn.Sequential( 36 | nn.Conv2d(in_channels=channels,out_channels=channels, kernel_size=3, stride=1, padding=1, dilation=1), 37 | #nn.BatchNorm2d(channels,momentum=0.9), 38 | nn.GroupNorm(1, self.channels), 39 | nn.ReLU(inplace=True) 40 | ) 41 | 42 | def forward(self, xh, xl): 43 | 44 | topdown_wei = self.topdown(xh) 45 | bottomup_wei = self.bottomup(xl) 46 | xs = 2 * torch.mul(xl, topdown_wei) + 2 * torch.mul(xh, bottomup_wei) 47 | xs = self.post(xs) 48 | return xs 49 | 50 | # from mxnet.gluon.block import HybridBlock 51 | # topdown = HybridBlock(prefix='topdown') 52 | # print(topdown) -------------------------------------------------------------------------------- /model/ALCL/ALCL_no_sigmoid.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torchvision.transforms.functional as TF 4 | 5 | 6 | # from torchinfo import summary 7 | 8 | class Resnet1(nn.Module): 9 | def __init__(self, in_channel, out_channel): 10 | super(Resnet1, self).__init__() 11 | self.layer = nn.Sequential( 12 | nn.Conv2d(in_channels=in_channel, out_channels=out_channel, kernel_size=3, padding=1, stride=1), 13 | nn.BatchNorm2d(out_channel), 14 | nn.ReLU(inplace=True), 15 | nn.Conv2d(in_channels=out_channel, out_channels=out_channel, kernel_size=3, padding=1, stride=1), 16 | nn.BatchNorm2d(out_channel) 17 | ) 18 | self.relu = nn.ReLU(inplace=True) 19 | self.layer.apply(weights_init) 20 | 21 | def forward(self, x): 22 | identity = x 23 | out = self.layer(x) 24 | out += identity 25 | return self.relu(out) 26 | 27 | 28 | class Resnet2(nn.Module): 29 | def __init__(self, in_channel, out_channel): 30 | super(Resnet2, self).__init__() 31 | self.layer1 = nn.Sequential( 32 | nn.Conv2d(in_channels=in_channel, out_channels=out_channel, kernel_size=3, padding=1, stride=1), 33 | nn.BatchNorm2d(out_channel), 34 | nn.ReLU(inplace=True), 35 | nn.MaxPool2d(kernel_size=2, stride=2), 36 | nn.Conv2d(in_channels=out_channel, out_channels=out_channel, kernel_size=3, padding=1, stride=1), 37 | nn.BatchNorm2d(out_channel) 38 | ) 39 | self.layer2 = nn.Sequential( 40 | nn.Conv2d(in_channels=in_channel, out_channels=out_channel, kernel_size=3, padding=1, stride=2), 41 | nn.BatchNorm2d(out_channel), 42 | nn.ReLU(inplace=True) 43 | ) 44 | self.relu = nn.ReLU(inplace=True) 45 | self.layer1.apply(weights_init) 46 | self.layer2.apply(weights_init) 47 | 48 | def forward(self, x): 49 | identity = x 50 | out = self.layer1(x) 51 | identity = self.layer2(identity) 52 | out += identity 53 | return self.relu(out) 54 | 55 | 56 | class Stage(nn.Module): 57 | def __init__(self): 58 | super(Stage, self).__init__() 59 | self.layer1 = nn.Sequential( 60 | nn.Conv2d(in_channels=3, out_channels=16, kernel_size=3, padding=1, stride=1), 61 | nn.BatchNorm2d(16), 62 | nn.ReLU(inplace=True), 63 | nn.Conv2d(in_channels=16, out_channels=16, kernel_size=3, padding=1, stride=1), 64 | nn.BatchNorm2d(16), 65 | nn.ReLU(inplace=True) 66 | ) 67 | self.resnet1_1 = Resnet1(in_channel=16, out_channel=16) 68 | self.resnet1_2 = Resnet1(in_channel=16, out_channel=16) 69 | self.resnet1_3 = Resnet1(in_channel=16, out_channel=16) 70 | self.resnet2_1 = Resnet2(in_channel=16, out_channel=32) 71 | self.resnet2_2 = Resnet1(in_channel=32, out_channel=32) 72 | self.resnet2_3 = Resnet1(in_channel=32, out_channel=32) 73 | self.resnet3_1 = Resnet2(in_channel=32, out_channel=64) 74 | self.resnet3_2 = Resnet1(in_channel=64, out_channel=64) 75 | self.resnet3_3 = Resnet1(in_channel=64, out_channel=64) 76 | self.resnet4_1 = Resnet2(in_channel=64, out_channel=128) 77 | self.resnet4_2 = Resnet1(in_channel=128, out_channel=128) 78 | self.resnet4_3 = Resnet1(in_channel=128, out_channel=128) 79 | self.resnet5_1 = Resnet2(in_channel=128, out_channel=256) 80 | self.resnet5_2 = Resnet1(in_channel=256, out_channel=256) 81 | self.resnet5_3 = Resnet1(in_channel=256, out_channel=256) 82 | self.layer1.apply(weights_init) 83 | 84 | def forward(self, x): 85 | outs = [] 86 | out = self.layer1(x) 87 | out = self.resnet1_1(out) 88 | out = self.resnet1_2(out) 89 | out = self.resnet1_3(out) 90 | # print("-------") 91 | # print(out.size()) 92 | outs.append(out) 93 | out = self.resnet2_1(out) 94 | out = self.resnet2_2(out) 95 | out = self.resnet2_3(out) 96 | # print(out.size()) 97 | outs.append(out) 98 | out = self.resnet3_1(out) 99 | out = self.resnet3_2(out) 100 | out = self.resnet3_3(out) 101 | # print(out.size()) 102 | outs.append(out) 103 | out = self.resnet4_1(out) 104 | out = self.resnet4_2(out) 105 | out = self.resnet4_3(out) 106 | # print(out.size()) 107 | outs.append(out) 108 | out = self.resnet5_1(out) 109 | out = self.resnet5_2(out) 110 | out = self.resnet5_3(out) 111 | # print(out.size()) 112 | # print("-------") 113 | outs.append(out) 114 | return outs 115 | 116 | 117 | class LCL(nn.Module): 118 | def __init__(self, in_channel, out_channel): 119 | super(LCL, self).__init__() 120 | self.layer1 = nn.Sequential( 121 | nn.Conv2d(in_channels=in_channel, out_channels=out_channel, kernel_size=1, padding=0, stride=1), 122 | nn.ReLU(inplace=True), 123 | nn.Conv2d(in_channels=out_channel, out_channels=out_channel, kernel_size=3, padding=1, stride=1, dilation=1), 124 | #nn.BatchNorm2d(out_channel), 125 | nn.ReLU(inplace=True) 126 | ) 127 | self.layer1.apply(weights_init) 128 | def forward(self, x): 129 | out = self.layer1(x) 130 | # print("-----") 131 | # print(out.size()) 132 | # print("-----") 133 | return out 134 | 135 | 136 | class Sbam(nn.Module): 137 | def __init__(self, in_channel, out_channel): 138 | super(Sbam, self).__init__() 139 | self.hl_layer = nn.Sequential( 140 | nn.UpsamplingBilinear2d(scale_factor=2), 141 | nn.Conv2d(in_channels=in_channel, out_channels=out_channel, kernel_size=1), 142 | nn.BatchNorm2d(out_channel), 143 | nn.ReLU(inplace=True) 144 | ) 145 | self.ll_layer = nn.Sequential( 146 | nn.Conv2d(in_channels=out_channel, out_channels=out_channel, kernel_size=1), 147 | nn.BatchNorm2d(out_channel), 148 | nn.Sigmoid() # ll = torch.sigmoid(ll) 149 | ) 150 | self.hl_layer.apply(weights_init) 151 | self.ll_layer.apply(weights_init) 152 | def forward(self, hl,ll): 153 | hl = self.hl_layer(hl) 154 | # print(hl.size()) 155 | ll_1 = ll 156 | ll = self.ll_layer(ll) 157 | # print(ll.size()) 158 | hl_1 = hl*ll 159 | out = ll_1+hl_1 160 | return out 161 | 162 | class ALCL_No_Sigmoid(nn.Module): 163 | def __init__(self): 164 | super(ALCL_No_Sigmoid, self).__init__() 165 | self.stage = Stage() 166 | self.lcl5 = LCL(256, 256) 167 | self.lcl4 = LCL(128, 128) 168 | self.lcl3 = LCL(64, 64) 169 | self.lcl2 = LCL(32, 32) 170 | self.lcl1 = LCL(16, 16) 171 | self.sbam4 = Sbam(256, 128) 172 | self.sbam3 = Sbam(128, 64) 173 | self.sbam2 = Sbam(64, 32) 174 | self.sbam1 = Sbam(32, 16) 175 | 176 | self.layer = nn.Sequential( 177 | nn.Conv2d(in_channels=16, out_channels=16, kernel_size=1), 178 | nn.ReLU(inplace=True), 179 | nn.Conv2d(in_channels=16, out_channels=1, kernel_size=1), 180 | # nn.Sigmoid() 181 | ) 182 | self.layer.apply(weights_init) 183 | 184 | def forward(self, x): 185 | outs = self.stage(x) 186 | out5 = self.lcl5(outs[4]) 187 | # print(out5.size()) 188 | out4 = self.lcl4(outs[3]) 189 | # print(out4.size()) 190 | out3 = self.lcl3(outs[2]) 191 | # print(out3.size()) 192 | out2 = self.lcl2(outs[1]) 193 | # print(out2.size()) 194 | out1 = self.lcl1(outs[0]) 195 | # print(out1.size()) 196 | out4_2 = self.sbam4(out5, out4) 197 | out3_2 = self.sbam3(out4_2, out3) 198 | out2_2 = self.sbam2(out3_2, out2) 199 | out1_2 = self.sbam1(out2_2, out1) 200 | out = self.layer(out1_2) 201 | 202 | return out 203 | def weights_init(m): 204 | if isinstance(m, nn.Conv2d): 205 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 206 | if m.bias is not None: 207 | nn.init.constant_(m.bias, 0) 208 | elif isinstance(m, nn.BatchNorm2d): 209 | nn.init.constant_(m.weight, 1) 210 | nn.init.constant_(m.bias, 0) 211 | elif isinstance(m, nn.Linear): 212 | nn.init.xavier_uniform_(m.weight) 213 | nn.init.constant_(m.bias, 0) 214 | return 215 | 216 | if __name__ == '__main__': 217 | model = ALCL_No_Sigmoid() 218 | x = torch.rand(8, 3, 512, 512) 219 | outs = model(x) 220 | print(outs.size()) 221 | -------------------------------------------------------------------------------- /model/ALCL/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YuChuang1205/PAL/a9046c5fc65e7b2548d6e195b624239fa836cf3e/model/ALCL/__init__.py -------------------------------------------------------------------------------- /model/ALCL/__pycache__/ALCL_no_sigmoid.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YuChuang1205/PAL/a9046c5fc65e7b2548d6e195b624239fa836cf3e/model/ALCL/__pycache__/ALCL_no_sigmoid.cpython-36.pyc -------------------------------------------------------------------------------- /model/ALCL/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YuChuang1205/PAL/a9046c5fc65e7b2548d6e195b624239fa836cf3e/model/ALCL/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /model/DNA/DNA_no_sigmoid.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class VGG_CBAM_Block(nn.Module): 6 | def __init__(self, in_channels, out_channels): 7 | super().__init__() 8 | self.conv1 = nn.Conv2d(in_channels, out_channels, 3, padding=1) 9 | self.bn1 = nn.BatchNorm2d(out_channels) 10 | self.relu = nn.ReLU(inplace=True) 11 | self.conv2 = nn.Conv2d(out_channels, out_channels, 3, padding=1) 12 | self.bn2 = nn.BatchNorm2d(out_channels) 13 | self.relu = nn.ReLU(inplace=True) 14 | self.ca = ChannelAttention(out_channels) 15 | self.sa = SpatialAttention() 16 | 17 | def forward(self, x): 18 | out = self.conv1(x) 19 | out = self.bn1(out) 20 | out = self.relu(out) 21 | out = self.conv2(out) 22 | out = self.bn2(out) 23 | out = self.ca(out) * out 24 | out = self.sa(out) * out 25 | out = self.relu(out) 26 | return out 27 | 28 | class ChannelAttention(nn.Module): 29 | def __init__(self, in_planes, ratio=16): 30 | super(ChannelAttention, self).__init__() 31 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 32 | self.max_pool = nn.AdaptiveMaxPool2d(1) 33 | self.fc1 = nn.Conv2d(in_planes, in_planes // 16, 1, bias=False) 34 | self.relu1 = nn.ReLU() 35 | self.fc2 = nn.Conv2d(in_planes // 16, in_planes, 1, bias=False) 36 | self.sigmoid = nn.Sigmoid() 37 | def forward(self, x): 38 | avg_out = self.fc2(self.relu1(self.fc1(self.avg_pool(x)))) 39 | max_out = self.fc2(self.relu1(self.fc1(self.max_pool(x)))) 40 | out = avg_out + max_out 41 | return self.sigmoid(out) 42 | 43 | class SpatialAttention(nn.Module): 44 | def __init__(self, kernel_size=7): 45 | super(SpatialAttention, self).__init__() 46 | assert kernel_size in (3, 7), 'kernel size must be 3 or 7' 47 | padding = 3 if kernel_size == 7 else 1 48 | self.conv1 = nn.Conv2d(2, 1, kernel_size, padding=padding, bias=False) 49 | self.sigmoid = nn.Sigmoid() 50 | def forward(self, x): 51 | avg_out = torch.mean(x, dim=1, keepdim=True) 52 | max_out, _ = torch.max(x, dim=1, keepdim=True) 53 | x = torch.cat([avg_out, max_out], dim=1) 54 | x = self.conv1(x) 55 | return self.sigmoid(x) 56 | 57 | class Res_CBAM_block(nn.Module): 58 | def __init__(self, in_channels, out_channels, stride = 1): 59 | super(Res_CBAM_block, self).__init__() 60 | self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size = 3, stride = stride, padding = 1) 61 | self.bn1 = nn.BatchNorm2d(out_channels) 62 | self.relu = nn.ReLU(inplace = True) 63 | self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size = 3, padding = 1) 64 | self.bn2 = nn.BatchNorm2d(out_channels) 65 | if stride != 1 or out_channels != in_channels: 66 | self.shortcut = nn.Sequential( 67 | nn.Conv2d(in_channels, out_channels, kernel_size = 1, stride = stride), 68 | nn.BatchNorm2d(out_channels)) 69 | else: 70 | self.shortcut = None 71 | 72 | self.ca = ChannelAttention(out_channels) 73 | self.sa = SpatialAttention() 74 | 75 | def forward(self, x): 76 | residual = x 77 | if self.shortcut is not None: 78 | residual = self.shortcut(x) 79 | out = self.conv1(x) 80 | out = self.bn1(out) 81 | out = self.relu(out) 82 | out = self.conv2(out) 83 | out = self.bn2(out) 84 | out = self.ca(out) * out 85 | out = self.sa(out) * out 86 | out += residual 87 | out = self.relu(out) 88 | return out 89 | 90 | class DNA_No_Sigmoid(nn.Module): 91 | def __init__(self, num_classes=1,input_channels=3, block=Res_CBAM_block, num_blocks=[2, 2, 2, 2], nb_filter=[16, 32, 64, 128, 256], deep_supervision=True, mode='test'): 92 | super(DNA_No_Sigmoid, self).__init__() 93 | self.mode = mode 94 | self.relu = nn.ReLU(inplace = True) 95 | self.deep_supervision = deep_supervision 96 | self.pool = nn.MaxPool2d(2, 2) 97 | self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) 98 | self.down = nn.Upsample(scale_factor=0.5, mode='bilinear', align_corners=True) 99 | 100 | self.up_4 = nn.Upsample(scale_factor=4, mode='bilinear', align_corners=True) 101 | self.up_8 = nn.Upsample(scale_factor=8, mode='bilinear', align_corners=True) 102 | self.up_16 = nn.Upsample(scale_factor=16, mode='bilinear', align_corners=True) 103 | 104 | self.conv0_0 = self._make_layer(block, input_channels, nb_filter[0]) 105 | self.conv1_0 = self._make_layer(block, nb_filter[0], nb_filter[1], num_blocks[0]) 106 | self.conv2_0 = self._make_layer(block, nb_filter[1], nb_filter[2], num_blocks[1]) 107 | self.conv3_0 = self._make_layer(block, nb_filter[2], nb_filter[3], num_blocks[2]) 108 | self.conv4_0 = self._make_layer(block, nb_filter[3], nb_filter[4], num_blocks[3]) 109 | 110 | self.conv0_1 = self._make_layer(block, nb_filter[0] + nb_filter[1], nb_filter[0]) 111 | self.conv1_1 = self._make_layer(block, nb_filter[1] + nb_filter[2] + nb_filter[0], nb_filter[1], num_blocks[0]) 112 | self.conv2_1 = self._make_layer(block, nb_filter[2] + nb_filter[3] + nb_filter[1], nb_filter[2], num_blocks[1]) 113 | self.conv3_1 = self._make_layer(block, nb_filter[3] + nb_filter[4] + nb_filter[2], nb_filter[3], num_blocks[2]) 114 | 115 | self.conv0_2 = self._make_layer(block, nb_filter[0]*2 + nb_filter[1], nb_filter[0]) 116 | self.conv1_2 = self._make_layer(block, nb_filter[1]*2 + nb_filter[2]+ nb_filter[0], nb_filter[1], num_blocks[0]) 117 | self.conv2_2 = self._make_layer(block, nb_filter[2]*2 + nb_filter[3]+ nb_filter[1], nb_filter[2], num_blocks[1]) 118 | 119 | self.conv0_3 = self._make_layer(block, nb_filter[0]*3 + nb_filter[1], nb_filter[0]) 120 | self.conv1_3 = self._make_layer(block, nb_filter[1]*3 + nb_filter[2]+ nb_filter[0], nb_filter[1], num_blocks[0]) 121 | 122 | self.conv0_4 = self._make_layer(block, nb_filter[0]*4 + nb_filter[1], nb_filter[0]) 123 | 124 | self.conv0_4_final = self._make_layer(block, nb_filter[0]*5, nb_filter[0]) 125 | 126 | self.conv0_4_1x1 = nn.Conv2d(nb_filter[4], nb_filter[0], kernel_size=1, stride=1) 127 | self.conv0_3_1x1 = nn.Conv2d(nb_filter[3], nb_filter[0], kernel_size=1, stride=1) 128 | self.conv0_2_1x1 = nn.Conv2d(nb_filter[2], nb_filter[0], kernel_size=1, stride=1) 129 | self.conv0_1_1x1 = nn.Conv2d(nb_filter[1], nb_filter[0], kernel_size=1, stride=1) 130 | 131 | if self.deep_supervision: 132 | self.final1 = nn.Conv2d (nb_filter[0], num_classes, kernel_size=1) 133 | self.final2 = nn.Conv2d (nb_filter[0], num_classes, kernel_size=1) 134 | self.final3 = nn.Conv2d (nb_filter[0], num_classes, kernel_size=1) 135 | self.final4 = nn.Conv2d (nb_filter[0], num_classes, kernel_size=1) 136 | else: 137 | self.final = nn.Conv2d (nb_filter[0], num_classes, kernel_size=1) 138 | 139 | def _make_layer(self, block, input_channels, output_channels, num_blocks=1): 140 | layers = [] 141 | layers.append(block(input_channels, output_channels)) 142 | for i in range(num_blocks-1): 143 | layers.append(block(output_channels, output_channels)) 144 | return nn.Sequential(*layers) 145 | 146 | def forward(self, input): 147 | x0_0 = self.conv0_0(input) 148 | x1_0 = self.conv1_0(self.pool(x0_0)) 149 | x0_1 = self.conv0_1(torch.cat([x0_0, self.up(x1_0)], 1)) 150 | 151 | x2_0 = self.conv2_0(self.pool(x1_0)) 152 | x1_1 = self.conv1_1(torch.cat([x1_0, self.up(x2_0),self.down(x0_1)], 1)) 153 | x0_2 = self.conv0_2(torch.cat([x0_0, x0_1, self.up(x1_1)], 1)) 154 | 155 | x3_0 = self.conv3_0(self.pool(x2_0)) 156 | x2_1 = self.conv2_1(torch.cat([x2_0, self.up(x3_0),self.down(x1_1)], 1)) 157 | x1_2 = self.conv1_2(torch.cat([x1_0, x1_1, self.up(x2_1),self.down(x0_2)], 1)) 158 | x0_3 = self.conv0_3(torch.cat([x0_0, x0_1, x0_2, self.up(x1_2)], 1)) 159 | 160 | x4_0 = self.conv4_0(self.pool(x3_0)) 161 | x3_1 = self.conv3_1(torch.cat([x3_0, self.up(x4_0),self.down(x2_1)], 1)) 162 | x2_2 = self.conv2_2(torch.cat([x2_0, x2_1, self.up(x3_1),self.down(x1_2)], 1)) 163 | x1_3 = self.conv1_3(torch.cat([x1_0, x1_1, x1_2, self.up(x2_2),self.down(x0_3)], 1)) 164 | x0_4 = self.conv0_4(torch.cat([x0_0, x0_1, x0_2, x0_3, self.up(x1_3)], 1)) 165 | 166 | Final_x0_4 = self.conv0_4_final( 167 | torch.cat([self.up_16(self.conv0_4_1x1(x4_0)),self.up_8(self.conv0_3_1x1(x3_1)), 168 | self.up_4 (self.conv0_2_1x1(x2_2)),self.up (self.conv0_1_1x1(x1_3)), x0_4], 1)) 169 | 170 | if self.deep_supervision: 171 | # output1 = self.final1(x0_1).sigmoid() 172 | # output2 = self.final2(x0_2).sigmoid() 173 | # output3 = self.final3(x0_3).sigmoid() 174 | # output4 = self.final4(Final_x0_4).sigmoid() 175 | output1 = self.final1(x0_1) 176 | output2 = self.final2(x0_2) 177 | output3 = self.final3(x0_3) 178 | output4 = self.final4(Final_x0_4) 179 | if self.mode == 'train': 180 | return [output1, output2, output3, output4] 181 | else: 182 | return output4 183 | else: 184 | # output = self.final(Final_x0_4).sigmoid() 185 | output = self.final(Final_x0_4) 186 | return output 187 | 188 | 189 | -------------------------------------------------------------------------------- /model/DNA/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | -------------------------------------------------------------------------------- /model/DNA/__pycache__/DNA_no_sigmoid.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YuChuang1205/PAL/a9046c5fc65e7b2548d6e195b624239fa836cf3e/model/DNA/__pycache__/DNA_no_sigmoid.cpython-36.pyc -------------------------------------------------------------------------------- /model/DNA/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YuChuang1205/PAL/a9046c5fc65e7b2548d6e195b624239fa836cf3e/model/DNA/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /model/GGL/GGL_no_sigmoid.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn import init 4 | import torch.nn.functional as F 5 | 6 | def gradient_1order(x,h_x=None,w_x=None): 7 | if h_x is None and w_x is None: 8 | h_x = x.size()[2] 9 | w_x = x.size()[3] 10 | r = F.pad(x, (0, 1, 0, 0))[:, :, :, 1:] 11 | l = F.pad(x, (1, 0, 0, 0))[:, :, :, :w_x] 12 | t = F.pad(x, (0, 0, 1, 0))[:, :, :h_x, :] 13 | b = F.pad(x, (0, 0, 0, 1))[:, :, 1:, :] 14 | xgrad = torch.pow(torch.pow((r - l) * 0.5, 2) + torch.pow((t - b) * 0.5, 2), 0.5) 15 | return xgrad 16 | 17 | class SEAttention(nn.Module): 18 | 19 | def __init__(self, channel,reduction=16): 20 | super().__init__() 21 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 22 | self.fc = nn.Sequential( 23 | nn.Linear(channel, channel // reduction, bias=False), 24 | nn.ReLU(inplace=True), 25 | nn.Linear(channel // reduction, channel, bias=False), 26 | nn.Sigmoid() 27 | ) 28 | 29 | 30 | def init_weights(self): 31 | for m in self.modules(): 32 | if isinstance(m, nn.Conv2d): 33 | init.kaiming_normal_(m.weight, mode='fan_out') 34 | if m.bias is not None: 35 | init.constant_(m.bias, 0) 36 | elif isinstance(m, nn.BatchNorm2d): 37 | init.constant_(m.weight, 1) 38 | init.constant_(m.bias, 0) 39 | elif isinstance(m, nn.Linear): 40 | init.normal_(m.weight, std=0.001) 41 | if m.bias is not None: 42 | init.constant_(m.bias, 0) 43 | 44 | def forward(self, x): 45 | b, c, _, _ = x.size() 46 | y = self.avg_pool(x).view(b, c) 47 | y = self.fc(y).view(b, c, 1, 1) 48 | return x * y.expand_as(x) 49 | 50 | 51 | class ChannelAttention(nn.Module): 52 | def __init__(self, channel, reduction=8): 53 | super().__init__() 54 | self.maxpool = nn.AdaptiveMaxPool2d(1) 55 | self.avgpool = nn.AdaptiveAvgPool2d(1) 56 | self.se = nn.Sequential( 57 | nn.Conv2d(channel, channel // reduction, 1, bias=False), 58 | nn.ReLU(), 59 | nn.Conv2d(channel // reduction, channel, 1, bias=False) 60 | ) 61 | self.sigmoid = nn.Sigmoid() 62 | 63 | def init_weights(self): 64 | for m in self.modules(): 65 | if isinstance(m, nn.Conv2d): 66 | init.kaiming_normal_(m.weight, mode='fan_out') 67 | if m.bias is not None: 68 | init.constant_(m.bias, 0) 69 | elif isinstance(m, nn.BatchNorm2d): 70 | init.constant_(m.weight, 1) 71 | init.constant_(m.bias, 0) 72 | elif isinstance(m, nn.Linear): 73 | init.normal_(m.weight, std=0.001) 74 | if m.bias is not None: 75 | init.constant_(m.bias, 0) 76 | def forward(self, x): 77 | max_result = self.maxpool(x) 78 | avg_result = self.avgpool(x) 79 | max_out = self.se(max_result) 80 | avg_out = self.se(avg_result) 81 | output = self.sigmoid(max_out + avg_out) 82 | return output 83 | 84 | class SpatialAttention(nn.Module): 85 | def __init__(self, kernel_size=7): 86 | super().__init__() 87 | self.conv = nn.Conv2d(2, 1, kernel_size=kernel_size, padding=kernel_size // 2) 88 | self.sigmoid = nn.Sigmoid() 89 | 90 | def init_weights(self): 91 | for m in self.modules(): 92 | if isinstance(m, nn.Conv2d): 93 | init.kaiming_normal_(m.weight, mode='fan_out') 94 | if m.bias is not None: 95 | init.constant_(m.bias, 0) 96 | elif isinstance(m, nn.BatchNorm2d): 97 | init.constant_(m.weight, 1) 98 | init.constant_(m.bias, 0) 99 | elif isinstance(m, nn.Linear): 100 | init.normal_(m.weight, std=0.001) 101 | if m.bias is not None: 102 | init.constant_(m.bias, 0) 103 | def forward(self, x): 104 | max_result, _ = torch.max(x, dim=1, keepdim=True) 105 | avg_result = torch.mean(x, dim=1, keepdim=True) 106 | result = torch.cat([max_result, avg_result], 1) 107 | output = self.conv(result) 108 | output = self.sigmoid(output) 109 | return output 110 | 111 | class Resnet1(nn.Module): 112 | def __init__(self, in_channel, out_channel): 113 | super(Resnet1, self).__init__() 114 | self.layer = nn.Sequential( 115 | nn.Conv2d(in_channels=in_channel, out_channels=out_channel, kernel_size=3, padding=1, stride=1), 116 | nn.BatchNorm2d(out_channel), 117 | nn.ReLU(inplace=True), 118 | nn.Conv2d(in_channels=out_channel, out_channels=out_channel, kernel_size=3, padding=1, stride=1), 119 | nn.BatchNorm2d(out_channel) 120 | ) 121 | self.relu = nn.ReLU(inplace=True) 122 | self.layer.apply(weights_init) 123 | 124 | def forward(self, x): 125 | identity = x 126 | out = self.layer(x) 127 | out += identity 128 | return self.relu(out) 129 | 130 | 131 | # layer2_1 #layer3_1#layer4_1#layer5_1 132 | class Resnet2(nn.Module): 133 | def __init__(self, in_channel, out_channel): 134 | super(Resnet2, self).__init__() 135 | self.layer1 = nn.Sequential( 136 | nn.Conv2d(in_channels=in_channel, out_channels=out_channel, kernel_size=3, padding=1, stride=1), 137 | nn.BatchNorm2d(out_channel), 138 | nn.ReLU(inplace=True), 139 | nn.MaxPool2d(kernel_size=2, stride=2), 140 | nn.Conv2d(in_channels=out_channel, out_channels=out_channel, kernel_size=3, padding=1, stride=1), 141 | nn.BatchNorm2d(out_channel) 142 | ) 143 | self.layer2 = nn.Sequential( 144 | nn.Conv2d(in_channels=in_channel, out_channels=out_channel, kernel_size=3, padding=1, stride=2), 145 | nn.BatchNorm2d(out_channel), 146 | nn.ReLU(inplace=True) 147 | ) 148 | self.relu = nn.ReLU(inplace=True) 149 | self.layer1.apply(weights_init) 150 | self.layer2.apply(weights_init) 151 | 152 | def forward(self, x): 153 | identity = x 154 | out = self.layer1(x) 155 | identity = self.layer2(identity) 156 | out += identity 157 | return self.relu(out) 158 | 159 | 160 | class Resnet3(nn.Module): 161 | def __init__(self, in_channel, out_channel): 162 | super(Resnet3, self).__init__() 163 | self.layer1 = nn.Sequential( 164 | nn.Conv2d(in_channels=in_channel, out_channels=out_channel, kernel_size=3, padding=1, stride=1), 165 | nn.BatchNorm2d(out_channel), 166 | nn.ReLU(inplace=True), 167 | 168 | ) 169 | self.layer2 = nn.Sequential( 170 | nn.Conv2d(in_channels=out_channel, out_channels=out_channel, kernel_size=3, padding=1, stride=1), 171 | nn.BatchNorm2d(out_channel) 172 | ) 173 | self.SEAttention = SEAttention(channel=out_channel, reduction=4) 174 | self.relu = nn.ReLU(inplace=True) 175 | self.layer1.apply(weights_init) 176 | self.layer2.apply(weights_init) 177 | 178 | def forward(self, x): 179 | identity = x 180 | out = self.layer1(x) 181 | out = self.SEAttention(out) 182 | out = self.layer2(out) 183 | out += identity 184 | return self.relu(out) 185 | 186 | class Res(nn.Module): 187 | def __init__(self, befor_channel,after_channel): 188 | super(Res, self).__init__() 189 | 190 | self.layer1 = nn.Sequential( 191 | nn.Conv2d(in_channels=befor_channel, out_channels=after_channel, kernel_size=3, padding=1, stride=1), 192 | nn.BatchNorm2d(after_channel), 193 | nn.ReLU(inplace=True), 194 | SEAttention(channel=after_channel, reduction=4), 195 | nn.Conv2d(in_channels=after_channel, out_channels=after_channel, kernel_size=3, padding=1, stride=1), 196 | nn.BatchNorm2d(after_channel), 197 | nn.ReLU(inplace=True), 198 | ) 199 | self.layer2 = nn.Sequential( 200 | nn.Conv2d(in_channels=2*after_channel, out_channels=after_channel, kernel_size=3, padding=1, stride=1), 201 | nn.BatchNorm2d(after_channel), 202 | nn.ReLU(inplace=True), 203 | SEAttention(channel=after_channel, reduction=4), 204 | nn.Conv2d(in_channels=after_channel, out_channels=after_channel, kernel_size=3, padding=1, stride=1), 205 | nn.BatchNorm2d(after_channel), 206 | ) 207 | self.relu = nn.ReLU(inplace=True) 208 | self.layer1.apply(weights_init) 209 | self.layer2.apply(weights_init) 210 | 211 | def forward(self, x, x1): 212 | x1 = self.layer1(x1) 213 | # print(x1.size()) 214 | con = torch.cat([x, x1], 1) 215 | identity = x 216 | out = self.layer2(con) 217 | out = out + identity 218 | return self.relu(out) 219 | 220 | class Stage(nn.Module): 221 | def __init__(self): 222 | super(Stage, self).__init__() 223 | self.layer1 = nn.Sequential( 224 | nn.Conv2d(in_channels=3, out_channels=16, kernel_size=3, padding=1, stride=1), 225 | nn.BatchNorm2d(16), 226 | nn.ReLU(inplace=True), 227 | nn.Conv2d(in_channels=16, out_channels=16, kernel_size=3, padding=1, stride=1), 228 | nn.BatchNorm2d(16), 229 | nn.ReLU(inplace=True) 230 | ) 231 | self.resnet1_1 = Resnet1(in_channel=16, out_channel=16) 232 | self.resnet1_2 = Resnet3(in_channel=16, out_channel=16) 233 | self.resnet1_3 = Resnet3(in_channel=16, out_channel=16) 234 | self.Res1 = Res(befor_channel=3, after_channel=16) 235 | self.resnet2_1 = Resnet2(in_channel=16, out_channel=32) 236 | self.resnet2_2 = Resnet3(in_channel=32, out_channel=32) 237 | self.resnet2_3 = Resnet3(in_channel=32, out_channel=32) 238 | self.Res2 = Res(befor_channel=3, after_channel=32) 239 | self.resnet3_1 = Resnet2(in_channel=32, out_channel=64) 240 | self.resnet3_2 = Resnet3(in_channel=64, out_channel=64) 241 | self.resnet3_3 = Resnet3(in_channel=64, out_channel=64) 242 | self.Res3 = Res(befor_channel=3, after_channel=64) 243 | self.resnet4_1 = Resnet2(in_channel=64, out_channel=128) 244 | self.resnet4_2 = Resnet3(in_channel=128, out_channel=128) 245 | self.resnet4_3 = Resnet3(in_channel=128, out_channel=128) 246 | self.Res4 = Res(befor_channel=3, after_channel=128) 247 | self.resnet5_1 = Resnet2(in_channel=128, out_channel=256) 248 | self.resnet5_2 = Resnet3(in_channel=256, out_channel=256) 249 | self.resnet5_3 = Resnet3(in_channel=256, out_channel=256) 250 | self.Res5 = Res(befor_channel=3, after_channel=256) 251 | self.pool = nn.Sequential(nn.MaxPool2d(kernel_size=2, stride=2), ) 252 | self.layer1.apply(weights_init) 253 | 254 | def forward(self, x): 255 | x_g = gradient_1order(x) 256 | outs = [] 257 | out = self.layer1(x) 258 | # print(out_1.size()) 259 | out = self.resnet1_1(out) 260 | out = self.resnet1_2(out) 261 | out = self.resnet1_3(out) 262 | # print(out.size()) 263 | #out_1 = self.pool(out_1) 264 | out = self.Res1(out, x_g) 265 | # print("-------") 266 | # print(out.size()) 267 | outs.append(out) 268 | out = self.resnet2_1(out) 269 | out = self.resnet2_2(out) 270 | out = self.resnet2_3(out) 271 | x1 = self.pool(x_g) 272 | out = self.Res2(out, x1) 273 | # print(out.size()) 274 | outs.append(out) 275 | out = self.resnet3_1(out) 276 | out = self.resnet3_2(out) 277 | out = self.resnet3_3(out) 278 | # print(out.size()) 279 | x2 = self.pool(x1) 280 | # print(x2.size()) 281 | out = self.Res3(out, x2) 282 | # print(out.size()) 283 | outs.append(out) 284 | out = self.resnet4_1(out) 285 | out = self.resnet4_2(out) 286 | out = self.resnet4_3(out) 287 | # print(out.size()) 288 | x3 = self.pool(x2) 289 | # print(x3.size()) 290 | out = self.Res4(out, x3) 291 | # print(out.size()) 292 | outs.append(out) 293 | out = self.resnet5_1(out) 294 | out = self.resnet5_2(out) 295 | out = self.resnet5_3(out) 296 | x4 = self.pool(x3) 297 | out = self.Res5(out, x4) 298 | # print(out.size()) 299 | # print("-------") 300 | outs.append(out) 301 | return outs 302 | 303 | class LCL(nn.Module): 304 | def __init__(self, in_channel, out_channel): 305 | super(LCL, self).__init__() 306 | self.layer1 = nn.Sequential( 307 | nn.Conv2d(in_channels=in_channel, out_channels=out_channel, kernel_size=1, padding=0, stride=1), 308 | nn.ReLU(inplace=True), 309 | nn.Conv2d(in_channels=out_channel, out_channels=out_channel,kernel_size=3, padding=1, stride=1, dilation=1), 310 | #nn.BatchNorm2d(out_channel), 311 | nn.ReLU(inplace=True) 312 | ) 313 | self.layer1.apply(weights_init) 314 | def forward(self, x): 315 | out = self.layer1(x) 316 | # print("-----") 317 | # print(out.size()) 318 | # print("-----") 319 | return out 320 | 321 | class Sbam(nn.Module): 322 | def __init__(self, in_channel, out_channel): 323 | super(Sbam, self).__init__() 324 | self.hl_layer = nn.Sequential( 325 | nn.UpsamplingBilinear2d(scale_factor=2), 326 | nn.Conv2d(in_channels=in_channel, out_channels=out_channel, kernel_size=1), 327 | nn.BatchNorm2d(out_channel), 328 | nn.ReLU(inplace=True) 329 | ) 330 | self.hl_layer_2 = ChannelAttention(out_channel) 331 | self.ll_layer = SpatialAttention() 332 | # self.ll_layer = nn.Sequential( 333 | # nn.Conv2d(in_channels=out_channel, out_channels=out_channel, kernel_size=1), 334 | # nn.BatchNorm2d(out_channel), 335 | # nn.Sigmoid() # ll = torch.sigmoid(ll) 336 | # ) 337 | self.hl_layer.apply(weights_init) 338 | def forward(self, hl,ll): 339 | hl = self.hl_layer(hl) 340 | # print(hl.size()) 341 | ll_1 =ll * self.hl_layer_2(hl) 342 | 343 | ll = self.ll_layer(ll) 344 | # print(ll.size()) 345 | hl_1 = hl * ll 346 | out = ll_1 + hl_1 347 | return out 348 | 349 | class GGL_No_Sigmoid(nn.Module): 350 | def __init__(self): 351 | super(GGL_No_Sigmoid, self).__init__() 352 | self.stage = Stage() 353 | self.lcl5 = LCL(256, 256) 354 | self.lcl4 = LCL(128, 128) 355 | self.lcl3 = LCL(64, 64) 356 | self.lcl2 = LCL(32, 32) 357 | self.lcl1 = LCL(16, 16) 358 | self.sbam4 = Sbam(256, 128) 359 | self.sbam3 = Sbam(128, 64) 360 | self.sbam2 = Sbam(64, 32) 361 | self.sbam1 = Sbam(32, 16) 362 | 363 | self.layer = nn.Sequential( 364 | nn.Conv2d(in_channels=16, out_channels=16, kernel_size=1), 365 | nn.ReLU(inplace=True), 366 | nn.Conv2d(in_channels=16, out_channels=1, kernel_size=1), 367 | # nn.Sigmoid() 368 | ) 369 | self.layer.apply(weights_init) 370 | 371 | def forward(self, x): 372 | outs = self.stage(x) 373 | out5 = self.lcl5(outs[4]) 374 | # print(out5.size()) 375 | out4 = self.lcl4(outs[3]) 376 | # print(out4.size()) 377 | out3 = self.lcl3(outs[2]) 378 | # print(out3.size()) 379 | out2 = self.lcl2(outs[1]) 380 | # print(out2.size()) 381 | out1 = self.lcl1(outs[0]) 382 | # print(out1.size()) 383 | out4_2 = self.sbam4(out5, out4) 384 | out3_2 = self.sbam3(out4_2, out3) 385 | out2_2 = self.sbam2(out3_2, out2) 386 | out1_2 = self.sbam1(out2_2, out1) 387 | out = self.layer(out1_2) 388 | 389 | return out 390 | 391 | 392 | 393 | def weights_init(m): 394 | if isinstance(m, nn.Conv2d): 395 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 396 | if m.bias is not None: 397 | nn.init.constant_(m.bias, 0) 398 | elif isinstance(m, nn.BatchNorm2d): 399 | nn.init.constant_(m.weight, 1) 400 | nn.init.constant_(m.bias, 0) 401 | # elif isinstance(m, nn.Linear): 402 | # nn.init.xavier_uniform_(m.weight) 403 | # nn.init.constant_(m.bias, 0) 404 | return 405 | 406 | 407 | # if __name__ == '__main__': 408 | # model = GGL_No_Sigmoid() 409 | # x = torch.rand(8, 3, 512, 512) 410 | # x_g = gradient_1order(x) 411 | # outs = model(x,x_g) 412 | # print(outs.size()) 413 | -------------------------------------------------------------------------------- /model/GGL/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YuChuang1205/PAL/a9046c5fc65e7b2548d6e195b624239fa836cf3e/model/GGL/__init__.py -------------------------------------------------------------------------------- /model/GGL/__pycache__/GGL_no_sigmoid.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YuChuang1205/PAL/a9046c5fc65e7b2548d6e195b624239fa836cf3e/model/GGL/__pycache__/GGL_no_sigmoid.cpython-36.pyc -------------------------------------------------------------------------------- /model/GGL/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YuChuang1205/PAL/a9046c5fc65e7b2548d6e195b624239fa836cf3e/model/GGL/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /model/MLCL/MLCL_no_sigmoid.py: -------------------------------------------------------------------------------- 1 | """ 2 | @author: yuchuang,zhaojinmiao 3 | @time: 4 | @desc: 这个版本是MLCL-Net的基础版本(论文一致)。即每阶段用了3个block paper:"Infrared small target detection based on multiscale local contrast learning networks" 5 | """ 6 | import torch 7 | import torch.nn as nn 8 | 9 | class Resnet1(nn.Module): 10 | def __init__(self, in_channel, out_channel): 11 | super(Resnet1, self).__init__() 12 | self.layer = nn.Sequential( 13 | nn.Conv2d(in_channels=in_channel, out_channels=out_channel, kernel_size=3, padding=1, stride=1), 14 | nn.BatchNorm2d(out_channel), 15 | nn.ReLU(inplace=True), 16 | nn.Conv2d(in_channels=out_channel, out_channels=out_channel, kernel_size=3, padding=1, stride=1), 17 | nn.BatchNorm2d(out_channel) 18 | ) 19 | self.relu = nn.ReLU(inplace=True) 20 | self.layer.apply(weights_init) 21 | 22 | def forward(self, x): 23 | identity = x 24 | out = self.layer(x) 25 | out += identity 26 | return self.relu(out) 27 | 28 | 29 | class Resnet2(nn.Module): 30 | def __init__(self, in_channel, out_channel): 31 | super(Resnet2, self).__init__() 32 | self.layer1 = nn.Sequential( 33 | nn.Conv2d(in_channels=in_channel, out_channels=out_channel, kernel_size=3, padding=1, stride=1), 34 | nn.BatchNorm2d(out_channel), 35 | nn.ReLU(inplace=True), 36 | nn.MaxPool2d(kernel_size=2, stride=2), 37 | nn.Conv2d(in_channels=out_channel, out_channels=out_channel, kernel_size=3, padding=1, stride=1), 38 | nn.BatchNorm2d(out_channel) 39 | ) 40 | self.layer2 = nn.Sequential( 41 | nn.Conv2d(in_channels=in_channel, out_channels=out_channel, kernel_size=3, padding=1, stride=2), 42 | nn.BatchNorm2d(out_channel), 43 | nn.ReLU(inplace=True) 44 | ) 45 | self.relu = nn.ReLU(inplace=True) 46 | self.layer1.apply(weights_init) 47 | self.layer2.apply(weights_init) 48 | 49 | def forward(self, x): 50 | identity = x 51 | out = self.layer1(x) 52 | identity = self.layer2(identity) 53 | out += identity 54 | return self.relu(out) 55 | 56 | 57 | class Stage(nn.Module): 58 | def __init__(self): 59 | super(Stage, self).__init__() 60 | self.layer1 = nn.Sequential( 61 | nn.Conv2d(in_channels=3, out_channels=16, kernel_size=3, padding=1, stride=1), 62 | nn.BatchNorm2d(16), 63 | nn.ReLU(inplace=True), 64 | nn.Conv2d(in_channels=16, out_channels=16, kernel_size=3, padding=1, stride=1), 65 | nn.BatchNorm2d(16), 66 | nn.ReLU(inplace=True) 67 | ) 68 | self.resnet1_1 = Resnet1(in_channel=16, out_channel=16) 69 | self.resnet1_2 = Resnet1(in_channel=16, out_channel=16) 70 | self.resnet1_3 = Resnet1(in_channel=16, out_channel=16) 71 | self.resnet2_1 = Resnet2(in_channel=16, out_channel=32) 72 | self.resnet2_2 = Resnet1(in_channel=32, out_channel=32) 73 | self.resnet2_3 = Resnet1(in_channel=32, out_channel=32) 74 | self.resnet3_1 = Resnet2(in_channel=32, out_channel=64) 75 | self.resnet3_2 = Resnet1(in_channel=64, out_channel=64) 76 | self.resnet3_3 = Resnet1(in_channel=64, out_channel=64) 77 | 78 | def forward(self, x): 79 | outs = [] 80 | out = self.layer1(x) 81 | out = self.resnet1_1(out) 82 | out = self.resnet1_2(out) 83 | out = self.resnet1_3(out) 84 | outs.append(out) 85 | out = self.resnet2_1(out) 86 | out = self.resnet2_2(out) 87 | out = self.resnet2_3(out) 88 | outs.append(out) 89 | out = self.resnet3_1(out) 90 | out = self.resnet3_2(out) 91 | out = self.resnet3_3(out) 92 | outs.append(out) 93 | return outs 94 | 95 | 96 | class MLCL(nn.Module): 97 | def __init__(self, in_channel, out_channel): 98 | super(MLCL, self).__init__() 99 | self.layer1 = nn.Sequential( 100 | nn.Conv2d(in_channels=in_channel, out_channels=out_channel, kernel_size=1, padding=0, stride=1), 101 | nn.ReLU(inplace=True), 102 | nn.Conv2d(in_channels=out_channel, out_channels=out_channel, kernel_size=3, padding=1, stride=1, dilation=1), 103 | nn.ReLU(inplace=True) 104 | ) 105 | self.layer2 = nn.Sequential( 106 | nn.Conv2d(in_channels=in_channel, out_channels=out_channel, kernel_size=3, padding=1, stride=1), 107 | nn.ReLU(inplace=True), 108 | nn.Conv2d(in_channels=out_channel, out_channels=out_channel, kernel_size=3, padding=3, stride=1, dilation=3), 109 | nn.ReLU(inplace=True) 110 | ) 111 | self.layer3 = nn.Sequential( 112 | nn.Conv2d(in_channels=in_channel, out_channels=out_channel, kernel_size=5, padding=2, stride=1), 113 | nn.ReLU(inplace=True), 114 | nn.Conv2d(in_channels=out_channel, out_channels=out_channel, kernel_size=3, padding=5, stride=1, dilation=5), 115 | nn.ReLU(inplace=True) 116 | ) 117 | # self.layer4 = nn.Sequential( 118 | # nn.Conv2d(in_channels=in_channel, out_channels=out_channel, kernel_size=7, padding=2, stride=1), 119 | # nn.ReLU(inplace=True), 120 | # nn.Conv2d(in_channels=out_channel, out_channels=out_channel, kernel_size=3, padding=7, stride=1,dilation=7), 121 | # nn.ReLU(inplace=True) 122 | # ) 123 | self.conv = nn.Conv2d(in_channels=out_channel * 3, out_channels=out_channel, kernel_size=1) 124 | self.layer1.apply(weights_init) 125 | self.layer2.apply(weights_init) 126 | self.layer3.apply(weights_init) 127 | self.conv.apply(weights_init) 128 | def forward(self, x): 129 | x1 = x 130 | x2 = x 131 | x3 = x 132 | out1 = self.layer1(x1) 133 | out2 = self.layer2(x2) 134 | out3 = self.layer3(x3) 135 | outs = torch.cat((out1, out2, out3), dim=1) 136 | return self.conv(outs) 137 | 138 | 139 | class MLCL_No_Sigmoid(nn.Module): 140 | def __init__(self): 141 | super(MLCL_No_Sigmoid, self).__init__() 142 | self.stage = Stage() 143 | self.mlcl3 = MLCL(64, 64) 144 | self.mlcl2 = MLCL(32, 32) 145 | self.mlcl1 = MLCL(16, 16) 146 | self.up3 = nn.UpsamplingBilinear2d(scale_factor=2) 147 | self.conv3 = nn.Sequential( 148 | nn.Conv2d(in_channels=64, out_channels=64, kernel_size=1), 149 | nn.ReLU(inplace=True) 150 | ) 151 | self.up2 = nn.UpsamplingBilinear2d(scale_factor=2) 152 | self.conv2 = nn.Sequential( 153 | nn.Conv2d(in_channels=32, out_channels=64, kernel_size=1), 154 | nn.ReLU(inplace=True) 155 | ) 156 | self.conv1 = nn.Sequential( 157 | nn.Conv2d(in_channels=16, out_channels=64, kernel_size=1), 158 | nn.ReLU(inplace=True) 159 | ) 160 | self.layer = nn.Sequential( 161 | nn.Conv2d(in_channels=64, out_channels=64, kernel_size=1), 162 | nn.ReLU(inplace=True), 163 | nn.Conv2d(in_channels=64, out_channels=1, kernel_size=1), 164 | #nn.Sigmoid() 165 | ) 166 | 167 | def forward(self, x): 168 | outs = self.stage(x) 169 | out3 = self.mlcl3(outs[2]) 170 | out2 = self.mlcl2(outs[1]) 171 | out1 = self.mlcl1(outs[0]) 172 | out3 = self.conv3(out3) # 128*128 64 173 | out3 = self.up3(out3) # 256*256 64 174 | out2 = self.conv2(out2) # 256*256 64 175 | out = out3 + out2 176 | out = self.up2(out) 177 | out1 = self.conv1(out1) 178 | out = out + out1 179 | out = self.layer(out) 180 | return out 181 | 182 | def weights_init(m): 183 | if isinstance(m, nn.Conv2d): 184 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 185 | if m.bias is not None: 186 | nn.init.constant_(m.bias, 0) 187 | elif isinstance(m, nn.BatchNorm2d): 188 | nn.init.constant_(m.weight, 1) 189 | nn.init.constant_(m.bias, 0) 190 | elif isinstance(m, nn.Linear): 191 | nn.init.xavier_uniform_(m.weight) 192 | nn.init.constant_(m.bias, 0) 193 | return 194 | 195 | if __name__ == '__main__': 196 | model = MLCL_No_Sigmoid() 197 | x = torch.rand(8, 3, 512, 512) 198 | outs = model(x) 199 | print(outs.size()) 200 | 201 | -------------------------------------------------------------------------------- /model/MLCL/MLCL_small_no_sigmoid.py: -------------------------------------------------------------------------------- 1 | """ 2 | @author: yuchuang,zhaojinmiao 3 | @time: 4 | @desc: 这个版本是MLCL-Net的轻量化版本。即每阶段只用了2个block paper:"Infrared small target detection based on multiscale local contrast learning networks" 5 | """ 6 | import torch 7 | import torch.nn as nn 8 | 9 | 10 | class Resnet1(nn.Module): 11 | def __init__(self, in_channel, out_channel): 12 | super(Resnet1, self).__init__() 13 | self.layer = nn.Sequential( 14 | nn.Conv2d(in_channels=in_channel, out_channels=out_channel, kernel_size=3, padding=1, stride=1), 15 | nn.BatchNorm2d(out_channel), 16 | nn.ReLU(inplace=True), 17 | nn.Conv2d(in_channels=out_channel, out_channels=out_channel, kernel_size=3, padding=1, stride=1), 18 | nn.BatchNorm2d(out_channel) 19 | ) 20 | self.relu = nn.ReLU(inplace=True) 21 | self.layer.apply(weights_init) 22 | 23 | def forward(self, x): 24 | identity = x 25 | out = self.layer(x) 26 | out += identity 27 | return self.relu(out) 28 | 29 | 30 | class Resnet2(nn.Module): 31 | def __init__(self, in_channel, out_channel): 32 | super(Resnet2, self).__init__() 33 | self.layer1 = nn.Sequential( 34 | nn.Conv2d(in_channels=in_channel, out_channels=out_channel, kernel_size=3, padding=1, stride=1), 35 | nn.BatchNorm2d(out_channel), 36 | nn.ReLU(inplace=True), 37 | nn.MaxPool2d(kernel_size=2, stride=2), 38 | nn.Conv2d(in_channels=out_channel, out_channels=out_channel, kernel_size=3, padding=1, stride=1), 39 | nn.BatchNorm2d(out_channel) 40 | ) 41 | self.layer2 = nn.Sequential( 42 | nn.Conv2d(in_channels=in_channel, out_channels=out_channel, kernel_size=3, padding=1, stride=2), 43 | nn.BatchNorm2d(out_channel), 44 | nn.ReLU(inplace=True) 45 | ) 46 | self.relu = nn.ReLU(inplace=True) 47 | self.layer1.apply(weights_init) 48 | self.layer2.apply(weights_init) 49 | 50 | def forward(self, x): 51 | identity = x 52 | out = self.layer1(x) 53 | identity = self.layer2(identity) 54 | out += identity 55 | return self.relu(out) 56 | 57 | 58 | class Stage(nn.Module): 59 | def __init__(self): 60 | super(Stage, self).__init__() 61 | self.layer1 = nn.Sequential( 62 | nn.Conv2d(in_channels=3, out_channels=16, kernel_size=3, padding=1, stride=1), 63 | nn.BatchNorm2d(16), 64 | nn.ReLU(inplace=True), 65 | nn.Conv2d(in_channels=16, out_channels=16, kernel_size=3, padding=1, stride=1), 66 | nn.BatchNorm2d(16), 67 | nn.ReLU(inplace=True) 68 | ) 69 | self.resnet1_1 = Resnet1(in_channel=16, out_channel=16) 70 | self.resnet1_2 = Resnet1(in_channel=16, out_channel=16) 71 | self.resnet2_1 = Resnet2(in_channel=16, out_channel=32) 72 | self.resnet2_2 = Resnet1(in_channel=32, out_channel=32) 73 | self.resnet3_1 = Resnet2(in_channel=32, out_channel=64) 74 | self.resnet3_2 = Resnet1(in_channel=64, out_channel=64) 75 | 76 | def forward(self, x): 77 | outs = [] 78 | out = self.layer1(x) 79 | out = self.resnet1_1(out) 80 | out = self.resnet1_2(out) 81 | outs.append(out) 82 | out = self.resnet2_1(out) 83 | out = self.resnet2_2(out) 84 | outs.append(out) 85 | out = self.resnet3_1(out) 86 | out = self.resnet3_2(out) 87 | outs.append(out) 88 | return outs 89 | 90 | 91 | class MLCL(nn.Module): 92 | def __init__(self, in_channel, out_channel): 93 | super(MLCL, self).__init__() 94 | self.layer1 = nn.Sequential( 95 | nn.Conv2d(in_channels=in_channel, out_channels=out_channel, kernel_size=1, padding=0, stride=1), 96 | nn.ReLU(inplace=True), 97 | nn.Conv2d(in_channels=out_channel, out_channels=out_channel, kernel_size=3, padding=1, stride=1, dilation=1), 98 | nn.ReLU(inplace=True) 99 | ) 100 | self.layer2 = nn.Sequential( 101 | nn.Conv2d(in_channels=in_channel, out_channels=out_channel, kernel_size=3, padding=1, stride=1), 102 | nn.ReLU(inplace=True), 103 | nn.Conv2d(in_channels=out_channel, out_channels=out_channel, kernel_size=3, padding=3, stride=1, dilation=3), 104 | nn.ReLU(inplace=True) 105 | ) 106 | self.layer3 = nn.Sequential( 107 | nn.Conv2d(in_channels=in_channel, out_channels=out_channel, kernel_size=5, padding=2, stride=1), 108 | nn.ReLU(inplace=True), 109 | nn.Conv2d(in_channels=out_channel, out_channels=out_channel, kernel_size=3, padding=5, stride=1, dilation=5), 110 | nn.ReLU(inplace=True) 111 | ) 112 | # self.layer4 = nn.Sequential( 113 | # nn.Conv2d(in_channels=in_channel, out_channels=out_channel, kernel_size=7, padding=2, stride=1), 114 | # nn.ReLU(inplace=True), 115 | # nn.Conv2d(in_channels=out_channel, out_channels=out_channel, kernel_size=3, padding=7, stride=1,dilation=7), 116 | # nn.ReLU(inplace=True) 117 | # ) 118 | self.conv = nn.Conv2d(in_channels=out_channel * 3, out_channels=out_channel, kernel_size=1) 119 | self.layer1.apply(weights_init) 120 | self.layer2.apply(weights_init) 121 | self.layer3.apply(weights_init) 122 | self.conv.apply(weights_init) 123 | def forward(self, x): 124 | x1 = x 125 | x2 = x 126 | x3 = x 127 | out1 = self.layer1(x1) 128 | out2 = self.layer2(x2) 129 | out3 = self.layer3(x3) 130 | outs = torch.cat((out1, out2, out3), dim=1) 131 | return self.conv(outs) 132 | 133 | 134 | class MLCL_small_No_Sigmoid(nn.Module): 135 | def __init__(self): 136 | super(MLCL_small_No_Sigmoid, self).__init__() 137 | self.stage = Stage() 138 | self.mlcl3 = MLCL(64, 64) 139 | self.mlcl2 = MLCL(32, 32) 140 | self.mlcl1 = MLCL(16, 16) 141 | self.up3 = nn.UpsamplingBilinear2d(scale_factor=2) 142 | self.conv3 = nn.Sequential( 143 | nn.Conv2d(in_channels=64, out_channels=64, kernel_size=1), 144 | nn.ReLU(inplace=True) 145 | ) 146 | self.up2 = nn.UpsamplingBilinear2d(scale_factor=2) 147 | self.conv2 = nn.Sequential( 148 | nn.Conv2d(in_channels=32, out_channels=64, kernel_size=1), 149 | nn.ReLU(inplace=True) 150 | ) 151 | self.conv1 = nn.Sequential( 152 | nn.Conv2d(in_channels=16, out_channels=64, kernel_size=1), 153 | nn.ReLU(inplace=True) 154 | ) 155 | self.layer = nn.Sequential( 156 | nn.Conv2d(in_channels=64, out_channels=64, kernel_size=1), 157 | nn.ReLU(inplace=True), 158 | nn.Conv2d(in_channels=64, out_channels=1, kernel_size=1), 159 | #nn.Sigmoid() 160 | ) 161 | 162 | def forward(self, x): 163 | outs = self.stage(x) 164 | out3 = self.mlcl3(outs[2]) 165 | out2 = self.mlcl2(outs[1]) 166 | out1 = self.mlcl1(outs[0]) 167 | out3 = self.conv3(out3) # 128*128 64 168 | out3 = self.up3(out3) # 256*256 64 169 | out2 = self.conv2(out2) # 256*256 64 170 | out = out3 + out2 171 | out = self.up2(out) 172 | out1 = self.conv1(out1) 173 | out = out + out1 174 | out = self.layer(out) 175 | return out 176 | 177 | def weights_init(m): 178 | if isinstance(m, nn.Conv2d): 179 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 180 | if m.bias is not None: 181 | nn.init.constant_(m.bias, 0) 182 | elif isinstance(m, nn.BatchNorm2d): 183 | nn.init.constant_(m.weight, 1) 184 | nn.init.constant_(m.bias, 0) 185 | elif isinstance(m, nn.Linear): 186 | nn.init.xavier_uniform_(m.weight) 187 | nn.init.constant_(m.bias, 0) 188 | return 189 | 190 | if __name__ == '__main__': 191 | model = MLCL_small_No_Sigmoid() 192 | x = torch.rand(8, 3, 512, 512) 193 | outs = model(x) 194 | print(outs.size()) 195 | 196 | -------------------------------------------------------------------------------- /model/MLCL/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YuChuang1205/PAL/a9046c5fc65e7b2548d6e195b624239fa836cf3e/model/MLCL/__init__.py -------------------------------------------------------------------------------- /model/MLCL/__pycache__/MLCL_no_sigmoid.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YuChuang1205/PAL/a9046c5fc65e7b2548d6e195b624239fa836cf3e/model/MLCL/__pycache__/MLCL_no_sigmoid.cpython-36.pyc -------------------------------------------------------------------------------- /model/MLCL/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YuChuang1205/PAL/a9046c5fc65e7b2548d6e195b624239fa836cf3e/model/MLCL/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /model/MSDA/MSDA_no_sigmoid.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from mm.attention.SEAttention import SEAttention 4 | import pywt 5 | from torch.nn import init 6 | from torch.autograd import Function 7 | 8 | 9 | class SpatialAttention(nn.Module): 10 | def __init__(self, kernel_size=7): 11 | super().__init__() 12 | self.conv = nn.Conv2d(2, 1, kernel_size=kernel_size, padding=kernel_size // 2) 13 | self.sigmoid = nn.Sigmoid() 14 | 15 | def forward(self, x): 16 | max_result, _ = torch.max(x, dim=1, keepdim=True) 17 | avg_result = torch.mean(x, dim=1, keepdim=True) 18 | result = torch.cat([max_result, avg_result], 1) 19 | output = self.conv(result) 20 | output = self.sigmoid(output) 21 | return output 22 | 23 | 24 | ################################################################################################ 25 | class LH_DWT_2D_attation(nn.Module): 26 | def __init__(self, wave): 27 | super(LH_DWT_2D_attation, self).__init__() 28 | w = pywt.Wavelet(wave) 29 | dec_hi = torch.Tensor(w.dec_hi[::-1]) 30 | dec_lo = torch.Tensor(w.dec_lo[::-1]) 31 | w_lh = dec_lo.unsqueeze(0) * dec_hi.unsqueeze(1) 32 | self.register_buffer('w_lh', w_lh.unsqueeze(0).unsqueeze(0)) 33 | self.w_lh = self.w_lh.to(dtype=torch.float32) 34 | 35 | def forward(self, x): 36 | return LH_DWT_Function_attation.apply(x, self.w_lh) 37 | 38 | 39 | class LH_DWT_Function_attation(Function): 40 | @staticmethod 41 | def forward(ctx, x, w_lh): 42 | x = x.contiguous() 43 | ctx.save_for_backward(w_lh) 44 | ctx.shape = x.shape 45 | dim = x.shape[1] 46 | x = torch.nn.functional.pad(x, (1, 0, 1, 0)) 47 | x_lh = torch.nn.functional.conv2d(x, w_lh.expand(dim, -1, -1, -1), stride=1, padding=0, groups=dim) 48 | x = x_lh 49 | return x 50 | 51 | @staticmethod 52 | def backward(ctx, dx): 53 | if ctx.needs_input_grad[0]: 54 | w_lh = ctx.saved_tensors[0] 55 | B, C, H, W = ctx.shape 56 | dx = dx.view(B, 1, -1, H, W) 57 | dx = dx.transpose(1, 2).reshape(B, -1, H, W) 58 | filters = w_lh # torch.Size([3, 1, 2, 2]) 59 | filters = filters.repeat(C, 1, 1, 1).to(dtype=torch.float16) 60 | dx = torch.nn.functional.conv_transpose2d(dx, filters, stride=1, groups=C) 61 | dx = torch.cat((dx[:, :, :0], dx[:, :, 1:]), dim=2) 62 | dx = torch.cat((dx[:, :, :, :0], dx[:, :, :, 1:]), dim=3) 63 | return dx, None 64 | 65 | 66 | ######################################################################################## 67 | class HL_DWT_2D_attation(nn.Module): 68 | def __init__(self, wave): 69 | super(HL_DWT_2D_attation, self).__init__() 70 | w = pywt.Wavelet(wave) 71 | dec_hi = torch.Tensor(w.dec_hi[::-1]) 72 | dec_lo = torch.Tensor(w.dec_lo[::-1]) 73 | 74 | w_hl = dec_hi.unsqueeze(0) * dec_lo.unsqueeze(1) 75 | 76 | self.register_buffer('w_hl', w_hl.unsqueeze(0).unsqueeze(0)) 77 | self.w_hl = self.w_hl.to(dtype=torch.float32) 78 | 79 | def forward(self, x): 80 | return HL_DWT_Function_attation.apply(x, self.w_hl) 81 | 82 | 83 | class HL_DWT_Function_attation(Function): 84 | @staticmethod 85 | def forward(ctx, x, w_hl): 86 | x = x.contiguous() 87 | ctx.save_for_backward(w_hl) 88 | ctx.shape = x.shape 89 | dim = x.shape[1] 90 | x = torch.nn.functional.pad(x, (1, 0, 1, 0)) 91 | x_hl = torch.nn.functional.conv2d(x, w_hl.expand(dim, -1, -1, -1), stride=1, padding=0, groups=dim) 92 | x = x_hl 93 | return x 94 | 95 | @staticmethod 96 | def backward(ctx, dx): 97 | if ctx.needs_input_grad[0]: 98 | w_hl = ctx.saved_tensors[0] 99 | B, C, H, W = ctx.shape 100 | dx = dx.view(B, 1, -1, H, W) 101 | dx = dx.transpose(1, 2).reshape(B, -1, H, W) 102 | filters = w_hl 103 | filters = filters.repeat(C, 1, 1, 1).to(dtype=torch.float16) 104 | dx = torch.nn.functional.conv_transpose2d(dx, filters, stride=1, groups=C) 105 | dx = torch.cat((dx[:, :, :0], dx[:, :, 1:]), dim=2) 106 | dx = torch.cat((dx[:, :, :, :0], dx[:, :, :, 1:]), dim=3) 107 | return dx, None 108 | 109 | 110 | ####################################################################################### 111 | class HH_DWT_2D_attation(nn.Module): 112 | def __init__(self, wave): 113 | super(HH_DWT_2D_attation, self).__init__() 114 | w = pywt.Wavelet(wave) 115 | dec_hi = torch.Tensor(w.dec_hi[::-1]) 116 | dec_lo = torch.Tensor(w.dec_lo[::-1]) 117 | w_hh = dec_hi.unsqueeze(0) * dec_hi.unsqueeze(1) 118 | self.register_buffer('w_hh', w_hh.unsqueeze(0).unsqueeze(0)) 119 | self.w_hh = self.w_hh.to(dtype=torch.float32) 120 | 121 | def forward(self, x): 122 | return HH_DWT_Function_attation.apply(x, self.w_hh) 123 | 124 | 125 | class HH_DWT_Function_attation(Function): 126 | @staticmethod 127 | def forward(ctx, x, w_hh): 128 | x = x.contiguous() 129 | ctx.save_for_backward(w_hh) 130 | ctx.shape = x.shape 131 | dim = x.shape[1] 132 | x = torch.nn.functional.pad(x, (1, 0, 1, 0)) 133 | x_hh = torch.nn.functional.conv2d(x, w_hh.expand(dim, -1, -1, -1), stride=1, padding=0, 134 | groups=dim) 135 | x = x_hh 136 | return x 137 | 138 | @staticmethod 139 | def backward(ctx, dx): 140 | if ctx.needs_input_grad[0]: 141 | w_hh = ctx.saved_tensors[0] 142 | B, C, H, W = ctx.shape 143 | dx = dx.view(B, 1, -1, H, W) 144 | dx = dx.transpose(1, 2).reshape(B, -1, H, W) 145 | filters = w_hh # torch.Size([3, 1, 2, 2]) 146 | filters = filters.repeat(C, 1, 1, 1).to(dtype=torch.float16) 147 | dx = torch.nn.functional.conv_transpose2d(dx, filters, stride=1, groups=C) 148 | dx = torch.cat((dx[:, :, :0], dx[:, :, 1:]), dim=2) 149 | dx = torch.cat((dx[:, :, :, :0], dx[:, :, :, 1:]), dim=3) 150 | return dx, None 151 | 152 | 153 | ####################################################################################### 154 | class LL_DWT_2D_attation(nn.Module): 155 | def __init__(self, wave): 156 | super(LL_DWT_2D_attation, self).__init__() 157 | w = pywt.Wavelet(wave) 158 | dec_hi = torch.Tensor(w.dec_hi[::-1]) 159 | dec_lo = torch.Tensor(w.dec_lo[::-1]) 160 | w_ll = dec_lo.unsqueeze(0) * dec_lo.unsqueeze(1) 161 | self.register_buffer('w_ll', w_ll.unsqueeze(0).unsqueeze(0)) 162 | self.w_ll = self.w_ll.to(dtype=torch.float32) 163 | 164 | def forward(self, x): 165 | return LL_DWT_Function_attation.apply(x, self.w_ll) 166 | 167 | 168 | class LL_DWT_Function_attation(Function): 169 | @staticmethod 170 | def forward(ctx, x, w_ll): 171 | x = x.contiguous() 172 | ctx.save_for_backward(w_ll) 173 | ctx.shape = x.shape 174 | dim = x.shape[1] 175 | x = torch.nn.functional.pad(x, (1, 0, 1, 0)) 176 | x_ll = torch.nn.functional.conv2d(x, w_ll.expand(dim, -1, -1, -1), stride=1, padding=0, 177 | groups=dim) 178 | x = x_ll 179 | return x 180 | 181 | @staticmethod 182 | def backward(ctx, dx): 183 | if ctx.needs_input_grad[0]: 184 | w_ll = ctx.saved_tensors[0] 185 | B, C, H, W = ctx.shape 186 | dx = dx.view(B, 1, -1, H, W) 187 | dx = dx.transpose(1, 2).reshape(B, -1, H, W) 188 | filters = w_ll # torch.Size([3, 1, 2, 2]) 189 | filters = filters.repeat(C, 1, 1, 1).to(dtype=torch.float16) 190 | dx = torch.nn.functional.conv_transpose2d(dx, filters, stride=1, groups=C) 191 | dx = torch.cat((dx[:, :, :0], dx[:, :, 1:]), dim=2) 192 | dx = torch.cat((dx[:, :, :, :0], dx[:, :, :, 1:]), dim=3) 193 | return dx, None 194 | 195 | 196 | ######################################################################################### 197 | 198 | class FrequencyAttention(nn.Module): # 199 | 200 | def __init__(self, in_channel): 201 | super().__init__() 202 | # self.HH_reduce = nn.Sequential( 203 | # nn.Conv2d(in_channel, 1, kernel_size=3, padding=1, stride=1), 204 | # nn.Sigmoid(), 205 | # ) 206 | self.HH_reduce = SpatialAttention() 207 | # self.LH_reduce = nn.Sequential( 208 | # nn.Conv2d(in_channel, 1, kernel_size=3, padding=1, stride=1), 209 | # nn.Sigmoid(), 210 | # ) 211 | self.LH_reduce = SpatialAttention() 212 | # self.HL_reduce = nn.Sequential( 213 | # nn.Conv2d(in_channel, 1, kernel_size=3, padding=1, stride=1), 214 | # nn.Sigmoid(), 215 | # ) 216 | self.HL_reduce = SpatialAttention() 217 | # self.LL_reduce = nn.Sequential( 218 | # nn.Conv2d(in_channel, 1, kernel_size=3, padding=1, stride=1), 219 | # nn.Sigmoid(), 220 | # ) 221 | self.LL_reduce = SpatialAttention() 222 | self.conv_res = nn.Sequential( 223 | nn.Conv2d(in_channel, in_channel, kernel_size=3, padding=1, stride=1), 224 | nn.BatchNorm2d(in_channel), 225 | ) 226 | self.relu = nn.ReLU(inplace=False) 227 | # self.fc =nn.Linear(in_channel*3, in_channel, bias=False) 228 | self.LH_DWT_2D_attation = LH_DWT_2D_attation('haar') 229 | self.HL_DWT_2D_attation = HL_DWT_2D_attation('haar') 230 | self.HH_DWT_2D_attation = HH_DWT_2D_attation('haar') 231 | self.LL_DWT_2D_attation = LL_DWT_2D_attation('haar') 232 | self.init_weights() 233 | 234 | def forward(self, x): 235 | out_LH = self.LH_DWT_2D_attation(x) 236 | out_LH = self.LH_reduce(out_LH) 237 | x_LH = x * out_LH 238 | # print(x_LH.size()) 239 | out_HL = self.HL_DWT_2D_attation(x) 240 | out_HL = self.HL_reduce(out_HL) 241 | x_HL = x * out_HL 242 | out_HH = self.HH_DWT_2D_attation(x) 243 | out_HH = self.HH_reduce(out_HH) 244 | x_HH = x * out_HH 245 | out_LL = self.LL_DWT_2D_attation(x) 246 | out_LL = self.LL_reduce(out_LL) 247 | x_LL = x * out_LL 248 | x_out = x_LH + x_HL + x_HH + x_LL 249 | x_out = self.relu(x + self.conv_res(x_out)) 250 | return x_out 251 | 252 | def init_weights(self): 253 | for m in self.modules(): 254 | if isinstance(m, nn.Conv2d): 255 | init.kaiming_normal_(m.weight, mode='fan_out') 256 | if m.bias is not None: 257 | init.constant_(m.bias, 0) 258 | elif isinstance(m, nn.BatchNorm2d): 259 | init.constant_(m.weight, 1) 260 | init.constant_(m.bias, 0) 261 | elif isinstance(m, nn.Linear): 262 | init.normal_(m.weight, std=0.001) 263 | if m.bias is not None: 264 | init.constant_(m.bias, 0) 265 | 266 | 267 | ############################################################################################ 268 | 269 | class DWT_2D(nn.Module): 270 | def __init__(self, wave): 271 | super(DWT_2D, self).__init__() 272 | w = pywt.Wavelet(wave) 273 | dec_hi = torch.Tensor(w.dec_hi[::-1]) 274 | dec_lo = torch.Tensor(w.dec_lo[::-1]) 275 | 276 | w_ll = dec_lo.unsqueeze(0) * dec_lo.unsqueeze(1) 277 | w_lh = dec_lo.unsqueeze(0) * dec_hi.unsqueeze(1) 278 | w_hl = dec_hi.unsqueeze(0) * dec_lo.unsqueeze(1) 279 | w_hh = dec_hi.unsqueeze(0) * dec_hi.unsqueeze(1) 280 | 281 | self.register_buffer('w_ll', w_ll.unsqueeze(0).unsqueeze(0)) 282 | self.register_buffer('w_lh', w_lh.unsqueeze(0).unsqueeze(0)) 283 | self.register_buffer('w_hl', w_hl.unsqueeze(0).unsqueeze(0)) 284 | self.register_buffer('w_hh', w_hh.unsqueeze(0).unsqueeze(0)) 285 | 286 | self.w_ll = self.w_ll.to(dtype=torch.float32) 287 | self.w_lh = self.w_lh.to(dtype=torch.float32) 288 | self.w_hl = self.w_hl.to(dtype=torch.float32) 289 | self.w_hh = self.w_hh.to(dtype=torch.float32) 290 | 291 | def forward(self, x): 292 | return DWT_Function.apply(x, self.w_ll, self.w_lh, self.w_hl, self.w_hh) 293 | 294 | 295 | class DWT_Function(Function): 296 | @staticmethod 297 | def forward(ctx, x, w_ll, w_lh, w_hl, w_hh): 298 | x = x.contiguous() 299 | ctx.save_for_backward(w_ll, w_lh, w_hl, w_hh) 300 | ctx.shape = x.shape 301 | dim = x.shape[1] 302 | 303 | x_ll = torch.nn.functional.conv2d(x, w_ll.expand(dim, -1, -1, -1), stride=2, padding=0, groups=dim) 304 | x_lh = torch.nn.functional.conv2d(x, w_lh.expand(dim, -1, -1, -1), stride=2, padding=0, groups=dim) 305 | x_hl = torch.nn.functional.conv2d(x, w_hl.expand(dim, -1, -1, -1), stride=2, padding=0, groups=dim) 306 | x_hh = torch.nn.functional.conv2d(x, w_hh.expand(dim, -1, -1, -1), stride=2, padding=0, 307 | groups=dim) 308 | x = torch.cat([x_lh, x_hl, x_hh], dim=1) 309 | # x = x_hh 310 | return x 311 | 312 | @staticmethod 313 | def backward(ctx, dx): 314 | if ctx.needs_input_grad[0]: 315 | w_ll, w_lh, w_hl, w_hh = ctx.saved_tensors 316 | B, C, H, W = ctx.shape 317 | 318 | dx = dx.view(B, 3, -1, H // 2, W // 2) 319 | dx = dx.transpose(1, 2).reshape(B, -1, H // 2, W // 2) 320 | filters = torch.cat([w_lh, w_hl, w_hh], dim=0) 321 | filters = filters.repeat(C, 1, 1, 1).to(dtype=torch.float16) 322 | dx = torch.nn.functional.conv_transpose2d(dx, filters, stride=2, groups=C) 323 | 324 | return dx, None, None, None, None 325 | 326 | 327 | class Hfrequencyfeature(nn.Module): # 328 | 329 | def __init__(self): 330 | super().__init__() 331 | 332 | self.DWT_2D = DWT_2D('haar') 333 | 334 | def forward(self, x): # x:torch.Size([8, 16, 512, 512]) 335 | out = self.DWT_2D(x) # torch.Size([8, 48, 512, 512]) 336 | return out 337 | 338 | 339 | def weights_init(m): 340 | if isinstance(m, nn.Conv2d): 341 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 342 | if m.bias is not None: 343 | nn.init.constant_(m.bias, 0) 344 | elif isinstance(m, nn.BatchNorm2d): 345 | nn.init.constant_(m.weight, 1) 346 | nn.init.constant_(m.bias, 0) 347 | 348 | return 349 | 350 | 351 | class BasicConv2d(nn.Module): 352 | def __init__(self, in_planes, out_planes, kernel_size, stride=1, padding=0, dilation=1): 353 | super(BasicConv2d, self).__init__() 354 | self.conv = nn.Conv2d(in_planes, out_planes, 355 | kernel_size=kernel_size, stride=stride, 356 | padding=padding, dilation=dilation, bias=False) 357 | self.bn = nn.BatchNorm2d(out_planes) 358 | self.relu = nn.ReLU(inplace=False) 359 | 360 | def forward(self, x): 361 | x = self.conv(x) 362 | x = self.bn(x) 363 | x = self.relu(x) 364 | return x 365 | 366 | 367 | ############################################################################## 368 | class RFB_modified(nn.Module): 369 | def __init__(self, in_channel, out_channel): 370 | super(RFB_modified, self).__init__() 371 | self.relu = nn.ReLU(inplace=False) 372 | self.branch0 = nn.Sequential( 373 | BasicConv2d(in_channel, out_channel, 1), 374 | ) 375 | self.branch1 = nn.Sequential( 376 | BasicConv2d(in_channel, out_channel, 1), 377 | BasicConv2d(out_channel, out_channel, 3, padding=1, dilation=1) 378 | ) 379 | 380 | self.branch2 = nn.Sequential( 381 | BasicConv2d(in_channel, out_channel, kernel_size=(3, 3), padding=1), 382 | BasicConv2d(out_channel, out_channel, 3, padding=3, dilation=3) 383 | ) 384 | self.branch3 = nn.Sequential( 385 | BasicConv2d(in_channel, out_channel, kernel_size=(5, 5), padding=2), 386 | BasicConv2d(out_channel, out_channel, 3, padding=5, dilation=5) 387 | ) 388 | 389 | self.conv_cat = BasicConv2d(4*out_channel, out_channel, 3, padding=1) 390 | self.conv_res = BasicConv2d(in_channel, out_channel, 3, padding=1) 391 | self.FrequencyAttention = FrequencyAttention(in_channel=out_channel) 392 | self.SEAttention = SEAttention(channel=out_channel, reduction=4) 393 | 394 | def forward(self, x): 395 | x0 = self.branch0(x) 396 | x1 = self.branch1(x) 397 | x2 = self.branch2(x) 398 | x3 = self.branch3(x) 399 | x_cat = self.conv_cat(torch.cat((x0, x1, x2, x3), 1)) 400 | x_cat = self.FrequencyAttention(x_cat) 401 | # print(x_cat.shape) 402 | x_cat = self.SEAttention(x_cat) 403 | x = self.relu(x+ self.conv_res(x_cat)) 404 | return x 405 | 406 | 407 | class RFB_modified_LCL(nn.Module): 408 | def __init__(self, in_channel, out_channel): 409 | super(RFB_modified_LCL, self).__init__() 410 | self.relu = nn.ReLU(inplace=False) 411 | self.branch0 = nn.Sequential( 412 | BasicConv2d(in_channel, out_channel, 1), 413 | ) 414 | self.branch1 = nn.Sequential( 415 | BasicConv2d(in_channel, out_channel, kernel_size=(3, 3), padding=1), 416 | BasicConv2d(out_channel, out_channel, 3, padding=3, dilation=3) 417 | ) 418 | self.branch2 = nn.Sequential( 419 | BasicConv2d(in_channel, out_channel, kernel_size=(5, 5), padding=2), 420 | BasicConv2d(out_channel, out_channel, 3, padding=5, dilation=5) 421 | ) 422 | 423 | self.conv_cat = BasicConv2d(3 * out_channel, out_channel, 3, padding=1) 424 | self.conv_res = BasicConv2d(in_channel, out_channel, 1) 425 | 426 | def forward(self, x): 427 | x0 = self.branch0(x) 428 | x1 = self.branch1(x) 429 | x2 = self.branch2(x) 430 | # x3 = self.branch3(x) 431 | x_cat = self.conv_cat(torch.cat((x0, x1, x2), 1)) 432 | # print(x_cat.size()) 433 | x_cat = self.relu(x_cat) 434 | return x_cat 435 | 436 | 437 | ############################################################################################################# 438 | class Resnet1(nn.Module): 439 | def __init__(self, in_channel, out_channel): 440 | super(Resnet1, self).__init__() 441 | self.layer = nn.Sequential( 442 | nn.Conv2d(in_channels=in_channel, out_channels=out_channel, kernel_size=3, padding=1, stride=1), 443 | nn.BatchNorm2d(out_channel), 444 | nn.ReLU(inplace=False), 445 | nn.Conv2d(in_channels=out_channel, out_channels=out_channel, kernel_size=3, padding=1, stride=1), 446 | nn.BatchNorm2d(out_channel) 447 | ) 448 | self.relu = nn.ReLU(inplace=False) 449 | 450 | self.layer.apply(weights_init) 451 | 452 | def forward(self, x): 453 | identity = x 454 | out = self.layer(x) 455 | out += identity 456 | return self.relu(out) 457 | 458 | 459 | # layer2_1 #layer3_1#layer4_1#layer5_1 460 | class Resnet2(nn.Module): 461 | def __init__(self, in_channel, out_channel): 462 | super(Resnet2, self).__init__() 463 | self.layer1 = nn.Sequential( 464 | nn.Conv2d(in_channels=in_channel, out_channels=out_channel, kernel_size=3, padding=1, stride=1), 465 | nn.BatchNorm2d(out_channel), 466 | nn.ReLU(inplace=False), 467 | nn.MaxPool2d(kernel_size=2, stride=2), 468 | nn.Conv2d(in_channels=out_channel, out_channels=out_channel, kernel_size=3, padding=1, stride=1), 469 | nn.BatchNorm2d(out_channel) 470 | ) 471 | self.layer2 = nn.Sequential( 472 | nn.Conv2d(in_channels=in_channel, out_channels=out_channel, kernel_size=3, padding=1, stride=2), 473 | nn.BatchNorm2d(out_channel), 474 | nn.ReLU(inplace=False) 475 | ) 476 | self.relu = nn.ReLU(inplace=False) 477 | 478 | self.layer1.apply(weights_init) 479 | self.layer2.apply(weights_init) 480 | 481 | def forward(self, x): 482 | identity = x 483 | out = self.layer1(x) 484 | identity = self.layer2(identity) 485 | # print("out:",out.size()) 486 | # print("identity:",identity.size()) 487 | out += identity 488 | return self.relu(out) 489 | 490 | 491 | class Hfrequency(nn.Module): 492 | def __init__(self): 493 | super(Hfrequency, self).__init__() 494 | self.Hfrequencyfeature = Hfrequencyfeature() 495 | 496 | def forward(self, x, out_2): 497 | out_1 = self.Hfrequencyfeature(x) 498 | out = torch.cat([out_1, out_2], dim=1) 499 | return out 500 | 501 | 502 | class Stage(nn.Module): 503 | def __init__(self): 504 | super(Stage, self).__init__() 505 | self.layer1 = nn.Sequential( 506 | nn.Conv2d(in_channels=3, out_channels=16, kernel_size=3, padding=1, stride=1), 507 | nn.BatchNorm2d(16), 508 | nn.ReLU(inplace=False), 509 | nn.Conv2d(in_channels=16, out_channels=16, kernel_size=3, padding=1, stride=1), 510 | nn.BatchNorm2d(16), 511 | nn.ReLU(inplace=False) 512 | ) 513 | self.resnet1_1 = Resnet1(in_channel=16, out_channel=16) 514 | self.resnet1_2 = RFB_modified(in_channel=16, out_channel=16) 515 | self.resnet1_3 = RFB_modified(in_channel=16, out_channel=16) 516 | self.layer1_4 = nn.Sequential( 517 | nn.MaxPool2d(kernel_size=2, stride=2), 518 | nn.MaxPool2d(kernel_size=2, stride=2), 519 | nn.MaxPool2d(kernel_size=2, stride=2), 520 | nn.MaxPool2d(kernel_size=2, stride=2), 521 | nn.Conv2d(in_channels=16, out_channels=16, kernel_size=1), 522 | nn.ReLU(inplace=False), 523 | ) 524 | 525 | 526 | self.resnet2_1 = Resnet2(in_channel=16, out_channel=23) 527 | self.hfrequency = Hfrequency() 528 | self.resnet2_2 = RFB_modified(in_channel=32, out_channel=32) 529 | self.resnet2_3 = RFB_modified(in_channel=32, out_channel=32) 530 | self.layer2_4 = nn.Sequential( 531 | nn.MaxPool2d(kernel_size=2, stride=2), 532 | nn.MaxPool2d(kernel_size=2, stride=2), 533 | nn.MaxPool2d(kernel_size=2, stride=2), 534 | nn.Conv2d(in_channels=32, out_channels=16, kernel_size=1), 535 | nn.ReLU(inplace=False), 536 | ) 537 | 538 | 539 | self.resnet3_1 = Resnet2(in_channel=32, out_channel=64) 540 | self.resnet3_2 = RFB_modified(in_channel=64, out_channel=64) 541 | self.resnet3_3 = RFB_modified(in_channel=64, out_channel=64) 542 | self.layer3_4 = nn.Sequential( 543 | nn.MaxPool2d(kernel_size=2, stride=2), 544 | nn.MaxPool2d(kernel_size=2, stride=2), 545 | nn.Conv2d(in_channels=64, out_channels=16, kernel_size=1), 546 | nn.ReLU(inplace=False), 547 | ) 548 | 549 | 550 | self.resnet4_1 = Resnet2(in_channel=64, out_channel=64) 551 | self.resnet4_2 = RFB_modified(in_channel=64, out_channel=64) 552 | self.resnet4_3 = RFB_modified(in_channel=64, out_channel=64) 553 | self.layer4_4 = nn.Sequential( 554 | nn.MaxPool2d(kernel_size=2, stride=2), 555 | nn.Conv2d(in_channels=64, out_channels=16, kernel_size=1), 556 | nn.ReLU(inplace=False), 557 | ) 558 | 559 | 560 | self.resnet5_1 = Resnet2(in_channel=64, out_channel=64) 561 | self.resnet5_2 = RFB_modified(in_channel=64, out_channel=64) 562 | self.resnet5_3 = RFB_modified(in_channel=64, out_channel=64) 563 | self.layer5_4 = nn.Sequential( 564 | nn.Conv2d(in_channels=128, out_channels=128, kernel_size=1), 565 | nn.ReLU(inplace=False), 566 | nn.Conv2d(in_channels=128, out_channels=64, kernel_size=1), 567 | nn.ReLU(inplace=False), 568 | 569 | ) 570 | 571 | 572 | 573 | self.layer1.apply(weights_init) 574 | self.layer1_4.apply(weights_init) 575 | self.layer2_4.apply(weights_init) 576 | self.layer3_4.apply(weights_init) 577 | self.layer4_4.apply(weights_init) 578 | self.layer5_4.apply(weights_init) 579 | 580 | def forward(self, x): 581 | outs = [] 582 | out = self.layer1(x) 583 | 584 | # print(out_1.size()) 585 | out = self.resnet1_1(out) 586 | out = self.resnet1_2(out) 587 | out = self.resnet1_3(out) 588 | out_1 = self.layer1_4(out) 589 | 590 | outs.append(out) 591 | out = self.resnet2_1(out) 592 | out = self.hfrequency(x,out) 593 | out = self.resnet2_2(out) 594 | out = self.resnet2_3(out) 595 | out_2 = self.layer2_4(out) 596 | 597 | outs.append(out) 598 | out = self.resnet3_1(out) 599 | out = self.resnet3_2(out) 600 | out = self.resnet3_3(out) 601 | out_3 = self.layer3_4(out) 602 | 603 | outs.append(out) 604 | out = self.resnet4_1(out) 605 | out = self.resnet4_2(out) 606 | out = self.resnet4_3(out) 607 | out_4 = self.layer4_4(out) 608 | 609 | outs.append(out) 610 | out = self.resnet5_1(out) 611 | out = self.resnet5_2(out) 612 | out = self.resnet5_3(out) 613 | out = torch.cat([out, out_4, out_3,out_2,out_1], dim=1) 614 | out = self.layer5_4(out) 615 | 616 | outs.append(out) 617 | return outs 618 | 619 | class Sbam(nn.Module): 620 | def __init__(self, in_channel, out_channel): 621 | super(Sbam, self).__init__() 622 | self.hl_up = nn.Sequential( 623 | nn.UpsamplingBilinear2d(scale_factor=2), 624 | ) 625 | self.hl_layer = nn.Sequential( 626 | nn.Conv2d(in_channels=in_channel, out_channels=out_channel, kernel_size=1), 627 | nn.BatchNorm2d(out_channel), 628 | nn.ReLU(inplace=False) 629 | ) 630 | self.concat_layer = nn.Sequential( 631 | nn.Conv2d(in_channels=in_channel+out_channel, out_channels=in_channel, kernel_size=1), 632 | nn.BatchNorm2d(in_channel), 633 | nn.ReLU(inplace=False), 634 | nn.Conv2d(in_channels=in_channel, out_channels=in_channel, kernel_size=3, padding=1, stride=1), 635 | nn.BatchNorm2d(in_channel), 636 | nn.ReLU(inplace=False), 637 | ) 638 | 639 | 640 | self.hl_layer.apply(weights_init) 641 | self.concat_layer.apply(weights_init) 642 | 643 | def forward(self, hl, ll): 644 | hl = self.hl_up(hl) 645 | concat = torch.cat((hl, ll), 1) 646 | k = self.concat_layer(concat) 647 | hl = hl + k 648 | hl = self.hl_layer(hl) 649 | out = ll + hl 650 | return out 651 | 652 | 653 | class MSDA_No_Sigmoid(nn.Module): 654 | def __init__(self): 655 | super(MSDA_No_Sigmoid, self).__init__() 656 | self.stage = Stage() 657 | self.mlcl5 = RFB_modified_LCL(64, 64) 658 | self.mlcl4 = RFB_modified_LCL(64, 64) 659 | self.mlcl3 = RFB_modified_LCL(64, 64) 660 | self.mlcl2 = RFB_modified_LCL(32, 32) 661 | self.mlcl1 = RFB_modified_LCL(16, 16) 662 | 663 | self.sbam4 = Sbam(64, 64) 664 | self.sbam3 = Sbam(64, 64) 665 | self.sbam2 = Sbam(64, 32) 666 | self.sbam1 = Sbam(32, 16) 667 | 668 | self.layer = nn.Sequential( 669 | nn.Conv2d(in_channels=16, out_channels=16, kernel_size=1), 670 | nn.ReLU(inplace=False), 671 | nn.Conv2d(in_channels=16, out_channels=1, kernel_size=1), 672 | # nn.Sigmoid() 673 | ) 674 | 675 | self.layer.apply(weights_init) 676 | 677 | def forward(self, x): 678 | outs = self.stage(x) 679 | 680 | out5 = self.mlcl5(outs[4]) 681 | out4 = self.mlcl4(outs[3]) 682 | out3 = self.mlcl3(outs[2]) 683 | out2 = self.mlcl2(outs[1]) 684 | out1 = self.mlcl1(outs[0]) 685 | 686 | out4_2 = self.sbam4(out5, out4) 687 | out3_2 = self.sbam3(out4_2, out3) 688 | out2_2 = self.sbam2(out3_2, out2) 689 | out1_2 = self.sbam1(out2_2, out1) 690 | out = self.layer(out1_2) 691 | 692 | return out 693 | 694 | 695 | def weights_init(m): 696 | if isinstance(m, nn.Conv2d): 697 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 698 | if m.bias is not None: 699 | nn.init.constant_(m.bias, 0) 700 | elif isinstance(m, nn.BatchNorm2d): 701 | nn.init.constant_(m.weight, 1) 702 | nn.init.constant_(m.bias, 0) 703 | # elif isinstance(m, nn.Linear): 704 | # nn.init.xavier_uniform_(m.weight) 705 | # nn.init.constant_(m.bias, 0) 706 | return 707 | 708 | 709 | if __name__ == '__main__': 710 | model = MSDA_No_Sigmoid() 711 | x = torch.rand(8, 3, 512, 512) 712 | outs = model(x) 713 | print(outs.size()) 714 | -------------------------------------------------------------------------------- /model/MSDA/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YuChuang1205/PAL/a9046c5fc65e7b2548d6e195b624239fa836cf3e/model/MSDA/__init__.py -------------------------------------------------------------------------------- /model/MSDA/__pycache__/MSDA_no_sigmoid.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YuChuang1205/PAL/a9046c5fc65e7b2548d6e195b624239fa836cf3e/model/MSDA/__pycache__/MSDA_no_sigmoid.cpython-36.pyc -------------------------------------------------------------------------------- /model/MSDA/__pycache__/MSDA_no_sigmoid.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YuChuang1205/PAL/a9046c5fc65e7b2548d6e195b624239fa836cf3e/model/MSDA/__pycache__/MSDA_no_sigmoid.cpython-38.pyc -------------------------------------------------------------------------------- /model/MSDA/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YuChuang1205/PAL/a9046c5fc65e7b2548d6e195b624239fa836cf3e/model/MSDA/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /model/MSDA/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YuChuang1205/PAL/a9046c5fc65e7b2548d6e195b624239fa836cf3e/model/MSDA/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /model/UIU/UIU_no_sigmoid.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | ############################################# 5 | from .fusion import * 6 | ################################################## 7 | 8 | class REBNCONV(nn.Module): 9 | def __init__(self,in_ch=3,out_ch=3,dirate=1): 10 | super(REBNCONV,self).__init__() 11 | 12 | self.conv_s1 = nn.Conv2d(in_ch,out_ch,3,padding=1*dirate,dilation=1*dirate) 13 | self.bn_s1 = nn.BatchNorm2d(out_ch) 14 | self.relu_s1 = nn.ReLU(inplace=True) 15 | 16 | def forward(self,x): 17 | 18 | hx = x 19 | xout = self.relu_s1(self.bn_s1(self.conv_s1(hx))) 20 | 21 | return xout 22 | 23 | def _upsample_like(src,tar): 24 | 25 | src = F.upsample(src, size=tar.shape[2:], mode='bilinear') 26 | 27 | return src 28 | 29 | 30 | ### RSU-7 ### 31 | class RSU7(nn.Module):#UNet07DRES(nn.Module): 32 | 33 | def __init__(self, in_ch=3, mid_ch=12, out_ch=3): 34 | super(RSU7,self).__init__() 35 | 36 | self.rebnconvin = REBNCONV(in_ch,out_ch,dirate=1) 37 | 38 | self.rebnconv1 = REBNCONV(out_ch,mid_ch,dirate=1) 39 | self.pool1 = nn.MaxPool2d(2,stride=2,ceil_mode=True) 40 | 41 | self.rebnconv2 = REBNCONV(mid_ch,mid_ch,dirate=1) 42 | self.pool2 = nn.MaxPool2d(2,stride=2,ceil_mode=True) 43 | 44 | self.rebnconv3 = REBNCONV(mid_ch,mid_ch,dirate=1) 45 | self.pool3 = nn.MaxPool2d(2,stride=2,ceil_mode=True) 46 | 47 | self.rebnconv4 = REBNCONV(mid_ch,mid_ch,dirate=1) 48 | self.pool4 = nn.MaxPool2d(2,stride=2,ceil_mode=True) 49 | 50 | self.rebnconv5 = REBNCONV(mid_ch,mid_ch,dirate=1) 51 | self.pool5 = nn.MaxPool2d(2,stride=2,ceil_mode=True) 52 | 53 | self.rebnconv6 = REBNCONV(mid_ch,mid_ch,dirate=1) 54 | 55 | self.rebnconv7 = REBNCONV(mid_ch,mid_ch,dirate=2) 56 | 57 | self.rebnconv6d = REBNCONV(mid_ch*2,mid_ch,dirate=1) 58 | self.rebnconv5d = REBNCONV(mid_ch*2,mid_ch,dirate=1) 59 | self.rebnconv4d = REBNCONV(mid_ch*2,mid_ch,dirate=1) 60 | self.rebnconv3d = REBNCONV(mid_ch*2,mid_ch,dirate=1) 61 | self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=1) 62 | self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1) 63 | 64 | def forward(self,x): 65 | 66 | hx = x 67 | hxin = self.rebnconvin(hx) 68 | 69 | hx1 = self.rebnconv1(hxin) 70 | hx = self.pool1(hx1) 71 | 72 | hx2 = self.rebnconv2(hx) 73 | hx = self.pool2(hx2) 74 | 75 | hx3 = self.rebnconv3(hx) 76 | hx = self.pool3(hx3) 77 | 78 | hx4 = self.rebnconv4(hx) 79 | hx = self.pool4(hx4) 80 | 81 | hx5 = self.rebnconv5(hx) 82 | hx = self.pool5(hx5) 83 | 84 | hx6 = self.rebnconv6(hx) 85 | 86 | hx7 = self.rebnconv7(hx6) 87 | 88 | hx6d = self.rebnconv6d(torch.cat((hx7,hx6),1)) 89 | hx6dup = _upsample_like(hx6d,hx5) 90 | 91 | hx5d = self.rebnconv5d(torch.cat((hx6dup,hx5),1)) 92 | hx5dup = _upsample_like(hx5d,hx4) 93 | 94 | hx4d = self.rebnconv4d(torch.cat((hx5dup,hx4),1)) 95 | hx4dup = _upsample_like(hx4d,hx3) 96 | 97 | hx3d = self.rebnconv3d(torch.cat((hx4dup,hx3),1)) 98 | hx3dup = _upsample_like(hx3d,hx2) 99 | 100 | hx2d = self.rebnconv2d(torch.cat((hx3dup,hx2),1)) 101 | hx2dup = _upsample_like(hx2d,hx1) 102 | 103 | hx1d = self.rebnconv1d(torch.cat((hx2dup,hx1),1)) 104 | 105 | return hx1d + hxin 106 | 107 | ### RSU-6 ### 108 | class RSU6(nn.Module): 109 | 110 | def __init__(self, in_ch=3, mid_ch=12, out_ch=3): 111 | super(RSU6,self).__init__() 112 | 113 | self.rebnconvin = REBNCONV(in_ch,out_ch,dirate=1) 114 | 115 | self.rebnconv1 = REBNCONV(out_ch,mid_ch,dirate=1) 116 | self.pool1 = nn.MaxPool2d(2,stride=2,ceil_mode=True) 117 | 118 | self.rebnconv2 = REBNCONV(mid_ch,mid_ch,dirate=1) 119 | self.pool2 = nn.MaxPool2d(2,stride=2,ceil_mode=True) 120 | 121 | self.rebnconv3 = REBNCONV(mid_ch,mid_ch,dirate=1) 122 | self.pool3 = nn.MaxPool2d(2,stride=2,ceil_mode=True) 123 | 124 | self.rebnconv4 = REBNCONV(mid_ch,mid_ch,dirate=1) 125 | self.pool4 = nn.MaxPool2d(2,stride=2,ceil_mode=True) 126 | 127 | self.rebnconv5 = REBNCONV(mid_ch,mid_ch,dirate=1) 128 | 129 | self.rebnconv6 = REBNCONV(mid_ch,mid_ch,dirate=2) 130 | 131 | self.rebnconv5d = REBNCONV(mid_ch*2,mid_ch,dirate=1) 132 | self.rebnconv4d = REBNCONV(mid_ch*2,mid_ch,dirate=1) 133 | self.rebnconv3d = REBNCONV(mid_ch*2,mid_ch,dirate=1) 134 | self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=1) 135 | self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1) 136 | 137 | def forward(self,x): 138 | 139 | hx = x 140 | 141 | hxin = self.rebnconvin(hx) 142 | 143 | hx1 = self.rebnconv1(hxin) 144 | hx = self.pool1(hx1) 145 | 146 | hx2 = self.rebnconv2(hx) 147 | hx = self.pool2(hx2) 148 | 149 | hx3 = self.rebnconv3(hx) 150 | hx = self.pool3(hx3) 151 | 152 | hx4 = self.rebnconv4(hx) 153 | hx = self.pool4(hx4) 154 | 155 | hx5 = self.rebnconv5(hx) 156 | 157 | hx6 = self.rebnconv6(hx5) 158 | 159 | hx5d = self.rebnconv5d(torch.cat((hx6,hx5),1)) 160 | hx5dup = _upsample_like(hx5d,hx4) 161 | 162 | hx4d = self.rebnconv4d(torch.cat((hx5dup,hx4),1)) 163 | hx4dup = _upsample_like(hx4d,hx3) 164 | 165 | hx3d = self.rebnconv3d(torch.cat((hx4dup,hx3),1)) 166 | hx3dup = _upsample_like(hx3d,hx2) 167 | 168 | hx2d = self.rebnconv2d(torch.cat((hx3dup,hx2),1)) 169 | hx2dup = _upsample_like(hx2d,hx1) 170 | 171 | hx1d = self.rebnconv1d(torch.cat((hx2dup,hx1),1)) 172 | 173 | return hx1d + hxin 174 | 175 | ### RSU-5 ### 176 | class RSU5(nn.Module): 177 | 178 | def __init__(self, in_ch=3, mid_ch=12, out_ch=3): 179 | super(RSU5,self).__init__() 180 | 181 | self.rebnconvin = REBNCONV(in_ch,out_ch,dirate=1) 182 | 183 | self.rebnconv1 = REBNCONV(out_ch,mid_ch,dirate=1) 184 | self.pool1 = nn.MaxPool2d(2,stride=2,ceil_mode=True) 185 | 186 | self.rebnconv2 = REBNCONV(mid_ch,mid_ch,dirate=1) 187 | self.pool2 = nn.MaxPool2d(2,stride=2,ceil_mode=True) 188 | 189 | self.rebnconv3 = REBNCONV(mid_ch,mid_ch,dirate=1) 190 | self.pool3 = nn.MaxPool2d(2,stride=2,ceil_mode=True) 191 | 192 | self.rebnconv4 = REBNCONV(mid_ch,mid_ch,dirate=1) 193 | 194 | self.rebnconv5 = REBNCONV(mid_ch,mid_ch,dirate=2) 195 | 196 | self.rebnconv4d = REBNCONV(mid_ch*2,mid_ch,dirate=1) 197 | self.rebnconv3d = REBNCONV(mid_ch*2,mid_ch,dirate=1) 198 | self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=1) 199 | self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1) 200 | 201 | def forward(self,x): 202 | 203 | hx = x 204 | 205 | hxin = self.rebnconvin(hx) 206 | 207 | hx1 = self.rebnconv1(hxin) 208 | hx = self.pool1(hx1) 209 | 210 | hx2 = self.rebnconv2(hx) 211 | hx = self.pool2(hx2) 212 | 213 | hx3 = self.rebnconv3(hx) 214 | hx = self.pool3(hx3) 215 | 216 | hx4 = self.rebnconv4(hx) 217 | 218 | hx5 = self.rebnconv5(hx4) 219 | 220 | hx4d = self.rebnconv4d(torch.cat((hx5,hx4), 1)) 221 | hx4dup = _upsample_like(hx4d, hx3) 222 | 223 | hx3d = self.rebnconv3d(torch.cat((hx4dup, hx3), 1)) 224 | hx3dup = _upsample_like(hx3d, hx2) 225 | 226 | hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1)) 227 | hx2dup = _upsample_like(hx2d, hx1) 228 | 229 | hx1d = self.rebnconv1d(torch.cat((hx2dup,hx1),1)) 230 | 231 | return hx1d + hxin 232 | 233 | ### RSU-4 ### 234 | class RSU4(nn.Module): 235 | 236 | def __init__(self, in_ch=3, mid_ch=12, out_ch=3): 237 | super(RSU4,self).__init__() 238 | 239 | self.rebnconvin = REBNCONV(in_ch,out_ch,dirate=1) 240 | 241 | self.rebnconv1 = REBNCONV(out_ch,mid_ch,dirate=1) 242 | self.pool1 = nn.MaxPool2d(2,stride=2,ceil_mode=True) 243 | 244 | self.rebnconv2 = REBNCONV(mid_ch,mid_ch,dirate=1) 245 | self.pool2 = nn.MaxPool2d(2,stride=2,ceil_mode=True) 246 | 247 | self.rebnconv3 = REBNCONV(mid_ch,mid_ch,dirate=1) 248 | 249 | self.rebnconv4 = REBNCONV(mid_ch,mid_ch,dirate=2) 250 | 251 | self.rebnconv3d = REBNCONV(mid_ch*2,mid_ch,dirate=1) 252 | self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=1) 253 | self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1) 254 | 255 | def forward(self,x): 256 | 257 | hx = x 258 | 259 | hxin = self.rebnconvin(hx) 260 | 261 | hx1 = self.rebnconv1(hxin) 262 | hx = self.pool1(hx1) 263 | 264 | hx2 = self.rebnconv2(hx) 265 | hx = self.pool2(hx2) 266 | 267 | hx3 = self.rebnconv3(hx) 268 | 269 | hx4 = self.rebnconv4(hx3) 270 | 271 | hx3d = self.rebnconv3d(torch.cat((hx4,hx3), 1)) 272 | hx3dup = _upsample_like(hx3d, hx2) 273 | 274 | hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1)) 275 | hx2dup = _upsample_like(hx2d, hx1) 276 | 277 | hx1d = self.rebnconv1d(torch.cat((hx2dup,hx1),1)) 278 | 279 | return hx1d + hxin 280 | 281 | ### RSU-4F ### 282 | class RSU4F(nn.Module): 283 | 284 | def __init__(self, in_ch=3, mid_ch=12, out_ch=3): 285 | super(RSU4F,self).__init__() 286 | 287 | self.rebnconvin = REBNCONV(in_ch,out_ch,dirate=1) 288 | 289 | self.rebnconv1 = REBNCONV(out_ch,mid_ch,dirate=1) 290 | self.rebnconv2 = REBNCONV(mid_ch,mid_ch,dirate=2) 291 | self.rebnconv3 = REBNCONV(mid_ch,mid_ch,dirate=4) 292 | 293 | self.rebnconv4 = REBNCONV(mid_ch,mid_ch,dirate=8) 294 | 295 | self.rebnconv3d = REBNCONV(mid_ch*2,mid_ch,dirate=4) 296 | self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=2) 297 | self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1) 298 | 299 | def forward(self,x): 300 | 301 | hx = x 302 | 303 | hxin = self.rebnconvin(hx) 304 | 305 | hx1 = self.rebnconv1(hxin) 306 | hx2 = self.rebnconv2(hx1) 307 | hx3 = self.rebnconv3(hx2) 308 | 309 | hx4 = self.rebnconv4(hx3) 310 | 311 | hx3d = self.rebnconv3d(torch.cat((hx4,hx3),1)) 312 | hx2d = self.rebnconv2d(torch.cat((hx3d,hx2),1)) 313 | hx1d = self.rebnconv1d(torch.cat((hx2d,hx1),1)) 314 | 315 | return hx1d + hxin 316 | 317 | 318 | ##### UIU-net #### 319 | class UIU_No_Sigmoid(nn.Module): 320 | 321 | def __init__(self, in_ch=3, out_ch=1,mode='test'): 322 | super(UIU_No_Sigmoid, self).__init__() 323 | self.mode = mode 324 | 325 | self.stage1 = RSU7(in_ch,32,64) 326 | self.pool12 = nn.MaxPool2d(2,stride=2,ceil_mode=True) 327 | 328 | self.stage2 = RSU6(64,32,128) 329 | self.pool23 = nn.MaxPool2d(2,stride=2,ceil_mode=True) 330 | 331 | self.stage3 = RSU5(128,64,256) 332 | self.pool34 = nn.MaxPool2d(2,stride=2,ceil_mode=True) 333 | 334 | self.stage4 = RSU4(256,128,512) 335 | self.pool45 = nn.MaxPool2d(2,stride=2,ceil_mode=True) 336 | 337 | self.stage5 = RSU4F(512,256,512) 338 | self.pool56 = nn.MaxPool2d(2,stride=2,ceil_mode=True) 339 | 340 | self.stage6 = RSU4F(512,256,512) 341 | 342 | # decoder 343 | self.stage5d = RSU4F(1024,256,512) 344 | self.stage4d = RSU4(1024,128,256) 345 | self.stage3d = RSU5(512,64,128) 346 | self.stage2d = RSU6(256,32,64) 347 | self.stage1d = RSU7(128,16,64) 348 | 349 | self.side1 = nn.Conv2d(64,out_ch,3,padding=1) 350 | self.side2 = nn.Conv2d(64,out_ch,3,padding=1) 351 | self.side3 = nn.Conv2d(128,out_ch,3,padding=1) 352 | self.side4 = nn.Conv2d(256,out_ch,3,padding=1) 353 | self.side5 = nn.Conv2d(512,out_ch,3,padding=1) 354 | self.side6 = nn.Conv2d(512,out_ch,3,padding=1) 355 | 356 | self.outconv = nn.Conv2d(6*out_ch,out_ch,1) 357 | 358 | #self.fuse6 = self._fuse_layer(512, 512, 512, fuse_mode='AsymBi') 359 | self.fuse5 = self._fuse_layer(512, 512, 512, fuse_mode='AsymBi') 360 | self.fuse4 = self._fuse_layer(512, 512, 512, fuse_mode='AsymBi') 361 | self.fuse3 = self._fuse_layer(256, 256, 256, fuse_mode='AsymBi') 362 | self.fuse2 = self._fuse_layer(128, 128, 128, fuse_mode='AsymBi') 363 | 364 | 365 | def _fuse_layer(self, in_high_channels, in_low_channels, out_channels,fuse_mode='AsymBi'):#fuse_mode='AsymBi' 366 | # assert fuse_mode in ['BiLocal', 'AsymBi', 'BiGlobal'] 367 | # if fuse_mode == 'BiLocal': 368 | # fuse_layer = BiLocalChaFuseReduce(in_high_channels, in_low_channels, out_channels) 369 | # el 370 | if fuse_mode == 'AsymBi': 371 | fuse_layer = AsymBiChaFuseReduce(in_high_channels, in_low_channels, out_channels) 372 | # elif fuse_mode == 'BiGlobal': 373 | # fuse_layer = BiGlobalChaFuseReduce(in_high_channels, in_low_channels, out_channels) 374 | else: 375 | NameError 376 | return fuse_layer 377 | 378 | 379 | def forward(self, x): 380 | 381 | hx = x 382 | 383 | #stage 1 384 | hx1 = self.stage1(hx) 385 | hx = self.pool12(hx1) 386 | 387 | #stage 2 388 | hx2 = self.stage2(hx) 389 | hx = self.pool23(hx2) 390 | 391 | #stage 3 392 | hx3 = self.stage3(hx) 393 | hx = self.pool34(hx3) 394 | 395 | #stage 4 396 | hx4 = self.stage4(hx) 397 | hx = self.pool45(hx4) 398 | 399 | #stage 5 400 | hx5 = self.stage5(hx) 401 | hx = self.pool56(hx5) 402 | 403 | #stage 6 404 | hx6 = self.stage6(hx) 405 | hx6up = _upsample_like(hx6,hx5) 406 | 407 | #-------------------- decoder -------------------- 408 | 409 | fusec51,fusec52 = self.fuse5(hx6up, hx5) 410 | hx5d = self.stage5d(torch.cat((fusec51, fusec52),1)) 411 | hx5dup = _upsample_like(hx5d,hx4) 412 | 413 | fusec41,fusec42 = self.fuse4(hx5dup, hx4) 414 | hx4d = self.stage4d(torch.cat((fusec41,fusec42),1)) 415 | hx4dup = _upsample_like(hx4d,hx3) 416 | 417 | 418 | fusec31,fusec32 = self.fuse3(hx4dup, hx3) 419 | hx3d = self.stage3d(torch.cat((fusec31,fusec32),1)) 420 | hx3dup = _upsample_like(hx3d,hx2) 421 | 422 | fusec21, fusec22 = self.fuse2(hx3dup, hx2) 423 | hx2d = self.stage2d(torch.cat((fusec21, fusec22), 1)) 424 | hx2dup = _upsample_like(hx2d,hx1) 425 | 426 | 427 | hx1d = self.stage1d(torch.cat((hx2dup,hx1),1)) 428 | 429 | 430 | #side output 431 | d1 = self.side1(hx1d) 432 | 433 | d22 = self.side2(hx2d) 434 | d2 = _upsample_like(d22,d1) 435 | 436 | d32 = self.side3(hx3d) 437 | d3 = _upsample_like(d32,d1) 438 | 439 | d42 = self.side4(hx4d) 440 | d4 = _upsample_like(d42,d1) 441 | 442 | d52 = self.side5(hx5d) 443 | d5 = _upsample_like(d52,d1) 444 | 445 | d62 = self.side6(hx6) 446 | d6 = _upsample_like(d62,d1) 447 | 448 | d0 = self.outconv(torch.cat((d1,d2,d3,d4,d5,d6),1)) 449 | 450 | # if self.mode == 'train': 451 | # return F.sigmoid(d0), F.sigmoid(d1), F.sigmoid(d2), F.sigmoid(d3), F.sigmoid(d4), F.sigmoid(d5), F.sigmoid(d6) 452 | # else: 453 | # return F.sigmoid(d0) 454 | if self.mode == 'train': 455 | return d0, d1, d2, d3, d4, d5, d6 456 | else: 457 | return d0 458 | -------------------------------------------------------------------------------- /model/UIU/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | -------------------------------------------------------------------------------- /model/UIU/__pycache__/UIU_no_sigmoid.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YuChuang1205/PAL/a9046c5fc65e7b2548d6e195b624239fa836cf3e/model/UIU/__pycache__/UIU_no_sigmoid.cpython-36.pyc -------------------------------------------------------------------------------- /model/UIU/__pycache__/UIU_no_sigmoid.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YuChuang1205/PAL/a9046c5fc65e7b2548d6e195b624239fa836cf3e/model/UIU/__pycache__/UIU_no_sigmoid.cpython-38.pyc -------------------------------------------------------------------------------- /model/UIU/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YuChuang1205/PAL/a9046c5fc65e7b2548d6e195b624239fa836cf3e/model/UIU/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /model/UIU/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YuChuang1205/PAL/a9046c5fc65e7b2548d6e195b624239fa836cf3e/model/UIU/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /model/UIU/__pycache__/fusion.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YuChuang1205/PAL/a9046c5fc65e7b2548d6e195b624239fa836cf3e/model/UIU/__pycache__/fusion.cpython-36.pyc -------------------------------------------------------------------------------- /model/UIU/__pycache__/fusion.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YuChuang1205/PAL/a9046c5fc65e7b2548d6e195b624239fa836cf3e/model/UIU/__pycache__/fusion.cpython-38.pyc -------------------------------------------------------------------------------- /model/UIU/fusion.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | class AsymBiChaFuseReduce(nn.Module): 5 | def __init__(self, in_high_channels, in_low_channels, out_channels=64, r=4): 6 | super(AsymBiChaFuseReduce, self).__init__() 7 | assert in_low_channels == out_channels 8 | self.high_channels = in_high_channels 9 | self.low_channels = in_low_channels 10 | self.out_channels = out_channels 11 | self.bottleneck_channels = int(out_channels // r) 12 | 13 | self.feature_high = nn.Sequential( 14 | nn.Conv2d(self.high_channels, self.out_channels, 1, 1, 0), 15 | #nn.BatchNorm2d(out_channels), 16 | nn.GroupNorm(1, out_channels), 17 | nn.ReLU(True), 18 | )##512 19 | 20 | self.topdown = nn.Sequential( 21 | nn.AdaptiveAvgPool2d((1, 1)), 22 | nn.Conv2d(self.out_channels, self.bottleneck_channels, 1, 1, 0), 23 | #nn.BatchNorm2d(self.bottleneck_channels), 24 | nn.GroupNorm(1, self.bottleneck_channels), 25 | nn.ReLU(True), 26 | 27 | nn.Conv2d(self.bottleneck_channels, self.out_channels, 1, 1, 0), 28 | #nn.BatchNorm2d(self.out_channels), 29 | nn.GroupNorm(1, out_channels), 30 | nn.Sigmoid(), 31 | )#512 32 | 33 | ##############add spatial attention ###Cross UtU############ 34 | self.bottomup = nn.Sequential( 35 | nn.Conv2d(self.low_channels, self.bottleneck_channels, 1, 1, 0), 36 | #nn.BatchNorm2d(self.bottleneck_channels), 37 | nn.GroupNorm(1, self.bottleneck_channels), 38 | nn.ReLU(True), 39 | # nn.Sigmoid(), 40 | 41 | SpatialAttention(kernel_size=3), 42 | # nn.Conv2d(self.bottleneck_channels, 2, 3, 1, 0), 43 | # nn.Conv2d(2, 1, 1, 1, 0), 44 | #nn.BatchNorm2d(self.out_channels), 45 | nn.Sigmoid() 46 | ) 47 | 48 | self.post = nn.Sequential( 49 | nn.Conv2d(self.out_channels, self.out_channels, 3, 1, 1), 50 | #nn.BatchNorm2d(self.out_channels), 51 | nn.GroupNorm(1, out_channels), 52 | nn.ReLU(True), 53 | )#512 54 | 55 | def forward(self, xh, xl): 56 | xh = self.feature_high(xh) 57 | 58 | topdown_wei = self.topdown(xh) 59 | bottomup_wei = self.bottomup(xl * topdown_wei) 60 | xs1 = 2 * xl * topdown_wei #1 61 | out1 = self.post(xs1) 62 | 63 | xs2 = 2 * xh * bottomup_wei #1 64 | out2 = self.post(xs2) 65 | return out1,out2 66 | 67 | ############################## 68 | class SpatialAttention(nn.Module): 69 | def __init__(self, kernel_size=3): 70 | super(SpatialAttention, self).__init__() 71 | 72 | assert kernel_size in (3, 7), 'kernel size must be 3 or 7' 73 | padding = 3 if kernel_size == 7 else 1 74 | 75 | self.conv1 = nn.Conv2d(2, 1, kernel_size, padding=padding, bias=False) 76 | 77 | def forward(self, x): 78 | avg_out = torch.mean(x, dim=1, keepdim=True) 79 | max_out, _ = torch.max(x, dim=1, keepdim=True) 80 | x = torch.cat([avg_out, max_out], dim=1) 81 | x = self.conv1(x) 82 | return x 83 | -------------------------------------------------------------------------------- /test_model.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python3 2 | # coding = gbk 3 | """ 4 | @Author : yuchuang 5 | @Time : 6 | @desc: 7 | """ 8 | import os 9 | import albumentations as A 10 | from albumentations.pytorch import ToTensorV2 11 | from torch.utils.data import DataLoader 12 | import numpy as np 13 | import torch 14 | import cv2 15 | from torch.utils.data import Dataset 16 | from PIL import Image 17 | # from model.MSDA.MSDA_no_sigmoid import MSDANet_No_Sigmoid 18 | from skimage import measure 19 | import torch.nn.functional as F 20 | from torch.autograd import Variable 21 | import math 22 | from components.cal_mean_std import Calculate_mean_std 23 | from utilts import access_model 24 | 25 | def read_txt(txt_path): 26 | with open(txt_path, 'r') as file: 27 | lines = file.readlines() 28 | image_out_list = [line.strip() + '.png' for line in lines] 29 | return image_out_list 30 | 31 | 32 | def make_dir(path): 33 | if os.path.exists(path) == False: 34 | os.makedirs(path) 35 | 36 | 37 | ############################################## 38 | choose_model = 'MSDA' ##choose model in [ACM, ALC, MLCL, ALCL, DNA, GGL, UIU, MSDA] 39 | model_func = access_model(choose_model) 40 | choose_dataset = 'SIRST3' ## choose dataset in [SIRST3, IRSTD_1K_point, NUDT_SIRST_1_1_point, SIRST_1_1_point_new] 41 | test_dir_name = '********' ## Replace with the folder name where the corresponding test model is located, such as 'MSDA__SIRST3__masks_coarse__2024-12-13_13-30-35'. Since the timestamps are unique, you need the folder name you generated. 42 | test_model_name = 'best_mIoU_checkpoint_' + test_dir_name + ".pth.tar" 43 | ################################################ 44 | 45 | 46 | # Hyperparameters etc. 47 | # os.environ['CUDA_VISIBLE_DEVICES'] = '0' 48 | DEVICE = "cuda:0" if torch.cuda.is_available() else "cpu" 49 | TEST_BATCH_SIZE = 1 50 | NUM_WORKERS = 4 51 | PIN_MEMORY = True 52 | LOAD_MODEL = False 53 | patch_size_test = 1024 54 | TEST_PATCH_BATCH_SIZE = 32 55 | root_path = os.path.abspath('.') 56 | dataset_path = os.path.join(root_path,'dataset', choose_dataset) 57 | test_dataset_path = os.path.join(dataset_path,'val') 58 | input_path = os.path.join(test_dataset_path, 'img') 59 | output_path = os.path.join(test_dataset_path, 'pre_results') 60 | make_dir(output_path) 61 | # TEST_NUM = len(os.listdir(input_path)) 62 | # txt_path = os.path.join(root_path, 'img_idx', 'test.txt') 63 | img_list = os.listdir(input_path) 64 | 65 | test_model_path = os.path.join(root_path,'work_dirs',test_dir_name,test_model_name) 66 | 67 | def test_pred(img, net, batch_size, patch_size): 68 | b, c, h, w = img.shape 69 | # print(img.shape) 70 | patch_size = patch_size 71 | stride = patch_size 72 | 73 | if h > patch_size and w > patch_size: 74 | # Unfold the image into patches 75 | img_unfold = F.unfold(img, kernel_size=patch_size, stride=stride) 76 | img_unfold = img_unfold.reshape(b, c, patch_size, patch_size, -1).permute(0, 4, 1, 2, 3) 77 | # print(img_unfold.shape) 78 | patch_num = img_unfold.size(1) 79 | 80 | preds_list = [] 81 | for i in range(0, patch_num, batch_size): 82 | end = min(i + batch_size, patch_num) 83 | batch_patches = img_unfold[:, i:end, :, :, :].reshape(-1, c, patch_size, patch_size) 84 | batch_patches = Variable(batch_patches.float()) 85 | batch_preds = net.forward(batch_patches) 86 | preds_list.append(batch_preds) 87 | # Concatenate all the patch predictions 88 | preds_unfold = torch.cat(preds_list, dim=0).permute(1, 2, 3, 0) 89 | preds_unfold = preds_unfold.reshape(b, -1, patch_num) 90 | preds = F.fold(preds_unfold, kernel_size=patch_size, stride=stride, output_size=(h, w)) 91 | else: 92 | preds = net.forward(img) 93 | 94 | return preds 95 | 96 | 97 | class SirstDataset(Dataset): 98 | def __init__(self, image_dir, patch_size, transform=None, mode='None'): 99 | self.image_dir = image_dir 100 | self.transform = transform 101 | self.images = np.sort(os.listdir(image_dir)) 102 | self.mode = mode 103 | self.patch_size = patch_size 104 | 105 | def __len__(self): 106 | return len(self.images) 107 | 108 | def __getitem__(self, index): 109 | img_path = os.path.join(self.image_dir, self.images[index]) 110 | image = np.array(Image.open(img_path).convert("RGB")) 111 | 112 | if (self.mode == 'test'): 113 | times = 32 114 | h, w, c = image.shape 115 | # 填充高度和宽度,使其能被32整除 116 | pad_height = math.ceil(h / times) * times - h 117 | pad_width = math.ceil(w / times) * times - w 118 | # 填充图像和掩码 119 | image = np.pad(image, ((0, pad_height), (0, pad_width), (0, 0)), mode='constant') 120 | if self.transform is not None: 121 | augmentations = self.transform(image=image) 122 | image = augmentations["image"] 123 | return image, self.images[index], h, w 124 | else: 125 | print("输入的模式错误!!!") 126 | 127 | 128 | def main(): 129 | origin_img_dir = dataset_path + "/origin/img" 130 | cal_mean, cal_std = Calculate_mean_std(origin_img_dir) 131 | test_transforms = A.Compose( 132 | [ 133 | # A.Resize(height=IMAGE_HEIGHT, width=IMAGE_WIDTH), 134 | A.Normalize( 135 | mean=cal_mean, 136 | std=cal_std, 137 | max_pixel_value=255.0, 138 | ), 139 | ToTensorV2(), 140 | ], 141 | ) 142 | test_ds = SirstDataset( 143 | image_dir=input_path, 144 | patch_size=patch_size_test, 145 | transform=test_transforms, 146 | mode='test' 147 | ) 148 | 149 | test_loader = DataLoader( 150 | test_ds, 151 | batch_size=TEST_BATCH_SIZE, 152 | num_workers=NUM_WORKERS, 153 | pin_memory=PIN_MEMORY, 154 | shuffle=False, 155 | ) 156 | model = model_func().to(DEVICE) 157 | 158 | model.load_state_dict({k.replace('module.', ''): v for k, v in 159 | torch.load(test_model_path, map_location=DEVICE)[ 160 | 'state_dict'].items()}) 161 | model.eval() 162 | 163 | temp_num = 0 164 | 165 | for idx, (img, name, h, w) in enumerate(test_loader): 166 | print(idx) 167 | img = img.to(device=DEVICE) 168 | with torch.no_grad(): 169 | image_1 = img 170 | 171 | output_1 = test_pred(image_1, model, batch_size=TEST_PATCH_BATCH_SIZE, patch_size=patch_size_test) 172 | output_1 = torch.sigmoid(output_1) 173 | output_1 = output_1[:, :, :h, :w] 174 | output_1 = output_1.cpu().data.numpy() 175 | 176 | for i in range(output_1.shape[0]): 177 | print(name[i]) 178 | temp_num = temp_num + 1 179 | pred = output_1[i] 180 | pred = pred[0] 181 | pred_target = np.where(pred > 0.5, 255, 0) 182 | pred_target = np.array(pred_target, dtype='uint8') 183 | 184 | cv2.imwrite(os.path.join(output_path, name[i]), pred_target) 185 | 186 | 187 | if __name__ == "__main__": 188 | main() 189 | -------------------------------------------------------------------------------- /tools/centroid_anno.m: -------------------------------------------------------------------------------- 1 | data_dir = './masks/'; 2 | save_dir = './masks_centroid/'; 3 | mkdir(save_dir); 4 | data_list = dir(data_dir); 5 | for i = 3:length(data_list) 6 | img = imread([data_dir,data_list(i).name]); 7 | Ilabel = bwlabel(img); 8 | Area_I = regionprops(Ilabel,'centroid'); 9 | img_centroid = zeros(size(img)); 10 | for x = 1: numel(Area_I) 11 | img_centroid(floor(Area_I(x).Centroid(2)),floor(Area_I(x).Centroid(1))) = 255; 12 | end 13 | imwrite(uint8(img_centroid),[save_dir,data_list(i).name]) 14 | end -------------------------------------------------------------------------------- /tools/coarse_anno.m: -------------------------------------------------------------------------------- 1 | data_dir = './masks/'; 2 | save_dir = './masks_coarse/'; 3 | mkdir(save_dir); 4 | data_list = dir(data_dir); 5 | for i = 3:length(data_list) 6 | img = double(imread([data_dir,data_list(i).name])); 7 | Ilabel = bwlabel(img); 8 | BoundingBoxs = regionprops(Ilabel,'BoundingBox'); 9 | centroids = regionprops(Ilabel,'centroid'); 10 | img_centroid = zeros(size(img)); 11 | for j = 1: numel(BoundingBoxs) 12 | img_temp = zeros(size(img)); 13 | while sum(sum(img_temp .* img)) == 0 14 | gaussian_num_x = normrnd(0,1/4,1,1) ; 15 | gaussian_num_y = normrnd(0,1/4,1,1) ; 16 | x = floor(centroids(j).Centroid(1)+BoundingBoxs(j).BoundingBox(3)/2*gaussian_num_x); 17 | y = floor(centroids(j).Centroid(2)+BoundingBoxs(j).BoundingBox(4)/2*gaussian_num_y); 18 | [y_max, x_max] = size(img); 19 | if x <1 20 | x= 1; 21 | end 22 | if y <1 23 | y= 1; 24 | end 25 | if x>x_max 26 | x = x_max; 27 | end 28 | if y>y_max 29 | y = y_max; 30 | end 31 | img_temp(y,x) = 255; 32 | end 33 | img_centroid(y,x) = 255; 34 | end 35 | imwrite(uint8(img_centroid),[save_dir,data_list(i).name]) 36 | end 37 | 38 | function [y] = Gaussian(x,mu,sigma) 39 | y = exp(-(x-mu).^2/(2*sigma^2));% 40 | end -------------------------------------------------------------------------------- /utilts.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python3 2 | # coding = gbk 3 | """ 4 | @Author : yuchuang 5 | @Time : 6 | @desc: 7 | """ 8 | 9 | import os 10 | 11 | import numpy as np 12 | 13 | import cv2 14 | import importlib 15 | import random 16 | import shutil 17 | 18 | import sys 19 | import torch 20 | # import pydensecrf.densecrf as dcrf 21 | # from pydensecrf.utils import unary_from_softmax, create_pairwise_gaussian, create_pairwise_bilateral 22 | 23 | def make_dir(path): 24 | if os.path.exists(path)==False: 25 | os.makedirs(path) 26 | 27 | def access_model(choose_model): 28 | choose_model_dir_name = choose_model + '_no_sigmoid' 29 | model_function = choose_model + '_No_Sigmoid' 30 | module_name = f"model.{choose_model}.{choose_model_dir_name}" 31 | module = importlib.import_module(module_name) 32 | model_func = getattr(module, model_function) 33 | return model_func 34 | 35 | 36 | def check_path(path): 37 | if os.path.exists(path) == True: 38 | print("Error: The workspace of the training pool already exists. Please manually and carefully delete the existing directory or add a suffix to the name of the created folder!!!") 39 | print("The conflicting directories are:", path) 40 | print("Method 3 is recommended to generate a unique training pool folder name!!!") 41 | sys.exit(0) 42 | 43 | 44 | def center_point_inside_contour(center_point, target_mask): 45 | (center_y, center_x) = center_point 46 | target_contours, _ = cv2.findContours(target_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) 47 | #center_index_XmergeY = {center_y * 1.0 + center_x * 0.0001,center_y-1 * 1.0 + center_x * 0.0001,center_y * 1.0 + center_x-1 * 0.0001,center_y+1 * 1.0 + center_x * 0.0001,center_y * 1.0 + center_x+1 * 0.0001} 48 | center_index_XmergeY = {center_y * 1.0 + center_x * 0.0001} 49 | temp_contour_mask = np.zeros(target_mask.shape, np.uint8) 50 | 51 | overlap_found = False 52 | for target_contour in target_contours: 53 | target_contour_mask = np.zeros(target_mask.shape, np.uint8) 54 | cv2.fillPoly(target_contour_mask, [target_contour], (255)) 55 | target_index = np.where(target_contour_mask == 255) 56 | target_index_XmergeY = set(target_index[0] * 1.0 + target_index[1] * 0.0001) 57 | if not center_index_XmergeY.isdisjoint(target_index_XmergeY): 58 | area = cv2.contourArea(target_contour) 59 | if area>50: 60 | break 61 | else: 62 | overlap_found = True 63 | #print("True") 64 | cv2.fillPoly(temp_contour_mask, [target_contour], (255)) 65 | break 66 | 67 | if not overlap_found: 68 | temp_contour_mask = temp_contour_mask 69 | return temp_contour_mask, overlap_found 70 | 71 | 72 | def process_image(y_and_x, y1_and_y2_and_x1_and_x2, img_shape,image, low_threshold=50, high_threshold=150, kernel_size=(3, 3), sigma=0): 73 | blurred_image = cv2.GaussianBlur(image, kernel_size, sigma) 74 | high_pass = cv2.subtract(image, blurred_image) 75 | edges = cv2.Canny(high_pass, low_threshold, high_threshold) 76 | kernel = np.ones((3, 3), np.uint8) 77 | sparse_edges = cv2.dilate(edges, kernel, iterations=1) 78 | sparse_edges = cv2.erode(sparse_edges, kernel, iterations=1) 79 | 80 | temp_contour_mask_2 = np.zeros(img_shape, np.uint8) 81 | (y1,y2, x1,x2) = y1_and_y2_and_x1_and_x2 82 | temp_contour_mask_2[y1:y2, x1:x2] = sparse_edges 83 | (y,x) = y_and_x 84 | 85 | 86 | refine_mask,flag = center_point_inside_contour((y,x), temp_contour_mask_2) 87 | refine_mask_out = refine_mask[y1:y2, x1:x2] 88 | return refine_mask_out,flag 89 | 90 | 91 | 92 | def data_inital_make_add_points(origin_img_dir, origin_points_dir,TRAIN_IMG_DIR,TRAIN_MASK_DIR,train_points_dir,nc_img_dir,nc_mask_dir,nc_points_dir, crop_size=10): 93 | input_image_path = origin_img_dir 94 | input_points_path = origin_points_dir 95 | 96 | output_image_path = TRAIN_IMG_DIR 97 | output_masks_path = TRAIN_MASK_DIR 98 | output_points_path = train_points_dir 99 | 100 | 101 | no_choose_output_image_path = nc_img_dir 102 | no_choose_output_masks_path = nc_mask_dir 103 | no_choose_output_points_path = nc_points_dir 104 | 105 | 106 | input_img_list = os.listdir(input_image_path) 107 | 108 | for i in range(len(input_img_list)): 109 | #print(f"正在处理图像:{input_img_list[i]}") 110 | img_path = os.path.join(input_image_path, input_img_list[i]) 111 | points_path = os.path.join(input_points_path, input_img_list[i]) 112 | 113 | out_img_path = os.path.join(output_image_path, input_img_list[i]) 114 | out_mask_path = os.path.join(output_masks_path, input_img_list[i]) 115 | out_points_path = os.path.join(output_points_path, input_img_list[i]) 116 | 117 | no_choose_out_img_path = os.path.join(no_choose_output_image_path, input_img_list[i]) 118 | no_choose_out_mask_path = os.path.join(no_choose_output_masks_path, input_img_list[i]) 119 | no_choose_out_points_path = os.path.join(no_choose_output_points_path, input_img_list[i]) 120 | 121 | img = cv2.imread(img_path, cv2.IMREAD_GRAYSCALE) 122 | mask = cv2.imread(points_path, cv2.IMREAD_GRAYSCALE) 123 | 124 | if img is None or mask is None: 125 | print(f"图像或mask读取失败:{input_img_list[i]}") 126 | continue 127 | 128 | points = np.where(mask == 255) 129 | if len(points[0]) == 0: 130 | #print(f"在mask中未找到标记点:{input_img_list[i]}") 131 | cv2.imwrite(out_img_path, img) 132 | cv2.imwrite(out_mask_path, mask) 133 | cv2.imwrite(out_points_path, mask) 134 | continue 135 | 136 | merged_result = np.zeros((img.shape[0], img.shape[1]), dtype=np.uint8) 137 | 138 | correct_point = 0 139 | 140 | for i in range(len(points[0])): 141 | center_y, center_x = points[0][i], points[1][i] 142 | 143 | half_size = crop_size 144 | x1 = max(center_x - half_size, 0) 145 | y1 = max(center_y - half_size, 0) 146 | x2 = min(center_x + half_size, img.shape[1]) 147 | y2 = min(center_y + half_size, img.shape[0]) 148 | 149 | roi = img[y1:y2, x1:x2] 150 | 151 | processed_roi, flag = process_image((center_y, center_x), (y1, y2, x1, x2), mask.shape, roi, 152 | low_threshold=20, high_threshold=40, kernel_size=(3, 3), sigma=0) 153 | 154 | merged_result[y1:y2, x1:x2] = merged_result[y1:y2, x1:x2] + processed_roi 155 | 156 | if flag == True: 157 | correct_point = correct_point + 1 158 | else: 159 | continue 160 | 161 | if (correct_point / len(points[0])) >= 0.8: 162 | merged_result = merged_result + mask 163 | merged_result = np.where(merged_result > 0, 255, 0) 164 | cv2.imwrite(out_img_path, img) 165 | cv2.imwrite(out_mask_path, merged_result) 166 | cv2.imwrite(out_points_path,mask) 167 | else: 168 | # print(no_choose_out_img_path) 169 | cv2.imwrite(no_choose_out_img_path, img) 170 | #cv2.imwrite(no_choose_out_mask_path, mask) 171 | cv2.imwrite(no_choose_out_points_path,mask) 172 | 173 | print("初始数据已生成,共生成样本张数:", len(os.listdir(output_image_path))) 174 | print("初始数据已生成,共生成样本张数:", len(os.listdir(output_image_path))) 175 | print("初始数据已生成,共生成样本张数:", len(os.listdir(output_image_path))) 176 | 177 | 178 | 179 | 180 | 181 | ###copy_mask 为点 182 | ###target_mask为预测结果 183 | def nc_pred_mask(copy_mask, target_mask,lose_point_ratio = 0.2,alarm_point_ration=0.2): 184 | copy_contours, _ = cv2.findContours(copy_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) 185 | target_contours, _ = cv2.findContours(target_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) 186 | 187 | overwrite_contours = [] 188 | un_overwrite_contours = [] 189 | 190 | target_index_sets = [] 191 | for target_contour in target_contours: 192 | target_contour_mask = np.zeros(copy_mask.shape, np.uint8) 193 | cv2.fillPoly(target_contour_mask, [target_contour], (255)) 194 | target_index = np.where(target_contour_mask == 255) 195 | target_index_XmergeY = set(target_index[0] * 1.0 + target_index[1] * 0.0001) 196 | target_index_sets.append(target_index_XmergeY) 197 | 198 | for copy_contour in copy_contours: 199 | copy_contour_mask = np.zeros(copy_mask.shape, np.uint8) 200 | cv2.fillPoly(copy_contour_mask, [copy_contour], (255)) 201 | copy_index = np.where(copy_contour_mask == 255) 202 | copy_index_XmergeY = set(copy_index[0] * 1.0 + copy_index[1] * 0.0001) 203 | 204 | overlap_found = False 205 | for target_index_XmergeY in target_index_sets: 206 | if not copy_index_XmergeY.isdisjoint(target_index_XmergeY): 207 | overwrite_contours.append(copy_contour) 208 | overlap_found = True 209 | break 210 | 211 | if not overlap_found: 212 | un_overwrite_contours.append(copy_contour) 213 | 214 | flag = False 215 | if len(un_overwrite_contours) / len(copy_contours) > lose_point_ratio or ((len(target_contours) - len(overwrite_contours)) / len(copy_contours)) > alarm_point_ration: 216 | flag = flag 217 | else: 218 | flag = True 219 | return flag 220 | 221 | 222 | 223 | def deal_pred_mask_and_true_point_in(nc_img_path,nc_pred_mask_path,nc_points_path,c_img_path,c_mask_path,c_points_path,lose_point_ratio=0.2,alarm_point_ration=0.2): 224 | nc_img_list = os.listdir(nc_img_path) 225 | new_choose_list = [] 226 | for i in range(len(nc_img_list)): 227 | img_path = os.path.join(nc_img_path, nc_img_list[i]) 228 | points_path = os.path.join(nc_points_path, nc_img_list[i]) 229 | pred_path = os.path.join(nc_pred_mask_path, nc_img_list[i]) 230 | 231 | img = cv2.imread(img_path, cv2.IMREAD_GRAYSCALE) 232 | mask = cv2.imread(points_path, cv2.IMREAD_GRAYSCALE) 233 | pred = cv2.imread(pred_path, cv2.IMREAD_GRAYSCALE) 234 | 235 | if img is None or mask is None: 236 | print(f"图像或mask读取失败:{nc_img_list[i]}") 237 | continue 238 | 239 | points = np.where(mask == 255) 240 | if len(points[0]) == 0: 241 | new_choose_list.append(nc_img_list[i]) 242 | continue 243 | 244 | flag = nc_pred_mask(mask, pred,lose_point_ratio =lose_point_ratio,alarm_point_ration=alarm_point_ration) 245 | if flag == True: 246 | new_choose_list.append(nc_img_list[i]) 247 | return new_choose_list 248 | 249 | 250 | 251 | ###注意这边与nc_pred_mask中的相反 252 | ###copy_mask 为预测结果 253 | ###target_mask为点 254 | def nc_correct_pred_mask(copy_mask, target_mask): 255 | copy_contours, _ = cv2.findContours(copy_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) 256 | target_contours, _ = cv2.findContours(target_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) 257 | 258 | overwrite_contours = [] 259 | un_overwrite_contours = [] 260 | 261 | target_index_sets = [] 262 | for target_contour in target_contours: 263 | target_contour_mask = np.zeros(copy_mask.shape, np.uint8) 264 | cv2.fillPoly(target_contour_mask, [target_contour], (255)) 265 | target_index = np.where(target_contour_mask == 255) 266 | target_index_XmergeY = set(target_index[0] * 1.0 + target_index[1] * 0.0001) 267 | target_index_sets.append(target_index_XmergeY) 268 | 269 | for copy_contour in copy_contours: 270 | copy_contour_mask = np.zeros(copy_mask.shape, np.uint8) 271 | cv2.fillPoly(copy_contour_mask, [copy_contour], (255)) 272 | copy_index = np.where(copy_contour_mask == 255) 273 | copy_index_XmergeY = set(copy_index[0] * 1.0 + copy_index[1] * 0.0001) 274 | 275 | overlap_found = False 276 | for target_index_XmergeY in target_index_sets: 277 | if not copy_index_XmergeY.isdisjoint(target_index_XmergeY): 278 | overwrite_contours.append(copy_contour) 279 | overlap_found = True 280 | break 281 | 282 | if not overlap_found: 283 | un_overwrite_contours.append(copy_contour) 284 | 285 | copy_contour_mask_out = np.zeros(copy_mask.shape, np.uint8) 286 | for i in range(len(overwrite_contours)): 287 | cv2.fillPoly(copy_contour_mask_out, [overwrite_contours[i]], (255)) 288 | 289 | copy_contour_mask_out = copy_contour_mask_out + target_mask 290 | copy_contour_mask_out = np.where(copy_contour_mask_out > 0, 255, 0) 291 | 292 | return copy_contour_mask_out 293 | 294 | 295 | 296 | def deal_gen_mask_error_aera(nc_pred_mask_path,nc_points_path,new_choose_list): 297 | for i in range(len(new_choose_list)): 298 | pred_mask_path = os.path.join(nc_pred_mask_path,new_choose_list[i]) 299 | points_path = os.path.join(nc_points_path,new_choose_list[i]) 300 | 301 | pred_mask = cv2.imread(pred_mask_path, cv2.IMREAD_GRAYSCALE) 302 | points = cv2.imread(points_path, cv2.IMREAD_GRAYSCALE) 303 | 304 | pred_mask_out = nc_correct_pred_mask(pred_mask,points) 305 | cv2.imwrite(pred_mask_path,pred_mask_out) 306 | print("生成的标签精细化完成!!!!!!") 307 | 308 | 309 | 310 | 311 | 312 | 313 | def move_files(file_list, src_path, dst_path): 314 | for file_name in file_list: 315 | #print(file_name) 316 | src_file = os.path.join(src_path, file_name) 317 | #print(src_file) 318 | dst_file = os.path.join(dst_path, file_name) 319 | if os.path.exists(src_file): 320 | try: 321 | shutil.copy(src_file, dst_file) 322 | os.remove(src_file) 323 | except PermissionError as e: 324 | print(f"PermissionError: {e}") 325 | 326 | # shutil.move(src_file, dst_file) 327 | else: 328 | print(f"File {src_file} does not exist.") 329 | 330 | 331 | 332 | 333 | def hard_sample_in(nc_img_path,nc_pred_mask_path,nc_points_path,c_img_path,c_mask_path,c_points_path,new_choose_list): 334 | move_files(new_choose_list, nc_img_path, c_img_path) 335 | move_files(new_choose_list, nc_pred_mask_path, c_mask_path) 336 | move_files(new_choose_list, nc_points_path, c_points_path) 337 | 338 | 339 | 340 | 341 | 342 | def update_gt_update_degen_corr(pred, gt_masks, thresh_Tb, thresh_k, size,degen=0.9): 343 | 344 | update_gt_masks = gt_masks.copy() 345 | 346 | background_length = 33 347 | target_length = 3 348 | 349 | num_labels, label_image = cv2.connectedComponents((gt_masks > 0.5).astype(np.uint8)) 350 | 351 | background_kernel = np.ones((background_length, background_length), np.uint8) 352 | target_kernel = np.ones((target_length, target_length), np.uint8) 353 | 354 | pred_max = pred.max() 355 | max_limitation = size[0] * size[1] * 0.0015 356 | 357 | combined_thresh_mask = np.zeros_like(pred, dtype=np.float32) 358 | 359 | for region_num in range(1, num_labels): 360 | region_coords = np.argwhere(label_image == region_num) 361 | centroid = np.mean(region_coords, axis=0).astype(int) 362 | 363 | cur_point_mask = np.zeros_like(pred, dtype=np.uint8) 364 | cur_point_mask[centroid[0], centroid[1]] = 1 365 | 366 | nbr_mask = cv2.dilate(cur_point_mask, background_kernel) > 0 367 | targets_mask = cv2.dilate(cur_point_mask, target_kernel) > 0 368 | 369 | region_size_ratio = len(region_coords) / max_limitation 370 | threshold_start = (pred * nbr_mask).max() * thresh_Tb 371 | threshold_delta = thresh_k * ((pred * nbr_mask).max() - threshold_start) * region_size_ratio 372 | threshold = threshold_start + threshold_delta 373 | threshold = threshold.cpu().numpy() if isinstance(threshold, torch.Tensor) else threshold 374 | 375 | thresh_mask = (pred * nbr_mask > threshold).astype(np.float32) 376 | 377 | num_labels_thresh, label_image_thresh = cv2.connectedComponents(thresh_mask.astype(np.uint8)) 378 | for num_cur in range(1, num_labels_thresh): 379 | curr_mask = (label_image_thresh == num_cur).astype(np.float32) 380 | if np.sum(curr_mask * targets_mask) == 0: 381 | thresh_mask -= curr_mask 382 | 383 | combined_thresh_mask = np.maximum(combined_thresh_mask, thresh_mask) 384 | 385 | target_patch = (update_gt_masks * combined_thresh_mask + pred * combined_thresh_mask) / 2 386 | background_patch = update_gt_masks * (1 - combined_thresh_mask)* degen 387 | update_gt_masks = background_patch + target_patch 388 | 389 | update_gt_masks = np.maximum(update_gt_masks, (gt_masks == 1).astype(np.float32)) 390 | 391 | return update_gt_masks 392 | 393 | 394 | 395 | 396 | 397 | 398 | 399 | 400 | 401 | --------------------------------------------------------------------------------