├── README.md
├── cal_PD_and_Fa.py
├── cal_mIoU_and_nIoU.py
├── components
├── __pycache__
│ ├── cal_mean_std.cpython-36.pyc
│ ├── cal_mean_std.cpython-38.pyc
│ ├── dataset_final_edge_copy_paste_final_2_img_path.cpython-36.pyc
│ ├── dataset_final_edge_copy_paste_final_2_img_path.cpython-38.pyc
│ ├── edges.cpython-36.pyc
│ ├── edges.cpython-38.pyc
│ ├── metric_new_crop.cpython-36.pyc
│ ├── metric_new_crop.cpython-38.pyc
│ ├── utils_all_edge_copy_paste_final_2_img_path.cpython-36.pyc
│ └── utils_all_edge_copy_paste_final_2_img_path.cpython-38.pyc
├── cal_mean_std.py
├── dataset_final_edge_copy_paste_final_2_img_path.py
├── edges.py
├── gardient.py
├── metric_new_crop.py
└── utils_all_edge_copy_paste_final_2_img_path.py
├── imgs
├── Main results.png
├── PAL framework.png
├── Results on the SIRST3 with centroid point label.png
├── Results on the SIRST3 with coarse point label.png
├── Results on the three separate dataset with centroid point label.png
├── Results on the three separate dataset with coarse point label.png
├── Visualization on the SIRST3 with centroid point label.png
└── Visualization on the SIRST3 with coarse point label.png
├── loss
├── Edge_loss.py
├── __init__.py
└── __pycache__
│ ├── Edge_loss.cpython-36.pyc
│ ├── Edge_loss.cpython-38.pyc
│ ├── __init__.cpython-36.pyc
│ └── __init__.cpython-38.pyc
├── mm
└── attention
│ ├── SEAttention.py
│ └── __pycache__
│ ├── SEAttention.cpython-36.pyc
│ └── SEAttention.cpython-38.pyc
├── model
├── ACM
│ ├── ACM_no_sigmoid.py
│ ├── __init__.py
│ ├── __pycache__
│ │ ├── ACM_no_sigmoid.cpython-36.pyc
│ │ ├── __init__.cpython-36.pyc
│ │ └── fusion.cpython-36.pyc
│ └── fusion.py
├── ALC
│ ├── ALC_no_sigmoid.py
│ ├── __init__.py
│ ├── __pycache__
│ │ ├── ALC_no_sigmoid.cpython-36.pyc
│ │ ├── __init__.cpython-36.pyc
│ │ └── fusion.cpython-36.pyc
│ └── fusion.py
├── ALCL
│ ├── ALCL_no_sigmoid.py
│ ├── __init__.py
│ └── __pycache__
│ │ ├── ALCL_no_sigmoid.cpython-36.pyc
│ │ └── __init__.cpython-36.pyc
├── DNA
│ ├── DNA_no_sigmoid.py
│ ├── __init__.py
│ └── __pycache__
│ │ ├── DNA_no_sigmoid.cpython-36.pyc
│ │ └── __init__.cpython-36.pyc
├── GGL
│ ├── GGL_no_sigmoid.py
│ ├── __init__.py
│ └── __pycache__
│ │ ├── GGL_no_sigmoid.cpython-36.pyc
│ │ └── __init__.cpython-36.pyc
├── MLCL
│ ├── MLCL_no_sigmoid.py
│ ├── MLCL_small_no_sigmoid.py
│ ├── __init__.py
│ └── __pycache__
│ │ ├── MLCL_no_sigmoid.cpython-36.pyc
│ │ └── __init__.cpython-36.pyc
├── MSDA
│ ├── MSDA_no_sigmoid.py
│ ├── __init__.py
│ └── __pycache__
│ │ ├── MSDA_no_sigmoid.cpython-36.pyc
│ │ ├── MSDA_no_sigmoid.cpython-38.pyc
│ │ ├── __init__.cpython-36.pyc
│ │ └── __init__.cpython-38.pyc
└── UIU
│ ├── UIU_no_sigmoid.py
│ ├── __init__.py
│ ├── __pycache__
│ ├── UIU_no_sigmoid.cpython-36.pyc
│ ├── UIU_no_sigmoid.cpython-38.pyc
│ ├── __init__.cpython-36.pyc
│ ├── __init__.cpython-38.pyc
│ ├── fusion.cpython-36.pyc
│ └── fusion.cpython-38.pyc
│ └── fusion.py
├── test_model.py
├── tools
├── centroid_anno.m
└── coarse_anno.m
├── train_model.py
└── utilts.py
/README.md:
--------------------------------------------------------------------------------
1 | ## The official complete code for paper "From Easy to Hard: Progressive Active Learning Framework for Infrared Small Target Detection with Single Point Supervision" [[Paper/arXiv](https://arxiv.org/abs/2412.11154)]
2 |
3 |
8 |
9 | In this project demo, we have integrated multiple SIRST detection networks ([**ACM**](https://arxiv.org/abs/2009.14530), [**ALC**](https://arxiv.org/abs/2012.08573), [**MLCL-Net**](https://doi.org/10.1016/j.infrared.2022.104107), [**ALCL-Net**](https://ieeexplore.ieee.org/document/9785618), [**DNANet**](https://arxiv.org/abs/2106.00487), [**GGL-Net**](https://ieeexplore.ieee.org/abstract/document/10230271), [**UIUNet**](https://arxiv.org/abs/2212.00968), [**MSDA-Net**](https://arxiv.org/abs/2406.02037)), label forms (Full supervision, Coarse single-point supervision, Centroid single-point supervision), and datasets ([**SIRST**](https://ieeexplore.ieee.org/document/9423171), [**NUDT-SIRST**](https://ieeexplore.ieee.org/document/9864119), [**IRSTD-1k**](https://ieeexplore.ieee.org/document/9880295) and [**SIRST3**](https://arxiv.org/pdf/2304.01484)). At the same time, more networks and functions can be integrated into the project later. We hope we can contribute to the development of this field.
10 |
11 |
12 |
13 |
14 |
15 |
16 | Comparison of different methods on the SIRST3 dataset. CNN Full, CNN Coarse, and CNN Centroid denote CNN-based methods under full supervision, coarse and centroid point supervision.
17 |
18 |
19 |
20 |
25 |
26 | ## Overview
27 |
28 | We consider that an excellent learning process should be from easy to hard and take into account the learning ability of the current learner (model) rather than directly treating all tasks (samples) equally. Inspired by organisms gradually adapting to the environment and continuously accumulating knowledge, we first propose an innovative progressive active learning idea, which emphasizes that the network progressively and actively recognizes and learns more hard samples to achieve continuous performance enhancement. For details, please see [[Paper/arXiv](https://arxiv.org/abs/2412.11154)].
29 |
30 |
31 |
32 |
33 |
34 |
35 | ## Datasets
36 | 1. Original datasets
37 | * **NUDT-SIRST** [[Original dataset](https://pan.baidu.com/s/1WdA_yOHDnIiyj4C9SbW_Kg?pwd=nudt)] [[paper](https://ieeexplore.ieee.org/document/9864119)]
38 | * **SIRST** [[Original dataset](https://github.com/YimianDai/sirst)] [[paper](https://ieeexplore.ieee.org/document/9423171)]
39 | * **IRSTD-1k** [[Original dataset](https://drive.google.com/file/d/1JoGDGF96v4CncKZprDnoIor0k1opaLZa/view)] [[paper](https://ieeexplore.ieee.org/document/9880295)]
40 | * **SIRST3** [[Original dataset](https://github.com/XinyiYing/LESPS)] [[paper](https://arxiv.org/pdf/2304.01484)]
41 |
42 | 2. The labels are processed according to the "coarse_anno.m" and "centroid_anno.m" files in the "tools" folder to produce coarse point labels and centroid point labels. (**You can also skip this step and use the complete dataset in step 3 directly.**)
43 |
44 | 3. The datasets we created from original datasets (**can be used directly in our demo**)
45 |
46 | * [💎 Download the dataset required by our code!!!](https://pan.baidu.com/s/1_QIs9zUM_7MqJgwzO2aC0Q?pwd=1234)
47 |
48 |
49 | ## How to use our code
50 | 1. Download the dataset
51 |
52 | Click [download datasets](https://pan.baidu.com/s/1_QIs9zUM_7MqJgwzO2aC0Q?pwd=1234)
53 |
54 | Unzip the downloaded compressed package to the root directory of the project.
55 |
56 | 2. Creat a Anaconda Virtual Environment
57 |
58 | ```
59 | conda create -n PAL python=3.8
60 | conda activate PAL
61 | ```
62 | 3. Configure the running environment
63 |
64 | ```
65 | pip install torch==1.13.1+cu116 torchvision==0.14.1+cu116 torchaudio==0.13.1 --extra-index-url https://download.pytorch.org/whl/cu116
66 | pip install segmentation_models_pytorch -i https://pypi.tuna.tsinghua.edu.cn/simple
67 | pip install PyWavelets -i https://pypi.tuna.tsinghua.edu.cn/simple
68 | pip install scikit-image -i https://pypi.tuna.tsinghua.edu.cn/simple
69 | pip install albumentations==1.3.0 -i https://pypi.tuna.tsinghua.edu.cn/simple
70 | pip install scikit-learn matplotlib thop h5py SimpleITK scikit-image medpy yacs torchinfo
71 | ```
72 | 4. Training the model
73 |
74 | The default model, dataset and label forms are MSDA-Net, SIRST3, and coarse point labels. If you need to train others, please modify the corresponding setting in "train_model.py". Just change the settings to your choice. It's very simple. For details, please see the beginning of the code of "train_model.py".
75 | ```
76 | python train_model.py
77 | ```
78 | 5. Testing the Model
79 |
80 | The default model, dataset and label forms are MSDA-Net, SIRST3, and coarse point labels. If you need to test others, please modify the corresponding setting in "test_model.py". Notably, in the "test_model.py" file, you also need to assign the name of the folder where the weight file is located to the "test_dir_name" variable so that the program can find the corresponding model weights. For details, please see the beginning of the code of "test_model.py".
81 | ```
82 | python test_model.py
83 | ```
84 | 6. Performance Evaluation
85 | Use "cal_mIoU_and_nIoU.py" and "cal_PD_and_Fa.py" for performance evaluation. Notably, the corresponding folder path should be replaced. default:SIRST3.
86 | ```
87 | python cal_mIoU_and_nIoU.py
88 | python cal_PD_and_Fa.py
89 | ```
90 |
91 | ## Results
92 |
93 | * **Quantative Results on the SIRST3 dataset with Coarse point labels:**
94 |
95 |
96 |
97 |
98 | * **Quantative Results on the three individual datasets with Coarse point labels:**
99 |
100 |
101 |
102 |
103 | * **Quantative Results on the SIRST3 dataset with Centroid point labels:**
104 |
105 |
106 |
107 |
108 | * **Quantative Results on the three individual datasets with Centroid point labels:**
109 |
110 |
111 |
112 |
113 | * **Qualitative results on the SIRST3 dataset with Coarse point labels:** (Red denotes the correct detections, blue denotes the false detections, and yellow denotes the missed detections.)
114 |
115 |
116 |
117 |
118 | * **Qualitative results on the SIRST3 dataset with Centroid point labels:** (Red denotes the correct detections, blue denotes the false detections, and yellow denotes the missed detections.)
119 |
120 |
121 |
122 |
123 |
124 |
125 | ## Citation
126 |
127 | If you find this repo helpful, please give us a 🤩**star**🤩. Please consider citing the **PAL** if it benefits your project.
128 |
129 | BibTeX reference is as follows.
130 | ```
131 | @misc{yu2024easyhardprogressiveactive,
132 | title={From Easy to Hard: Progressive Active Learning Framework for Infrared Small Target Detection with Single Point Supervision},
133 | author={Chuang Yu and Jinmiao Zhao and Yunpeng Liu and Sicheng Zhao and Yimian Dai and Xiangyu Yue},
134 | year={2024},
135 | eprint={2412.11154},
136 | archivePrefix={arXiv},
137 | primaryClass={cs.CV},
138 | url={https://arxiv.org/abs/2412.11154},
139 | }
140 | ```
141 |
142 | word reference is as follows.
143 | ```
144 | Chuang Yu, Jinmiao Zhao, Yunpeng Liu, Sicheng Zhao, Yimian Dai, Xiangyu Yue. From Easy to Hard: Progressive Active Learning Framework for Infrared Small Target Detection with Single Point Supervision. arXiv preprint arXiv:2412.11154, 2024.
145 | ```
146 |
147 |
148 |
149 | ## Other link
150 |
151 | 1. My homepage: [[YuChuang](https://github.com/YuChuang1205)]
152 | 2. "MSDA-Net" demo: [[Link](https://github.com/YuChuang1205/MSDA-Net)]
153 |
--------------------------------------------------------------------------------
/cal_PD_and_Fa.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/python3
2 | # coding = gbk
3 | """
4 | @Author : yuchuang,zhaojinmiao
5 | @Time :
6 | @desc:
7 | """
8 | import os
9 | import cv2
10 | import numpy as np
11 | import sys
12 | from skimage import measure
13 | import os
14 | from os.path import join, isfile
15 |
16 |
17 | def cal_Pd_Fa(input_pred, input_true):
18 | image = measure.label(input_pred, connectivity=2)
19 | coord_image = measure.regionprops(image)
20 | label = measure.label(input_true, connectivity=2)
21 | coord_label = measure.regionprops(label)
22 | target = len(coord_label)
23 | image_area_total = []
24 | image_area_match = []
25 | distance_match = []
26 | dismatch = []
27 | for K in range(len(coord_image)):
28 | area_image = np.array(coord_image[K].area)
29 | #print(area_image)
30 | image_area_total.append(area_image)
31 |
32 | for i in range(len(coord_label)):
33 | centroid_label = np.array(list(coord_label[i].centroid))
34 | for m in range(len(coord_image)):
35 | centroid_image = np.array(list(coord_image[m].centroid))
36 | distance = np.linalg.norm(centroid_image - centroid_label)
37 | area_image = np.array(coord_image[m].area)
38 | if distance < 3:
39 | distance_match.append(distance)
40 | image_area_match.append(area_image)
41 | del coord_image[m]
42 | break
43 |
44 | FA = np.sum(image_area_total) - np.sum(image_area_match)
45 | PD = len(distance_match)
46 | return target, FA, PD
47 |
48 | # IMAGE_SIZE = 512
49 | root_path = os.path.abspath('.')
50 | # input_path = os.path.join(root_path, 'input')
51 | dataset_path = os.path.join(root_path,'dataset','SIRST3')
52 | test_dataset_path = os.path.join(dataset_path,'val')
53 |
54 | input_pred_path = os.path.join(test_dataset_path, 'pre_results')
55 | input_true_path = os.path.join(test_dataset_path, 'mask')
56 |
57 | # img_num = count_images_in_folder(input_pred_path)
58 | # print(img_num)
59 | input_pred_list = os.listdir(input_true_path)
60 | img_num = len(input_pred_list)
61 | print(img_num)
62 | input_pred_list.sort()
63 | target_all = 0
64 | FA_all = 0
65 | PD_all = 0
66 | all_pixel = 0
67 | for i in range(len(input_pred_list)):
68 | print("正在处理:", input_pred_list[i])
69 | img_name = input_pred_list[i]
70 | input_pred_img_path = os.path.join(input_pred_path, img_name)
71 | input_true_img_path = os.path.join(input_true_path, img_name)
72 | input_pred = cv2.imread(input_pred_img_path, cv2.IMREAD_GRAYSCALE)
73 | print(input_pred.shape)
74 | all_pixel += input_pred.shape[0] * input_pred.shape[1]
75 | input_pred = np.where(input_pred>127,255,0)
76 | input_true = cv2.imread(input_true_img_path, cv2.IMREAD_GRAYSCALE)
77 | target, FA, PD = cal_Pd_Fa(input_pred, input_true)
78 | #print(FA)
79 | target_all = target_all + target
80 | FA_all = FA_all + FA
81 | PD_all = PD_all + PD
82 | Pd = PD_all / target_all
83 | Fa = FA_all / all_pixel
84 |
85 | print("Pd为:", Pd)
86 | print("Fa为:", Fa*1000000)
87 | print("Done!!!!!!!!!!")
88 |
89 |
90 |
91 |
92 |
93 |
--------------------------------------------------------------------------------
/cal_mIoU_and_nIoU.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/python3
2 | # coding = gbk
3 | """
4 | @Author : yuchuang,zhaojinmiao
5 | @Time :
6 | @desc:
7 | """
8 | import os
9 | import cv2
10 | import numpy as np
11 | import sys
12 |
13 |
14 | def cal_iou(input_pred, input_true):
15 |
16 | input_pred = input_pred/255
17 | input_true = input_true/255
18 |
19 | inter_count = np.sum(input_pred*input_true)
20 | outer_count = np.sum(input_pred) + np.sum(input_true) - inter_count
21 |
22 |
23 | return inter_count, outer_count
24 |
25 | def sigmoid(x):
26 | s = 1 / (1 + np.exp(-x))
27 | return s
28 |
29 | root_path = os.path.abspath('.')
30 | # input_path = os.path.join(root_path, 'input')
31 |
32 | dataset_path = os.path.join(root_path,'dataset','SIRST3')
33 | test_dataset_path = os.path.join(dataset_path,'val')
34 |
35 | input_pred_path = os.path.join(test_dataset_path, 'pre_results')
36 | input_true_path = os.path.join(test_dataset_path, 'mask')
37 |
38 | input_pred_list = os.listdir(input_true_path)
39 |
40 | input_pred_list.sort()
41 |
42 | inter_count_all = 0
43 | outer_count_all = 0
44 | niou_list = []
45 | for i in range(len(input_pred_list)):
46 | print("正在处理:", input_pred_list[i])
47 | img_name = input_pred_list[i]
48 | input_pred_img_path = os.path.join(input_pred_path, img_name)
49 | input_true_img_path = os.path.join(input_true_path, img_name)
50 | input_pred = cv2.imread(input_pred_img_path, cv2.IMREAD_GRAYSCALE)
51 | #input_pred = sigmoid(input_pred)
52 | input_pred = np.where(input_pred>127,255,0)
53 | #print(input_pred)
54 | input_true = cv2.imread(input_true_img_path, cv2.IMREAD_GRAYSCALE)
55 | inter_count, outer_count = cal_iou(input_pred, input_true)
56 | inter_count_all = inter_count_all + inter_count
57 | outer_count_all = outer_count_all + outer_count
58 | niou_list.append(inter_count/(outer_count+1e-6))
59 |
60 | niou = np.mean(niou_list)
61 | miou = inter_count_all/outer_count_all
62 |
63 | print("moiu为:", miou)
64 | print("niou为:", niou)
65 | print("Done!!!!!!!!!!")
66 |
67 |
68 |
69 |
70 |
71 |
--------------------------------------------------------------------------------
/components/__pycache__/cal_mean_std.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YuChuang1205/PAL/a9046c5fc65e7b2548d6e195b624239fa836cf3e/components/__pycache__/cal_mean_std.cpython-36.pyc
--------------------------------------------------------------------------------
/components/__pycache__/cal_mean_std.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YuChuang1205/PAL/a9046c5fc65e7b2548d6e195b624239fa836cf3e/components/__pycache__/cal_mean_std.cpython-38.pyc
--------------------------------------------------------------------------------
/components/__pycache__/dataset_final_edge_copy_paste_final_2_img_path.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YuChuang1205/PAL/a9046c5fc65e7b2548d6e195b624239fa836cf3e/components/__pycache__/dataset_final_edge_copy_paste_final_2_img_path.cpython-36.pyc
--------------------------------------------------------------------------------
/components/__pycache__/dataset_final_edge_copy_paste_final_2_img_path.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YuChuang1205/PAL/a9046c5fc65e7b2548d6e195b624239fa836cf3e/components/__pycache__/dataset_final_edge_copy_paste_final_2_img_path.cpython-38.pyc
--------------------------------------------------------------------------------
/components/__pycache__/edges.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YuChuang1205/PAL/a9046c5fc65e7b2548d6e195b624239fa836cf3e/components/__pycache__/edges.cpython-36.pyc
--------------------------------------------------------------------------------
/components/__pycache__/edges.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YuChuang1205/PAL/a9046c5fc65e7b2548d6e195b624239fa836cf3e/components/__pycache__/edges.cpython-38.pyc
--------------------------------------------------------------------------------
/components/__pycache__/metric_new_crop.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YuChuang1205/PAL/a9046c5fc65e7b2548d6e195b624239fa836cf3e/components/__pycache__/metric_new_crop.cpython-36.pyc
--------------------------------------------------------------------------------
/components/__pycache__/metric_new_crop.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YuChuang1205/PAL/a9046c5fc65e7b2548d6e195b624239fa836cf3e/components/__pycache__/metric_new_crop.cpython-38.pyc
--------------------------------------------------------------------------------
/components/__pycache__/utils_all_edge_copy_paste_final_2_img_path.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YuChuang1205/PAL/a9046c5fc65e7b2548d6e195b624239fa836cf3e/components/__pycache__/utils_all_edge_copy_paste_final_2_img_path.cpython-36.pyc
--------------------------------------------------------------------------------
/components/__pycache__/utils_all_edge_copy_paste_final_2_img_path.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YuChuang1205/PAL/a9046c5fc65e7b2548d6e195b624239fa836cf3e/components/__pycache__/utils_all_edge_copy_paste_final_2_img_path.cpython-38.pyc
--------------------------------------------------------------------------------
/components/cal_mean_std.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from PIL import Image
3 | import os
4 | #
5 |
6 | def Calculate_mean_std(img_dir):
7 | img_list = os.listdir(img_dir)
8 |
9 | mean_list = []
10 | std_list = []
11 |
12 | for i in range(len(img_list)):
13 | #print(i)
14 | img_path = os.path.join(img_dir,img_list[i])
15 | img = np.array(Image.open(img_path).convert("L"))
16 | mean_list.append(img.mean())
17 | std_list.append(img.std())
18 | mean_out = np.mean(mean_list)/255
19 | std_out = np.mean(std_list)/255
20 | print("路径为:", img_dir)
21 | print("数据集均值为:", mean_out)
22 | print("数据集方差为:", std_out)
23 |
24 | return mean_out, std_out
25 |
26 |
--------------------------------------------------------------------------------
/components/dataset_final_edge_copy_paste_final_2_img_path.py:
--------------------------------------------------------------------------------
1 | import os
2 | from PIL import Image
3 | from torch.utils.data import Dataset
4 | import numpy as np
5 | import matplotlib.pyplot as plt
6 | import albumentations as A
7 | from albumentations.pytorch import ToTensorV2
8 | from torch.utils.data import DataLoader
9 | import time
10 | from components.edges import onehot_to_binary_edges, mask_to_onehot
11 | import cv2
12 | import random
13 | import math
14 |
15 |
16 |
17 | def random_crop(img, mask, patch_size):
18 | h, w, c = img.shape
19 | mh, mw = mask.shape
20 |
21 | assert (h, w) == (mh, mw), "Image and mask must have the same height and width"
22 |
23 | if min(h, w) < patch_size:
24 | img = np.pad(img, ((0, max(h, patch_size) - h), (0, max(w, patch_size) - w), (0, 0)), mode='constant')
25 | mask = np.pad(mask, ((0, max(h, patch_size) - h), (0, max(w, patch_size) - w)), mode='constant')
26 | h, w, _ = img.shape
27 |
28 | h_start = random.randint(0, h - patch_size)
29 | h_end = h_start + patch_size
30 | w_start = random.randint(0, w - patch_size)
31 | w_end = w_start + patch_size
32 |
33 | img_patch = img[h_start:h_end, w_start:w_end, :]
34 | mask_patch = mask[h_start:h_end, w_start:w_end]
35 |
36 | return img_patch, mask_patch
37 |
38 |
39 | class SirstDataset(Dataset):
40 | def __init__(self, image_dir, mask_dir, patch_size, transform=None, mode='None'):
41 | self.image_dir = image_dir
42 | self.mask_dir = mask_dir
43 | self.transform = transform
44 | self.images = np.sort(os.listdir(image_dir))
45 | self.mode = mode
46 | self.patch_size = patch_size
47 |
48 | def __len__(self):
49 | return len(self.images)
50 |
51 | def __getitem__(self, index):
52 | img_path = os.path.join(self.image_dir, self.images[index])
53 | mask_path = os.path.join(self.mask_dir, self.images[index])
54 | image = np.array(Image.open(img_path).convert("RGB"))
55 | # print(image.shape)
56 | mask = np.array(Image.open(mask_path).convert("L"), dtype=np.float32)
57 | mask = (mask > 127.5).astype(float)
58 |
59 | if (self.mode == 'train'):
60 | image_patch, mask_patch = random_crop(image, mask, self.patch_size)
61 | if self.transform is not None:
62 | augmentations = self.transform(image=image_patch, mask=mask_patch)
63 | image = augmentations["image"]
64 | mask = augmentations["mask"]
65 | mask_2 = mask.numpy()
66 | mask_2 = mask_2.astype(np.int64)
67 | oneHot_label = mask_to_onehot(mask_2, 2)
68 | edge = onehot_to_binary_edges(oneHot_label, 1, 2)
69 | edge[1, :] = 0
70 | edge[-1:, :] = 0
71 | edge[:, :1] = 0
72 | edge[:, -1:] = 0
73 | edge = np.expand_dims(edge, axis=0).astype(np.int64)
74 | return image, mask, edge
75 |
76 | elif (self.mode == 'val'):
77 | times = 32
78 | h, w, c = image.shape
79 | pad_height = math.ceil(h / times) * times - h
80 | pad_width = math.ceil(w / times) * times - w
81 | image = np.pad(image, ((0, pad_height), (0, pad_width), (0, 0)), mode='constant')
82 | mask = np.pad(mask, ((0, pad_height), (0, pad_width)), mode='constant')
83 | if self.transform is not None:
84 | augmentations = self.transform(image=image, mask=mask)
85 | image = augmentations["image"]
86 | mask = augmentations["mask"]
87 |
88 | return image, mask, h, w
89 |
--------------------------------------------------------------------------------
/components/edges.py:
--------------------------------------------------------------------------------
1 | from scipy.ndimage.morphology import distance_transform_edt
2 | import numpy as np
3 | import torch
4 | def onehot_to_multiclass_edges(mask, radius, num_classes):
5 | """
6 | Converts a segmentation mask (K,H,W) to an edgemap (K,H,W)
7 | """
8 | if radius < 0:
9 | return mask
10 |
11 | # We need to pad the borders for boundary conditions
12 | mask_pad = np.pad(mask, ((0, 0), (1, 1), (1, 1)), mode='constant', constant_values=0)
13 |
14 | channels = []
15 | for i in range(num_classes):
16 | dist = distance_transform_edt(mask_pad[i, :]) + distance_transform_edt(1.0 - mask_pad[i, :])
17 | dist = dist[1:-1, 1:-1]
18 | dist[dist > radius] = 0
19 | dist = (dist > 0).astype(np.uint8)
20 | channels.append(dist)
21 |
22 | return np.array(channels)
23 |
24 |
25 | def onehot_to_binary_edges(mask, radius, num_classes):
26 | """
27 | Converts a segmentation mask (K,H,W) to a binary edgemap (H,W)
28 | """
29 |
30 | if radius < 0:
31 | return mask
32 |
33 | # We need to pad the borders for boundary conditions
34 | mask_pad = np.pad(mask, ((0, 0), (1, 1), (1, 1)), mode='constant', constant_values=0)
35 |
36 | edgemap = np.zeros(mask.shape[1:])
37 | for i in range(num_classes):
38 | # 提取轮廓
39 | dist = distance_transform_edt(mask_pad[i, :]) + distance_transform_edt(1.0 - mask_pad[i, :])
40 | dist = dist[1:-1, 1:-1]
41 | dist[dist > radius] = 0
42 | edgemap += dist
43 | edgemap = (edgemap > 0).astype(np.uint8)*255
44 | return edgemap
45 |
46 |
47 | def mask_to_onehot(mask, num_classes):
48 | """
49 | Converts a segmentation mask (H,W) to (K,H,W) where the last dim is a one
50 | hot encoding vector
51 | """
52 | _mask = [mask == (i) for i in range(num_classes)]
53 | return np.array(_mask).astype(np.uint8)
--------------------------------------------------------------------------------
/components/gardient.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | def gradient_1order(x,h_x=None,w_x=None):
4 | if h_x is None and w_x is None:
5 | h_x = x.size()[2]
6 | w_x = x.size()[3]
7 | r = F.pad(x, (0, 1, 0, 0))[:, :, :, 1:]
8 | l = F.pad(x, (1, 0, 0, 0))[:, :, :, :w_x]
9 | t = F.pad(x, (0, 0, 1, 0))[:, :, :h_x, :]
10 | b = F.pad(x, (0, 0, 0, 1))[:, :, 1:, :]
11 | xgrad = torch.pow(torch.pow((r - l) * 0.5, 2) + torch.pow((t - b) * 0.5, 2), 0.5)
12 | return xgrad
13 | if __name__ == '__main__':
14 | input=torch.randn(50,512,7,7)
15 | output = gradient_1order(input)
16 | print(output.shape)
--------------------------------------------------------------------------------
/components/metric_new_crop.py:
--------------------------------------------------------------------------------
1 |
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 | import numpy as np
6 | from skimage import measure
7 |
8 | TEST_BATCH_SIZE = 1
9 | class SigmoidMetric():
10 | def __init__(self, score_thresh=0.5):
11 | self.score_thresh = score_thresh
12 | self.reset()
13 |
14 | def update(self, pred, labels):
15 | correct, labeled = self.batch_pix_accuracy(pred, labels)
16 | inter, union = self.batch_intersection_union(pred, labels)
17 |
18 | self.total_correct += correct
19 | self.total_label += labeled
20 | self.total_inter += inter
21 | self.total_union += union
22 |
23 | def get(self):
24 | """Gets the current evaluation result."""
25 | pixAcc = 1.0 * self.total_correct / (np.spacing(1) + self.total_label)
26 | IoU = 1.0 * self.total_inter / (np.spacing(1) + self.total_union)
27 | mIoU = IoU.mean()
28 | return pixAcc, mIoU
29 |
30 | def reset(self):
31 | """Resets the internal evaluation result to initial state."""
32 | self.total_inter = 0
33 | self.total_union = 0
34 | self.total_correct = 0
35 | self.total_label = 0
36 |
37 | def batch_pix_accuracy(self, output, target):
38 | assert output.shape == target.shape
39 | output = output.cpu().detach().numpy()
40 | target = target.cpu().detach().numpy()
41 |
42 | predict = (output > self.score_thresh).astype('int64') # P
43 | target = (target > self.score_thresh).astype('int64')
44 | # -----------------------------#在这个之前必须变为0、1
45 | pixel_labeled = np.sum(target > 0) # T
46 | pixel_correct = np.sum((predict == target) * (target > 0)) # TP
47 | assert pixel_correct <= pixel_labeled
48 | return pixel_correct, pixel_labeled
49 |
50 | def batch_intersection_union(self, output, target):
51 | mini = 1
52 | maxi = 1 # nclass
53 | nbins = 1 # nclass
54 | predict = (output.cpu().detach().numpy() > self.score_thresh).astype('int64') # P
55 | target = (target.cpu().detach().numpy() > self.score_thresh).astype('int64') # T
56 | # target = target.cpu().numpy().astype('int64') # T
57 | intersection = predict * (predict == target) # TP
58 |
59 |
60 | # areas of intersection and union
61 | area_inter, _ = np.histogram(intersection, bins=nbins, range=(mini, maxi)) # 统计二值化图像中像素值为1的像素数量
62 | area_pred, _ = np.histogram(predict, bins=nbins, range=(mini, maxi))
63 | area_lab, _ = np.histogram(target, bins=nbins, range=(mini, maxi))
64 | area_union = area_pred + area_lab - area_inter
65 | assert (area_inter <= area_union).all()
66 | return area_inter, area_union
67 |
68 |
69 | class SamplewiseSigmoidMetric():
70 | def __init__(self, nclass, score_thresh=0.5):
71 | self.nclass = nclass
72 | self.score_thresh = score_thresh
73 | self.reset()
74 |
75 | def update(self, preds, labels):
76 | """Updates the internal evaluation result."""
77 | inter_arr, union_arr = self.batch_intersection_union(preds, labels)
78 | self.total_inter = np.append(self.total_inter, inter_arr)
79 | self.total_union = np.append(self.total_union, union_arr)
80 |
81 | def get(self):
82 | """Gets the current evaluation result."""
83 | IoU = 1.0 * self.total_inter / (np.spacing(1) + self.total_union)
84 | mIoU = IoU.mean()
85 | return IoU, mIoU
86 |
87 | def reset(self):
88 | """Resets the internal evaluation result to initial state."""
89 | self.total_inter = np.array([])
90 | self.total_union = np.array([])
91 | self.total_correct = np.array([])
92 | self.total_label = np.array([])
93 |
94 | def batch_intersection_union(self, output, target):
95 | """nIoU"""
96 | # inputs are tensor
97 | # the category 0 is ignored class, typically for background / boundary
98 | mini = 1
99 | maxi = 1 # nclass
100 | nbins = 1 # nclass
101 |
102 | predict = (output.cpu().detach().numpy() > self.score_thresh).astype('int64') # P
103 | target = (target.cpu().detach().numpy() > self.score_thresh).astype('int64') # T
104 | # target = target.cpu().detach().numpy().astype('int64') # T
105 | intersection = predict * (predict == target) # TP
106 |
107 | num_sample = intersection.shape[0]
108 | area_inter_arr = np.zeros(num_sample)
109 | area_pred_arr = np.zeros(num_sample)
110 | area_lab_arr = np.zeros(num_sample)
111 | area_union_arr = np.zeros(num_sample)
112 |
113 | for b in range(num_sample):
114 | # areas of intersection and union
115 | area_inter, _ = np.histogram(intersection[b], bins=nbins, range=(mini, maxi))
116 | area_inter_arr[b] = area_inter
117 |
118 | area_pred, _ = np.histogram(predict[b], bins=nbins, range=(mini, maxi))
119 | area_pred_arr[b] = area_pred
120 |
121 | area_lab, _ = np.histogram(target[b], bins=nbins, range=(mini, maxi))
122 | area_lab_arr[b] = area_lab
123 |
124 | area_union = area_pred + area_lab - area_inter
125 | area_union_arr[b] = area_union
126 |
127 | assert (area_inter <= area_union).all()
128 |
129 | return area_inter_arr, area_union_arr
130 |
131 |
132 | class PD_FA_2():
133 | def __init__(self, nclass):
134 | super(PD_FA_2, self).__init__()
135 | self.nclass = nclass
136 | self.image_area_total = []
137 | self.image_area_match = []
138 | self.FA = 0
139 | self.PD = 0
140 | self.target = 0
141 | self.all_pixel = 0
142 | def update(self, preds, labels):
143 |
144 |
145 | predits = np.array((preds > 0.5).cpu()).astype('int64')
146 |
147 | for i in range(predits.shape[0]):
148 | self.image_h = predits.shape[-2]
149 | self.image_w = predits.shape[-1]
150 | self.all_pixel += self.image_h * self.image_w
151 |
152 |
153 | labelss = np.array((labels > 0.5).cpu()).astype('int64') # P
154 |
155 | image = measure.label(predits, connectivity=2)# 寻找最大连通域,二维图像当connectivity=2时代表8连通.
156 | # print('image.size', image.shape)
157 | #print(image)
158 | coord_image = measure.regionprops(image)# 返回所有连通区块的属性列表# 属性列表中包含了每个连通区块的一些统计信息,比如面积、中心坐标等
159 | #print(coord_image)
160 | label = measure.label(labelss, connectivity=2)
161 | coord_label = measure.regionprops(label)
162 |
163 | self.target += len(coord_label) # 标签总小目标数len(coord_label)
164 | self.image_area_total = []
165 | self.image_area_match = []
166 | self.distance_match = []
167 | self.dismatch = []
168 |
169 | for K in range(len(coord_image)):
170 | area_image = np.array(coord_image[K].area) # coord_image[K].area——第K个连通区域(目标)中,区域内像素点总数
171 | self.image_area_total.append(area_image) #预测图像中各连通区域的面积列表
172 | # 比较标签图像中的每个连通域的质心与预测图像中的连通域的质心之间的距离,如果距离小于3,则将预测图像连通域的面积和距离添加到相应的列表中,同时删除已匹配的预测图像连通域。
173 | for i in range(len(coord_label)):
174 | centroid_label = np.array(list(coord_label[i].centroid)) # coord_label[i].centroid标签连通区域i的质心坐标,centroid_label标签中目标的坐标集
175 | for m in range(len(coord_image)):
176 | centroid_image = np.array(list(coord_image[m].centroid)) # coord_image[m].centroid预测连通区域m的质心坐标,centroid_image预测图像中目标的坐标集
177 | distance = np.linalg.norm(centroid_image - centroid_label) #计算当前标签图像连通域 i 的质心与预测图像连通域 m 的质心之间的欧氏距离。
178 | area_image = np.array(coord_image[m].area) # 获取当前预测图像连通域 m 的面积。
179 | if distance < 3: # 如果质心距离小于3(这里的3是一个阈值,可以根据实际情况调整)。
180 | self.distance_match.append(distance) # 将匹配的质心距离添加到 self.distance_match 列表中
181 | self.image_area_match.append(area_image) # 将匹配的预测图像连通域面积添加到 self.image_area_match 列表中。
182 |
183 | del coord_image[m] # 从 coord_image 列表中删除连通域 m,因为它已经匹配到了。
184 | break
185 |
186 | self.dismatch = np.sum(self.image_area_total)-np.sum(self.image_area_match)
187 | self.FA += self.dismatch
188 | self.PD+=len(self.distance_match) # 预测到的小目标数
189 |
190 | def get(self,img_num):
191 | # print("imgae_w:", self.image_w)
192 | # print("imgae_h:", self.image_h)
193 | Final_FA = self.FA / self.all_pixel
194 | Final_PD = self.PD / self.target
195 | #print("预测的目标点PD",self.PD)
196 | #print("预测的目标点target",self.target)
197 |
198 | return Final_FA,Final_PD
199 |
200 |
201 | def reset(self):
202 | self.FA = 0
203 | self.PD = 0
204 | self.target = 0
205 | self.all_pixel = 0
206 |
207 | if __name__ == '__main__':
208 | pred = torch.rand(8, 1, 512, 512)
209 | target = torch.rand(8, 1, 512, 512)
210 | m1 = SigmoidMetric()
211 | m2 = SamplewiseSigmoidMetric(nclass=1, score_thresh=0.5)
212 | m1.update(pred, target)
213 | m2.update(pred, target)
214 | pixAcc, mIoU = m1.get()
215 | _, nIoU = m2.get()
216 |
--------------------------------------------------------------------------------
/components/utils_all_edge_copy_paste_final_2_img_path.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torchvision
3 | from components.dataset_final_edge_copy_paste_final_2_img_path import SirstDataset
4 | from torch.utils.data import DataLoader
5 | import os
6 |
7 | def make_dir(path):
8 | if os.path.exists(path) == False:
9 | os.makedirs(path)
10 |
11 |
12 | def get_loaders(
13 | train_dir,
14 | train_maskdir,
15 | val_dir,
16 | val_maskdir,
17 | patch_size,
18 | train_batch_size,
19 | test_batch_size,
20 | train_transform,
21 | val_transform,
22 | num_workers=4,
23 | pin_memory=True,
24 | ):
25 | train_ds = SirstDataset(
26 | image_dir=train_dir,
27 | mask_dir=train_maskdir,
28 | patch_size=patch_size,
29 | transform=train_transform,
30 | mode='train',
31 | )
32 |
33 | train_loader = DataLoader(
34 | train_ds,
35 | batch_size=train_batch_size,
36 | num_workers=num_workers,
37 | pin_memory=pin_memory,
38 | shuffle=True,
39 | )
40 |
41 | val_ds = SirstDataset(
42 | image_dir=val_dir,
43 | mask_dir=val_maskdir,
44 | patch_size=patch_size,
45 | transform=val_transform,
46 | mode='val',
47 | )
48 |
49 | val_loader = DataLoader(
50 | val_ds,
51 | batch_size=test_batch_size,
52 | num_workers=num_workers,
53 | pin_memory=pin_memory,
54 | shuffle=False,
55 | )
56 |
57 | return train_loader, val_loader
58 |
59 |
--------------------------------------------------------------------------------
/imgs/Main results.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YuChuang1205/PAL/a9046c5fc65e7b2548d6e195b624239fa836cf3e/imgs/Main results.png
--------------------------------------------------------------------------------
/imgs/PAL framework.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YuChuang1205/PAL/a9046c5fc65e7b2548d6e195b624239fa836cf3e/imgs/PAL framework.png
--------------------------------------------------------------------------------
/imgs/Results on the SIRST3 with centroid point label.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YuChuang1205/PAL/a9046c5fc65e7b2548d6e195b624239fa836cf3e/imgs/Results on the SIRST3 with centroid point label.png
--------------------------------------------------------------------------------
/imgs/Results on the SIRST3 with coarse point label.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YuChuang1205/PAL/a9046c5fc65e7b2548d6e195b624239fa836cf3e/imgs/Results on the SIRST3 with coarse point label.png
--------------------------------------------------------------------------------
/imgs/Results on the three separate dataset with centroid point label.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YuChuang1205/PAL/a9046c5fc65e7b2548d6e195b624239fa836cf3e/imgs/Results on the three separate dataset with centroid point label.png
--------------------------------------------------------------------------------
/imgs/Results on the three separate dataset with coarse point label.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YuChuang1205/PAL/a9046c5fc65e7b2548d6e195b624239fa836cf3e/imgs/Results on the three separate dataset with coarse point label.png
--------------------------------------------------------------------------------
/imgs/Visualization on the SIRST3 with centroid point label.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YuChuang1205/PAL/a9046c5fc65e7b2548d6e195b624239fa836cf3e/imgs/Visualization on the SIRST3 with centroid point label.png
--------------------------------------------------------------------------------
/imgs/Visualization on the SIRST3 with coarse point label.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YuChuang1205/PAL/a9046c5fc65e7b2548d6e195b624239fa836cf3e/imgs/Visualization on the SIRST3 with coarse point label.png
--------------------------------------------------------------------------------
/loss/Edge_loss.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/python3
2 | # coding = gbk
3 | """
4 | @Author : yuchuang
5 | @Time : 2024/3/30 22:07
6 | @desc:
7 | """
8 | import torch
9 | import torch.nn as nn
10 | import torch.nn.functional as F
11 | import numpy as np
12 | import torch.nn as nn
13 | from segmentation_models_pytorch.losses import DiceLoss, SoftCrossEntropyLoss, LovaszLoss,SoftBCEWithLogitsLoss
14 |
15 | def edgeSCE_loss(pred, target, edge):
16 |
17 | BinaryCrossEntropy_fn = SoftBCEWithLogitsLoss(smooth_factor=None, reduction='None')
18 |
19 | edge_weight = 4.
20 | loss_sce = BinaryCrossEntropy_fn(pred, target)
21 | #print(loss_sce.size())
22 | #print(edge.size())
23 | edge = edge.clone()
24 | edge[edge == 0] = 1.
25 | edge[edge > 0] = edge_weight
26 | loss_sce *= edge
27 |
28 | loss_sce_, ind = loss_sce.contiguous().view(-1).sort()
29 | min_value = loss_sce_[int(0.5 * loss_sce.numel())]
30 | loss_sce = loss_sce[loss_sce >= min_value]
31 | loss_sce = loss_sce.mean()
32 | loss = loss_sce
33 | return loss
34 |
35 | # if __name__ == '__main__':
36 | # target=torch.ones((2,1,256,256),dtype=torch.float32)
37 | # input=(torch.ones((2,1,256,256))*0.9)
38 | # input[0,0,0,0] = 0.99
39 | # loss=edgeBCE_Dice_loss(input,target,target*255)
40 | # print(loss)
41 |
42 |
--------------------------------------------------------------------------------
/loss/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YuChuang1205/PAL/a9046c5fc65e7b2548d6e195b624239fa836cf3e/loss/__init__.py
--------------------------------------------------------------------------------
/loss/__pycache__/Edge_loss.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YuChuang1205/PAL/a9046c5fc65e7b2548d6e195b624239fa836cf3e/loss/__pycache__/Edge_loss.cpython-36.pyc
--------------------------------------------------------------------------------
/loss/__pycache__/Edge_loss.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YuChuang1205/PAL/a9046c5fc65e7b2548d6e195b624239fa836cf3e/loss/__pycache__/Edge_loss.cpython-38.pyc
--------------------------------------------------------------------------------
/loss/__pycache__/__init__.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YuChuang1205/PAL/a9046c5fc65e7b2548d6e195b624239fa836cf3e/loss/__pycache__/__init__.cpython-36.pyc
--------------------------------------------------------------------------------
/loss/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YuChuang1205/PAL/a9046c5fc65e7b2548d6e195b624239fa836cf3e/loss/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/mm/attention/SEAttention.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | from torch import nn
4 | from torch.nn import init
5 |
6 |
7 |
8 | class SEAttention(nn.Module):
9 |
10 | def __init__(self, channel,reduction=16):
11 | super().__init__()
12 | self.avg_pool = nn.AdaptiveAvgPool2d(1)
13 | self.fc = nn.Sequential(
14 | nn.Linear(channel, channel // reduction, bias=False),
15 | nn.ReLU(inplace=True),
16 | nn.Linear(channel // reduction, channel, bias=False),
17 | nn.Sigmoid()
18 | )
19 |
20 |
21 | def init_weights(self):
22 | for m in self.modules():
23 | if isinstance(m, nn.Conv2d):
24 | init.kaiming_normal_(m.weight, mode='fan_out')
25 | if m.bias is not None:
26 | init.constant_(m.bias, 0)
27 | elif isinstance(m, nn.BatchNorm2d):
28 | init.constant_(m.weight, 1)
29 | init.constant_(m.bias, 0)
30 | elif isinstance(m, nn.Linear):
31 | init.normal_(m.weight, std=0.001)
32 | if m.bias is not None:
33 | init.constant_(m.bias, 0)
34 |
35 | def forward(self, x):
36 | b, c, _, _ = x.size()
37 | y = self.avg_pool(x).view(b, c)
38 | y = self.fc(y).view(b, c, 1, 1)
39 | return x * y.expand_as(x)
40 |
41 |
42 | if __name__ == '__main__':
43 | input=torch.randn(50,512,7,7)
44 | se = SEAttention(channel=512,reduction=8)
45 | output=se(input)
46 | print(output.shape)
47 |
48 |
--------------------------------------------------------------------------------
/mm/attention/__pycache__/SEAttention.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YuChuang1205/PAL/a9046c5fc65e7b2548d6e195b624239fa836cf3e/mm/attention/__pycache__/SEAttention.cpython-36.pyc
--------------------------------------------------------------------------------
/mm/attention/__pycache__/SEAttention.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YuChuang1205/PAL/a9046c5fc65e7b2548d6e195b624239fa836cf3e/mm/attention/__pycache__/SEAttention.cpython-38.pyc
--------------------------------------------------------------------------------
/model/ACM/ACM_no_sigmoid.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import numpy as np
4 | from torch.nn import BatchNorm2d
5 | from torchvision.models.resnet import BasicBlock
6 | from .fusion import AsymBiChaFuse
7 | import torch.nn.functional as F
8 | # from model.utils import init_weights, count_param
9 |
10 | class ACM_No_Sigmoid(nn.Module):
11 | def __init__(self, in_channels=3, layers=[3,3,3], channels=[8,16,32,64], fuse_mode='AsymBi', tiny=False, classes=1,
12 | norm_layer=BatchNorm2d,groups=1, norm_kwargs=None, **kwargs):
13 | super(ACM_No_Sigmoid, self).__init__()
14 | self.layer_num = len(layers)
15 | self.tiny = tiny
16 | self._norm_layer = norm_layer
17 | self.groups = groups
18 | self.momentum=0.9
19 | stem_width = int(channels[0]) ##channels: 8 16 32 64
20 | # self.stem.add(norm_layer(scale=False, center=False,**({} if norm_kwargs is None else norm_kwargs)))
21 | if tiny: # 默认是False
22 | self.stem = nn.Sequential(
23 | norm_layer(in_channels,self.momentum),
24 | nn.Conv2d(in_channels, out_channels=stem_width * 2, kernel_size=3, stride=1,padding=1, bias=False),
25 | norm_layer(stem_width * 2, momentum=self.momentum),
26 | nn.ReLU(inplace=True)
27 | )
28 | else:
29 | self.stem = nn.Sequential(
30 | # self.stem.add(nn.Conv2D(channels=stem_width*2, kernel_size=3, strides=2,
31 | # padding=1, use_bias=False))
32 | # self.stem.add(norm_layer(in_channels=stem_width*2))
33 | # self.stem.add(nn.Activation('relu'))
34 | # self.stem.add(nn.MaxPool2D(pool_size=3, strides=2, padding=1))
35 | norm_layer(in_channels, momentum=self.momentum),
36 | nn.Conv2d(in_channels=in_channels,out_channels=stem_width, kernel_size=3, stride=2,padding=1, bias=False),
37 | norm_layer(stem_width,momentum=self.momentum),
38 | nn.ReLU(inplace=True),
39 | nn.Conv2d(in_channels=stem_width,out_channels=stem_width, kernel_size=3, stride=1,padding=1, bias=False),
40 | norm_layer(stem_width,momentum=self.momentum),
41 | nn.ReLU(inplace=True),
42 | nn.Conv2d(in_channels=stem_width,out_channels=stem_width * 2, kernel_size=3, stride=1,padding=1, bias=False),
43 | norm_layer(stem_width * 2,momentum=self.momentum),
44 | nn.ReLU(inplace=True),
45 | nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
46 | )
47 |
48 | self.layer1 = self._make_layer(block=BasicBlock, blocks=layers[0],
49 | out_channels=channels[1],
50 | in_channels=channels[1], stride=1)
51 |
52 | self.layer2 = self._make_layer(block=BasicBlock, blocks=layers[1],
53 | out_channels=channels[2], stride=2,
54 | in_channels=channels[1])
55 | #
56 | self.layer3 = self._make_layer(block=BasicBlock, blocks=layers[2],
57 | out_channels=channels[3], stride=2,
58 | in_channels=channels[2])
59 |
60 | self.deconv2 = nn.ConvTranspose2d(in_channels=channels[3] ,out_channels=channels[2], kernel_size=(4, 4), ##channels: 8 16 32 64
61 | stride=2, padding=1)
62 | self.uplayer2 = self._make_layer(block=BasicBlock, blocks=layers[1],
63 | out_channels=channels[2], stride=1,
64 | in_channels=channels[2])
65 | self.fuse2 = self._fuse_layer(fuse_mode, channels=channels[2])
66 |
67 | self.deconv1 = nn.ConvTranspose2d(in_channels=channels[2] ,out_channels=channels[1], kernel_size=(4, 4),
68 | stride=2, padding=1)
69 | self.uplayer1 = self._make_layer(block=BasicBlock, blocks=layers[0],
70 | out_channels=channels[1], stride=1,
71 | in_channels=channels[1])
72 | self.fuse1 = self._fuse_layer(fuse_mode, channels=channels[1])
73 |
74 | self.head = _FCNHead(in_channels=channels[1], channels=classes, momentum=self.momentum)
75 |
76 |
77 | def _make_layer(self, block, out_channels, in_channels, blocks, stride):
78 |
79 | norm_layer = self._norm_layer
80 | downsample = None
81 |
82 | if stride != 1 or out_channels != in_channels:
83 | downsample = nn.Sequential(
84 | conv1x1(in_channels, out_channels , stride),
85 | norm_layer(out_channels * block.expansion, momentum=self.momentum),
86 | )
87 |
88 | layers = []
89 | layers.append(block(in_channels, out_channels, stride, downsample, self.groups, norm_layer=norm_layer))
90 | self.inplanes = out_channels * block.expansion
91 | for _ in range(1, blocks):
92 | layers.append(block(self.inplanes, out_channels, self.groups, norm_layer=norm_layer))
93 | return nn.Sequential(*layers)
94 |
95 | def _fuse_layer(self, fuse_mode, channels):
96 |
97 | if fuse_mode == 'AsymBi':
98 | fuse_layer = AsymBiChaFuse(channels=channels)
99 | else:
100 | raise ValueError('Unknown fuse_mode')
101 | return fuse_layer
102 |
103 | def forward(self, x):
104 |
105 | _, _, hei, wid = x.shape
106 |
107 | x = self.stem(x) # (4,16,120,120)
108 | c1 = self.layer1(x) # (4,16,120,120)
109 | c2 = self.layer2(c1) # (4,32, 60, 60)
110 | c3 = self.layer3(c2) # (4,64, 30, 30)
111 |
112 | deconvc2 = self.deconv2(c3) # (4,32, 60, 60)
113 | fusec2 = self.fuse2(deconvc2, c2) # (4,32, 60, 60)
114 | upc2 = self.uplayer2(fusec2) # (4,32, 60, 60)
115 |
116 | deconvc1 = self.deconv1(upc2) # (4,16,120,120)
117 | fusec1 = self.fuse1(deconvc1, c1) # (4,16,120,120)
118 | upc1 = self.uplayer1(fusec1) # (4,16,120,120)
119 |
120 | pred = self.head(upc1) # (4,1,120,120)
121 |
122 | if self.tiny:
123 | out = pred
124 | else:
125 | # out = F.contrib.BilinearResize2D(pred, height=hei, width=wid) # down 4
126 | out = F.interpolate(pred, scale_factor=4, mode='bilinear') # down 4 # (4,1,480,480)
127 |
128 | return out
129 |
130 | def evaluate(self, x):
131 | """evaluating network with inputs and targets"""
132 | return self.forward(x)
133 |
134 |
135 | class _FCNHead(nn.Module):
136 | # pylint: disable=redefined-outer-name
137 | def __init__(self, in_channels, channels, momentum, norm_layer=nn.BatchNorm2d, norm_kwargs=None, **kwargs):
138 | super(_FCNHead, self).__init__()
139 | inter_channels = in_channels // 4
140 | self.block = nn.Sequential(
141 | nn.Conv2d(in_channels=in_channels, out_channels=inter_channels,kernel_size=3, padding=1, bias=False),
142 | norm_layer(inter_channels, momentum=momentum),
143 | nn.ReLU(inplace=True),
144 | nn.Dropout(0.1),
145 | nn.Conv2d(in_channels=inter_channels, out_channels=channels,kernel_size=1)
146 | )
147 | # pylint: disable=arguments-differ
148 | def forward(self, x):
149 | return self.block(x)
150 |
151 | def conv1x1(in_planes, out_planes, stride=1):
152 | """1x1 convolution"""
153 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
154 |
155 |
156 |
157 | # #########################################################
158 | # ###2.测试ASKCResUNet
159 | # if __name__ == '__main__':
160 | # DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") # 让torch判断是否使用GPU,建议使用GPU环境,因为会快很多
161 | # layers = [3] * 3
162 | # channels = [x * 1 for x in [8, 16, 32, 64]]
163 | # in_channels = 3
164 | # model= ASKCResUNet(in_channels, layers=layers, channels=channels, fuse_mode='AsymBi',tiny=False, classes=1)
165 | #
166 | # model=model.cuda()
167 | # DATA = torch.randn(8,3,480,480).to(DEVICE)
168 | #
169 | # output=model(DATA)
170 | # print("output:",np.shape(output))
171 | # ##########################################################
--------------------------------------------------------------------------------
/model/ACM/__init__.py:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/model/ACM/__pycache__/ACM_no_sigmoid.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YuChuang1205/PAL/a9046c5fc65e7b2548d6e195b624239fa836cf3e/model/ACM/__pycache__/ACM_no_sigmoid.cpython-36.pyc
--------------------------------------------------------------------------------
/model/ACM/__pycache__/__init__.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YuChuang1205/PAL/a9046c5fc65e7b2548d6e195b624239fa836cf3e/model/ACM/__pycache__/__init__.cpython-36.pyc
--------------------------------------------------------------------------------
/model/ACM/__pycache__/fusion.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YuChuang1205/PAL/a9046c5fc65e7b2548d6e195b624239fa836cf3e/model/ACM/__pycache__/fusion.cpython-36.pyc
--------------------------------------------------------------------------------
/model/ACM/fusion.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | from torch.nn import BatchNorm2d
3 | import torch
4 | from torch.nn import GroupNorm
5 | class AsymBiChaFuse(nn.Module):
6 | def __init__(self, channels=64, r=4):
7 | super(AsymBiChaFuse, self).__init__()
8 | self.channels = channels
9 | self.bottleneck_channels = int(channels // r)
10 |
11 |
12 | self.topdown = nn.Sequential(
13 | nn.AdaptiveAvgPool2d(1),
14 | nn.Conv2d(in_channels=self.channels, out_channels=self.bottleneck_channels, kernel_size=1, stride=1,padding=0),
15 | # nn.BatchNorm2d(self.bottleneck_channels,momentum=0.9),
16 | nn.GroupNorm(1, self.bottleneck_channels),
17 | nn.ReLU(inplace=True),
18 | nn.Conv2d(in_channels=self.bottleneck_channels,out_channels=self.channels, kernel_size=1, stride=1,padding=0),
19 | #nn.BatchNorm2d(self.channels,momentum=0.9),
20 | nn.GroupNorm(1, self.channels),
21 | nn.Sigmoid()
22 | )
23 |
24 | self.bottomup = nn.Sequential(
25 | nn.Conv2d(in_channels=self.channels,out_channels=self.bottleneck_channels, kernel_size=1, stride=1,padding=0),
26 | #nn.BatchNorm2d(self.bottleneck_channels,momentum=0.9),
27 | nn.GroupNorm(1, self.bottleneck_channels),
28 | nn.ReLU(inplace=True),
29 | nn.Conv2d(in_channels=self.bottleneck_channels,out_channels=self.channels, kernel_size=1, stride=1,padding=0),
30 | #nn.BatchNorm2d(self.channels,momentum=0.9),
31 | nn.GroupNorm(1, self.channels),
32 | nn.Sigmoid()
33 | )
34 |
35 | self.post = nn.Sequential(
36 | nn.Conv2d(in_channels=channels,out_channels=channels, kernel_size=3, stride=1, padding=1, dilation=1),
37 | #nn.BatchNorm2d(channels,momentum=0.9),
38 | nn.GroupNorm(1, self.channels),
39 | nn.ReLU(inplace=True)
40 | )
41 |
42 | def forward(self, xh, xl):
43 |
44 | topdown_wei = self.topdown(xh)
45 | bottomup_wei = self.bottomup(xl)
46 | xs = 2 * torch.mul(xl, topdown_wei) + 2 * torch.mul(xh, bottomup_wei)
47 | xs = self.post(xs)
48 | return xs
49 |
50 | # from mxnet.gluon.block import HybridBlock
51 | # topdown = HybridBlock(prefix='topdown')
52 | # print(topdown)
--------------------------------------------------------------------------------
/model/ALC/ALC_no_sigmoid.py:
--------------------------------------------------------------------------------
1 | from __future__ import division
2 | import os
3 | from torch.nn.modules import module
4 | import torch
5 | import torch.nn as nn
6 | from torch.nn import BatchNorm2d
7 | from .fusion import AsymBiChaFuse
8 |
9 | # from mxnet import nd
10 | from torchvision import transforms
11 | from torchvision.models.resnet import BasicBlock
12 |
13 | class _FCNHead(nn.Module):
14 | # pylint: disable=redefined-outer-name
15 | def __init__(self, in_channels, channels, momentum, norm_layer=nn.BatchNorm2d, norm_kwargs=None, **kwargs):
16 | super(_FCNHead, self).__init__()
17 | inter_channels = in_channels // 4
18 | self.block = nn.Sequential(
19 | nn.Conv2d(in_channels=in_channels, out_channels=inter_channels,kernel_size=3, padding=1, bias=False),
20 | norm_layer(inter_channels, momentum=momentum),
21 | nn.ReLU(inplace=True),
22 | nn.Dropout(0.1),
23 | nn.Conv2d(in_channels=inter_channels, out_channels=channels,kernel_size=1)
24 | )
25 | # pylint: disable=arguments-differ
26 | def forward(self, x):
27 | return self.block(x)
28 |
29 | def conv1x1(in_planes, out_planes, stride=1):
30 | """1x1 convolution"""
31 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
32 |
33 |
34 | class ALC_No_Sigmoid(nn.Module):
35 | def __init__(self, in_channels=3, layers=[4,4,4], channels=[8,16,32,64], fuse_mode='AsymBi', act_dilation=16, classes=1, tinyFlag=False,
36 | norm_layer=BatchNorm2d,groups=1,norm_kwargs=None, **kwargs):
37 | super(ALC_No_Sigmoid, self).__init__()
38 |
39 | self.layer_num = len(layers)
40 | self.tinyFlag = tinyFlag
41 | self.groups = groups
42 | self._norm_layer = norm_layer
43 | stem_width = int(channels[0])
44 | self.momentum=0.9
45 | if tinyFlag:
46 | self.stem = nn.Sequential(
47 | norm_layer(in_channels, self.momentum),
48 | nn.Conv2d(in_channels, out_channels=stem_width * 2, kernel_size=3, stride=1, padding=1, bias=False),
49 | norm_layer(stem_width * 2, momentum=self.momentum),
50 | nn.ReLU(inplace=True)
51 | )
52 |
53 | else:
54 | self.stem = nn.Sequential(
55 | # self.stem.add(nn.Conv2D(channels=stem_width*2, kernel_size=3, strides=2,
56 | # padding=1, use_bias=False))
57 | # self.stem.add(norm_layer(in_channels=stem_width*2))
58 | # self.stem.add(nn.Activation('relu'))
59 | # self.stem.add(nn.MaxPool2D(pool_size=3, strides=2, padding=1))
60 | norm_layer(in_channels, momentum=self.momentum),
61 | nn.Conv2d(in_channels=in_channels, out_channels=stem_width, kernel_size=3, stride=2, padding=1, bias=False),
62 | norm_layer(stem_width, momentum=self.momentum),
63 | nn.ReLU(inplace=True),
64 | nn.Conv2d(in_channels=stem_width, out_channels=stem_width, kernel_size=3, stride=1, padding=1,
65 | bias=False),
66 | norm_layer(stem_width, momentum=self.momentum),
67 | nn.ReLU(inplace=True),
68 | nn.Conv2d(in_channels=stem_width, out_channels=stem_width * 2, kernel_size=3, stride=1, padding=1,
69 | bias=False),
70 | norm_layer(stem_width * 2, momentum=self.momentum),
71 | nn.ReLU(inplace=True),
72 | nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
73 | )
74 |
75 |
76 | # self.head1 = _FCNHead(in_channels=channels[1], channels=classes)
77 | # self.head2 = _FCNHead(in_channels=channels[2], channels=classes)
78 | # self.head3 = _FCNHead(in_channels=channels[3], channels=classes)
79 | # self.head4 = _FCNHead(in_channels=channels[4], channels=classes)
80 |
81 | self.head = _FCNHead(in_channels=channels[0], channels=classes, momentum=self.momentum)
82 |
83 | self.layer1 = self._make_layer(block=BasicBlock, blocks=layers[0],
84 | out_channels=channels[1],
85 | in_channels=channels[1], stride=1)
86 |
87 | self.layer2 = self._make_layer(block=BasicBlock, blocks=layers[1],
88 | out_channels=channels[2], stride=2,
89 | in_channels=channels[1])
90 | #
91 | self.layer3 = self._make_layer(block=BasicBlock, blocks=layers[2],
92 | out_channels=channels[3], stride=2,
93 | in_channels=channels[2])
94 | self.deconv2 = nn.ConvTranspose2d(in_channels=channels[3], out_channels=channels[2], kernel_size=(4, 4),
95 | ##channels: 8 16 32 64
96 | stride=2, padding=1)
97 | self.uplayer2 = self._make_layer(block=BasicBlock, blocks=layers[1],
98 | out_channels=channels[2], stride=1,
99 | in_channels=channels[2])
100 |
101 |
102 | self.deconv1 = nn.ConvTranspose2d(in_channels=channels[2], out_channels=channels[1], kernel_size=(4, 4),
103 | stride=2, padding=1)
104 |
105 | self.deconv0 = nn.ConvTranspose2d(in_channels=channels[1], out_channels=channels[0], kernel_size=(4, 4),
106 | stride=2, padding=1)
107 |
108 | self.uplayer1 = self._make_layer(block=BasicBlock, blocks=layers[0],
109 | out_channels=channels[1], stride=1,
110 | in_channels=channels[1])
111 |
112 |
113 | if self.layer_num == 4:
114 | self.layer4 = self._make_layer(block=BasicBlock, blocks=layers[3],
115 | out_channels=channels[3], stride=2,
116 | in_channels=channels[3])
117 |
118 | if self.layer_num == 4:
119 | self.fuse34 = self._fuse_layer(fuse_mode, channels=channels[3]) # channels[4]
120 |
121 | self.fuse23 = self._fuse_layer(fuse_mode, channels=channels[2]) # 64
122 | self.fuse12 = self._fuse_layer(fuse_mode, channels=channels[1]) # 32
123 |
124 | # if fuse_order == 'reverse':
125 | # self.fuse12 = self._fuse_layer(fuse_mode, channels=channels[2]) # channels[2]
126 | # self.fuse23 = self._fuse_layer(fuse_mode, channels=channels[3]) # channels[3]
127 | # self.fuse34 = self._fuse_layer(fuse_mode, channels=channels[4]) # channels[4]
128 | # elif fuse_order == 'normal':
129 | # self.fuse34 = self._fuse_layer(fuse_mode, channels=channels[4]) # channels[4]
130 | # self.fuse23 = self._fuse_layer(fuse_mode, channels=channels[4]) # channels[4]
131 | # self.fuse12 = self._fuse_layer(fuse_mode, channels=channels[4]) # channels[4]
132 |
133 | def _make_layer(self, block, out_channels, in_channels, blocks, stride):
134 |
135 | norm_layer = self._norm_layer
136 | downsample = None
137 |
138 | if stride != 1 or out_channels != in_channels:
139 | downsample = nn.Sequential(
140 | conv1x1(in_channels, out_channels , stride),
141 | norm_layer(out_channels * block.expansion, momentum=self.momentum),
142 | )
143 |
144 | layers = []
145 | layers.append(block(in_channels, out_channels, stride, downsample, self.groups, norm_layer=norm_layer))
146 | self.inplanes = out_channels * block.expansion
147 | for _ in range(1, blocks):
148 | layers.append(block(self.inplanes, out_channels, self.groups, norm_layer=norm_layer))
149 | return nn.Sequential(*layers)
150 |
151 | def _fuse_layer(self, fuse_mode, channels):
152 |
153 | if fuse_mode == 'AsymBi':
154 | fuse_layer = AsymBiChaFuse(channels=channels)
155 | else:
156 | raise ValueError('Unknown fuse_mode')
157 | return fuse_layer
158 |
159 | def forward(self, x):
160 |
161 | _, _, hei, wid = x.shape# 1024 1024
162 |
163 | x = self.stem(x) #torch.Size([8, 16, 256, 256])
164 | c1 = self.layer1(x) # torch.Size([8, 16, 256, 256])
165 | c2 = self.layer2(c1) # torch.Size([8, 32, 128, 128])
166 |
167 | out = self.layer3(c2) # (8,64, 64, 64)
168 |
169 | if self.layer_num == 4:
170 | c4 = self.layer4(out) # torch.Size([8,64, 32, 32])
171 | if self.tinyFlag:
172 | c4 = transforms.Resize([hei//4, wid//4])(c4) # down 4
173 | else:
174 | c4 = transforms.Resize([hei//16, wid//16])(c4) # down 16 torch.Size([8, 64, 64, 64])
175 | out = self.fuse34(c4, out) #torch.Size([8, 64, 128, 128])`
176 |
177 | if self.tinyFlag:
178 | out = transforms.Resize([hei//2, wid//2])(out) # down 16 torch.Size([8, 64, 64, 64])
179 | else:
180 | out = transforms.Resize([hei//16, wid//16])(out) # down 8, 128 torch.Size([8, 64, 64, 64])
181 |
182 | out = self.deconv2(out) # torch.Size([8, 32, 128, 128])
183 | out = self.fuse23(out, c2) # torch.Size([8, 32, 128, 128])
184 | if self.tinyFlag:
185 | out = transforms.Resize([hei, wid])(out) # down 1
186 | else:
187 | out = transforms.Resize( [hei//8, wid//8])(out) # (4,16,120,120)
188 |
189 | out = self.deconv1(out) # torch.Size([8, 16, 256, 256])
190 | out = self.fuse12(out, c1) # torch.Size([8, 16, 256, 256])
191 |
192 | out = self.deconv0(out) # torch.Size([8, 8, 512, 512])
193 | pred = self.head(out) # torch.Size([8, 8, 512, 512])
194 |
195 |
196 | if self.tinyFlag:
197 | out = pred
198 | else:
199 | out = transforms.Resize( [hei, wid])(pred) # down 4
200 |
201 | ######### reverse order ##########
202 | # up_c2 = F.contrib.BilinearResize2D(c2, height=hei//4, width=wid//4) # down 4
203 | # fuse2 = self.fuse12(up_c2, c1) # down 4, channels[2]
204 | #
205 | # up_c3 = F.contrib.BilinearResize2D(c3, height=hei//4, width=wid//4) # down 4
206 | # fuse3 = self.fuse23(up_c3, fuse2) # down 4, channels[3]
207 | #
208 | # up_c4 = F.contrib.BilinearResize2D(c4, height=hei//4, width=wid//4) # down 4
209 | # fuse4 = self.fuse34(up_c4, fuse3) # down 4, channels[4]
210 | #
211 |
212 | ######### normal order ##########
213 | # out = F.contrib.BilinearResize2D(c4, height=hei//16, width=wid//16)
214 | # out = self.fuse34(out, c3)
215 | # out = F.contrib.BilinearResize2D(out, height=hei//8, width=wid//8)
216 | # out = self.fuse23(out, c2)
217 | # out = F.contrib.BilinearResize2D(out, height=hei//4, width=wid//4)
218 | # out = self.fuse12(out, c1)
219 | # out = self.head(out)
220 | # out = F.contrib.BilinearResize2D(out, height=hei, width=wid)
221 |
222 |
223 | return out
224 |
225 | def evaluate(self, x):
226 | """evaluating network with inputs and targets"""
227 | return self.forward(x)
228 |
229 |
230 |
231 |
--------------------------------------------------------------------------------
/model/ALC/__init__.py:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/model/ALC/__pycache__/ALC_no_sigmoid.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YuChuang1205/PAL/a9046c5fc65e7b2548d6e195b624239fa836cf3e/model/ALC/__pycache__/ALC_no_sigmoid.cpython-36.pyc
--------------------------------------------------------------------------------
/model/ALC/__pycache__/__init__.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YuChuang1205/PAL/a9046c5fc65e7b2548d6e195b624239fa836cf3e/model/ALC/__pycache__/__init__.cpython-36.pyc
--------------------------------------------------------------------------------
/model/ALC/__pycache__/fusion.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YuChuang1205/PAL/a9046c5fc65e7b2548d6e195b624239fa836cf3e/model/ALC/__pycache__/fusion.cpython-36.pyc
--------------------------------------------------------------------------------
/model/ALC/fusion.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | from torch.nn import BatchNorm2d
3 | import torch
4 | from torch.nn import GroupNorm
5 | class AsymBiChaFuse(nn.Module):
6 | def __init__(self, channels=64, r=4):
7 | super(AsymBiChaFuse, self).__init__()
8 | self.channels = channels
9 | self.bottleneck_channels = int(channels // r)
10 |
11 |
12 | self.topdown = nn.Sequential(
13 | nn.AdaptiveAvgPool2d(1),
14 | nn.Conv2d(in_channels=self.channels, out_channels=self.bottleneck_channels, kernel_size=1, stride=1,padding=0),
15 | # nn.BatchNorm2d(self.bottleneck_channels,momentum=0.9),
16 | nn.GroupNorm(1, self.bottleneck_channels),
17 | nn.ReLU(inplace=True),
18 | nn.Conv2d(in_channels=self.bottleneck_channels,out_channels=self.channels, kernel_size=1, stride=1,padding=0),
19 | #nn.BatchNorm2d(self.channels,momentum=0.9),
20 | nn.GroupNorm(1, self.channels),
21 | nn.Sigmoid()
22 | )
23 |
24 | self.bottomup = nn.Sequential(
25 | nn.Conv2d(in_channels=self.channels,out_channels=self.bottleneck_channels, kernel_size=1, stride=1,padding=0),
26 | #nn.BatchNorm2d(self.bottleneck_channels,momentum=0.9),
27 | nn.GroupNorm(1, self.bottleneck_channels),
28 | nn.ReLU(inplace=True),
29 | nn.Conv2d(in_channels=self.bottleneck_channels,out_channels=self.channels, kernel_size=1, stride=1,padding=0),
30 | #nn.BatchNorm2d(self.channels,momentum=0.9),
31 | nn.GroupNorm(1, self.channels),
32 | nn.Sigmoid()
33 | )
34 |
35 | self.post = nn.Sequential(
36 | nn.Conv2d(in_channels=channels,out_channels=channels, kernel_size=3, stride=1, padding=1, dilation=1),
37 | #nn.BatchNorm2d(channels,momentum=0.9),
38 | nn.GroupNorm(1, self.channels),
39 | nn.ReLU(inplace=True)
40 | )
41 |
42 | def forward(self, xh, xl):
43 |
44 | topdown_wei = self.topdown(xh)
45 | bottomup_wei = self.bottomup(xl)
46 | xs = 2 * torch.mul(xl, topdown_wei) + 2 * torch.mul(xh, bottomup_wei)
47 | xs = self.post(xs)
48 | return xs
49 |
50 | # from mxnet.gluon.block import HybridBlock
51 | # topdown = HybridBlock(prefix='topdown')
52 | # print(topdown)
--------------------------------------------------------------------------------
/model/ALCL/ALCL_no_sigmoid.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torchvision.transforms.functional as TF
4 |
5 |
6 | # from torchinfo import summary
7 |
8 | class Resnet1(nn.Module):
9 | def __init__(self, in_channel, out_channel):
10 | super(Resnet1, self).__init__()
11 | self.layer = nn.Sequential(
12 | nn.Conv2d(in_channels=in_channel, out_channels=out_channel, kernel_size=3, padding=1, stride=1),
13 | nn.BatchNorm2d(out_channel),
14 | nn.ReLU(inplace=True),
15 | nn.Conv2d(in_channels=out_channel, out_channels=out_channel, kernel_size=3, padding=1, stride=1),
16 | nn.BatchNorm2d(out_channel)
17 | )
18 | self.relu = nn.ReLU(inplace=True)
19 | self.layer.apply(weights_init)
20 |
21 | def forward(self, x):
22 | identity = x
23 | out = self.layer(x)
24 | out += identity
25 | return self.relu(out)
26 |
27 |
28 | class Resnet2(nn.Module):
29 | def __init__(self, in_channel, out_channel):
30 | super(Resnet2, self).__init__()
31 | self.layer1 = nn.Sequential(
32 | nn.Conv2d(in_channels=in_channel, out_channels=out_channel, kernel_size=3, padding=1, stride=1),
33 | nn.BatchNorm2d(out_channel),
34 | nn.ReLU(inplace=True),
35 | nn.MaxPool2d(kernel_size=2, stride=2),
36 | nn.Conv2d(in_channels=out_channel, out_channels=out_channel, kernel_size=3, padding=1, stride=1),
37 | nn.BatchNorm2d(out_channel)
38 | )
39 | self.layer2 = nn.Sequential(
40 | nn.Conv2d(in_channels=in_channel, out_channels=out_channel, kernel_size=3, padding=1, stride=2),
41 | nn.BatchNorm2d(out_channel),
42 | nn.ReLU(inplace=True)
43 | )
44 | self.relu = nn.ReLU(inplace=True)
45 | self.layer1.apply(weights_init)
46 | self.layer2.apply(weights_init)
47 |
48 | def forward(self, x):
49 | identity = x
50 | out = self.layer1(x)
51 | identity = self.layer2(identity)
52 | out += identity
53 | return self.relu(out)
54 |
55 |
56 | class Stage(nn.Module):
57 | def __init__(self):
58 | super(Stage, self).__init__()
59 | self.layer1 = nn.Sequential(
60 | nn.Conv2d(in_channels=3, out_channels=16, kernel_size=3, padding=1, stride=1),
61 | nn.BatchNorm2d(16),
62 | nn.ReLU(inplace=True),
63 | nn.Conv2d(in_channels=16, out_channels=16, kernel_size=3, padding=1, stride=1),
64 | nn.BatchNorm2d(16),
65 | nn.ReLU(inplace=True)
66 | )
67 | self.resnet1_1 = Resnet1(in_channel=16, out_channel=16)
68 | self.resnet1_2 = Resnet1(in_channel=16, out_channel=16)
69 | self.resnet1_3 = Resnet1(in_channel=16, out_channel=16)
70 | self.resnet2_1 = Resnet2(in_channel=16, out_channel=32)
71 | self.resnet2_2 = Resnet1(in_channel=32, out_channel=32)
72 | self.resnet2_3 = Resnet1(in_channel=32, out_channel=32)
73 | self.resnet3_1 = Resnet2(in_channel=32, out_channel=64)
74 | self.resnet3_2 = Resnet1(in_channel=64, out_channel=64)
75 | self.resnet3_3 = Resnet1(in_channel=64, out_channel=64)
76 | self.resnet4_1 = Resnet2(in_channel=64, out_channel=128)
77 | self.resnet4_2 = Resnet1(in_channel=128, out_channel=128)
78 | self.resnet4_3 = Resnet1(in_channel=128, out_channel=128)
79 | self.resnet5_1 = Resnet2(in_channel=128, out_channel=256)
80 | self.resnet5_2 = Resnet1(in_channel=256, out_channel=256)
81 | self.resnet5_3 = Resnet1(in_channel=256, out_channel=256)
82 | self.layer1.apply(weights_init)
83 |
84 | def forward(self, x):
85 | outs = []
86 | out = self.layer1(x)
87 | out = self.resnet1_1(out)
88 | out = self.resnet1_2(out)
89 | out = self.resnet1_3(out)
90 | # print("-------")
91 | # print(out.size())
92 | outs.append(out)
93 | out = self.resnet2_1(out)
94 | out = self.resnet2_2(out)
95 | out = self.resnet2_3(out)
96 | # print(out.size())
97 | outs.append(out)
98 | out = self.resnet3_1(out)
99 | out = self.resnet3_2(out)
100 | out = self.resnet3_3(out)
101 | # print(out.size())
102 | outs.append(out)
103 | out = self.resnet4_1(out)
104 | out = self.resnet4_2(out)
105 | out = self.resnet4_3(out)
106 | # print(out.size())
107 | outs.append(out)
108 | out = self.resnet5_1(out)
109 | out = self.resnet5_2(out)
110 | out = self.resnet5_3(out)
111 | # print(out.size())
112 | # print("-------")
113 | outs.append(out)
114 | return outs
115 |
116 |
117 | class LCL(nn.Module):
118 | def __init__(self, in_channel, out_channel):
119 | super(LCL, self).__init__()
120 | self.layer1 = nn.Sequential(
121 | nn.Conv2d(in_channels=in_channel, out_channels=out_channel, kernel_size=1, padding=0, stride=1),
122 | nn.ReLU(inplace=True),
123 | nn.Conv2d(in_channels=out_channel, out_channels=out_channel, kernel_size=3, padding=1, stride=1, dilation=1),
124 | #nn.BatchNorm2d(out_channel),
125 | nn.ReLU(inplace=True)
126 | )
127 | self.layer1.apply(weights_init)
128 | def forward(self, x):
129 | out = self.layer1(x)
130 | # print("-----")
131 | # print(out.size())
132 | # print("-----")
133 | return out
134 |
135 |
136 | class Sbam(nn.Module):
137 | def __init__(self, in_channel, out_channel):
138 | super(Sbam, self).__init__()
139 | self.hl_layer = nn.Sequential(
140 | nn.UpsamplingBilinear2d(scale_factor=2),
141 | nn.Conv2d(in_channels=in_channel, out_channels=out_channel, kernel_size=1),
142 | nn.BatchNorm2d(out_channel),
143 | nn.ReLU(inplace=True)
144 | )
145 | self.ll_layer = nn.Sequential(
146 | nn.Conv2d(in_channels=out_channel, out_channels=out_channel, kernel_size=1),
147 | nn.BatchNorm2d(out_channel),
148 | nn.Sigmoid() # ll = torch.sigmoid(ll)
149 | )
150 | self.hl_layer.apply(weights_init)
151 | self.ll_layer.apply(weights_init)
152 | def forward(self, hl,ll):
153 | hl = self.hl_layer(hl)
154 | # print(hl.size())
155 | ll_1 = ll
156 | ll = self.ll_layer(ll)
157 | # print(ll.size())
158 | hl_1 = hl*ll
159 | out = ll_1+hl_1
160 | return out
161 |
162 | class ALCL_No_Sigmoid(nn.Module):
163 | def __init__(self):
164 | super(ALCL_No_Sigmoid, self).__init__()
165 | self.stage = Stage()
166 | self.lcl5 = LCL(256, 256)
167 | self.lcl4 = LCL(128, 128)
168 | self.lcl3 = LCL(64, 64)
169 | self.lcl2 = LCL(32, 32)
170 | self.lcl1 = LCL(16, 16)
171 | self.sbam4 = Sbam(256, 128)
172 | self.sbam3 = Sbam(128, 64)
173 | self.sbam2 = Sbam(64, 32)
174 | self.sbam1 = Sbam(32, 16)
175 |
176 | self.layer = nn.Sequential(
177 | nn.Conv2d(in_channels=16, out_channels=16, kernel_size=1),
178 | nn.ReLU(inplace=True),
179 | nn.Conv2d(in_channels=16, out_channels=1, kernel_size=1),
180 | # nn.Sigmoid()
181 | )
182 | self.layer.apply(weights_init)
183 |
184 | def forward(self, x):
185 | outs = self.stage(x)
186 | out5 = self.lcl5(outs[4])
187 | # print(out5.size())
188 | out4 = self.lcl4(outs[3])
189 | # print(out4.size())
190 | out3 = self.lcl3(outs[2])
191 | # print(out3.size())
192 | out2 = self.lcl2(outs[1])
193 | # print(out2.size())
194 | out1 = self.lcl1(outs[0])
195 | # print(out1.size())
196 | out4_2 = self.sbam4(out5, out4)
197 | out3_2 = self.sbam3(out4_2, out3)
198 | out2_2 = self.sbam2(out3_2, out2)
199 | out1_2 = self.sbam1(out2_2, out1)
200 | out = self.layer(out1_2)
201 |
202 | return out
203 | def weights_init(m):
204 | if isinstance(m, nn.Conv2d):
205 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
206 | if m.bias is not None:
207 | nn.init.constant_(m.bias, 0)
208 | elif isinstance(m, nn.BatchNorm2d):
209 | nn.init.constant_(m.weight, 1)
210 | nn.init.constant_(m.bias, 0)
211 | elif isinstance(m, nn.Linear):
212 | nn.init.xavier_uniform_(m.weight)
213 | nn.init.constant_(m.bias, 0)
214 | return
215 |
216 | if __name__ == '__main__':
217 | model = ALCL_No_Sigmoid()
218 | x = torch.rand(8, 3, 512, 512)
219 | outs = model(x)
220 | print(outs.size())
221 |
--------------------------------------------------------------------------------
/model/ALCL/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YuChuang1205/PAL/a9046c5fc65e7b2548d6e195b624239fa836cf3e/model/ALCL/__init__.py
--------------------------------------------------------------------------------
/model/ALCL/__pycache__/ALCL_no_sigmoid.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YuChuang1205/PAL/a9046c5fc65e7b2548d6e195b624239fa836cf3e/model/ALCL/__pycache__/ALCL_no_sigmoid.cpython-36.pyc
--------------------------------------------------------------------------------
/model/ALCL/__pycache__/__init__.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YuChuang1205/PAL/a9046c5fc65e7b2548d6e195b624239fa836cf3e/model/ALCL/__pycache__/__init__.cpython-36.pyc
--------------------------------------------------------------------------------
/model/DNA/DNA_no_sigmoid.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 |
5 | class VGG_CBAM_Block(nn.Module):
6 | def __init__(self, in_channels, out_channels):
7 | super().__init__()
8 | self.conv1 = nn.Conv2d(in_channels, out_channels, 3, padding=1)
9 | self.bn1 = nn.BatchNorm2d(out_channels)
10 | self.relu = nn.ReLU(inplace=True)
11 | self.conv2 = nn.Conv2d(out_channels, out_channels, 3, padding=1)
12 | self.bn2 = nn.BatchNorm2d(out_channels)
13 | self.relu = nn.ReLU(inplace=True)
14 | self.ca = ChannelAttention(out_channels)
15 | self.sa = SpatialAttention()
16 |
17 | def forward(self, x):
18 | out = self.conv1(x)
19 | out = self.bn1(out)
20 | out = self.relu(out)
21 | out = self.conv2(out)
22 | out = self.bn2(out)
23 | out = self.ca(out) * out
24 | out = self.sa(out) * out
25 | out = self.relu(out)
26 | return out
27 |
28 | class ChannelAttention(nn.Module):
29 | def __init__(self, in_planes, ratio=16):
30 | super(ChannelAttention, self).__init__()
31 | self.avg_pool = nn.AdaptiveAvgPool2d(1)
32 | self.max_pool = nn.AdaptiveMaxPool2d(1)
33 | self.fc1 = nn.Conv2d(in_planes, in_planes // 16, 1, bias=False)
34 | self.relu1 = nn.ReLU()
35 | self.fc2 = nn.Conv2d(in_planes // 16, in_planes, 1, bias=False)
36 | self.sigmoid = nn.Sigmoid()
37 | def forward(self, x):
38 | avg_out = self.fc2(self.relu1(self.fc1(self.avg_pool(x))))
39 | max_out = self.fc2(self.relu1(self.fc1(self.max_pool(x))))
40 | out = avg_out + max_out
41 | return self.sigmoid(out)
42 |
43 | class SpatialAttention(nn.Module):
44 | def __init__(self, kernel_size=7):
45 | super(SpatialAttention, self).__init__()
46 | assert kernel_size in (3, 7), 'kernel size must be 3 or 7'
47 | padding = 3 if kernel_size == 7 else 1
48 | self.conv1 = nn.Conv2d(2, 1, kernel_size, padding=padding, bias=False)
49 | self.sigmoid = nn.Sigmoid()
50 | def forward(self, x):
51 | avg_out = torch.mean(x, dim=1, keepdim=True)
52 | max_out, _ = torch.max(x, dim=1, keepdim=True)
53 | x = torch.cat([avg_out, max_out], dim=1)
54 | x = self.conv1(x)
55 | return self.sigmoid(x)
56 |
57 | class Res_CBAM_block(nn.Module):
58 | def __init__(self, in_channels, out_channels, stride = 1):
59 | super(Res_CBAM_block, self).__init__()
60 | self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size = 3, stride = stride, padding = 1)
61 | self.bn1 = nn.BatchNorm2d(out_channels)
62 | self.relu = nn.ReLU(inplace = True)
63 | self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size = 3, padding = 1)
64 | self.bn2 = nn.BatchNorm2d(out_channels)
65 | if stride != 1 or out_channels != in_channels:
66 | self.shortcut = nn.Sequential(
67 | nn.Conv2d(in_channels, out_channels, kernel_size = 1, stride = stride),
68 | nn.BatchNorm2d(out_channels))
69 | else:
70 | self.shortcut = None
71 |
72 | self.ca = ChannelAttention(out_channels)
73 | self.sa = SpatialAttention()
74 |
75 | def forward(self, x):
76 | residual = x
77 | if self.shortcut is not None:
78 | residual = self.shortcut(x)
79 | out = self.conv1(x)
80 | out = self.bn1(out)
81 | out = self.relu(out)
82 | out = self.conv2(out)
83 | out = self.bn2(out)
84 | out = self.ca(out) * out
85 | out = self.sa(out) * out
86 | out += residual
87 | out = self.relu(out)
88 | return out
89 |
90 | class DNA_No_Sigmoid(nn.Module):
91 | def __init__(self, num_classes=1,input_channels=3, block=Res_CBAM_block, num_blocks=[2, 2, 2, 2], nb_filter=[16, 32, 64, 128, 256], deep_supervision=True, mode='test'):
92 | super(DNA_No_Sigmoid, self).__init__()
93 | self.mode = mode
94 | self.relu = nn.ReLU(inplace = True)
95 | self.deep_supervision = deep_supervision
96 | self.pool = nn.MaxPool2d(2, 2)
97 | self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
98 | self.down = nn.Upsample(scale_factor=0.5, mode='bilinear', align_corners=True)
99 |
100 | self.up_4 = nn.Upsample(scale_factor=4, mode='bilinear', align_corners=True)
101 | self.up_8 = nn.Upsample(scale_factor=8, mode='bilinear', align_corners=True)
102 | self.up_16 = nn.Upsample(scale_factor=16, mode='bilinear', align_corners=True)
103 |
104 | self.conv0_0 = self._make_layer(block, input_channels, nb_filter[0])
105 | self.conv1_0 = self._make_layer(block, nb_filter[0], nb_filter[1], num_blocks[0])
106 | self.conv2_0 = self._make_layer(block, nb_filter[1], nb_filter[2], num_blocks[1])
107 | self.conv3_0 = self._make_layer(block, nb_filter[2], nb_filter[3], num_blocks[2])
108 | self.conv4_0 = self._make_layer(block, nb_filter[3], nb_filter[4], num_blocks[3])
109 |
110 | self.conv0_1 = self._make_layer(block, nb_filter[0] + nb_filter[1], nb_filter[0])
111 | self.conv1_1 = self._make_layer(block, nb_filter[1] + nb_filter[2] + nb_filter[0], nb_filter[1], num_blocks[0])
112 | self.conv2_1 = self._make_layer(block, nb_filter[2] + nb_filter[3] + nb_filter[1], nb_filter[2], num_blocks[1])
113 | self.conv3_1 = self._make_layer(block, nb_filter[3] + nb_filter[4] + nb_filter[2], nb_filter[3], num_blocks[2])
114 |
115 | self.conv0_2 = self._make_layer(block, nb_filter[0]*2 + nb_filter[1], nb_filter[0])
116 | self.conv1_2 = self._make_layer(block, nb_filter[1]*2 + nb_filter[2]+ nb_filter[0], nb_filter[1], num_blocks[0])
117 | self.conv2_2 = self._make_layer(block, nb_filter[2]*2 + nb_filter[3]+ nb_filter[1], nb_filter[2], num_blocks[1])
118 |
119 | self.conv0_3 = self._make_layer(block, nb_filter[0]*3 + nb_filter[1], nb_filter[0])
120 | self.conv1_3 = self._make_layer(block, nb_filter[1]*3 + nb_filter[2]+ nb_filter[0], nb_filter[1], num_blocks[0])
121 |
122 | self.conv0_4 = self._make_layer(block, nb_filter[0]*4 + nb_filter[1], nb_filter[0])
123 |
124 | self.conv0_4_final = self._make_layer(block, nb_filter[0]*5, nb_filter[0])
125 |
126 | self.conv0_4_1x1 = nn.Conv2d(nb_filter[4], nb_filter[0], kernel_size=1, stride=1)
127 | self.conv0_3_1x1 = nn.Conv2d(nb_filter[3], nb_filter[0], kernel_size=1, stride=1)
128 | self.conv0_2_1x1 = nn.Conv2d(nb_filter[2], nb_filter[0], kernel_size=1, stride=1)
129 | self.conv0_1_1x1 = nn.Conv2d(nb_filter[1], nb_filter[0], kernel_size=1, stride=1)
130 |
131 | if self.deep_supervision:
132 | self.final1 = nn.Conv2d (nb_filter[0], num_classes, kernel_size=1)
133 | self.final2 = nn.Conv2d (nb_filter[0], num_classes, kernel_size=1)
134 | self.final3 = nn.Conv2d (nb_filter[0], num_classes, kernel_size=1)
135 | self.final4 = nn.Conv2d (nb_filter[0], num_classes, kernel_size=1)
136 | else:
137 | self.final = nn.Conv2d (nb_filter[0], num_classes, kernel_size=1)
138 |
139 | def _make_layer(self, block, input_channels, output_channels, num_blocks=1):
140 | layers = []
141 | layers.append(block(input_channels, output_channels))
142 | for i in range(num_blocks-1):
143 | layers.append(block(output_channels, output_channels))
144 | return nn.Sequential(*layers)
145 |
146 | def forward(self, input):
147 | x0_0 = self.conv0_0(input)
148 | x1_0 = self.conv1_0(self.pool(x0_0))
149 | x0_1 = self.conv0_1(torch.cat([x0_0, self.up(x1_0)], 1))
150 |
151 | x2_0 = self.conv2_0(self.pool(x1_0))
152 | x1_1 = self.conv1_1(torch.cat([x1_0, self.up(x2_0),self.down(x0_1)], 1))
153 | x0_2 = self.conv0_2(torch.cat([x0_0, x0_1, self.up(x1_1)], 1))
154 |
155 | x3_0 = self.conv3_0(self.pool(x2_0))
156 | x2_1 = self.conv2_1(torch.cat([x2_0, self.up(x3_0),self.down(x1_1)], 1))
157 | x1_2 = self.conv1_2(torch.cat([x1_0, x1_1, self.up(x2_1),self.down(x0_2)], 1))
158 | x0_3 = self.conv0_3(torch.cat([x0_0, x0_1, x0_2, self.up(x1_2)], 1))
159 |
160 | x4_0 = self.conv4_0(self.pool(x3_0))
161 | x3_1 = self.conv3_1(torch.cat([x3_0, self.up(x4_0),self.down(x2_1)], 1))
162 | x2_2 = self.conv2_2(torch.cat([x2_0, x2_1, self.up(x3_1),self.down(x1_2)], 1))
163 | x1_3 = self.conv1_3(torch.cat([x1_0, x1_1, x1_2, self.up(x2_2),self.down(x0_3)], 1))
164 | x0_4 = self.conv0_4(torch.cat([x0_0, x0_1, x0_2, x0_3, self.up(x1_3)], 1))
165 |
166 | Final_x0_4 = self.conv0_4_final(
167 | torch.cat([self.up_16(self.conv0_4_1x1(x4_0)),self.up_8(self.conv0_3_1x1(x3_1)),
168 | self.up_4 (self.conv0_2_1x1(x2_2)),self.up (self.conv0_1_1x1(x1_3)), x0_4], 1))
169 |
170 | if self.deep_supervision:
171 | # output1 = self.final1(x0_1).sigmoid()
172 | # output2 = self.final2(x0_2).sigmoid()
173 | # output3 = self.final3(x0_3).sigmoid()
174 | # output4 = self.final4(Final_x0_4).sigmoid()
175 | output1 = self.final1(x0_1)
176 | output2 = self.final2(x0_2)
177 | output3 = self.final3(x0_3)
178 | output4 = self.final4(Final_x0_4)
179 | if self.mode == 'train':
180 | return [output1, output2, output3, output4]
181 | else:
182 | return output4
183 | else:
184 | # output = self.final(Final_x0_4).sigmoid()
185 | output = self.final(Final_x0_4)
186 | return output
187 |
188 |
189 |
--------------------------------------------------------------------------------
/model/DNA/__init__.py:
--------------------------------------------------------------------------------
1 |
2 |
3 |
--------------------------------------------------------------------------------
/model/DNA/__pycache__/DNA_no_sigmoid.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YuChuang1205/PAL/a9046c5fc65e7b2548d6e195b624239fa836cf3e/model/DNA/__pycache__/DNA_no_sigmoid.cpython-36.pyc
--------------------------------------------------------------------------------
/model/DNA/__pycache__/__init__.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YuChuang1205/PAL/a9046c5fc65e7b2548d6e195b624239fa836cf3e/model/DNA/__pycache__/__init__.cpython-36.pyc
--------------------------------------------------------------------------------
/model/GGL/GGL_no_sigmoid.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from torch.nn import init
4 | import torch.nn.functional as F
5 |
6 | def gradient_1order(x,h_x=None,w_x=None):
7 | if h_x is None and w_x is None:
8 | h_x = x.size()[2]
9 | w_x = x.size()[3]
10 | r = F.pad(x, (0, 1, 0, 0))[:, :, :, 1:]
11 | l = F.pad(x, (1, 0, 0, 0))[:, :, :, :w_x]
12 | t = F.pad(x, (0, 0, 1, 0))[:, :, :h_x, :]
13 | b = F.pad(x, (0, 0, 0, 1))[:, :, 1:, :]
14 | xgrad = torch.pow(torch.pow((r - l) * 0.5, 2) + torch.pow((t - b) * 0.5, 2), 0.5)
15 | return xgrad
16 |
17 | class SEAttention(nn.Module):
18 |
19 | def __init__(self, channel,reduction=16):
20 | super().__init__()
21 | self.avg_pool = nn.AdaptiveAvgPool2d(1)
22 | self.fc = nn.Sequential(
23 | nn.Linear(channel, channel // reduction, bias=False),
24 | nn.ReLU(inplace=True),
25 | nn.Linear(channel // reduction, channel, bias=False),
26 | nn.Sigmoid()
27 | )
28 |
29 |
30 | def init_weights(self):
31 | for m in self.modules():
32 | if isinstance(m, nn.Conv2d):
33 | init.kaiming_normal_(m.weight, mode='fan_out')
34 | if m.bias is not None:
35 | init.constant_(m.bias, 0)
36 | elif isinstance(m, nn.BatchNorm2d):
37 | init.constant_(m.weight, 1)
38 | init.constant_(m.bias, 0)
39 | elif isinstance(m, nn.Linear):
40 | init.normal_(m.weight, std=0.001)
41 | if m.bias is not None:
42 | init.constant_(m.bias, 0)
43 |
44 | def forward(self, x):
45 | b, c, _, _ = x.size()
46 | y = self.avg_pool(x).view(b, c)
47 | y = self.fc(y).view(b, c, 1, 1)
48 | return x * y.expand_as(x)
49 |
50 |
51 | class ChannelAttention(nn.Module):
52 | def __init__(self, channel, reduction=8):
53 | super().__init__()
54 | self.maxpool = nn.AdaptiveMaxPool2d(1)
55 | self.avgpool = nn.AdaptiveAvgPool2d(1)
56 | self.se = nn.Sequential(
57 | nn.Conv2d(channel, channel // reduction, 1, bias=False),
58 | nn.ReLU(),
59 | nn.Conv2d(channel // reduction, channel, 1, bias=False)
60 | )
61 | self.sigmoid = nn.Sigmoid()
62 |
63 | def init_weights(self):
64 | for m in self.modules():
65 | if isinstance(m, nn.Conv2d):
66 | init.kaiming_normal_(m.weight, mode='fan_out')
67 | if m.bias is not None:
68 | init.constant_(m.bias, 0)
69 | elif isinstance(m, nn.BatchNorm2d):
70 | init.constant_(m.weight, 1)
71 | init.constant_(m.bias, 0)
72 | elif isinstance(m, nn.Linear):
73 | init.normal_(m.weight, std=0.001)
74 | if m.bias is not None:
75 | init.constant_(m.bias, 0)
76 | def forward(self, x):
77 | max_result = self.maxpool(x)
78 | avg_result = self.avgpool(x)
79 | max_out = self.se(max_result)
80 | avg_out = self.se(avg_result)
81 | output = self.sigmoid(max_out + avg_out)
82 | return output
83 |
84 | class SpatialAttention(nn.Module):
85 | def __init__(self, kernel_size=7):
86 | super().__init__()
87 | self.conv = nn.Conv2d(2, 1, kernel_size=kernel_size, padding=kernel_size // 2)
88 | self.sigmoid = nn.Sigmoid()
89 |
90 | def init_weights(self):
91 | for m in self.modules():
92 | if isinstance(m, nn.Conv2d):
93 | init.kaiming_normal_(m.weight, mode='fan_out')
94 | if m.bias is not None:
95 | init.constant_(m.bias, 0)
96 | elif isinstance(m, nn.BatchNorm2d):
97 | init.constant_(m.weight, 1)
98 | init.constant_(m.bias, 0)
99 | elif isinstance(m, nn.Linear):
100 | init.normal_(m.weight, std=0.001)
101 | if m.bias is not None:
102 | init.constant_(m.bias, 0)
103 | def forward(self, x):
104 | max_result, _ = torch.max(x, dim=1, keepdim=True)
105 | avg_result = torch.mean(x, dim=1, keepdim=True)
106 | result = torch.cat([max_result, avg_result], 1)
107 | output = self.conv(result)
108 | output = self.sigmoid(output)
109 | return output
110 |
111 | class Resnet1(nn.Module):
112 | def __init__(self, in_channel, out_channel):
113 | super(Resnet1, self).__init__()
114 | self.layer = nn.Sequential(
115 | nn.Conv2d(in_channels=in_channel, out_channels=out_channel, kernel_size=3, padding=1, stride=1),
116 | nn.BatchNorm2d(out_channel),
117 | nn.ReLU(inplace=True),
118 | nn.Conv2d(in_channels=out_channel, out_channels=out_channel, kernel_size=3, padding=1, stride=1),
119 | nn.BatchNorm2d(out_channel)
120 | )
121 | self.relu = nn.ReLU(inplace=True)
122 | self.layer.apply(weights_init)
123 |
124 | def forward(self, x):
125 | identity = x
126 | out = self.layer(x)
127 | out += identity
128 | return self.relu(out)
129 |
130 |
131 | # layer2_1 #layer3_1#layer4_1#layer5_1
132 | class Resnet2(nn.Module):
133 | def __init__(self, in_channel, out_channel):
134 | super(Resnet2, self).__init__()
135 | self.layer1 = nn.Sequential(
136 | nn.Conv2d(in_channels=in_channel, out_channels=out_channel, kernel_size=3, padding=1, stride=1),
137 | nn.BatchNorm2d(out_channel),
138 | nn.ReLU(inplace=True),
139 | nn.MaxPool2d(kernel_size=2, stride=2),
140 | nn.Conv2d(in_channels=out_channel, out_channels=out_channel, kernel_size=3, padding=1, stride=1),
141 | nn.BatchNorm2d(out_channel)
142 | )
143 | self.layer2 = nn.Sequential(
144 | nn.Conv2d(in_channels=in_channel, out_channels=out_channel, kernel_size=3, padding=1, stride=2),
145 | nn.BatchNorm2d(out_channel),
146 | nn.ReLU(inplace=True)
147 | )
148 | self.relu = nn.ReLU(inplace=True)
149 | self.layer1.apply(weights_init)
150 | self.layer2.apply(weights_init)
151 |
152 | def forward(self, x):
153 | identity = x
154 | out = self.layer1(x)
155 | identity = self.layer2(identity)
156 | out += identity
157 | return self.relu(out)
158 |
159 |
160 | class Resnet3(nn.Module):
161 | def __init__(self, in_channel, out_channel):
162 | super(Resnet3, self).__init__()
163 | self.layer1 = nn.Sequential(
164 | nn.Conv2d(in_channels=in_channel, out_channels=out_channel, kernel_size=3, padding=1, stride=1),
165 | nn.BatchNorm2d(out_channel),
166 | nn.ReLU(inplace=True),
167 |
168 | )
169 | self.layer2 = nn.Sequential(
170 | nn.Conv2d(in_channels=out_channel, out_channels=out_channel, kernel_size=3, padding=1, stride=1),
171 | nn.BatchNorm2d(out_channel)
172 | )
173 | self.SEAttention = SEAttention(channel=out_channel, reduction=4)
174 | self.relu = nn.ReLU(inplace=True)
175 | self.layer1.apply(weights_init)
176 | self.layer2.apply(weights_init)
177 |
178 | def forward(self, x):
179 | identity = x
180 | out = self.layer1(x)
181 | out = self.SEAttention(out)
182 | out = self.layer2(out)
183 | out += identity
184 | return self.relu(out)
185 |
186 | class Res(nn.Module):
187 | def __init__(self, befor_channel,after_channel):
188 | super(Res, self).__init__()
189 |
190 | self.layer1 = nn.Sequential(
191 | nn.Conv2d(in_channels=befor_channel, out_channels=after_channel, kernel_size=3, padding=1, stride=1),
192 | nn.BatchNorm2d(after_channel),
193 | nn.ReLU(inplace=True),
194 | SEAttention(channel=after_channel, reduction=4),
195 | nn.Conv2d(in_channels=after_channel, out_channels=after_channel, kernel_size=3, padding=1, stride=1),
196 | nn.BatchNorm2d(after_channel),
197 | nn.ReLU(inplace=True),
198 | )
199 | self.layer2 = nn.Sequential(
200 | nn.Conv2d(in_channels=2*after_channel, out_channels=after_channel, kernel_size=3, padding=1, stride=1),
201 | nn.BatchNorm2d(after_channel),
202 | nn.ReLU(inplace=True),
203 | SEAttention(channel=after_channel, reduction=4),
204 | nn.Conv2d(in_channels=after_channel, out_channels=after_channel, kernel_size=3, padding=1, stride=1),
205 | nn.BatchNorm2d(after_channel),
206 | )
207 | self.relu = nn.ReLU(inplace=True)
208 | self.layer1.apply(weights_init)
209 | self.layer2.apply(weights_init)
210 |
211 | def forward(self, x, x1):
212 | x1 = self.layer1(x1)
213 | # print(x1.size())
214 | con = torch.cat([x, x1], 1)
215 | identity = x
216 | out = self.layer2(con)
217 | out = out + identity
218 | return self.relu(out)
219 |
220 | class Stage(nn.Module):
221 | def __init__(self):
222 | super(Stage, self).__init__()
223 | self.layer1 = nn.Sequential(
224 | nn.Conv2d(in_channels=3, out_channels=16, kernel_size=3, padding=1, stride=1),
225 | nn.BatchNorm2d(16),
226 | nn.ReLU(inplace=True),
227 | nn.Conv2d(in_channels=16, out_channels=16, kernel_size=3, padding=1, stride=1),
228 | nn.BatchNorm2d(16),
229 | nn.ReLU(inplace=True)
230 | )
231 | self.resnet1_1 = Resnet1(in_channel=16, out_channel=16)
232 | self.resnet1_2 = Resnet3(in_channel=16, out_channel=16)
233 | self.resnet1_3 = Resnet3(in_channel=16, out_channel=16)
234 | self.Res1 = Res(befor_channel=3, after_channel=16)
235 | self.resnet2_1 = Resnet2(in_channel=16, out_channel=32)
236 | self.resnet2_2 = Resnet3(in_channel=32, out_channel=32)
237 | self.resnet2_3 = Resnet3(in_channel=32, out_channel=32)
238 | self.Res2 = Res(befor_channel=3, after_channel=32)
239 | self.resnet3_1 = Resnet2(in_channel=32, out_channel=64)
240 | self.resnet3_2 = Resnet3(in_channel=64, out_channel=64)
241 | self.resnet3_3 = Resnet3(in_channel=64, out_channel=64)
242 | self.Res3 = Res(befor_channel=3, after_channel=64)
243 | self.resnet4_1 = Resnet2(in_channel=64, out_channel=128)
244 | self.resnet4_2 = Resnet3(in_channel=128, out_channel=128)
245 | self.resnet4_3 = Resnet3(in_channel=128, out_channel=128)
246 | self.Res4 = Res(befor_channel=3, after_channel=128)
247 | self.resnet5_1 = Resnet2(in_channel=128, out_channel=256)
248 | self.resnet5_2 = Resnet3(in_channel=256, out_channel=256)
249 | self.resnet5_3 = Resnet3(in_channel=256, out_channel=256)
250 | self.Res5 = Res(befor_channel=3, after_channel=256)
251 | self.pool = nn.Sequential(nn.MaxPool2d(kernel_size=2, stride=2), )
252 | self.layer1.apply(weights_init)
253 |
254 | def forward(self, x):
255 | x_g = gradient_1order(x)
256 | outs = []
257 | out = self.layer1(x)
258 | # print(out_1.size())
259 | out = self.resnet1_1(out)
260 | out = self.resnet1_2(out)
261 | out = self.resnet1_3(out)
262 | # print(out.size())
263 | #out_1 = self.pool(out_1)
264 | out = self.Res1(out, x_g)
265 | # print("-------")
266 | # print(out.size())
267 | outs.append(out)
268 | out = self.resnet2_1(out)
269 | out = self.resnet2_2(out)
270 | out = self.resnet2_3(out)
271 | x1 = self.pool(x_g)
272 | out = self.Res2(out, x1)
273 | # print(out.size())
274 | outs.append(out)
275 | out = self.resnet3_1(out)
276 | out = self.resnet3_2(out)
277 | out = self.resnet3_3(out)
278 | # print(out.size())
279 | x2 = self.pool(x1)
280 | # print(x2.size())
281 | out = self.Res3(out, x2)
282 | # print(out.size())
283 | outs.append(out)
284 | out = self.resnet4_1(out)
285 | out = self.resnet4_2(out)
286 | out = self.resnet4_3(out)
287 | # print(out.size())
288 | x3 = self.pool(x2)
289 | # print(x3.size())
290 | out = self.Res4(out, x3)
291 | # print(out.size())
292 | outs.append(out)
293 | out = self.resnet5_1(out)
294 | out = self.resnet5_2(out)
295 | out = self.resnet5_3(out)
296 | x4 = self.pool(x3)
297 | out = self.Res5(out, x4)
298 | # print(out.size())
299 | # print("-------")
300 | outs.append(out)
301 | return outs
302 |
303 | class LCL(nn.Module):
304 | def __init__(self, in_channel, out_channel):
305 | super(LCL, self).__init__()
306 | self.layer1 = nn.Sequential(
307 | nn.Conv2d(in_channels=in_channel, out_channels=out_channel, kernel_size=1, padding=0, stride=1),
308 | nn.ReLU(inplace=True),
309 | nn.Conv2d(in_channels=out_channel, out_channels=out_channel,kernel_size=3, padding=1, stride=1, dilation=1),
310 | #nn.BatchNorm2d(out_channel),
311 | nn.ReLU(inplace=True)
312 | )
313 | self.layer1.apply(weights_init)
314 | def forward(self, x):
315 | out = self.layer1(x)
316 | # print("-----")
317 | # print(out.size())
318 | # print("-----")
319 | return out
320 |
321 | class Sbam(nn.Module):
322 | def __init__(self, in_channel, out_channel):
323 | super(Sbam, self).__init__()
324 | self.hl_layer = nn.Sequential(
325 | nn.UpsamplingBilinear2d(scale_factor=2),
326 | nn.Conv2d(in_channels=in_channel, out_channels=out_channel, kernel_size=1),
327 | nn.BatchNorm2d(out_channel),
328 | nn.ReLU(inplace=True)
329 | )
330 | self.hl_layer_2 = ChannelAttention(out_channel)
331 | self.ll_layer = SpatialAttention()
332 | # self.ll_layer = nn.Sequential(
333 | # nn.Conv2d(in_channels=out_channel, out_channels=out_channel, kernel_size=1),
334 | # nn.BatchNorm2d(out_channel),
335 | # nn.Sigmoid() # ll = torch.sigmoid(ll)
336 | # )
337 | self.hl_layer.apply(weights_init)
338 | def forward(self, hl,ll):
339 | hl = self.hl_layer(hl)
340 | # print(hl.size())
341 | ll_1 =ll * self.hl_layer_2(hl)
342 |
343 | ll = self.ll_layer(ll)
344 | # print(ll.size())
345 | hl_1 = hl * ll
346 | out = ll_1 + hl_1
347 | return out
348 |
349 | class GGL_No_Sigmoid(nn.Module):
350 | def __init__(self):
351 | super(GGL_No_Sigmoid, self).__init__()
352 | self.stage = Stage()
353 | self.lcl5 = LCL(256, 256)
354 | self.lcl4 = LCL(128, 128)
355 | self.lcl3 = LCL(64, 64)
356 | self.lcl2 = LCL(32, 32)
357 | self.lcl1 = LCL(16, 16)
358 | self.sbam4 = Sbam(256, 128)
359 | self.sbam3 = Sbam(128, 64)
360 | self.sbam2 = Sbam(64, 32)
361 | self.sbam1 = Sbam(32, 16)
362 |
363 | self.layer = nn.Sequential(
364 | nn.Conv2d(in_channels=16, out_channels=16, kernel_size=1),
365 | nn.ReLU(inplace=True),
366 | nn.Conv2d(in_channels=16, out_channels=1, kernel_size=1),
367 | # nn.Sigmoid()
368 | )
369 | self.layer.apply(weights_init)
370 |
371 | def forward(self, x):
372 | outs = self.stage(x)
373 | out5 = self.lcl5(outs[4])
374 | # print(out5.size())
375 | out4 = self.lcl4(outs[3])
376 | # print(out4.size())
377 | out3 = self.lcl3(outs[2])
378 | # print(out3.size())
379 | out2 = self.lcl2(outs[1])
380 | # print(out2.size())
381 | out1 = self.lcl1(outs[0])
382 | # print(out1.size())
383 | out4_2 = self.sbam4(out5, out4)
384 | out3_2 = self.sbam3(out4_2, out3)
385 | out2_2 = self.sbam2(out3_2, out2)
386 | out1_2 = self.sbam1(out2_2, out1)
387 | out = self.layer(out1_2)
388 |
389 | return out
390 |
391 |
392 |
393 | def weights_init(m):
394 | if isinstance(m, nn.Conv2d):
395 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
396 | if m.bias is not None:
397 | nn.init.constant_(m.bias, 0)
398 | elif isinstance(m, nn.BatchNorm2d):
399 | nn.init.constant_(m.weight, 1)
400 | nn.init.constant_(m.bias, 0)
401 | # elif isinstance(m, nn.Linear):
402 | # nn.init.xavier_uniform_(m.weight)
403 | # nn.init.constant_(m.bias, 0)
404 | return
405 |
406 |
407 | # if __name__ == '__main__':
408 | # model = GGL_No_Sigmoid()
409 | # x = torch.rand(8, 3, 512, 512)
410 | # x_g = gradient_1order(x)
411 | # outs = model(x,x_g)
412 | # print(outs.size())
413 |
--------------------------------------------------------------------------------
/model/GGL/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YuChuang1205/PAL/a9046c5fc65e7b2548d6e195b624239fa836cf3e/model/GGL/__init__.py
--------------------------------------------------------------------------------
/model/GGL/__pycache__/GGL_no_sigmoid.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YuChuang1205/PAL/a9046c5fc65e7b2548d6e195b624239fa836cf3e/model/GGL/__pycache__/GGL_no_sigmoid.cpython-36.pyc
--------------------------------------------------------------------------------
/model/GGL/__pycache__/__init__.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YuChuang1205/PAL/a9046c5fc65e7b2548d6e195b624239fa836cf3e/model/GGL/__pycache__/__init__.cpython-36.pyc
--------------------------------------------------------------------------------
/model/MLCL/MLCL_no_sigmoid.py:
--------------------------------------------------------------------------------
1 | """
2 | @author: yuchuang,zhaojinmiao
3 | @time:
4 | @desc: 这个版本是MLCL-Net的基础版本(论文一致)。即每阶段用了3个block paper:"Infrared small target detection based on multiscale local contrast learning networks"
5 | """
6 | import torch
7 | import torch.nn as nn
8 |
9 | class Resnet1(nn.Module):
10 | def __init__(self, in_channel, out_channel):
11 | super(Resnet1, self).__init__()
12 | self.layer = nn.Sequential(
13 | nn.Conv2d(in_channels=in_channel, out_channels=out_channel, kernel_size=3, padding=1, stride=1),
14 | nn.BatchNorm2d(out_channel),
15 | nn.ReLU(inplace=True),
16 | nn.Conv2d(in_channels=out_channel, out_channels=out_channel, kernel_size=3, padding=1, stride=1),
17 | nn.BatchNorm2d(out_channel)
18 | )
19 | self.relu = nn.ReLU(inplace=True)
20 | self.layer.apply(weights_init)
21 |
22 | def forward(self, x):
23 | identity = x
24 | out = self.layer(x)
25 | out += identity
26 | return self.relu(out)
27 |
28 |
29 | class Resnet2(nn.Module):
30 | def __init__(self, in_channel, out_channel):
31 | super(Resnet2, self).__init__()
32 | self.layer1 = nn.Sequential(
33 | nn.Conv2d(in_channels=in_channel, out_channels=out_channel, kernel_size=3, padding=1, stride=1),
34 | nn.BatchNorm2d(out_channel),
35 | nn.ReLU(inplace=True),
36 | nn.MaxPool2d(kernel_size=2, stride=2),
37 | nn.Conv2d(in_channels=out_channel, out_channels=out_channel, kernel_size=3, padding=1, stride=1),
38 | nn.BatchNorm2d(out_channel)
39 | )
40 | self.layer2 = nn.Sequential(
41 | nn.Conv2d(in_channels=in_channel, out_channels=out_channel, kernel_size=3, padding=1, stride=2),
42 | nn.BatchNorm2d(out_channel),
43 | nn.ReLU(inplace=True)
44 | )
45 | self.relu = nn.ReLU(inplace=True)
46 | self.layer1.apply(weights_init)
47 | self.layer2.apply(weights_init)
48 |
49 | def forward(self, x):
50 | identity = x
51 | out = self.layer1(x)
52 | identity = self.layer2(identity)
53 | out += identity
54 | return self.relu(out)
55 |
56 |
57 | class Stage(nn.Module):
58 | def __init__(self):
59 | super(Stage, self).__init__()
60 | self.layer1 = nn.Sequential(
61 | nn.Conv2d(in_channels=3, out_channels=16, kernel_size=3, padding=1, stride=1),
62 | nn.BatchNorm2d(16),
63 | nn.ReLU(inplace=True),
64 | nn.Conv2d(in_channels=16, out_channels=16, kernel_size=3, padding=1, stride=1),
65 | nn.BatchNorm2d(16),
66 | nn.ReLU(inplace=True)
67 | )
68 | self.resnet1_1 = Resnet1(in_channel=16, out_channel=16)
69 | self.resnet1_2 = Resnet1(in_channel=16, out_channel=16)
70 | self.resnet1_3 = Resnet1(in_channel=16, out_channel=16)
71 | self.resnet2_1 = Resnet2(in_channel=16, out_channel=32)
72 | self.resnet2_2 = Resnet1(in_channel=32, out_channel=32)
73 | self.resnet2_3 = Resnet1(in_channel=32, out_channel=32)
74 | self.resnet3_1 = Resnet2(in_channel=32, out_channel=64)
75 | self.resnet3_2 = Resnet1(in_channel=64, out_channel=64)
76 | self.resnet3_3 = Resnet1(in_channel=64, out_channel=64)
77 |
78 | def forward(self, x):
79 | outs = []
80 | out = self.layer1(x)
81 | out = self.resnet1_1(out)
82 | out = self.resnet1_2(out)
83 | out = self.resnet1_3(out)
84 | outs.append(out)
85 | out = self.resnet2_1(out)
86 | out = self.resnet2_2(out)
87 | out = self.resnet2_3(out)
88 | outs.append(out)
89 | out = self.resnet3_1(out)
90 | out = self.resnet3_2(out)
91 | out = self.resnet3_3(out)
92 | outs.append(out)
93 | return outs
94 |
95 |
96 | class MLCL(nn.Module):
97 | def __init__(self, in_channel, out_channel):
98 | super(MLCL, self).__init__()
99 | self.layer1 = nn.Sequential(
100 | nn.Conv2d(in_channels=in_channel, out_channels=out_channel, kernel_size=1, padding=0, stride=1),
101 | nn.ReLU(inplace=True),
102 | nn.Conv2d(in_channels=out_channel, out_channels=out_channel, kernel_size=3, padding=1, stride=1, dilation=1),
103 | nn.ReLU(inplace=True)
104 | )
105 | self.layer2 = nn.Sequential(
106 | nn.Conv2d(in_channels=in_channel, out_channels=out_channel, kernel_size=3, padding=1, stride=1),
107 | nn.ReLU(inplace=True),
108 | nn.Conv2d(in_channels=out_channel, out_channels=out_channel, kernel_size=3, padding=3, stride=1, dilation=3),
109 | nn.ReLU(inplace=True)
110 | )
111 | self.layer3 = nn.Sequential(
112 | nn.Conv2d(in_channels=in_channel, out_channels=out_channel, kernel_size=5, padding=2, stride=1),
113 | nn.ReLU(inplace=True),
114 | nn.Conv2d(in_channels=out_channel, out_channels=out_channel, kernel_size=3, padding=5, stride=1, dilation=5),
115 | nn.ReLU(inplace=True)
116 | )
117 | # self.layer4 = nn.Sequential(
118 | # nn.Conv2d(in_channels=in_channel, out_channels=out_channel, kernel_size=7, padding=2, stride=1),
119 | # nn.ReLU(inplace=True),
120 | # nn.Conv2d(in_channels=out_channel, out_channels=out_channel, kernel_size=3, padding=7, stride=1,dilation=7),
121 | # nn.ReLU(inplace=True)
122 | # )
123 | self.conv = nn.Conv2d(in_channels=out_channel * 3, out_channels=out_channel, kernel_size=1)
124 | self.layer1.apply(weights_init)
125 | self.layer2.apply(weights_init)
126 | self.layer3.apply(weights_init)
127 | self.conv.apply(weights_init)
128 | def forward(self, x):
129 | x1 = x
130 | x2 = x
131 | x3 = x
132 | out1 = self.layer1(x1)
133 | out2 = self.layer2(x2)
134 | out3 = self.layer3(x3)
135 | outs = torch.cat((out1, out2, out3), dim=1)
136 | return self.conv(outs)
137 |
138 |
139 | class MLCL_No_Sigmoid(nn.Module):
140 | def __init__(self):
141 | super(MLCL_No_Sigmoid, self).__init__()
142 | self.stage = Stage()
143 | self.mlcl3 = MLCL(64, 64)
144 | self.mlcl2 = MLCL(32, 32)
145 | self.mlcl1 = MLCL(16, 16)
146 | self.up3 = nn.UpsamplingBilinear2d(scale_factor=2)
147 | self.conv3 = nn.Sequential(
148 | nn.Conv2d(in_channels=64, out_channels=64, kernel_size=1),
149 | nn.ReLU(inplace=True)
150 | )
151 | self.up2 = nn.UpsamplingBilinear2d(scale_factor=2)
152 | self.conv2 = nn.Sequential(
153 | nn.Conv2d(in_channels=32, out_channels=64, kernel_size=1),
154 | nn.ReLU(inplace=True)
155 | )
156 | self.conv1 = nn.Sequential(
157 | nn.Conv2d(in_channels=16, out_channels=64, kernel_size=1),
158 | nn.ReLU(inplace=True)
159 | )
160 | self.layer = nn.Sequential(
161 | nn.Conv2d(in_channels=64, out_channels=64, kernel_size=1),
162 | nn.ReLU(inplace=True),
163 | nn.Conv2d(in_channels=64, out_channels=1, kernel_size=1),
164 | #nn.Sigmoid()
165 | )
166 |
167 | def forward(self, x):
168 | outs = self.stage(x)
169 | out3 = self.mlcl3(outs[2])
170 | out2 = self.mlcl2(outs[1])
171 | out1 = self.mlcl1(outs[0])
172 | out3 = self.conv3(out3) # 128*128 64
173 | out3 = self.up3(out3) # 256*256 64
174 | out2 = self.conv2(out2) # 256*256 64
175 | out = out3 + out2
176 | out = self.up2(out)
177 | out1 = self.conv1(out1)
178 | out = out + out1
179 | out = self.layer(out)
180 | return out
181 |
182 | def weights_init(m):
183 | if isinstance(m, nn.Conv2d):
184 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
185 | if m.bias is not None:
186 | nn.init.constant_(m.bias, 0)
187 | elif isinstance(m, nn.BatchNorm2d):
188 | nn.init.constant_(m.weight, 1)
189 | nn.init.constant_(m.bias, 0)
190 | elif isinstance(m, nn.Linear):
191 | nn.init.xavier_uniform_(m.weight)
192 | nn.init.constant_(m.bias, 0)
193 | return
194 |
195 | if __name__ == '__main__':
196 | model = MLCL_No_Sigmoid()
197 | x = torch.rand(8, 3, 512, 512)
198 | outs = model(x)
199 | print(outs.size())
200 |
201 |
--------------------------------------------------------------------------------
/model/MLCL/MLCL_small_no_sigmoid.py:
--------------------------------------------------------------------------------
1 | """
2 | @author: yuchuang,zhaojinmiao
3 | @time:
4 | @desc: 这个版本是MLCL-Net的轻量化版本。即每阶段只用了2个block paper:"Infrared small target detection based on multiscale local contrast learning networks"
5 | """
6 | import torch
7 | import torch.nn as nn
8 |
9 |
10 | class Resnet1(nn.Module):
11 | def __init__(self, in_channel, out_channel):
12 | super(Resnet1, self).__init__()
13 | self.layer = nn.Sequential(
14 | nn.Conv2d(in_channels=in_channel, out_channels=out_channel, kernel_size=3, padding=1, stride=1),
15 | nn.BatchNorm2d(out_channel),
16 | nn.ReLU(inplace=True),
17 | nn.Conv2d(in_channels=out_channel, out_channels=out_channel, kernel_size=3, padding=1, stride=1),
18 | nn.BatchNorm2d(out_channel)
19 | )
20 | self.relu = nn.ReLU(inplace=True)
21 | self.layer.apply(weights_init)
22 |
23 | def forward(self, x):
24 | identity = x
25 | out = self.layer(x)
26 | out += identity
27 | return self.relu(out)
28 |
29 |
30 | class Resnet2(nn.Module):
31 | def __init__(self, in_channel, out_channel):
32 | super(Resnet2, self).__init__()
33 | self.layer1 = nn.Sequential(
34 | nn.Conv2d(in_channels=in_channel, out_channels=out_channel, kernel_size=3, padding=1, stride=1),
35 | nn.BatchNorm2d(out_channel),
36 | nn.ReLU(inplace=True),
37 | nn.MaxPool2d(kernel_size=2, stride=2),
38 | nn.Conv2d(in_channels=out_channel, out_channels=out_channel, kernel_size=3, padding=1, stride=1),
39 | nn.BatchNorm2d(out_channel)
40 | )
41 | self.layer2 = nn.Sequential(
42 | nn.Conv2d(in_channels=in_channel, out_channels=out_channel, kernel_size=3, padding=1, stride=2),
43 | nn.BatchNorm2d(out_channel),
44 | nn.ReLU(inplace=True)
45 | )
46 | self.relu = nn.ReLU(inplace=True)
47 | self.layer1.apply(weights_init)
48 | self.layer2.apply(weights_init)
49 |
50 | def forward(self, x):
51 | identity = x
52 | out = self.layer1(x)
53 | identity = self.layer2(identity)
54 | out += identity
55 | return self.relu(out)
56 |
57 |
58 | class Stage(nn.Module):
59 | def __init__(self):
60 | super(Stage, self).__init__()
61 | self.layer1 = nn.Sequential(
62 | nn.Conv2d(in_channels=3, out_channels=16, kernel_size=3, padding=1, stride=1),
63 | nn.BatchNorm2d(16),
64 | nn.ReLU(inplace=True),
65 | nn.Conv2d(in_channels=16, out_channels=16, kernel_size=3, padding=1, stride=1),
66 | nn.BatchNorm2d(16),
67 | nn.ReLU(inplace=True)
68 | )
69 | self.resnet1_1 = Resnet1(in_channel=16, out_channel=16)
70 | self.resnet1_2 = Resnet1(in_channel=16, out_channel=16)
71 | self.resnet2_1 = Resnet2(in_channel=16, out_channel=32)
72 | self.resnet2_2 = Resnet1(in_channel=32, out_channel=32)
73 | self.resnet3_1 = Resnet2(in_channel=32, out_channel=64)
74 | self.resnet3_2 = Resnet1(in_channel=64, out_channel=64)
75 |
76 | def forward(self, x):
77 | outs = []
78 | out = self.layer1(x)
79 | out = self.resnet1_1(out)
80 | out = self.resnet1_2(out)
81 | outs.append(out)
82 | out = self.resnet2_1(out)
83 | out = self.resnet2_2(out)
84 | outs.append(out)
85 | out = self.resnet3_1(out)
86 | out = self.resnet3_2(out)
87 | outs.append(out)
88 | return outs
89 |
90 |
91 | class MLCL(nn.Module):
92 | def __init__(self, in_channel, out_channel):
93 | super(MLCL, self).__init__()
94 | self.layer1 = nn.Sequential(
95 | nn.Conv2d(in_channels=in_channel, out_channels=out_channel, kernel_size=1, padding=0, stride=1),
96 | nn.ReLU(inplace=True),
97 | nn.Conv2d(in_channels=out_channel, out_channels=out_channel, kernel_size=3, padding=1, stride=1, dilation=1),
98 | nn.ReLU(inplace=True)
99 | )
100 | self.layer2 = nn.Sequential(
101 | nn.Conv2d(in_channels=in_channel, out_channels=out_channel, kernel_size=3, padding=1, stride=1),
102 | nn.ReLU(inplace=True),
103 | nn.Conv2d(in_channels=out_channel, out_channels=out_channel, kernel_size=3, padding=3, stride=1, dilation=3),
104 | nn.ReLU(inplace=True)
105 | )
106 | self.layer3 = nn.Sequential(
107 | nn.Conv2d(in_channels=in_channel, out_channels=out_channel, kernel_size=5, padding=2, stride=1),
108 | nn.ReLU(inplace=True),
109 | nn.Conv2d(in_channels=out_channel, out_channels=out_channel, kernel_size=3, padding=5, stride=1, dilation=5),
110 | nn.ReLU(inplace=True)
111 | )
112 | # self.layer4 = nn.Sequential(
113 | # nn.Conv2d(in_channels=in_channel, out_channels=out_channel, kernel_size=7, padding=2, stride=1),
114 | # nn.ReLU(inplace=True),
115 | # nn.Conv2d(in_channels=out_channel, out_channels=out_channel, kernel_size=3, padding=7, stride=1,dilation=7),
116 | # nn.ReLU(inplace=True)
117 | # )
118 | self.conv = nn.Conv2d(in_channels=out_channel * 3, out_channels=out_channel, kernel_size=1)
119 | self.layer1.apply(weights_init)
120 | self.layer2.apply(weights_init)
121 | self.layer3.apply(weights_init)
122 | self.conv.apply(weights_init)
123 | def forward(self, x):
124 | x1 = x
125 | x2 = x
126 | x3 = x
127 | out1 = self.layer1(x1)
128 | out2 = self.layer2(x2)
129 | out3 = self.layer3(x3)
130 | outs = torch.cat((out1, out2, out3), dim=1)
131 | return self.conv(outs)
132 |
133 |
134 | class MLCL_small_No_Sigmoid(nn.Module):
135 | def __init__(self):
136 | super(MLCL_small_No_Sigmoid, self).__init__()
137 | self.stage = Stage()
138 | self.mlcl3 = MLCL(64, 64)
139 | self.mlcl2 = MLCL(32, 32)
140 | self.mlcl1 = MLCL(16, 16)
141 | self.up3 = nn.UpsamplingBilinear2d(scale_factor=2)
142 | self.conv3 = nn.Sequential(
143 | nn.Conv2d(in_channels=64, out_channels=64, kernel_size=1),
144 | nn.ReLU(inplace=True)
145 | )
146 | self.up2 = nn.UpsamplingBilinear2d(scale_factor=2)
147 | self.conv2 = nn.Sequential(
148 | nn.Conv2d(in_channels=32, out_channels=64, kernel_size=1),
149 | nn.ReLU(inplace=True)
150 | )
151 | self.conv1 = nn.Sequential(
152 | nn.Conv2d(in_channels=16, out_channels=64, kernel_size=1),
153 | nn.ReLU(inplace=True)
154 | )
155 | self.layer = nn.Sequential(
156 | nn.Conv2d(in_channels=64, out_channels=64, kernel_size=1),
157 | nn.ReLU(inplace=True),
158 | nn.Conv2d(in_channels=64, out_channels=1, kernel_size=1),
159 | #nn.Sigmoid()
160 | )
161 |
162 | def forward(self, x):
163 | outs = self.stage(x)
164 | out3 = self.mlcl3(outs[2])
165 | out2 = self.mlcl2(outs[1])
166 | out1 = self.mlcl1(outs[0])
167 | out3 = self.conv3(out3) # 128*128 64
168 | out3 = self.up3(out3) # 256*256 64
169 | out2 = self.conv2(out2) # 256*256 64
170 | out = out3 + out2
171 | out = self.up2(out)
172 | out1 = self.conv1(out1)
173 | out = out + out1
174 | out = self.layer(out)
175 | return out
176 |
177 | def weights_init(m):
178 | if isinstance(m, nn.Conv2d):
179 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
180 | if m.bias is not None:
181 | nn.init.constant_(m.bias, 0)
182 | elif isinstance(m, nn.BatchNorm2d):
183 | nn.init.constant_(m.weight, 1)
184 | nn.init.constant_(m.bias, 0)
185 | elif isinstance(m, nn.Linear):
186 | nn.init.xavier_uniform_(m.weight)
187 | nn.init.constant_(m.bias, 0)
188 | return
189 |
190 | if __name__ == '__main__':
191 | model = MLCL_small_No_Sigmoid()
192 | x = torch.rand(8, 3, 512, 512)
193 | outs = model(x)
194 | print(outs.size())
195 |
196 |
--------------------------------------------------------------------------------
/model/MLCL/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YuChuang1205/PAL/a9046c5fc65e7b2548d6e195b624239fa836cf3e/model/MLCL/__init__.py
--------------------------------------------------------------------------------
/model/MLCL/__pycache__/MLCL_no_sigmoid.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YuChuang1205/PAL/a9046c5fc65e7b2548d6e195b624239fa836cf3e/model/MLCL/__pycache__/MLCL_no_sigmoid.cpython-36.pyc
--------------------------------------------------------------------------------
/model/MLCL/__pycache__/__init__.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YuChuang1205/PAL/a9046c5fc65e7b2548d6e195b624239fa836cf3e/model/MLCL/__pycache__/__init__.cpython-36.pyc
--------------------------------------------------------------------------------
/model/MSDA/MSDA_no_sigmoid.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from mm.attention.SEAttention import SEAttention
4 | import pywt
5 | from torch.nn import init
6 | from torch.autograd import Function
7 |
8 |
9 | class SpatialAttention(nn.Module):
10 | def __init__(self, kernel_size=7):
11 | super().__init__()
12 | self.conv = nn.Conv2d(2, 1, kernel_size=kernel_size, padding=kernel_size // 2)
13 | self.sigmoid = nn.Sigmoid()
14 |
15 | def forward(self, x):
16 | max_result, _ = torch.max(x, dim=1, keepdim=True)
17 | avg_result = torch.mean(x, dim=1, keepdim=True)
18 | result = torch.cat([max_result, avg_result], 1)
19 | output = self.conv(result)
20 | output = self.sigmoid(output)
21 | return output
22 |
23 |
24 | ################################################################################################
25 | class LH_DWT_2D_attation(nn.Module):
26 | def __init__(self, wave):
27 | super(LH_DWT_2D_attation, self).__init__()
28 | w = pywt.Wavelet(wave)
29 | dec_hi = torch.Tensor(w.dec_hi[::-1])
30 | dec_lo = torch.Tensor(w.dec_lo[::-1])
31 | w_lh = dec_lo.unsqueeze(0) * dec_hi.unsqueeze(1)
32 | self.register_buffer('w_lh', w_lh.unsqueeze(0).unsqueeze(0))
33 | self.w_lh = self.w_lh.to(dtype=torch.float32)
34 |
35 | def forward(self, x):
36 | return LH_DWT_Function_attation.apply(x, self.w_lh)
37 |
38 |
39 | class LH_DWT_Function_attation(Function):
40 | @staticmethod
41 | def forward(ctx, x, w_lh):
42 | x = x.contiguous()
43 | ctx.save_for_backward(w_lh)
44 | ctx.shape = x.shape
45 | dim = x.shape[1]
46 | x = torch.nn.functional.pad(x, (1, 0, 1, 0))
47 | x_lh = torch.nn.functional.conv2d(x, w_lh.expand(dim, -1, -1, -1), stride=1, padding=0, groups=dim)
48 | x = x_lh
49 | return x
50 |
51 | @staticmethod
52 | def backward(ctx, dx):
53 | if ctx.needs_input_grad[0]:
54 | w_lh = ctx.saved_tensors[0]
55 | B, C, H, W = ctx.shape
56 | dx = dx.view(B, 1, -1, H, W)
57 | dx = dx.transpose(1, 2).reshape(B, -1, H, W)
58 | filters = w_lh # torch.Size([3, 1, 2, 2])
59 | filters = filters.repeat(C, 1, 1, 1).to(dtype=torch.float16)
60 | dx = torch.nn.functional.conv_transpose2d(dx, filters, stride=1, groups=C)
61 | dx = torch.cat((dx[:, :, :0], dx[:, :, 1:]), dim=2)
62 | dx = torch.cat((dx[:, :, :, :0], dx[:, :, :, 1:]), dim=3)
63 | return dx, None
64 |
65 |
66 | ########################################################################################
67 | class HL_DWT_2D_attation(nn.Module):
68 | def __init__(self, wave):
69 | super(HL_DWT_2D_attation, self).__init__()
70 | w = pywt.Wavelet(wave)
71 | dec_hi = torch.Tensor(w.dec_hi[::-1])
72 | dec_lo = torch.Tensor(w.dec_lo[::-1])
73 |
74 | w_hl = dec_hi.unsqueeze(0) * dec_lo.unsqueeze(1)
75 |
76 | self.register_buffer('w_hl', w_hl.unsqueeze(0).unsqueeze(0))
77 | self.w_hl = self.w_hl.to(dtype=torch.float32)
78 |
79 | def forward(self, x):
80 | return HL_DWT_Function_attation.apply(x, self.w_hl)
81 |
82 |
83 | class HL_DWT_Function_attation(Function):
84 | @staticmethod
85 | def forward(ctx, x, w_hl):
86 | x = x.contiguous()
87 | ctx.save_for_backward(w_hl)
88 | ctx.shape = x.shape
89 | dim = x.shape[1]
90 | x = torch.nn.functional.pad(x, (1, 0, 1, 0))
91 | x_hl = torch.nn.functional.conv2d(x, w_hl.expand(dim, -1, -1, -1), stride=1, padding=0, groups=dim)
92 | x = x_hl
93 | return x
94 |
95 | @staticmethod
96 | def backward(ctx, dx):
97 | if ctx.needs_input_grad[0]:
98 | w_hl = ctx.saved_tensors[0]
99 | B, C, H, W = ctx.shape
100 | dx = dx.view(B, 1, -1, H, W)
101 | dx = dx.transpose(1, 2).reshape(B, -1, H, W)
102 | filters = w_hl
103 | filters = filters.repeat(C, 1, 1, 1).to(dtype=torch.float16)
104 | dx = torch.nn.functional.conv_transpose2d(dx, filters, stride=1, groups=C)
105 | dx = torch.cat((dx[:, :, :0], dx[:, :, 1:]), dim=2)
106 | dx = torch.cat((dx[:, :, :, :0], dx[:, :, :, 1:]), dim=3)
107 | return dx, None
108 |
109 |
110 | #######################################################################################
111 | class HH_DWT_2D_attation(nn.Module):
112 | def __init__(self, wave):
113 | super(HH_DWT_2D_attation, self).__init__()
114 | w = pywt.Wavelet(wave)
115 | dec_hi = torch.Tensor(w.dec_hi[::-1])
116 | dec_lo = torch.Tensor(w.dec_lo[::-1])
117 | w_hh = dec_hi.unsqueeze(0) * dec_hi.unsqueeze(1)
118 | self.register_buffer('w_hh', w_hh.unsqueeze(0).unsqueeze(0))
119 | self.w_hh = self.w_hh.to(dtype=torch.float32)
120 |
121 | def forward(self, x):
122 | return HH_DWT_Function_attation.apply(x, self.w_hh)
123 |
124 |
125 | class HH_DWT_Function_attation(Function):
126 | @staticmethod
127 | def forward(ctx, x, w_hh):
128 | x = x.contiguous()
129 | ctx.save_for_backward(w_hh)
130 | ctx.shape = x.shape
131 | dim = x.shape[1]
132 | x = torch.nn.functional.pad(x, (1, 0, 1, 0))
133 | x_hh = torch.nn.functional.conv2d(x, w_hh.expand(dim, -1, -1, -1), stride=1, padding=0,
134 | groups=dim)
135 | x = x_hh
136 | return x
137 |
138 | @staticmethod
139 | def backward(ctx, dx):
140 | if ctx.needs_input_grad[0]:
141 | w_hh = ctx.saved_tensors[0]
142 | B, C, H, W = ctx.shape
143 | dx = dx.view(B, 1, -1, H, W)
144 | dx = dx.transpose(1, 2).reshape(B, -1, H, W)
145 | filters = w_hh # torch.Size([3, 1, 2, 2])
146 | filters = filters.repeat(C, 1, 1, 1).to(dtype=torch.float16)
147 | dx = torch.nn.functional.conv_transpose2d(dx, filters, stride=1, groups=C)
148 | dx = torch.cat((dx[:, :, :0], dx[:, :, 1:]), dim=2)
149 | dx = torch.cat((dx[:, :, :, :0], dx[:, :, :, 1:]), dim=3)
150 | return dx, None
151 |
152 |
153 | #######################################################################################
154 | class LL_DWT_2D_attation(nn.Module):
155 | def __init__(self, wave):
156 | super(LL_DWT_2D_attation, self).__init__()
157 | w = pywt.Wavelet(wave)
158 | dec_hi = torch.Tensor(w.dec_hi[::-1])
159 | dec_lo = torch.Tensor(w.dec_lo[::-1])
160 | w_ll = dec_lo.unsqueeze(0) * dec_lo.unsqueeze(1)
161 | self.register_buffer('w_ll', w_ll.unsqueeze(0).unsqueeze(0))
162 | self.w_ll = self.w_ll.to(dtype=torch.float32)
163 |
164 | def forward(self, x):
165 | return LL_DWT_Function_attation.apply(x, self.w_ll)
166 |
167 |
168 | class LL_DWT_Function_attation(Function):
169 | @staticmethod
170 | def forward(ctx, x, w_ll):
171 | x = x.contiguous()
172 | ctx.save_for_backward(w_ll)
173 | ctx.shape = x.shape
174 | dim = x.shape[1]
175 | x = torch.nn.functional.pad(x, (1, 0, 1, 0))
176 | x_ll = torch.nn.functional.conv2d(x, w_ll.expand(dim, -1, -1, -1), stride=1, padding=0,
177 | groups=dim)
178 | x = x_ll
179 | return x
180 |
181 | @staticmethod
182 | def backward(ctx, dx):
183 | if ctx.needs_input_grad[0]:
184 | w_ll = ctx.saved_tensors[0]
185 | B, C, H, W = ctx.shape
186 | dx = dx.view(B, 1, -1, H, W)
187 | dx = dx.transpose(1, 2).reshape(B, -1, H, W)
188 | filters = w_ll # torch.Size([3, 1, 2, 2])
189 | filters = filters.repeat(C, 1, 1, 1).to(dtype=torch.float16)
190 | dx = torch.nn.functional.conv_transpose2d(dx, filters, stride=1, groups=C)
191 | dx = torch.cat((dx[:, :, :0], dx[:, :, 1:]), dim=2)
192 | dx = torch.cat((dx[:, :, :, :0], dx[:, :, :, 1:]), dim=3)
193 | return dx, None
194 |
195 |
196 | #########################################################################################
197 |
198 | class FrequencyAttention(nn.Module): #
199 |
200 | def __init__(self, in_channel):
201 | super().__init__()
202 | # self.HH_reduce = nn.Sequential(
203 | # nn.Conv2d(in_channel, 1, kernel_size=3, padding=1, stride=1),
204 | # nn.Sigmoid(),
205 | # )
206 | self.HH_reduce = SpatialAttention()
207 | # self.LH_reduce = nn.Sequential(
208 | # nn.Conv2d(in_channel, 1, kernel_size=3, padding=1, stride=1),
209 | # nn.Sigmoid(),
210 | # )
211 | self.LH_reduce = SpatialAttention()
212 | # self.HL_reduce = nn.Sequential(
213 | # nn.Conv2d(in_channel, 1, kernel_size=3, padding=1, stride=1),
214 | # nn.Sigmoid(),
215 | # )
216 | self.HL_reduce = SpatialAttention()
217 | # self.LL_reduce = nn.Sequential(
218 | # nn.Conv2d(in_channel, 1, kernel_size=3, padding=1, stride=1),
219 | # nn.Sigmoid(),
220 | # )
221 | self.LL_reduce = SpatialAttention()
222 | self.conv_res = nn.Sequential(
223 | nn.Conv2d(in_channel, in_channel, kernel_size=3, padding=1, stride=1),
224 | nn.BatchNorm2d(in_channel),
225 | )
226 | self.relu = nn.ReLU(inplace=False)
227 | # self.fc =nn.Linear(in_channel*3, in_channel, bias=False)
228 | self.LH_DWT_2D_attation = LH_DWT_2D_attation('haar')
229 | self.HL_DWT_2D_attation = HL_DWT_2D_attation('haar')
230 | self.HH_DWT_2D_attation = HH_DWT_2D_attation('haar')
231 | self.LL_DWT_2D_attation = LL_DWT_2D_attation('haar')
232 | self.init_weights()
233 |
234 | def forward(self, x):
235 | out_LH = self.LH_DWT_2D_attation(x)
236 | out_LH = self.LH_reduce(out_LH)
237 | x_LH = x * out_LH
238 | # print(x_LH.size())
239 | out_HL = self.HL_DWT_2D_attation(x)
240 | out_HL = self.HL_reduce(out_HL)
241 | x_HL = x * out_HL
242 | out_HH = self.HH_DWT_2D_attation(x)
243 | out_HH = self.HH_reduce(out_HH)
244 | x_HH = x * out_HH
245 | out_LL = self.LL_DWT_2D_attation(x)
246 | out_LL = self.LL_reduce(out_LL)
247 | x_LL = x * out_LL
248 | x_out = x_LH + x_HL + x_HH + x_LL
249 | x_out = self.relu(x + self.conv_res(x_out))
250 | return x_out
251 |
252 | def init_weights(self):
253 | for m in self.modules():
254 | if isinstance(m, nn.Conv2d):
255 | init.kaiming_normal_(m.weight, mode='fan_out')
256 | if m.bias is not None:
257 | init.constant_(m.bias, 0)
258 | elif isinstance(m, nn.BatchNorm2d):
259 | init.constant_(m.weight, 1)
260 | init.constant_(m.bias, 0)
261 | elif isinstance(m, nn.Linear):
262 | init.normal_(m.weight, std=0.001)
263 | if m.bias is not None:
264 | init.constant_(m.bias, 0)
265 |
266 |
267 | ############################################################################################
268 |
269 | class DWT_2D(nn.Module):
270 | def __init__(self, wave):
271 | super(DWT_2D, self).__init__()
272 | w = pywt.Wavelet(wave)
273 | dec_hi = torch.Tensor(w.dec_hi[::-1])
274 | dec_lo = torch.Tensor(w.dec_lo[::-1])
275 |
276 | w_ll = dec_lo.unsqueeze(0) * dec_lo.unsqueeze(1)
277 | w_lh = dec_lo.unsqueeze(0) * dec_hi.unsqueeze(1)
278 | w_hl = dec_hi.unsqueeze(0) * dec_lo.unsqueeze(1)
279 | w_hh = dec_hi.unsqueeze(0) * dec_hi.unsqueeze(1)
280 |
281 | self.register_buffer('w_ll', w_ll.unsqueeze(0).unsqueeze(0))
282 | self.register_buffer('w_lh', w_lh.unsqueeze(0).unsqueeze(0))
283 | self.register_buffer('w_hl', w_hl.unsqueeze(0).unsqueeze(0))
284 | self.register_buffer('w_hh', w_hh.unsqueeze(0).unsqueeze(0))
285 |
286 | self.w_ll = self.w_ll.to(dtype=torch.float32)
287 | self.w_lh = self.w_lh.to(dtype=torch.float32)
288 | self.w_hl = self.w_hl.to(dtype=torch.float32)
289 | self.w_hh = self.w_hh.to(dtype=torch.float32)
290 |
291 | def forward(self, x):
292 | return DWT_Function.apply(x, self.w_ll, self.w_lh, self.w_hl, self.w_hh)
293 |
294 |
295 | class DWT_Function(Function):
296 | @staticmethod
297 | def forward(ctx, x, w_ll, w_lh, w_hl, w_hh):
298 | x = x.contiguous()
299 | ctx.save_for_backward(w_ll, w_lh, w_hl, w_hh)
300 | ctx.shape = x.shape
301 | dim = x.shape[1]
302 |
303 | x_ll = torch.nn.functional.conv2d(x, w_ll.expand(dim, -1, -1, -1), stride=2, padding=0, groups=dim)
304 | x_lh = torch.nn.functional.conv2d(x, w_lh.expand(dim, -1, -1, -1), stride=2, padding=0, groups=dim)
305 | x_hl = torch.nn.functional.conv2d(x, w_hl.expand(dim, -1, -1, -1), stride=2, padding=0, groups=dim)
306 | x_hh = torch.nn.functional.conv2d(x, w_hh.expand(dim, -1, -1, -1), stride=2, padding=0,
307 | groups=dim)
308 | x = torch.cat([x_lh, x_hl, x_hh], dim=1)
309 | # x = x_hh
310 | return x
311 |
312 | @staticmethod
313 | def backward(ctx, dx):
314 | if ctx.needs_input_grad[0]:
315 | w_ll, w_lh, w_hl, w_hh = ctx.saved_tensors
316 | B, C, H, W = ctx.shape
317 |
318 | dx = dx.view(B, 3, -1, H // 2, W // 2)
319 | dx = dx.transpose(1, 2).reshape(B, -1, H // 2, W // 2)
320 | filters = torch.cat([w_lh, w_hl, w_hh], dim=0)
321 | filters = filters.repeat(C, 1, 1, 1).to(dtype=torch.float16)
322 | dx = torch.nn.functional.conv_transpose2d(dx, filters, stride=2, groups=C)
323 |
324 | return dx, None, None, None, None
325 |
326 |
327 | class Hfrequencyfeature(nn.Module): #
328 |
329 | def __init__(self):
330 | super().__init__()
331 |
332 | self.DWT_2D = DWT_2D('haar')
333 |
334 | def forward(self, x): # x:torch.Size([8, 16, 512, 512])
335 | out = self.DWT_2D(x) # torch.Size([8, 48, 512, 512])
336 | return out
337 |
338 |
339 | def weights_init(m):
340 | if isinstance(m, nn.Conv2d):
341 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
342 | if m.bias is not None:
343 | nn.init.constant_(m.bias, 0)
344 | elif isinstance(m, nn.BatchNorm2d):
345 | nn.init.constant_(m.weight, 1)
346 | nn.init.constant_(m.bias, 0)
347 |
348 | return
349 |
350 |
351 | class BasicConv2d(nn.Module):
352 | def __init__(self, in_planes, out_planes, kernel_size, stride=1, padding=0, dilation=1):
353 | super(BasicConv2d, self).__init__()
354 | self.conv = nn.Conv2d(in_planes, out_planes,
355 | kernel_size=kernel_size, stride=stride,
356 | padding=padding, dilation=dilation, bias=False)
357 | self.bn = nn.BatchNorm2d(out_planes)
358 | self.relu = nn.ReLU(inplace=False)
359 |
360 | def forward(self, x):
361 | x = self.conv(x)
362 | x = self.bn(x)
363 | x = self.relu(x)
364 | return x
365 |
366 |
367 | ##############################################################################
368 | class RFB_modified(nn.Module):
369 | def __init__(self, in_channel, out_channel):
370 | super(RFB_modified, self).__init__()
371 | self.relu = nn.ReLU(inplace=False)
372 | self.branch0 = nn.Sequential(
373 | BasicConv2d(in_channel, out_channel, 1),
374 | )
375 | self.branch1 = nn.Sequential(
376 | BasicConv2d(in_channel, out_channel, 1),
377 | BasicConv2d(out_channel, out_channel, 3, padding=1, dilation=1)
378 | )
379 |
380 | self.branch2 = nn.Sequential(
381 | BasicConv2d(in_channel, out_channel, kernel_size=(3, 3), padding=1),
382 | BasicConv2d(out_channel, out_channel, 3, padding=3, dilation=3)
383 | )
384 | self.branch3 = nn.Sequential(
385 | BasicConv2d(in_channel, out_channel, kernel_size=(5, 5), padding=2),
386 | BasicConv2d(out_channel, out_channel, 3, padding=5, dilation=5)
387 | )
388 |
389 | self.conv_cat = BasicConv2d(4*out_channel, out_channel, 3, padding=1)
390 | self.conv_res = BasicConv2d(in_channel, out_channel, 3, padding=1)
391 | self.FrequencyAttention = FrequencyAttention(in_channel=out_channel)
392 | self.SEAttention = SEAttention(channel=out_channel, reduction=4)
393 |
394 | def forward(self, x):
395 | x0 = self.branch0(x)
396 | x1 = self.branch1(x)
397 | x2 = self.branch2(x)
398 | x3 = self.branch3(x)
399 | x_cat = self.conv_cat(torch.cat((x0, x1, x2, x3), 1))
400 | x_cat = self.FrequencyAttention(x_cat)
401 | # print(x_cat.shape)
402 | x_cat = self.SEAttention(x_cat)
403 | x = self.relu(x+ self.conv_res(x_cat))
404 | return x
405 |
406 |
407 | class RFB_modified_LCL(nn.Module):
408 | def __init__(self, in_channel, out_channel):
409 | super(RFB_modified_LCL, self).__init__()
410 | self.relu = nn.ReLU(inplace=False)
411 | self.branch0 = nn.Sequential(
412 | BasicConv2d(in_channel, out_channel, 1),
413 | )
414 | self.branch1 = nn.Sequential(
415 | BasicConv2d(in_channel, out_channel, kernel_size=(3, 3), padding=1),
416 | BasicConv2d(out_channel, out_channel, 3, padding=3, dilation=3)
417 | )
418 | self.branch2 = nn.Sequential(
419 | BasicConv2d(in_channel, out_channel, kernel_size=(5, 5), padding=2),
420 | BasicConv2d(out_channel, out_channel, 3, padding=5, dilation=5)
421 | )
422 |
423 | self.conv_cat = BasicConv2d(3 * out_channel, out_channel, 3, padding=1)
424 | self.conv_res = BasicConv2d(in_channel, out_channel, 1)
425 |
426 | def forward(self, x):
427 | x0 = self.branch0(x)
428 | x1 = self.branch1(x)
429 | x2 = self.branch2(x)
430 | # x3 = self.branch3(x)
431 | x_cat = self.conv_cat(torch.cat((x0, x1, x2), 1))
432 | # print(x_cat.size())
433 | x_cat = self.relu(x_cat)
434 | return x_cat
435 |
436 |
437 | #############################################################################################################
438 | class Resnet1(nn.Module):
439 | def __init__(self, in_channel, out_channel):
440 | super(Resnet1, self).__init__()
441 | self.layer = nn.Sequential(
442 | nn.Conv2d(in_channels=in_channel, out_channels=out_channel, kernel_size=3, padding=1, stride=1),
443 | nn.BatchNorm2d(out_channel),
444 | nn.ReLU(inplace=False),
445 | nn.Conv2d(in_channels=out_channel, out_channels=out_channel, kernel_size=3, padding=1, stride=1),
446 | nn.BatchNorm2d(out_channel)
447 | )
448 | self.relu = nn.ReLU(inplace=False)
449 |
450 | self.layer.apply(weights_init)
451 |
452 | def forward(self, x):
453 | identity = x
454 | out = self.layer(x)
455 | out += identity
456 | return self.relu(out)
457 |
458 |
459 | # layer2_1 #layer3_1#layer4_1#layer5_1
460 | class Resnet2(nn.Module):
461 | def __init__(self, in_channel, out_channel):
462 | super(Resnet2, self).__init__()
463 | self.layer1 = nn.Sequential(
464 | nn.Conv2d(in_channels=in_channel, out_channels=out_channel, kernel_size=3, padding=1, stride=1),
465 | nn.BatchNorm2d(out_channel),
466 | nn.ReLU(inplace=False),
467 | nn.MaxPool2d(kernel_size=2, stride=2),
468 | nn.Conv2d(in_channels=out_channel, out_channels=out_channel, kernel_size=3, padding=1, stride=1),
469 | nn.BatchNorm2d(out_channel)
470 | )
471 | self.layer2 = nn.Sequential(
472 | nn.Conv2d(in_channels=in_channel, out_channels=out_channel, kernel_size=3, padding=1, stride=2),
473 | nn.BatchNorm2d(out_channel),
474 | nn.ReLU(inplace=False)
475 | )
476 | self.relu = nn.ReLU(inplace=False)
477 |
478 | self.layer1.apply(weights_init)
479 | self.layer2.apply(weights_init)
480 |
481 | def forward(self, x):
482 | identity = x
483 | out = self.layer1(x)
484 | identity = self.layer2(identity)
485 | # print("out:",out.size())
486 | # print("identity:",identity.size())
487 | out += identity
488 | return self.relu(out)
489 |
490 |
491 | class Hfrequency(nn.Module):
492 | def __init__(self):
493 | super(Hfrequency, self).__init__()
494 | self.Hfrequencyfeature = Hfrequencyfeature()
495 |
496 | def forward(self, x, out_2):
497 | out_1 = self.Hfrequencyfeature(x)
498 | out = torch.cat([out_1, out_2], dim=1)
499 | return out
500 |
501 |
502 | class Stage(nn.Module):
503 | def __init__(self):
504 | super(Stage, self).__init__()
505 | self.layer1 = nn.Sequential(
506 | nn.Conv2d(in_channels=3, out_channels=16, kernel_size=3, padding=1, stride=1),
507 | nn.BatchNorm2d(16),
508 | nn.ReLU(inplace=False),
509 | nn.Conv2d(in_channels=16, out_channels=16, kernel_size=3, padding=1, stride=1),
510 | nn.BatchNorm2d(16),
511 | nn.ReLU(inplace=False)
512 | )
513 | self.resnet1_1 = Resnet1(in_channel=16, out_channel=16)
514 | self.resnet1_2 = RFB_modified(in_channel=16, out_channel=16)
515 | self.resnet1_3 = RFB_modified(in_channel=16, out_channel=16)
516 | self.layer1_4 = nn.Sequential(
517 | nn.MaxPool2d(kernel_size=2, stride=2),
518 | nn.MaxPool2d(kernel_size=2, stride=2),
519 | nn.MaxPool2d(kernel_size=2, stride=2),
520 | nn.MaxPool2d(kernel_size=2, stride=2),
521 | nn.Conv2d(in_channels=16, out_channels=16, kernel_size=1),
522 | nn.ReLU(inplace=False),
523 | )
524 |
525 |
526 | self.resnet2_1 = Resnet2(in_channel=16, out_channel=23)
527 | self.hfrequency = Hfrequency()
528 | self.resnet2_2 = RFB_modified(in_channel=32, out_channel=32)
529 | self.resnet2_3 = RFB_modified(in_channel=32, out_channel=32)
530 | self.layer2_4 = nn.Sequential(
531 | nn.MaxPool2d(kernel_size=2, stride=2),
532 | nn.MaxPool2d(kernel_size=2, stride=2),
533 | nn.MaxPool2d(kernel_size=2, stride=2),
534 | nn.Conv2d(in_channels=32, out_channels=16, kernel_size=1),
535 | nn.ReLU(inplace=False),
536 | )
537 |
538 |
539 | self.resnet3_1 = Resnet2(in_channel=32, out_channel=64)
540 | self.resnet3_2 = RFB_modified(in_channel=64, out_channel=64)
541 | self.resnet3_3 = RFB_modified(in_channel=64, out_channel=64)
542 | self.layer3_4 = nn.Sequential(
543 | nn.MaxPool2d(kernel_size=2, stride=2),
544 | nn.MaxPool2d(kernel_size=2, stride=2),
545 | nn.Conv2d(in_channels=64, out_channels=16, kernel_size=1),
546 | nn.ReLU(inplace=False),
547 | )
548 |
549 |
550 | self.resnet4_1 = Resnet2(in_channel=64, out_channel=64)
551 | self.resnet4_2 = RFB_modified(in_channel=64, out_channel=64)
552 | self.resnet4_3 = RFB_modified(in_channel=64, out_channel=64)
553 | self.layer4_4 = nn.Sequential(
554 | nn.MaxPool2d(kernel_size=2, stride=2),
555 | nn.Conv2d(in_channels=64, out_channels=16, kernel_size=1),
556 | nn.ReLU(inplace=False),
557 | )
558 |
559 |
560 | self.resnet5_1 = Resnet2(in_channel=64, out_channel=64)
561 | self.resnet5_2 = RFB_modified(in_channel=64, out_channel=64)
562 | self.resnet5_3 = RFB_modified(in_channel=64, out_channel=64)
563 | self.layer5_4 = nn.Sequential(
564 | nn.Conv2d(in_channels=128, out_channels=128, kernel_size=1),
565 | nn.ReLU(inplace=False),
566 | nn.Conv2d(in_channels=128, out_channels=64, kernel_size=1),
567 | nn.ReLU(inplace=False),
568 |
569 | )
570 |
571 |
572 |
573 | self.layer1.apply(weights_init)
574 | self.layer1_4.apply(weights_init)
575 | self.layer2_4.apply(weights_init)
576 | self.layer3_4.apply(weights_init)
577 | self.layer4_4.apply(weights_init)
578 | self.layer5_4.apply(weights_init)
579 |
580 | def forward(self, x):
581 | outs = []
582 | out = self.layer1(x)
583 |
584 | # print(out_1.size())
585 | out = self.resnet1_1(out)
586 | out = self.resnet1_2(out)
587 | out = self.resnet1_3(out)
588 | out_1 = self.layer1_4(out)
589 |
590 | outs.append(out)
591 | out = self.resnet2_1(out)
592 | out = self.hfrequency(x,out)
593 | out = self.resnet2_2(out)
594 | out = self.resnet2_3(out)
595 | out_2 = self.layer2_4(out)
596 |
597 | outs.append(out)
598 | out = self.resnet3_1(out)
599 | out = self.resnet3_2(out)
600 | out = self.resnet3_3(out)
601 | out_3 = self.layer3_4(out)
602 |
603 | outs.append(out)
604 | out = self.resnet4_1(out)
605 | out = self.resnet4_2(out)
606 | out = self.resnet4_3(out)
607 | out_4 = self.layer4_4(out)
608 |
609 | outs.append(out)
610 | out = self.resnet5_1(out)
611 | out = self.resnet5_2(out)
612 | out = self.resnet5_3(out)
613 | out = torch.cat([out, out_4, out_3,out_2,out_1], dim=1)
614 | out = self.layer5_4(out)
615 |
616 | outs.append(out)
617 | return outs
618 |
619 | class Sbam(nn.Module):
620 | def __init__(self, in_channel, out_channel):
621 | super(Sbam, self).__init__()
622 | self.hl_up = nn.Sequential(
623 | nn.UpsamplingBilinear2d(scale_factor=2),
624 | )
625 | self.hl_layer = nn.Sequential(
626 | nn.Conv2d(in_channels=in_channel, out_channels=out_channel, kernel_size=1),
627 | nn.BatchNorm2d(out_channel),
628 | nn.ReLU(inplace=False)
629 | )
630 | self.concat_layer = nn.Sequential(
631 | nn.Conv2d(in_channels=in_channel+out_channel, out_channels=in_channel, kernel_size=1),
632 | nn.BatchNorm2d(in_channel),
633 | nn.ReLU(inplace=False),
634 | nn.Conv2d(in_channels=in_channel, out_channels=in_channel, kernel_size=3, padding=1, stride=1),
635 | nn.BatchNorm2d(in_channel),
636 | nn.ReLU(inplace=False),
637 | )
638 |
639 |
640 | self.hl_layer.apply(weights_init)
641 | self.concat_layer.apply(weights_init)
642 |
643 | def forward(self, hl, ll):
644 | hl = self.hl_up(hl)
645 | concat = torch.cat((hl, ll), 1)
646 | k = self.concat_layer(concat)
647 | hl = hl + k
648 | hl = self.hl_layer(hl)
649 | out = ll + hl
650 | return out
651 |
652 |
653 | class MSDA_No_Sigmoid(nn.Module):
654 | def __init__(self):
655 | super(MSDA_No_Sigmoid, self).__init__()
656 | self.stage = Stage()
657 | self.mlcl5 = RFB_modified_LCL(64, 64)
658 | self.mlcl4 = RFB_modified_LCL(64, 64)
659 | self.mlcl3 = RFB_modified_LCL(64, 64)
660 | self.mlcl2 = RFB_modified_LCL(32, 32)
661 | self.mlcl1 = RFB_modified_LCL(16, 16)
662 |
663 | self.sbam4 = Sbam(64, 64)
664 | self.sbam3 = Sbam(64, 64)
665 | self.sbam2 = Sbam(64, 32)
666 | self.sbam1 = Sbam(32, 16)
667 |
668 | self.layer = nn.Sequential(
669 | nn.Conv2d(in_channels=16, out_channels=16, kernel_size=1),
670 | nn.ReLU(inplace=False),
671 | nn.Conv2d(in_channels=16, out_channels=1, kernel_size=1),
672 | # nn.Sigmoid()
673 | )
674 |
675 | self.layer.apply(weights_init)
676 |
677 | def forward(self, x):
678 | outs = self.stage(x)
679 |
680 | out5 = self.mlcl5(outs[4])
681 | out4 = self.mlcl4(outs[3])
682 | out3 = self.mlcl3(outs[2])
683 | out2 = self.mlcl2(outs[1])
684 | out1 = self.mlcl1(outs[0])
685 |
686 | out4_2 = self.sbam4(out5, out4)
687 | out3_2 = self.sbam3(out4_2, out3)
688 | out2_2 = self.sbam2(out3_2, out2)
689 | out1_2 = self.sbam1(out2_2, out1)
690 | out = self.layer(out1_2)
691 |
692 | return out
693 |
694 |
695 | def weights_init(m):
696 | if isinstance(m, nn.Conv2d):
697 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
698 | if m.bias is not None:
699 | nn.init.constant_(m.bias, 0)
700 | elif isinstance(m, nn.BatchNorm2d):
701 | nn.init.constant_(m.weight, 1)
702 | nn.init.constant_(m.bias, 0)
703 | # elif isinstance(m, nn.Linear):
704 | # nn.init.xavier_uniform_(m.weight)
705 | # nn.init.constant_(m.bias, 0)
706 | return
707 |
708 |
709 | if __name__ == '__main__':
710 | model = MSDA_No_Sigmoid()
711 | x = torch.rand(8, 3, 512, 512)
712 | outs = model(x)
713 | print(outs.size())
714 |
--------------------------------------------------------------------------------
/model/MSDA/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YuChuang1205/PAL/a9046c5fc65e7b2548d6e195b624239fa836cf3e/model/MSDA/__init__.py
--------------------------------------------------------------------------------
/model/MSDA/__pycache__/MSDA_no_sigmoid.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YuChuang1205/PAL/a9046c5fc65e7b2548d6e195b624239fa836cf3e/model/MSDA/__pycache__/MSDA_no_sigmoid.cpython-36.pyc
--------------------------------------------------------------------------------
/model/MSDA/__pycache__/MSDA_no_sigmoid.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YuChuang1205/PAL/a9046c5fc65e7b2548d6e195b624239fa836cf3e/model/MSDA/__pycache__/MSDA_no_sigmoid.cpython-38.pyc
--------------------------------------------------------------------------------
/model/MSDA/__pycache__/__init__.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YuChuang1205/PAL/a9046c5fc65e7b2548d6e195b624239fa836cf3e/model/MSDA/__pycache__/__init__.cpython-36.pyc
--------------------------------------------------------------------------------
/model/MSDA/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YuChuang1205/PAL/a9046c5fc65e7b2548d6e195b624239fa836cf3e/model/MSDA/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/model/UIU/UIU_no_sigmoid.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | #############################################
5 | from .fusion import *
6 | ##################################################
7 |
8 | class REBNCONV(nn.Module):
9 | def __init__(self,in_ch=3,out_ch=3,dirate=1):
10 | super(REBNCONV,self).__init__()
11 |
12 | self.conv_s1 = nn.Conv2d(in_ch,out_ch,3,padding=1*dirate,dilation=1*dirate)
13 | self.bn_s1 = nn.BatchNorm2d(out_ch)
14 | self.relu_s1 = nn.ReLU(inplace=True)
15 |
16 | def forward(self,x):
17 |
18 | hx = x
19 | xout = self.relu_s1(self.bn_s1(self.conv_s1(hx)))
20 |
21 | return xout
22 |
23 | def _upsample_like(src,tar):
24 |
25 | src = F.upsample(src, size=tar.shape[2:], mode='bilinear')
26 |
27 | return src
28 |
29 |
30 | ### RSU-7 ###
31 | class RSU7(nn.Module):#UNet07DRES(nn.Module):
32 |
33 | def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
34 | super(RSU7,self).__init__()
35 |
36 | self.rebnconvin = REBNCONV(in_ch,out_ch,dirate=1)
37 |
38 | self.rebnconv1 = REBNCONV(out_ch,mid_ch,dirate=1)
39 | self.pool1 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
40 |
41 | self.rebnconv2 = REBNCONV(mid_ch,mid_ch,dirate=1)
42 | self.pool2 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
43 |
44 | self.rebnconv3 = REBNCONV(mid_ch,mid_ch,dirate=1)
45 | self.pool3 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
46 |
47 | self.rebnconv4 = REBNCONV(mid_ch,mid_ch,dirate=1)
48 | self.pool4 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
49 |
50 | self.rebnconv5 = REBNCONV(mid_ch,mid_ch,dirate=1)
51 | self.pool5 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
52 |
53 | self.rebnconv6 = REBNCONV(mid_ch,mid_ch,dirate=1)
54 |
55 | self.rebnconv7 = REBNCONV(mid_ch,mid_ch,dirate=2)
56 |
57 | self.rebnconv6d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
58 | self.rebnconv5d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
59 | self.rebnconv4d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
60 | self.rebnconv3d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
61 | self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
62 | self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1)
63 |
64 | def forward(self,x):
65 |
66 | hx = x
67 | hxin = self.rebnconvin(hx)
68 |
69 | hx1 = self.rebnconv1(hxin)
70 | hx = self.pool1(hx1)
71 |
72 | hx2 = self.rebnconv2(hx)
73 | hx = self.pool2(hx2)
74 |
75 | hx3 = self.rebnconv3(hx)
76 | hx = self.pool3(hx3)
77 |
78 | hx4 = self.rebnconv4(hx)
79 | hx = self.pool4(hx4)
80 |
81 | hx5 = self.rebnconv5(hx)
82 | hx = self.pool5(hx5)
83 |
84 | hx6 = self.rebnconv6(hx)
85 |
86 | hx7 = self.rebnconv7(hx6)
87 |
88 | hx6d = self.rebnconv6d(torch.cat((hx7,hx6),1))
89 | hx6dup = _upsample_like(hx6d,hx5)
90 |
91 | hx5d = self.rebnconv5d(torch.cat((hx6dup,hx5),1))
92 | hx5dup = _upsample_like(hx5d,hx4)
93 |
94 | hx4d = self.rebnconv4d(torch.cat((hx5dup,hx4),1))
95 | hx4dup = _upsample_like(hx4d,hx3)
96 |
97 | hx3d = self.rebnconv3d(torch.cat((hx4dup,hx3),1))
98 | hx3dup = _upsample_like(hx3d,hx2)
99 |
100 | hx2d = self.rebnconv2d(torch.cat((hx3dup,hx2),1))
101 | hx2dup = _upsample_like(hx2d,hx1)
102 |
103 | hx1d = self.rebnconv1d(torch.cat((hx2dup,hx1),1))
104 |
105 | return hx1d + hxin
106 |
107 | ### RSU-6 ###
108 | class RSU6(nn.Module):
109 |
110 | def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
111 | super(RSU6,self).__init__()
112 |
113 | self.rebnconvin = REBNCONV(in_ch,out_ch,dirate=1)
114 |
115 | self.rebnconv1 = REBNCONV(out_ch,mid_ch,dirate=1)
116 | self.pool1 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
117 |
118 | self.rebnconv2 = REBNCONV(mid_ch,mid_ch,dirate=1)
119 | self.pool2 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
120 |
121 | self.rebnconv3 = REBNCONV(mid_ch,mid_ch,dirate=1)
122 | self.pool3 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
123 |
124 | self.rebnconv4 = REBNCONV(mid_ch,mid_ch,dirate=1)
125 | self.pool4 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
126 |
127 | self.rebnconv5 = REBNCONV(mid_ch,mid_ch,dirate=1)
128 |
129 | self.rebnconv6 = REBNCONV(mid_ch,mid_ch,dirate=2)
130 |
131 | self.rebnconv5d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
132 | self.rebnconv4d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
133 | self.rebnconv3d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
134 | self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
135 | self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1)
136 |
137 | def forward(self,x):
138 |
139 | hx = x
140 |
141 | hxin = self.rebnconvin(hx)
142 |
143 | hx1 = self.rebnconv1(hxin)
144 | hx = self.pool1(hx1)
145 |
146 | hx2 = self.rebnconv2(hx)
147 | hx = self.pool2(hx2)
148 |
149 | hx3 = self.rebnconv3(hx)
150 | hx = self.pool3(hx3)
151 |
152 | hx4 = self.rebnconv4(hx)
153 | hx = self.pool4(hx4)
154 |
155 | hx5 = self.rebnconv5(hx)
156 |
157 | hx6 = self.rebnconv6(hx5)
158 |
159 | hx5d = self.rebnconv5d(torch.cat((hx6,hx5),1))
160 | hx5dup = _upsample_like(hx5d,hx4)
161 |
162 | hx4d = self.rebnconv4d(torch.cat((hx5dup,hx4),1))
163 | hx4dup = _upsample_like(hx4d,hx3)
164 |
165 | hx3d = self.rebnconv3d(torch.cat((hx4dup,hx3),1))
166 | hx3dup = _upsample_like(hx3d,hx2)
167 |
168 | hx2d = self.rebnconv2d(torch.cat((hx3dup,hx2),1))
169 | hx2dup = _upsample_like(hx2d,hx1)
170 |
171 | hx1d = self.rebnconv1d(torch.cat((hx2dup,hx1),1))
172 |
173 | return hx1d + hxin
174 |
175 | ### RSU-5 ###
176 | class RSU5(nn.Module):
177 |
178 | def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
179 | super(RSU5,self).__init__()
180 |
181 | self.rebnconvin = REBNCONV(in_ch,out_ch,dirate=1)
182 |
183 | self.rebnconv1 = REBNCONV(out_ch,mid_ch,dirate=1)
184 | self.pool1 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
185 |
186 | self.rebnconv2 = REBNCONV(mid_ch,mid_ch,dirate=1)
187 | self.pool2 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
188 |
189 | self.rebnconv3 = REBNCONV(mid_ch,mid_ch,dirate=1)
190 | self.pool3 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
191 |
192 | self.rebnconv4 = REBNCONV(mid_ch,mid_ch,dirate=1)
193 |
194 | self.rebnconv5 = REBNCONV(mid_ch,mid_ch,dirate=2)
195 |
196 | self.rebnconv4d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
197 | self.rebnconv3d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
198 | self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
199 | self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1)
200 |
201 | def forward(self,x):
202 |
203 | hx = x
204 |
205 | hxin = self.rebnconvin(hx)
206 |
207 | hx1 = self.rebnconv1(hxin)
208 | hx = self.pool1(hx1)
209 |
210 | hx2 = self.rebnconv2(hx)
211 | hx = self.pool2(hx2)
212 |
213 | hx3 = self.rebnconv3(hx)
214 | hx = self.pool3(hx3)
215 |
216 | hx4 = self.rebnconv4(hx)
217 |
218 | hx5 = self.rebnconv5(hx4)
219 |
220 | hx4d = self.rebnconv4d(torch.cat((hx5,hx4), 1))
221 | hx4dup = _upsample_like(hx4d, hx3)
222 |
223 | hx3d = self.rebnconv3d(torch.cat((hx4dup, hx3), 1))
224 | hx3dup = _upsample_like(hx3d, hx2)
225 |
226 | hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1))
227 | hx2dup = _upsample_like(hx2d, hx1)
228 |
229 | hx1d = self.rebnconv1d(torch.cat((hx2dup,hx1),1))
230 |
231 | return hx1d + hxin
232 |
233 | ### RSU-4 ###
234 | class RSU4(nn.Module):
235 |
236 | def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
237 | super(RSU4,self).__init__()
238 |
239 | self.rebnconvin = REBNCONV(in_ch,out_ch,dirate=1)
240 |
241 | self.rebnconv1 = REBNCONV(out_ch,mid_ch,dirate=1)
242 | self.pool1 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
243 |
244 | self.rebnconv2 = REBNCONV(mid_ch,mid_ch,dirate=1)
245 | self.pool2 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
246 |
247 | self.rebnconv3 = REBNCONV(mid_ch,mid_ch,dirate=1)
248 |
249 | self.rebnconv4 = REBNCONV(mid_ch,mid_ch,dirate=2)
250 |
251 | self.rebnconv3d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
252 | self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
253 | self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1)
254 |
255 | def forward(self,x):
256 |
257 | hx = x
258 |
259 | hxin = self.rebnconvin(hx)
260 |
261 | hx1 = self.rebnconv1(hxin)
262 | hx = self.pool1(hx1)
263 |
264 | hx2 = self.rebnconv2(hx)
265 | hx = self.pool2(hx2)
266 |
267 | hx3 = self.rebnconv3(hx)
268 |
269 | hx4 = self.rebnconv4(hx3)
270 |
271 | hx3d = self.rebnconv3d(torch.cat((hx4,hx3), 1))
272 | hx3dup = _upsample_like(hx3d, hx2)
273 |
274 | hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1))
275 | hx2dup = _upsample_like(hx2d, hx1)
276 |
277 | hx1d = self.rebnconv1d(torch.cat((hx2dup,hx1),1))
278 |
279 | return hx1d + hxin
280 |
281 | ### RSU-4F ###
282 | class RSU4F(nn.Module):
283 |
284 | def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
285 | super(RSU4F,self).__init__()
286 |
287 | self.rebnconvin = REBNCONV(in_ch,out_ch,dirate=1)
288 |
289 | self.rebnconv1 = REBNCONV(out_ch,mid_ch,dirate=1)
290 | self.rebnconv2 = REBNCONV(mid_ch,mid_ch,dirate=2)
291 | self.rebnconv3 = REBNCONV(mid_ch,mid_ch,dirate=4)
292 |
293 | self.rebnconv4 = REBNCONV(mid_ch,mid_ch,dirate=8)
294 |
295 | self.rebnconv3d = REBNCONV(mid_ch*2,mid_ch,dirate=4)
296 | self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=2)
297 | self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1)
298 |
299 | def forward(self,x):
300 |
301 | hx = x
302 |
303 | hxin = self.rebnconvin(hx)
304 |
305 | hx1 = self.rebnconv1(hxin)
306 | hx2 = self.rebnconv2(hx1)
307 | hx3 = self.rebnconv3(hx2)
308 |
309 | hx4 = self.rebnconv4(hx3)
310 |
311 | hx3d = self.rebnconv3d(torch.cat((hx4,hx3),1))
312 | hx2d = self.rebnconv2d(torch.cat((hx3d,hx2),1))
313 | hx1d = self.rebnconv1d(torch.cat((hx2d,hx1),1))
314 |
315 | return hx1d + hxin
316 |
317 |
318 | ##### UIU-net ####
319 | class UIU_No_Sigmoid(nn.Module):
320 |
321 | def __init__(self, in_ch=3, out_ch=1,mode='test'):
322 | super(UIU_No_Sigmoid, self).__init__()
323 | self.mode = mode
324 |
325 | self.stage1 = RSU7(in_ch,32,64)
326 | self.pool12 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
327 |
328 | self.stage2 = RSU6(64,32,128)
329 | self.pool23 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
330 |
331 | self.stage3 = RSU5(128,64,256)
332 | self.pool34 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
333 |
334 | self.stage4 = RSU4(256,128,512)
335 | self.pool45 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
336 |
337 | self.stage5 = RSU4F(512,256,512)
338 | self.pool56 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
339 |
340 | self.stage6 = RSU4F(512,256,512)
341 |
342 | # decoder
343 | self.stage5d = RSU4F(1024,256,512)
344 | self.stage4d = RSU4(1024,128,256)
345 | self.stage3d = RSU5(512,64,128)
346 | self.stage2d = RSU6(256,32,64)
347 | self.stage1d = RSU7(128,16,64)
348 |
349 | self.side1 = nn.Conv2d(64,out_ch,3,padding=1)
350 | self.side2 = nn.Conv2d(64,out_ch,3,padding=1)
351 | self.side3 = nn.Conv2d(128,out_ch,3,padding=1)
352 | self.side4 = nn.Conv2d(256,out_ch,3,padding=1)
353 | self.side5 = nn.Conv2d(512,out_ch,3,padding=1)
354 | self.side6 = nn.Conv2d(512,out_ch,3,padding=1)
355 |
356 | self.outconv = nn.Conv2d(6*out_ch,out_ch,1)
357 |
358 | #self.fuse6 = self._fuse_layer(512, 512, 512, fuse_mode='AsymBi')
359 | self.fuse5 = self._fuse_layer(512, 512, 512, fuse_mode='AsymBi')
360 | self.fuse4 = self._fuse_layer(512, 512, 512, fuse_mode='AsymBi')
361 | self.fuse3 = self._fuse_layer(256, 256, 256, fuse_mode='AsymBi')
362 | self.fuse2 = self._fuse_layer(128, 128, 128, fuse_mode='AsymBi')
363 |
364 |
365 | def _fuse_layer(self, in_high_channels, in_low_channels, out_channels,fuse_mode='AsymBi'):#fuse_mode='AsymBi'
366 | # assert fuse_mode in ['BiLocal', 'AsymBi', 'BiGlobal']
367 | # if fuse_mode == 'BiLocal':
368 | # fuse_layer = BiLocalChaFuseReduce(in_high_channels, in_low_channels, out_channels)
369 | # el
370 | if fuse_mode == 'AsymBi':
371 | fuse_layer = AsymBiChaFuseReduce(in_high_channels, in_low_channels, out_channels)
372 | # elif fuse_mode == 'BiGlobal':
373 | # fuse_layer = BiGlobalChaFuseReduce(in_high_channels, in_low_channels, out_channels)
374 | else:
375 | NameError
376 | return fuse_layer
377 |
378 |
379 | def forward(self, x):
380 |
381 | hx = x
382 |
383 | #stage 1
384 | hx1 = self.stage1(hx)
385 | hx = self.pool12(hx1)
386 |
387 | #stage 2
388 | hx2 = self.stage2(hx)
389 | hx = self.pool23(hx2)
390 |
391 | #stage 3
392 | hx3 = self.stage3(hx)
393 | hx = self.pool34(hx3)
394 |
395 | #stage 4
396 | hx4 = self.stage4(hx)
397 | hx = self.pool45(hx4)
398 |
399 | #stage 5
400 | hx5 = self.stage5(hx)
401 | hx = self.pool56(hx5)
402 |
403 | #stage 6
404 | hx6 = self.stage6(hx)
405 | hx6up = _upsample_like(hx6,hx5)
406 |
407 | #-------------------- decoder --------------------
408 |
409 | fusec51,fusec52 = self.fuse5(hx6up, hx5)
410 | hx5d = self.stage5d(torch.cat((fusec51, fusec52),1))
411 | hx5dup = _upsample_like(hx5d,hx4)
412 |
413 | fusec41,fusec42 = self.fuse4(hx5dup, hx4)
414 | hx4d = self.stage4d(torch.cat((fusec41,fusec42),1))
415 | hx4dup = _upsample_like(hx4d,hx3)
416 |
417 |
418 | fusec31,fusec32 = self.fuse3(hx4dup, hx3)
419 | hx3d = self.stage3d(torch.cat((fusec31,fusec32),1))
420 | hx3dup = _upsample_like(hx3d,hx2)
421 |
422 | fusec21, fusec22 = self.fuse2(hx3dup, hx2)
423 | hx2d = self.stage2d(torch.cat((fusec21, fusec22), 1))
424 | hx2dup = _upsample_like(hx2d,hx1)
425 |
426 |
427 | hx1d = self.stage1d(torch.cat((hx2dup,hx1),1))
428 |
429 |
430 | #side output
431 | d1 = self.side1(hx1d)
432 |
433 | d22 = self.side2(hx2d)
434 | d2 = _upsample_like(d22,d1)
435 |
436 | d32 = self.side3(hx3d)
437 | d3 = _upsample_like(d32,d1)
438 |
439 | d42 = self.side4(hx4d)
440 | d4 = _upsample_like(d42,d1)
441 |
442 | d52 = self.side5(hx5d)
443 | d5 = _upsample_like(d52,d1)
444 |
445 | d62 = self.side6(hx6)
446 | d6 = _upsample_like(d62,d1)
447 |
448 | d0 = self.outconv(torch.cat((d1,d2,d3,d4,d5,d6),1))
449 |
450 | # if self.mode == 'train':
451 | # return F.sigmoid(d0), F.sigmoid(d1), F.sigmoid(d2), F.sigmoid(d3), F.sigmoid(d4), F.sigmoid(d5), F.sigmoid(d6)
452 | # else:
453 | # return F.sigmoid(d0)
454 | if self.mode == 'train':
455 | return d0, d1, d2, d3, d4, d5, d6
456 | else:
457 | return d0
458 |
--------------------------------------------------------------------------------
/model/UIU/__init__.py:
--------------------------------------------------------------------------------
1 |
2 |
3 |
--------------------------------------------------------------------------------
/model/UIU/__pycache__/UIU_no_sigmoid.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YuChuang1205/PAL/a9046c5fc65e7b2548d6e195b624239fa836cf3e/model/UIU/__pycache__/UIU_no_sigmoid.cpython-36.pyc
--------------------------------------------------------------------------------
/model/UIU/__pycache__/UIU_no_sigmoid.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YuChuang1205/PAL/a9046c5fc65e7b2548d6e195b624239fa836cf3e/model/UIU/__pycache__/UIU_no_sigmoid.cpython-38.pyc
--------------------------------------------------------------------------------
/model/UIU/__pycache__/__init__.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YuChuang1205/PAL/a9046c5fc65e7b2548d6e195b624239fa836cf3e/model/UIU/__pycache__/__init__.cpython-36.pyc
--------------------------------------------------------------------------------
/model/UIU/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YuChuang1205/PAL/a9046c5fc65e7b2548d6e195b624239fa836cf3e/model/UIU/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/model/UIU/__pycache__/fusion.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YuChuang1205/PAL/a9046c5fc65e7b2548d6e195b624239fa836cf3e/model/UIU/__pycache__/fusion.cpython-36.pyc
--------------------------------------------------------------------------------
/model/UIU/__pycache__/fusion.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YuChuang1205/PAL/a9046c5fc65e7b2548d6e195b624239fa836cf3e/model/UIU/__pycache__/fusion.cpython-38.pyc
--------------------------------------------------------------------------------
/model/UIU/fusion.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | class AsymBiChaFuseReduce(nn.Module):
5 | def __init__(self, in_high_channels, in_low_channels, out_channels=64, r=4):
6 | super(AsymBiChaFuseReduce, self).__init__()
7 | assert in_low_channels == out_channels
8 | self.high_channels = in_high_channels
9 | self.low_channels = in_low_channels
10 | self.out_channels = out_channels
11 | self.bottleneck_channels = int(out_channels // r)
12 |
13 | self.feature_high = nn.Sequential(
14 | nn.Conv2d(self.high_channels, self.out_channels, 1, 1, 0),
15 | #nn.BatchNorm2d(out_channels),
16 | nn.GroupNorm(1, out_channels),
17 | nn.ReLU(True),
18 | )##512
19 |
20 | self.topdown = nn.Sequential(
21 | nn.AdaptiveAvgPool2d((1, 1)),
22 | nn.Conv2d(self.out_channels, self.bottleneck_channels, 1, 1, 0),
23 | #nn.BatchNorm2d(self.bottleneck_channels),
24 | nn.GroupNorm(1, self.bottleneck_channels),
25 | nn.ReLU(True),
26 |
27 | nn.Conv2d(self.bottleneck_channels, self.out_channels, 1, 1, 0),
28 | #nn.BatchNorm2d(self.out_channels),
29 | nn.GroupNorm(1, out_channels),
30 | nn.Sigmoid(),
31 | )#512
32 |
33 | ##############add spatial attention ###Cross UtU############
34 | self.bottomup = nn.Sequential(
35 | nn.Conv2d(self.low_channels, self.bottleneck_channels, 1, 1, 0),
36 | #nn.BatchNorm2d(self.bottleneck_channels),
37 | nn.GroupNorm(1, self.bottleneck_channels),
38 | nn.ReLU(True),
39 | # nn.Sigmoid(),
40 |
41 | SpatialAttention(kernel_size=3),
42 | # nn.Conv2d(self.bottleneck_channels, 2, 3, 1, 0),
43 | # nn.Conv2d(2, 1, 1, 1, 0),
44 | #nn.BatchNorm2d(self.out_channels),
45 | nn.Sigmoid()
46 | )
47 |
48 | self.post = nn.Sequential(
49 | nn.Conv2d(self.out_channels, self.out_channels, 3, 1, 1),
50 | #nn.BatchNorm2d(self.out_channels),
51 | nn.GroupNorm(1, out_channels),
52 | nn.ReLU(True),
53 | )#512
54 |
55 | def forward(self, xh, xl):
56 | xh = self.feature_high(xh)
57 |
58 | topdown_wei = self.topdown(xh)
59 | bottomup_wei = self.bottomup(xl * topdown_wei)
60 | xs1 = 2 * xl * topdown_wei #1
61 | out1 = self.post(xs1)
62 |
63 | xs2 = 2 * xh * bottomup_wei #1
64 | out2 = self.post(xs2)
65 | return out1,out2
66 |
67 | ##############################
68 | class SpatialAttention(nn.Module):
69 | def __init__(self, kernel_size=3):
70 | super(SpatialAttention, self).__init__()
71 |
72 | assert kernel_size in (3, 7), 'kernel size must be 3 or 7'
73 | padding = 3 if kernel_size == 7 else 1
74 |
75 | self.conv1 = nn.Conv2d(2, 1, kernel_size, padding=padding, bias=False)
76 |
77 | def forward(self, x):
78 | avg_out = torch.mean(x, dim=1, keepdim=True)
79 | max_out, _ = torch.max(x, dim=1, keepdim=True)
80 | x = torch.cat([avg_out, max_out], dim=1)
81 | x = self.conv1(x)
82 | return x
83 |
--------------------------------------------------------------------------------
/test_model.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/python3
2 | # coding = gbk
3 | """
4 | @Author : yuchuang
5 | @Time :
6 | @desc:
7 | """
8 | import os
9 | import albumentations as A
10 | from albumentations.pytorch import ToTensorV2
11 | from torch.utils.data import DataLoader
12 | import numpy as np
13 | import torch
14 | import cv2
15 | from torch.utils.data import Dataset
16 | from PIL import Image
17 | # from model.MSDA.MSDA_no_sigmoid import MSDANet_No_Sigmoid
18 | from skimage import measure
19 | import torch.nn.functional as F
20 | from torch.autograd import Variable
21 | import math
22 | from components.cal_mean_std import Calculate_mean_std
23 | from utilts import access_model
24 |
25 | def read_txt(txt_path):
26 | with open(txt_path, 'r') as file:
27 | lines = file.readlines()
28 | image_out_list = [line.strip() + '.png' for line in lines]
29 | return image_out_list
30 |
31 |
32 | def make_dir(path):
33 | if os.path.exists(path) == False:
34 | os.makedirs(path)
35 |
36 |
37 | ##############################################
38 | choose_model = 'MSDA' ##choose model in [ACM, ALC, MLCL, ALCL, DNA, GGL, UIU, MSDA]
39 | model_func = access_model(choose_model)
40 | choose_dataset = 'SIRST3' ## choose dataset in [SIRST3, IRSTD_1K_point, NUDT_SIRST_1_1_point, SIRST_1_1_point_new]
41 | test_dir_name = '********' ## Replace with the folder name where the corresponding test model is located, such as 'MSDA__SIRST3__masks_coarse__2024-12-13_13-30-35'. Since the timestamps are unique, you need the folder name you generated.
42 | test_model_name = 'best_mIoU_checkpoint_' + test_dir_name + ".pth.tar"
43 | ################################################
44 |
45 |
46 | # Hyperparameters etc.
47 | # os.environ['CUDA_VISIBLE_DEVICES'] = '0'
48 | DEVICE = "cuda:0" if torch.cuda.is_available() else "cpu"
49 | TEST_BATCH_SIZE = 1
50 | NUM_WORKERS = 4
51 | PIN_MEMORY = True
52 | LOAD_MODEL = False
53 | patch_size_test = 1024
54 | TEST_PATCH_BATCH_SIZE = 32
55 | root_path = os.path.abspath('.')
56 | dataset_path = os.path.join(root_path,'dataset', choose_dataset)
57 | test_dataset_path = os.path.join(dataset_path,'val')
58 | input_path = os.path.join(test_dataset_path, 'img')
59 | output_path = os.path.join(test_dataset_path, 'pre_results')
60 | make_dir(output_path)
61 | # TEST_NUM = len(os.listdir(input_path))
62 | # txt_path = os.path.join(root_path, 'img_idx', 'test.txt')
63 | img_list = os.listdir(input_path)
64 |
65 | test_model_path = os.path.join(root_path,'work_dirs',test_dir_name,test_model_name)
66 |
67 | def test_pred(img, net, batch_size, patch_size):
68 | b, c, h, w = img.shape
69 | # print(img.shape)
70 | patch_size = patch_size
71 | stride = patch_size
72 |
73 | if h > patch_size and w > patch_size:
74 | # Unfold the image into patches
75 | img_unfold = F.unfold(img, kernel_size=patch_size, stride=stride)
76 | img_unfold = img_unfold.reshape(b, c, patch_size, patch_size, -1).permute(0, 4, 1, 2, 3)
77 | # print(img_unfold.shape)
78 | patch_num = img_unfold.size(1)
79 |
80 | preds_list = []
81 | for i in range(0, patch_num, batch_size):
82 | end = min(i + batch_size, patch_num)
83 | batch_patches = img_unfold[:, i:end, :, :, :].reshape(-1, c, patch_size, patch_size)
84 | batch_patches = Variable(batch_patches.float())
85 | batch_preds = net.forward(batch_patches)
86 | preds_list.append(batch_preds)
87 | # Concatenate all the patch predictions
88 | preds_unfold = torch.cat(preds_list, dim=0).permute(1, 2, 3, 0)
89 | preds_unfold = preds_unfold.reshape(b, -1, patch_num)
90 | preds = F.fold(preds_unfold, kernel_size=patch_size, stride=stride, output_size=(h, w))
91 | else:
92 | preds = net.forward(img)
93 |
94 | return preds
95 |
96 |
97 | class SirstDataset(Dataset):
98 | def __init__(self, image_dir, patch_size, transform=None, mode='None'):
99 | self.image_dir = image_dir
100 | self.transform = transform
101 | self.images = np.sort(os.listdir(image_dir))
102 | self.mode = mode
103 | self.patch_size = patch_size
104 |
105 | def __len__(self):
106 | return len(self.images)
107 |
108 | def __getitem__(self, index):
109 | img_path = os.path.join(self.image_dir, self.images[index])
110 | image = np.array(Image.open(img_path).convert("RGB"))
111 |
112 | if (self.mode == 'test'):
113 | times = 32
114 | h, w, c = image.shape
115 | # 填充高度和宽度,使其能被32整除
116 | pad_height = math.ceil(h / times) * times - h
117 | pad_width = math.ceil(w / times) * times - w
118 | # 填充图像和掩码
119 | image = np.pad(image, ((0, pad_height), (0, pad_width), (0, 0)), mode='constant')
120 | if self.transform is not None:
121 | augmentations = self.transform(image=image)
122 | image = augmentations["image"]
123 | return image, self.images[index], h, w
124 | else:
125 | print("输入的模式错误!!!")
126 |
127 |
128 | def main():
129 | origin_img_dir = dataset_path + "/origin/img"
130 | cal_mean, cal_std = Calculate_mean_std(origin_img_dir)
131 | test_transforms = A.Compose(
132 | [
133 | # A.Resize(height=IMAGE_HEIGHT, width=IMAGE_WIDTH),
134 | A.Normalize(
135 | mean=cal_mean,
136 | std=cal_std,
137 | max_pixel_value=255.0,
138 | ),
139 | ToTensorV2(),
140 | ],
141 | )
142 | test_ds = SirstDataset(
143 | image_dir=input_path,
144 | patch_size=patch_size_test,
145 | transform=test_transforms,
146 | mode='test'
147 | )
148 |
149 | test_loader = DataLoader(
150 | test_ds,
151 | batch_size=TEST_BATCH_SIZE,
152 | num_workers=NUM_WORKERS,
153 | pin_memory=PIN_MEMORY,
154 | shuffle=False,
155 | )
156 | model = model_func().to(DEVICE)
157 |
158 | model.load_state_dict({k.replace('module.', ''): v for k, v in
159 | torch.load(test_model_path, map_location=DEVICE)[
160 | 'state_dict'].items()})
161 | model.eval()
162 |
163 | temp_num = 0
164 |
165 | for idx, (img, name, h, w) in enumerate(test_loader):
166 | print(idx)
167 | img = img.to(device=DEVICE)
168 | with torch.no_grad():
169 | image_1 = img
170 |
171 | output_1 = test_pred(image_1, model, batch_size=TEST_PATCH_BATCH_SIZE, patch_size=patch_size_test)
172 | output_1 = torch.sigmoid(output_1)
173 | output_1 = output_1[:, :, :h, :w]
174 | output_1 = output_1.cpu().data.numpy()
175 |
176 | for i in range(output_1.shape[0]):
177 | print(name[i])
178 | temp_num = temp_num + 1
179 | pred = output_1[i]
180 | pred = pred[0]
181 | pred_target = np.where(pred > 0.5, 255, 0)
182 | pred_target = np.array(pred_target, dtype='uint8')
183 |
184 | cv2.imwrite(os.path.join(output_path, name[i]), pred_target)
185 |
186 |
187 | if __name__ == "__main__":
188 | main()
189 |
--------------------------------------------------------------------------------
/tools/centroid_anno.m:
--------------------------------------------------------------------------------
1 | data_dir = './masks/';
2 | save_dir = './masks_centroid/';
3 | mkdir(save_dir);
4 | data_list = dir(data_dir);
5 | for i = 3:length(data_list)
6 | img = imread([data_dir,data_list(i).name]);
7 | Ilabel = bwlabel(img);
8 | Area_I = regionprops(Ilabel,'centroid');
9 | img_centroid = zeros(size(img));
10 | for x = 1: numel(Area_I)
11 | img_centroid(floor(Area_I(x).Centroid(2)),floor(Area_I(x).Centroid(1))) = 255;
12 | end
13 | imwrite(uint8(img_centroid),[save_dir,data_list(i).name])
14 | end
--------------------------------------------------------------------------------
/tools/coarse_anno.m:
--------------------------------------------------------------------------------
1 | data_dir = './masks/';
2 | save_dir = './masks_coarse/';
3 | mkdir(save_dir);
4 | data_list = dir(data_dir);
5 | for i = 3:length(data_list)
6 | img = double(imread([data_dir,data_list(i).name]));
7 | Ilabel = bwlabel(img);
8 | BoundingBoxs = regionprops(Ilabel,'BoundingBox');
9 | centroids = regionprops(Ilabel,'centroid');
10 | img_centroid = zeros(size(img));
11 | for j = 1: numel(BoundingBoxs)
12 | img_temp = zeros(size(img));
13 | while sum(sum(img_temp .* img)) == 0
14 | gaussian_num_x = normrnd(0,1/4,1,1) ;
15 | gaussian_num_y = normrnd(0,1/4,1,1) ;
16 | x = floor(centroids(j).Centroid(1)+BoundingBoxs(j).BoundingBox(3)/2*gaussian_num_x);
17 | y = floor(centroids(j).Centroid(2)+BoundingBoxs(j).BoundingBox(4)/2*gaussian_num_y);
18 | [y_max, x_max] = size(img);
19 | if x <1
20 | x= 1;
21 | end
22 | if y <1
23 | y= 1;
24 | end
25 | if x>x_max
26 | x = x_max;
27 | end
28 | if y>y_max
29 | y = y_max;
30 | end
31 | img_temp(y,x) = 255;
32 | end
33 | img_centroid(y,x) = 255;
34 | end
35 | imwrite(uint8(img_centroid),[save_dir,data_list(i).name])
36 | end
37 |
38 | function [y] = Gaussian(x,mu,sigma)
39 | y = exp(-(x-mu).^2/(2*sigma^2));%
40 | end
--------------------------------------------------------------------------------
/utilts.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/python3
2 | # coding = gbk
3 | """
4 | @Author : yuchuang
5 | @Time :
6 | @desc:
7 | """
8 |
9 | import os
10 |
11 | import numpy as np
12 |
13 | import cv2
14 | import importlib
15 | import random
16 | import shutil
17 |
18 | import sys
19 | import torch
20 | # import pydensecrf.densecrf as dcrf
21 | # from pydensecrf.utils import unary_from_softmax, create_pairwise_gaussian, create_pairwise_bilateral
22 |
23 | def make_dir(path):
24 | if os.path.exists(path)==False:
25 | os.makedirs(path)
26 |
27 | def access_model(choose_model):
28 | choose_model_dir_name = choose_model + '_no_sigmoid'
29 | model_function = choose_model + '_No_Sigmoid'
30 | module_name = f"model.{choose_model}.{choose_model_dir_name}"
31 | module = importlib.import_module(module_name)
32 | model_func = getattr(module, model_function)
33 | return model_func
34 |
35 |
36 | def check_path(path):
37 | if os.path.exists(path) == True:
38 | print("Error: The workspace of the training pool already exists. Please manually and carefully delete the existing directory or add a suffix to the name of the created folder!!!")
39 | print("The conflicting directories are:", path)
40 | print("Method 3 is recommended to generate a unique training pool folder name!!!")
41 | sys.exit(0)
42 |
43 |
44 | def center_point_inside_contour(center_point, target_mask):
45 | (center_y, center_x) = center_point
46 | target_contours, _ = cv2.findContours(target_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
47 | #center_index_XmergeY = {center_y * 1.0 + center_x * 0.0001,center_y-1 * 1.0 + center_x * 0.0001,center_y * 1.0 + center_x-1 * 0.0001,center_y+1 * 1.0 + center_x * 0.0001,center_y * 1.0 + center_x+1 * 0.0001}
48 | center_index_XmergeY = {center_y * 1.0 + center_x * 0.0001}
49 | temp_contour_mask = np.zeros(target_mask.shape, np.uint8)
50 |
51 | overlap_found = False
52 | for target_contour in target_contours:
53 | target_contour_mask = np.zeros(target_mask.shape, np.uint8)
54 | cv2.fillPoly(target_contour_mask, [target_contour], (255))
55 | target_index = np.where(target_contour_mask == 255)
56 | target_index_XmergeY = set(target_index[0] * 1.0 + target_index[1] * 0.0001)
57 | if not center_index_XmergeY.isdisjoint(target_index_XmergeY):
58 | area = cv2.contourArea(target_contour)
59 | if area>50:
60 | break
61 | else:
62 | overlap_found = True
63 | #print("True")
64 | cv2.fillPoly(temp_contour_mask, [target_contour], (255))
65 | break
66 |
67 | if not overlap_found:
68 | temp_contour_mask = temp_contour_mask
69 | return temp_contour_mask, overlap_found
70 |
71 |
72 | def process_image(y_and_x, y1_and_y2_and_x1_and_x2, img_shape,image, low_threshold=50, high_threshold=150, kernel_size=(3, 3), sigma=0):
73 | blurred_image = cv2.GaussianBlur(image, kernel_size, sigma)
74 | high_pass = cv2.subtract(image, blurred_image)
75 | edges = cv2.Canny(high_pass, low_threshold, high_threshold)
76 | kernel = np.ones((3, 3), np.uint8)
77 | sparse_edges = cv2.dilate(edges, kernel, iterations=1)
78 | sparse_edges = cv2.erode(sparse_edges, kernel, iterations=1)
79 |
80 | temp_contour_mask_2 = np.zeros(img_shape, np.uint8)
81 | (y1,y2, x1,x2) = y1_and_y2_and_x1_and_x2
82 | temp_contour_mask_2[y1:y2, x1:x2] = sparse_edges
83 | (y,x) = y_and_x
84 |
85 |
86 | refine_mask,flag = center_point_inside_contour((y,x), temp_contour_mask_2)
87 | refine_mask_out = refine_mask[y1:y2, x1:x2]
88 | return refine_mask_out,flag
89 |
90 |
91 |
92 | def data_inital_make_add_points(origin_img_dir, origin_points_dir,TRAIN_IMG_DIR,TRAIN_MASK_DIR,train_points_dir,nc_img_dir,nc_mask_dir,nc_points_dir, crop_size=10):
93 | input_image_path = origin_img_dir
94 | input_points_path = origin_points_dir
95 |
96 | output_image_path = TRAIN_IMG_DIR
97 | output_masks_path = TRAIN_MASK_DIR
98 | output_points_path = train_points_dir
99 |
100 |
101 | no_choose_output_image_path = nc_img_dir
102 | no_choose_output_masks_path = nc_mask_dir
103 | no_choose_output_points_path = nc_points_dir
104 |
105 |
106 | input_img_list = os.listdir(input_image_path)
107 |
108 | for i in range(len(input_img_list)):
109 | #print(f"正在处理图像:{input_img_list[i]}")
110 | img_path = os.path.join(input_image_path, input_img_list[i])
111 | points_path = os.path.join(input_points_path, input_img_list[i])
112 |
113 | out_img_path = os.path.join(output_image_path, input_img_list[i])
114 | out_mask_path = os.path.join(output_masks_path, input_img_list[i])
115 | out_points_path = os.path.join(output_points_path, input_img_list[i])
116 |
117 | no_choose_out_img_path = os.path.join(no_choose_output_image_path, input_img_list[i])
118 | no_choose_out_mask_path = os.path.join(no_choose_output_masks_path, input_img_list[i])
119 | no_choose_out_points_path = os.path.join(no_choose_output_points_path, input_img_list[i])
120 |
121 | img = cv2.imread(img_path, cv2.IMREAD_GRAYSCALE)
122 | mask = cv2.imread(points_path, cv2.IMREAD_GRAYSCALE)
123 |
124 | if img is None or mask is None:
125 | print(f"图像或mask读取失败:{input_img_list[i]}")
126 | continue
127 |
128 | points = np.where(mask == 255)
129 | if len(points[0]) == 0:
130 | #print(f"在mask中未找到标记点:{input_img_list[i]}")
131 | cv2.imwrite(out_img_path, img)
132 | cv2.imwrite(out_mask_path, mask)
133 | cv2.imwrite(out_points_path, mask)
134 | continue
135 |
136 | merged_result = np.zeros((img.shape[0], img.shape[1]), dtype=np.uint8)
137 |
138 | correct_point = 0
139 |
140 | for i in range(len(points[0])):
141 | center_y, center_x = points[0][i], points[1][i]
142 |
143 | half_size = crop_size
144 | x1 = max(center_x - half_size, 0)
145 | y1 = max(center_y - half_size, 0)
146 | x2 = min(center_x + half_size, img.shape[1])
147 | y2 = min(center_y + half_size, img.shape[0])
148 |
149 | roi = img[y1:y2, x1:x2]
150 |
151 | processed_roi, flag = process_image((center_y, center_x), (y1, y2, x1, x2), mask.shape, roi,
152 | low_threshold=20, high_threshold=40, kernel_size=(3, 3), sigma=0)
153 |
154 | merged_result[y1:y2, x1:x2] = merged_result[y1:y2, x1:x2] + processed_roi
155 |
156 | if flag == True:
157 | correct_point = correct_point + 1
158 | else:
159 | continue
160 |
161 | if (correct_point / len(points[0])) >= 0.8:
162 | merged_result = merged_result + mask
163 | merged_result = np.where(merged_result > 0, 255, 0)
164 | cv2.imwrite(out_img_path, img)
165 | cv2.imwrite(out_mask_path, merged_result)
166 | cv2.imwrite(out_points_path,mask)
167 | else:
168 | # print(no_choose_out_img_path)
169 | cv2.imwrite(no_choose_out_img_path, img)
170 | #cv2.imwrite(no_choose_out_mask_path, mask)
171 | cv2.imwrite(no_choose_out_points_path,mask)
172 |
173 | print("初始数据已生成,共生成样本张数:", len(os.listdir(output_image_path)))
174 | print("初始数据已生成,共生成样本张数:", len(os.listdir(output_image_path)))
175 | print("初始数据已生成,共生成样本张数:", len(os.listdir(output_image_path)))
176 |
177 |
178 |
179 |
180 |
181 | ###copy_mask 为点
182 | ###target_mask为预测结果
183 | def nc_pred_mask(copy_mask, target_mask,lose_point_ratio = 0.2,alarm_point_ration=0.2):
184 | copy_contours, _ = cv2.findContours(copy_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
185 | target_contours, _ = cv2.findContours(target_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
186 |
187 | overwrite_contours = []
188 | un_overwrite_contours = []
189 |
190 | target_index_sets = []
191 | for target_contour in target_contours:
192 | target_contour_mask = np.zeros(copy_mask.shape, np.uint8)
193 | cv2.fillPoly(target_contour_mask, [target_contour], (255))
194 | target_index = np.where(target_contour_mask == 255)
195 | target_index_XmergeY = set(target_index[0] * 1.0 + target_index[1] * 0.0001)
196 | target_index_sets.append(target_index_XmergeY)
197 |
198 | for copy_contour in copy_contours:
199 | copy_contour_mask = np.zeros(copy_mask.shape, np.uint8)
200 | cv2.fillPoly(copy_contour_mask, [copy_contour], (255))
201 | copy_index = np.where(copy_contour_mask == 255)
202 | copy_index_XmergeY = set(copy_index[0] * 1.0 + copy_index[1] * 0.0001)
203 |
204 | overlap_found = False
205 | for target_index_XmergeY in target_index_sets:
206 | if not copy_index_XmergeY.isdisjoint(target_index_XmergeY):
207 | overwrite_contours.append(copy_contour)
208 | overlap_found = True
209 | break
210 |
211 | if not overlap_found:
212 | un_overwrite_contours.append(copy_contour)
213 |
214 | flag = False
215 | if len(un_overwrite_contours) / len(copy_contours) > lose_point_ratio or ((len(target_contours) - len(overwrite_contours)) / len(copy_contours)) > alarm_point_ration:
216 | flag = flag
217 | else:
218 | flag = True
219 | return flag
220 |
221 |
222 |
223 | def deal_pred_mask_and_true_point_in(nc_img_path,nc_pred_mask_path,nc_points_path,c_img_path,c_mask_path,c_points_path,lose_point_ratio=0.2,alarm_point_ration=0.2):
224 | nc_img_list = os.listdir(nc_img_path)
225 | new_choose_list = []
226 | for i in range(len(nc_img_list)):
227 | img_path = os.path.join(nc_img_path, nc_img_list[i])
228 | points_path = os.path.join(nc_points_path, nc_img_list[i])
229 | pred_path = os.path.join(nc_pred_mask_path, nc_img_list[i])
230 |
231 | img = cv2.imread(img_path, cv2.IMREAD_GRAYSCALE)
232 | mask = cv2.imread(points_path, cv2.IMREAD_GRAYSCALE)
233 | pred = cv2.imread(pred_path, cv2.IMREAD_GRAYSCALE)
234 |
235 | if img is None or mask is None:
236 | print(f"图像或mask读取失败:{nc_img_list[i]}")
237 | continue
238 |
239 | points = np.where(mask == 255)
240 | if len(points[0]) == 0:
241 | new_choose_list.append(nc_img_list[i])
242 | continue
243 |
244 | flag = nc_pred_mask(mask, pred,lose_point_ratio =lose_point_ratio,alarm_point_ration=alarm_point_ration)
245 | if flag == True:
246 | new_choose_list.append(nc_img_list[i])
247 | return new_choose_list
248 |
249 |
250 |
251 | ###注意这边与nc_pred_mask中的相反
252 | ###copy_mask 为预测结果
253 | ###target_mask为点
254 | def nc_correct_pred_mask(copy_mask, target_mask):
255 | copy_contours, _ = cv2.findContours(copy_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
256 | target_contours, _ = cv2.findContours(target_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
257 |
258 | overwrite_contours = []
259 | un_overwrite_contours = []
260 |
261 | target_index_sets = []
262 | for target_contour in target_contours:
263 | target_contour_mask = np.zeros(copy_mask.shape, np.uint8)
264 | cv2.fillPoly(target_contour_mask, [target_contour], (255))
265 | target_index = np.where(target_contour_mask == 255)
266 | target_index_XmergeY = set(target_index[0] * 1.0 + target_index[1] * 0.0001)
267 | target_index_sets.append(target_index_XmergeY)
268 |
269 | for copy_contour in copy_contours:
270 | copy_contour_mask = np.zeros(copy_mask.shape, np.uint8)
271 | cv2.fillPoly(copy_contour_mask, [copy_contour], (255))
272 | copy_index = np.where(copy_contour_mask == 255)
273 | copy_index_XmergeY = set(copy_index[0] * 1.0 + copy_index[1] * 0.0001)
274 |
275 | overlap_found = False
276 | for target_index_XmergeY in target_index_sets:
277 | if not copy_index_XmergeY.isdisjoint(target_index_XmergeY):
278 | overwrite_contours.append(copy_contour)
279 | overlap_found = True
280 | break
281 |
282 | if not overlap_found:
283 | un_overwrite_contours.append(copy_contour)
284 |
285 | copy_contour_mask_out = np.zeros(copy_mask.shape, np.uint8)
286 | for i in range(len(overwrite_contours)):
287 | cv2.fillPoly(copy_contour_mask_out, [overwrite_contours[i]], (255))
288 |
289 | copy_contour_mask_out = copy_contour_mask_out + target_mask
290 | copy_contour_mask_out = np.where(copy_contour_mask_out > 0, 255, 0)
291 |
292 | return copy_contour_mask_out
293 |
294 |
295 |
296 | def deal_gen_mask_error_aera(nc_pred_mask_path,nc_points_path,new_choose_list):
297 | for i in range(len(new_choose_list)):
298 | pred_mask_path = os.path.join(nc_pred_mask_path,new_choose_list[i])
299 | points_path = os.path.join(nc_points_path,new_choose_list[i])
300 |
301 | pred_mask = cv2.imread(pred_mask_path, cv2.IMREAD_GRAYSCALE)
302 | points = cv2.imread(points_path, cv2.IMREAD_GRAYSCALE)
303 |
304 | pred_mask_out = nc_correct_pred_mask(pred_mask,points)
305 | cv2.imwrite(pred_mask_path,pred_mask_out)
306 | print("生成的标签精细化完成!!!!!!")
307 |
308 |
309 |
310 |
311 |
312 |
313 | def move_files(file_list, src_path, dst_path):
314 | for file_name in file_list:
315 | #print(file_name)
316 | src_file = os.path.join(src_path, file_name)
317 | #print(src_file)
318 | dst_file = os.path.join(dst_path, file_name)
319 | if os.path.exists(src_file):
320 | try:
321 | shutil.copy(src_file, dst_file)
322 | os.remove(src_file)
323 | except PermissionError as e:
324 | print(f"PermissionError: {e}")
325 |
326 | # shutil.move(src_file, dst_file)
327 | else:
328 | print(f"File {src_file} does not exist.")
329 |
330 |
331 |
332 |
333 | def hard_sample_in(nc_img_path,nc_pred_mask_path,nc_points_path,c_img_path,c_mask_path,c_points_path,new_choose_list):
334 | move_files(new_choose_list, nc_img_path, c_img_path)
335 | move_files(new_choose_list, nc_pred_mask_path, c_mask_path)
336 | move_files(new_choose_list, nc_points_path, c_points_path)
337 |
338 |
339 |
340 |
341 |
342 | def update_gt_update_degen_corr(pred, gt_masks, thresh_Tb, thresh_k, size,degen=0.9):
343 |
344 | update_gt_masks = gt_masks.copy()
345 |
346 | background_length = 33
347 | target_length = 3
348 |
349 | num_labels, label_image = cv2.connectedComponents((gt_masks > 0.5).astype(np.uint8))
350 |
351 | background_kernel = np.ones((background_length, background_length), np.uint8)
352 | target_kernel = np.ones((target_length, target_length), np.uint8)
353 |
354 | pred_max = pred.max()
355 | max_limitation = size[0] * size[1] * 0.0015
356 |
357 | combined_thresh_mask = np.zeros_like(pred, dtype=np.float32)
358 |
359 | for region_num in range(1, num_labels):
360 | region_coords = np.argwhere(label_image == region_num)
361 | centroid = np.mean(region_coords, axis=0).astype(int)
362 |
363 | cur_point_mask = np.zeros_like(pred, dtype=np.uint8)
364 | cur_point_mask[centroid[0], centroid[1]] = 1
365 |
366 | nbr_mask = cv2.dilate(cur_point_mask, background_kernel) > 0
367 | targets_mask = cv2.dilate(cur_point_mask, target_kernel) > 0
368 |
369 | region_size_ratio = len(region_coords) / max_limitation
370 | threshold_start = (pred * nbr_mask).max() * thresh_Tb
371 | threshold_delta = thresh_k * ((pred * nbr_mask).max() - threshold_start) * region_size_ratio
372 | threshold = threshold_start + threshold_delta
373 | threshold = threshold.cpu().numpy() if isinstance(threshold, torch.Tensor) else threshold
374 |
375 | thresh_mask = (pred * nbr_mask > threshold).astype(np.float32)
376 |
377 | num_labels_thresh, label_image_thresh = cv2.connectedComponents(thresh_mask.astype(np.uint8))
378 | for num_cur in range(1, num_labels_thresh):
379 | curr_mask = (label_image_thresh == num_cur).astype(np.float32)
380 | if np.sum(curr_mask * targets_mask) == 0:
381 | thresh_mask -= curr_mask
382 |
383 | combined_thresh_mask = np.maximum(combined_thresh_mask, thresh_mask)
384 |
385 | target_patch = (update_gt_masks * combined_thresh_mask + pred * combined_thresh_mask) / 2
386 | background_patch = update_gt_masks * (1 - combined_thresh_mask)* degen
387 | update_gt_masks = background_patch + target_patch
388 |
389 | update_gt_masks = np.maximum(update_gt_masks, (gt_masks == 1).astype(np.float32))
390 |
391 | return update_gt_masks
392 |
393 |
394 |
395 |
396 |
397 |
398 |
399 |
400 |
401 |
--------------------------------------------------------------------------------