├── DataProcess
├── __init__.py
├── postdam
│ ├── CropPost.py
│ ├── FindMean.py
│ ├── TransPostLabel.py
│ ├── WriteSegTXT.py
│ ├── __init__.py
│ ├── make_point_label.py
│ └── rgb_to_value.txt
└── vaihingen
│ ├── CropVai.py
│ ├── FindMean.py
│ ├── TransVaiLabel.py
│ ├── WriteTXT.py
│ ├── __init__.py
│ ├── make_point_label.py
│ └── rgb_to_value.txt
├── README.md
├── datafiles
├── color_dict.py
├── potsdam
│ ├── seg_test.txt
│ ├── seg_train.txt
│ └── seg_val.txt
└── vaihingen
│ ├── seg_test.txt
│ ├── seg_train.txt
│ └── seg_val.txt
├── dataset
├── __init__.py
├── data_sec.py
├── data_utils.py
├── datapoint.py
└── transform.py
├── models
├── DBFNet.py
├── __init__.py
├── base_model.py
├── fbf.py
├── network
│ ├── DeeplabV3.py
│ ├── LinkNet.py
│ ├── __init__.py
│ ├── fcn.py
│ ├── fpn.py
│ ├── resnet.py
│ └── segformer.py
├── seg_model.py
└── tools
│ ├── __init__.py
│ └── densecrf.py
├── options
├── Second_options.py
├── __init__.py
└── point_options.py
├── predict.py
├── run
├── __init__.py
├── point
│ ├── p_predict_train.py
│ ├── p_test.py
│ └── p_train.py
└── second
│ ├── __init__.py
│ ├── sec_predict_train.py
│ ├── sec_test.py
│ └── sec_train.py
└── utils
├── __init__.py
├── __pycache__
├── __init__.cpython-35.pyc
├── __init__.cpython-36.pyc
├── __init__.cpython-37.pyc
├── metric.cpython-36.pyc
├── metric.cpython-37.pyc
├── paint.cpython-35.pyc
├── paint.cpython-36.pyc
├── paint.cpython-37.pyc
├── query.cpython-37.pyc
├── util.cpython-35.pyc
├── util.cpython-36.pyc
└── util.cpython-37.pyc
├── metric.py
├── paint.py
├── query.py
└── util.py
/DataProcess/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Luffy03/DBFNet/5ec2458098c206aa9780e693ed81320e548da4a2/DataProcess/__init__.py
--------------------------------------------------------------------------------
/DataProcess/postdam/CropPost.py:
--------------------------------------------------------------------------------
1 | import os
2 | import tifffile
3 | import numpy as np
4 | import cv2
5 | from utils import check_dir, read, imsave
6 |
7 | train_name = [7, 8, 9, 10, 11, 12]
8 | test_name = [13, 14, 15]
9 |
10 |
11 | def crop_overlap(img, label, vis, name, img_path, label_path, vis_path, size=256, stride=128):
12 | h, w, c = img.shape
13 | new_h, new_w = (h // size + 1) * size, (w // size + 1) * size
14 | num_h, num_w = (new_h // stride) - 1, (new_w // stride) - 1
15 |
16 | new_img = np.zeros([new_h, new_w, c]).astype(np.uint8)
17 | new_img[:h, :w, :] = img
18 |
19 | new_label = 255 * np.ones([new_h, new_w, 3]).astype(np.uint8)
20 | new_label[:h, :w, :] = label
21 |
22 | new_vis = np.zeros([new_h, new_w, 3]).astype(np.uint8)
23 | new_vis[:h, :w, :] = vis
24 |
25 | count = 0
26 |
27 | for i in range(num_h):
28 | for j in range(num_w):
29 | out = new_img[i * stride:i * stride + size, j * stride:j * stride + size, :]
30 | gt = new_label[i * stride:i * stride + size, j * stride:j * stride + size, :]
31 | v = new_vis[i * stride:i * stride + size, j * stride:j * stride + size, :]
32 | assert v.shape == (256, 256, 3), print(v.shape)
33 |
34 | tifffile.imsave(img_path + '/' + str(name) + '_' + str(count) + '.tif', out)
35 | tifffile.imsave(label_path + '/' + str(name) + '_' + str(count) + '.tif', gt)
36 | tifffile.imsave(vis_path + '/' + str(name) + '_' + str(count) + '.tif', v)
37 |
38 | count += 1
39 |
40 |
41 | def crop(img, label, vis, name, img_path, label_path, vis_path, size=256):
42 | height, width, _ = img.shape
43 | h_size = height // size
44 | w_size = width // size
45 |
46 | count = 0
47 |
48 | for i in range(h_size):
49 | for j in range(w_size):
50 | out = img[i * size:(i + 1) * size, j * size:(j + 1) * size, :3]
51 | gt = label[i * size:(i + 1) * size, j * size:(j + 1) * size, :]
52 | v = vis[i * size:(i + 1) * size, j * size:(j + 1) * size, :]
53 | assert v.shape == (256, 256, 3)
54 |
55 | imsave(img_path + '/' + str(name) + '_' + str(count) + '.png', out)
56 | imsave(label_path + '/' + str(name) + '_' + str(count) + '.png', gt)
57 | imsave(vis_path + '/' + str(name) + '_' + str(count) + '.png', v)
58 |
59 | count += 1
60 |
61 |
62 | def save(img, label, vis, name, save_path, flag='val'):
63 | img_path = save_path + '/img'
64 | label_path = save_path + '/label'
65 | v_path = save_path + '/label_vis'
66 | check_dir(img_path), check_dir(label_path), check_dir(v_path)
67 |
68 | crop(img, label, vis, name, img_path, label_path, v_path)
69 |
70 |
71 | def run(root_path):
72 | img_path = root_path + '/dataset_origin/4_Ortho_RGBIR'
73 | label_path = root_path + '/dataset_origin/Labels'
74 | vis_path = root_path + '/dataset_origin/5_Labels_all'
75 |
76 | list = os.listdir(label_path)
77 | for i in list:
78 | img = read(os.path.join(img_path, i[:-9] + 'RGBIR.tif'))
79 | label = read(os.path.join(label_path, i))
80 | vis = read(os.path.join(vis_path, i))
81 |
82 | name = i[12:-10]
83 |
84 | if int(i[14:-10]) in test_name:
85 | print(name, 'test')
86 | test_path = root_path+'/test'
87 | check_dir(test_path)
88 | save(img, label, vis, name, test_path, flag='test')
89 |
90 | elif int(i[14:-10]) in val_name:
91 | print(name, 'val')
92 | val_path = root_path + '/val'
93 | check_dir(val_path)
94 | save(img, label, vis, name, val_path, flag='val')
95 |
96 | elif int(i[14:-10]) in train_name:
97 | print(name, 'train')
98 | train_path = root_path+'/train'
99 | check_dir(train_path)
100 | save(img, label, vis, name, train_path, flag='train')
101 |
--------------------------------------------------------------------------------
/DataProcess/postdam/FindMean.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | import tifffile
4 |
5 | # means = [86.42521457, 92.37607528, 85.74658389, 98.17242502]
6 | path = '/home/ggm/WLS/semantic/dataset/potsdam/train/img'
7 |
8 | list = os.listdir(path)
9 |
10 | k = 0
11 | sum = np.zeros([4])
12 |
13 | for idx, i in enumerate(list):
14 | print(idx)
15 | k += 1
16 | img = tifffile.imread(os.path.join(path, i))
17 | img = img.reshape(256*256, -1)
18 |
19 | mean = np.mean(img, axis=0)
20 | sum += mean
21 |
22 |
23 | means = sum/k
24 | print(means)
25 |
26 | # std = [35.58409409, 35.45218542, 36.91464009, 36.3449891]
27 | # means = [86.42521457, 92.37607528, 85.74658389, 98.17242502]
28 | #
29 | # path = '/home/ggm/WLS/semantic/dataset/potsdam/train/img'
30 | #
31 | # list = os.listdir(path)
32 | # k = 0
33 | # sum = np.zeros([4])
34 | #
35 | # for idx, i in enumerate(list):
36 | # print(idx)
37 | # k += 1
38 | # img = tifffile.imread(os.path.join(path, i))
39 | # img = img.reshape(256*256, -1)
40 | #
41 | # x = (img - means) ** 2
42 | #
43 | # sum += np.sum(x, axis=0)
44 | #
45 | # std = np.sqrt(sum/(k*256*256))
46 | # print(std)
--------------------------------------------------------------------------------
/DataProcess/postdam/TransPostLabel.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | import numpy as np
3 | import os
4 | from utils import *
5 | import matplotlib.pyplot as plt
6 | import tifffile
7 | from tqdm import tqdm
8 |
9 |
10 | def value_to_rgb(anno):
11 | label2color_dict = {
12 | 0: [255, 255, 255], # Impervious surfaces (RGB: 255, 255, 255)
13 | 1: [0, 0, 255], # Building (RGB: 0, 0, 255)
14 | 2: [0, 255, 255], # Low vegetation (RGB: 0, 255, 255)
15 | 3: [0, 255, 0], # Tree (RGB: 0, 255, 0)
16 | 4: [255, 255, 0], # Car (RGB: 255, 255, 0)
17 | 5: [255, 0, 0], # Clutter/background (RGB: 255, 0, 0)
18 | 255: [0, 0, 0], # boundary (RGB: 0, 0, 0)
19 | }
20 | # visualize
21 | visual_anno = np.zeros((anno.shape[0], anno.shape[1], 3), dtype=np.uint8)
22 | for i in range(visual_anno.shape[0]): # i for h
23 | for j in range(visual_anno.shape[1]):
24 | # cv2: bgr
25 | color = label2color_dict[anno[i, j, 0]]
26 |
27 | visual_anno[i, j, 0] = color[0]
28 | visual_anno[i, j, 1] = color[1]
29 | visual_anno[i, j, 2] = color[2]
30 |
31 | return visual_anno
32 |
33 |
34 | def rgb_to_value(rgb, txt='/home/ggm/WLS/semantic/PointAnno/DataProcess/potsdam/rgb_to_value.txt'):
35 | key_arr = np.loadtxt(txt)
36 | array = np.zeros((rgb.shape[0], rgb.shape[1], 3), dtype=np.uint8)
37 |
38 | for translation in key_arr:
39 | r, g, b, value = translation
40 | tmp = [r, g, b]
41 | array[(rgb == tmp).all(axis=2)] = value
42 |
43 | return array
44 |
45 |
46 | def save_label_value(label_path, save_path):
47 | check_dir(save_path)
48 | label_list = os.listdir(label_path)
49 | for idx, i in enumerate(label_list):
50 | print(idx)
51 | path = os.path.join(label_path, i)
52 | label = tifffile.imread(path)
53 |
54 | label = ((label > 128) * 255)
55 |
56 | label_value = rgb_to_value(label)
57 | tifffile.imsave(os.path.join(save_path, i), label_value)
58 |
59 |
60 | def show_label(root_path, img_path, vis_path):
61 | list = os.listdir(root_path)
62 | for idx, i in tqdm(enumerate(list)):
63 | print(i)
64 | img = tifffile.imread(os.path.join(img_path, i[:-9] + 'RGBIR.tif'))[:, :, :3]
65 | vis = tifffile.imread(os.path.join(vis_path, i))
66 |
67 | label = tifffile.imread(os.path.join(root_path, i))
68 | label = value_to_rgb(label)
69 |
70 | fig, axs = plt.subplots(1, 3, figsize=(14, 4))
71 |
72 | axs[0].imshow(img.astype(np.uint8))
73 | axs[0].axis("off")
74 | axs[1].imshow(label.astype(np.uint8))
75 | axs[1].axis("off")
76 |
77 | axs[2].imshow(vis)
78 | axs[2].axis("off")
79 |
80 | plt.suptitle(os.path.basename(i), y=0.94)
81 | plt.tight_layout()
82 | plt.show()
83 | plt.close()
84 |
85 |
86 | def show_one(path='/home/ggm/WLS/semantic/dataset/potsdam/val/label/6_7_217.tif'):
87 | label = tifffile.imread(path)
88 | label = value_to_rgb(label)
89 |
90 | fig, axs = plt.subplots(1, 2, figsize=(14, 4))
91 |
92 | axs[0].imshow(label.astype(np.uint8))
93 | axs[0].axis("off")
94 |
95 | plt.tight_layout()
96 | plt.show()
97 | plt.close()
98 |
99 |
100 | if __name__ == '__main__':
101 | img_path = '/home/ggm/WLS/semantic/dataset/potsdam/dataset_origin/4_Ortho_RGBIR'
102 |
103 | path = '/home/ggm/WLS/semantic/dataset/potsdam/dataset_origin/5_Labels_all'
104 | save_path = '/home/ggm/WLS/semantic/dataset/potsdam/dataset_origin/Labels'
105 | check_dir(save_path)
106 |
107 | # save_label_value(path, save_path)
108 | show_label(save_path, img_path, path)
--------------------------------------------------------------------------------
/DataProcess/postdam/WriteSegTXT.py:
--------------------------------------------------------------------------------
1 | import os
2 | import random
3 |
4 |
5 | def write_train(path):
6 | train_txt = open(path + '/datafiles/potsdam/seg_train.txt', 'w')
7 | train_path = '/media/hlf/Luffy/WLS//semantic/dataset/potsdam/train/img'
8 | list = os.listdir(train_path)
9 |
10 | for idx, i in enumerate(list):
11 | train_txt.write(i + '\n')
12 | train_txt.close()
13 |
14 |
15 | def write_val(path):
16 | val_txt = open(path + '/datafiles/potsdam/seg_val.txt', 'w')
17 | val_path = '/media/hlf/Luffy/WLS//semantic/dataset/potsdam/val/img'
18 | list = os.listdir(val_path)
19 |
20 | for idx, i in enumerate(list):
21 | val_txt.write(i + '\n')
22 | val_txt.close()
23 |
24 |
25 | def write_test(path):
26 | test_txt = open(path + '/datafiles/potsdam/seg_test.txt', 'w')
27 | test_path = '/media/hlf/Luffy/WLS//semantic/dataset/potsdam/test/img'
28 | list = os.listdir(test_path)
29 |
30 | for idx, i in enumerate(list):
31 | test_txt.write(i + '\n')
32 | test_txt.close()
33 |
34 |
35 | if __name__ == '__main__':
36 | path = '/media/hlf/Luffy/WLS//PointAnno'
37 | write_train(path)
38 | write_val(path)
39 | write_test(path)
--------------------------------------------------------------------------------
/DataProcess/postdam/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Luffy03/DBFNet/5ec2458098c206aa9780e693ed81320e548da4a2/DataProcess/postdam/__init__.py
--------------------------------------------------------------------------------
/DataProcess/postdam/make_point_label.py:
--------------------------------------------------------------------------------
1 | import os
2 | import cv2
3 | import numpy as np
4 | import tifffile
5 | import matplotlib.pyplot as plt
6 | from tqdm import tqdm
7 | from utils import check_dir, read, imsave
8 |
9 |
10 | def read_cls_color(cls):
11 | label2color_dict = {
12 | 0: [255, 255, 255], # Impervious surfaces (RGB: 255, 255, 255)
13 | 1: [0, 0, 255], # Building (RGB: 0, 0, 255)
14 | 2: [0, 255, 255], # Low vegetation (RGB: 0, 255, 255)
15 | 3: [0, 255, 0], # Tree (RGB: 0, 255, 0)
16 | 4: [255, 255, 0], # Car (RGB: 255, 255, 0)
17 | 5: [255, 0, 0], # Clutter/background (RGB: 255, 0, 0)
18 | }
19 | return label2color_dict[cls]
20 |
21 |
22 | def draw_point(label, kernal_size=100, point_size=3):
23 | h, w, c = label.shape
24 | label_set = np.unique(label)
25 |
26 | new_mask = np.ones([h, w, c], np.uint8) * 255
27 | new_mask_vis = np.zeros([h, w, c], np.uint8)
28 |
29 | for cls in label_set:
30 | if cls != 255:
31 | color = read_cls_color(cls)
32 |
33 | temp_mask = np.zeros([h, w])
34 | temp_mask[label[:, :, 0] == cls] = 255
35 | temp_mask = np.asarray(temp_mask, dtype=np.uint8)
36 | _, contours, hierarchy = cv2.findContours(temp_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE)
37 | # cv2.drawContours(new_mask, contours, -1, color, 1)
38 |
39 | for i in range(len(contours)):
40 | area = cv2.contourArea(contours[i])
41 | if area > kernal_size:
42 | # distance to contour
43 | dist = np.empty([h, w], dtype=np.float32)
44 | for h_ in range(h):
45 | for w_ in range(w):
46 | dist[h_, w_] = cv2.pointPolygonTest(contours[i], (w_, h_), True)
47 |
48 | # make sure the point in the temp_mask
49 | temp_dist = temp_mask * dist
50 | min_, max_, _, maxdistpt = cv2.minMaxLoc(temp_dist)
51 | cx, cy = maxdistpt[0], maxdistpt[1]
52 |
53 | new_mask[cy:cy + point_size, cx:cx + point_size, :] = (cls, cls, cls)
54 | new_mask_vis[cy:cy + point_size, cx:cx + point_size, :] = color
55 |
56 | return new_mask, new_mask_vis
57 |
58 |
59 | def make(root_path):
60 | train_path = root_path + '/train'
61 | val_path = root_path + '/val'
62 |
63 | paths = [train_path, val_path]
64 | for path in paths:
65 | label_path = path + '/label'
66 |
67 | point_label_path = path + '/point_label'
68 | point_label_vis_path = path + '/point_label_vis'
69 | check_dir(point_label_path), check_dir(point_label_vis_path)
70 |
71 | list = os.listdir(label_path)
72 | for i in tqdm(list):
73 | label = os.path.join(label_path, i)
74 | label = read(label)
75 | new_mask, new_mask_vis = draw_point(label)
76 | imsave(point_label_path + '/' + i, new_mask)
77 | imsave(point_label_vis_path + '/' + i, new_mask_vis)
78 |
79 |
80 | if __name__ == '__main__':
81 | root_path = '/media/hlf/Luffy/WLS/semantic/dataset/potsdam'
82 | make(root_path)
83 |
--------------------------------------------------------------------------------
/DataProcess/postdam/rgb_to_value.txt:
--------------------------------------------------------------------------------
1 | 255 255 255 0
2 | 0 0 255 1
3 | 0 255 255 2
4 | 0 255 0 3
5 | 255 255 0 4
6 | 255 0 0 5
--------------------------------------------------------------------------------
/DataProcess/vaihingen/CropVai.py:
--------------------------------------------------------------------------------
1 | import os
2 | import tifffile
3 | import numpy as np
4 | import cv2
5 | from utils import check_dir, read, imsave
6 |
7 | test_name = [2, 4, 6, 8, 10, 12, 14, 16, 20, 22, 24, 27, 29, 31, 33, 35, 38]
8 | val_name = [5]
9 |
10 |
11 | def crop_overlap(img, label, vis, name, img_path, label_path, vis_path, size=256, stride=128):
12 | h, w, c = img.shape
13 | new_h, new_w = (h // size + 1) * size, (w // size + 1) * size
14 | num_h, num_w = (new_h // stride) - 1, (new_w // stride) - 1
15 |
16 | new_img = np.zeros([new_h, new_w, c]).astype(np.uint8)
17 | new_img[:h, :w, :] = img
18 |
19 | new_label = 255 * np.ones([new_h, new_w, 3]).astype(np.uint8)
20 | new_label[:h, :w, :] = label
21 |
22 | new_vis = np.zeros([new_h, new_w, 3]).astype(np.uint8)
23 | new_vis[:h, :w, :] = vis
24 |
25 | count = 0
26 |
27 | for i in range(num_h):
28 | for j in range(num_w):
29 | out = new_img[i * stride:i * stride + size, j * stride:j * stride + size, :]
30 | gt = new_label[i * stride:i * stride + size, j * stride:j * stride + size, :]
31 | v = new_vis[i * stride:i * stride + size, j * stride:j * stride + size, :]
32 | assert v.shape == (256, 256, 3), print(v.shape)
33 |
34 | imsave(img_path + '/' + str(name) + '_' + str(count) + '.png', out)
35 | imsave(label_path + '/' + str(name) + '_' + str(count) + '.png', gt)
36 | imsave(vis_path + '/' + str(name) + '_' + str(count) + '.png', v)
37 |
38 | count += 1
39 |
40 |
41 | def crop(img, label, vis, name, img_path, label_path, vis_path, size=256):
42 | height, width, _ = img.shape
43 | h_size = height // size
44 | w_size = width // size
45 |
46 | count = 0
47 |
48 | for i in range(h_size):
49 | for j in range(w_size):
50 | out = img[i * size:(i + 1) * size, j * size:(j + 1) * size, :]
51 | gt = label[i * size:(i + 1) * size, j * size:(j + 1) * size, :]
52 | v = vis[i * size:(i + 1) * size, j * size:(j + 1) * size, :]
53 | assert v.shape == (256, 256, 3)
54 |
55 | imsave(img_path + '/' + str(name) + '_' + str(count) + '.png', out)
56 | imsave(label_path + '/' + str(name) + '_' + str(count) + '.png', gt)
57 | imsave(vis_path + '/' + str(name) + '_' + str(count) + '.png', v)
58 |
59 | count += 1
60 |
61 |
62 | def save(img, label, vis, name, save_path, flag='val'):
63 | img_path = save_path + '/img'
64 | label_path = save_path + '/label'
65 | v_path = save_path + '/label_vis'
66 | check_dir(img_path), check_dir(label_path), check_dir(v_path)
67 |
68 | if flag == 'val':
69 | crop(img, label, vis, name, img_path, label_path, v_path)
70 | else:
71 | crop_overlap(img, label, vis, name, img_path, label_path, v_path)
72 |
73 |
74 | def run(root_path):
75 | data_path = root_path + '/dataset_origin/image'
76 | label_path = root_path + '/dataset_origin/gts'
77 | label_vis_path = root_path + '/dataset_origin/vis'
78 |
79 | list = os.listdir(label_path)
80 | for i in list:
81 | name = i[20:-4]
82 | print(i, name)
83 |
84 | img = read(os.path.join(data_path, i))
85 | label = read(os.path.join(label_path, i))
86 | label_vis = read(os.path.join(label_vis_path, i))
87 |
88 | if int(name) in test_name:
89 | print(name, 'test')
90 | test_path = root_path + '/test'
91 | check_dir(test_path)
92 | save(img, label, label_vis, name, test_path, flag='test')
93 |
94 | # elif int(name) in val_name:
95 | # print(name, 'val')
96 | # val_path = root_path + '/val'
97 | # check_dir(val_path)
98 | # save(img, label, label_vis, name, val_path, flag='val')
99 |
100 | else:
101 | print(name, 'train')
102 | train_path = root_path + '/train'
103 | check_dir(train_path)
104 | save(img, label, label_vis, name, train_path, flag='train')
105 |
106 |
107 | if __name__ == '__main__':
108 | root_path = '/media/hlf/Luffy/WLS/semantic/dataset/vaihingen'
109 | run(root_path)
--------------------------------------------------------------------------------
/DataProcess/vaihingen/FindMean.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | import tifffile
4 |
5 | # means = [119.14901543, 83.04203606, 81.79810095]
6 | path = '/home/ggm/WLS/semantic/dataset/vaihingen/train/img'
7 |
8 | list = os.listdir(path)
9 |
10 | k = 0
11 | sum = np.zeros([3])
12 |
13 | for idx, i in enumerate(list):
14 | print(idx)
15 | k += 1
16 | img = tifffile.imread(os.path.join(path, i))
17 | img = img.reshape(256*256, -1)
18 |
19 | mean = np.mean(img, axis=0)
20 | sum += mean
21 |
22 |
23 | means = sum/k
24 | print(means)
25 |
26 | # std = [55.63038161, 40.67145608, 38.61447761]
27 | # means = [119.14901543, 83.04203606, 81.79810095]
28 | #
29 | # path = '/home/ggm/WLS/semantic/dataset/vaihingen/train/img'
30 | #
31 | # list = os.listdir(path)
32 | # k = 0
33 | # sum = np.zeros([3])
34 | #
35 | # for idx, i in enumerate(list):
36 | # print(idx)
37 | # k += 1
38 | # img = tifffile.imread(os.path.join(path, i))
39 | # img = img.reshape(256*256, -1)
40 | #
41 | # x = (img - means) ** 2
42 | #
43 | # sum += np.sum(x, axis=0)
44 | #
45 | # std = np.sqrt(sum/(k*256*256))
46 | # print(std)
--------------------------------------------------------------------------------
/DataProcess/vaihingen/TransVaiLabel.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | import numpy as np
3 | import os
4 | from utils import *
5 | import matplotlib.pyplot as plt
6 | import tifffile
7 |
8 |
9 | def value_to_rgb(anno):
10 | label2color_dict = {
11 | 0: [255, 255, 255], # Impervious surfaces (RGB: 255, 255, 255)
12 | 1: [0, 0, 255], # Building (RGB: 0, 0, 255)
13 | 2: [0, 255, 255], # Low vegetation (RGB: 0, 255, 255)
14 | 3: [0, 255, 0], # Tree (RGB: 0, 255, 0)
15 | 4: [255, 255, 0], # Car (RGB: 255, 255, 0)
16 | 5: [255, 0, 0], # Clutter/background (RGB: 255, 0, 0)
17 | 255: [0, 0, 0], # boundary (RGB: 0, 0, 0)
18 | }
19 | # visualize
20 | visual_anno = np.zeros((anno.shape[0], anno.shape[1], 3), dtype=np.uint8)
21 | for i in range(visual_anno.shape[0]): # i for h
22 | for j in range(visual_anno.shape[1]):
23 | # cv2: bgr
24 | color = label2color_dict[anno[i, j, 0]]
25 |
26 | visual_anno[i, j, 0] = color[0]
27 | visual_anno[i, j, 1] = color[1]
28 | visual_anno[i, j, 2] = color[2]
29 |
30 | return visual_anno
31 |
32 |
33 | def rgb_to_value(rgb, txt='/media/hlf/Luffy/WLS/PointAnno/DataProcess/vaihingen/rgb_to_value.txt'):
34 | key_arr = np.loadtxt(txt)
35 | array = np.zeros((rgb.shape[0], rgb.shape[1], 3), dtype=np.uint8)
36 |
37 | for translation in key_arr:
38 | r, g, b, value = translation
39 | tmp = [r, g, b]
40 | array[(rgb == tmp).all(axis=2)] = value
41 |
42 | return array
43 |
44 |
45 | def save_label_value(label_path, save_path):
46 | check_dir(save_path)
47 | label_list = os.listdir(label_path)
48 | for idx, i in enumerate(label_list):
49 | print(idx)
50 | path = os.path.join(label_path, i)
51 | label = tifffile.imread(path)
52 | label = ((label > 128) * 255)
53 |
54 | label_value = rgb_to_value(label)
55 | tifffile.imsave(os.path.join(save_path, i), label_value)
56 |
57 |
58 | def show_label(root_path, img_path, vis_path):
59 | list = os.listdir(root_path)
60 | for idx, i in enumerate(list):
61 | label = tifffile.imread(os.path.join(root_path, i))
62 | print(np.unique(label))
63 | label = value_to_rgb(label)
64 |
65 | img = tifffile.imread(os.path.join(img_path, i))[:, :, :3]
66 | vis = tifffile.imread(os.path.join(vis_path, i))
67 |
68 | plt.subplot(1, 3, 1)
69 | plt.imshow(img)
70 | plt.subplot(1, 3, 2)
71 | plt.imshow(label)
72 | plt.subplot(1, 3, 3)
73 | plt.imshow(vis)
74 | plt.show()
75 |
76 |
77 | if __name__ == '__main__':
78 | img_path = '/media/hlf/Luffy/WLS/semantic/dataset/vaihingen/dataset_origin/image'
79 | path = '/media/hlf/Luffy/WLS/semantic/dataset/vaihingen/dataset_origin/vis_noB'
80 | save_path = '/media/hlf/Luffy/WLS/semantic/dataset/vaihingen/dataset_origin/gts_noB'
81 |
82 | save_label_value(path, save_path)
83 | show_label(save_path, img_path, path)
--------------------------------------------------------------------------------
/DataProcess/vaihingen/WriteTXT.py:
--------------------------------------------------------------------------------
1 | import os
2 | import random
3 |
4 | import os
5 | import random
6 |
7 |
8 | def write_train(path):
9 | train_txt = open(path + '/datafiles/vaihingen/sec_train.txt', 'w')
10 | train_path = '/media/hlf/Luffy/WLS/semantic/dataset/vaihingen/sec/train/img'
11 | list = os.listdir(train_path)
12 |
13 | for idx, i in enumerate(list):
14 | train_txt.write(i + '\n')
15 | train_txt.close()
16 |
17 |
18 | def write_val(path):
19 | val_txt = open(path + '/datafiles/vaihingen/sec_val.txt', 'w')
20 | val_path = '/media/hlf/Luffy/WLS/semantic/dataset/vaihingen/sec/val/img'
21 | list = os.listdir(val_path)
22 |
23 | for idx, i in enumerate(list):
24 | val_txt.write(i + '\n')
25 | val_txt.close()
26 |
27 |
28 | def write_test(path):
29 | test_txt = open(path + '/datafiles/vaihingen/seg_test.txt', 'w')
30 | test_path = '/media/hlf/Luffy/WLS/semantic/dataset/vaihingen/test/img'
31 | list = os.listdir(test_path)
32 |
33 | for idx, i in enumerate(list):
34 | test_txt.write(i + '\n')
35 | test_txt.close()
36 |
37 |
38 | if __name__ == '__main__':
39 | path = '/media/hlf/Luffy/WLS/PointAnno'
40 | write_train(path)
41 | write_val(path)
42 | # write_test(path)
43 |
--------------------------------------------------------------------------------
/DataProcess/vaihingen/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Luffy03/DBFNet/5ec2458098c206aa9780e693ed81320e548da4a2/DataProcess/vaihingen/__init__.py
--------------------------------------------------------------------------------
/DataProcess/vaihingen/make_point_label.py:
--------------------------------------------------------------------------------
1 | import os
2 | import cv2
3 | import numpy as np
4 | import tifffile
5 | import matplotlib.pyplot as plt
6 | from tqdm import tqdm
7 | from utils import check_dir, read, imsave
8 |
9 |
10 | def read_cls_color(cls):
11 | label2color_dict = {
12 | 0: [255, 255, 255], # Impervious surfaces (RGB: 255, 255, 255)
13 | 1: [0, 0, 255], # Building (RGB: 0, 0, 255)
14 | 2: [0, 255, 255], # Low vegetation (RGB: 0, 255, 255)
15 | 3: [0, 255, 0], # Tree (RGB: 0, 255, 0)
16 | 4: [255, 255, 0], # Car (RGB: 255, 255, 0)
17 | 5: [255, 0, 0], # Clutter/background (RGB: 255, 0, 0)
18 | 255: [0, 0, 0]
19 | }
20 | return label2color_dict[cls]
21 |
22 |
23 | def draw_point(label, kernal_size=100):
24 | label = read(label)
25 |
26 | h, w, c = label.shape
27 | label_set = np.unique(label)
28 |
29 | new_mask = np.ones([h, w, c], np.uint8) * 255
30 | new_mask_vis = np.zeros([h, w, c], np.uint8)
31 |
32 | for cls in label_set:
33 | if cls != 255:
34 | color = read_cls_color(cls)
35 |
36 | temp_mask = np.zeros([h, w])
37 | temp_mask[label[:, :, 0] == cls] = 255
38 | temp_mask = np.asarray(temp_mask, dtype=np.uint8)
39 | _, contours, hierarchy = cv2.findContours(temp_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE)
40 | # cv2.drawContours(new_mask, contours, -1, color, 1)
41 |
42 | for i in range(len(contours)):
43 | area = cv2.contourArea(contours[i])
44 | if area > kernal_size:
45 | # distance to contour
46 | dist = np.empty([h, w], dtype=np.float32)
47 | for h_ in range(h):
48 | for w_ in range(w):
49 | dist[h_, w_] = cv2.pointPolygonTest(contours[i], (w_, h_), True)
50 |
51 | # make sure the point in the temp_mask
52 | temp_dist = temp_mask * dist
53 | min_, max_, _, maxdistpt = cv2.minMaxLoc(temp_dist)
54 | cx, cy = maxdistpt[0], maxdistpt[1]
55 |
56 | new_mask[cy:cy + 3, cx:cx + 3, :] = (cls, cls, cls)
57 | new_mask_vis[cy:cy + 3, cx:cx + 3, :] = color
58 |
59 | return new_mask, new_mask_vis
60 |
61 |
62 | def make(root_path):
63 | train_path = root_path + '/train'
64 | val_path = root_path + '/val'
65 | test_path = root_path + '/test'
66 |
67 | paths = [train_path]
68 | for path in paths:
69 | label_path = path + '/label'
70 |
71 | point_label_path = path + '/point_label'
72 | point_label_vis_path = path + '/point_label_vis'
73 | check_dir(point_label_path), check_dir(point_label_vis_path)
74 |
75 | list = os.listdir(label_path)
76 | for i in tqdm(list):
77 | label = os.path.join(label_path, i)
78 | new_mask, new_mask_vis = draw_point(label)
79 | imsave(point_label_path + '/' + i, new_mask)
80 | imsave(point_label_vis_path + '/' + i, new_mask_vis)
81 |
82 |
83 | if __name__ == '__main__':
84 | root_path = '/media/hlf/Luffy/WLS/semantic/dataset/vaihingen'
85 | make(root_path)
86 |
--------------------------------------------------------------------------------
/DataProcess/vaihingen/rgb_to_value.txt:
--------------------------------------------------------------------------------
1 | 255 255 255 0
2 | 0 0 255 1
3 | 0 255 255 2
4 | 0 255 0 3
5 | 255 255 0 4
6 | 255 0 0 5
7 | 0 0 0 255
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Deep Bilateral Filtering Network (DBFNet)
2 | Code for TIP 2022 paper, [**"Deep Bilateral Filtering Network for Point-Supervised Semantic Segmentation in Remote Sensing Images"**](https://ieeexplore.ieee.org/document/9961229), accepted.
3 |
4 | Authors: Linshan Wu, Leyuan Fang, Jun Yue, Bob Zhang, Pedram Ghamisi, and Min He
5 |
6 | ## Getting Started
7 | ### Prepare Dataset
8 | Download the Potsdam and Vaihingen [datasets](https://drive.google.com/drive/folders/1CiYzJyBn1rV-xsrsYQ6o2HDQjdfnadHl) after processing.
9 |
10 | Or you can download the datasets from the official [website](https://www.isprs.org/education/benchmarks/UrbanSemLab/2d-sem-label-vaihingen.aspx). Then, crop the original images and create point labels following our code in [Dataprocess](https://github.com/Luffy03/DBFNet/tree/master/DataProcess).
11 |
12 | If your want to run our code on your own datasets, the pre-process code is also available in [Dataprocess](https://github.com/Luffy03/DBFNet/tree/master/DataProcess).
13 |
14 | ## Evaluate
15 | ### 1. Download the [original datasets](https://www.isprs.org/education/benchmarks/UrbanSemLab/2d-sem-label-vaihingen.aspx)
16 | ### 2. Download our [weights](https://drive.google.com/drive/folders/1CiYzJyBn1rV-xsrsYQ6o2HDQjdfnadHl)
17 | ### 3. Run our code
18 | ```bash
19 | python predict.py
20 | ```
21 |
22 | ## Train
23 | ### 1. Train DBFNet
24 | ```bash
25 | python run/point/p_train.py
26 | ```
27 | ### 2. Generate pseudo labels
28 | ```bash
29 | python run/point/p_predict_train.py
30 | ```
31 | ### 3. Recursive learning
32 | ```bash
33 | python run/second/sec_train.py
34 | ```
35 |
36 | ## Citation ✏️ 📄
37 |
38 | If you find this repo useful for your research, please consider citing the paper as follows:
39 |
40 | ```
41 | @ARTICLE{Wu_DBFNet,
42 | author={Wu, Linshan and Fang, Leyuan and Yue, Jun and Zhang, Bob and Ghamisi, Pedram and He, Min},
43 | journal={IEEE Transactions on Image Processing},
44 | title={Deep Bilateral Filtering Network for Point-Supervised Semantic Segmentation in Remote Sensing Images},
45 | year={2022},
46 | volume={31},
47 | number={},
48 | pages={7419-7434},
49 | doi={10.1109/TIP.2022.3222904}}
50 | @article{wu2024modeling,
51 | title={Modeling the Label Distributions for Weakly-Supervised Semantic Segmentation},
52 | author={Wu, Linshan and Zhong, Zhun and Ma, Jiayi and Wei, Yunchao and Chen, Hao and Fang, Leyuan and Li, Shutao},
53 | journal={arXiv preprint arXiv:2403.13225},
54 | year={2024}
55 | }
56 | @inproceedings{AGMM,
57 | title={Sparsely Annotated Semantic Segmentation with Adaptive Gaussian Mixtures},
58 | author={Wu, Linshan and Zhong, Zhun and Fang, Leyuan and He, Xingxin and Liu, Qiang and Ma, Jiayi and Chen, Hao},
59 | booktitle={IEEE Conf. Comput. Vis. Pattern Recog.},
60 | month={June},
61 | year={2023},
62 | pages={15454-15464}
63 | }
64 | ```
65 |
66 | For any question, please contact [Linshan Wu](mailto:15274891948@163.com).
67 |
--------------------------------------------------------------------------------
/datafiles/color_dict.py:
--------------------------------------------------------------------------------
1 | postdam_color_dict = {
2 | 0: [255, 255, 255], # Impervious surfaces (RGB: 255, 255, 255)
3 | 1: [0, 0, 255], # Building (RGB: 0, 0, 255)
4 | 2: [0, 255, 255], # Low vegetation (RGB: 0, 255, 255)
5 | 3: [0, 255, 0], # Tree (RGB: 0, 255, 0)
6 | 4: [255, 255, 0], # Car (RGB: 255, 255, 0)
7 | 5: [255, 0, 0], # Clutter/background (RGB: 255, 0, 0)
8 | 255: [0, 0, 0], # background
9 | }
10 |
--------------------------------------------------------------------------------
/datafiles/vaihingen/seg_val.txt:
--------------------------------------------------------------------------------
1 | 11_104.png
2 | 11_132.png
3 | 11_133.png
4 | 11_137.png
5 | 11_150.png
6 | 11_152.png
7 | 11_172.png
8 | 11_211.png
9 | 11_212.png
10 | 11_215.png
11 | 11_221.png
12 | 11_23.png
13 | 11_239.png
14 | 11_24.png
15 | 11_241.png
16 | 11_242.png
17 | 11_244.png
18 | 11_262.png
19 | 11_272.png
20 | 11_281.png
21 | 11_282.png
22 | 11_284.png
23 | 11_293.png
24 | 11_31.png
25 | 11_44.png
26 | 11_49.png
27 | 11_50.png
28 | 11_60.png
29 | 11_69.png
30 | 23_120.png
31 | 23_131.png
32 | 23_133.png
33 | 23_137.png
34 | 23_149.png
35 | 23_150.png
36 | 23_157.png
37 | 23_195.png
38 | 23_197.png
39 | 23_198.png
40 | 23_204.png
41 | 23_207.png
42 | 23_216.png
43 | 23_229.png
44 | 23_24.png
45 | 23_241.png
46 | 23_255.png
47 | 23_260.png
48 | 23_271.png
49 | 23_279.png
50 | 23_280.png
51 | 23_281.png
52 | 23_72.png
53 | 23_75.png
54 | 23_8.png
55 | 23_88.png
56 | 17_31.png
57 | 17_6.png
58 | 17_65.png
59 | 17_69.png
60 | 17_99.png
61 | 1_111.png
62 | 1_117.png
63 | 1_119.png
64 | 1_133.png
65 | 1_138.png
66 | 1_146.png
67 | 1_16.png
68 | 1_161.png
69 | 1_170.png
70 | 1_178.png
71 | 1_191.png
72 | 1_205.png
73 | 1_206.png
74 | 1_213.png
75 | 1_22.png
76 | 1_222.png
77 | 1_226.png
78 | 1_227.png
79 | 1_228.png
80 | 1_236.png
81 | 1_238.png
82 | 1_255.png
83 | 1_264.png
84 | 1_291.png
85 | 1_294.png
86 | 1_299.png
87 | 1_309.png
88 | 1_313.png
89 | 32_70.png
90 | 32_8.png
91 | 32_82.png
92 | 32_92.png
93 | 34_0.png
94 | 34_11.png
95 | 34_120.png
96 | 34_129.png
97 | 34_158.png
98 | 34_179.png
99 | 34_196.png
100 | 34_2.png
101 | 34_201.png
102 | 34_31.png
103 | 34_43.png
104 | 34_49.png
105 | 34_78.png
106 | 34_87.png
107 | 34_93.png
108 | 37_108.png
109 | 37_117.png
110 | 37_12.png
111 | 37_124.png
112 | 37_127.png
113 | 37_137.png
114 | 37_146.png
115 | 37_152.png
116 | 37_167.png
117 | 37_170.png
118 | 37_171.png
119 | 13_231.png
120 | 13_248.png
121 | 13_253.png
122 | 13_261.png
123 | 13_262.png
124 | 13_269.png
125 | 13_27.png
126 | 13_280.png
127 | 13_281.png
128 | 13_282.png
129 | 13_285.png
130 | 13_286.png
131 | 13_289.png
132 | 13_322.png
133 | 13_333.png
134 | 13_341.png
135 | 13_359.png
136 | 13_363.png
137 | 13_401.png
138 | 13_410.png
139 | 13_412.png
140 | 13_421.png
141 | 13_45.png
142 | 13_58.png
143 | 13_60.png
144 | 13_62.png
145 | 13_71.png
146 | 13_90.png
147 | 15_100.png
148 | 15_11.png
149 | 15_113.png
150 | 15_131.png
151 | 15_145.png
152 | 15_152.png
153 | 15_154.png
154 | 15_163.png
155 | 15_182.png
156 | 11_83.png
157 | 13_229.png
158 | 15_191.png
159 | 17_199.png
160 | 1_35.png
161 | 21_187.png
162 | 23_117.png
163 | 26_120.png
164 | 28_127.png
165 | 30_219.png
166 | 32_69.png
167 | 37_179.png
168 | 3_313.png
169 | 5_192.png
170 | 5_56.png
171 | 3_319.png
172 | 3_324.png
173 | 3_335.png
174 | 3_339.png
175 | 3_340.png
176 | 3_38.png
177 | 3_59.png
178 | 3_61.png
179 | 3_71.png
180 | 3_86.png
181 | 3_88.png
182 | 3_97.png
183 | 5_117.png
184 | 5_12.png
185 | 5_127.png
186 | 5_13.png
187 | 5_15.png
188 | 5_158.png
189 | 5_16.png
190 | 5_164.png
191 | 5_186.png
192 | 5_191.png
193 | 28_137.png
194 | 28_149.png
195 | 28_157.png
196 | 28_16.png
197 | 28_165.png
198 | 28_17.png
199 | 28_173.png
200 | 28_183.png
201 | 28_199.png
202 | 28_2.png
203 | 28_221.png
204 | 28_223.png
205 | 28_228.png
206 | 28_230.png
207 | 28_232.png
208 | 28_247.png
209 | 28_270.png
210 | 28_272.png
211 | 28_286.png
212 | 28_54.png
213 | 28_56.png
214 | 28_82.png
215 | 28_9.png
216 | 28_93.png
217 | 30_102.png
218 | 30_103.png
219 | 30_120.png
220 | 30_124.png
221 | 30_128.png
222 | 30_139.png
223 | 30_155.png
224 | 30_160.png
225 | 30_162.png
226 | 30_170.png
227 | 30_190.png
228 | 30_214.png
229 | 30_217.png
230 | 1_37.png
231 | 1_4.png
232 | 1_44.png
233 | 1_63.png
234 | 1_74.png
235 | 1_85.png
236 | 1_90.png
237 | 1_95.png
238 | 21_101.png
239 | 21_106.png
240 | 21_109.png
241 | 21_126.png
242 | 21_141.png
243 | 21_142.png
244 | 21_145.png
245 | 21_158.png
246 | 21_159.png
247 | 21_164.png
248 | 21_168.png
249 | 21_174.png
250 | 21_179.png
251 | 5_95.png
252 | 5_98.png
253 | 7_100.png
254 | 7_107.png
255 | 7_119.png
256 | 7_127.png
257 | 7_128.png
258 | 7_149.png
259 | 7_15.png
260 | 7_159.png
261 | 7_163.png
262 | 7_168.png
263 | 7_179.png
264 | 7_18.png
265 | 7_190.png
266 | 7_195.png
267 | 7_2.png
268 | 7_207.png
269 | 7_219.png
270 | 7_223.png
271 | 7_226.png
272 | 7_237.png
273 | 7_239.png
274 | 7_244.png
275 | 7_283.png
276 | 7_284.png
277 | 7_40.png
278 | 7_49.png
279 | 7_59.png
280 | 7_60.png
281 | 7_63.png
282 | 7_74.png
283 | 7_88.png
284 | 37_191.png
285 | 37_198.png
286 | 37_200.png
287 | 37_37.png
288 | 37_41.png
289 | 37_51.png
290 | 37_59.png
291 | 37_85.png
292 | 37_88.png
293 | 3_0.png
294 | 3_103.png
295 | 3_111.png
296 | 3_116.png
297 | 3_121.png
298 | 3_122.png
299 | 3_132.png
300 | 3_133.png
301 | 3_138.png
302 | 3_141.png
303 | 3_194.png
304 | 3_203.png
305 | 3_218.png
306 | 3_22.png
307 | 3_231.png
308 | 3_254.png
309 | 3_256.png
310 | 3_268.png
311 | 3_281.png
312 | 3_289.png
313 | 3_302.png
314 | 3_304.png
315 | 30_253.png
316 | 30_264.png
317 | 30_27.png
318 | 30_272.png
319 | 30_304.png
320 | 30_305.png
321 | 30_36.png
322 | 30_40.png
323 | 30_43.png
324 | 30_52.png
325 | 30_59.png
326 | 30_65.png
327 | 30_75.png
328 | 30_77.png
329 | 30_8.png
330 | 30_83.png
331 | 30_90.png
332 | 32_106.png
333 | 32_120.png
334 | 32_131.png
335 | 32_148.png
336 | 32_158.png
337 | 32_177.png
338 | 32_183.png
339 | 32_186.png
340 | 32_206.png
341 | 32_207.png
342 | 32_22.png
343 | 32_234.png
344 | 32_243.png
345 | 32_280.png
346 | 32_33.png
347 | 32_38.png
348 | 32_60.png
349 | 32_64.png
350 | 15_213.png
351 | 15_217.png
352 | 15_218.png
353 | 15_220.png
354 | 15_229.png
355 | 15_232.png
356 | 15_243.png
357 | 15_258.png
358 | 15_263.png
359 | 15_266.png
360 | 15_271.png
361 | 15_277.png
362 | 15_293.png
363 | 15_299.png
364 | 15_301.png
365 | 15_302.png
366 | 15_312.png
367 | 15_41.png
368 | 15_6.png
369 | 15_82.png
370 | 15_93.png
371 | 17_0.png
372 | 17_10.png
373 | 17_12.png
374 | 17_122.png
375 | 17_126.png
376 | 17_132.png
377 | 17_137.png
378 | 17_150.png
379 | 17_154.png
380 | 17_167.png
381 | 17_174.png
382 | 17_177.png
383 | 11_84.png
384 | 11_9.png
385 | 13_1.png
386 | 13_102.png
387 | 13_109.png
388 | 13_111.png
389 | 13_113.png
390 | 13_121.png
391 | 13_129.png
392 | 13_13.png
393 | 13_132.png
394 | 13_142.png
395 | 13_144.png
396 | 13_145.png
397 | 13_148.png
398 | 13_152.png
399 | 13_154.png
400 | 13_171.png
401 | 13_175.png
402 | 13_194.png
403 | 13_199.png
404 | 13_202.png
405 | 13_21.png
406 | 13_214.png
407 | 13_215.png
408 | 13_218.png
409 | 13_227.png
410 | 26_132.png
411 | 26_143.png
412 | 26_147.png
413 | 26_15.png
414 | 26_150.png
415 | 26_155.png
416 | 26_158.png
417 | 26_169.png
418 | 26_17.png
419 | 26_178.png
420 | 26_189.png
421 | 26_214.png
422 | 26_223.png
423 | 26_23.png
424 | 26_256.png
425 | 26_259.png
426 | 26_264.png
427 | 26_265.png
428 | 26_266.png
429 | 26_3.png
430 | 26_41.png
431 | 26_46.png
432 | 26_58.png
433 | 26_90.png
434 | 28_116.png
435 | 28_119.png
436 | 5_201.png
437 | 5_224.png
438 | 5_229.png
439 | 5_244.png
440 | 5_249.png
441 | 5_256.png
442 | 5_257.png
443 | 5_26.png
444 | 5_262.png
445 | 5_267.png
446 | 5_268.png
447 | 5_275.png
448 | 5_279.png
449 | 5_29.png
450 | 5_39.png
451 | 5_44.png
452 | 5_46.png
453 | 5_50.png
454 | 5_55.png
455 | 21_189.png
456 | 21_194.png
457 | 21_197.png
458 | 21_2.png
459 | 21_204.png
460 | 21_215.png
461 | 21_216.png
462 | 21_230.png
463 | 21_238.png
464 | 21_241.png
465 | 21_244.png
466 | 21_28.png
467 | 21_4.png
468 | 21_6.png
469 | 21_70.png
470 | 21_85.png
471 | 21_95.png
472 | 23_104.png
473 | 23_106.png
474 |
--------------------------------------------------------------------------------
/dataset/__init__.py:
--------------------------------------------------------------------------------
1 | from dataset.datapoint import Dataset_point
2 | from dataset.data_sec import Dataset_sec
3 |
4 | import dataset.transform as trans
--------------------------------------------------------------------------------
/dataset/data_sec.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.utils.data import Dataset
3 | import numpy as np
4 | import os
5 | from utils import util
6 | from tqdm import tqdm
7 | import multiprocessing
8 | from torchvision import transforms
9 | import dataset.transform as trans
10 | import tifffile
11 | import matplotlib.pyplot as plt
12 | from dataset.data_utils import label_transform, value_to_rgb
13 |
14 |
15 | class Dataset_sec(Dataset):
16 | def __init__(self, opt, save_path, file_name_txt_path, flag='train', transform=None):
17 | # first round
18 | save_path = save_path + '/Point'
19 | # save_path = save_path + '/Second'
20 |
21 | # parameters from opt
22 | self.data_root = os.path.join(opt.data_root, opt.dataset)
23 | self.dataset = opt.dataset
24 |
25 | if flag == 'train':
26 | self.img_path = self.data_root + '/train/img'
27 | self.label_path = save_path + '/predict_train/crf/label'
28 |
29 | elif flag == 'val':
30 | self.img_path = self.data_root + '/val/img'
31 | self.label_path = save_path + '/predict_val/crf/label'
32 |
33 | elif flag == 'predict_train':
34 | self.img_path = self.data_root + '/train/img'
35 | self.label_path = self.data_root + '/train/crf/label'
36 |
37 | elif flag == 'test':
38 | self.img_path = self.data_root + '/test/img'
39 | self.label_path = self.data_root + '/test/label'
40 |
41 | self.img_size = opt.img_size
42 | self.num_classes = opt.num_classes
43 | self.in_channels = opt.in_channels
44 |
45 | self.img_txt_path = file_name_txt_path
46 | self.flag = flag
47 | self.transform = transform
48 | self.img_label_path_pairs = self.get_img_label_path_pairs()
49 |
50 | def get_img_label_path_pairs(self):
51 | img_label_pair_list = {}
52 |
53 | with open(self.img_txt_path, 'r') as lines:
54 | for idx, line in enumerate(tqdm(lines)):
55 | name = line.strip("\n").split(' ')[0]
56 | path = os.path.join(self.img_path, name)
57 | img_label_pair_list.setdefault(idx, [path, name])
58 |
59 | return img_label_pair_list
60 |
61 | def make_clslabel(self, label):
62 | label_set = np.unique(label)
63 | cls_label = np.zeros(self.num_classes)
64 | for i in label_set:
65 | if i != 255:
66 | cls_label[i] += 1
67 | cls_label = torch.from_numpy(cls_label).float()
68 | return cls_label
69 |
70 | def data_transform(self, img, label):
71 | img = img[:, :, :self.in_channels]
72 | img = img.astype(np.float32).transpose(2, 0, 1)
73 | img = torch.from_numpy(img).float()
74 | if len(label.shape) == 3:
75 | label = label[:, :, 0]
76 | label = label.copy()
77 | label = torch.from_numpy(label).long()
78 |
79 | return img, label
80 |
81 | def __getitem__(self, index):
82 | item = self.img_label_path_pairs[index]
83 | img_path, name = item
84 |
85 | img = util.read(img_path)[:, :, :self.in_channels]
86 | label = util.read(os.path.join(self.label_path, name))
87 | if len(label.shape) != 3:
88 | label = np.expand_dims(label, axis=-1)
89 |
90 | if self.flag == 'train':
91 | if self.transform is not None:
92 | for t in self.transform.transforms:
93 | img, label = t([img, label])
94 |
95 | img = util.Normalize(img, flag=self.dataset)
96 | img, label = self.data_transform(img, label)
97 | return img, label, name
98 |
99 | def __len__(self):
100 |
101 | return len(self.img_label_path_pairs)
102 |
103 |
104 | if __name__ == "__main__":
105 |
106 | from utils import *
107 | from options import *
108 | from torch.utils.data import DataLoader
109 |
110 | opt = Sec_Options().parse()
111 | print(opt)
112 | save_path = os.path.join(opt.save_path, opt.dataset)
113 | train_txt_path, val_txt_path, test_txt_path = create_data_path(opt)
114 |
115 | train_transform = transforms.Compose([
116 | trans.Color_Aug(),
117 | trans.RandomHorizontalFlip(),
118 | trans.RandomVerticleFlip(),
119 | trans.RandomRotate90(),
120 | ])
121 | dataset = Dataset_sec(opt, save_path, train_txt_path, flag='train', transform=train_transform)
122 |
123 | loader = DataLoader(
124 | dataset=dataset,
125 | batch_size=1, num_workers=8, pin_memory=True, drop_last=True
126 | )
127 |
128 | for i in tqdm(loader):
129 | img, label, name = i
130 |
131 | pass
--------------------------------------------------------------------------------
/dataset/data_utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 |
4 |
5 | def label_transform(dataset, label):
6 | label_new = label
7 | if dataset == 'potsdam' or dataset == 'vaihingen':
8 | label_new[label == 5] = 255
9 |
10 | else:
11 | return label_new
12 |
13 | return label_new
14 |
15 |
16 | def value_to_rgb(anno, flag='potsdam'):
17 | if flag == 'potsdam' or flag == 'vaihingen':
18 | label2color_dict = {
19 | 0: [255, 255, 255], # Impervious surfaces (RGB: 255, 255, 255)
20 | 1: [0, 0, 255], # Building (RGB: 0, 0, 255)
21 | 2: [0, 255, 255], # Low vegetation (RGB: 0, 255, 255)
22 | 3: [0, 255, 0], # Tree (RGB: 0, 255, 0)
23 | 4: [255, 255, 0], # Car (RGB: 255, 255, 0)
24 | 5: [255, 0, 0], # Clutter/background (RGB: 255, 0, 0)
25 | 255: [0, 0, 0]
26 | }
27 | else:
28 | label2color_dict = {}
29 |
30 | # visualize
31 | visual_anno = np.zeros((anno.shape[0], anno.shape[1], 3), dtype=np.uint8)
32 | for i in range(visual_anno.shape[0]): # i for h
33 | for j in range(visual_anno.shape[1]):
34 | # cv2: bgr
35 | color = label2color_dict[anno[i, j, 0]]
36 |
37 | visual_anno[i, j, 0] = color[0]
38 | visual_anno[i, j, 1] = color[1]
39 | visual_anno[i, j, 2] = color[2]
40 |
41 | return visual_anno
--------------------------------------------------------------------------------
/dataset/datapoint.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.utils.data import Dataset
3 | import numpy as np
4 | import os
5 | from utils import util
6 | from tqdm import tqdm
7 | import multiprocessing
8 | from torchvision import transforms
9 | import dataset.transform as trans
10 | import tifffile
11 | import matplotlib.pyplot as plt
12 | from dataset.data_utils import label_transform
13 |
14 |
15 | class Dataset_point(Dataset):
16 | def __init__(self, opt, file_name_txt_path, flag='train', transform=None):
17 | # parameters from opt
18 | self.data_root = os.path.join(opt.data_root, opt.dataset)
19 | self.dataset = opt.dataset
20 |
21 | if flag == 'test':
22 | self.img_path = self.data_root + '/test/img'
23 | if self.dataset == 'vaihingen':
24 | self.label_path = self.data_root + '/test/label_noB'
25 | else:
26 | self.label_path = self.data_root + '/test/label'
27 | elif flag == 'train':
28 | self.img_path = self.data_root + '/train/img'
29 | self.label_path = self.data_root + '/train/point_label'
30 | elif flag == 'val':
31 | self.img_path = self.data_root + '/val/img'
32 | self.label_path = self.data_root + '/val/point_label'
33 |
34 | elif flag == 'predict_train':
35 | self.img_path = self.data_root + '/train/img'
36 | self.label_path = self.data_root + '/train/label'
37 |
38 | self.img_size = opt.img_size
39 | self.num_classes = opt.num_classes
40 | self.in_channels = opt.in_channels
41 |
42 | self.img_txt_path = file_name_txt_path
43 | self.flag = flag
44 | self.transform = transform
45 | self.img_label_path_pairs = self.get_img_label_path_pairs()
46 |
47 | def get_img_label_path_pairs(self):
48 | img_label_pair_list = {}
49 |
50 | with open(self.img_txt_path, 'r') as lines:
51 | for idx, line in enumerate(tqdm(lines)):
52 | name = line.strip("\n").split(' ')[0]
53 | path = os.path.join(self.img_path, name)
54 | img_label_pair_list.setdefault(idx, [path, name])
55 |
56 | return img_label_pair_list
57 |
58 | def make_clslabel(self, label):
59 | label_set = np.unique(label)
60 | cls_label = np.zeros(self.num_classes)
61 | for i in label_set:
62 | if i != 255:
63 | cls_label[i] += 1
64 | return cls_label
65 |
66 | def data_transform(self, img, label, cls_label):
67 | img = img[:, :, :self.in_channels]
68 | img = img.astype(np.float32).transpose(2, 0, 1)
69 | img = torch.from_numpy(img).float()
70 | if len(label.shape) == 3:
71 | label = label[:, :, 0]
72 |
73 | label, cls_label = label.copy(), cls_label.copy()
74 |
75 | label = torch.from_numpy(label).long()
76 | cls_label = torch.from_numpy(cls_label).float()
77 |
78 | return img, label, cls_label
79 |
80 | def __getitem__(self, index):
81 | item = self.img_label_path_pairs[index]
82 | img_path, name = item
83 |
84 | img = util.read(img_path)[:, :, :3]
85 | img = util.Normalize(img, flag=self.dataset)
86 | label = util.read(os.path.join(self.label_path, name))
87 |
88 | label = label_transform(self.dataset, label)
89 | cls_label = self.make_clslabel(label)
90 |
91 | # data transform
92 | if self.transform is not None:
93 | for t in self.transform.transforms:
94 | img, label = t([img, label])
95 |
96 | img, label, cls_label = self.data_transform(img, label, cls_label)
97 | return img, label, cls_label, name
98 |
99 | def __len__(self):
100 |
101 | return len(self.img_label_path_pairs)
102 |
103 |
104 | def value_to_rgb(anno, flag='potsdam'):
105 | if flag == 'potsdam' or flag == 'vaihingen':
106 | label2color_dict = {
107 | 0: [255, 255, 255], # Impervious surfaces (RGB: 255, 255, 255)
108 | 1: [0, 0, 255], # Building (RGB: 0, 0, 255)
109 | 2: [0, 255, 255], # Low vegetation (RGB: 0, 255, 255)
110 | 3: [0, 255, 0], # Tree (RGB: 0, 255, 0)
111 | 4: [255, 255, 0], # Car (RGB: 255, 255, 0)
112 | 5: [255, 0, 0], # Clutter/background (RGB: 255, 0, 0)
113 | 255: [0, 0, 0]
114 | }
115 | else:
116 | label2color_dict = {}
117 |
118 | # visualize
119 | visual_anno = np.zeros((anno.shape[0], anno.shape[1], 3), dtype=np.uint8)
120 | for i in range(visual_anno.shape[0]): # i for h
121 | for j in range(visual_anno.shape[1]):
122 | # cv2: bgr
123 | color = label2color_dict[anno[i, j, 0]]
124 |
125 | visual_anno[i, j, 0] = color[0]
126 | visual_anno[i, j, 1] = color[1]
127 | visual_anno[i, j, 2] = color[2]
128 |
129 | return visual_anno
130 |
131 |
132 | def show_pair(opt, img, label, name):
133 | true_label = read('/media/hlf/Luffy/WLS/semantic/dataset/'+opt.dataset+'/train/label_vis/' + name[0])
134 | fig, axs = plt.subplots(1, 3, figsize=(14, 4))
135 |
136 | img = img.squeeze(0).permute(1, 2, 0).cpu().numpy()
137 | img = Normalize_back(img, flag=opt.dataset)
138 | axs[0].imshow(img[:, :, :3].astype(np.uint8))
139 | axs[0].axis("off")
140 |
141 | label = label.permute(1, 2, 0).cpu().numpy()
142 | print(np.unique(label))
143 | vis = value_to_rgb(label, flag=opt.dataset)
144 | axs[1].imshow(vis.astype(np.uint8))
145 | axs[1].axis("off")
146 |
147 | axs[2].imshow(true_label.astype(np.uint8))
148 | axs[2].axis("off")
149 |
150 | plt.tight_layout()
151 | plt.show()
152 | plt.close()
153 |
154 |
155 | if __name__ == "__main__":
156 |
157 | from utils import *
158 | from options import *
159 | from torch.utils.data import DataLoader
160 |
161 | opt = Point_Options().parse()
162 | print(opt)
163 | train_txt_path, val_txt_path, test_txt_path = create_data_path(opt)
164 |
165 | train_transform = transforms.Compose([
166 | trans.Scale(opt.img_size),
167 | trans.RandomHorizontalFlip(),
168 | trans.RandomVerticleFlip(),
169 | trans.RandomRotate90(),
170 | ])
171 | dataset = Dataset_point(opt, train_txt_path, flag='train', transform=None)
172 |
173 | loader = DataLoader(
174 | dataset=dataset, shuffle=True,
175 | batch_size=1, num_workers=8, pin_memory=True, drop_last=True
176 | )
177 |
178 | for i in tqdm(loader):
179 | img, label, cls_label, name = i
180 | show_pair(opt, img, label, name)
181 | pass
--------------------------------------------------------------------------------
/dataset/transform.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | import numpy as np
3 | import os
4 | import math
5 | import collections
6 | import random
7 |
8 |
9 | class Scale(object):
10 | """Rescale the input PIL.Image to the given size.
11 |
12 | Args:
13 | size (sequence or int): Desired output size. If size is a sequence like
14 | (w, h), output size will be matched to this. If size is an int,
15 | smaller edge of the image will be matched to this number.
16 | i.e, if height > width, then image will be rescaled to
17 | (size * height / width, size)
18 | interpolation (int, optional): Desired interpolation. Default is
19 | ``PIL.Image.BILINEAR``
20 | """
21 |
22 | def __init__(self, size):
23 | assert isinstance(size, int) or (isinstance(size, collections.Iterable) and len(size) == 2)
24 | self.size = size
25 |
26 | def __call__(self, inputs):
27 | """
28 | Args:
29 | img (npy): Image to be scaled.
30 | """
31 | outs = []
32 | for input in inputs:
33 | h = w = self.size
34 | oh, ow, c = input.shape
35 | img = np.resize(input, (h, w, c))
36 | outs.append(img)
37 |
38 | return outs
39 |
40 |
41 | class RandomHorizontalFlip(object):
42 | def __init__(self, u=0.5):
43 | self.u = u
44 |
45 | def __call__(self, inputs):
46 | if np.random.random() < self.u:
47 | new_inputs = []
48 | for input in inputs:
49 | input = np.flip(input, 0)
50 | new_inputs.append(input)
51 | return new_inputs
52 | else:
53 | return inputs
54 |
55 |
56 | class RandomVerticleFlip(object):
57 | def __init__(self, u=0.5):
58 | self.u = u
59 |
60 | def __call__(self, inputs):
61 | if np.random.random() < self.u:
62 | new_inputs = []
63 | for input in inputs:
64 | input = np.flip(input, 1)
65 | new_inputs.append(input)
66 | return new_inputs
67 | else:
68 | return inputs
69 |
70 |
71 | class RandomRotate90(object):
72 | def __init__(self, u=0.5):
73 | self.u = u
74 |
75 | def __call__(self, inputs):
76 | if np.random.random() < self.u:
77 | new_inputs = []
78 | for input in inputs:
79 | input = np.rot90(input)
80 | new_inputs.append(input)
81 | return new_inputs
82 | else:
83 | return inputs
84 |
85 |
86 | class Color_Aug(object):
87 | def __init__(self):
88 | self.contra_adj = 0.1
89 | self.bright_adj = 0.1
90 |
91 | def __call__(self, image):
92 | n_ch = image.shape[-1]
93 | ch_mean = np.mean(image, axis=(0, 1), keepdims=True).astype(np.float32)
94 |
95 | contra_mul = np.random.uniform(1 - self.contra_adj, 1 + self.contra_adj, (1, 1, n_ch)).astype(
96 | np.float32
97 | )
98 | bright_mul = np.random.uniform(1 - self.bright_adj, 1 + self.bright_adj, (1, 1, n_ch)).astype(
99 | np.float32
100 | )
101 |
102 | image = (image - ch_mean) * contra_mul + ch_mean * bright_mul
103 |
104 | return image
105 |
106 |
107 | class RandomHueSaturationValue(object):
108 | def __init__(self):
109 | self.hue_shift_limit = (-25, 25)
110 | self.sat_shift_limit = (-15, 15)
111 | self.val_shift_limit = (-15, 15)
112 |
113 | def change(self, image):
114 | image = cv2.cvtColor(image, cv2.COLOR_BGR2HSV)
115 | h, s, v = cv2.split(image)
116 | hue_shift = np.random.randint(self.hue_shift_limit[0], self.hue_shift_limit[1] + 1)
117 | hue_shift = np.uint8(hue_shift)
118 | h += hue_shift
119 | sat_shift = np.random.uniform(self.sat_shift_limit[0], self.sat_shift_limit[1])
120 | s = cv2.add(s, sat_shift)
121 | val_shift = np.random.uniform(self.val_shift_limit[0], self.val_shift_limit[1])
122 | v = cv2.add(v, val_shift)
123 | image = cv2.merge((h, s, v))
124 | # image = cv2.merge((s, v))
125 | image = cv2.cvtColor(image, cv2.COLOR_HSV2BGR)
126 | return image
127 |
128 | def __call__(self, img):
129 | if np.random.random() < 0.5:
130 | if img.shape[-1] == 4:
131 | img_3 = img[:, :, :3]
132 | img_3 = self.change(img_3)
133 | img[:, :, :3] = img_3
134 | else:
135 | img = self.change(img)
136 |
137 | return img
138 |
139 |
140 | def stretchImage(data, s=0.005, bins=2000): # 线性拉伸,去掉最大最小0.5%的像素值,然后线性拉伸至[0,1]
141 | ht = np.histogram(data, bins)
142 | d = np.cumsum(ht[0]) / float(data.size)
143 | lmin = 0
144 | lmax = bins - 1
145 | while lmin < bins:
146 | if d[lmin] >= s:
147 | break
148 | lmin += 1
149 | while lmax >= 0:
150 | if d[lmax] <= 1 - s:
151 | break
152 | lmax -= 1
153 | return np.clip((data - ht[1][lmin]) / (ht[1][lmax] - ht[1][lmin]), 0, 1)
154 |
155 |
156 | g_para = {}
157 |
158 |
159 | def getPara(radius=4): # 根据半径计算权重参数矩阵
160 | global g_para
161 | m = g_para.get(radius, None)
162 | if m is not None:
163 | return m
164 | size = radius * 2 + 1
165 | m = np.zeros((size, size))
166 | for h in range(-radius, radius + 1):
167 | for w in range(-radius, radius + 1):
168 | if h == 0 and w == 0:
169 | continue
170 | m[radius + h, radius + w] = 1.0 / math.sqrt(h ** 2 + w ** 2)
171 | m /= m.sum()
172 | g_para[radius] = m
173 | return m
174 |
175 |
176 | def zmIce(I, ratio=4, radius=300): # 常规的ACE实现
177 | para = getPara(radius)
178 | height, width = I.shape
179 | # zh,zw = [0]*radius + range(height) + [height-1]*radius, [0]*radius + range(width) + [width -1]*radius
180 | zh, zw = [0] * radius + [x for x in range(height)] + [height - 1] * radius, [0] * radius + [x for x in
181 | range(width)] + [
182 | width - 1] * radius
183 | Z = I[np.ix_(zh, zw)]
184 | res = np.zeros(I.shape)
185 | for h in range(radius * 2 + 1):
186 | for w in range(radius * 2 + 1):
187 | if para[h][w] == 0:
188 | continue
189 | res += (para[h][w] * np.clip((I - Z[h:h + height, w:w + width]) * ratio, -1, 1))
190 | return res
191 |
192 |
193 | def zmIceFast(I, ratio=4, radius=300): # 单通道ACE快速增强实现
194 | height, width = I.shape[:2]
195 | if min(height, width) <= 2:
196 | return np.zeros(I.shape) + 0.5
197 | Rs = cv2.resize(I, ((width + 1) // 2, (height + 1) // 2))
198 | Rf = zmIceFast(Rs, ratio, radius) # 递归调用
199 | Rf = cv2.resize(Rf, (width, height))
200 | Rs = cv2.resize(Rs, (width, height))
201 |
202 | return Rf + zmIce(I, ratio, radius) - zmIce(Rs, ratio, radius)
203 |
204 |
205 | def zmIceColor(I, ratio=4, radius=3): # rgb三通道分别增强,ratio是对比度增强因子,radius是卷积模板半径
206 | res = np.zeros(I.shape)
207 | for k in range(3):
208 | res[:, :, k] = stretchImage(zmIceFast(I[:, :, k], ratio, radius))
209 | return res
210 |
211 |
212 | def do_gamma(image, gamma=1.0):
213 | image = image ** (1.0 / gamma)
214 | image = np.clip(image, 0, 1)
215 | return image
--------------------------------------------------------------------------------
/models/DBFNet.py:
--------------------------------------------------------------------------------
1 | from models.base_model import *
2 | from models.tools import *
3 | import os
4 | from models.fbf import FBFModule
5 |
6 | def clean_mask(mask, cls_label):
7 | n, c = cls_label.size()
8 | """Remove any masks of labels that are not present"""
9 | return mask * cls_label.view(n, c, 1, 1)
10 |
11 |
12 | def get_penalty(predict, cls_label):
13 | # cls_label: (n, c)
14 | # predict: (n, c, h, w)
15 | n, c, h, w = predict.size()
16 | predict = torch.softmax(predict, dim=1)
17 |
18 | # if a patch does not contain label c,
19 | # then none of the pixels in this patch can be assigned to label c
20 | loss0 = - (1 - cls_label.view(n, c, 1, 1)) * torch.log(1 - predict + 1e-6)
21 | loss0 = torch.mean(torch.sum(loss0, dim=1))
22 |
23 | # if a patch has only one type, then the whole patch should be assigned to this type
24 | sum = (torch.sum(cls_label, dim=-1, keepdim=True) == 1)
25 | loss1 = - (sum * cls_label).view(n, c, 1, 1) * torch.log(predict + 1e-6)
26 | loss1 = torch.mean(torch.sum(loss1, dim=1))
27 | return loss0 + loss1
28 |
29 |
30 | def build_channels(backbone):
31 | if backbone == 'resnet34' or backbone == 'resnet18':
32 | channels = [64, 128, 256, 512]
33 |
34 | else:
35 | channels = [256, 512, 1024, 2048]
36 |
37 | return channels
38 |
39 |
40 | def get_numiter_dilations(flag):
41 | if flag == 'train':
42 |
43 | # for potsdam
44 | num_iter = 1
45 | dilations = [[1, 3, 5, 7], [1, 3, 5], [1, 3], [1]]
46 |
47 | else:
48 | # for potsdam
49 | num_iter = 3
50 | dilations = [[1], [1], [1], [1]]
51 |
52 | # for vaihingen
53 | # num_iter = 5
54 | # dilations = [[1], [1], [1], [1]]
55 | return num_iter, dilations
56 |
57 |
58 | class FBF_Layer(nn.Module):
59 | def __init__(self, in_c, out_c, num_iter, dilations):
60 | super(FBF_Layer, self).__init__()
61 | self.num_iter = num_iter
62 | self.fbf_m = FBFModule(num_iter=num_iter, dilations=dilations)
63 | self.conv = nn.Sequential(
64 | nn.Conv2d(in_c, out_c, kernel_size=1, stride=1, padding=0),
65 | nn.BatchNorm2d(out_c),
66 | nn.ReLU(), )
67 |
68 | def forward(self, x):
69 | # x = self.conv0(x)
70 | x = self.fbf_m(x)
71 | x = self.conv(x)
72 |
73 | return x
74 |
75 |
76 | class Net(nn.Module):
77 | def __init__(self, opt, flag='train'):
78 | super(Net, self).__init__()
79 | # stride as 32
80 | self.backbone = build_backbone(opt.backbone, output_stride=32)
81 | self.img_size = opt.img_size
82 |
83 | # build DBF layer
84 | num_iter, dilations = get_numiter_dilations(flag)
85 | channels = build_channels(opt.backbone)
86 | key_channels = 128
87 | self.FBF_layer1 = FBF_Layer(channels[0], key_channels, num_iter, dilations[0])
88 | self.FBF_layer2 = FBF_Layer(channels[1], key_channels, num_iter, dilations[1])
89 | self.FBF_layer3 = FBF_Layer(channels[2], key_channels, num_iter, dilations[2])
90 | self.FBF_layer4 = FBF_Layer(channels[3], key_channels, num_iter, dilations[3])
91 |
92 | self.last_conv = nn.Sequential(
93 | nn.Conv2d(4*128, 128, kernel_size=1, stride=1, padding=0, bias=False),
94 | nn.BatchNorm2d(128),
95 | nn.ReLU(),
96 | )
97 | self.FBF_layer = FBF_Layer(key_channels, key_channels, 1, dilations=[1])
98 | self.seg_decoder = nn.Sequential(
99 | nn.Conv2d(key_channels, key_channels, kernel_size=1, stride=1, padding=0),
100 | nn.BatchNorm2d(key_channels),
101 | nn.ReLU(),
102 | nn.Conv2d(key_channels, opt.num_classes, kernel_size=3, stride=1, padding=1),
103 | )
104 |
105 | def resize_out(self, output):
106 | if output.size()[-1] != self.img_size:
107 | output = F.interpolate(output, size=(self.img_size, self.img_size), mode='bilinear')
108 | return output
109 |
110 | def upsample_cat(self, p1, p2, p3, p4):
111 | p2 = nn.functional.interpolate(p2, size=p1.size()[2:], mode='bilinear', align_corners=True)
112 | p3 = nn.functional.interpolate(p3, size=p1.size()[2:], mode='bilinear', align_corners=True)
113 | p4 = nn.functional.interpolate(p4, size=p1.size()[2:], mode='bilinear', align_corners=True)
114 | return torch.cat([p1, p2, p3, p4], dim=1)
115 |
116 | def forward(self, x):
117 | # resnet 4 layers
118 | l1, l2, l3, l4 = self.backbone(x)
119 |
120 | # after feature bilateral filtering
121 | f1 = self.FBF_layer1(l1)
122 | f2 = self.FBF_layer2(l2)
123 | f3 = self.FBF_layer3(l3)
124 | f4 = self.FBF_layer4(l4)
125 |
126 | # fpn-like Top-down
127 | p4 = f4
128 | p3 = F.upsample(p4, size=f3.size()[2:], mode='bilinear') + f3
129 | p2 = F.upsample(p3, size=f2.size()[2:], mode='bilinear') + f2
130 | p1 = F.upsample(p2, size=f1.size()[2:], mode='bilinear') + f1
131 |
132 | cat = self.upsample_cat(p1, p2, p3, p4)
133 | feat = self.last_conv(cat)
134 | feat = self.FBF_layer(feat)
135 |
136 | out = self.seg_decoder(feat)
137 | out = self.resize_out(out)
138 |
139 | return out
140 |
141 | def forward_loss(self, x, label, cls_label):
142 | coarse_mask = self.forward(x)
143 |
144 | # get loss
145 | criterion = nn.CrossEntropyLoss(ignore_index=255)
146 | # penalty
147 | penalty = get_penalty(coarse_mask, cls_label)
148 | # label: point annotations point-level supervision
149 | seg_loss = criterion(coarse_mask, label)
150 |
151 | return seg_loss, penalty
152 |
153 |
154 | if __name__ == "__main__":
155 | import utils.util as util
156 | from options import *
157 | from torch.utils.data import DataLoader
158 | from dataset import *
159 | from tqdm import tqdm
160 | from torchvision import transforms
161 |
162 | opt = Point_Options().parse()
163 | save_path = os.path.join(opt.save_path, opt.dataset)
164 | train_txt_path, val_txt_path, test_txt_path = util.create_data_path(opt)
165 | log_path, checkpoint_path, predict_path, _, _ = util.create_save_path(opt)
166 |
167 | train_transform = transforms.Compose([
168 | trans.RandomHorizontalFlip(),
169 | trans.RandomVerticleFlip(),
170 | trans.RandomRotate90(),
171 | ])
172 | dataset = Dataset_point(opt, train_txt_path, flag='train', transform=train_transform)
173 |
174 | loader = DataLoader(
175 | dataset=dataset,
176 | batch_size=1, num_workers=8, pin_memory=True, drop_last=True
177 | )
178 | net = Net(opt, flag='train')
179 | net.cuda()
180 | torch.save({'state_dict': net.state_dict()},
181 | os.path.join(checkpoint_path, 'model.pth'))
182 |
183 | for i in tqdm(loader):
184 | # put the data from loader to cuda
185 | img, label, cls_label, name = i
186 | input, label, cls_label = img.cuda(non_blocking=True), \
187 | label.cuda(non_blocking=True), cls_label.cuda(non_blocking=True)
188 |
189 | # model forward
190 | out = net.forward(input)
191 |
192 | pass
193 |
--------------------------------------------------------------------------------
/models/__init__.py:
--------------------------------------------------------------------------------
1 | from models.base_model import *
2 | from models.seg_model import *
3 | from models.DBFNet import *
4 |
--------------------------------------------------------------------------------
/models/base_model.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from models.network import *
5 | from options import *
6 |
7 |
8 | def build_backbone(backbone, output_stride, pretrained=True, in_c=3):
9 | if backbone == 'resnet50':
10 | return ResNet50(output_stride, pretrained=pretrained, in_c=in_c)
11 |
12 | elif backbone == 'resnet101':
13 | return ResNet101(output_stride, pretrained=pretrained, in_c=in_c)
14 |
15 | elif backbone == 'resnet34':
16 | return ResNet34(output_stride, pretrained=pretrained, in_c=in_c)
17 |
18 | elif backbone == 'resnet18':
19 | return ResNet18(output_stride, pretrained=pretrained, in_c=in_c)
20 |
21 | else:
22 | raise NotImplementedError
23 |
24 |
25 | def build_base_model(opt):
26 | # get in_channels
27 | if opt.base_model == 'Deeplab':
28 | model = DeepLab(backbone=opt.backbone, in_channels=opt.in_channels)
29 |
30 | elif opt.base_model == 'LinkNet':
31 | model = DLinkNet(opt)
32 |
33 | elif opt.base_model == 'FCN':
34 | model = FCN(opt)
35 |
36 | elif opt.base_model == 'UNet':
37 | model = UNet(opt)
38 |
39 | else:
40 | model = None
41 | print('model none')
42 | return model
43 |
44 |
45 | def build_channels(opt):
46 | if opt.base_model == 'Deeplab':
47 | channels = 256
48 | elif opt.base_model == 'STANet':
49 | channels = 128
50 | elif opt.base_model == 'LinkNet':
51 | channels = 32
52 | elif opt.base_model == 'FCN':
53 | channels = 32
54 | elif opt.base_model == 'UNet':
55 | channels = 32
56 | else:
57 | channels = None
58 | print('model none')
59 | return channels
60 |
61 |
62 |
--------------------------------------------------------------------------------
/models/fbf.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | import torch.nn as nn
4 |
5 |
6 | class LocalAffinity(nn.Module):
7 |
8 | def __init__(self, dilations=[1]):
9 | super(LocalAffinity, self).__init__()
10 | self.dilations = dilations
11 | weight = self._init_aff()
12 | self.register_buffer('kernel', weight)
13 |
14 | def _init_aff(self):
15 | # initialising the shift kernel
16 | weight = torch.zeros(8, 1, 3, 3, device='cuda')
17 |
18 | weight[0, 0, 0, 0] = 1
19 | weight[1, 0, 0, 1] = 1
20 | weight[2, 0, 0, 2] = 1
21 |
22 | weight[3, 0, 1, 0] = 1
23 | weight[4, 0, 1, 2] = 1
24 |
25 | weight[5, 0, 2, 0] = 1
26 | weight[6, 0, 2, 1] = 1
27 | weight[7, 0, 2, 2] = 1
28 |
29 | self.weight_check = weight.clone()
30 | return weight
31 |
32 | def forward(self, x):
33 | self.weight_check = self.weight_check.type_as(x)
34 | assert torch.all(self.weight_check.eq(self.kernel))
35 |
36 | B, K, H, W = x.size()
37 | x = x.view(B * K, 1, H, W)
38 |
39 | x_affs = []
40 | for d in self.dilations:
41 | x_pad = F.pad(x, [d] * 4, mode='replicate')
42 | x_aff = F.conv2d(x_pad, self.kernel, dilation=d)
43 | x_affs.append(x_aff)
44 |
45 | x_aff = torch.cat(x_affs, 1)
46 | return x_aff.view(B, K, -1, H, W)
47 |
48 |
49 | class FBFModule(nn.Module):
50 | def __init__(self, num_iter=5, dilations=[1]):
51 | # Dilated Local Affinity
52 | super(FBFModule, self).__init__()
53 | self.num_iter = num_iter
54 | self.aff_loc = LocalAffinity(dilations)
55 |
56 | def forward(self, feature):
57 | # feature: [BxCxHxW]
58 | n, c, h, w = feature.size()
59 |
60 | for _ in range(self.num_iter):
61 | f = self.aff_loc(feature) # [BxCxPxHxW]
62 | # dim2 represent the p neighbor-pixels' value
63 |
64 | abs = torch.abs(feature.unsqueeze(2) - f)
65 | aff = torch.exp(-torch.mean(abs, dim=1, keepdim=True))
66 | # aff = F.cosine_similarity(feature.unsqueeze(2), f, dim=1).unsqueeze(1)
67 |
68 | aff = aff/torch.sum(aff, dim=2, keepdim=True) # [Bx1xPxHxW]
69 | # print(aff[0, 0, :, h//2, w//2])
70 | # dim2 represent the p neighbor-pixels' affinity
71 |
72 | feature = torch.sum(f * aff, dim=2)
73 |
74 | return feature
75 |
76 |
77 | if __name__ == '__main__':
78 | feature = torch.randn([4, 64, 256, 256], device='cuda')
79 |
80 | model = FBFModule()
81 | output = model(feature)
82 | print(output.shape)
--------------------------------------------------------------------------------
/models/network/DeeplabV3.py:
--------------------------------------------------------------------------------
1 | # -*- coding:utf-8 -*-
2 |
3 | # @Filename: DeeplabV3
4 | # @Project : Glory
5 | # @date : 2020-12-28 21:52
6 | # @Author : Linshan
7 |
8 | import torch
9 | import math
10 | import torch.nn as nn
11 | import torch.nn.functional as F
12 | from itertools import chain
13 | from models.network.resnet import *
14 | from models.network.segformer import *
15 | '''
16 | -> ResNet BackBone
17 | '''
18 |
19 |
20 | def initialize_weights(*models):
21 | for model in models:
22 | for m in model.modules():
23 | if isinstance(m, nn.Conv2d):
24 | nn.init.kaiming_normal_(m.weight.data, nonlinearity='relu')
25 | elif isinstance(m, nn.BatchNorm2d):
26 | m.weight.data.fill_(1.)
27 | m.bias.data.fill_(1e-4)
28 | elif isinstance(m, nn.Linear):
29 | m.weight.data.normal_(0.0, 0.0001)
30 | m.bias.data.zero_()
31 |
32 |
33 | '''
34 | -> The Atrous Spatial Pyramid Pooling
35 | '''
36 |
37 |
38 | def assp_branch(in_channels, out_channles, kernel_size, dilation):
39 | padding = 0 if kernel_size == 1 else dilation
40 | return nn.Sequential(
41 | nn.Conv2d(in_channels, out_channles, kernel_size, padding=padding, dilation=dilation, bias=False),
42 | nn.BatchNorm2d(out_channles),
43 | nn.ReLU(inplace=True))
44 |
45 |
46 | class ASSP(nn.Module):
47 | def __init__(self, in_channels, aspp_stride):
48 | super(ASSP, self).__init__()
49 |
50 | assert aspp_stride in [8, 16], 'Only output strides of 8 or 16 are suported'
51 | if aspp_stride == 16:
52 | dilations = [1, 6, 12, 18]
53 | elif aspp_stride == 8:
54 | dilations = [1, 12, 24, 36]
55 |
56 | self.aspp1 = assp_branch(in_channels, 256, 1, dilation=dilations[0])
57 | self.aspp2 = assp_branch(in_channels, 256, 3, dilation=dilations[1])
58 | self.aspp3 = assp_branch(in_channels, 256, 3, dilation=dilations[2])
59 | self.aspp4 = assp_branch(in_channels, 256, 3, dilation=dilations[3])
60 |
61 | self.avg_pool = nn.Sequential(
62 | nn.AdaptiveAvgPool2d((1, 1)),
63 | nn.Conv2d(in_channels, 256, 1, bias=False),
64 | nn.BatchNorm2d(256),
65 | nn.ReLU(inplace=True))
66 |
67 | self.conv1 = nn.Conv2d(256 * 5, 256, 1, bias=False)
68 | self.bn1 = nn.BatchNorm2d(256)
69 | self.relu = nn.ReLU(inplace=True)
70 | self.dropout = nn.Dropout(0.5)
71 |
72 | initialize_weights(self)
73 |
74 | def forward(self, x):
75 | x1 = self.aspp1(x)
76 | x2 = self.aspp2(x)
77 | x3 = self.aspp3(x)
78 | x4 = self.aspp4(x)
79 |
80 | x5 = F.interpolate(self.avg_pool(x), size=(x.size(2), x.size(3)), mode='bilinear', align_corners=True)
81 |
82 | x = self.conv1(torch.cat((x1, x2, x3, x4, x5), dim=1))
83 | x = self.bn1(x)
84 | x = self.dropout(self.relu(x))
85 |
86 | return x
87 |
88 |
89 | '''
90 | -> Decoder
91 | '''
92 |
93 |
94 | class Decoder(nn.Module):
95 | def __init__(self, low_level_channels):
96 | super(Decoder, self).__init__()
97 | self.conv1 = nn.Conv2d(low_level_channels, 48, 1, bias=False)
98 | self.bn1 = nn.BatchNorm2d(48)
99 | self.relu = nn.ReLU(inplace=True)
100 |
101 | # Table 2, best performance with two 3x3 convs
102 | self.output = nn.Sequential(
103 | nn.Conv2d(48 + 256, 256, 3, stride=1, padding=1, bias=False),
104 | nn.BatchNorm2d(256),
105 | nn.ReLU(inplace=True),
106 | nn.Conv2d(256, 256, 3, stride=1, padding=1, bias=False),
107 | nn.BatchNorm2d(256),
108 | nn.ReLU(inplace=True)
109 | )
110 | initialize_weights(self)
111 |
112 | def forward(self, x, low_level_features):
113 | low_level_features = self.conv1(low_level_features)
114 | low_level_features = self.relu(self.bn1(low_level_features))
115 | H, W = low_level_features.size(2), low_level_features.size(3)
116 |
117 | x = F.interpolate(x, size=(H, W), mode='bilinear', align_corners=True)
118 | x = self.output(torch.cat((low_level_features, x), dim=1))
119 | return x
120 |
121 |
122 | '''
123 | -> Deeplab V3 +
124 | '''
125 |
126 |
127 | class DeepLab(nn.Module):
128 | def __init__(self, backbone, in_channels, pretrained=True,
129 | output_stride=16, aspp_stride=8, freeze_bn=False, **_):
130 |
131 | super(DeepLab, self).__init__()
132 | assert ('resnet' or 'resnest' in backbone)
133 |
134 | if 'resnet18' in backbone:
135 | self.backbone = ResNet18(output_stride=output_stride, pretrained=pretrained, in_c=in_channels)
136 | low_level_channels = 64
137 | aspp_channels = 512
138 |
139 | elif 'resnet34' in backbone:
140 | self.backbone = ResNet34(output_stride=output_stride, pretrained=pretrained, in_c=in_channels)
141 | low_level_channels = 64
142 | aspp_channels = 512
143 |
144 | elif 'resnet50' in backbone:
145 | self.backbone = ResNet50(output_stride=output_stride, pretrained=pretrained, in_c=in_channels)
146 | low_level_channels = 256
147 | aspp_channels = 2048
148 |
149 | elif 'resnet101' in backbone:
150 | self.backbone = ResNet101(output_stride=output_stride, pretrained=pretrained, in_c=in_channels)
151 | low_level_channels = 256
152 | aspp_channels = 2048
153 |
154 | elif 'mit_b0' in backbone:
155 | self.backbone = mit_b0()
156 | checkpoint = torch.load('/media/hlf/Luffy/WLS/ContestCD/pretrained_former/segformer_b0.pth', map_location=torch.device('cpu'))
157 | self.backbone.load_state_dict(checkpoint['state_dict'])
158 | print('load pretrained transformer')
159 | low_level_channels = 32
160 | aspp_channels = 256
161 |
162 | elif 'mit_b3' in backbone:
163 | self.backbone = mit_b3()
164 | checkpoint = torch.load('/media/hlf/Luffy/WLS/ContestCD/pretrained_former/segformer_b3.pth', map_location=torch.device('cpu'))
165 | self.backbone.load_state_dict(checkpoint['state_dict'])
166 | print('load pretrained transformer')
167 | low_level_channels = 64
168 | aspp_channels = 512
169 |
170 | else:
171 | low_level_channels = None
172 | aspp_channels = None
173 |
174 | self.ASSP = ASSP(in_channels=aspp_channels, aspp_stride=aspp_stride)
175 |
176 | self.decoder = Decoder(low_level_channels)
177 |
178 | if freeze_bn:
179 | self.freeze_bn()
180 |
181 | def forward(self, x):
182 | _, _, h, w = x.size()
183 | x = self.backbone(x)
184 |
185 | feature = self.ASSP(x[3])
186 | x = self.decoder(feature, x[0])
187 | # x = F.interpolate(x, size=(h, w), mode='bilinear')
188 |
189 | return x
190 |
191 | def freeze_bn(self):
192 | for module in self.modules():
193 | if isinstance(module, nn.BatchNorm2d): module.eval()
194 |
195 |
196 | if __name__ == '__main__':
197 |
198 | model = DeepLab(num_classes=4)
199 |
200 | x = torch.rand(2, 4, 256, 256)
201 | x = model(x)
202 |
203 | print(x.shape)
--------------------------------------------------------------------------------
/models/network/LinkNet.py:
--------------------------------------------------------------------------------
1 | # -*- coding:utf-8 -*-
2 |
3 | # @Filename: LinkNet
4 | # @Project : Glory
5 | # @date : 2020-12-28 21:29
6 | # @Author : Linshan
7 |
8 | import torch
9 | import torch.nn as nn
10 | from torch.autograd import Variable
11 | import torch.nn.functional as F
12 | from functools import partial
13 | from models.network.resnet import *
14 |
15 | nonlinearity = partial(F.relu, inplace=True)
16 |
17 |
18 | class Dblock_more_dilate(nn.Module):
19 | def __init__(self, channel):
20 | super(Dblock_more_dilate, self).__init__()
21 | self.dilate1 = nn.Conv2d(channel, channel, kernel_size=3, dilation=1, padding=1)
22 | self.dilate2 = nn.Conv2d(channel, channel, kernel_size=3, dilation=2, padding=2)
23 | self.dilate3 = nn.Conv2d(channel, channel, kernel_size=3, dilation=4, padding=4)
24 | self.dilate4 = nn.Conv2d(channel, channel, kernel_size=3, dilation=8, padding=8)
25 | self.dilate5 = nn.Conv2d(channel, channel, kernel_size=3, dilation=16, padding=16)
26 | for m in self.modules():
27 | if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
28 | if m.bias is not None:
29 | m.bias.data.zero_()
30 |
31 | def forward(self, x):
32 | dilate1_out = nonlinearity(self.dilate1(x))
33 | dilate2_out = nonlinearity(self.dilate2(dilate1_out))
34 | dilate3_out = nonlinearity(self.dilate3(dilate2_out))
35 | dilate4_out = nonlinearity(self.dilate4(dilate3_out))
36 | dilate5_out = nonlinearity(self.dilate5(dilate4_out))
37 | out = x + dilate1_out + dilate2_out + dilate3_out + dilate4_out + dilate5_out
38 | return out
39 |
40 |
41 | class Dblock(nn.Module):
42 | def __init__(self, channel):
43 | super(Dblock, self).__init__()
44 | self.dilate1 = nn.Conv2d(channel, channel, kernel_size=3, dilation=1, padding=1)
45 | self.dilate2 = nn.Conv2d(channel, channel, kernel_size=3, dilation=2, padding=2)
46 | self.dilate3 = nn.Conv2d(channel, channel, kernel_size=3, dilation=4, padding=4)
47 | self.dilate4 = nn.Conv2d(channel, channel, kernel_size=3, dilation=8, padding=8)
48 | # self.dilate5 = nn.Conv2d(channel, channel, kernel_size=3, dilation=16, padding=16)
49 | for m in self.modules():
50 | if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
51 | if m.bias is not None:
52 | m.bias.data.zero_()
53 |
54 | def forward(self, x):
55 | dilate1_out = nonlinearity(self.dilate1(x))
56 | dilate2_out = nonlinearity(self.dilate2(dilate1_out))
57 | dilate3_out = nonlinearity(self.dilate3(dilate2_out))
58 | dilate4_out = nonlinearity(self.dilate4(dilate3_out))
59 | # dilate5_out = nonlinearity(self.dilate5(dilate4_out))
60 | out = x + dilate1_out + dilate2_out + dilate3_out + dilate4_out # + dilate5_out
61 | return out
62 |
63 |
64 | class DecoderBlock(nn.Module):
65 | def __init__(self, in_channels, n_filters):
66 | super(DecoderBlock, self).__init__()
67 |
68 | self.conv1 = nn.Conv2d(in_channels, in_channels // 4, 1)
69 | self.norm1 = nn.BatchNorm2d(in_channels // 4)
70 | self.relu1 = nonlinearity
71 |
72 | self.deconv2 = nn.ConvTranspose2d(in_channels // 4, in_channels // 4, 3, stride=2, padding=1, output_padding=1)
73 | self.norm2 = nn.BatchNorm2d(in_channels // 4)
74 | self.relu2 = nonlinearity
75 |
76 | self.conv3 = nn.Conv2d(in_channels // 4, n_filters, 1)
77 | self.norm3 = nn.BatchNorm2d(n_filters)
78 | self.relu3 = nonlinearity
79 |
80 | def forward(self, x):
81 | x = self.conv1(x)
82 | x = self.norm1(x)
83 | x = self.relu1(x)
84 | x = self.deconv2(x)
85 | x = self.norm2(x)
86 | x = self.relu2(x)
87 | x = self.conv3(x)
88 | x = self.norm3(x)
89 | x = self.relu3(x)
90 | return x
91 |
92 |
93 | class DLinkNet(nn.Module):
94 | def __init__(self, opt):
95 | super(DLinkNet, self).__init__()
96 | if 'resnet50' in opt.backbone:
97 | self.backbone = ResNet50(pretrained=True)
98 | filters = [256, 512, 1024, 2048]
99 |
100 | elif 'resnet101' in opt.backbone:
101 | self.backbone = opt.ResNet101(pretrained=True)
102 | filters = [256, 512, 1024, 2048]
103 |
104 | elif 'resnet34' in opt.backbone:
105 | self.backbone = ResNet34(pretrained=True)
106 | filters = [64, 128, 256, 512]
107 |
108 | elif 'resnet18' in opt.backbone:
109 | self.backbone = ResNet18(pretrained=True)
110 | filters = [64, 128, 256, 512]
111 | else:
112 | filters = None
113 |
114 | self.dblock_master = Dblock(filters[3])
115 |
116 | self.decoder4_master = DecoderBlock(filters[3], filters[2])
117 | self.decoder3_master = DecoderBlock(filters[2], filters[1])
118 | self.decoder2_master = DecoderBlock(filters[1], filters[0])
119 | self.decoder1_master = DecoderBlock(filters[0], filters[0])
120 |
121 | self.finaldeconv1_master = nn.ConvTranspose2d(filters[0], 32, 4, 2, 1)
122 | self.finalrelu1_master = nonlinearity
123 |
124 | def forward(self, x):
125 | x = self.backbone(x)
126 |
127 | # center_master
128 | e4 = self.dblock_master(x[3])
129 | # decoder_master
130 | d4 = self.decoder4_master(e4)
131 | d3 = self.decoder3_master(d4)
132 | d2 = self.decoder2_master(d3)
133 | d1 = self.decoder1_master(d2)
134 |
135 | out = self.finaldeconv1_master(d1)
136 | out = self.finalrelu1_master(out)
137 |
138 | return out
139 |
140 |
141 | if __name__ == '__main__':
142 |
143 | model = DLinkNet()
144 |
145 | x = torch.rand(2, 4, 512, 512)
146 | x = model(x)
147 |
148 | print(x.shape)
--------------------------------------------------------------------------------
/models/network/__init__.py:
--------------------------------------------------------------------------------
1 | from models.network.LinkNet import *
2 | from models.network.DeeplabV3 import *
3 | from models.network.STANet import STANet
4 | from models.network.fcn import FCN, UNet
5 | from models.network.resnet import *
6 |
7 |
8 |
--------------------------------------------------------------------------------
/models/network/fcn.py:
--------------------------------------------------------------------------------
1 | import functools
2 |
3 | import numpy as np
4 |
5 | import torch
6 | import torch.nn as nn
7 | import torch.nn.functional as F
8 | from models.network.LinkNet import *
9 |
10 |
11 | class FCN(nn.Module):
12 | def __init__(self, opt):
13 | super(FCN, self).__init__()
14 | if 'resnet50' in opt.backbone:
15 | self.backbone = ResNet50(pretrained=True)
16 | filters = [256, 512, 1024, 2048]
17 |
18 | elif 'resnet101' in opt.backbone:
19 | self.backbone = opt.ResNet101(pretrained=True)
20 | filters = [256, 512, 1024, 2048]
21 |
22 | elif 'resnet34' in opt.backbone:
23 | self.backbone = ResNet34(pretrained=True)
24 | filters = [64, 128, 256, 512]
25 |
26 | elif 'resnet18' in opt.backbone:
27 | self.backbone = ResNet18(pretrained=True)
28 | filters = [64, 128, 256, 512]
29 | else:
30 | filters = None
31 |
32 | self.decoder4_master = DecoderBlock(filters[3], filters[2])
33 | self.decoder3_master = DecoderBlock(filters[2], filters[1])
34 | self.decoder2_master = DecoderBlock(filters[1], filters[0])
35 | self.decoder1_master = DecoderBlock(filters[0], filters[0])
36 |
37 | self.finaldeconv1_master = nn.ConvTranspose2d(filters[0], 32, 4, 2, 1)
38 | self.finalrelu1_master = nonlinearity
39 |
40 | def forward(self, x):
41 | x = self.backbone(x)
42 |
43 | # center_master
44 | e4 = x[3]
45 | # decoder_master
46 | d4 = self.decoder4_master(e4)
47 | d3 = self.decoder3_master(d4)
48 | d2 = self.decoder2_master(d3)
49 | d1 = self.decoder1_master(d2)
50 |
51 | out = self.finaldeconv1_master(d1)
52 | out = self.finalrelu1_master(out)
53 |
54 | return out
55 |
56 |
57 | class UNet(nn.Module):
58 | def __init__(self, opt):
59 | super(UNet, self).__init__()
60 | if 'resnet50' in opt.backbone:
61 | self.backbone = ResNet50(pretrained=True)
62 | filters = [256, 512, 1024, 2048]
63 |
64 | elif 'resnet101' in opt.backbone:
65 | self.backbone = opt.ResNet101(pretrained=True)
66 | filters = [256, 512, 1024, 2048]
67 |
68 | elif 'resnet34' in opt.backbone:
69 | self.backbone = ResNet34(pretrained=True)
70 | filters = [64, 128, 256, 512]
71 |
72 | elif 'resnet18' in opt.backbone:
73 | self.backbone = ResNet18(pretrained=True)
74 | filters = [64, 128, 256, 512]
75 | else:
76 | filters = None
77 |
78 | self.decoder4_master = DecoderBlock(filters[3], filters[2])
79 | self.decoder3_master = DecoderBlock(filters[2]*2, filters[1])
80 | self.decoder2_master = DecoderBlock(filters[1]*2, filters[0])
81 | self.decoder1_master = DecoderBlock(filters[0]*2, filters[0])
82 |
83 | self.finaldeconv1_master = nn.ConvTranspose2d(filters[0], 32, 4, 2, 1)
84 | self.finalrelu1_master = nonlinearity
85 |
86 | def forward(self, x):
87 | x = self.backbone(x)
88 |
89 | # center_master
90 | e4 = x[3]
91 | # decoder_master
92 | d4 = torch.cat([self.decoder4_master(e4), x[2]], dim=1)
93 | d3 = torch.cat([self.decoder3_master(d4), x[1]], dim=1)
94 | d2 = torch.cat([self.decoder2_master(d3), x[0]], dim=1)
95 | d1 = self.decoder1_master(d2)
96 |
97 | out = self.finaldeconv1_master(d1)
98 | out = self.finalrelu1_master(out)
99 |
100 |
101 | return out
--------------------------------------------------------------------------------
/models/network/fpn.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch.nn as nn
3 | import torch.utils.model_zoo as model_zoo
4 | import torch.nn.functional as F
5 |
6 |
7 | def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
8 | """3x3 convolution with padding"""
9 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
10 | padding=dilation, groups=groups, bias=False, dilation=dilation)
11 |
12 |
13 | def conv1x1(in_planes, out_planes, stride=1):
14 | """1x1 convolution"""
15 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
16 |
17 |
18 | class BasicBlock(nn.Module):
19 | expansion = 1
20 |
21 | def __init__(self, inplanes, planes, stride=1, dilation=1, downsample=None, BatchNorm=None, groups=1,
22 | base_width=64):
23 | super(BasicBlock, self).__init__()
24 | if BatchNorm is None:
25 | BatchNorm = nn.BatchNorm2d
26 | if groups != 1 or base_width != 64:
27 | raise ValueError('BasicBlock only supports groups=1 and base_width=64')
28 | if dilation > 1:
29 | raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
30 | # Both self.conv1 and self.downsample layers downsample the input when stride != 1
31 | self.conv1 = conv3x3(inplanes, planes, stride)
32 | self.bn1 = BatchNorm(planes)
33 | self.relu = nn.ReLU(inplace=True)
34 | self.conv2 = conv3x3(planes, planes)
35 | self.bn2 = BatchNorm(planes)
36 | self.downsample = downsample
37 | self.stride = stride
38 |
39 | def forward(self, x):
40 | identity = x
41 |
42 | out = self.conv1(x)
43 | out = self.bn1(out)
44 | out = self.relu(out)
45 |
46 | out = self.conv2(out)
47 | out = self.bn2(out)
48 |
49 | if self.downsample is not None:
50 | identity = self.downsample(x)
51 |
52 | out += identity
53 | out = self.relu(out)
54 |
55 | return out
56 |
57 |
58 | class Bottleneck(nn.Module):
59 | expansion = 4
60 |
61 | def __init__(self, inplanes, planes, stride=1, dilation=1, downsample=None, BatchNorm=None, groups=1,
62 | base_width=64):
63 | super(Bottleneck, self).__init__()
64 | width = int(planes * (base_width / 64.)) * groups
65 | self.conv1 = nn.Conv2d(inplanes, width, kernel_size=1, bias=False)
66 | self.bn1 = BatchNorm(width)
67 | self.conv2 = nn.Conv2d(width, width, kernel_size=3, stride=stride,
68 | dilation=dilation, padding=dilation, bias=False, groups=groups)
69 | self.bn2 = BatchNorm(width)
70 | self.conv3 = nn.Conv2d(width, planes * 4, kernel_size=1, bias=False)
71 | self.bn3 = BatchNorm(planes * 4)
72 | self.relu = nn.ReLU(inplace=True)
73 | self.downsample = downsample
74 | self.stride = stride
75 | self.dilation = dilation
76 |
77 | def forward(self, x):
78 | residual = x
79 |
80 | out = self.conv1(x)
81 | out = self.bn1(out)
82 | out = self.relu(out)
83 |
84 | out = self.conv2(out)
85 | out = self.bn2(out)
86 | out = self.relu(out)
87 |
88 | out = self.conv3(out)
89 | out = self.bn3(out)
90 |
91 | if self.downsample is not None:
92 | residual = self.downsample(x)
93 |
94 | out += residual
95 | out = self.relu(out)
96 |
97 | return out
98 |
99 |
100 | class ResNet(nn.Module):
101 |
102 | def __init__(self, arch, block, layers, output_stride, BatchNorm, pretrained=True):
103 | self.inplanes = 64
104 | self.layers = layers
105 | self.arch = arch
106 | super(ResNet, self).__init__()
107 | blocks = [1, 2, 4]
108 | if output_stride == 16:
109 | strides = [1, 2, 2, 1]
110 | dilations = [1, 1, 1, 2]
111 | elif output_stride == 8:
112 | strides = [1, 2, 1, 1]
113 | dilations = [1, 1, 2, 4]
114 | else:
115 | strides = [1, 2, 2, 2]
116 | dilations = [1, 1, 1, 1]
117 |
118 | if arch == 'resnext50':
119 | self.base_width = 4
120 | self.groups = 32
121 | else:
122 | self.base_width = 64
123 | self.groups = 1
124 |
125 | # Modules
126 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,
127 | bias=False)
128 | self.bn1 = BatchNorm(64)
129 | self.relu = nn.ReLU(inplace=True)
130 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
131 |
132 | self.layer1 = self._make_layer(block, 64, layers[0], stride=strides[0], dilation=dilations[0],
133 | BatchNorm=BatchNorm)
134 | self.layer2 = self._make_layer(block, 128, layers[1], stride=strides[1], dilation=dilations[1],
135 | BatchNorm=BatchNorm)
136 | self.layer3 = self._make_layer(block, 256, layers[2], stride=strides[2], dilation=dilations[2],
137 | BatchNorm=BatchNorm)
138 | if self.arch == 'resnet18':
139 | self.layer4 = self._make_layer(block, 512, layers[3], stride=strides[3], dilation=dilations[3],
140 | BatchNorm=BatchNorm)
141 | else:
142 | self.layer4 = self._make_MG_unit(block, 512, blocks=blocks, stride=strides[3], dilation=dilations[3],
143 | BatchNorm=BatchNorm)
144 |
145 | if self.arch == 'resnet18':
146 | self.toplayer = nn.Conv2d(512, 128, kernel_size=1, stride=1, padding=0)
147 | self.latlayer1 = nn.Conv2d(256, 128, kernel_size=1, stride=1, padding=0)
148 | self.latlayer2 = nn.Conv2d(128, 128, kernel_size=1, stride=1, padding=0)
149 | self.latlayer3 = nn.Conv2d(64, 128, kernel_size=1, stride=1, padding=0)
150 | else:
151 | self.toplayer = nn.Conv2d(2048, 256, kernel_size=1, stride=1, padding=0)
152 | self.latlayer1 = nn.Conv2d(1024, 256, kernel_size=1, stride=1, padding=0)
153 | self.latlayer2 = nn.Conv2d(512, 256, kernel_size=1, stride=1, padding=0)
154 | self.latlayer3 = nn.Conv2d(256, 256, kernel_size=1, stride=1, padding=0)
155 |
156 | self.smooth1 = nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1)
157 | self.smooth2 = nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1)
158 | self.smooth3 = nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1)
159 |
160 | self._init_weight()
161 |
162 | if pretrained:
163 | self._load_pretrained_model()
164 |
165 | def _make_layer(self, block, planes, blocks, stride=1, dilation=1, BatchNorm=None):
166 | downsample = None
167 | if stride != 1 or self.inplanes != planes * block.expansion:
168 | downsample = nn.Sequential(
169 | nn.Conv2d(self.inplanes, planes * block.expansion,
170 | kernel_size=1, stride=stride, bias=False),
171 | BatchNorm(planes * block.expansion),
172 | )
173 |
174 | layers = []
175 | layers.append(block(self.inplanes, planes, stride, dilation, downsample, BatchNorm, groups=self.groups,
176 | base_width=self.base_width))
177 | self.inplanes = planes * block.expansion
178 | for i in range(1, blocks):
179 | layers.append(block(self.inplanes, planes, dilation=dilation, BatchNorm=BatchNorm, groups=self.groups,
180 | base_width=self.base_width))
181 |
182 | return nn.Sequential(*layers)
183 |
184 | def _make_MG_unit(self, block, planes, blocks, stride=1, dilation=1, BatchNorm=None):
185 | downsample = None
186 | if stride != 1 or self.inplanes != planes * block.expansion:
187 | downsample = nn.Sequential(
188 | nn.Conv2d(self.inplanes, planes * block.expansion,
189 | kernel_size=1, stride=stride, bias=False),
190 | BatchNorm(planes * block.expansion),
191 | )
192 |
193 | layers = []
194 | layers.append(block(self.inplanes, planes, stride, dilation=blocks[0] * dilation,
195 | downsample=downsample, BatchNorm=BatchNorm, groups=self.groups, base_width=self.base_width))
196 | self.inplanes = planes * block.expansion
197 | for i in range(1, len(blocks)):
198 | layers.append(block(self.inplanes, planes, stride=1,
199 | dilation=blocks[i] * dilation, BatchNorm=BatchNorm, groups=self.groups,
200 | base_width=self.base_width))
201 |
202 | return nn.Sequential(*layers)
203 |
204 | def forward(self, input):
205 | # Bottom-up
206 | x = self.conv1(input)
207 | x = self.bn1(x)
208 | x = self.relu(x)
209 | c1 = self.maxpool(x)
210 | c2 = self.layer1(c1) # x4
211 | c3 = self.layer2(c2) # x8
212 | c4 = self.layer3(c3) # x16
213 | c5 = self.layer4(c4) # x16
214 |
215 | # Top-down
216 | p5 = self.toplayer(c5)
217 | p4 = F.upsample(p5, size=c4.size()[2:], mode='bilinear') + self.latlayer1(c4)
218 | p3 = F.upsample(p4, size=c3.size()[2:], mode='bilinear') + self.latlayer2(c3)
219 | p2 = F.upsample(p3, size=c2.size()[2:], mode='bilinear') + self.latlayer3(c2)
220 |
221 | p4 = self.smooth1(p4)
222 | p3 = self.smooth2(p3)
223 | p2 = self.smooth3(p2)
224 |
225 | return p2, p3, p4, p5
226 |
227 | def _init_weight(self):
228 | for m in self.modules():
229 | if isinstance(m, nn.Conv2d):
230 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
231 | m.weight.data.normal_(0, math.sqrt(2. / n))
232 | elif isinstance(m, nn.BatchNorm2d):
233 | m.weight.data.fill_(1)
234 | m.bias.data.zero_()
235 |
236 | def _load_pretrained_model(self):
237 | if self.arch == 'resnet101':
238 | pretrain_dict = model_zoo.load_url('https://download.pytorch.org/models/resnet101-5d3b4d8f.pth')
239 | elif self.arch == 'resnet50':
240 | pretrain_dict = model_zoo.load_url('https://download.pytorch.org/models/resnet50-19c8e357.pth')
241 | elif self.arch == 'resnet18':
242 | pretrain_dict = model_zoo.load_url('https://download.pytorch.org/models/resnet18-5c106cde.pth')
243 | elif self.arch == 'resnext50':
244 | pretrain_dict = model_zoo.load_url('https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth')
245 | model_dict = {}
246 | state_dict = self.state_dict()
247 | for k, v in pretrain_dict.items():
248 | if k in state_dict:
249 | model_dict[k] = v
250 | state_dict.update(model_dict)
251 | self.load_state_dict(state_dict)
252 |
253 |
254 | def FPN101(output_stride, BatchNorm=nn.BatchNorm2d, pretrained=True):
255 | """Constructs a ResNet-101 model.
256 | Args:
257 | pretrained (bool): If True, returns a model pre-trained on ImageNet
258 | """
259 | model = ResNet('resnet101', Bottleneck, [3, 4, 23, 3], output_stride, BatchNorm, pretrained=pretrained)
260 | return model
261 |
262 |
263 | def FPN50(output_stride, BatchNorm=nn.BatchNorm2d, pretrained=True):
264 | """Constructs a ResNet-50 model.
265 | Args:
266 | pretrained (bool): If True, returns a model pre-trained on ImageNet
267 | """
268 | model = ResNet('resnet50', Bottleneck, [3, 4, 6, 3], output_stride, BatchNorm, pretrained=pretrained)
269 | return model
270 |
271 |
272 | def FPN18(output_stride, BatchNorm=nn.BatchNorm2d, pretrained=True):
273 | model = ResNet('resnet18', BasicBlock, [2, 2, 2, 2], output_stride, BatchNorm, pretrained=pretrained)
274 | return model
275 |
276 |
277 | def ResNext50_FPN(output_stride, BatchNorm=nn.BatchNorm2d, pretrained=True):
278 | model = ResNet('resnext50', Bottleneck, [3, 4, 6, 3], output_stride, BatchNorm, pretrained=pretrained)
279 | return model
280 |
281 |
282 | if __name__ == "__main__":
283 | import torch
284 | from torchstat import stat
285 |
286 | model = FPN18(BatchNorm=nn.BatchNorm2d, pretrained=True, output_stride=32)
287 | stat(model.cpu(), input_size=(3, 256, 256))
288 |
289 | input = torch.rand(1, 3, 480, 640)
290 | output = model(input)
291 | for out in output:
292 | print(out.size())
293 | # print(low_level_feat.size())
--------------------------------------------------------------------------------
/models/network/resnet.py:
--------------------------------------------------------------------------------
1 | # -*- coding:utf-8 -*-
2 |
3 | # @Filename: Deeplab
4 | # @Project : Glory
5 | # @date : 2020-12-28 21:47
6 | # @Author : Linshan
7 |
8 | import torch
9 | import torch.nn as nn
10 | import torch.nn.functional as F
11 | import torch.utils.model_zoo as model_zoo
12 | import math
13 |
14 |
15 | def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
16 | """3x3 convolution with padding"""
17 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
18 | padding=dilation, groups=groups, bias=False, dilation=dilation)
19 |
20 |
21 | model_urls = {
22 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
23 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
24 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
25 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
26 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
27 | }
28 |
29 |
30 | def ResNet18(output_stride=32, BatchNorm=nn.BatchNorm2d, pretrained=True, in_c=3):
31 | """
32 | output, low_level_feat:
33 | 512, 256, 128, 64, 64
34 | """
35 | model = ResNet(BasicBlock, [2, 2, 2, 2], output_stride, BatchNorm, in_c=in_c)
36 | if in_c != 3:
37 | pretrained = False
38 | if pretrained:
39 | model._load_pretrained_model(model_urls['resnet18'])
40 | return model
41 |
42 |
43 | def ResNet34(output_stride=32, BatchNorm=nn.BatchNorm2d, pretrained=True, in_c=3):
44 | """
45 | output, low_level_feat:
46 | 512, 64
47 | """
48 | model = ResNet(BasicBlock, [3, 4, 6, 3], output_stride, BatchNorm, in_c=in_c)
49 | if in_c != 3:
50 | pretrained = False
51 | if pretrained:
52 | model._load_pretrained_model(model_urls['resnet34'])
53 | return model
54 |
55 |
56 | def ResNet50(output_stride=32, BatchNorm=nn.BatchNorm2d, pretrained=True, in_c=3):
57 | """
58 | output, low_level_feat:
59 | 2048, 256
60 | """
61 | model = ResNet(Bottleneck, [3, 4, 6, 3], output_stride, BatchNorm, in_c=in_c)
62 | if in_c != 3:
63 | pretrained=False
64 | if pretrained:
65 | model._load_pretrained_model(model_urls['resnet50'])
66 | return model
67 |
68 |
69 | def ResNet101(output_stride=32, BatchNorm=nn.BatchNorm2d, pretrained=True, in_c=3):
70 | """
71 | output, low_level_feat:
72 | 2048, 256
73 | """
74 | model = ResNet(Bottleneck, [3, 4, 23, 3], output_stride, BatchNorm, in_c=in_c)
75 | if in_c != 3:
76 | pretrained=False
77 | if pretrained:
78 | model._load_pretrained_model(model_urls['resnet101'])
79 | return model
80 |
81 |
82 | class BasicBlock(nn.Module):
83 | expansion = 1
84 |
85 | def __init__(self, inplanes, planes, stride=1, dilation=1, downsample=None, BatchNorm=None):
86 | super(BasicBlock, self).__init__()
87 |
88 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=3, stride=stride,
89 | dilation=dilation, padding=dilation, bias=False)
90 | self.bn1 = BatchNorm(planes)
91 |
92 | self.relu = nn.ReLU(inplace=True)
93 | self.conv2 = conv3x3(planes, planes)
94 | self.bn2 = BatchNorm(planes)
95 | self.downsample = downsample
96 | self.stride = stride
97 |
98 | def forward(self, x):
99 | identity = x
100 |
101 | out = self.conv1(x)
102 | out = self.bn1(out)
103 | out = self.relu(out)
104 |
105 | out = self.conv2(out)
106 | out = self.bn2(out)
107 |
108 | if self.downsample is not None:
109 | identity = self.downsample(x)
110 |
111 | out += identity
112 | out = self.relu(out)
113 |
114 | return out
115 |
116 |
117 | class Bottleneck(nn.Module):
118 | expansion = 4
119 |
120 | def __init__(self, inplanes, planes, stride=1, dilation=1, downsample=None, BatchNorm=None):
121 | super(Bottleneck, self).__init__()
122 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
123 | self.bn1 = BatchNorm(planes)
124 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
125 | dilation=dilation, padding=dilation, bias=False)
126 | self.bn2 = BatchNorm(planes)
127 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
128 | self.bn3 = BatchNorm(planes * 4)
129 | self.relu = nn.ReLU(inplace=True)
130 | self.downsample = downsample
131 | self.stride = stride
132 | self.dilation = dilation
133 |
134 | def forward(self, x):
135 | residual = x
136 |
137 | out = self.conv1(x)
138 | out = self.bn1(out)
139 | out = self.relu(out)
140 |
141 | out = self.conv2(out)
142 | out = self.bn2(out)
143 | out = self.relu(out)
144 |
145 | out = self.conv3(out)
146 | out = self.bn3(out)
147 |
148 | if self.downsample is not None:
149 | residual = self.downsample(x)
150 |
151 | out += residual
152 | out = self.relu(out)
153 |
154 | return out
155 |
156 |
157 | class ResNet(nn.Module):
158 |
159 | def __init__(self, block, layers, output_stride, BatchNorm, in_c=4):
160 |
161 | self.inplanes = 64
162 | self.in_c = in_c
163 | super(ResNet, self).__init__()
164 | blocks = [1, 2, 4]
165 | if output_stride == 32:
166 | strides = [1, 2, 2, 2]
167 | dilations = [1, 1, 1, 1]
168 | elif output_stride == 16:
169 | strides = [1, 2, 2, 1]
170 | dilations = [1, 1, 1, 2]
171 | elif output_stride == 8:
172 | strides = [1, 2, 1, 1]
173 | dilations = [1, 1, 2, 4]
174 | elif output_stride == 4:
175 | strides = [1, 1, 1, 1]
176 | dilations = [1, 2, 4, 8]
177 | else:
178 | raise NotImplementedError
179 |
180 | # Modules
181 | self.conv1 = nn.Conv2d(self.in_c, 64, kernel_size=7, stride=2, padding=3,
182 | bias=False)
183 | self.bn1 = BatchNorm(64)
184 | self.relu = nn.ReLU(inplace=True)
185 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
186 | self.layer1 = self._make_layer(block, 64, layers[0], stride=strides[0], dilation=dilations[0], BatchNorm=BatchNorm)
187 | self.layer2 = self._make_layer(block, 128, layers[1], stride=strides[1], dilation=dilations[1], BatchNorm=BatchNorm)
188 | self.layer3 = self._make_layer(block, 256, layers[2], stride=strides[2], dilation=dilations[2], BatchNorm=BatchNorm)
189 | #self.layer4 = self._make_MG_unit(block, 512, blocks=blocks, stride=strides[3], dilation=dilations[3], BatchNorm=BatchNorm)
190 | self.layer4 = self._make_layer(block, 512, layers[3], stride=strides[3], dilation=dilations[3], BatchNorm=BatchNorm)
191 | self._init_weight()
192 |
193 | def _make_layer(self, block, planes, blocks, stride=1, dilation=1, BatchNorm=None):
194 | downsample = None
195 | if stride != 1 or self.inplanes != planes * block.expansion:
196 | downsample = nn.Sequential(
197 | nn.Conv2d(self.inplanes, planes * block.expansion,
198 | kernel_size=1, stride=stride, bias=False),
199 | BatchNorm(planes * block.expansion),
200 | )
201 |
202 | layers = []
203 | layers.append(block(self.inplanes, planes, stride, dilation, downsample, BatchNorm))
204 | self.inplanes = planes * block.expansion
205 | for i in range(1, blocks):
206 | layers.append(block(self.inplanes, planes, dilation=dilation, BatchNorm=BatchNorm))
207 |
208 | return nn.Sequential(*layers)
209 |
210 | def _make_MG_unit(self, block, planes, blocks, stride=1, dilation=1, BatchNorm=None):
211 | downsample = None
212 | if stride != 1 or self.inplanes != planes * block.expansion:
213 | downsample = nn.Sequential(
214 | nn.Conv2d(self.inplanes, planes * block.expansion,
215 | kernel_size=1, stride=stride, bias=False),
216 | BatchNorm(planes * block.expansion),
217 | )
218 |
219 | layers = []
220 | layers.append(block(self.inplanes, planes, stride, dilation=blocks[0]*dilation,
221 | downsample=downsample, BatchNorm=BatchNorm))
222 | self.inplanes = planes * block.expansion
223 | for i in range(1, len(blocks)):
224 | layers.append(block(self.inplanes, planes, stride=1,
225 | dilation=blocks[i]*dilation, BatchNorm=BatchNorm))
226 |
227 | return nn.Sequential(*layers)
228 |
229 | def forward(self, input):
230 | x = self.conv1(input)
231 | x = self.bn1(x)
232 | x = self.relu(x)
233 | x = self.maxpool(x) # | 4
234 |
235 | x = self.layer1(x) # | 4
236 | low_level_feat2 = x # | 4
237 |
238 | x = self.layer2(x) # | 8
239 | low_level_feat3 = x
240 |
241 | x = self.layer3(x) # | 16
242 | low_level_feat4 = x
243 |
244 | x = self.layer4(x) # | 32
245 |
246 | return low_level_feat2, low_level_feat3, low_level_feat4, x
247 |
248 | def _init_weight(self):
249 | for m in self.modules():
250 | if isinstance(m, nn.Conv2d):
251 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
252 | m.weight.data.normal_(0, math.sqrt(2. / n))
253 | elif isinstance(m, nn.BatchNorm2d):
254 | m.weight.data.fill_(1)
255 | m.bias.data.zero_()
256 |
257 | def _load_pretrained_model(self, model_path):
258 | pretrain_dict = model_zoo.load_url(model_path)
259 | model_dict = {}
260 | state_dict = self.state_dict()
261 | for k, v in pretrain_dict.items():
262 | if k in state_dict:
263 | model_dict[k] = v
264 | state_dict.update(model_dict)
265 | self.load_state_dict(state_dict)
266 | print('load pretrained model')
--------------------------------------------------------------------------------
/models/network/segformer.py:
--------------------------------------------------------------------------------
1 | # -*- coding:utf-8 -*-
2 |
3 | # @Filename: segformer
4 | # @Project : ContestCD
5 | # @date : 2021-10-19 18:34
6 | # @Author : Linshan
7 |
8 | import torch
9 | import torch.nn as nn
10 | import torch.nn.functional as F
11 | from functools import partial
12 |
13 | from timm.models.layers import DropPath, to_2tuple, trunc_normal_
14 | from timm.models.registry import register_model
15 | from timm.models.vision_transformer import _cfg
16 | # from mmseg.models.builder import BACKBONES
17 | # from mmseg.utils import get_root_logger
18 | # from mmcv.runner import load_checkpoint
19 | import math
20 |
21 |
22 | class Mlp(nn.Module):
23 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
24 | super().__init__()
25 | out_features = out_features or in_features
26 | hidden_features = hidden_features or in_features
27 | self.fc1 = nn.Linear(in_features, hidden_features)
28 | self.dwconv = DWConv(hidden_features)
29 | self.act = act_layer()
30 | self.fc2 = nn.Linear(hidden_features, out_features)
31 | self.drop = nn.Dropout(drop)
32 |
33 | # self.apply(self._init_weights)
34 |
35 | def _init_weights(self, m):
36 | if isinstance(m, nn.Linear):
37 | trunc_normal_(m.weight, std=.02)
38 | if isinstance(m, nn.Linear) and m.bias is not None:
39 | nn.init.constant_(m.bias, 0)
40 | elif isinstance(m, nn.LayerNorm):
41 | nn.init.constant_(m.bias, 0)
42 | nn.init.constant_(m.weight, 1.0)
43 | elif isinstance(m, nn.Conv2d):
44 | fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
45 | fan_out //= m.groups
46 | m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
47 | if m.bias is not None:
48 | m.bias.data.zero_()
49 |
50 | def forward(self, x, H, W):
51 | x = self.fc1(x)
52 | x = self.dwconv(x, H, W)
53 | x = self.act(x)
54 | x = self.drop(x)
55 | x = self.fc2(x)
56 | x = self.drop(x)
57 | return x
58 |
59 |
60 | class DWConv(nn.Module):
61 | def __init__(self, dim=768):
62 | super(DWConv, self).__init__()
63 | self.dwconv = nn.Conv2d(dim, dim, 3, 1, 1, bias=True, groups=dim)
64 |
65 | def forward(self, x, H, W):
66 | B, N, C = x.shape
67 | x = x.transpose(1, 2).view(B, C, H, W)
68 | x = self.dwconv(x)
69 | x = x.flatten(2).transpose(1, 2)
70 |
71 | return x
72 |
73 |
74 | class Attention(nn.Module):
75 | def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., sr_ratio=1):
76 | super().__init__()
77 | assert dim % num_heads == 0, f"dim {dim} should be divided by num_heads {num_heads}."
78 |
79 | self.dim = dim
80 | self.num_heads = num_heads
81 | head_dim = dim // num_heads
82 | self.scale = qk_scale or head_dim ** -0.5
83 |
84 | self.q = nn.Linear(dim, dim, bias=qkv_bias)
85 | self.kv = nn.Linear(dim, dim * 2, bias=qkv_bias)
86 | self.attn_drop = nn.Dropout(attn_drop)
87 | self.proj = nn.Linear(dim, dim)
88 | self.proj_drop = nn.Dropout(proj_drop)
89 |
90 | self.sr_ratio = sr_ratio
91 | if sr_ratio > 1:
92 | self.sr = nn.Conv2d(dim, dim, kernel_size=sr_ratio, stride=sr_ratio)
93 | self.norm = nn.LayerNorm(dim)
94 |
95 | # self.apply(self._init_weights)
96 |
97 | def _init_weights(self, m):
98 | if isinstance(m, nn.Linear):
99 | trunc_normal_(m.weight, std=.02)
100 | if isinstance(m, nn.Linear) and m.bias is not None:
101 | nn.init.constant_(m.bias, 0)
102 | elif isinstance(m, nn.LayerNorm):
103 | nn.init.constant_(m.bias, 0)
104 | nn.init.constant_(m.weight, 1.0)
105 | elif isinstance(m, nn.Conv2d):
106 | fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
107 | fan_out //= m.groups
108 | m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
109 | if m.bias is not None:
110 | m.bias.data.zero_()
111 |
112 | def forward(self, x, H, W):
113 | B, N, C = x.shape
114 | q = self.q(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
115 |
116 | if self.sr_ratio > 1:
117 | x_ = x.permute(0, 2, 1).reshape(B, C, H, W)
118 | x_ = self.sr(x_).reshape(B, C, -1).permute(0, 2, 1)
119 | x_ = self.norm(x_)
120 | kv = self.kv(x_).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
121 | else:
122 | kv = self.kv(x).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
123 | k, v = kv[0], kv[1]
124 |
125 | attn = (q @ k.transpose(-2, -1)) * self.scale
126 | attn = attn.softmax(dim=-1)
127 | attn = self.attn_drop(attn)
128 |
129 | x = (attn @ v).transpose(1, 2).reshape(B, N, C)
130 | x = self.proj(x)
131 | x = self.proj_drop(x)
132 |
133 | return x
134 |
135 |
136 | class Block(nn.Module):
137 |
138 | def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
139 | drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, sr_ratio=1):
140 | super().__init__()
141 | self.norm1 = norm_layer(dim)
142 | self.attn = Attention(
143 | dim,
144 | num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale,
145 | attn_drop=attn_drop, proj_drop=drop, sr_ratio=sr_ratio)
146 | # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
147 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
148 | self.norm2 = norm_layer(dim)
149 | mlp_hidden_dim = int(dim * mlp_ratio)
150 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
151 |
152 | # self.apply(self._init_weights)
153 |
154 | def _init_weights(self, m):
155 | if isinstance(m, nn.Linear):
156 | trunc_normal_(m.weight, std=.02)
157 | if isinstance(m, nn.Linear) and m.bias is not None:
158 | nn.init.constant_(m.bias, 0)
159 | elif isinstance(m, nn.LayerNorm):
160 | nn.init.constant_(m.bias, 0)
161 | nn.init.constant_(m.weight, 1.0)
162 | elif isinstance(m, nn.Conv2d):
163 | fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
164 | fan_out //= m.groups
165 | m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
166 | if m.bias is not None:
167 | m.bias.data.zero_()
168 |
169 | def forward(self, x, H, W):
170 | x = x + self.drop_path(self.attn(self.norm1(x), H, W))
171 | x = x + self.drop_path(self.mlp(self.norm2(x), H, W))
172 |
173 | return x
174 |
175 |
176 | class OverlapPatchEmbed(nn.Module):
177 | """ Image to Patch Embedding
178 | """
179 |
180 | def __init__(self, img_size=224, patch_size=7, stride=4, in_chans=3, embed_dim=768):
181 | super().__init__()
182 | img_size = to_2tuple(img_size)
183 | patch_size = to_2tuple(patch_size)
184 |
185 | self.img_size = img_size
186 | self.patch_size = patch_size
187 | self.H, self.W = img_size[0] // patch_size[0], img_size[1] // patch_size[1]
188 | self.num_patches = self.H * self.W
189 | self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=stride,
190 | padding=(patch_size[0] // 2, patch_size[1] // 2))
191 | self.norm = nn.LayerNorm(embed_dim)
192 |
193 | # self.apply(self._init_weights)
194 |
195 | def _init_weights(self, m):
196 | if isinstance(m, nn.Linear):
197 | trunc_normal_(m.weight, std=.02)
198 | if isinstance(m, nn.Linear) and m.bias is not None:
199 | nn.init.constant_(m.bias, 0)
200 | elif isinstance(m, nn.LayerNorm):
201 | nn.init.constant_(m.bias, 0)
202 | nn.init.constant_(m.weight, 1.0)
203 | elif isinstance(m, nn.Conv2d):
204 | fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
205 | fan_out //= m.groups
206 | m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
207 | if m.bias is not None:
208 | m.bias.data.zero_()
209 |
210 | def forward(self, x):
211 | x = self.proj(x)
212 | _, _, H, W = x.shape
213 | x = x.flatten(2).transpose(1, 2) # size (B, H*W, C)
214 | # print(x.shape)
215 | x = self.norm(x)
216 |
217 | return x, H, W
218 |
219 |
220 | class MixVisionTransformer(nn.Module):
221 | def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dims=[64, 128, 256, 512],
222 | num_heads=[1, 2, 4, 8], mlp_ratios=[4, 4, 4, 4], qkv_bias=False, qk_scale=None, drop_rate=0.,
223 | attn_drop_rate=0., drop_path_rate=0., norm_layer=nn.LayerNorm,
224 | depths=[3, 4, 6, 3], sr_ratios=[8, 4, 2, 1]):
225 | super().__init__()
226 | self.num_classes = num_classes
227 | self.depths = depths
228 |
229 | # patch_embed
230 | self.patch_embed1 = OverlapPatchEmbed(img_size=img_size, patch_size=7, stride=4, in_chans=in_chans,
231 | embed_dim=embed_dims[0])
232 | self.patch_embed2 = OverlapPatchEmbed(img_size=img_size // 4, patch_size=3, stride=2, in_chans=embed_dims[0],
233 | embed_dim=embed_dims[1])
234 | self.patch_embed3 = OverlapPatchEmbed(img_size=img_size // 8, patch_size=3, stride=2, in_chans=embed_dims[1],
235 | embed_dim=embed_dims[2])
236 | self.patch_embed4 = OverlapPatchEmbed(img_size=img_size // 16, patch_size=3, stride=2, in_chans=embed_dims[2],
237 | embed_dim=embed_dims[3])
238 |
239 | # transformer encoder
240 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule
241 | cur = 0
242 | self.block1 = nn.ModuleList([Block(
243 | dim=embed_dims[0], num_heads=num_heads[0], mlp_ratio=mlp_ratios[0], qkv_bias=qkv_bias, qk_scale=qk_scale,
244 | drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer,
245 | sr_ratio=sr_ratios[0])
246 | for i in range(depths[0])])
247 | self.norm1 = norm_layer(embed_dims[0])
248 |
249 | cur += depths[0]
250 | self.block2 = nn.ModuleList([Block(
251 | dim=embed_dims[1], num_heads=num_heads[1], mlp_ratio=mlp_ratios[1], qkv_bias=qkv_bias, qk_scale=qk_scale,
252 | drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer,
253 | sr_ratio=sr_ratios[1])
254 | for i in range(depths[1])])
255 | self.norm2 = norm_layer(embed_dims[1])
256 |
257 | cur += depths[1]
258 | self.block3 = nn.ModuleList([Block(
259 | dim=embed_dims[2], num_heads=num_heads[2], mlp_ratio=mlp_ratios[2], qkv_bias=qkv_bias, qk_scale=qk_scale,
260 | drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer,
261 | sr_ratio=sr_ratios[2])
262 | for i in range(depths[2])])
263 | self.norm3 = norm_layer(embed_dims[2])
264 |
265 | cur += depths[2]
266 | self.block4 = nn.ModuleList([Block(
267 | dim=embed_dims[3], num_heads=num_heads[3], mlp_ratio=mlp_ratios[3], qkv_bias=qkv_bias, qk_scale=qk_scale,
268 | drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer,
269 | sr_ratio=sr_ratios[3])
270 | for i in range(depths[3])])
271 | self.norm4 = norm_layer(embed_dims[3])
272 |
273 | # classification head
274 | # self.head = nn.Linear(embed_dims[3], num_classes) if num_classes > 0 else nn.Identity()
275 |
276 | # self.apply(self._init_weights)
277 |
278 | def _init_weights(self, m):
279 | if isinstance(m, nn.Linear):
280 | trunc_normal_(m.weight, std=.02)
281 | if isinstance(m, nn.Linear) and m.bias is not None:
282 | nn.init.constant_(m.bias, 0)
283 | elif isinstance(m, nn.LayerNorm):
284 | nn.init.constant_(m.bias, 0)
285 | nn.init.constant_(m.weight, 1.0)
286 | elif isinstance(m, nn.Conv2d):
287 | fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
288 | fan_out //= m.groups
289 | m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
290 | if m.bias is not None:
291 | m.bias.data.zero_()
292 |
293 | def reset_drop_path(self, drop_path_rate):
294 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(self.depths))]
295 | cur = 0
296 | for i in range(self.depths[0]):
297 | self.block1[i].drop_path.drop_prob = dpr[cur + i]
298 |
299 | cur += self.depths[0]
300 | for i in range(self.depths[1]):
301 | self.block2[i].drop_path.drop_prob = dpr[cur + i]
302 |
303 | cur += self.depths[1]
304 | for i in range(self.depths[2]):
305 | self.block3[i].drop_path.drop_prob = dpr[cur + i]
306 |
307 | cur += self.depths[2]
308 | for i in range(self.depths[3]):
309 | self.block4[i].drop_path.drop_prob = dpr[cur + i]
310 |
311 | def freeze_patch_emb(self):
312 | self.patch_embed1.requires_grad = False
313 |
314 | @torch.jit.ignore
315 | def no_weight_decay(self):
316 | return {'pos_embed1', 'pos_embed2', 'pos_embed3', 'pos_embed4', 'cls_token'} # has pos_embed may be better
317 |
318 | def get_classifier(self):
319 | return self.head
320 |
321 | def reset_classifier(self, num_classes, global_pool=''):
322 | self.num_classes = num_classes
323 | self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
324 |
325 | def forward_features(self, x):
326 | B = x.shape[0]
327 | outs = []
328 |
329 | # stage 1
330 | x, H, W = self.patch_embed1(x)
331 | for i, blk in enumerate(self.block1):
332 | x = blk(x, H, W)
333 | x = self.norm1(x)
334 | x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
335 | outs.append(x)
336 |
337 | # stage 2
338 | x, H, W = self.patch_embed2(x)
339 | for i, blk in enumerate(self.block2):
340 | x = blk(x, H, W)
341 | x = self.norm2(x)
342 | x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
343 | outs.append(x)
344 |
345 | # stage 3
346 | x, H, W = self.patch_embed3(x)
347 | for i, blk in enumerate(self.block3):
348 | x = blk(x, H, W)
349 | x = self.norm3(x)
350 | x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
351 | outs.append(x)
352 |
353 | # stage 4
354 | x, H, W = self.patch_embed4(x)
355 | for i, blk in enumerate(self.block4):
356 | x = blk(x, H, W)
357 | x = self.norm4(x)
358 | x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
359 | outs.append(x)
360 |
361 | return outs
362 |
363 | def forward(self, x):
364 | x = self.forward_features(x)
365 | # x = self.head(x)
366 |
367 | return x
368 |
369 |
370 | class mit_b0(MixVisionTransformer):
371 | def __init__(self, **kwargs):
372 | super(mit_b0, self).__init__(
373 | patch_size=4, embed_dims=[32, 64, 160, 256], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4],
374 | qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[2, 2, 2, 2], sr_ratios=[8, 4, 2, 1],
375 | drop_rate=0.0, drop_path_rate=0.1)
376 |
377 |
378 | class mit_b1(MixVisionTransformer):
379 | def __init__(self, **kwargs):
380 | super(mit_b1, self).__init__(
381 | patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4],
382 | qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[2, 2, 2, 2], sr_ratios=[8, 4, 2, 1],
383 | drop_rate=0.0, drop_path_rate=0.1)
384 |
385 |
386 | class mit_b2(MixVisionTransformer):
387 | def __init__(self, **kwargs):
388 | super(mit_b2, self).__init__(
389 | patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4],
390 | qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[3, 4, 6, 3], sr_ratios=[8, 4, 2, 1],
391 | drop_rate=0.0, drop_path_rate=0.1)
392 |
393 |
394 | class mit_b3(MixVisionTransformer):
395 | def __init__(self, **kwargs):
396 | super(mit_b3, self).__init__(
397 | patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4],
398 | qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[3, 4, 18, 3], sr_ratios=[8, 4, 2, 1],
399 | drop_rate=0.0, drop_path_rate=0.1)
400 |
401 |
402 | class mit_b4(MixVisionTransformer):
403 | def __init__(self, **kwargs):
404 | super(mit_b4, self).__init__(
405 | patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4],
406 | qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[3, 8, 27, 3], sr_ratios=[8, 4, 2, 1],
407 | drop_rate=0.0, drop_path_rate=0.1)
408 |
409 |
410 | class mit_b5(MixVisionTransformer):
411 | def __init__(self, **kwargs):
412 | super(mit_b5, self).__init__(
413 | patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4],
414 | qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[3, 6, 40, 3], sr_ratios=[8, 4, 2, 1],
415 | drop_rate=0.0, drop_path_rate=0.1)
416 |
417 |
418 | def load_check(model, path):
419 | checkpoint = torch.load(path, map_location=torch.device('cpu'))
420 | state_dict = checkpoint['state_dict']
421 |
422 | from collections import OrderedDict
423 | new_state_dict = OrderedDict()
424 | for k, v in state_dict.items():
425 | if k[:8] == 'backbone':
426 | name = k[9:] # remove `module.`,表面从第7个key值字符取到最后一个字符,正好去掉了module.
427 | new_state_dict[name] = v # 新字典的key值对应的value为一一对应的值。
428 |
429 | model.load_state_dict(new_state_dict)
430 | print('load pretrained transformer')
431 | return model
432 |
433 |
434 | if __name__ == '__main__':
435 | import torch
436 | model = mit_b1()
437 | # checkpoint = torch.load('/media/hlf/Luffy/WLS/ContestCD/segformer_b3.pth', map_location=torch.device('cpu'))
438 | # model.load_state_dict(checkpoint['state_dict'])
439 | # print('load pretrained transformer')
440 |
441 | # model = load_check(model, path='/media/hlf/Luffy/WLS/ContestCD/segformer.b2.512x512.ade.160k.pth')
442 | # torch.save({'state_dict': model.state_dict()},
443 | # '/media/hlf/Luffy/WLS/ContestCD/segformer_b2.pth', _use_new_zipfile_serialization=False)
444 |
445 | input = torch.rand(1, 3, 512, 512)
446 | outputs = model(input)
447 | print(outputs.shape)
448 | # for output in outputs:
449 | # print(output.size())
--------------------------------------------------------------------------------
/models/seg_model.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from models.network import *
5 | from options import *
6 | from models.base_model import build_base_model, build_channels
7 | from models.tools import *
8 |
9 |
10 | class seg_decoder(nn.Module):
11 | def __init__(self, in_channels, num_classes):
12 | super(seg_decoder, self).__init__()
13 | self.last_layer = nn.Sequential(
14 | nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0),
15 | nn.BatchNorm2d(in_channels),
16 | nn.ReLU(),
17 | nn.Conv2d(in_channels, num_classes, kernel_size=3, stride=1, padding=1)
18 | )
19 |
20 | def forward(self, x):
21 | x = self.last_layer(x)
22 | return x
23 |
24 |
25 | class Seg_Net(nn.Module):
26 | def __init__(self, opt):
27 | super(Seg_Net, self).__init__()
28 | self.base_model = build_base_model(opt)
29 | self.in_channels = build_channels(opt)
30 | self.seg_decoder = seg_decoder(self.in_channels, opt.num_classes)
31 | self.img_size = opt.img_size
32 |
33 | def resize_out(self, output):
34 | if output.size()[-1] != self.img_size:
35 | output = F.interpolate(output, size=(self.img_size, self.img_size), mode='bilinear')
36 | return output
37 |
38 | def forward(self, x):
39 | x = self.base_model(x)
40 | # feat = x
41 |
42 | x = self.seg_decoder(x)
43 | x = self.resize_out(x)
44 |
45 | return x
46 |
47 |
48 | if __name__ == "__main__":
49 | opt = Seg_Options().parse()
50 | net = Seg_Net(opt)
51 | net.cuda()
52 | x = torch.rand(4, 3, 512, 512).cuda()
53 | label = torch.randint(0, 8, (4, 512, 512)).cuda()
54 | output = net(x)
55 |
56 | criterion = nn.CrossEntropyLoss(ignore_index=5)
57 | loss = criterion(output, label)
58 |
59 | print(output.shape)
60 | print(loss)
--------------------------------------------------------------------------------
/models/tools/__init__.py:
--------------------------------------------------------------------------------
1 | from models.tools.densecrf import get_crf
2 |
--------------------------------------------------------------------------------
/models/tools/densecrf.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 | import tifffile
6 | import os
7 | import matplotlib.pyplot as plt
8 | import pydensecrf.densecrf as dcrf
9 | from dataset.datapoint import value_to_rgb
10 | from pydensecrf.utils import create_pairwise_bilateral, create_pairwise_gaussian
11 |
12 |
13 | def softmax(x):
14 | x_row_max = x.max(axis=-1)
15 | x_row_max = x_row_max.reshape(list(x.shape)[:-1]+[1])
16 | x = x - x_row_max
17 | x_exp = np.exp(x)
18 | x_exp_row_sum = x_exp.sum(axis=-1).reshape(list(x.shape)[:-1]+[1])
19 | softmax = x_exp / x_exp_row_sum
20 | return softmax
21 |
22 |
23 | def get_crf(opt, mask, img):
24 | mask = np.transpose(mask, (2, 0, 1))
25 | img = np.ascontiguousarray(img)
26 |
27 | unary = -np.log(mask + 1e-8)
28 | unary = unary.reshape((opt.num_classes, -1))
29 | unary = np.ascontiguousarray(unary)
30 |
31 | d = dcrf.DenseCRF2D(opt.img_size, opt.img_size, opt.num_classes)
32 | d.setUnaryEnergy(unary)
33 |
34 | d.addPairwiseGaussian(sxy=5, compat=3)
35 | d.addPairwiseBilateral(sxy=10, srgb=13, rgbim=img, compat=10)
36 |
37 | output = d.inference(10)
38 |
39 | map = np.argmax(output, axis=0).reshape((opt.img_size, opt.img_size))
40 |
41 | return map
42 |
--------------------------------------------------------------------------------
/options/Second_options.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 |
4 | class Sec_Options():
5 | """This classification defines options used during both training and test time.
6 | It also implements several helper functions such as parsing, printing, and saving the options.
7 | It also gathers additional options defined in functions in both dataset classification and model classification.
8 | """
9 |
10 | def __init__(self):
11 | """Reset the classification; indicates the classification hasn't been initailized"""
12 | self.initialized = False
13 |
14 | def initialize(self, parser):
15 | """Define the common options that are used in both training and test."""
16 | # basic parameters
17 | parser.add_argument('--data_root', type=str, default='/media/hlf/Luffy/WLS/semantic/dataset', help='path to dataroot')
18 | parser.add_argument('--dataset', type=str, default='potsdam', help='[chesapeake|potsdam|vaihingen|GID]')
19 | parser.add_argument('--experiment_name', type=str, default='Second', help='name of the experiment. It decides where to load datafiles, store samples and models')
20 | parser.add_argument('--save_path', type=str, default='/media/hlf/Luffy/WLS/PointAnno/save', help='models are saved here')
21 | parser.add_argument('--data_inform_path', type=str, default='/media/hlf/Luffy/WLS/PointAnno/datafiles', help='path to files about the datafiles information')
22 |
23 | # model parameters
24 | parser.add_argument('--base_model', type=str, default='Deeplab', help='choose which base model. [HRNet18|HRNet48|Deeplab]')
25 | parser.add_argument('--backbone', type=str, default='resnet18', help='which resnet')
26 | parser.add_argument('--num_classes', type=int, default=5, help='classes')
27 |
28 | # train parameters
29 | parser.add_argument('--loss', type=str, default='OHEM_loss', help='choose which hr_loss')
30 | parser.add_argument('--batch_size', type=int, default=32, help='input batch size')
31 | parser.add_argument('--pin', type=bool, default=True, help='pin_memory or not')
32 |
33 | parser.add_argument('--num_workers', type=int, default=10, help='number of workers')
34 | parser.add_argument('--img_size', type=int, default=256, help='image size')
35 | parser.add_argument('--in_channels', type=int, default=3, help='input channels')
36 |
37 | parser.add_argument('--num_epochs', type=int, default=100, help='num of epochs')
38 | parser.add_argument('--base_lr', type=float, default=1e-3, help='base learning rate')
39 | parser.add_argument('--decay', type=float, default=5e-4, help='decay')
40 | parser.add_argument('--log_interval', type=int, default=60, help='how long to log')
41 | parser.add_argument('--resume', type=bool, default=False, help='resume the saved checkpoint or not')
42 |
43 | self.initialized = True
44 | return parser
45 |
46 | def gather_options(self):
47 | """Initialize our parser with basic options(only once).
48 | Add additional model-specific and dataset-specific options.
49 | These options are defined in the function
50 | in model and dataset classes.
51 | """
52 | if not self.initialized: # check if it has been initialized
53 | parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
54 | parser = self.initialize(parser)
55 |
56 | # get the basic options
57 | opt, _ = parser.parse_known_args()
58 | self.parser = parser
59 | return parser.parse_args()
60 |
61 | def parse(self):
62 | """Parse our options, create checkpoints directory suffix, and set up gpu device."""
63 | opt = self.gather_options()
64 | self.opt = opt
65 |
66 | return self.opt
67 |
68 |
69 | if __name__ == '__main__':
70 | opt = Point_Options().parse()
71 | print(opt)
--------------------------------------------------------------------------------
/options/__init__.py:
--------------------------------------------------------------------------------
1 | from options.point_options import *
2 | from options.Second_options import *
3 |
--------------------------------------------------------------------------------
/options/point_options.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 |
4 | class Point_Options():
5 | """This classification defines options used during both training and test time.
6 | It also implements several helper functions such as parsing, printing, and saving the options.
7 | It also gathers additional options defined in functions in both dataset classification and model classification.
8 | """
9 |
10 | def __init__(self):
11 | """Reset the classification; indicates the classification hasn't been initailized"""
12 | self.initialized = False
13 |
14 | def initialize(self, parser):
15 | """Define the common options that are used in both training and test."""
16 | # basic parameters
17 | parser.add_argument('--data_root', type=str, default='/media/hlf/Luffy/WLS/semantic/dataset', help='your path to dataroot')
18 | parser.add_argument('--dataset', type=str, default='potsdam', help='[potsdam|vaihingen]')
19 | parser.add_argument('--experiment_name', type=str, default='Point', help='name of the experiment. It decides where to load datafiles, store samples and models')
20 | parser.add_argument('--save_path', type=str, default='/media/hlf/Luffy/WLS/PointAnno/save', help='models are saved here')
21 | parser.add_argument('--data_inform_path', type=str, default='/media/hlf/Luffy/WLS/PointAnno/datafiles', help='path to files about the datafiles information')
22 |
23 | # model parameter
24 | parser.add_argument('--backbone', type=str, default='resnet18', help='which resnet')
25 | parser.add_argument('--out_stride', type=int, default=32, help='out_stride')
26 | parser.add_argument('--num_classes', type=int, default=5, help='classes')
27 |
28 | # train parameters
29 | parser.add_argument('--batch_size', type=int, default=32, help='input batch size')
30 | parser.add_argument('--pin', type=bool, default=True, help='pin_memory or not')
31 |
32 | parser.add_argument('--num_workers', type=int, default=10, help='number of workers')
33 | parser.add_argument('--img_size', type=int, default=256, help='image size')
34 | parser.add_argument('--in_channels', type=int, default=3, help='input channels')
35 |
36 | parser.add_argument('--num_epochs', type=int, default=100, help='num of epochs')
37 | parser.add_argument('--base_lr', type=float, default=1e-3, help='base learning rate')
38 | parser.add_argument('--decay', type=float, default=5e-4, help='decay')
39 | parser.add_argument('--log_interval', type=int, default=60, help='how long to log, set yo 100 batch')
40 | parser.add_argument('--resume', type=bool, default=False, help='resume the saved checkpoint or not')
41 |
42 | self.initialized = True
43 | return parser
44 |
45 | def gather_options(self):
46 | """Initialize our parser with basic options(only once).
47 | Add additional model-specific and dataset-specific options.
48 | These options are defined in the function
49 | in model and dataset classes.
50 | """
51 | if not self.initialized: # check if it has been initialized
52 | parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
53 | parser = self.initialize(parser)
54 |
55 | # get the basic options
56 | opt, _ = parser.parse_known_args()
57 | self.parser = parser
58 | return parser.parse_args()
59 |
60 | def parse(self):
61 | """Parse our options, create checkpoints directory suffix, and set up gpu device."""
62 | opt = self.gather_options()
63 | self.opt = opt
64 |
65 | return self.opt
66 |
67 |
68 | if __name__ == '__main__':
69 | opt = Point_Options().parse()
70 | print(opt)
--------------------------------------------------------------------------------
/predict.py:
--------------------------------------------------------------------------------
1 | from utils import *
2 | from dataset import *
3 | from matplotlib import pyplot as plt
4 | from torch.utils.data import DataLoader
5 | from tqdm import tqdm
6 | from datafiles.color_dict import *
7 | from models.tools import get_crf
8 | import random
9 | from dataset.data_utils import value_to_rgb
10 | # from models.tools import get_crf
11 | import ttach as tta
12 |
13 | print("PyTorch Version: ", torch.__version__)
14 | print('cuda', torch.version.cuda)
15 | os.environ['CUDA_VISIBLE_DEVICES'] = '0'
16 | device = torch.device("cuda:0")
17 | print('Device:', device)
18 |
19 |
20 | def save(visual, name, path):
21 | check_dir(path)
22 | imsave(path+'/'+name[:-4]+'.png', visual)
23 |
24 |
25 | def eval(opt, out, label, name, stride=128):
26 | compute_metric = IOUMetric(num_classes=opt.num_classes)
27 | hist = np.zeros([opt.num_classes, opt.num_classes])
28 |
29 | h, w, _ = label.shape
30 | num_h, num_w = h//stride, w//stride
31 | for i in range(num_h):
32 | for j in range(num_w):
33 | o = out[i*stride:(i+1)*stride, j*stride:(j+1)*stride]
34 | l = label[i*stride:(i+1)*stride, j*stride:(j+1)*stride, 0]
35 | hist += compute_metric.get_hist(o, l)
36 | # # evaluate
37 | iou, miou, kappa, acc, acc_cls, f_score, m_f_score = eval_hist(hist)
38 | print(name)
39 | print('mfscore:%.4f fscore:%s' % (m_f_score, str(f_score)))
40 | print('------'*5)
41 | return hist
42 |
43 |
44 | def predict_tile(opt, model, img, label, label_vis, name,
45 | save_path, dataset='potsdam', size=256):
46 | im = img.copy()
47 | with torch.no_grad():
48 | model.eval()
49 | h, w, c = img.shape
50 | img = Normalize(img, flag=dataset)
51 | img = img.astype(np.float32).transpose(2, 0, 1)
52 | img = torch.from_numpy(img).unsqueeze(0).float().cuda()
53 | output = pre_slide(model, img, num_classes=5, tile_size=(size, size), tta=True)
54 |
55 | output = output.squeeze(0).permute(1, 2, 0).data.cpu().numpy().astype(np.float32)
56 |
57 | output = output[:h, :w, :]
58 | output = np.argmax(output, axis=-1)
59 | hist = eval(opt, output, label, name)
60 |
61 | # output = np.expand_dims(output, axis=-1)
62 | # predict = value_to_rgb(output, flag=opt.dataset)
63 | # fig, axs = plt.subplots(1, 3, figsize=(20, 8))
64 | # axs[0].imshow(im.astype(np.uint8))
65 | # axs[0].axis("off")
66 | # axs[1].imshow(label_vis.astype(np.uint8))
67 | # axs[1].axis("off")
68 | # axs[2].imshow(predict.astype(np.uint8))
69 | # axs[2].axis("off")
70 | # plt.suptitle(os.path.basename(name), y=0.94)
71 | # plt.tight_layout()
72 | # plt.show()
73 | # plt.close()
74 |
75 | # save(predict, name)
76 |
77 | return hist
78 |
79 |
80 | def run_potsdam(root_path, final=True):
81 | img_path = root_path + '/4_Ortho_RGBIR'
82 | label_path = root_path + '/Labels'
83 | vis_path = root_path + '/5_Labels_all'
84 | save_path = './save/potsdam/predict_masks'
85 | train_name = [7, 8, 9, 10, 11, 12]
86 |
87 | opt = Sec_Options().parse()
88 | model = Seg_Net(opt)
89 | checkpoint = torch.load('./save/potsdam_0.8719.pth', map_location=torch.device('cpu'))
90 | model.load_state_dict(checkpoint['state_dict'])
91 | model = model.cuda()
92 | model.eval()
93 |
94 | hist = np.zeros([opt.num_classes, opt.num_classes])
95 | list = os.listdir(label_path)
96 | for i in list:
97 | if int(i[14:-10]) in train_name:
98 | pass
99 | else:
100 | img = read(os.path.join(img_path, i[:-9] + 'RGBIR.tif'))[:, :, :3]
101 | label = read(os.path.join(label_path, i))
102 | vis = read(os.path.join(vis_path, i))
103 | hist += predict_tile(opt, model, img, label, vis, i, save_path, dataset='potsdam')
104 |
105 | iou, miou, kappa, acc, acc_cls, f_score, m_f_score = eval_hist(hist)
106 | print('total')
107 | print(acc)
108 | print('mfscore:%.4f fscore:%s' % (m_f_score, str(f_score)))
109 | print('------' * 5)
110 |
111 |
112 | def run_vaihingen(root_path):
113 | save_path = './save/vaihingen/predict_masks'
114 | img_path = root_path + '/image'
115 | label_path = root_path + '/gts_noB'
116 | vis_path = root_path + '/vis_noB'
117 | test_name = [2, 4, 6, 8, 10, 12, 14, 16, 20, 22, 24, 27, 29, 31, 33, 35, 38]
118 |
119 | opt = Sec_Options().parse()
120 | model = Seg_Net(opt)
121 |
122 | checkpoint = torch.load('./save/vaihingen_0.8867.pth', map_location=torch.device('cpu'))
123 | model.load_state_dict(checkpoint['state_dict'])
124 | model = model.cuda()
125 | model.eval()
126 |
127 | hist = np.zeros([opt.num_classes, opt.num_classes])
128 | list = os.listdir(label_path)
129 | for i in list:
130 | if int(i[20:-4]) not in test_name:
131 | pass
132 | else:
133 | img = read(os.path.join(img_path, i))
134 | label = read(os.path.join(label_path, i))
135 | vis = read(os.path.join(vis_path, i))
136 | hist += predict_tile(opt, model, img, label, vis, i, save_path, dataset='vaihingen')
137 |
138 | iou, miou, kappa, acc, acc_cls, f_score, m_f_score = eval_hist(hist)
139 | print('total')
140 | print('mfscore:%.4f fscore:%s' % (m_f_score, str(f_score)))
141 | print('------' * 5)
142 |
143 |
144 | if __name__ == '__main__':
145 | potsdam_path = '/media/hlf/Luffy/WLS/semantic/dataset/potsdam/dataset_origin'
146 | run_potsdam(potsdam_path)
147 |
148 | # vaihingen_path = '/media/hlf/Luffy/WLS/semantic/dataset/vaihingen/dataset_origin'
149 | # run_vaihingen(vaihingen_path)
150 |
151 |
--------------------------------------------------------------------------------
/run/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Luffy03/DBFNet/5ec2458098c206aa9780e693ed81320e548da4a2/run/__init__.py
--------------------------------------------------------------------------------
/run/point/p_predict_train.py:
--------------------------------------------------------------------------------
1 | from utils import *
2 | from dataset import *
3 | from torchvision import transforms
4 | from matplotlib import pyplot as plt
5 | from torch.utils.data import DataLoader
6 | from tqdm import tqdm
7 | from datafiles.color_dict import *
8 | from models.DBFNet import clean_mask
9 | from models.tools import get_crf
10 | import tifffile
11 | import random
12 | import ttach as tta
13 |
14 | print("PyTorch Version: ", torch.__version__)
15 | print('cuda', torch.version.cuda)
16 | os.environ['CUDA_VISIBLE_DEVICES'] = '0'
17 | device = torch.device("cuda:0")
18 | print('Device:', device)
19 |
20 |
21 | def predict_val():
22 | opt = Point_Options().parse()
23 | log_path, checkpoint_path, predict_test_path, predict_train_path, predict_val_path = create_save_path(opt)
24 | train_txt_path, val_txt_path, test_txt_path = create_data_path(opt)
25 |
26 | # load train and val dataset
27 | val_dataset = Dataset_point(opt, val_txt_path, flag='predict_val', transform=None)
28 | loader = DataLoader(val_dataset, batch_size=1, num_workers=opt.num_workers,
29 | pin_memory=opt.pin)
30 |
31 | model = Net(opt, flag='test')
32 | checkpoint = torch.load(checkpoint_path + '/model_best_0.8553.pth', map_location=torch.device('cpu'))
33 | model.load_state_dict(checkpoint['state_dict'])
34 | print('resume success')
35 | model = model.cuda()
36 | model.eval()
37 |
38 | for batch_idx, batch in enumerate(tqdm(loader)):
39 | img, label, cls_label, filename = batch
40 | with torch.no_grad():
41 | input, label, cls_label = img.cuda(non_blocking=True), label.cuda(non_blocking=True), \
42 | cls_label.cuda(non_blocking=True)
43 | tta_model = tta.SegmentationTTAWrapper(model, tta.aliases.d4_transform(),
44 | merge_mode='mean')
45 | coarse_mask = tta_model(input)
46 | torch.cuda.synchronize()
47 |
48 | # save train predict after softmax and clean
49 | coarse_mask = F.softmax(coarse_mask, dim=1)
50 | mask_clean = clean_mask(coarse_mask, cls_label)
51 | mask = mask_clean.squeeze(0).permute(1, 2, 0).cpu().numpy()
52 | mask = mask.astype(np.float32)
53 |
54 | # get crf
55 | img = img.squeeze(0).permute(1, 2, 0).cpu().numpy()
56 | img = Normalize_back(img, flag=opt.dataset)
57 | crf_out = get_crf(opt, mask, img.astype(np.uint8))
58 |
59 | # # save crf out
60 | predict_path_crf = predict_val_path + '/crf'
61 | check_dir(predict_path_crf)
62 | save_pred_anno_numpy(crf_out, predict_path_crf, filename, dict=postdam_color_dict, flag=False)
63 |
64 |
65 | def main():
66 | opt = Point_Options().parse()
67 | log_path, checkpoint_path, predict_test_path, predict_train_path, predict_val_path = create_save_path(opt)
68 | train_txt_path, val_txt_path, test_txt_path = create_data_path(opt)
69 |
70 | # load train and val dataset
71 | train_dataset = Dataset_point(opt, train_txt_path, flag='predict_train', transform=None)
72 | loader = DataLoader(train_dataset, batch_size=1, num_workers=opt.num_workers,
73 | pin_memory=opt.pin)
74 |
75 | model = Net(opt, flag='test')
76 | checkpoint = torch.load(checkpoint_path+'/model_best_0.8553.pth', map_location=torch.device('cpu'))
77 | model.load_state_dict(checkpoint['state_dict'])
78 | print('resume success')
79 | model = model.to(device)
80 | model.eval()
81 |
82 | # numpy for crf
83 | compute_metric_ny = IOUMetric(num_classes=opt.num_classes)
84 | hist_ny = np.zeros([opt.num_classes, opt.num_classes])
85 |
86 | for batch_idx, batch in enumerate(tqdm(loader)):
87 | img, label, cls_label, filename = batch
88 | with torch.no_grad():
89 | input, label, cls_label = img.cuda(non_blocking=True), label.cuda(non_blocking=True),\
90 | cls_label.cuda(non_blocking=True)
91 | tta_model = tta.SegmentationTTAWrapper(model, tta.aliases.d4_transform(),
92 | merge_mode='mean')
93 | coarse_mask = tta_model(input)
94 | torch.cuda.synchronize()
95 |
96 | # save train predict after softmax and clean
97 | coarse_mask = F.softmax(coarse_mask, dim=1)
98 | mask_clean = clean_mask(coarse_mask, cls_label)
99 |
100 | mask = mask_clean.squeeze(0).permute(1, 2, 0).cpu().numpy()
101 | mask = mask.astype(np.float32)
102 |
103 | # get crf
104 | img = img.squeeze(0).permute(1, 2, 0).cpu().numpy()
105 | img = Normalize_back(img, flag=opt.dataset)
106 | crf_out = get_crf(opt, mask, img.astype(np.uint8))
107 | # # # save crf out
108 | predict_path_crf = predict_train_path + '/crf'
109 | check_dir(predict_path_crf)
110 | save_pred_anno_numpy(crf_out, predict_path_crf, filename, dict=postdam_color_dict, flag=False)
111 |
112 | label_ny = label.squeeze(0).data.cpu().numpy()
113 | hist_ny += compute_metric_ny.get_hist(crf_out, label_ny)
114 |
115 | # crf out's metric
116 | iou, miou, kappa, acc, acc_cls, f_score, m_f_score = eval_hist(hist_ny)
117 |
118 | print('miou:%.4f iou:%s kappa:%.4f' % (miou, str(iou), kappa))
119 | print('acc:%.4f acc_cls:%s' % (acc, str(acc_cls)))
120 | print('mfscore:%.4f fscore:%s' % (m_f_score, str(f_score)))
121 |
122 | result_path = log_path[:-4] + '/result_predict_train.txt'
123 | result_txt = open(result_path, 'a')
124 | result_txt.write('miou:%.4f iou:%s kappa:%.4f' % (miou, str(iou), kappa) + '\n')
125 | result_txt.write('acc:%.4f acc_cls:%s' % (acc, str(acc_cls)) + '\n')
126 | result_txt.write('mfscore:%.4f fscore:%s' % (m_f_score, str(f_score)) + '\n')
127 | result_txt.write('---------------------------' + '\n')
128 | result_txt.close()
129 |
130 |
131 | def show(predict_path, test_path):
132 | path = predict_path + '/label_vis'
133 | list = os.listdir(path)
134 |
135 | img_path = test_path + '/img'
136 | point_label_path = test_path + '/point_label_vis'
137 | label_path = test_path + '/label_vis'
138 |
139 | for i in list:
140 | predict = read(os.path.join(path, i))
141 |
142 | img = read(os.path.join(img_path, i))[:, :, :3]
143 | label = read(os.path.join(label_path, i))
144 | point_label = read(os.path.join(point_label_path, i))
145 |
146 | fig, axs = plt.subplots(1, 4, figsize=(14, 4))
147 |
148 | axs[0].imshow(img.astype(np.uint8))
149 | axs[0].axis("off")
150 | axs[1].imshow(label.astype(np.uint8))
151 | axs[1].axis("off")
152 | axs[2].imshow(predict.astype(np.uint8))
153 | axs[2].axis("off")
154 | axs[3].imshow(point_label.astype(np.uint8))
155 | axs[3].axis("off")
156 |
157 | plt.suptitle(os.path.basename(i), y=0.94)
158 | plt.tight_layout()
159 | plt.show()
160 | plt.close()
161 |
162 |
163 | if __name__ == '__main__':
164 | main()
165 | predict_val()
166 |
--------------------------------------------------------------------------------
/run/point/p_test.py:
--------------------------------------------------------------------------------
1 | from utils import *
2 | from dataset import *
3 | from matplotlib import pyplot as plt
4 | from torch.utils.data import DataLoader
5 | from tqdm import tqdm
6 | from datafiles.color_dict import *
7 | from models.tools import get_crf
8 | import random
9 | import ttach as tta
10 | print("PyTorch Version: ", torch.__version__)
11 | print('cuda', torch.version.cuda)
12 | os.environ['CUDA_VISIBLE_DEVICES'] = '0'
13 | device = torch.device("cuda:0")
14 | print('Device:', device)
15 |
16 |
17 | def main():
18 | opt = Point_Options().parse()
19 | log_path, checkpoint_path, predict_path, _, _ = create_save_path(opt)
20 | train_txt_path, val_txt_path, test_txt_path = create_data_path(opt)
21 |
22 | # load train and val dataset
23 | test_dataset = Dataset_point(opt, test_txt_path, flag='test', transform=None)
24 | loader = DataLoader(test_dataset, batch_size=1, shuffle=True, num_workers=opt.num_workers,
25 | pin_memory=opt.pin)
26 |
27 | model = Net(opt, flag='test')
28 | checkpoint = torch.load(checkpoint_path + '/model_best_0.8553.pth', map_location=torch.device('cpu'))
29 | model.load_state_dict(checkpoint['state_dict'])
30 | print('resume success')
31 | model = model.to(device)
32 | model.eval()
33 |
34 | compute_metric = IOUMetric(num_classes=opt.num_classes)
35 | hist = np.zeros([opt.num_classes, opt.num_classes])
36 |
37 | for batch_idx, batch in enumerate(tqdm(loader)):
38 | img, label, cls_label, filename = batch
39 | with torch.no_grad():
40 | input, label = img.cuda(non_blocking=True), label.cuda(non_blocking=True)
41 | tta_model = tta.SegmentationTTAWrapper(model, tta.aliases.d4_transform(),
42 | merge_mode='mean')
43 | output = tta_model(input)
44 | torch.cuda.synchronize()
45 |
46 | mask = F.softmax(output, dim=1)
47 | mask = mask.squeeze(0).permute(1, 2, 0).cpu().numpy()
48 | mask = mask.astype(np.float32)
49 |
50 | img = img.squeeze(0).permute(1, 2, 0).cpu().numpy()
51 | img = Normalize_back(img, flag=opt.dataset)
52 |
53 | crf_out = get_crf(opt, mask, img.astype(np.uint8))
54 | # save_pred_anno_numpy(crf_out, predict_path, filename, dict=postdam_color_dict, flag=True)
55 | label = label.squeeze(0).data.cpu().numpy()
56 |
57 | hist += compute_metric.get_hist(crf_out, label)
58 |
59 | iou, miou, kappa, acc, acc_cls, f_score, m_f_score = eval_hist(hist)
60 |
61 | print('miou:%.4f iou:%s kappa:%.4f' % (miou, str(iou), kappa))
62 | print('acc:%.4f acc_cls:%s' % (acc, str(acc_cls)))
63 | print('mfscore:%.4f fscore:%s' % (m_f_score, str(f_score)))
64 |
65 |
66 | if __name__ == '__main__':
67 | main()
68 |
69 |
--------------------------------------------------------------------------------
/run/point/p_train.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.optim as optim
3 | from torchvision import transforms
4 | from torch.autograd import Variable
5 | import shutil
6 | from torch.utils.data import DataLoader
7 | from tqdm import tqdm
8 | from collections import OrderedDict
9 | from options import *
10 | from utils import *
11 | from dataset import *
12 |
13 | print("PyTorch Version: ", torch.__version__)
14 | print('cuda', torch.version.cuda)
15 | os.environ['CUDA_VISIBLE_DEVICES'] = '0'
16 | device = torch.device("cuda:0")
17 | print('Device:', device)
18 |
19 |
20 | def main():
21 | opt = Point_Options().parse()
22 | log_path, checkpoint_path, _, _, _ = create_save_path(opt)
23 | train_txt_path, val_txt_path, _ = create_data_path(opt)
24 |
25 | # Define logger
26 | logger, tensorboard_log_dir = create_logger(log_path)
27 | logger.info(opt)
28 | # Define Transformation
29 | train_transform = transforms.Compose([
30 | trans.RandomHorizontalFlip(),
31 | trans.RandomVerticleFlip(),
32 | trans.RandomRotate90(),
33 | ])
34 |
35 | # load train and val dataset
36 | train_dataset = Dataset_point(opt, train_txt_path, flag='train', transform=train_transform)
37 | val_dataset = Dataset_point(opt, val_txt_path, flag='val', transform=None)
38 |
39 | # Create training and validation dataloaders
40 | train_loader = DataLoader(train_dataset, batch_size=opt.batch_size, shuffle=True,
41 | num_workers=opt.num_workers, pin_memory=opt.pin, drop_last=True)
42 |
43 | model = Net(opt)
44 | if opt.resume:
45 | checkpoint = torch.load(checkpoint_path+'/model_best.pth', map_location=torch.device('cpu'))
46 | model.load_state_dict(checkpoint['state_dict'])
47 | print('resume success')
48 |
49 | model = model.to(device)
50 | optimizer = optim.AdamW(model.parameters(), lr=opt.base_lr, amsgrad=True)
51 | best_metric = 1e16
52 |
53 | for epoch in range(opt.num_epochs):
54 | time_start = time.time()
55 |
56 | train(opt, epoch, model, train_loader, optimizer, logger)
57 | metric = validate(opt, model, val_dataset, logger)
58 |
59 | logger.info('best_val_metric:%.4f current_val_metric:%.4f' % (best_metric, metric))
60 | if metric < best_metric:
61 | logger.info('epoch:%d Save to model_best' % (epoch))
62 | torch.save({'state_dict': model.state_dict()},
63 | os.path.join(checkpoint_path, 'model_best.pth'))
64 | best_metric = metric
65 |
66 | end_time = time.time()
67 | time_cost = end_time - time_start
68 | logger.info("Epoch %d Time %d ----------------------" % (epoch, time_cost))
69 | logger.info('\n')
70 |
71 |
72 | def train(opt, epoch, model, loader, optimizer, logger):
73 | model.train()
74 |
75 | loss_m = AverageMeter()
76 | seg_loss_m = AverageMeter()
77 | pse_loss_m = AverageMeter()
78 | penalty_m = AverageMeter()
79 |
80 | last_idx = len(loader) - 1
81 |
82 | for batch_idx, batch in enumerate(loader):
83 | step = epoch * len(loader) + batch_idx
84 | adjust_learning_rate(opt.base_lr, optimizer, step, len(loader), num_epochs=30)
85 | lr = get_lr(optimizer)
86 |
87 | last_batch = batch_idx == last_idx
88 | img, label, cls_label, _ = batch
89 | input, label, cls_label = img.cuda(non_blocking=True), label.cuda(non_blocking=True), \
90 | cls_label.cuda(non_blocking=True)
91 | seg_loss, penalty = model.forward_loss(input, label, cls_label)
92 | loss = seg_loss + penalty
93 |
94 | seg_loss_m.update(seg_loss.item(), input.size(0))
95 | penalty_m.update(penalty.item(), input.size(0))
96 |
97 | loss_m.update(loss.item(), input.size(0))
98 |
99 | optimizer.zero_grad()
100 | loss.backward()
101 | optimizer.step()
102 |
103 | torch.cuda.synchronize()
104 |
105 | log_interval = len(loader) // 3
106 | if last_batch or batch_idx % log_interval == 0:
107 | logger.info('Train:{} [{:>4d}/{} ({:>3.0f}%)] '
108 | 'Loss:({loss.avg:>6.4f}) '
109 | 'segloss:({seg_loss.avg:>6.4f}) '
110 | #'pseloss:({pse_loss.avg:>6.4f}) '
111 | 'penal:({penal.avg:>6.4f}) '
112 | 'LR:{lr:.3e} '.format(
113 | epoch,
114 | batch_idx, len(loader),
115 | 100. * batch_idx / last_idx,
116 | loss=loss_m,
117 | seg_loss=seg_loss_m,
118 | pse_loss=pse_loss_m,
119 | penal=penalty_m,
120 | lr=lr))
121 |
122 | return OrderedDict([('loss', loss_m.avg)])
123 |
124 |
125 | def validate(opt, model, val_dataset, logger):
126 | val_loader = DataLoader(val_dataset, batch_size=opt.batch_size, shuffle=True, num_workers=opt.num_workers, pin_memory=opt.pin)
127 | model.eval()
128 | loss_m = AverageMeter()
129 | seg_loss_m = AverageMeter()
130 | pse_loss_m = AverageMeter()
131 | penalty_m = AverageMeter()
132 |
133 | for batch_idx, batch in enumerate(val_loader):
134 | img, label, cls_label, _ = batch
135 | with torch.no_grad():
136 | input, label, cls_label = img.cuda(non_blocking=True), label.cuda(non_blocking=True), \
137 | cls_label.cuda(non_blocking=True)
138 | out = model.forward(input)
139 |
140 | criterion = nn.CrossEntropyLoss(ignore_index=255)
141 | # penalty
142 | penalty = get_penalty(out, cls_label)
143 | # label: point annotations point-level supervision
144 | seg_loss = criterion(out, label)
145 | loss = seg_loss + penalty
146 |
147 | seg_loss_m.update(seg_loss.item(), input.size(0))
148 | # pse_loss_m.update(pse_loss.item(), input.size(0))
149 | penalty_m.update(penalty.item(), input.size(0))
150 | loss_m.update(loss.item(), input.size(0))
151 |
152 | logger.info('VAL:'
153 | 'Loss:({loss.avg:>6.4f}) '
154 | 'segloss:({seg_loss.avg:>6.4f}) '
155 | # 'pseloss:({pse_loss.avg:>6.4f}) '
156 | 'penal:({penal.avg:>6.4f}) '.format(
157 | loss=loss_m,
158 | seg_loss=seg_loss_m,
159 | pse_loss=pse_loss_m,
160 | penal=penalty_m
161 | ))
162 |
163 | return loss_m.avg
164 |
165 |
166 | if __name__ == '__main__':
167 | main()
--------------------------------------------------------------------------------
/run/second/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Luffy03/DBFNet/5ec2458098c206aa9780e693ed81320e548da4a2/run/second/__init__.py
--------------------------------------------------------------------------------
/run/second/sec_predict_train.py:
--------------------------------------------------------------------------------
1 | from utils import *
2 | from dataset import *
3 | from torchvision import transforms
4 | from matplotlib import pyplot as plt
5 | from torch.utils.data import DataLoader
6 | from tqdm import tqdm
7 | from datafiles.color_dict import *
8 | from models.MyModel import clean_mask
9 | import tifffile
10 | import random
11 | import ttach as tta
12 |
13 | print("PyTorch Version: ", torch.__version__)
14 | print('cuda', torch.version.cuda)
15 | os.environ['CUDA_VISIBLE_DEVICES'] = '0'
16 | device = torch.device("cuda:0")
17 | print('Device:', device)
18 |
19 |
20 | def predict_val():
21 | opt = Sec_Options().parse()
22 | log_path, checkpoint_path, predict_test_path, predict_train_path, predict_val_path = create_save_path(opt)
23 | train_txt_path, val_txt_path, test_txt_path = create_data_path(opt)
24 |
25 | # load val dataset
26 | val_dataset = Dataset_point(opt, val_txt_path, flag='val', transform=None)
27 | loader = DataLoader(val_dataset, batch_size=1, num_workers=opt.num_workers,
28 | pin_memory=opt.pin)
29 |
30 | model = Seg_Net(opt)
31 | checkpoint = torch.load(checkpoint_path + '/model_best_0.8184.pth', map_location=torch.device('cpu'))
32 | model.load_state_dict(checkpoint['state_dict'])
33 | print('resume success')
34 | model = model.to(device)
35 | model.eval()
36 |
37 | for batch_idx, batch in enumerate(tqdm(loader)):
38 | img, label, cls_label, filename = batch
39 | with torch.no_grad():
40 | input, label, cls_label = img.cuda(non_blocking=True), label.cuda(non_blocking=True), \
41 | cls_label.cuda(non_blocking=True)
42 | coarse_mask = model.forward(input)
43 | torch.cuda.synchronize()
44 |
45 | # save train predict after softmax and clean
46 | coarse_mask = F.softmax(coarse_mask, dim=1)
47 | mask_clean = clean_mask(coarse_mask, cls_label)
48 | mask_clean = torch.argmax(mask_clean, dim=1)
49 | save_pred_anno(mask_clean, predict_val_path, filename, dict=postdam_color_dict, flag=False)
50 |
51 |
52 | def main():
53 | opt = Sec_Options().parse()
54 | save_path = os.path.join(opt.save_path, opt.dataset)
55 |
56 | log_path, checkpoint_path, predict_test_path, predict_train_path, predict_val_path = create_save_path(opt)
57 | train_txt_path, val_txt_path, test_txt_path = create_data_path(opt)
58 |
59 | # load train and val dataset
60 | train_dataset = Dataset_point(opt, train_txt_path, flag='predict_train', transform=None)
61 | loader = DataLoader(train_dataset, batch_size=1, num_workers=opt.num_workers,
62 | pin_memory=opt.pin)
63 |
64 | model = Seg_Net(opt)
65 | checkpoint = torch.load(checkpoint_path+'/model_best_0.8184.pth', map_location=torch.device('cpu'))
66 | model.load_state_dict(checkpoint['state_dict'])
67 | print('resume success')
68 | model = model.to(device)
69 | model.eval()
70 |
71 | compute_metric = IOUMetric_tensor(num_classes=opt.num_classes)
72 | hist = torch.zeros([opt.num_classes, opt.num_classes]).cuda()
73 |
74 | for batch_idx, batch in enumerate(tqdm(loader)):
75 | img, label, cls_label, filename = batch
76 | with torch.no_grad():
77 | input, label, cls_label = img.cuda(non_blocking=True), label.cuda(non_blocking=True),\
78 | cls_label.cuda(non_blocking=True)
79 | tta_model = tta.SegmentationTTAWrapper(model, tta.aliases.d4_transform(),
80 | merge_mode='mean')
81 | coarse_mask = tta_model.forward(input)
82 | torch.cuda.synchronize()
83 |
84 | output = torch.argmax(coarse_mask, dim=1)
85 |
86 | # save train predict after softmax and clean
87 | coarse_mask = F.softmax(coarse_mask, dim=1)
88 | mask_clean = clean_mask(coarse_mask, cls_label)
89 | mask_clean = torch.argmax(mask_clean, dim=1)
90 | save_pred_anno(mask_clean, predict_train_path, filename, dict=postdam_color_dict, flag=True)
91 |
92 | hist += compute_metric.get_hist(output, label)
93 |
94 | hist = hist.data.cpu().numpy()
95 | iou, miou, kappa, acc, acc_cls, f_score, m_f_score = eval_hist(hist)
96 |
97 | print('miou:%.4f iou:%s kappa:%.4f' % (miou, str(iou), kappa))
98 | print('acc:%.4f acc_cls:%s' % (acc, str(acc_cls)))
99 | print('mfscore:%.4f fscore:%s' % (m_f_score, str(f_score)))
100 |
101 | result_path = log_path[:-4] + '/result_predict_train.txt'
102 | result_txt = open(result_path, 'a')
103 | result_txt.write('miou:%.4f iou:%s kappa:%.4f' % (miou, str(iou), kappa) + '\n')
104 | result_txt.write('acc:%.4f acc_cls:%s' % (acc, str(acc_cls)) + '\n')
105 | result_txt.write('mfscore:%.4f fscore:%s' % (m_f_score, str(f_score)) + '\n')
106 | result_txt.write('---------------------------' + '\n')
107 | result_txt.close()
108 |
109 |
110 | def show(predict_path, test_path):
111 | path = predict_path + '/label_vis'
112 | list = os.listdir(path)
113 |
114 | img_path = test_path + '/img'
115 | point_label_path = test_path + '/point_label_vis'
116 | label_path = test_path + '/label_vis'
117 |
118 | for i in list:
119 | predict = read(os.path.join(path, i))
120 |
121 | img = read(os.path.join(img_path, i))[:, :, :3]
122 | label = read(os.path.join(label_path, i))
123 | point_label = read(os.path.join(point_label_path, i))
124 |
125 | fig, axs = plt.subplots(1, 4, figsize=(14, 4))
126 |
127 | axs[0].imshow(img.astype(np.uint8))
128 | axs[0].axis("off")
129 | axs[1].imshow(label.astype(np.uint8))
130 | axs[1].axis("off")
131 | axs[2].imshow(predict.astype(np.uint8))
132 | axs[2].axis("off")
133 | axs[3].imshow(point_label.astype(np.uint8))
134 | axs[3].axis("off")
135 |
136 | plt.suptitle(os.path.basename(i), y=0.94)
137 | plt.tight_layout()
138 | plt.show()
139 | plt.close()
140 |
141 |
142 | if __name__ == '__main__':
143 | main()
144 |
145 | # opt = Point_iter_second_Options().parse()
146 | # log_path, checkpoint_path, predict_path, predict_train_path, _ = create_save_path(opt)
147 | #
148 | # label_path = '/home/ggm/WLS/semantic/dataset/potsdam/train'
149 | # show(predict_train_path, label_path)
150 |
151 | predict_val()
--------------------------------------------------------------------------------
/run/second/sec_test.py:
--------------------------------------------------------------------------------
1 | from utils import *
2 | from dataset import *
3 | from matplotlib import pyplot as plt
4 | from torch.utils.data import DataLoader
5 | from tqdm import tqdm
6 | from datafiles.color_dict import *
7 | from models.tools import get_crf
8 | import random
9 | import ttach as tta
10 |
11 | print("PyTorch Version: ", torch.__version__)
12 | print('cuda', torch.version.cuda)
13 | os.environ['CUDA_VISIBLE_DEVICES'] = '0'
14 | device = torch.device("cuda:0")
15 | print('Device:', device)
16 |
17 |
18 | def main():
19 | opt = Sec_Options().parse()
20 | save_path = os.path.join(opt.save_path, opt.dataset)
21 | log_path, checkpoint_path, predict_path, _, _ = create_save_path(opt)
22 | train_txt_path, val_txt_path, test_txt_path = create_data_path(opt)
23 |
24 | # load train and val dataset
25 | test_dataset = Dataset_point(opt, test_txt_path, flag='test', transform=None)
26 | loader = DataLoader(test_dataset, batch_size=1, shuffle=True, num_workers=opt.num_workers,
27 | pin_memory=opt.pin)
28 |
29 | model = Seg_Net(opt)
30 | checkpoint = torch.load(checkpoint_path+'/model_best.pth', map_location=torch.device('cpu'))
31 | model.load_state_dict(checkpoint['state_dict'])
32 | print('resume success')
33 | model = model.to(device)
34 | model.eval()
35 |
36 | compute_metric = IOUMetric_tensor(num_classes=opt.num_classes)
37 | hist = torch.zeros([opt.num_classes, opt.num_classes]).cuda()
38 |
39 | for batch_idx, batch in enumerate(tqdm(loader)):
40 | img, label, cls_label, filename = batch
41 | with torch.no_grad():
42 | input, label = img.cuda(non_blocking=True), label.cuda(non_blocking=True)
43 | tta_model = tta.SegmentationTTAWrapper(model, tta.aliases.d4_transform(),
44 | merge_mode='mean')
45 | output = model(input)
46 | torch.cuda.synchronize()
47 |
48 | if opt.dataset == 'vaihingen':
49 | cls_label = cls_label.cuda(non_blocking=True)
50 | coarse_mask = F.softmax(output, dim=1)
51 | mask_clean = clean_mask(coarse_mask, cls_label)
52 | output = torch.argmax(mask_clean, dim=1)
53 | else:
54 | output = torch.argmax(output, dim=1)
55 |
56 | # save_pred_anno(output, predict_path, filename, dict=postdam_color_dict, flag=True)
57 | hist += compute_metric.get_hist(output, label)
58 |
59 | hist = hist.data.cpu().numpy()
60 | iou, miou, kappa, acc, acc_cls, f_score, m_f_score = eval_hist(hist)
61 |
62 | print('miou:%.4f iou:%s kappa:%.4f' % (miou, str(iou), kappa))
63 | print('acc:%.4f acc_cls:%s' % (acc, str(acc_cls)))
64 | print('mfscore:%.4f fscore:%s' % (m_f_score, str(f_score)))
65 |
66 | # result_path = log_path[:-4] + '/result_test.txt'
67 | # result_txt = open(result_path, 'a')
68 | # result_txt.write('miou:%.4f iou:%s kappa:%.4f' % (miou, str(iou), kappa) + '\n')
69 | # result_txt.write('acc:%.4f acc_cls:%s' % (acc, str(acc_cls)) + '\n')
70 | # result_txt.write('mfscore:%.4f fscore:%s' % (m_f_score, str(f_score)) + '\n')
71 | # result_txt.write('---------------------------' + '\n')
72 | # result_txt.close()
73 |
74 |
75 | def p_test_withcrf():
76 | opt = Sec_Options().parse()
77 | log_path, checkpoint_path, predict_path, _, _ = create_save_path(opt)
78 | train_txt_path, val_txt_path, test_txt_path = create_data_path(opt)
79 |
80 | # load train and val dataset
81 | test_dataset = Dataset_point(opt, test_txt_path, flag='test', transform=None)
82 | loader = DataLoader(test_dataset, batch_size=1, shuffle=True, num_workers=opt.num_workers,
83 | pin_memory=opt.pin)
84 |
85 | model = Seg_Net(opt)
86 | checkpoint = torch.load(checkpoint_path + '/model_best.pth', map_location=torch.device('cpu'))
87 | model.load_state_dict(checkpoint['state_dict'])
88 | print('resume success')
89 | model = model.to(device)
90 | model.eval()
91 |
92 | compute_metric = IOUMetric(num_classes=opt.num_classes)
93 | hist = np.zeros([opt.num_classes, opt.num_classes])
94 |
95 | for batch_idx, batch in enumerate(tqdm(loader)):
96 | img, label, cls_label, filename = batch
97 | with torch.no_grad():
98 | input, label = img.cuda(non_blocking=True), label.cuda(non_blocking=True)
99 | tta_model = tta.SegmentationTTAWrapper(model, tta.aliases.d4_transform(),
100 | merge_mode='mean')
101 | output = tta_model(input)
102 | torch.cuda.synchronize()
103 |
104 | if opt.dataset == 'vaihingen':
105 | cls_label = cls_label.cuda(non_blocking=True)
106 | coarse_mask = F.softmax(output, dim=1)
107 | mask = clean_mask(coarse_mask, cls_label)
108 | else:
109 | mask = F.softmax(output, dim=1)
110 |
111 | mask = mask.squeeze(0).permute(1, 2, 0).cpu().numpy()
112 | mask = mask.astype(np.float32)
113 |
114 | img = img.squeeze(0).permute(1, 2, 0).cpu().numpy()
115 | img = Normalize_back(img, flag='potsdam')
116 |
117 | crf_out = get_crf(opt, mask, img.astype(np.uint8))
118 | label = label.squeeze(0).data.cpu().numpy()
119 |
120 | # save crf out
121 | predict_path_crf = predict_path + '/crf'
122 | check_dir(predict_path_crf)
123 | # save_pred_anno_numpy(crf_out, predict_path_crf, filename, dict=nuclear_color_dict, flag=True)
124 |
125 | hist += compute_metric.get_hist(crf_out, label)
126 |
127 | iou, miou, kappa, acc, acc_cls, f_score, m_f_score = eval_hist(hist)
128 |
129 | print('miou:%.4f iou:%s kappa:%.4f' % (miou, str(iou), kappa))
130 | print('acc:%.4f acc_cls:%s' % (acc, str(acc_cls)))
131 | print('mfscore:%.4f fscore:%s' % (m_f_score, str(f_score)))
132 |
133 |
134 | def show(predict_path, test_path):
135 | # crf_path = predict_path[:-10] + '/crf/label_vis'
136 | img_path = test_path + '/img'
137 | label_path = test_path + '/label_vis'
138 |
139 | list = os.listdir(predict_path)
140 | random.shuffle(list)
141 |
142 | for i in list:
143 | predict = read(os.path.join(predict_path, i))
144 | img = read(os.path.join(img_path, i))[:, :, :3]
145 | # crf = read(os.path.join(crf_path, i))
146 | label = read(os.path.join(label_path, i))
147 |
148 | fig, axs = plt.subplots(1, 3, figsize=(14, 4))
149 |
150 | axs[0].imshow(img.astype(np.uint8))
151 | axs[0].axis("off")
152 | axs[1].imshow(label.astype(np.uint8))
153 | axs[1].axis("off")
154 |
155 | axs[2].imshow(predict.astype(np.uint8))
156 | axs[2].axis("off")
157 |
158 | # axs[3].imshow(crf.astype(np.uint8))
159 | # axs[3].axis("off")
160 |
161 | plt.suptitle(os.path.basename(i), y=0.94)
162 | plt.tight_layout()
163 | plt.show()
164 | plt.close()
165 |
166 |
167 | if __name__ == '__main__':
168 | main()
169 | p_test_withcrf()
170 |
171 | # predict_path = '/home/ggm/WLS/semantic/PointAnno/save/nuclear/Second/predict_test/label_vis'
172 | # test_path = '/home/ggm/WLS/semantic/dataset/nuclear/test'
173 | # show(predict_path, test_path)
--------------------------------------------------------------------------------
/run/second/sec_train.py:
--------------------------------------------------------------------------------
1 | import torch.optim as optim
2 | from torchvision import transforms
3 | from torch.autograd import Variable
4 | import shutil
5 | from torch.utils.data import DataLoader
6 | from tqdm import tqdm
7 | from collections import OrderedDict
8 | from options import *
9 | from utils import *
10 | from dataset import *
11 |
12 |
13 | print("PyTorch Version: ", torch.__version__)
14 | print('cuda', torch.version.cuda)
15 | os.environ['CUDA_VISIBLE_DEVICES'] = '0'
16 | device = torch.device("cuda:0")
17 | print('Device:', device)
18 |
19 |
20 | def main():
21 | opt = Sec_Options().parse()
22 | save_path = os.path.join(opt.save_path, opt.dataset)
23 | log_path, checkpoint_path, _, _, _ = create_save_path(opt)
24 | train_txt_path, val_txt_path, _ = create_data_path(opt)
25 |
26 | # Define logger
27 | logger, tensorboard_log_dir = create_logger(log_path)
28 | logger.info(opt)
29 | # Define Transformation
30 | train_transform = transforms.Compose([
31 | trans.RandomHorizontalFlip(),
32 | trans.RandomVerticleFlip(),
33 | trans.RandomRotate90(),
34 | ])
35 |
36 | # load train and val dataset
37 | train_dataset = Dataset_sec(opt, save_path, train_txt_path, flag='train', transform=train_transform)
38 | val_dataset = Dataset_sec(opt, save_path, val_txt_path, flag='val', transform=None)
39 |
40 | # Create training and validation dataloaders
41 | train_loader = DataLoader(train_dataset, batch_size=opt.batch_size, shuffle=True,
42 | num_workers=opt.num_workers, pin_memory=opt.pin, drop_last=True)
43 |
44 | model = Seg_Net(opt)
45 | if opt.resume:
46 | checkpoint = torch.load(checkpoint_path+'/model_best.pth', map_location=torch.device('cpu'))
47 | model.load_state_dict(checkpoint['state_dict'])
48 | print('resume success')
49 |
50 | model = model.to(device)
51 | optimizer = optim.AdamW(model.parameters(), lr=opt.base_lr, amsgrad=True)
52 | best_metric = -1e16
53 |
54 | for epoch in range(opt.num_epochs):
55 | time_start = time.time()
56 |
57 | train(opt, epoch, model, train_loader, optimizer, logger)
58 | metric = validate(opt, epoch, model, val_dataset, logger)
59 |
60 | logger.info('best_val_metric:%.4f current_val_metric:%.4f' % (best_metric, metric))
61 | if metric > best_metric:
62 | logger.info('epoch:%d Save to model_best' % (epoch))
63 | torch.save({'state_dict': model.state_dict()},
64 | os.path.join(checkpoint_path, 'model_best.pth'))
65 | best_metric = metric
66 |
67 | end_time = time.time()
68 | time_cost = end_time - time_start
69 | logger.info("Epoch %d Time %d ----------------------" % (epoch, time_cost))
70 | logger.info('\n')
71 |
72 |
73 | def train(opt, epoch, model, loader, optimizer, logger):
74 | start_iter_epoch = 15
75 |
76 | model.train()
77 | loss_m = AverageMeter()
78 | soft_loss_m = AverageMeter()
79 | last_idx = len(loader) - 1
80 |
81 | for batch_idx, batch in enumerate(loader):
82 | # step = epoch * len(loader) + batch_idx
83 | # adjust_learning_rate(opt.base_lr, optimizer, step, len(loader))
84 | lr = get_lr(optimizer)
85 |
86 | last_batch = batch_idx == last_idx
87 | img, label, name = batch
88 | input, label = img.cuda(non_blocking=True), label.cuda(non_blocking=True)
89 |
90 | output = model(input)
91 |
92 | criterion = nn.CrossEntropyLoss(ignore_index=255)
93 | loss = criterion(output, label)
94 | # if epoch <= start_iter_epoch:
95 | # loss = criterion(output, label)
96 | #
97 | # else:
98 | # soft_loss = get_soft_loss(output, soft_label)
99 | # loss = criterion(output, label) + soft_loss
100 | # soft_loss_m.update(soft_loss.item(), input.size(0))
101 |
102 | loss_m.update(loss.item(), input.size(0))
103 | optimizer.zero_grad()
104 | loss.backward()
105 | optimizer.step()
106 | torch.cuda.synchronize()
107 |
108 | log_interval = len(loader) // 3
109 | if last_batch or batch_idx % log_interval == 0:
110 | logger.info('Train:{} [{:>4d}/{} ({:>3.0f}%)] '
111 | 'Loss:({loss.avg:>6.4f}) '
112 | #'soft:({soft.avg:>6.4f}) '
113 | 'LR:{lr:.3e} '.format(
114 | epoch,
115 | batch_idx, len(loader),
116 | 100. * batch_idx / last_idx,
117 | loss=loss_m,
118 | lr=lr))
119 |
120 | # if epoch >= start_iter_epoch:
121 | # logger.info('predict_train')
122 | # predict_train(opt, model)
123 |
124 | return OrderedDict([('loss', loss_m.avg)])
125 |
126 |
127 | def predict_train(opt, model):
128 | train_txt_path, val_txt_path, test_txt_path = create_data_path(opt)
129 | save_path = os.path.join(opt.save_path, opt.dataset)
130 | train_dataset = Dataset_soft(opt, save_path, train_txt_path, flag='train', transform=None)
131 | loader = DataLoader(train_dataset, batch_size=1, num_workers=opt.num_workers,
132 | pin_memory=opt.pin)
133 |
134 | model.eval()
135 | for batch_idx, batch in enumerate(tqdm(loader)):
136 | img, label, cls_label, soft_label, filename = batch
137 | with torch.no_grad():
138 | input, cls_label = img.cuda(non_blocking=True), cls_label.cuda(non_blocking=True)
139 | coarse_mask = model.forward(input)
140 | mask = F.softmax(coarse_mask, dim=1)
141 | mask = clean_mask(mask, cls_label)
142 | mask = mask.squeeze(0).permute(1, 2, 0).data.cpu().numpy()
143 |
144 | save_mask_path = save_path + '/Second/soft_label'
145 | check_dir(save_mask_path)
146 | path = save_mask_path + '/' + filename[0][:-4] + '.npy'
147 |
148 | if os.path.isfile(path):
149 | soft_label = np.load(path)
150 | new_soft_label = (soft_label + mask) / 2
151 | np.save(path, new_soft_label)
152 | else:
153 | np.save(path, mask)
154 |
155 |
156 | def validate(opt, epoch, model, val_dataset, logger):
157 | compute_metric = IOUMetric_tensor(num_classes=opt.num_classes)
158 | hist = torch.zeros([opt.num_classes, opt.num_classes]).cuda()
159 |
160 | val_loader = DataLoader(val_dataset, batch_size=opt.batch_size, shuffle=True, num_workers=opt.num_workers, pin_memory=opt.pin)
161 | model.eval()
162 |
163 | for batch_idx, batch in enumerate(val_loader):
164 | with torch.no_grad():
165 | img, label, name = batch
166 | input, label = img.cuda(non_blocking=True), label.cuda(non_blocking=True)
167 | output = model(input)
168 | torch.cuda.synchronize()
169 |
170 | output = torch.argmax(output, dim=1)
171 | hist += compute_metric.get_hist(output, label)
172 |
173 | iou = torch.diag(hist) / (hist.sum(dim=1) + hist.sum(dim=0) - torch.diag(hist))
174 | miou = torch.mean(iou).float()
175 |
176 | logger.info('epoch:%d current_miou:%.4f current_iou:%s' % (epoch, miou, str(iou)))
177 |
178 | return miou
179 |
180 |
181 | if __name__ == '__main__':
182 | main()
--------------------------------------------------------------------------------
/utils/__init__.py:
--------------------------------------------------------------------------------
1 | from utils.util import *
2 | from utils.metric import *
--------------------------------------------------------------------------------
/utils/__pycache__/__init__.cpython-35.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Luffy03/DBFNet/5ec2458098c206aa9780e693ed81320e548da4a2/utils/__pycache__/__init__.cpython-35.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/__init__.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Luffy03/DBFNet/5ec2458098c206aa9780e693ed81320e548da4a2/utils/__pycache__/__init__.cpython-36.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/__init__.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Luffy03/DBFNet/5ec2458098c206aa9780e693ed81320e548da4a2/utils/__pycache__/__init__.cpython-37.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/metric.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Luffy03/DBFNet/5ec2458098c206aa9780e693ed81320e548da4a2/utils/__pycache__/metric.cpython-36.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/metric.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Luffy03/DBFNet/5ec2458098c206aa9780e693ed81320e548da4a2/utils/__pycache__/metric.cpython-37.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/paint.cpython-35.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Luffy03/DBFNet/5ec2458098c206aa9780e693ed81320e548da4a2/utils/__pycache__/paint.cpython-35.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/paint.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Luffy03/DBFNet/5ec2458098c206aa9780e693ed81320e548da4a2/utils/__pycache__/paint.cpython-36.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/paint.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Luffy03/DBFNet/5ec2458098c206aa9780e693ed81320e548da4a2/utils/__pycache__/paint.cpython-37.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/query.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Luffy03/DBFNet/5ec2458098c206aa9780e693ed81320e548da4a2/utils/__pycache__/query.cpython-37.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/util.cpython-35.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Luffy03/DBFNet/5ec2458098c206aa9780e693ed81320e548da4a2/utils/__pycache__/util.cpython-35.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/util.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Luffy03/DBFNet/5ec2458098c206aa9780e693ed81320e548da4a2/utils/__pycache__/util.cpython-36.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/util.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Luffy03/DBFNet/5ec2458098c206aa9780e693ed81320e548da4a2/utils/__pycache__/util.cpython-37.pyc
--------------------------------------------------------------------------------
/utils/metric.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | import os
3 | import numpy as np
4 | import torch
5 | import math
6 |
7 |
8 | def cal_kappa(hist):
9 | if hist.sum() == 0:
10 | po = 0
11 | pe = 1
12 | kappa = 0
13 | else:
14 | po = np.diag(hist).sum() / hist.sum()
15 | pe = np.matmul(hist.sum(1), hist.sum(0).T) / hist.sum() ** 2
16 | if pe == 1:
17 | kappa = 0
18 | else:
19 | kappa = (po - pe) / (1 - pe)
20 | return kappa
21 |
22 |
23 | def cal_fscore(hist):
24 | TP = np.diag(hist)
25 | FP = hist.sum(axis=0) - np.diag(hist)
26 | FN = hist.sum(axis=1) - np.diag(hist)
27 | TN = hist.sum() - (FP + FN + TP)
28 |
29 | precision = TP / (TP + FP)
30 | recall = TP / (TP + FN)
31 |
32 | f_score = 2 * precision * recall / (precision + recall)
33 | m_f_score = np.mean(f_score)
34 | return f_score, m_f_score
35 |
36 |
37 | class IOUMetric:
38 | """
39 | Class to calculate mean-iou using fast_hist method
40 | """
41 |
42 | def __init__(self, num_classes):
43 | self.num_classes = num_classes
44 | self.hist = np.zeros((num_classes, num_classes))
45 |
46 | def get_hist(self, label_pred, label_true):
47 | # 找出标签中需要计算的类别,去掉了背景
48 | mask = (label_true >= 0) & (label_true < self.num_classes)
49 | # # np.bincount计算了从0到n**2-1这n**2个数中每个数出现的次数,返回值形状(n, n)
50 | hist = np.bincount(
51 | self.num_classes * label_true[mask].astype(int) +
52 | label_pred[mask], minlength=self.num_classes ** 2).reshape(self.num_classes, self.num_classes)
53 | return hist
54 |
55 | # 输入:预测值和真实值
56 | # 语义分割的任务是为每个像素点分配一个label
57 | def evaluate(self, predictions, gts):
58 | for lp, lt in zip(predictions, gts):
59 | assert len(lp.flatten()) == len(lt.flatten())
60 | self.hist += self.get_hist(lp.flatten(), lt.flatten())
61 | # miou
62 | iou = np.diag(self.hist) / (self.hist.sum(axis=1) + self.hist.sum(axis=0) - np.diag(self.hist))
63 | miou = np.nanmean(iou)
64 | # dice
65 | dice = 2 * np.diag(self.hist) / (self.hist.sum(axis=1) + self.hist.sum(axis=0))
66 | mdice = np.nanmean(dice)
67 |
68 | # -----------------其他指标------------------------------
69 | # mean acc
70 | acc = np.diag(self.hist).sum() / self.hist.sum()
71 | acc_cls = np.nanmean(np.diag(self.hist) / self.hist.sum(axis=1))
72 | freq = self.hist.sum(axis=1) / self.hist.sum()
73 | fwavacc = (freq[freq > 0] * iou[freq > 0]).sum()
74 | return acc, acc_cls, iou, miou, dice, mdice, fwavacc
75 |
76 |
77 | class IOUMetric_tensor:
78 | """
79 | Class to calculate mean-iou with tensor_type using fast_hist method
80 | """
81 |
82 | def __init__(self, num_classes):
83 | self.num_classes = num_classes
84 | self.hist = torch.zeros([num_classes, num_classes])
85 |
86 | def get_hist(self, label_pred, label_true):
87 | # 找出标签中需要计算的类别,去掉了背景
88 | mask = (label_true >= 0) & (label_true < self.num_classes)
89 | # # np.bincount计算了从0到n**2-1这n**2个数中每个数出现的次数,返回值形状(n, n)
90 | hist = torch.bincount(
91 | self.num_classes * label_true[mask] +
92 | label_pred[mask], minlength=self.num_classes ** 2).view(self.num_classes, self.num_classes)
93 | return hist
94 |
95 | # 输入:预测值和真实值
96 | # 语义分割的任务是为每个像素点分配一个label
97 | def evaluate(self, predictions, gts):
98 | for lp, lt in zip(predictions, gts):
99 | assert len(lp.flatten()) == len(lt.flatten())
100 | self.hist += self.get_hist(lp.flatten(), lt.flatten())
101 | # miou
102 | iou = torch.diag(self.hist) / (self.hist.sum(dim=1) + self.hist.sum(dim=0) - torch.diag(self.hist))
103 | miou = torch.mean(iou)
104 | # dice
105 | dice = 2 * torch.diag(self.hist) / (self.hist.sum(dim=1) + self.hist.sum(dim=0))
106 | mdice = torch.mean(dice)
107 |
108 | # -----------------其他指标------------------------------
109 | # mean acc
110 | acc = torch.diag(self.hist).sum() / self.hist.sum()
111 | acc_cls = torch.mean(np.diag(self.hist) / self.hist.sum(dim=1))
112 | freq = self.hist.sum(dim=1) / self.hist.sum()
113 | fwavacc = (freq[freq > 0] * iou[freq > 0]).sum()
114 | return acc, acc_cls, iou, miou, dice, mdice, fwavacc
115 |
116 |
117 | def eval_hist(hist):
118 | # hist must be numpy
119 | kappa = cal_kappa(hist)
120 | iou = np.diag(hist) / (hist.sum(axis=1) + hist.sum(axis=0) - np.diag(hist))
121 | miou = np.nanmean(iou)
122 |
123 | # f_score
124 | f_score, m_f_score = cal_fscore(hist)
125 |
126 | # mean acc
127 | acc = np.diag(hist).sum() / hist.sum()
128 | acc_cls = np.diag(hist) / hist.sum(axis=1)
129 |
130 | return iou, miou, kappa, acc, acc_cls, f_score, m_f_score
131 |
132 |
133 | def cls_accuracy(output, target):
134 | """Computes the precision@k for the specified values of k"""
135 | y_true = target.data.cpu().numpy()
136 | y_pred = output.data.cpu().numpy()
137 |
138 | # True Positive:即y_true与y_pred中同时为1的个数
139 | TP = np.sum(np.multiply(y_true, y_pred))
140 |
141 | # False Positive:即y_true中为0但是在y_pred中被识别为1的个数
142 | FP = np.sum(np.logical_and(np.equal(y_true, 0), np.equal(y_pred, 1)))
143 | # False Negative:即y_true中为1但是在y_pred中被识别为0的个数
144 | FN = np.sum(np.logical_and(np.equal(y_true, 1), np.equal(y_pred, 0)))
145 | # True Negative:即y_true与y_pred中同时为0的个数
146 | TN = np.sum(np.logical_and(np.equal(y_true, 0), np.equal(y_pred, 0)))
147 |
148 | # 根据上面得到的值计算A、P、R、F1
149 | kappa = (TP + TN) / (TP + FP + FN + TN) # y_pred与y_ture中同时为1或0
150 | P = TP / (TP + FP) # y_pred中为1的元素同时在y_true中也为1
151 | R = TP / (TP + FN) # y_true中为1的元素同时在y_pred中也为1
152 | F1 = 2 * P * R / (P + R)
153 |
154 | return P, R, kappa, F1
155 |
156 |
157 | def per_cls_accuracy(output, target):
158 | y_true = target.data.cpu().numpy()
159 | y_pred = output.data.cpu().numpy()
160 | acc_all = []
161 | for cls in range(y_true.shape[1]):
162 | true = np.sum(y_true[:, cls] == y_pred[:, cls])
163 | acc = true/y_true.shape[0]
164 | acc_all.append(acc)
165 | return acc_all
166 |
167 |
168 | if __name__ == '__main__':
169 | a = torch.randint(0, 2, size=[30, 4])
170 | b = torch.randint(0, 2, size=[30, 4])
171 |
172 | P, R, kappa, F1 = cls_accuracy(a, b)
173 | print(P, R, kappa, F1)
174 | acc_all = per_cls_accuracy(a, b)
175 | print(acc_all)
--------------------------------------------------------------------------------
/utils/paint.py:
--------------------------------------------------------------------------------
1 | import os
2 | import cv2
3 | from PIL import Image
4 | import numpy as np
5 | import matplotlib.pyplot as plt
6 | from datafiles.color_dict import color_dict
7 |
8 |
9 | def hex_to_rgb(value):
10 | lv = len(value)
11 | return tuple(int(value[i:i + lv // 3], 16) for i in range(0, lv, lv // 3))
12 |
13 |
14 | def create_visual_anno(anno, dict, flag=None):
15 | # assert np.max(anno) <= num_classes-1, "only %d classes are supported, add new color in label2color_dict" % (num_classes)
16 |
17 | if flag == 'hex':
18 | rgb_dict = {}
19 | for keys, hex_value in dict.items():
20 | rgb_value = hex_to_rgb(hex_value)
21 | rgb_dict[keys] = rgb_value
22 | else:
23 | rgb_dict = dict
24 |
25 | # visualize
26 | visual_anno = np.zeros((anno.shape[0], anno.shape[1], 3), dtype=np.uint8)
27 | for i in range(visual_anno.shape[0]): # i for h
28 | for j in range(visual_anno.shape[1]):
29 | color = rgb_dict[anno[i, j]]
30 | visual_anno[i, j, 0] = color[0]
31 | visual_anno[i, j, 1] = color[1]
32 | visual_anno[i, j, 2] = color[2]
33 |
34 | return visual_anno
35 |
36 |
--------------------------------------------------------------------------------
/utils/query.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | import cv2
4 |
5 |
6 | def ask_for_query(index, path):
7 | index = list(index)
8 | query_all = []
9 |
10 | path_list = os.listdir(path)
11 |
12 | for idx in index:
13 | query = []
14 | for path_index in path_list:
15 | if (set(list(path_index)) & set(index)) == set(idx):
16 | query.append(path_index)
17 | query_all.append((idx, query))
18 |
19 | query_dict = dict(query_all)
20 |
21 | return query_dict
22 |
23 |
24 | def ask_for_query_new(index, path):
25 | index = list(index)
26 | query_all = []
27 |
28 | path_list = os.listdir(path)
29 |
30 | for i in path_list:
31 | # the shared
32 | I = (set(list(i)) & set(index))
33 | if I != set():
34 | if len(index) > 1:
35 | if I != set(index):
36 | query_all.append(i)
37 | else:
38 | query_all.append(i)
39 |
40 | return query_all
41 |
42 |
43 | if __name__ == '__main__':
44 | path = '/home/ggm/WLS/semantic/dataset/potsdam/train/img_index'
45 |
46 | index = '012'
47 | query = ask_for_query_new(index, path)
48 | print(query)
49 |
50 |
--------------------------------------------------------------------------------
/utils/util.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 | import logging
6 | import time
7 | import json
8 | import cv2
9 | import os
10 | from utils.paint import create_visual_anno
11 | from pathlib import Path
12 | from models import *
13 | from PIL import Image
14 | import tifffile
15 | import numpy as np
16 | import ttach as tta
17 | from math import *
18 |
19 |
20 | def read(path):
21 | if path.endswith('.tif'):
22 | return tifffile.imread(path)
23 | else:
24 | img = Image.open(path)
25 | return np.asarray(img)
26 |
27 |
28 | def imsave(path, img):
29 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
30 | cv2.imwrite(path[:-4] + '.png', img)
31 |
32 |
33 | def check_dir(dir):
34 | if not os.path.exists(dir):
35 | os.mkdir(dir)
36 |
37 |
38 | def get_lr(optimizer):
39 | for param_group in optimizer.param_groups:
40 | return param_group['lr']
41 |
42 |
43 | class AverageMeter:
44 | """Computes and stores the average and current value"""
45 | def __init__(self):
46 | self.reset()
47 |
48 | def reset(self):
49 | self.val = 0
50 | self.avg = 0
51 | self.sum = 0
52 | self.count = 0
53 |
54 | def update(self, val, n=1):
55 | self.val = val
56 | self.sum += val * n
57 | self.count += n
58 | self.avg = self.sum / self.count
59 |
60 |
61 | def adjust_learning_rate(learning_rate, optimizer, step, length, num_epochs=20):
62 | """Sets the learning rate to the initial LR decayed by 10 every 20 epochs"""
63 | stride = num_epochs * length
64 | lr = learning_rate * (0.1 ** (step // stride))
65 | if step % stride == 0:
66 | print("learning_rate change to:%.8f" % (lr))
67 | for param_group in optimizer.param_groups:
68 | param_group['lr'] = lr
69 |
70 |
71 | def save_pred_anno(out, save_pred_dir, filename, dict, flag=False):
72 | out = out.squeeze(0)
73 | check_dir(save_pred_dir)
74 |
75 | # save predict dir
76 | save_value_dir = save_pred_dir + '/label'
77 | check_dir(save_value_dir)
78 | save_value = os.path.join(save_value_dir, filename[0].split('/')[-1])
79 | label = out.data.cpu().numpy()
80 | cv2.imwrite(save_value, label)
81 |
82 | if flag is True:
83 | # save predict_visual dir
84 | save_anno_dir = save_pred_dir + '/label_vis'
85 | check_dir(save_anno_dir)
86 | save_anno = os.path.join(save_anno_dir, filename[0].split('/')[-1])
87 |
88 | img_visual = create_visual_anno(out.data.cpu().numpy(), dict=dict, flag=flag)
89 | b, g, r = cv2.split(img_visual)
90 | img_visual_rgb = cv2.merge([r, g, b])
91 | cv2.imwrite(save_anno, img_visual_rgb)
92 |
93 |
94 | def save_pred_anno_numpy(out, save_pred_dir, filename, dict, flag=False):
95 | check_dir(save_pred_dir)
96 |
97 | # save predict dir
98 | save_value_dir = save_pred_dir + '/label'
99 | check_dir(save_value_dir)
100 | save_value = os.path.join(save_value_dir, filename[0].split('/')[-1])
101 |
102 | cv2.imwrite(save_value, out)
103 |
104 | if flag is True:
105 | # save predict_visual dir
106 | save_anno_dir = save_pred_dir + '/label_vis'
107 | check_dir(save_anno_dir)
108 | save_anno = os.path.join(save_anno_dir, filename[0].split('/')[-1])
109 |
110 | img_visual = create_visual_anno(out, dict=dict, flag=flag)
111 | b, g, r = cv2.split(img_visual)
112 | img_visual_rgb = cv2.merge([r, g, b])
113 | cv2.imwrite(save_anno, img_visual_rgb)
114 |
115 |
116 | def save2json(metric_dict, save_path):
117 | file_ = open(save_path, 'w')
118 | file_.write(json.dumps(metric_dict, ensure_ascii=False,indent=2))
119 | file_.close()
120 |
121 |
122 | def create_save_path(opt):
123 | save_path = os.path.join(opt.save_path, opt.dataset)
124 | exp_path = os.path.join(save_path, opt.experiment_name)
125 |
126 | log_path = os.path.join(exp_path, 'log')
127 | checkpoint_path = os.path.join(exp_path, 'checkpoint')
128 | predict_test_path = os.path.join(exp_path, 'predict_test')
129 | predict_train_path = os.path.join(exp_path, 'predict_train')
130 | predict_val_path = os.path.join(exp_path, 'predict_val')
131 |
132 | check_dir(save_path), check_dir(exp_path), check_dir(log_path), check_dir(checkpoint_path), \
133 | check_dir(predict_test_path), check_dir(predict_train_path), check_dir(predict_val_path)
134 |
135 | return log_path, checkpoint_path, predict_test_path, predict_train_path, predict_val_path
136 |
137 |
138 | def create_data_path(opt):
139 | data_inform_path = os.path.join(opt.data_inform_path, opt.dataset)
140 |
141 | # for vai second
142 | # train_txt_path = os.path.join(data_inform_path, 'sec_train.txt')
143 | # val_txt_path = os.path.join(data_inform_path, 'sec_val.txt')
144 |
145 | train_txt_path = os.path.join(data_inform_path, 'seg_train.txt')
146 | val_txt_path = os.path.join(data_inform_path, 'seg_val.txt')
147 | test_txt_path = os.path.join(data_inform_path, 'seg_test.txt')
148 |
149 | return train_txt_path, val_txt_path, test_txt_path
150 |
151 |
152 | def create_logger(log_path):
153 | time_str = time.strftime('%Y-%m-%d-%H-%M')
154 |
155 | log_file = '{}.log'.format(time_str)
156 |
157 | final_log_file = os.path.join(log_path , log_file)
158 | head = '%(asctime)-15s %(message)s'
159 | logging.basicConfig(filename=str(final_log_file),
160 | format=head)
161 | logger = logging.getLogger()
162 | logger.setLevel(logging.INFO)
163 | console = logging.StreamHandler()
164 | logging.getLogger('').addHandler(console)
165 |
166 | tensorboard_log_dir = Path(log_path)/'scalar'/time_str
167 | print('=>creating {}'.format(tensorboard_log_dir))
168 | tensorboard_log_dir.mkdir(parents=True, exist_ok=True)
169 |
170 | return logger, str(tensorboard_log_dir)
171 |
172 |
173 | def resize_label(label, size):
174 | if len(label.size()) == 3:
175 | label = label.unsqueeze(1)
176 | label = F.interpolate(label, size=(size, size), mode='bilinear', align_corners=True)
177 |
178 | return label
179 |
180 |
181 | def get_mean_std(flag):
182 | if flag == 'potsdam':
183 | means = [86.42521457, 92.37607528, 85.74658389]
184 | std = [35.58409409, 35.45218542, 36.91464009]
185 | elif flag == 'vaihingen':
186 | means = [119.14901543, 83.04203606, 81.79810095]
187 | std = [55.63038161, 40.67145608, 38.61447761]
188 |
189 | else:
190 | means = 0
191 | std = 0
192 | print('error')
193 | return means, std
194 |
195 |
196 | def Normalize(img, flag='potsdam'):
197 | means, std = get_mean_std(flag)
198 | img = (img - means) / std
199 |
200 | return img
201 |
202 |
203 | def Normalize_back(img, flag='potsdam'):
204 | means, std = get_mean_std(flag)
205 |
206 | means = means[:3]
207 | std = std[:3]
208 |
209 | img = img * std + means
210 |
211 | return img
212 |
213 |
214 | def pad_image(img, target_size):
215 | """Pad an image up to the target size."""
216 | rows_missing = target_size[0] - img.shape[2]
217 | cols_missing = target_size[1] - img.shape[3]
218 | padded_img = F.pad(img, (0, 0, rows_missing, cols_missing), 'constant', 0)
219 | return padded_img
220 |
221 |
222 | def pre_slide(model, image, num_classes=7, tile_size=(512, 512), tta=False):
223 | image_size = image.shape # bigger than (1, 3, 512, 512), i.e. (1,3,1024,1024)
224 | overlap = 1 / 2 # 每次滑动的重合率为1/2
225 |
226 | stride = ceil(tile_size[0] * (1 - overlap)) # 滑动步长:769*(1-1/3) = 513
227 | tile_rows = int(ceil((image_size[2] - tile_size[0]) / stride) + 1) # 行滑动步数:(1024-769)/513 + 1 = 2
228 | tile_cols = int(ceil((image_size[3] - tile_size[1]) / stride) + 1) # 列滑动步数:(2048-769)/513 + 1 = 4
229 |
230 | full_probs = torch.zeros((1, num_classes, image_size[2], image_size[3])).cuda() # 初始化全概率矩阵 shape(1024,2048,19)
231 |
232 | count_predictions = torch.zeros((1, 1, image_size[2], image_size[3])).cuda() # 初始化计数矩阵 shape(1024,2048,19)
233 | tile_counter = 0 # 滑动计数0
234 |
235 | for row in range(tile_rows): # row = 0,1
236 | for col in range(tile_cols): # col = 0,1,2,3
237 | x1 = int(col * stride) # 起始位置x1 = 0 * 513 = 0
238 | y1 = int(row * stride) # y1 = 0 * 513 = 0
239 | x2 = min(x1 + tile_size[1], image_size[3]) # 末位置x2 = min(0+769, 2048)
240 | y2 = min(y1 + tile_size[0], image_size[2]) # y2 = min(0+769, 1024)
241 | x1 = max(int(x2 - tile_size[1]), 0) # 重新校准起始位置x1 = max(769-769, 0)
242 | y1 = max(int(y2 - tile_size[0]), 0) # y1 = max(769-769, 0)
243 |
244 | img = image[:, :, y1:y2, x1:x2] # 滑动窗口对应的图像 imge[:, :, 0:769, 0:769]
245 | padded_img = pad_image(img, tile_size) # padding 确保扣下来的图像为769*769
246 |
247 | tile_counter += 1 # 计数加1
248 | # print("Predicting tile %i" % tile_counter)
249 |
250 | # 将扣下来的部分传入网络,网络输出概率图。
251 | # use softmax
252 | if tta is True:
253 | padded = tta_predict(model, padded_img)
254 | else:
255 | padded = model(padded_img)
256 | padded = F.softmax(padded, dim=1)
257 |
258 | pre = padded[:, :, 0:img.shape[2], 0:img.shape[3]] # 扣下相应面积 shape(769,769,19)
259 |
260 | count_predictions[:, :, y1:y2, x1:x2] += 1 # 窗口区域内的计数矩阵加1
261 | full_probs[:, :, y1:y2, x1:x2] += pre # 窗口区域内的全概率矩阵叠加预测结果
262 |
263 | # average the predictions in the overlapping regions
264 | full_probs /= count_predictions # 全概率矩阵 除以 计数矩阵 即得 平均概率
265 |
266 | return full_probs # 返回整张图的平均概率 shape(1, 1, 1024,2048)
267 |
268 |
269 | def tta_predict(model, img):
270 | tta_transforms = tta.Compose(
271 | [
272 | tta.HorizontalFlip(),
273 | tta.Rotate90(angles=[0, 90, 180, 270]),
274 | ])
275 |
276 | xs = []
277 |
278 | for t in tta_transforms:
279 | aug_img = t.augment_image(img)
280 | aug_x = model(aug_img)
281 | aug_x = F.softmax(aug_x, dim=1)
282 |
283 | x = t.deaugment_mask(aug_x)
284 | xs.append(x)
285 |
286 | xs = torch.cat(xs, 0)
287 | x = torch.mean(xs, dim=0, keepdim=True)
288 |
289 | return x
--------------------------------------------------------------------------------