├── DataProcess ├── __init__.py ├── postdam │ ├── CropPost.py │ ├── FindMean.py │ ├── TransPostLabel.py │ ├── WriteSegTXT.py │ ├── __init__.py │ ├── make_point_label.py │ └── rgb_to_value.txt └── vaihingen │ ├── CropVai.py │ ├── FindMean.py │ ├── TransVaiLabel.py │ ├── WriteTXT.py │ ├── __init__.py │ ├── make_point_label.py │ └── rgb_to_value.txt ├── README.md ├── datafiles ├── color_dict.py ├── potsdam │ ├── seg_test.txt │ ├── seg_train.txt │ └── seg_val.txt └── vaihingen │ ├── seg_test.txt │ ├── seg_train.txt │ └── seg_val.txt ├── dataset ├── __init__.py ├── data_sec.py ├── data_utils.py ├── datapoint.py └── transform.py ├── models ├── DBFNet.py ├── __init__.py ├── base_model.py ├── fbf.py ├── network │ ├── DeeplabV3.py │ ├── LinkNet.py │ ├── __init__.py │ ├── fcn.py │ ├── fpn.py │ ├── resnet.py │ └── segformer.py ├── seg_model.py └── tools │ ├── __init__.py │ └── densecrf.py ├── options ├── Second_options.py ├── __init__.py └── point_options.py ├── predict.py ├── run ├── __init__.py ├── point │ ├── p_predict_train.py │ ├── p_test.py │ └── p_train.py └── second │ ├── __init__.py │ ├── sec_predict_train.py │ ├── sec_test.py │ └── sec_train.py └── utils ├── __init__.py ├── __pycache__ ├── __init__.cpython-35.pyc ├── __init__.cpython-36.pyc ├── __init__.cpython-37.pyc ├── metric.cpython-36.pyc ├── metric.cpython-37.pyc ├── paint.cpython-35.pyc ├── paint.cpython-36.pyc ├── paint.cpython-37.pyc ├── query.cpython-37.pyc ├── util.cpython-35.pyc ├── util.cpython-36.pyc └── util.cpython-37.pyc ├── metric.py ├── paint.py ├── query.py └── util.py /DataProcess/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Luffy03/DBFNet/5ec2458098c206aa9780e693ed81320e548da4a2/DataProcess/__init__.py -------------------------------------------------------------------------------- /DataProcess/postdam/CropPost.py: -------------------------------------------------------------------------------- 1 | import os 2 | import tifffile 3 | import numpy as np 4 | import cv2 5 | from utils import check_dir, read, imsave 6 | 7 | train_name = [7, 8, 9, 10, 11, 12] 8 | test_name = [13, 14, 15] 9 | 10 | 11 | def crop_overlap(img, label, vis, name, img_path, label_path, vis_path, size=256, stride=128): 12 | h, w, c = img.shape 13 | new_h, new_w = (h // size + 1) * size, (w // size + 1) * size 14 | num_h, num_w = (new_h // stride) - 1, (new_w // stride) - 1 15 | 16 | new_img = np.zeros([new_h, new_w, c]).astype(np.uint8) 17 | new_img[:h, :w, :] = img 18 | 19 | new_label = 255 * np.ones([new_h, new_w, 3]).astype(np.uint8) 20 | new_label[:h, :w, :] = label 21 | 22 | new_vis = np.zeros([new_h, new_w, 3]).astype(np.uint8) 23 | new_vis[:h, :w, :] = vis 24 | 25 | count = 0 26 | 27 | for i in range(num_h): 28 | for j in range(num_w): 29 | out = new_img[i * stride:i * stride + size, j * stride:j * stride + size, :] 30 | gt = new_label[i * stride:i * stride + size, j * stride:j * stride + size, :] 31 | v = new_vis[i * stride:i * stride + size, j * stride:j * stride + size, :] 32 | assert v.shape == (256, 256, 3), print(v.shape) 33 | 34 | tifffile.imsave(img_path + '/' + str(name) + '_' + str(count) + '.tif', out) 35 | tifffile.imsave(label_path + '/' + str(name) + '_' + str(count) + '.tif', gt) 36 | tifffile.imsave(vis_path + '/' + str(name) + '_' + str(count) + '.tif', v) 37 | 38 | count += 1 39 | 40 | 41 | def crop(img, label, vis, name, img_path, label_path, vis_path, size=256): 42 | height, width, _ = img.shape 43 | h_size = height // size 44 | w_size = width // size 45 | 46 | count = 0 47 | 48 | for i in range(h_size): 49 | for j in range(w_size): 50 | out = img[i * size:(i + 1) * size, j * size:(j + 1) * size, :3] 51 | gt = label[i * size:(i + 1) * size, j * size:(j + 1) * size, :] 52 | v = vis[i * size:(i + 1) * size, j * size:(j + 1) * size, :] 53 | assert v.shape == (256, 256, 3) 54 | 55 | imsave(img_path + '/' + str(name) + '_' + str(count) + '.png', out) 56 | imsave(label_path + '/' + str(name) + '_' + str(count) + '.png', gt) 57 | imsave(vis_path + '/' + str(name) + '_' + str(count) + '.png', v) 58 | 59 | count += 1 60 | 61 | 62 | def save(img, label, vis, name, save_path, flag='val'): 63 | img_path = save_path + '/img' 64 | label_path = save_path + '/label' 65 | v_path = save_path + '/label_vis' 66 | check_dir(img_path), check_dir(label_path), check_dir(v_path) 67 | 68 | crop(img, label, vis, name, img_path, label_path, v_path) 69 | 70 | 71 | def run(root_path): 72 | img_path = root_path + '/dataset_origin/4_Ortho_RGBIR' 73 | label_path = root_path + '/dataset_origin/Labels' 74 | vis_path = root_path + '/dataset_origin/5_Labels_all' 75 | 76 | list = os.listdir(label_path) 77 | for i in list: 78 | img = read(os.path.join(img_path, i[:-9] + 'RGBIR.tif')) 79 | label = read(os.path.join(label_path, i)) 80 | vis = read(os.path.join(vis_path, i)) 81 | 82 | name = i[12:-10] 83 | 84 | if int(i[14:-10]) in test_name: 85 | print(name, 'test') 86 | test_path = root_path+'/test' 87 | check_dir(test_path) 88 | save(img, label, vis, name, test_path, flag='test') 89 | 90 | elif int(i[14:-10]) in val_name: 91 | print(name, 'val') 92 | val_path = root_path + '/val' 93 | check_dir(val_path) 94 | save(img, label, vis, name, val_path, flag='val') 95 | 96 | elif int(i[14:-10]) in train_name: 97 | print(name, 'train') 98 | train_path = root_path+'/train' 99 | check_dir(train_path) 100 | save(img, label, vis, name, train_path, flag='train') 101 | -------------------------------------------------------------------------------- /DataProcess/postdam/FindMean.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import tifffile 4 | 5 | # means = [86.42521457, 92.37607528, 85.74658389, 98.17242502] 6 | path = '/home/ggm/WLS/semantic/dataset/potsdam/train/img' 7 | 8 | list = os.listdir(path) 9 | 10 | k = 0 11 | sum = np.zeros([4]) 12 | 13 | for idx, i in enumerate(list): 14 | print(idx) 15 | k += 1 16 | img = tifffile.imread(os.path.join(path, i)) 17 | img = img.reshape(256*256, -1) 18 | 19 | mean = np.mean(img, axis=0) 20 | sum += mean 21 | 22 | 23 | means = sum/k 24 | print(means) 25 | 26 | # std = [35.58409409, 35.45218542, 36.91464009, 36.3449891] 27 | # means = [86.42521457, 92.37607528, 85.74658389, 98.17242502] 28 | # 29 | # path = '/home/ggm/WLS/semantic/dataset/potsdam/train/img' 30 | # 31 | # list = os.listdir(path) 32 | # k = 0 33 | # sum = np.zeros([4]) 34 | # 35 | # for idx, i in enumerate(list): 36 | # print(idx) 37 | # k += 1 38 | # img = tifffile.imread(os.path.join(path, i)) 39 | # img = img.reshape(256*256, -1) 40 | # 41 | # x = (img - means) ** 2 42 | # 43 | # sum += np.sum(x, axis=0) 44 | # 45 | # std = np.sqrt(sum/(k*256*256)) 46 | # print(std) -------------------------------------------------------------------------------- /DataProcess/postdam/TransPostLabel.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | import os 4 | from utils import * 5 | import matplotlib.pyplot as plt 6 | import tifffile 7 | from tqdm import tqdm 8 | 9 | 10 | def value_to_rgb(anno): 11 | label2color_dict = { 12 | 0: [255, 255, 255], # Impervious surfaces (RGB: 255, 255, 255) 13 | 1: [0, 0, 255], # Building (RGB: 0, 0, 255) 14 | 2: [0, 255, 255], # Low vegetation (RGB: 0, 255, 255) 15 | 3: [0, 255, 0], # Tree (RGB: 0, 255, 0) 16 | 4: [255, 255, 0], # Car (RGB: 255, 255, 0) 17 | 5: [255, 0, 0], # Clutter/background (RGB: 255, 0, 0) 18 | 255: [0, 0, 0], # boundary (RGB: 0, 0, 0) 19 | } 20 | # visualize 21 | visual_anno = np.zeros((anno.shape[0], anno.shape[1], 3), dtype=np.uint8) 22 | for i in range(visual_anno.shape[0]): # i for h 23 | for j in range(visual_anno.shape[1]): 24 | # cv2: bgr 25 | color = label2color_dict[anno[i, j, 0]] 26 | 27 | visual_anno[i, j, 0] = color[0] 28 | visual_anno[i, j, 1] = color[1] 29 | visual_anno[i, j, 2] = color[2] 30 | 31 | return visual_anno 32 | 33 | 34 | def rgb_to_value(rgb, txt='/home/ggm/WLS/semantic/PointAnno/DataProcess/potsdam/rgb_to_value.txt'): 35 | key_arr = np.loadtxt(txt) 36 | array = np.zeros((rgb.shape[0], rgb.shape[1], 3), dtype=np.uint8) 37 | 38 | for translation in key_arr: 39 | r, g, b, value = translation 40 | tmp = [r, g, b] 41 | array[(rgb == tmp).all(axis=2)] = value 42 | 43 | return array 44 | 45 | 46 | def save_label_value(label_path, save_path): 47 | check_dir(save_path) 48 | label_list = os.listdir(label_path) 49 | for idx, i in enumerate(label_list): 50 | print(idx) 51 | path = os.path.join(label_path, i) 52 | label = tifffile.imread(path) 53 | 54 | label = ((label > 128) * 255) 55 | 56 | label_value = rgb_to_value(label) 57 | tifffile.imsave(os.path.join(save_path, i), label_value) 58 | 59 | 60 | def show_label(root_path, img_path, vis_path): 61 | list = os.listdir(root_path) 62 | for idx, i in tqdm(enumerate(list)): 63 | print(i) 64 | img = tifffile.imread(os.path.join(img_path, i[:-9] + 'RGBIR.tif'))[:, :, :3] 65 | vis = tifffile.imread(os.path.join(vis_path, i)) 66 | 67 | label = tifffile.imread(os.path.join(root_path, i)) 68 | label = value_to_rgb(label) 69 | 70 | fig, axs = plt.subplots(1, 3, figsize=(14, 4)) 71 | 72 | axs[0].imshow(img.astype(np.uint8)) 73 | axs[0].axis("off") 74 | axs[1].imshow(label.astype(np.uint8)) 75 | axs[1].axis("off") 76 | 77 | axs[2].imshow(vis) 78 | axs[2].axis("off") 79 | 80 | plt.suptitle(os.path.basename(i), y=0.94) 81 | plt.tight_layout() 82 | plt.show() 83 | plt.close() 84 | 85 | 86 | def show_one(path='/home/ggm/WLS/semantic/dataset/potsdam/val/label/6_7_217.tif'): 87 | label = tifffile.imread(path) 88 | label = value_to_rgb(label) 89 | 90 | fig, axs = plt.subplots(1, 2, figsize=(14, 4)) 91 | 92 | axs[0].imshow(label.astype(np.uint8)) 93 | axs[0].axis("off") 94 | 95 | plt.tight_layout() 96 | plt.show() 97 | plt.close() 98 | 99 | 100 | if __name__ == '__main__': 101 | img_path = '/home/ggm/WLS/semantic/dataset/potsdam/dataset_origin/4_Ortho_RGBIR' 102 | 103 | path = '/home/ggm/WLS/semantic/dataset/potsdam/dataset_origin/5_Labels_all' 104 | save_path = '/home/ggm/WLS/semantic/dataset/potsdam/dataset_origin/Labels' 105 | check_dir(save_path) 106 | 107 | # save_label_value(path, save_path) 108 | show_label(save_path, img_path, path) -------------------------------------------------------------------------------- /DataProcess/postdam/WriteSegTXT.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | 4 | 5 | def write_train(path): 6 | train_txt = open(path + '/datafiles/potsdam/seg_train.txt', 'w') 7 | train_path = '/media/hlf/Luffy/WLS//semantic/dataset/potsdam/train/img' 8 | list = os.listdir(train_path) 9 | 10 | for idx, i in enumerate(list): 11 | train_txt.write(i + '\n') 12 | train_txt.close() 13 | 14 | 15 | def write_val(path): 16 | val_txt = open(path + '/datafiles/potsdam/seg_val.txt', 'w') 17 | val_path = '/media/hlf/Luffy/WLS//semantic/dataset/potsdam/val/img' 18 | list = os.listdir(val_path) 19 | 20 | for idx, i in enumerate(list): 21 | val_txt.write(i + '\n') 22 | val_txt.close() 23 | 24 | 25 | def write_test(path): 26 | test_txt = open(path + '/datafiles/potsdam/seg_test.txt', 'w') 27 | test_path = '/media/hlf/Luffy/WLS//semantic/dataset/potsdam/test/img' 28 | list = os.listdir(test_path) 29 | 30 | for idx, i in enumerate(list): 31 | test_txt.write(i + '\n') 32 | test_txt.close() 33 | 34 | 35 | if __name__ == '__main__': 36 | path = '/media/hlf/Luffy/WLS//PointAnno' 37 | write_train(path) 38 | write_val(path) 39 | write_test(path) -------------------------------------------------------------------------------- /DataProcess/postdam/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Luffy03/DBFNet/5ec2458098c206aa9780e693ed81320e548da4a2/DataProcess/postdam/__init__.py -------------------------------------------------------------------------------- /DataProcess/postdam/make_point_label.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | import numpy as np 4 | import tifffile 5 | import matplotlib.pyplot as plt 6 | from tqdm import tqdm 7 | from utils import check_dir, read, imsave 8 | 9 | 10 | def read_cls_color(cls): 11 | label2color_dict = { 12 | 0: [255, 255, 255], # Impervious surfaces (RGB: 255, 255, 255) 13 | 1: [0, 0, 255], # Building (RGB: 0, 0, 255) 14 | 2: [0, 255, 255], # Low vegetation (RGB: 0, 255, 255) 15 | 3: [0, 255, 0], # Tree (RGB: 0, 255, 0) 16 | 4: [255, 255, 0], # Car (RGB: 255, 255, 0) 17 | 5: [255, 0, 0], # Clutter/background (RGB: 255, 0, 0) 18 | } 19 | return label2color_dict[cls] 20 | 21 | 22 | def draw_point(label, kernal_size=100, point_size=3): 23 | h, w, c = label.shape 24 | label_set = np.unique(label) 25 | 26 | new_mask = np.ones([h, w, c], np.uint8) * 255 27 | new_mask_vis = np.zeros([h, w, c], np.uint8) 28 | 29 | for cls in label_set: 30 | if cls != 255: 31 | color = read_cls_color(cls) 32 | 33 | temp_mask = np.zeros([h, w]) 34 | temp_mask[label[:, :, 0] == cls] = 255 35 | temp_mask = np.asarray(temp_mask, dtype=np.uint8) 36 | _, contours, hierarchy = cv2.findContours(temp_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE) 37 | # cv2.drawContours(new_mask, contours, -1, color, 1) 38 | 39 | for i in range(len(contours)): 40 | area = cv2.contourArea(contours[i]) 41 | if area > kernal_size: 42 | # distance to contour 43 | dist = np.empty([h, w], dtype=np.float32) 44 | for h_ in range(h): 45 | for w_ in range(w): 46 | dist[h_, w_] = cv2.pointPolygonTest(contours[i], (w_, h_), True) 47 | 48 | # make sure the point in the temp_mask 49 | temp_dist = temp_mask * dist 50 | min_, max_, _, maxdistpt = cv2.minMaxLoc(temp_dist) 51 | cx, cy = maxdistpt[0], maxdistpt[1] 52 | 53 | new_mask[cy:cy + point_size, cx:cx + point_size, :] = (cls, cls, cls) 54 | new_mask_vis[cy:cy + point_size, cx:cx + point_size, :] = color 55 | 56 | return new_mask, new_mask_vis 57 | 58 | 59 | def make(root_path): 60 | train_path = root_path + '/train' 61 | val_path = root_path + '/val' 62 | 63 | paths = [train_path, val_path] 64 | for path in paths: 65 | label_path = path + '/label' 66 | 67 | point_label_path = path + '/point_label' 68 | point_label_vis_path = path + '/point_label_vis' 69 | check_dir(point_label_path), check_dir(point_label_vis_path) 70 | 71 | list = os.listdir(label_path) 72 | for i in tqdm(list): 73 | label = os.path.join(label_path, i) 74 | label = read(label) 75 | new_mask, new_mask_vis = draw_point(label) 76 | imsave(point_label_path + '/' + i, new_mask) 77 | imsave(point_label_vis_path + '/' + i, new_mask_vis) 78 | 79 | 80 | if __name__ == '__main__': 81 | root_path = '/media/hlf/Luffy/WLS/semantic/dataset/potsdam' 82 | make(root_path) 83 | -------------------------------------------------------------------------------- /DataProcess/postdam/rgb_to_value.txt: -------------------------------------------------------------------------------- 1 | 255 255 255 0 2 | 0 0 255 1 3 | 0 255 255 2 4 | 0 255 0 3 5 | 255 255 0 4 6 | 255 0 0 5 -------------------------------------------------------------------------------- /DataProcess/vaihingen/CropVai.py: -------------------------------------------------------------------------------- 1 | import os 2 | import tifffile 3 | import numpy as np 4 | import cv2 5 | from utils import check_dir, read, imsave 6 | 7 | test_name = [2, 4, 6, 8, 10, 12, 14, 16, 20, 22, 24, 27, 29, 31, 33, 35, 38] 8 | val_name = [5] 9 | 10 | 11 | def crop_overlap(img, label, vis, name, img_path, label_path, vis_path, size=256, stride=128): 12 | h, w, c = img.shape 13 | new_h, new_w = (h // size + 1) * size, (w // size + 1) * size 14 | num_h, num_w = (new_h // stride) - 1, (new_w // stride) - 1 15 | 16 | new_img = np.zeros([new_h, new_w, c]).astype(np.uint8) 17 | new_img[:h, :w, :] = img 18 | 19 | new_label = 255 * np.ones([new_h, new_w, 3]).astype(np.uint8) 20 | new_label[:h, :w, :] = label 21 | 22 | new_vis = np.zeros([new_h, new_w, 3]).astype(np.uint8) 23 | new_vis[:h, :w, :] = vis 24 | 25 | count = 0 26 | 27 | for i in range(num_h): 28 | for j in range(num_w): 29 | out = new_img[i * stride:i * stride + size, j * stride:j * stride + size, :] 30 | gt = new_label[i * stride:i * stride + size, j * stride:j * stride + size, :] 31 | v = new_vis[i * stride:i * stride + size, j * stride:j * stride + size, :] 32 | assert v.shape == (256, 256, 3), print(v.shape) 33 | 34 | imsave(img_path + '/' + str(name) + '_' + str(count) + '.png', out) 35 | imsave(label_path + '/' + str(name) + '_' + str(count) + '.png', gt) 36 | imsave(vis_path + '/' + str(name) + '_' + str(count) + '.png', v) 37 | 38 | count += 1 39 | 40 | 41 | def crop(img, label, vis, name, img_path, label_path, vis_path, size=256): 42 | height, width, _ = img.shape 43 | h_size = height // size 44 | w_size = width // size 45 | 46 | count = 0 47 | 48 | for i in range(h_size): 49 | for j in range(w_size): 50 | out = img[i * size:(i + 1) * size, j * size:(j + 1) * size, :] 51 | gt = label[i * size:(i + 1) * size, j * size:(j + 1) * size, :] 52 | v = vis[i * size:(i + 1) * size, j * size:(j + 1) * size, :] 53 | assert v.shape == (256, 256, 3) 54 | 55 | imsave(img_path + '/' + str(name) + '_' + str(count) + '.png', out) 56 | imsave(label_path + '/' + str(name) + '_' + str(count) + '.png', gt) 57 | imsave(vis_path + '/' + str(name) + '_' + str(count) + '.png', v) 58 | 59 | count += 1 60 | 61 | 62 | def save(img, label, vis, name, save_path, flag='val'): 63 | img_path = save_path + '/img' 64 | label_path = save_path + '/label' 65 | v_path = save_path + '/label_vis' 66 | check_dir(img_path), check_dir(label_path), check_dir(v_path) 67 | 68 | if flag == 'val': 69 | crop(img, label, vis, name, img_path, label_path, v_path) 70 | else: 71 | crop_overlap(img, label, vis, name, img_path, label_path, v_path) 72 | 73 | 74 | def run(root_path): 75 | data_path = root_path + '/dataset_origin/image' 76 | label_path = root_path + '/dataset_origin/gts' 77 | label_vis_path = root_path + '/dataset_origin/vis' 78 | 79 | list = os.listdir(label_path) 80 | for i in list: 81 | name = i[20:-4] 82 | print(i, name) 83 | 84 | img = read(os.path.join(data_path, i)) 85 | label = read(os.path.join(label_path, i)) 86 | label_vis = read(os.path.join(label_vis_path, i)) 87 | 88 | if int(name) in test_name: 89 | print(name, 'test') 90 | test_path = root_path + '/test' 91 | check_dir(test_path) 92 | save(img, label, label_vis, name, test_path, flag='test') 93 | 94 | # elif int(name) in val_name: 95 | # print(name, 'val') 96 | # val_path = root_path + '/val' 97 | # check_dir(val_path) 98 | # save(img, label, label_vis, name, val_path, flag='val') 99 | 100 | else: 101 | print(name, 'train') 102 | train_path = root_path + '/train' 103 | check_dir(train_path) 104 | save(img, label, label_vis, name, train_path, flag='train') 105 | 106 | 107 | if __name__ == '__main__': 108 | root_path = '/media/hlf/Luffy/WLS/semantic/dataset/vaihingen' 109 | run(root_path) -------------------------------------------------------------------------------- /DataProcess/vaihingen/FindMean.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import tifffile 4 | 5 | # means = [119.14901543, 83.04203606, 81.79810095] 6 | path = '/home/ggm/WLS/semantic/dataset/vaihingen/train/img' 7 | 8 | list = os.listdir(path) 9 | 10 | k = 0 11 | sum = np.zeros([3]) 12 | 13 | for idx, i in enumerate(list): 14 | print(idx) 15 | k += 1 16 | img = tifffile.imread(os.path.join(path, i)) 17 | img = img.reshape(256*256, -1) 18 | 19 | mean = np.mean(img, axis=0) 20 | sum += mean 21 | 22 | 23 | means = sum/k 24 | print(means) 25 | 26 | # std = [55.63038161, 40.67145608, 38.61447761] 27 | # means = [119.14901543, 83.04203606, 81.79810095] 28 | # 29 | # path = '/home/ggm/WLS/semantic/dataset/vaihingen/train/img' 30 | # 31 | # list = os.listdir(path) 32 | # k = 0 33 | # sum = np.zeros([3]) 34 | # 35 | # for idx, i in enumerate(list): 36 | # print(idx) 37 | # k += 1 38 | # img = tifffile.imread(os.path.join(path, i)) 39 | # img = img.reshape(256*256, -1) 40 | # 41 | # x = (img - means) ** 2 42 | # 43 | # sum += np.sum(x, axis=0) 44 | # 45 | # std = np.sqrt(sum/(k*256*256)) 46 | # print(std) -------------------------------------------------------------------------------- /DataProcess/vaihingen/TransVaiLabel.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | import os 4 | from utils import * 5 | import matplotlib.pyplot as plt 6 | import tifffile 7 | 8 | 9 | def value_to_rgb(anno): 10 | label2color_dict = { 11 | 0: [255, 255, 255], # Impervious surfaces (RGB: 255, 255, 255) 12 | 1: [0, 0, 255], # Building (RGB: 0, 0, 255) 13 | 2: [0, 255, 255], # Low vegetation (RGB: 0, 255, 255) 14 | 3: [0, 255, 0], # Tree (RGB: 0, 255, 0) 15 | 4: [255, 255, 0], # Car (RGB: 255, 255, 0) 16 | 5: [255, 0, 0], # Clutter/background (RGB: 255, 0, 0) 17 | 255: [0, 0, 0], # boundary (RGB: 0, 0, 0) 18 | } 19 | # visualize 20 | visual_anno = np.zeros((anno.shape[0], anno.shape[1], 3), dtype=np.uint8) 21 | for i in range(visual_anno.shape[0]): # i for h 22 | for j in range(visual_anno.shape[1]): 23 | # cv2: bgr 24 | color = label2color_dict[anno[i, j, 0]] 25 | 26 | visual_anno[i, j, 0] = color[0] 27 | visual_anno[i, j, 1] = color[1] 28 | visual_anno[i, j, 2] = color[2] 29 | 30 | return visual_anno 31 | 32 | 33 | def rgb_to_value(rgb, txt='/media/hlf/Luffy/WLS/PointAnno/DataProcess/vaihingen/rgb_to_value.txt'): 34 | key_arr = np.loadtxt(txt) 35 | array = np.zeros((rgb.shape[0], rgb.shape[1], 3), dtype=np.uint8) 36 | 37 | for translation in key_arr: 38 | r, g, b, value = translation 39 | tmp = [r, g, b] 40 | array[(rgb == tmp).all(axis=2)] = value 41 | 42 | return array 43 | 44 | 45 | def save_label_value(label_path, save_path): 46 | check_dir(save_path) 47 | label_list = os.listdir(label_path) 48 | for idx, i in enumerate(label_list): 49 | print(idx) 50 | path = os.path.join(label_path, i) 51 | label = tifffile.imread(path) 52 | label = ((label > 128) * 255) 53 | 54 | label_value = rgb_to_value(label) 55 | tifffile.imsave(os.path.join(save_path, i), label_value) 56 | 57 | 58 | def show_label(root_path, img_path, vis_path): 59 | list = os.listdir(root_path) 60 | for idx, i in enumerate(list): 61 | label = tifffile.imread(os.path.join(root_path, i)) 62 | print(np.unique(label)) 63 | label = value_to_rgb(label) 64 | 65 | img = tifffile.imread(os.path.join(img_path, i))[:, :, :3] 66 | vis = tifffile.imread(os.path.join(vis_path, i)) 67 | 68 | plt.subplot(1, 3, 1) 69 | plt.imshow(img) 70 | plt.subplot(1, 3, 2) 71 | plt.imshow(label) 72 | plt.subplot(1, 3, 3) 73 | plt.imshow(vis) 74 | plt.show() 75 | 76 | 77 | if __name__ == '__main__': 78 | img_path = '/media/hlf/Luffy/WLS/semantic/dataset/vaihingen/dataset_origin/image' 79 | path = '/media/hlf/Luffy/WLS/semantic/dataset/vaihingen/dataset_origin/vis_noB' 80 | save_path = '/media/hlf/Luffy/WLS/semantic/dataset/vaihingen/dataset_origin/gts_noB' 81 | 82 | save_label_value(path, save_path) 83 | show_label(save_path, img_path, path) -------------------------------------------------------------------------------- /DataProcess/vaihingen/WriteTXT.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | 4 | import os 5 | import random 6 | 7 | 8 | def write_train(path): 9 | train_txt = open(path + '/datafiles/vaihingen/sec_train.txt', 'w') 10 | train_path = '/media/hlf/Luffy/WLS/semantic/dataset/vaihingen/sec/train/img' 11 | list = os.listdir(train_path) 12 | 13 | for idx, i in enumerate(list): 14 | train_txt.write(i + '\n') 15 | train_txt.close() 16 | 17 | 18 | def write_val(path): 19 | val_txt = open(path + '/datafiles/vaihingen/sec_val.txt', 'w') 20 | val_path = '/media/hlf/Luffy/WLS/semantic/dataset/vaihingen/sec/val/img' 21 | list = os.listdir(val_path) 22 | 23 | for idx, i in enumerate(list): 24 | val_txt.write(i + '\n') 25 | val_txt.close() 26 | 27 | 28 | def write_test(path): 29 | test_txt = open(path + '/datafiles/vaihingen/seg_test.txt', 'w') 30 | test_path = '/media/hlf/Luffy/WLS/semantic/dataset/vaihingen/test/img' 31 | list = os.listdir(test_path) 32 | 33 | for idx, i in enumerate(list): 34 | test_txt.write(i + '\n') 35 | test_txt.close() 36 | 37 | 38 | if __name__ == '__main__': 39 | path = '/media/hlf/Luffy/WLS/PointAnno' 40 | write_train(path) 41 | write_val(path) 42 | # write_test(path) 43 | -------------------------------------------------------------------------------- /DataProcess/vaihingen/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Luffy03/DBFNet/5ec2458098c206aa9780e693ed81320e548da4a2/DataProcess/vaihingen/__init__.py -------------------------------------------------------------------------------- /DataProcess/vaihingen/make_point_label.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | import numpy as np 4 | import tifffile 5 | import matplotlib.pyplot as plt 6 | from tqdm import tqdm 7 | from utils import check_dir, read, imsave 8 | 9 | 10 | def read_cls_color(cls): 11 | label2color_dict = { 12 | 0: [255, 255, 255], # Impervious surfaces (RGB: 255, 255, 255) 13 | 1: [0, 0, 255], # Building (RGB: 0, 0, 255) 14 | 2: [0, 255, 255], # Low vegetation (RGB: 0, 255, 255) 15 | 3: [0, 255, 0], # Tree (RGB: 0, 255, 0) 16 | 4: [255, 255, 0], # Car (RGB: 255, 255, 0) 17 | 5: [255, 0, 0], # Clutter/background (RGB: 255, 0, 0) 18 | 255: [0, 0, 0] 19 | } 20 | return label2color_dict[cls] 21 | 22 | 23 | def draw_point(label, kernal_size=100): 24 | label = read(label) 25 | 26 | h, w, c = label.shape 27 | label_set = np.unique(label) 28 | 29 | new_mask = np.ones([h, w, c], np.uint8) * 255 30 | new_mask_vis = np.zeros([h, w, c], np.uint8) 31 | 32 | for cls in label_set: 33 | if cls != 255: 34 | color = read_cls_color(cls) 35 | 36 | temp_mask = np.zeros([h, w]) 37 | temp_mask[label[:, :, 0] == cls] = 255 38 | temp_mask = np.asarray(temp_mask, dtype=np.uint8) 39 | _, contours, hierarchy = cv2.findContours(temp_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE) 40 | # cv2.drawContours(new_mask, contours, -1, color, 1) 41 | 42 | for i in range(len(contours)): 43 | area = cv2.contourArea(contours[i]) 44 | if area > kernal_size: 45 | # distance to contour 46 | dist = np.empty([h, w], dtype=np.float32) 47 | for h_ in range(h): 48 | for w_ in range(w): 49 | dist[h_, w_] = cv2.pointPolygonTest(contours[i], (w_, h_), True) 50 | 51 | # make sure the point in the temp_mask 52 | temp_dist = temp_mask * dist 53 | min_, max_, _, maxdistpt = cv2.minMaxLoc(temp_dist) 54 | cx, cy = maxdistpt[0], maxdistpt[1] 55 | 56 | new_mask[cy:cy + 3, cx:cx + 3, :] = (cls, cls, cls) 57 | new_mask_vis[cy:cy + 3, cx:cx + 3, :] = color 58 | 59 | return new_mask, new_mask_vis 60 | 61 | 62 | def make(root_path): 63 | train_path = root_path + '/train' 64 | val_path = root_path + '/val' 65 | test_path = root_path + '/test' 66 | 67 | paths = [train_path] 68 | for path in paths: 69 | label_path = path + '/label' 70 | 71 | point_label_path = path + '/point_label' 72 | point_label_vis_path = path + '/point_label_vis' 73 | check_dir(point_label_path), check_dir(point_label_vis_path) 74 | 75 | list = os.listdir(label_path) 76 | for i in tqdm(list): 77 | label = os.path.join(label_path, i) 78 | new_mask, new_mask_vis = draw_point(label) 79 | imsave(point_label_path + '/' + i, new_mask) 80 | imsave(point_label_vis_path + '/' + i, new_mask_vis) 81 | 82 | 83 | if __name__ == '__main__': 84 | root_path = '/media/hlf/Luffy/WLS/semantic/dataset/vaihingen' 85 | make(root_path) 86 | -------------------------------------------------------------------------------- /DataProcess/vaihingen/rgb_to_value.txt: -------------------------------------------------------------------------------- 1 | 255 255 255 0 2 | 0 0 255 1 3 | 0 255 255 2 4 | 0 255 0 3 5 | 255 255 0 4 6 | 255 0 0 5 7 | 0 0 0 255 -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Deep Bilateral Filtering Network (DBFNet) 2 | Code for TIP 2022 paper, [**"Deep Bilateral Filtering Network for Point-Supervised Semantic Segmentation in Remote Sensing Images"**](https://ieeexplore.ieee.org/document/9961229), accepted. 3 | 4 | Authors: Linshan Wu, Leyuan Fang, Jun Yue, Bob Zhang, Pedram Ghamisi, and Min He 5 | 6 | ## Getting Started 7 | ### Prepare Dataset 8 | Download the Potsdam and Vaihingen [datasets](https://drive.google.com/drive/folders/1CiYzJyBn1rV-xsrsYQ6o2HDQjdfnadHl) after processing. 9 | 10 | Or you can download the datasets from the official [website](https://www.isprs.org/education/benchmarks/UrbanSemLab/2d-sem-label-vaihingen.aspx). Then, crop the original images and create point labels following our code in [Dataprocess](https://github.com/Luffy03/DBFNet/tree/master/DataProcess). 11 | 12 | If your want to run our code on your own datasets, the pre-process code is also available in [Dataprocess](https://github.com/Luffy03/DBFNet/tree/master/DataProcess). 13 | 14 | ## Evaluate 15 | ### 1. Download the [original datasets](https://www.isprs.org/education/benchmarks/UrbanSemLab/2d-sem-label-vaihingen.aspx) 16 | ### 2. Download our [weights](https://drive.google.com/drive/folders/1CiYzJyBn1rV-xsrsYQ6o2HDQjdfnadHl) 17 | ### 3. Run our code 18 | ```bash 19 | python predict.py 20 | ``` 21 | 22 | ## Train 23 | ### 1. Train DBFNet 24 | ```bash 25 | python run/point/p_train.py 26 | ``` 27 | ### 2. Generate pseudo labels 28 | ```bash 29 | python run/point/p_predict_train.py 30 | ``` 31 | ### 3. Recursive learning 32 | ```bash 33 | python run/second/sec_train.py 34 | ``` 35 | 36 | ## Citation ✏️ 📄 37 | 38 | If you find this repo useful for your research, please consider citing the paper as follows: 39 | 40 | ``` 41 | @ARTICLE{Wu_DBFNet, 42 | author={Wu, Linshan and Fang, Leyuan and Yue, Jun and Zhang, Bob and Ghamisi, Pedram and He, Min}, 43 | journal={IEEE Transactions on Image Processing}, 44 | title={Deep Bilateral Filtering Network for Point-Supervised Semantic Segmentation in Remote Sensing Images}, 45 | year={2022}, 46 | volume={31}, 47 | number={}, 48 | pages={7419-7434}, 49 | doi={10.1109/TIP.2022.3222904}} 50 | @article{wu2024modeling, 51 | title={Modeling the Label Distributions for Weakly-Supervised Semantic Segmentation}, 52 | author={Wu, Linshan and Zhong, Zhun and Ma, Jiayi and Wei, Yunchao and Chen, Hao and Fang, Leyuan and Li, Shutao}, 53 | journal={arXiv preprint arXiv:2403.13225}, 54 | year={2024} 55 | } 56 | @inproceedings{AGMM, 57 | title={Sparsely Annotated Semantic Segmentation with Adaptive Gaussian Mixtures}, 58 | author={Wu, Linshan and Zhong, Zhun and Fang, Leyuan and He, Xingxin and Liu, Qiang and Ma, Jiayi and Chen, Hao}, 59 | booktitle={IEEE Conf. Comput. Vis. Pattern Recog.}, 60 | month={June}, 61 | year={2023}, 62 | pages={15454-15464} 63 | } 64 | ``` 65 | 66 | For any question, please contact [Linshan Wu](mailto:15274891948@163.com). 67 | -------------------------------------------------------------------------------- /datafiles/color_dict.py: -------------------------------------------------------------------------------- 1 | postdam_color_dict = { 2 | 0: [255, 255, 255], # Impervious surfaces (RGB: 255, 255, 255) 3 | 1: [0, 0, 255], # Building (RGB: 0, 0, 255) 4 | 2: [0, 255, 255], # Low vegetation (RGB: 0, 255, 255) 5 | 3: [0, 255, 0], # Tree (RGB: 0, 255, 0) 6 | 4: [255, 255, 0], # Car (RGB: 255, 255, 0) 7 | 5: [255, 0, 0], # Clutter/background (RGB: 255, 0, 0) 8 | 255: [0, 0, 0], # background 9 | } 10 | -------------------------------------------------------------------------------- /datafiles/vaihingen/seg_val.txt: -------------------------------------------------------------------------------- 1 | 11_104.png 2 | 11_132.png 3 | 11_133.png 4 | 11_137.png 5 | 11_150.png 6 | 11_152.png 7 | 11_172.png 8 | 11_211.png 9 | 11_212.png 10 | 11_215.png 11 | 11_221.png 12 | 11_23.png 13 | 11_239.png 14 | 11_24.png 15 | 11_241.png 16 | 11_242.png 17 | 11_244.png 18 | 11_262.png 19 | 11_272.png 20 | 11_281.png 21 | 11_282.png 22 | 11_284.png 23 | 11_293.png 24 | 11_31.png 25 | 11_44.png 26 | 11_49.png 27 | 11_50.png 28 | 11_60.png 29 | 11_69.png 30 | 23_120.png 31 | 23_131.png 32 | 23_133.png 33 | 23_137.png 34 | 23_149.png 35 | 23_150.png 36 | 23_157.png 37 | 23_195.png 38 | 23_197.png 39 | 23_198.png 40 | 23_204.png 41 | 23_207.png 42 | 23_216.png 43 | 23_229.png 44 | 23_24.png 45 | 23_241.png 46 | 23_255.png 47 | 23_260.png 48 | 23_271.png 49 | 23_279.png 50 | 23_280.png 51 | 23_281.png 52 | 23_72.png 53 | 23_75.png 54 | 23_8.png 55 | 23_88.png 56 | 17_31.png 57 | 17_6.png 58 | 17_65.png 59 | 17_69.png 60 | 17_99.png 61 | 1_111.png 62 | 1_117.png 63 | 1_119.png 64 | 1_133.png 65 | 1_138.png 66 | 1_146.png 67 | 1_16.png 68 | 1_161.png 69 | 1_170.png 70 | 1_178.png 71 | 1_191.png 72 | 1_205.png 73 | 1_206.png 74 | 1_213.png 75 | 1_22.png 76 | 1_222.png 77 | 1_226.png 78 | 1_227.png 79 | 1_228.png 80 | 1_236.png 81 | 1_238.png 82 | 1_255.png 83 | 1_264.png 84 | 1_291.png 85 | 1_294.png 86 | 1_299.png 87 | 1_309.png 88 | 1_313.png 89 | 32_70.png 90 | 32_8.png 91 | 32_82.png 92 | 32_92.png 93 | 34_0.png 94 | 34_11.png 95 | 34_120.png 96 | 34_129.png 97 | 34_158.png 98 | 34_179.png 99 | 34_196.png 100 | 34_2.png 101 | 34_201.png 102 | 34_31.png 103 | 34_43.png 104 | 34_49.png 105 | 34_78.png 106 | 34_87.png 107 | 34_93.png 108 | 37_108.png 109 | 37_117.png 110 | 37_12.png 111 | 37_124.png 112 | 37_127.png 113 | 37_137.png 114 | 37_146.png 115 | 37_152.png 116 | 37_167.png 117 | 37_170.png 118 | 37_171.png 119 | 13_231.png 120 | 13_248.png 121 | 13_253.png 122 | 13_261.png 123 | 13_262.png 124 | 13_269.png 125 | 13_27.png 126 | 13_280.png 127 | 13_281.png 128 | 13_282.png 129 | 13_285.png 130 | 13_286.png 131 | 13_289.png 132 | 13_322.png 133 | 13_333.png 134 | 13_341.png 135 | 13_359.png 136 | 13_363.png 137 | 13_401.png 138 | 13_410.png 139 | 13_412.png 140 | 13_421.png 141 | 13_45.png 142 | 13_58.png 143 | 13_60.png 144 | 13_62.png 145 | 13_71.png 146 | 13_90.png 147 | 15_100.png 148 | 15_11.png 149 | 15_113.png 150 | 15_131.png 151 | 15_145.png 152 | 15_152.png 153 | 15_154.png 154 | 15_163.png 155 | 15_182.png 156 | 11_83.png 157 | 13_229.png 158 | 15_191.png 159 | 17_199.png 160 | 1_35.png 161 | 21_187.png 162 | 23_117.png 163 | 26_120.png 164 | 28_127.png 165 | 30_219.png 166 | 32_69.png 167 | 37_179.png 168 | 3_313.png 169 | 5_192.png 170 | 5_56.png 171 | 3_319.png 172 | 3_324.png 173 | 3_335.png 174 | 3_339.png 175 | 3_340.png 176 | 3_38.png 177 | 3_59.png 178 | 3_61.png 179 | 3_71.png 180 | 3_86.png 181 | 3_88.png 182 | 3_97.png 183 | 5_117.png 184 | 5_12.png 185 | 5_127.png 186 | 5_13.png 187 | 5_15.png 188 | 5_158.png 189 | 5_16.png 190 | 5_164.png 191 | 5_186.png 192 | 5_191.png 193 | 28_137.png 194 | 28_149.png 195 | 28_157.png 196 | 28_16.png 197 | 28_165.png 198 | 28_17.png 199 | 28_173.png 200 | 28_183.png 201 | 28_199.png 202 | 28_2.png 203 | 28_221.png 204 | 28_223.png 205 | 28_228.png 206 | 28_230.png 207 | 28_232.png 208 | 28_247.png 209 | 28_270.png 210 | 28_272.png 211 | 28_286.png 212 | 28_54.png 213 | 28_56.png 214 | 28_82.png 215 | 28_9.png 216 | 28_93.png 217 | 30_102.png 218 | 30_103.png 219 | 30_120.png 220 | 30_124.png 221 | 30_128.png 222 | 30_139.png 223 | 30_155.png 224 | 30_160.png 225 | 30_162.png 226 | 30_170.png 227 | 30_190.png 228 | 30_214.png 229 | 30_217.png 230 | 1_37.png 231 | 1_4.png 232 | 1_44.png 233 | 1_63.png 234 | 1_74.png 235 | 1_85.png 236 | 1_90.png 237 | 1_95.png 238 | 21_101.png 239 | 21_106.png 240 | 21_109.png 241 | 21_126.png 242 | 21_141.png 243 | 21_142.png 244 | 21_145.png 245 | 21_158.png 246 | 21_159.png 247 | 21_164.png 248 | 21_168.png 249 | 21_174.png 250 | 21_179.png 251 | 5_95.png 252 | 5_98.png 253 | 7_100.png 254 | 7_107.png 255 | 7_119.png 256 | 7_127.png 257 | 7_128.png 258 | 7_149.png 259 | 7_15.png 260 | 7_159.png 261 | 7_163.png 262 | 7_168.png 263 | 7_179.png 264 | 7_18.png 265 | 7_190.png 266 | 7_195.png 267 | 7_2.png 268 | 7_207.png 269 | 7_219.png 270 | 7_223.png 271 | 7_226.png 272 | 7_237.png 273 | 7_239.png 274 | 7_244.png 275 | 7_283.png 276 | 7_284.png 277 | 7_40.png 278 | 7_49.png 279 | 7_59.png 280 | 7_60.png 281 | 7_63.png 282 | 7_74.png 283 | 7_88.png 284 | 37_191.png 285 | 37_198.png 286 | 37_200.png 287 | 37_37.png 288 | 37_41.png 289 | 37_51.png 290 | 37_59.png 291 | 37_85.png 292 | 37_88.png 293 | 3_0.png 294 | 3_103.png 295 | 3_111.png 296 | 3_116.png 297 | 3_121.png 298 | 3_122.png 299 | 3_132.png 300 | 3_133.png 301 | 3_138.png 302 | 3_141.png 303 | 3_194.png 304 | 3_203.png 305 | 3_218.png 306 | 3_22.png 307 | 3_231.png 308 | 3_254.png 309 | 3_256.png 310 | 3_268.png 311 | 3_281.png 312 | 3_289.png 313 | 3_302.png 314 | 3_304.png 315 | 30_253.png 316 | 30_264.png 317 | 30_27.png 318 | 30_272.png 319 | 30_304.png 320 | 30_305.png 321 | 30_36.png 322 | 30_40.png 323 | 30_43.png 324 | 30_52.png 325 | 30_59.png 326 | 30_65.png 327 | 30_75.png 328 | 30_77.png 329 | 30_8.png 330 | 30_83.png 331 | 30_90.png 332 | 32_106.png 333 | 32_120.png 334 | 32_131.png 335 | 32_148.png 336 | 32_158.png 337 | 32_177.png 338 | 32_183.png 339 | 32_186.png 340 | 32_206.png 341 | 32_207.png 342 | 32_22.png 343 | 32_234.png 344 | 32_243.png 345 | 32_280.png 346 | 32_33.png 347 | 32_38.png 348 | 32_60.png 349 | 32_64.png 350 | 15_213.png 351 | 15_217.png 352 | 15_218.png 353 | 15_220.png 354 | 15_229.png 355 | 15_232.png 356 | 15_243.png 357 | 15_258.png 358 | 15_263.png 359 | 15_266.png 360 | 15_271.png 361 | 15_277.png 362 | 15_293.png 363 | 15_299.png 364 | 15_301.png 365 | 15_302.png 366 | 15_312.png 367 | 15_41.png 368 | 15_6.png 369 | 15_82.png 370 | 15_93.png 371 | 17_0.png 372 | 17_10.png 373 | 17_12.png 374 | 17_122.png 375 | 17_126.png 376 | 17_132.png 377 | 17_137.png 378 | 17_150.png 379 | 17_154.png 380 | 17_167.png 381 | 17_174.png 382 | 17_177.png 383 | 11_84.png 384 | 11_9.png 385 | 13_1.png 386 | 13_102.png 387 | 13_109.png 388 | 13_111.png 389 | 13_113.png 390 | 13_121.png 391 | 13_129.png 392 | 13_13.png 393 | 13_132.png 394 | 13_142.png 395 | 13_144.png 396 | 13_145.png 397 | 13_148.png 398 | 13_152.png 399 | 13_154.png 400 | 13_171.png 401 | 13_175.png 402 | 13_194.png 403 | 13_199.png 404 | 13_202.png 405 | 13_21.png 406 | 13_214.png 407 | 13_215.png 408 | 13_218.png 409 | 13_227.png 410 | 26_132.png 411 | 26_143.png 412 | 26_147.png 413 | 26_15.png 414 | 26_150.png 415 | 26_155.png 416 | 26_158.png 417 | 26_169.png 418 | 26_17.png 419 | 26_178.png 420 | 26_189.png 421 | 26_214.png 422 | 26_223.png 423 | 26_23.png 424 | 26_256.png 425 | 26_259.png 426 | 26_264.png 427 | 26_265.png 428 | 26_266.png 429 | 26_3.png 430 | 26_41.png 431 | 26_46.png 432 | 26_58.png 433 | 26_90.png 434 | 28_116.png 435 | 28_119.png 436 | 5_201.png 437 | 5_224.png 438 | 5_229.png 439 | 5_244.png 440 | 5_249.png 441 | 5_256.png 442 | 5_257.png 443 | 5_26.png 444 | 5_262.png 445 | 5_267.png 446 | 5_268.png 447 | 5_275.png 448 | 5_279.png 449 | 5_29.png 450 | 5_39.png 451 | 5_44.png 452 | 5_46.png 453 | 5_50.png 454 | 5_55.png 455 | 21_189.png 456 | 21_194.png 457 | 21_197.png 458 | 21_2.png 459 | 21_204.png 460 | 21_215.png 461 | 21_216.png 462 | 21_230.png 463 | 21_238.png 464 | 21_241.png 465 | 21_244.png 466 | 21_28.png 467 | 21_4.png 468 | 21_6.png 469 | 21_70.png 470 | 21_85.png 471 | 21_95.png 472 | 23_104.png 473 | 23_106.png 474 | -------------------------------------------------------------------------------- /dataset/__init__.py: -------------------------------------------------------------------------------- 1 | from dataset.datapoint import Dataset_point 2 | from dataset.data_sec import Dataset_sec 3 | 4 | import dataset.transform as trans -------------------------------------------------------------------------------- /dataset/data_sec.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import Dataset 3 | import numpy as np 4 | import os 5 | from utils import util 6 | from tqdm import tqdm 7 | import multiprocessing 8 | from torchvision import transforms 9 | import dataset.transform as trans 10 | import tifffile 11 | import matplotlib.pyplot as plt 12 | from dataset.data_utils import label_transform, value_to_rgb 13 | 14 | 15 | class Dataset_sec(Dataset): 16 | def __init__(self, opt, save_path, file_name_txt_path, flag='train', transform=None): 17 | # first round 18 | save_path = save_path + '/Point' 19 | # save_path = save_path + '/Second' 20 | 21 | # parameters from opt 22 | self.data_root = os.path.join(opt.data_root, opt.dataset) 23 | self.dataset = opt.dataset 24 | 25 | if flag == 'train': 26 | self.img_path = self.data_root + '/train/img' 27 | self.label_path = save_path + '/predict_train/crf/label' 28 | 29 | elif flag == 'val': 30 | self.img_path = self.data_root + '/val/img' 31 | self.label_path = save_path + '/predict_val/crf/label' 32 | 33 | elif flag == 'predict_train': 34 | self.img_path = self.data_root + '/train/img' 35 | self.label_path = self.data_root + '/train/crf/label' 36 | 37 | elif flag == 'test': 38 | self.img_path = self.data_root + '/test/img' 39 | self.label_path = self.data_root + '/test/label' 40 | 41 | self.img_size = opt.img_size 42 | self.num_classes = opt.num_classes 43 | self.in_channels = opt.in_channels 44 | 45 | self.img_txt_path = file_name_txt_path 46 | self.flag = flag 47 | self.transform = transform 48 | self.img_label_path_pairs = self.get_img_label_path_pairs() 49 | 50 | def get_img_label_path_pairs(self): 51 | img_label_pair_list = {} 52 | 53 | with open(self.img_txt_path, 'r') as lines: 54 | for idx, line in enumerate(tqdm(lines)): 55 | name = line.strip("\n").split(' ')[0] 56 | path = os.path.join(self.img_path, name) 57 | img_label_pair_list.setdefault(idx, [path, name]) 58 | 59 | return img_label_pair_list 60 | 61 | def make_clslabel(self, label): 62 | label_set = np.unique(label) 63 | cls_label = np.zeros(self.num_classes) 64 | for i in label_set: 65 | if i != 255: 66 | cls_label[i] += 1 67 | cls_label = torch.from_numpy(cls_label).float() 68 | return cls_label 69 | 70 | def data_transform(self, img, label): 71 | img = img[:, :, :self.in_channels] 72 | img = img.astype(np.float32).transpose(2, 0, 1) 73 | img = torch.from_numpy(img).float() 74 | if len(label.shape) == 3: 75 | label = label[:, :, 0] 76 | label = label.copy() 77 | label = torch.from_numpy(label).long() 78 | 79 | return img, label 80 | 81 | def __getitem__(self, index): 82 | item = self.img_label_path_pairs[index] 83 | img_path, name = item 84 | 85 | img = util.read(img_path)[:, :, :self.in_channels] 86 | label = util.read(os.path.join(self.label_path, name)) 87 | if len(label.shape) != 3: 88 | label = np.expand_dims(label, axis=-1) 89 | 90 | if self.flag == 'train': 91 | if self.transform is not None: 92 | for t in self.transform.transforms: 93 | img, label = t([img, label]) 94 | 95 | img = util.Normalize(img, flag=self.dataset) 96 | img, label = self.data_transform(img, label) 97 | return img, label, name 98 | 99 | def __len__(self): 100 | 101 | return len(self.img_label_path_pairs) 102 | 103 | 104 | if __name__ == "__main__": 105 | 106 | from utils import * 107 | from options import * 108 | from torch.utils.data import DataLoader 109 | 110 | opt = Sec_Options().parse() 111 | print(opt) 112 | save_path = os.path.join(opt.save_path, opt.dataset) 113 | train_txt_path, val_txt_path, test_txt_path = create_data_path(opt) 114 | 115 | train_transform = transforms.Compose([ 116 | trans.Color_Aug(), 117 | trans.RandomHorizontalFlip(), 118 | trans.RandomVerticleFlip(), 119 | trans.RandomRotate90(), 120 | ]) 121 | dataset = Dataset_sec(opt, save_path, train_txt_path, flag='train', transform=train_transform) 122 | 123 | loader = DataLoader( 124 | dataset=dataset, 125 | batch_size=1, num_workers=8, pin_memory=True, drop_last=True 126 | ) 127 | 128 | for i in tqdm(loader): 129 | img, label, name = i 130 | 131 | pass -------------------------------------------------------------------------------- /dataset/data_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | 4 | 5 | def label_transform(dataset, label): 6 | label_new = label 7 | if dataset == 'potsdam' or dataset == 'vaihingen': 8 | label_new[label == 5] = 255 9 | 10 | else: 11 | return label_new 12 | 13 | return label_new 14 | 15 | 16 | def value_to_rgb(anno, flag='potsdam'): 17 | if flag == 'potsdam' or flag == 'vaihingen': 18 | label2color_dict = { 19 | 0: [255, 255, 255], # Impervious surfaces (RGB: 255, 255, 255) 20 | 1: [0, 0, 255], # Building (RGB: 0, 0, 255) 21 | 2: [0, 255, 255], # Low vegetation (RGB: 0, 255, 255) 22 | 3: [0, 255, 0], # Tree (RGB: 0, 255, 0) 23 | 4: [255, 255, 0], # Car (RGB: 255, 255, 0) 24 | 5: [255, 0, 0], # Clutter/background (RGB: 255, 0, 0) 25 | 255: [0, 0, 0] 26 | } 27 | else: 28 | label2color_dict = {} 29 | 30 | # visualize 31 | visual_anno = np.zeros((anno.shape[0], anno.shape[1], 3), dtype=np.uint8) 32 | for i in range(visual_anno.shape[0]): # i for h 33 | for j in range(visual_anno.shape[1]): 34 | # cv2: bgr 35 | color = label2color_dict[anno[i, j, 0]] 36 | 37 | visual_anno[i, j, 0] = color[0] 38 | visual_anno[i, j, 1] = color[1] 39 | visual_anno[i, j, 2] = color[2] 40 | 41 | return visual_anno -------------------------------------------------------------------------------- /dataset/datapoint.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import Dataset 3 | import numpy as np 4 | import os 5 | from utils import util 6 | from tqdm import tqdm 7 | import multiprocessing 8 | from torchvision import transforms 9 | import dataset.transform as trans 10 | import tifffile 11 | import matplotlib.pyplot as plt 12 | from dataset.data_utils import label_transform 13 | 14 | 15 | class Dataset_point(Dataset): 16 | def __init__(self, opt, file_name_txt_path, flag='train', transform=None): 17 | # parameters from opt 18 | self.data_root = os.path.join(opt.data_root, opt.dataset) 19 | self.dataset = opt.dataset 20 | 21 | if flag == 'test': 22 | self.img_path = self.data_root + '/test/img' 23 | if self.dataset == 'vaihingen': 24 | self.label_path = self.data_root + '/test/label_noB' 25 | else: 26 | self.label_path = self.data_root + '/test/label' 27 | elif flag == 'train': 28 | self.img_path = self.data_root + '/train/img' 29 | self.label_path = self.data_root + '/train/point_label' 30 | elif flag == 'val': 31 | self.img_path = self.data_root + '/val/img' 32 | self.label_path = self.data_root + '/val/point_label' 33 | 34 | elif flag == 'predict_train': 35 | self.img_path = self.data_root + '/train/img' 36 | self.label_path = self.data_root + '/train/label' 37 | 38 | self.img_size = opt.img_size 39 | self.num_classes = opt.num_classes 40 | self.in_channels = opt.in_channels 41 | 42 | self.img_txt_path = file_name_txt_path 43 | self.flag = flag 44 | self.transform = transform 45 | self.img_label_path_pairs = self.get_img_label_path_pairs() 46 | 47 | def get_img_label_path_pairs(self): 48 | img_label_pair_list = {} 49 | 50 | with open(self.img_txt_path, 'r') as lines: 51 | for idx, line in enumerate(tqdm(lines)): 52 | name = line.strip("\n").split(' ')[0] 53 | path = os.path.join(self.img_path, name) 54 | img_label_pair_list.setdefault(idx, [path, name]) 55 | 56 | return img_label_pair_list 57 | 58 | def make_clslabel(self, label): 59 | label_set = np.unique(label) 60 | cls_label = np.zeros(self.num_classes) 61 | for i in label_set: 62 | if i != 255: 63 | cls_label[i] += 1 64 | return cls_label 65 | 66 | def data_transform(self, img, label, cls_label): 67 | img = img[:, :, :self.in_channels] 68 | img = img.astype(np.float32).transpose(2, 0, 1) 69 | img = torch.from_numpy(img).float() 70 | if len(label.shape) == 3: 71 | label = label[:, :, 0] 72 | 73 | label, cls_label = label.copy(), cls_label.copy() 74 | 75 | label = torch.from_numpy(label).long() 76 | cls_label = torch.from_numpy(cls_label).float() 77 | 78 | return img, label, cls_label 79 | 80 | def __getitem__(self, index): 81 | item = self.img_label_path_pairs[index] 82 | img_path, name = item 83 | 84 | img = util.read(img_path)[:, :, :3] 85 | img = util.Normalize(img, flag=self.dataset) 86 | label = util.read(os.path.join(self.label_path, name)) 87 | 88 | label = label_transform(self.dataset, label) 89 | cls_label = self.make_clslabel(label) 90 | 91 | # data transform 92 | if self.transform is not None: 93 | for t in self.transform.transforms: 94 | img, label = t([img, label]) 95 | 96 | img, label, cls_label = self.data_transform(img, label, cls_label) 97 | return img, label, cls_label, name 98 | 99 | def __len__(self): 100 | 101 | return len(self.img_label_path_pairs) 102 | 103 | 104 | def value_to_rgb(anno, flag='potsdam'): 105 | if flag == 'potsdam' or flag == 'vaihingen': 106 | label2color_dict = { 107 | 0: [255, 255, 255], # Impervious surfaces (RGB: 255, 255, 255) 108 | 1: [0, 0, 255], # Building (RGB: 0, 0, 255) 109 | 2: [0, 255, 255], # Low vegetation (RGB: 0, 255, 255) 110 | 3: [0, 255, 0], # Tree (RGB: 0, 255, 0) 111 | 4: [255, 255, 0], # Car (RGB: 255, 255, 0) 112 | 5: [255, 0, 0], # Clutter/background (RGB: 255, 0, 0) 113 | 255: [0, 0, 0] 114 | } 115 | else: 116 | label2color_dict = {} 117 | 118 | # visualize 119 | visual_anno = np.zeros((anno.shape[0], anno.shape[1], 3), dtype=np.uint8) 120 | for i in range(visual_anno.shape[0]): # i for h 121 | for j in range(visual_anno.shape[1]): 122 | # cv2: bgr 123 | color = label2color_dict[anno[i, j, 0]] 124 | 125 | visual_anno[i, j, 0] = color[0] 126 | visual_anno[i, j, 1] = color[1] 127 | visual_anno[i, j, 2] = color[2] 128 | 129 | return visual_anno 130 | 131 | 132 | def show_pair(opt, img, label, name): 133 | true_label = read('/media/hlf/Luffy/WLS/semantic/dataset/'+opt.dataset+'/train/label_vis/' + name[0]) 134 | fig, axs = plt.subplots(1, 3, figsize=(14, 4)) 135 | 136 | img = img.squeeze(0).permute(1, 2, 0).cpu().numpy() 137 | img = Normalize_back(img, flag=opt.dataset) 138 | axs[0].imshow(img[:, :, :3].astype(np.uint8)) 139 | axs[0].axis("off") 140 | 141 | label = label.permute(1, 2, 0).cpu().numpy() 142 | print(np.unique(label)) 143 | vis = value_to_rgb(label, flag=opt.dataset) 144 | axs[1].imshow(vis.astype(np.uint8)) 145 | axs[1].axis("off") 146 | 147 | axs[2].imshow(true_label.astype(np.uint8)) 148 | axs[2].axis("off") 149 | 150 | plt.tight_layout() 151 | plt.show() 152 | plt.close() 153 | 154 | 155 | if __name__ == "__main__": 156 | 157 | from utils import * 158 | from options import * 159 | from torch.utils.data import DataLoader 160 | 161 | opt = Point_Options().parse() 162 | print(opt) 163 | train_txt_path, val_txt_path, test_txt_path = create_data_path(opt) 164 | 165 | train_transform = transforms.Compose([ 166 | trans.Scale(opt.img_size), 167 | trans.RandomHorizontalFlip(), 168 | trans.RandomVerticleFlip(), 169 | trans.RandomRotate90(), 170 | ]) 171 | dataset = Dataset_point(opt, train_txt_path, flag='train', transform=None) 172 | 173 | loader = DataLoader( 174 | dataset=dataset, shuffle=True, 175 | batch_size=1, num_workers=8, pin_memory=True, drop_last=True 176 | ) 177 | 178 | for i in tqdm(loader): 179 | img, label, cls_label, name = i 180 | show_pair(opt, img, label, name) 181 | pass -------------------------------------------------------------------------------- /dataset/transform.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | import os 4 | import math 5 | import collections 6 | import random 7 | 8 | 9 | class Scale(object): 10 | """Rescale the input PIL.Image to the given size. 11 | 12 | Args: 13 | size (sequence or int): Desired output size. If size is a sequence like 14 | (w, h), output size will be matched to this. If size is an int, 15 | smaller edge of the image will be matched to this number. 16 | i.e, if height > width, then image will be rescaled to 17 | (size * height / width, size) 18 | interpolation (int, optional): Desired interpolation. Default is 19 | ``PIL.Image.BILINEAR`` 20 | """ 21 | 22 | def __init__(self, size): 23 | assert isinstance(size, int) or (isinstance(size, collections.Iterable) and len(size) == 2) 24 | self.size = size 25 | 26 | def __call__(self, inputs): 27 | """ 28 | Args: 29 | img (npy): Image to be scaled. 30 | """ 31 | outs = [] 32 | for input in inputs: 33 | h = w = self.size 34 | oh, ow, c = input.shape 35 | img = np.resize(input, (h, w, c)) 36 | outs.append(img) 37 | 38 | return outs 39 | 40 | 41 | class RandomHorizontalFlip(object): 42 | def __init__(self, u=0.5): 43 | self.u = u 44 | 45 | def __call__(self, inputs): 46 | if np.random.random() < self.u: 47 | new_inputs = [] 48 | for input in inputs: 49 | input = np.flip(input, 0) 50 | new_inputs.append(input) 51 | return new_inputs 52 | else: 53 | return inputs 54 | 55 | 56 | class RandomVerticleFlip(object): 57 | def __init__(self, u=0.5): 58 | self.u = u 59 | 60 | def __call__(self, inputs): 61 | if np.random.random() < self.u: 62 | new_inputs = [] 63 | for input in inputs: 64 | input = np.flip(input, 1) 65 | new_inputs.append(input) 66 | return new_inputs 67 | else: 68 | return inputs 69 | 70 | 71 | class RandomRotate90(object): 72 | def __init__(self, u=0.5): 73 | self.u = u 74 | 75 | def __call__(self, inputs): 76 | if np.random.random() < self.u: 77 | new_inputs = [] 78 | for input in inputs: 79 | input = np.rot90(input) 80 | new_inputs.append(input) 81 | return new_inputs 82 | else: 83 | return inputs 84 | 85 | 86 | class Color_Aug(object): 87 | def __init__(self): 88 | self.contra_adj = 0.1 89 | self.bright_adj = 0.1 90 | 91 | def __call__(self, image): 92 | n_ch = image.shape[-1] 93 | ch_mean = np.mean(image, axis=(0, 1), keepdims=True).astype(np.float32) 94 | 95 | contra_mul = np.random.uniform(1 - self.contra_adj, 1 + self.contra_adj, (1, 1, n_ch)).astype( 96 | np.float32 97 | ) 98 | bright_mul = np.random.uniform(1 - self.bright_adj, 1 + self.bright_adj, (1, 1, n_ch)).astype( 99 | np.float32 100 | ) 101 | 102 | image = (image - ch_mean) * contra_mul + ch_mean * bright_mul 103 | 104 | return image 105 | 106 | 107 | class RandomHueSaturationValue(object): 108 | def __init__(self): 109 | self.hue_shift_limit = (-25, 25) 110 | self.sat_shift_limit = (-15, 15) 111 | self.val_shift_limit = (-15, 15) 112 | 113 | def change(self, image): 114 | image = cv2.cvtColor(image, cv2.COLOR_BGR2HSV) 115 | h, s, v = cv2.split(image) 116 | hue_shift = np.random.randint(self.hue_shift_limit[0], self.hue_shift_limit[1] + 1) 117 | hue_shift = np.uint8(hue_shift) 118 | h += hue_shift 119 | sat_shift = np.random.uniform(self.sat_shift_limit[0], self.sat_shift_limit[1]) 120 | s = cv2.add(s, sat_shift) 121 | val_shift = np.random.uniform(self.val_shift_limit[0], self.val_shift_limit[1]) 122 | v = cv2.add(v, val_shift) 123 | image = cv2.merge((h, s, v)) 124 | # image = cv2.merge((s, v)) 125 | image = cv2.cvtColor(image, cv2.COLOR_HSV2BGR) 126 | return image 127 | 128 | def __call__(self, img): 129 | if np.random.random() < 0.5: 130 | if img.shape[-1] == 4: 131 | img_3 = img[:, :, :3] 132 | img_3 = self.change(img_3) 133 | img[:, :, :3] = img_3 134 | else: 135 | img = self.change(img) 136 | 137 | return img 138 | 139 | 140 | def stretchImage(data, s=0.005, bins=2000): # 线性拉伸,去掉最大最小0.5%的像素值,然后线性拉伸至[0,1] 141 | ht = np.histogram(data, bins) 142 | d = np.cumsum(ht[0]) / float(data.size) 143 | lmin = 0 144 | lmax = bins - 1 145 | while lmin < bins: 146 | if d[lmin] >= s: 147 | break 148 | lmin += 1 149 | while lmax >= 0: 150 | if d[lmax] <= 1 - s: 151 | break 152 | lmax -= 1 153 | return np.clip((data - ht[1][lmin]) / (ht[1][lmax] - ht[1][lmin]), 0, 1) 154 | 155 | 156 | g_para = {} 157 | 158 | 159 | def getPara(radius=4): # 根据半径计算权重参数矩阵 160 | global g_para 161 | m = g_para.get(radius, None) 162 | if m is not None: 163 | return m 164 | size = radius * 2 + 1 165 | m = np.zeros((size, size)) 166 | for h in range(-radius, radius + 1): 167 | for w in range(-radius, radius + 1): 168 | if h == 0 and w == 0: 169 | continue 170 | m[radius + h, radius + w] = 1.0 / math.sqrt(h ** 2 + w ** 2) 171 | m /= m.sum() 172 | g_para[radius] = m 173 | return m 174 | 175 | 176 | def zmIce(I, ratio=4, radius=300): # 常规的ACE实现 177 | para = getPara(radius) 178 | height, width = I.shape 179 | # zh,zw = [0]*radius + range(height) + [height-1]*radius, [0]*radius + range(width) + [width -1]*radius 180 | zh, zw = [0] * radius + [x for x in range(height)] + [height - 1] * radius, [0] * radius + [x for x in 181 | range(width)] + [ 182 | width - 1] * radius 183 | Z = I[np.ix_(zh, zw)] 184 | res = np.zeros(I.shape) 185 | for h in range(radius * 2 + 1): 186 | for w in range(radius * 2 + 1): 187 | if para[h][w] == 0: 188 | continue 189 | res += (para[h][w] * np.clip((I - Z[h:h + height, w:w + width]) * ratio, -1, 1)) 190 | return res 191 | 192 | 193 | def zmIceFast(I, ratio=4, radius=300): # 单通道ACE快速增强实现 194 | height, width = I.shape[:2] 195 | if min(height, width) <= 2: 196 | return np.zeros(I.shape) + 0.5 197 | Rs = cv2.resize(I, ((width + 1) // 2, (height + 1) // 2)) 198 | Rf = zmIceFast(Rs, ratio, radius) # 递归调用 199 | Rf = cv2.resize(Rf, (width, height)) 200 | Rs = cv2.resize(Rs, (width, height)) 201 | 202 | return Rf + zmIce(I, ratio, radius) - zmIce(Rs, ratio, radius) 203 | 204 | 205 | def zmIceColor(I, ratio=4, radius=3): # rgb三通道分别增强,ratio是对比度增强因子,radius是卷积模板半径 206 | res = np.zeros(I.shape) 207 | for k in range(3): 208 | res[:, :, k] = stretchImage(zmIceFast(I[:, :, k], ratio, radius)) 209 | return res 210 | 211 | 212 | def do_gamma(image, gamma=1.0): 213 | image = image ** (1.0 / gamma) 214 | image = np.clip(image, 0, 1) 215 | return image -------------------------------------------------------------------------------- /models/DBFNet.py: -------------------------------------------------------------------------------- 1 | from models.base_model import * 2 | from models.tools import * 3 | import os 4 | from models.fbf import FBFModule 5 | 6 | def clean_mask(mask, cls_label): 7 | n, c = cls_label.size() 8 | """Remove any masks of labels that are not present""" 9 | return mask * cls_label.view(n, c, 1, 1) 10 | 11 | 12 | def get_penalty(predict, cls_label): 13 | # cls_label: (n, c) 14 | # predict: (n, c, h, w) 15 | n, c, h, w = predict.size() 16 | predict = torch.softmax(predict, dim=1) 17 | 18 | # if a patch does not contain label c, 19 | # then none of the pixels in this patch can be assigned to label c 20 | loss0 = - (1 - cls_label.view(n, c, 1, 1)) * torch.log(1 - predict + 1e-6) 21 | loss0 = torch.mean(torch.sum(loss0, dim=1)) 22 | 23 | # if a patch has only one type, then the whole patch should be assigned to this type 24 | sum = (torch.sum(cls_label, dim=-1, keepdim=True) == 1) 25 | loss1 = - (sum * cls_label).view(n, c, 1, 1) * torch.log(predict + 1e-6) 26 | loss1 = torch.mean(torch.sum(loss1, dim=1)) 27 | return loss0 + loss1 28 | 29 | 30 | def build_channels(backbone): 31 | if backbone == 'resnet34' or backbone == 'resnet18': 32 | channels = [64, 128, 256, 512] 33 | 34 | else: 35 | channels = [256, 512, 1024, 2048] 36 | 37 | return channels 38 | 39 | 40 | def get_numiter_dilations(flag): 41 | if flag == 'train': 42 | 43 | # for potsdam 44 | num_iter = 1 45 | dilations = [[1, 3, 5, 7], [1, 3, 5], [1, 3], [1]] 46 | 47 | else: 48 | # for potsdam 49 | num_iter = 3 50 | dilations = [[1], [1], [1], [1]] 51 | 52 | # for vaihingen 53 | # num_iter = 5 54 | # dilations = [[1], [1], [1], [1]] 55 | return num_iter, dilations 56 | 57 | 58 | class FBF_Layer(nn.Module): 59 | def __init__(self, in_c, out_c, num_iter, dilations): 60 | super(FBF_Layer, self).__init__() 61 | self.num_iter = num_iter 62 | self.fbf_m = FBFModule(num_iter=num_iter, dilations=dilations) 63 | self.conv = nn.Sequential( 64 | nn.Conv2d(in_c, out_c, kernel_size=1, stride=1, padding=0), 65 | nn.BatchNorm2d(out_c), 66 | nn.ReLU(), ) 67 | 68 | def forward(self, x): 69 | # x = self.conv0(x) 70 | x = self.fbf_m(x) 71 | x = self.conv(x) 72 | 73 | return x 74 | 75 | 76 | class Net(nn.Module): 77 | def __init__(self, opt, flag='train'): 78 | super(Net, self).__init__() 79 | # stride as 32 80 | self.backbone = build_backbone(opt.backbone, output_stride=32) 81 | self.img_size = opt.img_size 82 | 83 | # build DBF layer 84 | num_iter, dilations = get_numiter_dilations(flag) 85 | channels = build_channels(opt.backbone) 86 | key_channels = 128 87 | self.FBF_layer1 = FBF_Layer(channels[0], key_channels, num_iter, dilations[0]) 88 | self.FBF_layer2 = FBF_Layer(channels[1], key_channels, num_iter, dilations[1]) 89 | self.FBF_layer3 = FBF_Layer(channels[2], key_channels, num_iter, dilations[2]) 90 | self.FBF_layer4 = FBF_Layer(channels[3], key_channels, num_iter, dilations[3]) 91 | 92 | self.last_conv = nn.Sequential( 93 | nn.Conv2d(4*128, 128, kernel_size=1, stride=1, padding=0, bias=False), 94 | nn.BatchNorm2d(128), 95 | nn.ReLU(), 96 | ) 97 | self.FBF_layer = FBF_Layer(key_channels, key_channels, 1, dilations=[1]) 98 | self.seg_decoder = nn.Sequential( 99 | nn.Conv2d(key_channels, key_channels, kernel_size=1, stride=1, padding=0), 100 | nn.BatchNorm2d(key_channels), 101 | nn.ReLU(), 102 | nn.Conv2d(key_channels, opt.num_classes, kernel_size=3, stride=1, padding=1), 103 | ) 104 | 105 | def resize_out(self, output): 106 | if output.size()[-1] != self.img_size: 107 | output = F.interpolate(output, size=(self.img_size, self.img_size), mode='bilinear') 108 | return output 109 | 110 | def upsample_cat(self, p1, p2, p3, p4): 111 | p2 = nn.functional.interpolate(p2, size=p1.size()[2:], mode='bilinear', align_corners=True) 112 | p3 = nn.functional.interpolate(p3, size=p1.size()[2:], mode='bilinear', align_corners=True) 113 | p4 = nn.functional.interpolate(p4, size=p1.size()[2:], mode='bilinear', align_corners=True) 114 | return torch.cat([p1, p2, p3, p4], dim=1) 115 | 116 | def forward(self, x): 117 | # resnet 4 layers 118 | l1, l2, l3, l4 = self.backbone(x) 119 | 120 | # after feature bilateral filtering 121 | f1 = self.FBF_layer1(l1) 122 | f2 = self.FBF_layer2(l2) 123 | f3 = self.FBF_layer3(l3) 124 | f4 = self.FBF_layer4(l4) 125 | 126 | # fpn-like Top-down 127 | p4 = f4 128 | p3 = F.upsample(p4, size=f3.size()[2:], mode='bilinear') + f3 129 | p2 = F.upsample(p3, size=f2.size()[2:], mode='bilinear') + f2 130 | p1 = F.upsample(p2, size=f1.size()[2:], mode='bilinear') + f1 131 | 132 | cat = self.upsample_cat(p1, p2, p3, p4) 133 | feat = self.last_conv(cat) 134 | feat = self.FBF_layer(feat) 135 | 136 | out = self.seg_decoder(feat) 137 | out = self.resize_out(out) 138 | 139 | return out 140 | 141 | def forward_loss(self, x, label, cls_label): 142 | coarse_mask = self.forward(x) 143 | 144 | # get loss 145 | criterion = nn.CrossEntropyLoss(ignore_index=255) 146 | # penalty 147 | penalty = get_penalty(coarse_mask, cls_label) 148 | # label: point annotations point-level supervision 149 | seg_loss = criterion(coarse_mask, label) 150 | 151 | return seg_loss, penalty 152 | 153 | 154 | if __name__ == "__main__": 155 | import utils.util as util 156 | from options import * 157 | from torch.utils.data import DataLoader 158 | from dataset import * 159 | from tqdm import tqdm 160 | from torchvision import transforms 161 | 162 | opt = Point_Options().parse() 163 | save_path = os.path.join(opt.save_path, opt.dataset) 164 | train_txt_path, val_txt_path, test_txt_path = util.create_data_path(opt) 165 | log_path, checkpoint_path, predict_path, _, _ = util.create_save_path(opt) 166 | 167 | train_transform = transforms.Compose([ 168 | trans.RandomHorizontalFlip(), 169 | trans.RandomVerticleFlip(), 170 | trans.RandomRotate90(), 171 | ]) 172 | dataset = Dataset_point(opt, train_txt_path, flag='train', transform=train_transform) 173 | 174 | loader = DataLoader( 175 | dataset=dataset, 176 | batch_size=1, num_workers=8, pin_memory=True, drop_last=True 177 | ) 178 | net = Net(opt, flag='train') 179 | net.cuda() 180 | torch.save({'state_dict': net.state_dict()}, 181 | os.path.join(checkpoint_path, 'model.pth')) 182 | 183 | for i in tqdm(loader): 184 | # put the data from loader to cuda 185 | img, label, cls_label, name = i 186 | input, label, cls_label = img.cuda(non_blocking=True), \ 187 | label.cuda(non_blocking=True), cls_label.cuda(non_blocking=True) 188 | 189 | # model forward 190 | out = net.forward(input) 191 | 192 | pass 193 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | from models.base_model import * 2 | from models.seg_model import * 3 | from models.DBFNet import * 4 | -------------------------------------------------------------------------------- /models/base_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from models.network import * 5 | from options import * 6 | 7 | 8 | def build_backbone(backbone, output_stride, pretrained=True, in_c=3): 9 | if backbone == 'resnet50': 10 | return ResNet50(output_stride, pretrained=pretrained, in_c=in_c) 11 | 12 | elif backbone == 'resnet101': 13 | return ResNet101(output_stride, pretrained=pretrained, in_c=in_c) 14 | 15 | elif backbone == 'resnet34': 16 | return ResNet34(output_stride, pretrained=pretrained, in_c=in_c) 17 | 18 | elif backbone == 'resnet18': 19 | return ResNet18(output_stride, pretrained=pretrained, in_c=in_c) 20 | 21 | else: 22 | raise NotImplementedError 23 | 24 | 25 | def build_base_model(opt): 26 | # get in_channels 27 | if opt.base_model == 'Deeplab': 28 | model = DeepLab(backbone=opt.backbone, in_channels=opt.in_channels) 29 | 30 | elif opt.base_model == 'LinkNet': 31 | model = DLinkNet(opt) 32 | 33 | elif opt.base_model == 'FCN': 34 | model = FCN(opt) 35 | 36 | elif opt.base_model == 'UNet': 37 | model = UNet(opt) 38 | 39 | else: 40 | model = None 41 | print('model none') 42 | return model 43 | 44 | 45 | def build_channels(opt): 46 | if opt.base_model == 'Deeplab': 47 | channels = 256 48 | elif opt.base_model == 'STANet': 49 | channels = 128 50 | elif opt.base_model == 'LinkNet': 51 | channels = 32 52 | elif opt.base_model == 'FCN': 53 | channels = 32 54 | elif opt.base_model == 'UNet': 55 | channels = 32 56 | else: 57 | channels = None 58 | print('model none') 59 | return channels 60 | 61 | 62 | -------------------------------------------------------------------------------- /models/fbf.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import torch.nn as nn 4 | 5 | 6 | class LocalAffinity(nn.Module): 7 | 8 | def __init__(self, dilations=[1]): 9 | super(LocalAffinity, self).__init__() 10 | self.dilations = dilations 11 | weight = self._init_aff() 12 | self.register_buffer('kernel', weight) 13 | 14 | def _init_aff(self): 15 | # initialising the shift kernel 16 | weight = torch.zeros(8, 1, 3, 3, device='cuda') 17 | 18 | weight[0, 0, 0, 0] = 1 19 | weight[1, 0, 0, 1] = 1 20 | weight[2, 0, 0, 2] = 1 21 | 22 | weight[3, 0, 1, 0] = 1 23 | weight[4, 0, 1, 2] = 1 24 | 25 | weight[5, 0, 2, 0] = 1 26 | weight[6, 0, 2, 1] = 1 27 | weight[7, 0, 2, 2] = 1 28 | 29 | self.weight_check = weight.clone() 30 | return weight 31 | 32 | def forward(self, x): 33 | self.weight_check = self.weight_check.type_as(x) 34 | assert torch.all(self.weight_check.eq(self.kernel)) 35 | 36 | B, K, H, W = x.size() 37 | x = x.view(B * K, 1, H, W) 38 | 39 | x_affs = [] 40 | for d in self.dilations: 41 | x_pad = F.pad(x, [d] * 4, mode='replicate') 42 | x_aff = F.conv2d(x_pad, self.kernel, dilation=d) 43 | x_affs.append(x_aff) 44 | 45 | x_aff = torch.cat(x_affs, 1) 46 | return x_aff.view(B, K, -1, H, W) 47 | 48 | 49 | class FBFModule(nn.Module): 50 | def __init__(self, num_iter=5, dilations=[1]): 51 | # Dilated Local Affinity 52 | super(FBFModule, self).__init__() 53 | self.num_iter = num_iter 54 | self.aff_loc = LocalAffinity(dilations) 55 | 56 | def forward(self, feature): 57 | # feature: [BxCxHxW] 58 | n, c, h, w = feature.size() 59 | 60 | for _ in range(self.num_iter): 61 | f = self.aff_loc(feature) # [BxCxPxHxW] 62 | # dim2 represent the p neighbor-pixels' value 63 | 64 | abs = torch.abs(feature.unsqueeze(2) - f) 65 | aff = torch.exp(-torch.mean(abs, dim=1, keepdim=True)) 66 | # aff = F.cosine_similarity(feature.unsqueeze(2), f, dim=1).unsqueeze(1) 67 | 68 | aff = aff/torch.sum(aff, dim=2, keepdim=True) # [Bx1xPxHxW] 69 | # print(aff[0, 0, :, h//2, w//2]) 70 | # dim2 represent the p neighbor-pixels' affinity 71 | 72 | feature = torch.sum(f * aff, dim=2) 73 | 74 | return feature 75 | 76 | 77 | if __name__ == '__main__': 78 | feature = torch.randn([4, 64, 256, 256], device='cuda') 79 | 80 | model = FBFModule() 81 | output = model(feature) 82 | print(output.shape) -------------------------------------------------------------------------------- /models/network/DeeplabV3.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | 3 | # @Filename: DeeplabV3 4 | # @Project : Glory 5 | # @date : 2020-12-28 21:52 6 | # @Author : Linshan 7 | 8 | import torch 9 | import math 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | from itertools import chain 13 | from models.network.resnet import * 14 | from models.network.segformer import * 15 | ''' 16 | -> ResNet BackBone 17 | ''' 18 | 19 | 20 | def initialize_weights(*models): 21 | for model in models: 22 | for m in model.modules(): 23 | if isinstance(m, nn.Conv2d): 24 | nn.init.kaiming_normal_(m.weight.data, nonlinearity='relu') 25 | elif isinstance(m, nn.BatchNorm2d): 26 | m.weight.data.fill_(1.) 27 | m.bias.data.fill_(1e-4) 28 | elif isinstance(m, nn.Linear): 29 | m.weight.data.normal_(0.0, 0.0001) 30 | m.bias.data.zero_() 31 | 32 | 33 | ''' 34 | -> The Atrous Spatial Pyramid Pooling 35 | ''' 36 | 37 | 38 | def assp_branch(in_channels, out_channles, kernel_size, dilation): 39 | padding = 0 if kernel_size == 1 else dilation 40 | return nn.Sequential( 41 | nn.Conv2d(in_channels, out_channles, kernel_size, padding=padding, dilation=dilation, bias=False), 42 | nn.BatchNorm2d(out_channles), 43 | nn.ReLU(inplace=True)) 44 | 45 | 46 | class ASSP(nn.Module): 47 | def __init__(self, in_channels, aspp_stride): 48 | super(ASSP, self).__init__() 49 | 50 | assert aspp_stride in [8, 16], 'Only output strides of 8 or 16 are suported' 51 | if aspp_stride == 16: 52 | dilations = [1, 6, 12, 18] 53 | elif aspp_stride == 8: 54 | dilations = [1, 12, 24, 36] 55 | 56 | self.aspp1 = assp_branch(in_channels, 256, 1, dilation=dilations[0]) 57 | self.aspp2 = assp_branch(in_channels, 256, 3, dilation=dilations[1]) 58 | self.aspp3 = assp_branch(in_channels, 256, 3, dilation=dilations[2]) 59 | self.aspp4 = assp_branch(in_channels, 256, 3, dilation=dilations[3]) 60 | 61 | self.avg_pool = nn.Sequential( 62 | nn.AdaptiveAvgPool2d((1, 1)), 63 | nn.Conv2d(in_channels, 256, 1, bias=False), 64 | nn.BatchNorm2d(256), 65 | nn.ReLU(inplace=True)) 66 | 67 | self.conv1 = nn.Conv2d(256 * 5, 256, 1, bias=False) 68 | self.bn1 = nn.BatchNorm2d(256) 69 | self.relu = nn.ReLU(inplace=True) 70 | self.dropout = nn.Dropout(0.5) 71 | 72 | initialize_weights(self) 73 | 74 | def forward(self, x): 75 | x1 = self.aspp1(x) 76 | x2 = self.aspp2(x) 77 | x3 = self.aspp3(x) 78 | x4 = self.aspp4(x) 79 | 80 | x5 = F.interpolate(self.avg_pool(x), size=(x.size(2), x.size(3)), mode='bilinear', align_corners=True) 81 | 82 | x = self.conv1(torch.cat((x1, x2, x3, x4, x5), dim=1)) 83 | x = self.bn1(x) 84 | x = self.dropout(self.relu(x)) 85 | 86 | return x 87 | 88 | 89 | ''' 90 | -> Decoder 91 | ''' 92 | 93 | 94 | class Decoder(nn.Module): 95 | def __init__(self, low_level_channels): 96 | super(Decoder, self).__init__() 97 | self.conv1 = nn.Conv2d(low_level_channels, 48, 1, bias=False) 98 | self.bn1 = nn.BatchNorm2d(48) 99 | self.relu = nn.ReLU(inplace=True) 100 | 101 | # Table 2, best performance with two 3x3 convs 102 | self.output = nn.Sequential( 103 | nn.Conv2d(48 + 256, 256, 3, stride=1, padding=1, bias=False), 104 | nn.BatchNorm2d(256), 105 | nn.ReLU(inplace=True), 106 | nn.Conv2d(256, 256, 3, stride=1, padding=1, bias=False), 107 | nn.BatchNorm2d(256), 108 | nn.ReLU(inplace=True) 109 | ) 110 | initialize_weights(self) 111 | 112 | def forward(self, x, low_level_features): 113 | low_level_features = self.conv1(low_level_features) 114 | low_level_features = self.relu(self.bn1(low_level_features)) 115 | H, W = low_level_features.size(2), low_level_features.size(3) 116 | 117 | x = F.interpolate(x, size=(H, W), mode='bilinear', align_corners=True) 118 | x = self.output(torch.cat((low_level_features, x), dim=1)) 119 | return x 120 | 121 | 122 | ''' 123 | -> Deeplab V3 + 124 | ''' 125 | 126 | 127 | class DeepLab(nn.Module): 128 | def __init__(self, backbone, in_channels, pretrained=True, 129 | output_stride=16, aspp_stride=8, freeze_bn=False, **_): 130 | 131 | super(DeepLab, self).__init__() 132 | assert ('resnet' or 'resnest' in backbone) 133 | 134 | if 'resnet18' in backbone: 135 | self.backbone = ResNet18(output_stride=output_stride, pretrained=pretrained, in_c=in_channels) 136 | low_level_channels = 64 137 | aspp_channels = 512 138 | 139 | elif 'resnet34' in backbone: 140 | self.backbone = ResNet34(output_stride=output_stride, pretrained=pretrained, in_c=in_channels) 141 | low_level_channels = 64 142 | aspp_channels = 512 143 | 144 | elif 'resnet50' in backbone: 145 | self.backbone = ResNet50(output_stride=output_stride, pretrained=pretrained, in_c=in_channels) 146 | low_level_channels = 256 147 | aspp_channels = 2048 148 | 149 | elif 'resnet101' in backbone: 150 | self.backbone = ResNet101(output_stride=output_stride, pretrained=pretrained, in_c=in_channels) 151 | low_level_channels = 256 152 | aspp_channels = 2048 153 | 154 | elif 'mit_b0' in backbone: 155 | self.backbone = mit_b0() 156 | checkpoint = torch.load('/media/hlf/Luffy/WLS/ContestCD/pretrained_former/segformer_b0.pth', map_location=torch.device('cpu')) 157 | self.backbone.load_state_dict(checkpoint['state_dict']) 158 | print('load pretrained transformer') 159 | low_level_channels = 32 160 | aspp_channels = 256 161 | 162 | elif 'mit_b3' in backbone: 163 | self.backbone = mit_b3() 164 | checkpoint = torch.load('/media/hlf/Luffy/WLS/ContestCD/pretrained_former/segformer_b3.pth', map_location=torch.device('cpu')) 165 | self.backbone.load_state_dict(checkpoint['state_dict']) 166 | print('load pretrained transformer') 167 | low_level_channels = 64 168 | aspp_channels = 512 169 | 170 | else: 171 | low_level_channels = None 172 | aspp_channels = None 173 | 174 | self.ASSP = ASSP(in_channels=aspp_channels, aspp_stride=aspp_stride) 175 | 176 | self.decoder = Decoder(low_level_channels) 177 | 178 | if freeze_bn: 179 | self.freeze_bn() 180 | 181 | def forward(self, x): 182 | _, _, h, w = x.size() 183 | x = self.backbone(x) 184 | 185 | feature = self.ASSP(x[3]) 186 | x = self.decoder(feature, x[0]) 187 | # x = F.interpolate(x, size=(h, w), mode='bilinear') 188 | 189 | return x 190 | 191 | def freeze_bn(self): 192 | for module in self.modules(): 193 | if isinstance(module, nn.BatchNorm2d): module.eval() 194 | 195 | 196 | if __name__ == '__main__': 197 | 198 | model = DeepLab(num_classes=4) 199 | 200 | x = torch.rand(2, 4, 256, 256) 201 | x = model(x) 202 | 203 | print(x.shape) -------------------------------------------------------------------------------- /models/network/LinkNet.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | 3 | # @Filename: LinkNet 4 | # @Project : Glory 5 | # @date : 2020-12-28 21:29 6 | # @Author : Linshan 7 | 8 | import torch 9 | import torch.nn as nn 10 | from torch.autograd import Variable 11 | import torch.nn.functional as F 12 | from functools import partial 13 | from models.network.resnet import * 14 | 15 | nonlinearity = partial(F.relu, inplace=True) 16 | 17 | 18 | class Dblock_more_dilate(nn.Module): 19 | def __init__(self, channel): 20 | super(Dblock_more_dilate, self).__init__() 21 | self.dilate1 = nn.Conv2d(channel, channel, kernel_size=3, dilation=1, padding=1) 22 | self.dilate2 = nn.Conv2d(channel, channel, kernel_size=3, dilation=2, padding=2) 23 | self.dilate3 = nn.Conv2d(channel, channel, kernel_size=3, dilation=4, padding=4) 24 | self.dilate4 = nn.Conv2d(channel, channel, kernel_size=3, dilation=8, padding=8) 25 | self.dilate5 = nn.Conv2d(channel, channel, kernel_size=3, dilation=16, padding=16) 26 | for m in self.modules(): 27 | if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d): 28 | if m.bias is not None: 29 | m.bias.data.zero_() 30 | 31 | def forward(self, x): 32 | dilate1_out = nonlinearity(self.dilate1(x)) 33 | dilate2_out = nonlinearity(self.dilate2(dilate1_out)) 34 | dilate3_out = nonlinearity(self.dilate3(dilate2_out)) 35 | dilate4_out = nonlinearity(self.dilate4(dilate3_out)) 36 | dilate5_out = nonlinearity(self.dilate5(dilate4_out)) 37 | out = x + dilate1_out + dilate2_out + dilate3_out + dilate4_out + dilate5_out 38 | return out 39 | 40 | 41 | class Dblock(nn.Module): 42 | def __init__(self, channel): 43 | super(Dblock, self).__init__() 44 | self.dilate1 = nn.Conv2d(channel, channel, kernel_size=3, dilation=1, padding=1) 45 | self.dilate2 = nn.Conv2d(channel, channel, kernel_size=3, dilation=2, padding=2) 46 | self.dilate3 = nn.Conv2d(channel, channel, kernel_size=3, dilation=4, padding=4) 47 | self.dilate4 = nn.Conv2d(channel, channel, kernel_size=3, dilation=8, padding=8) 48 | # self.dilate5 = nn.Conv2d(channel, channel, kernel_size=3, dilation=16, padding=16) 49 | for m in self.modules(): 50 | if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d): 51 | if m.bias is not None: 52 | m.bias.data.zero_() 53 | 54 | def forward(self, x): 55 | dilate1_out = nonlinearity(self.dilate1(x)) 56 | dilate2_out = nonlinearity(self.dilate2(dilate1_out)) 57 | dilate3_out = nonlinearity(self.dilate3(dilate2_out)) 58 | dilate4_out = nonlinearity(self.dilate4(dilate3_out)) 59 | # dilate5_out = nonlinearity(self.dilate5(dilate4_out)) 60 | out = x + dilate1_out + dilate2_out + dilate3_out + dilate4_out # + dilate5_out 61 | return out 62 | 63 | 64 | class DecoderBlock(nn.Module): 65 | def __init__(self, in_channels, n_filters): 66 | super(DecoderBlock, self).__init__() 67 | 68 | self.conv1 = nn.Conv2d(in_channels, in_channels // 4, 1) 69 | self.norm1 = nn.BatchNorm2d(in_channels // 4) 70 | self.relu1 = nonlinearity 71 | 72 | self.deconv2 = nn.ConvTranspose2d(in_channels // 4, in_channels // 4, 3, stride=2, padding=1, output_padding=1) 73 | self.norm2 = nn.BatchNorm2d(in_channels // 4) 74 | self.relu2 = nonlinearity 75 | 76 | self.conv3 = nn.Conv2d(in_channels // 4, n_filters, 1) 77 | self.norm3 = nn.BatchNorm2d(n_filters) 78 | self.relu3 = nonlinearity 79 | 80 | def forward(self, x): 81 | x = self.conv1(x) 82 | x = self.norm1(x) 83 | x = self.relu1(x) 84 | x = self.deconv2(x) 85 | x = self.norm2(x) 86 | x = self.relu2(x) 87 | x = self.conv3(x) 88 | x = self.norm3(x) 89 | x = self.relu3(x) 90 | return x 91 | 92 | 93 | class DLinkNet(nn.Module): 94 | def __init__(self, opt): 95 | super(DLinkNet, self).__init__() 96 | if 'resnet50' in opt.backbone: 97 | self.backbone = ResNet50(pretrained=True) 98 | filters = [256, 512, 1024, 2048] 99 | 100 | elif 'resnet101' in opt.backbone: 101 | self.backbone = opt.ResNet101(pretrained=True) 102 | filters = [256, 512, 1024, 2048] 103 | 104 | elif 'resnet34' in opt.backbone: 105 | self.backbone = ResNet34(pretrained=True) 106 | filters = [64, 128, 256, 512] 107 | 108 | elif 'resnet18' in opt.backbone: 109 | self.backbone = ResNet18(pretrained=True) 110 | filters = [64, 128, 256, 512] 111 | else: 112 | filters = None 113 | 114 | self.dblock_master = Dblock(filters[3]) 115 | 116 | self.decoder4_master = DecoderBlock(filters[3], filters[2]) 117 | self.decoder3_master = DecoderBlock(filters[2], filters[1]) 118 | self.decoder2_master = DecoderBlock(filters[1], filters[0]) 119 | self.decoder1_master = DecoderBlock(filters[0], filters[0]) 120 | 121 | self.finaldeconv1_master = nn.ConvTranspose2d(filters[0], 32, 4, 2, 1) 122 | self.finalrelu1_master = nonlinearity 123 | 124 | def forward(self, x): 125 | x = self.backbone(x) 126 | 127 | # center_master 128 | e4 = self.dblock_master(x[3]) 129 | # decoder_master 130 | d4 = self.decoder4_master(e4) 131 | d3 = self.decoder3_master(d4) 132 | d2 = self.decoder2_master(d3) 133 | d1 = self.decoder1_master(d2) 134 | 135 | out = self.finaldeconv1_master(d1) 136 | out = self.finalrelu1_master(out) 137 | 138 | return out 139 | 140 | 141 | if __name__ == '__main__': 142 | 143 | model = DLinkNet() 144 | 145 | x = torch.rand(2, 4, 512, 512) 146 | x = model(x) 147 | 148 | print(x.shape) -------------------------------------------------------------------------------- /models/network/__init__.py: -------------------------------------------------------------------------------- 1 | from models.network.LinkNet import * 2 | from models.network.DeeplabV3 import * 3 | from models.network.STANet import STANet 4 | from models.network.fcn import FCN, UNet 5 | from models.network.resnet import * 6 | 7 | 8 | -------------------------------------------------------------------------------- /models/network/fcn.py: -------------------------------------------------------------------------------- 1 | import functools 2 | 3 | import numpy as np 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | from models.network.LinkNet import * 9 | 10 | 11 | class FCN(nn.Module): 12 | def __init__(self, opt): 13 | super(FCN, self).__init__() 14 | if 'resnet50' in opt.backbone: 15 | self.backbone = ResNet50(pretrained=True) 16 | filters = [256, 512, 1024, 2048] 17 | 18 | elif 'resnet101' in opt.backbone: 19 | self.backbone = opt.ResNet101(pretrained=True) 20 | filters = [256, 512, 1024, 2048] 21 | 22 | elif 'resnet34' in opt.backbone: 23 | self.backbone = ResNet34(pretrained=True) 24 | filters = [64, 128, 256, 512] 25 | 26 | elif 'resnet18' in opt.backbone: 27 | self.backbone = ResNet18(pretrained=True) 28 | filters = [64, 128, 256, 512] 29 | else: 30 | filters = None 31 | 32 | self.decoder4_master = DecoderBlock(filters[3], filters[2]) 33 | self.decoder3_master = DecoderBlock(filters[2], filters[1]) 34 | self.decoder2_master = DecoderBlock(filters[1], filters[0]) 35 | self.decoder1_master = DecoderBlock(filters[0], filters[0]) 36 | 37 | self.finaldeconv1_master = nn.ConvTranspose2d(filters[0], 32, 4, 2, 1) 38 | self.finalrelu1_master = nonlinearity 39 | 40 | def forward(self, x): 41 | x = self.backbone(x) 42 | 43 | # center_master 44 | e4 = x[3] 45 | # decoder_master 46 | d4 = self.decoder4_master(e4) 47 | d3 = self.decoder3_master(d4) 48 | d2 = self.decoder2_master(d3) 49 | d1 = self.decoder1_master(d2) 50 | 51 | out = self.finaldeconv1_master(d1) 52 | out = self.finalrelu1_master(out) 53 | 54 | return out 55 | 56 | 57 | class UNet(nn.Module): 58 | def __init__(self, opt): 59 | super(UNet, self).__init__() 60 | if 'resnet50' in opt.backbone: 61 | self.backbone = ResNet50(pretrained=True) 62 | filters = [256, 512, 1024, 2048] 63 | 64 | elif 'resnet101' in opt.backbone: 65 | self.backbone = opt.ResNet101(pretrained=True) 66 | filters = [256, 512, 1024, 2048] 67 | 68 | elif 'resnet34' in opt.backbone: 69 | self.backbone = ResNet34(pretrained=True) 70 | filters = [64, 128, 256, 512] 71 | 72 | elif 'resnet18' in opt.backbone: 73 | self.backbone = ResNet18(pretrained=True) 74 | filters = [64, 128, 256, 512] 75 | else: 76 | filters = None 77 | 78 | self.decoder4_master = DecoderBlock(filters[3], filters[2]) 79 | self.decoder3_master = DecoderBlock(filters[2]*2, filters[1]) 80 | self.decoder2_master = DecoderBlock(filters[1]*2, filters[0]) 81 | self.decoder1_master = DecoderBlock(filters[0]*2, filters[0]) 82 | 83 | self.finaldeconv1_master = nn.ConvTranspose2d(filters[0], 32, 4, 2, 1) 84 | self.finalrelu1_master = nonlinearity 85 | 86 | def forward(self, x): 87 | x = self.backbone(x) 88 | 89 | # center_master 90 | e4 = x[3] 91 | # decoder_master 92 | d4 = torch.cat([self.decoder4_master(e4), x[2]], dim=1) 93 | d3 = torch.cat([self.decoder3_master(d4), x[1]], dim=1) 94 | d2 = torch.cat([self.decoder2_master(d3), x[0]], dim=1) 95 | d1 = self.decoder1_master(d2) 96 | 97 | out = self.finaldeconv1_master(d1) 98 | out = self.finalrelu1_master(out) 99 | 100 | 101 | return out -------------------------------------------------------------------------------- /models/network/fpn.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch.nn as nn 3 | import torch.utils.model_zoo as model_zoo 4 | import torch.nn.functional as F 5 | 6 | 7 | def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): 8 | """3x3 convolution with padding""" 9 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 10 | padding=dilation, groups=groups, bias=False, dilation=dilation) 11 | 12 | 13 | def conv1x1(in_planes, out_planes, stride=1): 14 | """1x1 convolution""" 15 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) 16 | 17 | 18 | class BasicBlock(nn.Module): 19 | expansion = 1 20 | 21 | def __init__(self, inplanes, planes, stride=1, dilation=1, downsample=None, BatchNorm=None, groups=1, 22 | base_width=64): 23 | super(BasicBlock, self).__init__() 24 | if BatchNorm is None: 25 | BatchNorm = nn.BatchNorm2d 26 | if groups != 1 or base_width != 64: 27 | raise ValueError('BasicBlock only supports groups=1 and base_width=64') 28 | if dilation > 1: 29 | raise NotImplementedError("Dilation > 1 not supported in BasicBlock") 30 | # Both self.conv1 and self.downsample layers downsample the input when stride != 1 31 | self.conv1 = conv3x3(inplanes, planes, stride) 32 | self.bn1 = BatchNorm(planes) 33 | self.relu = nn.ReLU(inplace=True) 34 | self.conv2 = conv3x3(planes, planes) 35 | self.bn2 = BatchNorm(planes) 36 | self.downsample = downsample 37 | self.stride = stride 38 | 39 | def forward(self, x): 40 | identity = x 41 | 42 | out = self.conv1(x) 43 | out = self.bn1(out) 44 | out = self.relu(out) 45 | 46 | out = self.conv2(out) 47 | out = self.bn2(out) 48 | 49 | if self.downsample is not None: 50 | identity = self.downsample(x) 51 | 52 | out += identity 53 | out = self.relu(out) 54 | 55 | return out 56 | 57 | 58 | class Bottleneck(nn.Module): 59 | expansion = 4 60 | 61 | def __init__(self, inplanes, planes, stride=1, dilation=1, downsample=None, BatchNorm=None, groups=1, 62 | base_width=64): 63 | super(Bottleneck, self).__init__() 64 | width = int(planes * (base_width / 64.)) * groups 65 | self.conv1 = nn.Conv2d(inplanes, width, kernel_size=1, bias=False) 66 | self.bn1 = BatchNorm(width) 67 | self.conv2 = nn.Conv2d(width, width, kernel_size=3, stride=stride, 68 | dilation=dilation, padding=dilation, bias=False, groups=groups) 69 | self.bn2 = BatchNorm(width) 70 | self.conv3 = nn.Conv2d(width, planes * 4, kernel_size=1, bias=False) 71 | self.bn3 = BatchNorm(planes * 4) 72 | self.relu = nn.ReLU(inplace=True) 73 | self.downsample = downsample 74 | self.stride = stride 75 | self.dilation = dilation 76 | 77 | def forward(self, x): 78 | residual = x 79 | 80 | out = self.conv1(x) 81 | out = self.bn1(out) 82 | out = self.relu(out) 83 | 84 | out = self.conv2(out) 85 | out = self.bn2(out) 86 | out = self.relu(out) 87 | 88 | out = self.conv3(out) 89 | out = self.bn3(out) 90 | 91 | if self.downsample is not None: 92 | residual = self.downsample(x) 93 | 94 | out += residual 95 | out = self.relu(out) 96 | 97 | return out 98 | 99 | 100 | class ResNet(nn.Module): 101 | 102 | def __init__(self, arch, block, layers, output_stride, BatchNorm, pretrained=True): 103 | self.inplanes = 64 104 | self.layers = layers 105 | self.arch = arch 106 | super(ResNet, self).__init__() 107 | blocks = [1, 2, 4] 108 | if output_stride == 16: 109 | strides = [1, 2, 2, 1] 110 | dilations = [1, 1, 1, 2] 111 | elif output_stride == 8: 112 | strides = [1, 2, 1, 1] 113 | dilations = [1, 1, 2, 4] 114 | else: 115 | strides = [1, 2, 2, 2] 116 | dilations = [1, 1, 1, 1] 117 | 118 | if arch == 'resnext50': 119 | self.base_width = 4 120 | self.groups = 32 121 | else: 122 | self.base_width = 64 123 | self.groups = 1 124 | 125 | # Modules 126 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, 127 | bias=False) 128 | self.bn1 = BatchNorm(64) 129 | self.relu = nn.ReLU(inplace=True) 130 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 131 | 132 | self.layer1 = self._make_layer(block, 64, layers[0], stride=strides[0], dilation=dilations[0], 133 | BatchNorm=BatchNorm) 134 | self.layer2 = self._make_layer(block, 128, layers[1], stride=strides[1], dilation=dilations[1], 135 | BatchNorm=BatchNorm) 136 | self.layer3 = self._make_layer(block, 256, layers[2], stride=strides[2], dilation=dilations[2], 137 | BatchNorm=BatchNorm) 138 | if self.arch == 'resnet18': 139 | self.layer4 = self._make_layer(block, 512, layers[3], stride=strides[3], dilation=dilations[3], 140 | BatchNorm=BatchNorm) 141 | else: 142 | self.layer4 = self._make_MG_unit(block, 512, blocks=blocks, stride=strides[3], dilation=dilations[3], 143 | BatchNorm=BatchNorm) 144 | 145 | if self.arch == 'resnet18': 146 | self.toplayer = nn.Conv2d(512, 128, kernel_size=1, stride=1, padding=0) 147 | self.latlayer1 = nn.Conv2d(256, 128, kernel_size=1, stride=1, padding=0) 148 | self.latlayer2 = nn.Conv2d(128, 128, kernel_size=1, stride=1, padding=0) 149 | self.latlayer3 = nn.Conv2d(64, 128, kernel_size=1, stride=1, padding=0) 150 | else: 151 | self.toplayer = nn.Conv2d(2048, 256, kernel_size=1, stride=1, padding=0) 152 | self.latlayer1 = nn.Conv2d(1024, 256, kernel_size=1, stride=1, padding=0) 153 | self.latlayer2 = nn.Conv2d(512, 256, kernel_size=1, stride=1, padding=0) 154 | self.latlayer3 = nn.Conv2d(256, 256, kernel_size=1, stride=1, padding=0) 155 | 156 | self.smooth1 = nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1) 157 | self.smooth2 = nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1) 158 | self.smooth3 = nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1) 159 | 160 | self._init_weight() 161 | 162 | if pretrained: 163 | self._load_pretrained_model() 164 | 165 | def _make_layer(self, block, planes, blocks, stride=1, dilation=1, BatchNorm=None): 166 | downsample = None 167 | if stride != 1 or self.inplanes != planes * block.expansion: 168 | downsample = nn.Sequential( 169 | nn.Conv2d(self.inplanes, planes * block.expansion, 170 | kernel_size=1, stride=stride, bias=False), 171 | BatchNorm(planes * block.expansion), 172 | ) 173 | 174 | layers = [] 175 | layers.append(block(self.inplanes, planes, stride, dilation, downsample, BatchNorm, groups=self.groups, 176 | base_width=self.base_width)) 177 | self.inplanes = planes * block.expansion 178 | for i in range(1, blocks): 179 | layers.append(block(self.inplanes, planes, dilation=dilation, BatchNorm=BatchNorm, groups=self.groups, 180 | base_width=self.base_width)) 181 | 182 | return nn.Sequential(*layers) 183 | 184 | def _make_MG_unit(self, block, planes, blocks, stride=1, dilation=1, BatchNorm=None): 185 | downsample = None 186 | if stride != 1 or self.inplanes != planes * block.expansion: 187 | downsample = nn.Sequential( 188 | nn.Conv2d(self.inplanes, planes * block.expansion, 189 | kernel_size=1, stride=stride, bias=False), 190 | BatchNorm(planes * block.expansion), 191 | ) 192 | 193 | layers = [] 194 | layers.append(block(self.inplanes, planes, stride, dilation=blocks[0] * dilation, 195 | downsample=downsample, BatchNorm=BatchNorm, groups=self.groups, base_width=self.base_width)) 196 | self.inplanes = planes * block.expansion 197 | for i in range(1, len(blocks)): 198 | layers.append(block(self.inplanes, planes, stride=1, 199 | dilation=blocks[i] * dilation, BatchNorm=BatchNorm, groups=self.groups, 200 | base_width=self.base_width)) 201 | 202 | return nn.Sequential(*layers) 203 | 204 | def forward(self, input): 205 | # Bottom-up 206 | x = self.conv1(input) 207 | x = self.bn1(x) 208 | x = self.relu(x) 209 | c1 = self.maxpool(x) 210 | c2 = self.layer1(c1) # x4 211 | c3 = self.layer2(c2) # x8 212 | c4 = self.layer3(c3) # x16 213 | c5 = self.layer4(c4) # x16 214 | 215 | # Top-down 216 | p5 = self.toplayer(c5) 217 | p4 = F.upsample(p5, size=c4.size()[2:], mode='bilinear') + self.latlayer1(c4) 218 | p3 = F.upsample(p4, size=c3.size()[2:], mode='bilinear') + self.latlayer2(c3) 219 | p2 = F.upsample(p3, size=c2.size()[2:], mode='bilinear') + self.latlayer3(c2) 220 | 221 | p4 = self.smooth1(p4) 222 | p3 = self.smooth2(p3) 223 | p2 = self.smooth3(p2) 224 | 225 | return p2, p3, p4, p5 226 | 227 | def _init_weight(self): 228 | for m in self.modules(): 229 | if isinstance(m, nn.Conv2d): 230 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 231 | m.weight.data.normal_(0, math.sqrt(2. / n)) 232 | elif isinstance(m, nn.BatchNorm2d): 233 | m.weight.data.fill_(1) 234 | m.bias.data.zero_() 235 | 236 | def _load_pretrained_model(self): 237 | if self.arch == 'resnet101': 238 | pretrain_dict = model_zoo.load_url('https://download.pytorch.org/models/resnet101-5d3b4d8f.pth') 239 | elif self.arch == 'resnet50': 240 | pretrain_dict = model_zoo.load_url('https://download.pytorch.org/models/resnet50-19c8e357.pth') 241 | elif self.arch == 'resnet18': 242 | pretrain_dict = model_zoo.load_url('https://download.pytorch.org/models/resnet18-5c106cde.pth') 243 | elif self.arch == 'resnext50': 244 | pretrain_dict = model_zoo.load_url('https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth') 245 | model_dict = {} 246 | state_dict = self.state_dict() 247 | for k, v in pretrain_dict.items(): 248 | if k in state_dict: 249 | model_dict[k] = v 250 | state_dict.update(model_dict) 251 | self.load_state_dict(state_dict) 252 | 253 | 254 | def FPN101(output_stride, BatchNorm=nn.BatchNorm2d, pretrained=True): 255 | """Constructs a ResNet-101 model. 256 | Args: 257 | pretrained (bool): If True, returns a model pre-trained on ImageNet 258 | """ 259 | model = ResNet('resnet101', Bottleneck, [3, 4, 23, 3], output_stride, BatchNorm, pretrained=pretrained) 260 | return model 261 | 262 | 263 | def FPN50(output_stride, BatchNorm=nn.BatchNorm2d, pretrained=True): 264 | """Constructs a ResNet-50 model. 265 | Args: 266 | pretrained (bool): If True, returns a model pre-trained on ImageNet 267 | """ 268 | model = ResNet('resnet50', Bottleneck, [3, 4, 6, 3], output_stride, BatchNorm, pretrained=pretrained) 269 | return model 270 | 271 | 272 | def FPN18(output_stride, BatchNorm=nn.BatchNorm2d, pretrained=True): 273 | model = ResNet('resnet18', BasicBlock, [2, 2, 2, 2], output_stride, BatchNorm, pretrained=pretrained) 274 | return model 275 | 276 | 277 | def ResNext50_FPN(output_stride, BatchNorm=nn.BatchNorm2d, pretrained=True): 278 | model = ResNet('resnext50', Bottleneck, [3, 4, 6, 3], output_stride, BatchNorm, pretrained=pretrained) 279 | return model 280 | 281 | 282 | if __name__ == "__main__": 283 | import torch 284 | from torchstat import stat 285 | 286 | model = FPN18(BatchNorm=nn.BatchNorm2d, pretrained=True, output_stride=32) 287 | stat(model.cpu(), input_size=(3, 256, 256)) 288 | 289 | input = torch.rand(1, 3, 480, 640) 290 | output = model(input) 291 | for out in output: 292 | print(out.size()) 293 | # print(low_level_feat.size()) -------------------------------------------------------------------------------- /models/network/resnet.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | 3 | # @Filename: Deeplab 4 | # @Project : Glory 5 | # @date : 2020-12-28 21:47 6 | # @Author : Linshan 7 | 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | import torch.utils.model_zoo as model_zoo 12 | import math 13 | 14 | 15 | def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): 16 | """3x3 convolution with padding""" 17 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 18 | padding=dilation, groups=groups, bias=False, dilation=dilation) 19 | 20 | 21 | model_urls = { 22 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', 23 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', 24 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', 25 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', 26 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', 27 | } 28 | 29 | 30 | def ResNet18(output_stride=32, BatchNorm=nn.BatchNorm2d, pretrained=True, in_c=3): 31 | """ 32 | output, low_level_feat: 33 | 512, 256, 128, 64, 64 34 | """ 35 | model = ResNet(BasicBlock, [2, 2, 2, 2], output_stride, BatchNorm, in_c=in_c) 36 | if in_c != 3: 37 | pretrained = False 38 | if pretrained: 39 | model._load_pretrained_model(model_urls['resnet18']) 40 | return model 41 | 42 | 43 | def ResNet34(output_stride=32, BatchNorm=nn.BatchNorm2d, pretrained=True, in_c=3): 44 | """ 45 | output, low_level_feat: 46 | 512, 64 47 | """ 48 | model = ResNet(BasicBlock, [3, 4, 6, 3], output_stride, BatchNorm, in_c=in_c) 49 | if in_c != 3: 50 | pretrained = False 51 | if pretrained: 52 | model._load_pretrained_model(model_urls['resnet34']) 53 | return model 54 | 55 | 56 | def ResNet50(output_stride=32, BatchNorm=nn.BatchNorm2d, pretrained=True, in_c=3): 57 | """ 58 | output, low_level_feat: 59 | 2048, 256 60 | """ 61 | model = ResNet(Bottleneck, [3, 4, 6, 3], output_stride, BatchNorm, in_c=in_c) 62 | if in_c != 3: 63 | pretrained=False 64 | if pretrained: 65 | model._load_pretrained_model(model_urls['resnet50']) 66 | return model 67 | 68 | 69 | def ResNet101(output_stride=32, BatchNorm=nn.BatchNorm2d, pretrained=True, in_c=3): 70 | """ 71 | output, low_level_feat: 72 | 2048, 256 73 | """ 74 | model = ResNet(Bottleneck, [3, 4, 23, 3], output_stride, BatchNorm, in_c=in_c) 75 | if in_c != 3: 76 | pretrained=False 77 | if pretrained: 78 | model._load_pretrained_model(model_urls['resnet101']) 79 | return model 80 | 81 | 82 | class BasicBlock(nn.Module): 83 | expansion = 1 84 | 85 | def __init__(self, inplanes, planes, stride=1, dilation=1, downsample=None, BatchNorm=None): 86 | super(BasicBlock, self).__init__() 87 | 88 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=3, stride=stride, 89 | dilation=dilation, padding=dilation, bias=False) 90 | self.bn1 = BatchNorm(planes) 91 | 92 | self.relu = nn.ReLU(inplace=True) 93 | self.conv2 = conv3x3(planes, planes) 94 | self.bn2 = BatchNorm(planes) 95 | self.downsample = downsample 96 | self.stride = stride 97 | 98 | def forward(self, x): 99 | identity = x 100 | 101 | out = self.conv1(x) 102 | out = self.bn1(out) 103 | out = self.relu(out) 104 | 105 | out = self.conv2(out) 106 | out = self.bn2(out) 107 | 108 | if self.downsample is not None: 109 | identity = self.downsample(x) 110 | 111 | out += identity 112 | out = self.relu(out) 113 | 114 | return out 115 | 116 | 117 | class Bottleneck(nn.Module): 118 | expansion = 4 119 | 120 | def __init__(self, inplanes, planes, stride=1, dilation=1, downsample=None, BatchNorm=None): 121 | super(Bottleneck, self).__init__() 122 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 123 | self.bn1 = BatchNorm(planes) 124 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 125 | dilation=dilation, padding=dilation, bias=False) 126 | self.bn2 = BatchNorm(planes) 127 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) 128 | self.bn3 = BatchNorm(planes * 4) 129 | self.relu = nn.ReLU(inplace=True) 130 | self.downsample = downsample 131 | self.stride = stride 132 | self.dilation = dilation 133 | 134 | def forward(self, x): 135 | residual = x 136 | 137 | out = self.conv1(x) 138 | out = self.bn1(out) 139 | out = self.relu(out) 140 | 141 | out = self.conv2(out) 142 | out = self.bn2(out) 143 | out = self.relu(out) 144 | 145 | out = self.conv3(out) 146 | out = self.bn3(out) 147 | 148 | if self.downsample is not None: 149 | residual = self.downsample(x) 150 | 151 | out += residual 152 | out = self.relu(out) 153 | 154 | return out 155 | 156 | 157 | class ResNet(nn.Module): 158 | 159 | def __init__(self, block, layers, output_stride, BatchNorm, in_c=4): 160 | 161 | self.inplanes = 64 162 | self.in_c = in_c 163 | super(ResNet, self).__init__() 164 | blocks = [1, 2, 4] 165 | if output_stride == 32: 166 | strides = [1, 2, 2, 2] 167 | dilations = [1, 1, 1, 1] 168 | elif output_stride == 16: 169 | strides = [1, 2, 2, 1] 170 | dilations = [1, 1, 1, 2] 171 | elif output_stride == 8: 172 | strides = [1, 2, 1, 1] 173 | dilations = [1, 1, 2, 4] 174 | elif output_stride == 4: 175 | strides = [1, 1, 1, 1] 176 | dilations = [1, 2, 4, 8] 177 | else: 178 | raise NotImplementedError 179 | 180 | # Modules 181 | self.conv1 = nn.Conv2d(self.in_c, 64, kernel_size=7, stride=2, padding=3, 182 | bias=False) 183 | self.bn1 = BatchNorm(64) 184 | self.relu = nn.ReLU(inplace=True) 185 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 186 | self.layer1 = self._make_layer(block, 64, layers[0], stride=strides[0], dilation=dilations[0], BatchNorm=BatchNorm) 187 | self.layer2 = self._make_layer(block, 128, layers[1], stride=strides[1], dilation=dilations[1], BatchNorm=BatchNorm) 188 | self.layer3 = self._make_layer(block, 256, layers[2], stride=strides[2], dilation=dilations[2], BatchNorm=BatchNorm) 189 | #self.layer4 = self._make_MG_unit(block, 512, blocks=blocks, stride=strides[3], dilation=dilations[3], BatchNorm=BatchNorm) 190 | self.layer4 = self._make_layer(block, 512, layers[3], stride=strides[3], dilation=dilations[3], BatchNorm=BatchNorm) 191 | self._init_weight() 192 | 193 | def _make_layer(self, block, planes, blocks, stride=1, dilation=1, BatchNorm=None): 194 | downsample = None 195 | if stride != 1 or self.inplanes != planes * block.expansion: 196 | downsample = nn.Sequential( 197 | nn.Conv2d(self.inplanes, planes * block.expansion, 198 | kernel_size=1, stride=stride, bias=False), 199 | BatchNorm(planes * block.expansion), 200 | ) 201 | 202 | layers = [] 203 | layers.append(block(self.inplanes, planes, stride, dilation, downsample, BatchNorm)) 204 | self.inplanes = planes * block.expansion 205 | for i in range(1, blocks): 206 | layers.append(block(self.inplanes, planes, dilation=dilation, BatchNorm=BatchNorm)) 207 | 208 | return nn.Sequential(*layers) 209 | 210 | def _make_MG_unit(self, block, planes, blocks, stride=1, dilation=1, BatchNorm=None): 211 | downsample = None 212 | if stride != 1 or self.inplanes != planes * block.expansion: 213 | downsample = nn.Sequential( 214 | nn.Conv2d(self.inplanes, planes * block.expansion, 215 | kernel_size=1, stride=stride, bias=False), 216 | BatchNorm(planes * block.expansion), 217 | ) 218 | 219 | layers = [] 220 | layers.append(block(self.inplanes, planes, stride, dilation=blocks[0]*dilation, 221 | downsample=downsample, BatchNorm=BatchNorm)) 222 | self.inplanes = planes * block.expansion 223 | for i in range(1, len(blocks)): 224 | layers.append(block(self.inplanes, planes, stride=1, 225 | dilation=blocks[i]*dilation, BatchNorm=BatchNorm)) 226 | 227 | return nn.Sequential(*layers) 228 | 229 | def forward(self, input): 230 | x = self.conv1(input) 231 | x = self.bn1(x) 232 | x = self.relu(x) 233 | x = self.maxpool(x) # | 4 234 | 235 | x = self.layer1(x) # | 4 236 | low_level_feat2 = x # | 4 237 | 238 | x = self.layer2(x) # | 8 239 | low_level_feat3 = x 240 | 241 | x = self.layer3(x) # | 16 242 | low_level_feat4 = x 243 | 244 | x = self.layer4(x) # | 32 245 | 246 | return low_level_feat2, low_level_feat3, low_level_feat4, x 247 | 248 | def _init_weight(self): 249 | for m in self.modules(): 250 | if isinstance(m, nn.Conv2d): 251 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 252 | m.weight.data.normal_(0, math.sqrt(2. / n)) 253 | elif isinstance(m, nn.BatchNorm2d): 254 | m.weight.data.fill_(1) 255 | m.bias.data.zero_() 256 | 257 | def _load_pretrained_model(self, model_path): 258 | pretrain_dict = model_zoo.load_url(model_path) 259 | model_dict = {} 260 | state_dict = self.state_dict() 261 | for k, v in pretrain_dict.items(): 262 | if k in state_dict: 263 | model_dict[k] = v 264 | state_dict.update(model_dict) 265 | self.load_state_dict(state_dict) 266 | print('load pretrained model') -------------------------------------------------------------------------------- /models/network/segformer.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | 3 | # @Filename: segformer 4 | # @Project : ContestCD 5 | # @date : 2021-10-19 18:34 6 | # @Author : Linshan 7 | 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | from functools import partial 12 | 13 | from timm.models.layers import DropPath, to_2tuple, trunc_normal_ 14 | from timm.models.registry import register_model 15 | from timm.models.vision_transformer import _cfg 16 | # from mmseg.models.builder import BACKBONES 17 | # from mmseg.utils import get_root_logger 18 | # from mmcv.runner import load_checkpoint 19 | import math 20 | 21 | 22 | class Mlp(nn.Module): 23 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): 24 | super().__init__() 25 | out_features = out_features or in_features 26 | hidden_features = hidden_features or in_features 27 | self.fc1 = nn.Linear(in_features, hidden_features) 28 | self.dwconv = DWConv(hidden_features) 29 | self.act = act_layer() 30 | self.fc2 = nn.Linear(hidden_features, out_features) 31 | self.drop = nn.Dropout(drop) 32 | 33 | # self.apply(self._init_weights) 34 | 35 | def _init_weights(self, m): 36 | if isinstance(m, nn.Linear): 37 | trunc_normal_(m.weight, std=.02) 38 | if isinstance(m, nn.Linear) and m.bias is not None: 39 | nn.init.constant_(m.bias, 0) 40 | elif isinstance(m, nn.LayerNorm): 41 | nn.init.constant_(m.bias, 0) 42 | nn.init.constant_(m.weight, 1.0) 43 | elif isinstance(m, nn.Conv2d): 44 | fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 45 | fan_out //= m.groups 46 | m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) 47 | if m.bias is not None: 48 | m.bias.data.zero_() 49 | 50 | def forward(self, x, H, W): 51 | x = self.fc1(x) 52 | x = self.dwconv(x, H, W) 53 | x = self.act(x) 54 | x = self.drop(x) 55 | x = self.fc2(x) 56 | x = self.drop(x) 57 | return x 58 | 59 | 60 | class DWConv(nn.Module): 61 | def __init__(self, dim=768): 62 | super(DWConv, self).__init__() 63 | self.dwconv = nn.Conv2d(dim, dim, 3, 1, 1, bias=True, groups=dim) 64 | 65 | def forward(self, x, H, W): 66 | B, N, C = x.shape 67 | x = x.transpose(1, 2).view(B, C, H, W) 68 | x = self.dwconv(x) 69 | x = x.flatten(2).transpose(1, 2) 70 | 71 | return x 72 | 73 | 74 | class Attention(nn.Module): 75 | def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., sr_ratio=1): 76 | super().__init__() 77 | assert dim % num_heads == 0, f"dim {dim} should be divided by num_heads {num_heads}." 78 | 79 | self.dim = dim 80 | self.num_heads = num_heads 81 | head_dim = dim // num_heads 82 | self.scale = qk_scale or head_dim ** -0.5 83 | 84 | self.q = nn.Linear(dim, dim, bias=qkv_bias) 85 | self.kv = nn.Linear(dim, dim * 2, bias=qkv_bias) 86 | self.attn_drop = nn.Dropout(attn_drop) 87 | self.proj = nn.Linear(dim, dim) 88 | self.proj_drop = nn.Dropout(proj_drop) 89 | 90 | self.sr_ratio = sr_ratio 91 | if sr_ratio > 1: 92 | self.sr = nn.Conv2d(dim, dim, kernel_size=sr_ratio, stride=sr_ratio) 93 | self.norm = nn.LayerNorm(dim) 94 | 95 | # self.apply(self._init_weights) 96 | 97 | def _init_weights(self, m): 98 | if isinstance(m, nn.Linear): 99 | trunc_normal_(m.weight, std=.02) 100 | if isinstance(m, nn.Linear) and m.bias is not None: 101 | nn.init.constant_(m.bias, 0) 102 | elif isinstance(m, nn.LayerNorm): 103 | nn.init.constant_(m.bias, 0) 104 | nn.init.constant_(m.weight, 1.0) 105 | elif isinstance(m, nn.Conv2d): 106 | fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 107 | fan_out //= m.groups 108 | m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) 109 | if m.bias is not None: 110 | m.bias.data.zero_() 111 | 112 | def forward(self, x, H, W): 113 | B, N, C = x.shape 114 | q = self.q(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) 115 | 116 | if self.sr_ratio > 1: 117 | x_ = x.permute(0, 2, 1).reshape(B, C, H, W) 118 | x_ = self.sr(x_).reshape(B, C, -1).permute(0, 2, 1) 119 | x_ = self.norm(x_) 120 | kv = self.kv(x_).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 121 | else: 122 | kv = self.kv(x).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 123 | k, v = kv[0], kv[1] 124 | 125 | attn = (q @ k.transpose(-2, -1)) * self.scale 126 | attn = attn.softmax(dim=-1) 127 | attn = self.attn_drop(attn) 128 | 129 | x = (attn @ v).transpose(1, 2).reshape(B, N, C) 130 | x = self.proj(x) 131 | x = self.proj_drop(x) 132 | 133 | return x 134 | 135 | 136 | class Block(nn.Module): 137 | 138 | def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., 139 | drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, sr_ratio=1): 140 | super().__init__() 141 | self.norm1 = norm_layer(dim) 142 | self.attn = Attention( 143 | dim, 144 | num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, 145 | attn_drop=attn_drop, proj_drop=drop, sr_ratio=sr_ratio) 146 | # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here 147 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 148 | self.norm2 = norm_layer(dim) 149 | mlp_hidden_dim = int(dim * mlp_ratio) 150 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) 151 | 152 | # self.apply(self._init_weights) 153 | 154 | def _init_weights(self, m): 155 | if isinstance(m, nn.Linear): 156 | trunc_normal_(m.weight, std=.02) 157 | if isinstance(m, nn.Linear) and m.bias is not None: 158 | nn.init.constant_(m.bias, 0) 159 | elif isinstance(m, nn.LayerNorm): 160 | nn.init.constant_(m.bias, 0) 161 | nn.init.constant_(m.weight, 1.0) 162 | elif isinstance(m, nn.Conv2d): 163 | fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 164 | fan_out //= m.groups 165 | m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) 166 | if m.bias is not None: 167 | m.bias.data.zero_() 168 | 169 | def forward(self, x, H, W): 170 | x = x + self.drop_path(self.attn(self.norm1(x), H, W)) 171 | x = x + self.drop_path(self.mlp(self.norm2(x), H, W)) 172 | 173 | return x 174 | 175 | 176 | class OverlapPatchEmbed(nn.Module): 177 | """ Image to Patch Embedding 178 | """ 179 | 180 | def __init__(self, img_size=224, patch_size=7, stride=4, in_chans=3, embed_dim=768): 181 | super().__init__() 182 | img_size = to_2tuple(img_size) 183 | patch_size = to_2tuple(patch_size) 184 | 185 | self.img_size = img_size 186 | self.patch_size = patch_size 187 | self.H, self.W = img_size[0] // patch_size[0], img_size[1] // patch_size[1] 188 | self.num_patches = self.H * self.W 189 | self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=stride, 190 | padding=(patch_size[0] // 2, patch_size[1] // 2)) 191 | self.norm = nn.LayerNorm(embed_dim) 192 | 193 | # self.apply(self._init_weights) 194 | 195 | def _init_weights(self, m): 196 | if isinstance(m, nn.Linear): 197 | trunc_normal_(m.weight, std=.02) 198 | if isinstance(m, nn.Linear) and m.bias is not None: 199 | nn.init.constant_(m.bias, 0) 200 | elif isinstance(m, nn.LayerNorm): 201 | nn.init.constant_(m.bias, 0) 202 | nn.init.constant_(m.weight, 1.0) 203 | elif isinstance(m, nn.Conv2d): 204 | fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 205 | fan_out //= m.groups 206 | m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) 207 | if m.bias is not None: 208 | m.bias.data.zero_() 209 | 210 | def forward(self, x): 211 | x = self.proj(x) 212 | _, _, H, W = x.shape 213 | x = x.flatten(2).transpose(1, 2) # size (B, H*W, C) 214 | # print(x.shape) 215 | x = self.norm(x) 216 | 217 | return x, H, W 218 | 219 | 220 | class MixVisionTransformer(nn.Module): 221 | def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dims=[64, 128, 256, 512], 222 | num_heads=[1, 2, 4, 8], mlp_ratios=[4, 4, 4, 4], qkv_bias=False, qk_scale=None, drop_rate=0., 223 | attn_drop_rate=0., drop_path_rate=0., norm_layer=nn.LayerNorm, 224 | depths=[3, 4, 6, 3], sr_ratios=[8, 4, 2, 1]): 225 | super().__init__() 226 | self.num_classes = num_classes 227 | self.depths = depths 228 | 229 | # patch_embed 230 | self.patch_embed1 = OverlapPatchEmbed(img_size=img_size, patch_size=7, stride=4, in_chans=in_chans, 231 | embed_dim=embed_dims[0]) 232 | self.patch_embed2 = OverlapPatchEmbed(img_size=img_size // 4, patch_size=3, stride=2, in_chans=embed_dims[0], 233 | embed_dim=embed_dims[1]) 234 | self.patch_embed3 = OverlapPatchEmbed(img_size=img_size // 8, patch_size=3, stride=2, in_chans=embed_dims[1], 235 | embed_dim=embed_dims[2]) 236 | self.patch_embed4 = OverlapPatchEmbed(img_size=img_size // 16, patch_size=3, stride=2, in_chans=embed_dims[2], 237 | embed_dim=embed_dims[3]) 238 | 239 | # transformer encoder 240 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule 241 | cur = 0 242 | self.block1 = nn.ModuleList([Block( 243 | dim=embed_dims[0], num_heads=num_heads[0], mlp_ratio=mlp_ratios[0], qkv_bias=qkv_bias, qk_scale=qk_scale, 244 | drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer, 245 | sr_ratio=sr_ratios[0]) 246 | for i in range(depths[0])]) 247 | self.norm1 = norm_layer(embed_dims[0]) 248 | 249 | cur += depths[0] 250 | self.block2 = nn.ModuleList([Block( 251 | dim=embed_dims[1], num_heads=num_heads[1], mlp_ratio=mlp_ratios[1], qkv_bias=qkv_bias, qk_scale=qk_scale, 252 | drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer, 253 | sr_ratio=sr_ratios[1]) 254 | for i in range(depths[1])]) 255 | self.norm2 = norm_layer(embed_dims[1]) 256 | 257 | cur += depths[1] 258 | self.block3 = nn.ModuleList([Block( 259 | dim=embed_dims[2], num_heads=num_heads[2], mlp_ratio=mlp_ratios[2], qkv_bias=qkv_bias, qk_scale=qk_scale, 260 | drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer, 261 | sr_ratio=sr_ratios[2]) 262 | for i in range(depths[2])]) 263 | self.norm3 = norm_layer(embed_dims[2]) 264 | 265 | cur += depths[2] 266 | self.block4 = nn.ModuleList([Block( 267 | dim=embed_dims[3], num_heads=num_heads[3], mlp_ratio=mlp_ratios[3], qkv_bias=qkv_bias, qk_scale=qk_scale, 268 | drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer, 269 | sr_ratio=sr_ratios[3]) 270 | for i in range(depths[3])]) 271 | self.norm4 = norm_layer(embed_dims[3]) 272 | 273 | # classification head 274 | # self.head = nn.Linear(embed_dims[3], num_classes) if num_classes > 0 else nn.Identity() 275 | 276 | # self.apply(self._init_weights) 277 | 278 | def _init_weights(self, m): 279 | if isinstance(m, nn.Linear): 280 | trunc_normal_(m.weight, std=.02) 281 | if isinstance(m, nn.Linear) and m.bias is not None: 282 | nn.init.constant_(m.bias, 0) 283 | elif isinstance(m, nn.LayerNorm): 284 | nn.init.constant_(m.bias, 0) 285 | nn.init.constant_(m.weight, 1.0) 286 | elif isinstance(m, nn.Conv2d): 287 | fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 288 | fan_out //= m.groups 289 | m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) 290 | if m.bias is not None: 291 | m.bias.data.zero_() 292 | 293 | def reset_drop_path(self, drop_path_rate): 294 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(self.depths))] 295 | cur = 0 296 | for i in range(self.depths[0]): 297 | self.block1[i].drop_path.drop_prob = dpr[cur + i] 298 | 299 | cur += self.depths[0] 300 | for i in range(self.depths[1]): 301 | self.block2[i].drop_path.drop_prob = dpr[cur + i] 302 | 303 | cur += self.depths[1] 304 | for i in range(self.depths[2]): 305 | self.block3[i].drop_path.drop_prob = dpr[cur + i] 306 | 307 | cur += self.depths[2] 308 | for i in range(self.depths[3]): 309 | self.block4[i].drop_path.drop_prob = dpr[cur + i] 310 | 311 | def freeze_patch_emb(self): 312 | self.patch_embed1.requires_grad = False 313 | 314 | @torch.jit.ignore 315 | def no_weight_decay(self): 316 | return {'pos_embed1', 'pos_embed2', 'pos_embed3', 'pos_embed4', 'cls_token'} # has pos_embed may be better 317 | 318 | def get_classifier(self): 319 | return self.head 320 | 321 | def reset_classifier(self, num_classes, global_pool=''): 322 | self.num_classes = num_classes 323 | self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() 324 | 325 | def forward_features(self, x): 326 | B = x.shape[0] 327 | outs = [] 328 | 329 | # stage 1 330 | x, H, W = self.patch_embed1(x) 331 | for i, blk in enumerate(self.block1): 332 | x = blk(x, H, W) 333 | x = self.norm1(x) 334 | x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() 335 | outs.append(x) 336 | 337 | # stage 2 338 | x, H, W = self.patch_embed2(x) 339 | for i, blk in enumerate(self.block2): 340 | x = blk(x, H, W) 341 | x = self.norm2(x) 342 | x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() 343 | outs.append(x) 344 | 345 | # stage 3 346 | x, H, W = self.patch_embed3(x) 347 | for i, blk in enumerate(self.block3): 348 | x = blk(x, H, W) 349 | x = self.norm3(x) 350 | x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() 351 | outs.append(x) 352 | 353 | # stage 4 354 | x, H, W = self.patch_embed4(x) 355 | for i, blk in enumerate(self.block4): 356 | x = blk(x, H, W) 357 | x = self.norm4(x) 358 | x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() 359 | outs.append(x) 360 | 361 | return outs 362 | 363 | def forward(self, x): 364 | x = self.forward_features(x) 365 | # x = self.head(x) 366 | 367 | return x 368 | 369 | 370 | class mit_b0(MixVisionTransformer): 371 | def __init__(self, **kwargs): 372 | super(mit_b0, self).__init__( 373 | patch_size=4, embed_dims=[32, 64, 160, 256], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4], 374 | qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[2, 2, 2, 2], sr_ratios=[8, 4, 2, 1], 375 | drop_rate=0.0, drop_path_rate=0.1) 376 | 377 | 378 | class mit_b1(MixVisionTransformer): 379 | def __init__(self, **kwargs): 380 | super(mit_b1, self).__init__( 381 | patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4], 382 | qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[2, 2, 2, 2], sr_ratios=[8, 4, 2, 1], 383 | drop_rate=0.0, drop_path_rate=0.1) 384 | 385 | 386 | class mit_b2(MixVisionTransformer): 387 | def __init__(self, **kwargs): 388 | super(mit_b2, self).__init__( 389 | patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4], 390 | qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[3, 4, 6, 3], sr_ratios=[8, 4, 2, 1], 391 | drop_rate=0.0, drop_path_rate=0.1) 392 | 393 | 394 | class mit_b3(MixVisionTransformer): 395 | def __init__(self, **kwargs): 396 | super(mit_b3, self).__init__( 397 | patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4], 398 | qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[3, 4, 18, 3], sr_ratios=[8, 4, 2, 1], 399 | drop_rate=0.0, drop_path_rate=0.1) 400 | 401 | 402 | class mit_b4(MixVisionTransformer): 403 | def __init__(self, **kwargs): 404 | super(mit_b4, self).__init__( 405 | patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4], 406 | qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[3, 8, 27, 3], sr_ratios=[8, 4, 2, 1], 407 | drop_rate=0.0, drop_path_rate=0.1) 408 | 409 | 410 | class mit_b5(MixVisionTransformer): 411 | def __init__(self, **kwargs): 412 | super(mit_b5, self).__init__( 413 | patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4], 414 | qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[3, 6, 40, 3], sr_ratios=[8, 4, 2, 1], 415 | drop_rate=0.0, drop_path_rate=0.1) 416 | 417 | 418 | def load_check(model, path): 419 | checkpoint = torch.load(path, map_location=torch.device('cpu')) 420 | state_dict = checkpoint['state_dict'] 421 | 422 | from collections import OrderedDict 423 | new_state_dict = OrderedDict() 424 | for k, v in state_dict.items(): 425 | if k[:8] == 'backbone': 426 | name = k[9:] # remove `module.`,表面从第7个key值字符取到最后一个字符,正好去掉了module. 427 | new_state_dict[name] = v # 新字典的key值对应的value为一一对应的值。 428 | 429 | model.load_state_dict(new_state_dict) 430 | print('load pretrained transformer') 431 | return model 432 | 433 | 434 | if __name__ == '__main__': 435 | import torch 436 | model = mit_b1() 437 | # checkpoint = torch.load('/media/hlf/Luffy/WLS/ContestCD/segformer_b3.pth', map_location=torch.device('cpu')) 438 | # model.load_state_dict(checkpoint['state_dict']) 439 | # print('load pretrained transformer') 440 | 441 | # model = load_check(model, path='/media/hlf/Luffy/WLS/ContestCD/segformer.b2.512x512.ade.160k.pth') 442 | # torch.save({'state_dict': model.state_dict()}, 443 | # '/media/hlf/Luffy/WLS/ContestCD/segformer_b2.pth', _use_new_zipfile_serialization=False) 444 | 445 | input = torch.rand(1, 3, 512, 512) 446 | outputs = model(input) 447 | print(outputs.shape) 448 | # for output in outputs: 449 | # print(output.size()) -------------------------------------------------------------------------------- /models/seg_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from models.network import * 5 | from options import * 6 | from models.base_model import build_base_model, build_channels 7 | from models.tools import * 8 | 9 | 10 | class seg_decoder(nn.Module): 11 | def __init__(self, in_channels, num_classes): 12 | super(seg_decoder, self).__init__() 13 | self.last_layer = nn.Sequential( 14 | nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0), 15 | nn.BatchNorm2d(in_channels), 16 | nn.ReLU(), 17 | nn.Conv2d(in_channels, num_classes, kernel_size=3, stride=1, padding=1) 18 | ) 19 | 20 | def forward(self, x): 21 | x = self.last_layer(x) 22 | return x 23 | 24 | 25 | class Seg_Net(nn.Module): 26 | def __init__(self, opt): 27 | super(Seg_Net, self).__init__() 28 | self.base_model = build_base_model(opt) 29 | self.in_channels = build_channels(opt) 30 | self.seg_decoder = seg_decoder(self.in_channels, opt.num_classes) 31 | self.img_size = opt.img_size 32 | 33 | def resize_out(self, output): 34 | if output.size()[-1] != self.img_size: 35 | output = F.interpolate(output, size=(self.img_size, self.img_size), mode='bilinear') 36 | return output 37 | 38 | def forward(self, x): 39 | x = self.base_model(x) 40 | # feat = x 41 | 42 | x = self.seg_decoder(x) 43 | x = self.resize_out(x) 44 | 45 | return x 46 | 47 | 48 | if __name__ == "__main__": 49 | opt = Seg_Options().parse() 50 | net = Seg_Net(opt) 51 | net.cuda() 52 | x = torch.rand(4, 3, 512, 512).cuda() 53 | label = torch.randint(0, 8, (4, 512, 512)).cuda() 54 | output = net(x) 55 | 56 | criterion = nn.CrossEntropyLoss(ignore_index=5) 57 | loss = criterion(output, label) 58 | 59 | print(output.shape) 60 | print(loss) -------------------------------------------------------------------------------- /models/tools/__init__.py: -------------------------------------------------------------------------------- 1 | from models.tools.densecrf import get_crf 2 | -------------------------------------------------------------------------------- /models/tools/densecrf.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import tifffile 6 | import os 7 | import matplotlib.pyplot as plt 8 | import pydensecrf.densecrf as dcrf 9 | from dataset.datapoint import value_to_rgb 10 | from pydensecrf.utils import create_pairwise_bilateral, create_pairwise_gaussian 11 | 12 | 13 | def softmax(x): 14 | x_row_max = x.max(axis=-1) 15 | x_row_max = x_row_max.reshape(list(x.shape)[:-1]+[1]) 16 | x = x - x_row_max 17 | x_exp = np.exp(x) 18 | x_exp_row_sum = x_exp.sum(axis=-1).reshape(list(x.shape)[:-1]+[1]) 19 | softmax = x_exp / x_exp_row_sum 20 | return softmax 21 | 22 | 23 | def get_crf(opt, mask, img): 24 | mask = np.transpose(mask, (2, 0, 1)) 25 | img = np.ascontiguousarray(img) 26 | 27 | unary = -np.log(mask + 1e-8) 28 | unary = unary.reshape((opt.num_classes, -1)) 29 | unary = np.ascontiguousarray(unary) 30 | 31 | d = dcrf.DenseCRF2D(opt.img_size, opt.img_size, opt.num_classes) 32 | d.setUnaryEnergy(unary) 33 | 34 | d.addPairwiseGaussian(sxy=5, compat=3) 35 | d.addPairwiseBilateral(sxy=10, srgb=13, rgbim=img, compat=10) 36 | 37 | output = d.inference(10) 38 | 39 | map = np.argmax(output, axis=0).reshape((opt.img_size, opt.img_size)) 40 | 41 | return map 42 | -------------------------------------------------------------------------------- /options/Second_options.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | 4 | class Sec_Options(): 5 | """This classification defines options used during both training and test time. 6 | It also implements several helper functions such as parsing, printing, and saving the options. 7 | It also gathers additional options defined in functions in both dataset classification and model classification. 8 | """ 9 | 10 | def __init__(self): 11 | """Reset the classification; indicates the classification hasn't been initailized""" 12 | self.initialized = False 13 | 14 | def initialize(self, parser): 15 | """Define the common options that are used in both training and test.""" 16 | # basic parameters 17 | parser.add_argument('--data_root', type=str, default='/media/hlf/Luffy/WLS/semantic/dataset', help='path to dataroot') 18 | parser.add_argument('--dataset', type=str, default='potsdam', help='[chesapeake|potsdam|vaihingen|GID]') 19 | parser.add_argument('--experiment_name', type=str, default='Second', help='name of the experiment. It decides where to load datafiles, store samples and models') 20 | parser.add_argument('--save_path', type=str, default='/media/hlf/Luffy/WLS/PointAnno/save', help='models are saved here') 21 | parser.add_argument('--data_inform_path', type=str, default='/media/hlf/Luffy/WLS/PointAnno/datafiles', help='path to files about the datafiles information') 22 | 23 | # model parameters 24 | parser.add_argument('--base_model', type=str, default='Deeplab', help='choose which base model. [HRNet18|HRNet48|Deeplab]') 25 | parser.add_argument('--backbone', type=str, default='resnet18', help='which resnet') 26 | parser.add_argument('--num_classes', type=int, default=5, help='classes') 27 | 28 | # train parameters 29 | parser.add_argument('--loss', type=str, default='OHEM_loss', help='choose which hr_loss') 30 | parser.add_argument('--batch_size', type=int, default=32, help='input batch size') 31 | parser.add_argument('--pin', type=bool, default=True, help='pin_memory or not') 32 | 33 | parser.add_argument('--num_workers', type=int, default=10, help='number of workers') 34 | parser.add_argument('--img_size', type=int, default=256, help='image size') 35 | parser.add_argument('--in_channels', type=int, default=3, help='input channels') 36 | 37 | parser.add_argument('--num_epochs', type=int, default=100, help='num of epochs') 38 | parser.add_argument('--base_lr', type=float, default=1e-3, help='base learning rate') 39 | parser.add_argument('--decay', type=float, default=5e-4, help='decay') 40 | parser.add_argument('--log_interval', type=int, default=60, help='how long to log') 41 | parser.add_argument('--resume', type=bool, default=False, help='resume the saved checkpoint or not') 42 | 43 | self.initialized = True 44 | return parser 45 | 46 | def gather_options(self): 47 | """Initialize our parser with basic options(only once). 48 | Add additional model-specific and dataset-specific options. 49 | These options are defined in the function 50 | in model and dataset classes. 51 | """ 52 | if not self.initialized: # check if it has been initialized 53 | parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) 54 | parser = self.initialize(parser) 55 | 56 | # get the basic options 57 | opt, _ = parser.parse_known_args() 58 | self.parser = parser 59 | return parser.parse_args() 60 | 61 | def parse(self): 62 | """Parse our options, create checkpoints directory suffix, and set up gpu device.""" 63 | opt = self.gather_options() 64 | self.opt = opt 65 | 66 | return self.opt 67 | 68 | 69 | if __name__ == '__main__': 70 | opt = Point_Options().parse() 71 | print(opt) -------------------------------------------------------------------------------- /options/__init__.py: -------------------------------------------------------------------------------- 1 | from options.point_options import * 2 | from options.Second_options import * 3 | -------------------------------------------------------------------------------- /options/point_options.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | 4 | class Point_Options(): 5 | """This classification defines options used during both training and test time. 6 | It also implements several helper functions such as parsing, printing, and saving the options. 7 | It also gathers additional options defined in functions in both dataset classification and model classification. 8 | """ 9 | 10 | def __init__(self): 11 | """Reset the classification; indicates the classification hasn't been initailized""" 12 | self.initialized = False 13 | 14 | def initialize(self, parser): 15 | """Define the common options that are used in both training and test.""" 16 | # basic parameters 17 | parser.add_argument('--data_root', type=str, default='/media/hlf/Luffy/WLS/semantic/dataset', help='your path to dataroot') 18 | parser.add_argument('--dataset', type=str, default='potsdam', help='[potsdam|vaihingen]') 19 | parser.add_argument('--experiment_name', type=str, default='Point', help='name of the experiment. It decides where to load datafiles, store samples and models') 20 | parser.add_argument('--save_path', type=str, default='/media/hlf/Luffy/WLS/PointAnno/save', help='models are saved here') 21 | parser.add_argument('--data_inform_path', type=str, default='/media/hlf/Luffy/WLS/PointAnno/datafiles', help='path to files about the datafiles information') 22 | 23 | # model parameter 24 | parser.add_argument('--backbone', type=str, default='resnet18', help='which resnet') 25 | parser.add_argument('--out_stride', type=int, default=32, help='out_stride') 26 | parser.add_argument('--num_classes', type=int, default=5, help='classes') 27 | 28 | # train parameters 29 | parser.add_argument('--batch_size', type=int, default=32, help='input batch size') 30 | parser.add_argument('--pin', type=bool, default=True, help='pin_memory or not') 31 | 32 | parser.add_argument('--num_workers', type=int, default=10, help='number of workers') 33 | parser.add_argument('--img_size', type=int, default=256, help='image size') 34 | parser.add_argument('--in_channels', type=int, default=3, help='input channels') 35 | 36 | parser.add_argument('--num_epochs', type=int, default=100, help='num of epochs') 37 | parser.add_argument('--base_lr', type=float, default=1e-3, help='base learning rate') 38 | parser.add_argument('--decay', type=float, default=5e-4, help='decay') 39 | parser.add_argument('--log_interval', type=int, default=60, help='how long to log, set yo 100 batch') 40 | parser.add_argument('--resume', type=bool, default=False, help='resume the saved checkpoint or not') 41 | 42 | self.initialized = True 43 | return parser 44 | 45 | def gather_options(self): 46 | """Initialize our parser with basic options(only once). 47 | Add additional model-specific and dataset-specific options. 48 | These options are defined in the function 49 | in model and dataset classes. 50 | """ 51 | if not self.initialized: # check if it has been initialized 52 | parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) 53 | parser = self.initialize(parser) 54 | 55 | # get the basic options 56 | opt, _ = parser.parse_known_args() 57 | self.parser = parser 58 | return parser.parse_args() 59 | 60 | def parse(self): 61 | """Parse our options, create checkpoints directory suffix, and set up gpu device.""" 62 | opt = self.gather_options() 63 | self.opt = opt 64 | 65 | return self.opt 66 | 67 | 68 | if __name__ == '__main__': 69 | opt = Point_Options().parse() 70 | print(opt) -------------------------------------------------------------------------------- /predict.py: -------------------------------------------------------------------------------- 1 | from utils import * 2 | from dataset import * 3 | from matplotlib import pyplot as plt 4 | from torch.utils.data import DataLoader 5 | from tqdm import tqdm 6 | from datafiles.color_dict import * 7 | from models.tools import get_crf 8 | import random 9 | from dataset.data_utils import value_to_rgb 10 | # from models.tools import get_crf 11 | import ttach as tta 12 | 13 | print("PyTorch Version: ", torch.__version__) 14 | print('cuda', torch.version.cuda) 15 | os.environ['CUDA_VISIBLE_DEVICES'] = '0' 16 | device = torch.device("cuda:0") 17 | print('Device:', device) 18 | 19 | 20 | def save(visual, name, path): 21 | check_dir(path) 22 | imsave(path+'/'+name[:-4]+'.png', visual) 23 | 24 | 25 | def eval(opt, out, label, name, stride=128): 26 | compute_metric = IOUMetric(num_classes=opt.num_classes) 27 | hist = np.zeros([opt.num_classes, opt.num_classes]) 28 | 29 | h, w, _ = label.shape 30 | num_h, num_w = h//stride, w//stride 31 | for i in range(num_h): 32 | for j in range(num_w): 33 | o = out[i*stride:(i+1)*stride, j*stride:(j+1)*stride] 34 | l = label[i*stride:(i+1)*stride, j*stride:(j+1)*stride, 0] 35 | hist += compute_metric.get_hist(o, l) 36 | # # evaluate 37 | iou, miou, kappa, acc, acc_cls, f_score, m_f_score = eval_hist(hist) 38 | print(name) 39 | print('mfscore:%.4f fscore:%s' % (m_f_score, str(f_score))) 40 | print('------'*5) 41 | return hist 42 | 43 | 44 | def predict_tile(opt, model, img, label, label_vis, name, 45 | save_path, dataset='potsdam', size=256): 46 | im = img.copy() 47 | with torch.no_grad(): 48 | model.eval() 49 | h, w, c = img.shape 50 | img = Normalize(img, flag=dataset) 51 | img = img.astype(np.float32).transpose(2, 0, 1) 52 | img = torch.from_numpy(img).unsqueeze(0).float().cuda() 53 | output = pre_slide(model, img, num_classes=5, tile_size=(size, size), tta=True) 54 | 55 | output = output.squeeze(0).permute(1, 2, 0).data.cpu().numpy().astype(np.float32) 56 | 57 | output = output[:h, :w, :] 58 | output = np.argmax(output, axis=-1) 59 | hist = eval(opt, output, label, name) 60 | 61 | # output = np.expand_dims(output, axis=-1) 62 | # predict = value_to_rgb(output, flag=opt.dataset) 63 | # fig, axs = plt.subplots(1, 3, figsize=(20, 8)) 64 | # axs[0].imshow(im.astype(np.uint8)) 65 | # axs[0].axis("off") 66 | # axs[1].imshow(label_vis.astype(np.uint8)) 67 | # axs[1].axis("off") 68 | # axs[2].imshow(predict.astype(np.uint8)) 69 | # axs[2].axis("off") 70 | # plt.suptitle(os.path.basename(name), y=0.94) 71 | # plt.tight_layout() 72 | # plt.show() 73 | # plt.close() 74 | 75 | # save(predict, name) 76 | 77 | return hist 78 | 79 | 80 | def run_potsdam(root_path, final=True): 81 | img_path = root_path + '/4_Ortho_RGBIR' 82 | label_path = root_path + '/Labels' 83 | vis_path = root_path + '/5_Labels_all' 84 | save_path = './save/potsdam/predict_masks' 85 | train_name = [7, 8, 9, 10, 11, 12] 86 | 87 | opt = Sec_Options().parse() 88 | model = Seg_Net(opt) 89 | checkpoint = torch.load('./save/potsdam_0.8719.pth', map_location=torch.device('cpu')) 90 | model.load_state_dict(checkpoint['state_dict']) 91 | model = model.cuda() 92 | model.eval() 93 | 94 | hist = np.zeros([opt.num_classes, opt.num_classes]) 95 | list = os.listdir(label_path) 96 | for i in list: 97 | if int(i[14:-10]) in train_name: 98 | pass 99 | else: 100 | img = read(os.path.join(img_path, i[:-9] + 'RGBIR.tif'))[:, :, :3] 101 | label = read(os.path.join(label_path, i)) 102 | vis = read(os.path.join(vis_path, i)) 103 | hist += predict_tile(opt, model, img, label, vis, i, save_path, dataset='potsdam') 104 | 105 | iou, miou, kappa, acc, acc_cls, f_score, m_f_score = eval_hist(hist) 106 | print('total') 107 | print(acc) 108 | print('mfscore:%.4f fscore:%s' % (m_f_score, str(f_score))) 109 | print('------' * 5) 110 | 111 | 112 | def run_vaihingen(root_path): 113 | save_path = './save/vaihingen/predict_masks' 114 | img_path = root_path + '/image' 115 | label_path = root_path + '/gts_noB' 116 | vis_path = root_path + '/vis_noB' 117 | test_name = [2, 4, 6, 8, 10, 12, 14, 16, 20, 22, 24, 27, 29, 31, 33, 35, 38] 118 | 119 | opt = Sec_Options().parse() 120 | model = Seg_Net(opt) 121 | 122 | checkpoint = torch.load('./save/vaihingen_0.8867.pth', map_location=torch.device('cpu')) 123 | model.load_state_dict(checkpoint['state_dict']) 124 | model = model.cuda() 125 | model.eval() 126 | 127 | hist = np.zeros([opt.num_classes, opt.num_classes]) 128 | list = os.listdir(label_path) 129 | for i in list: 130 | if int(i[20:-4]) not in test_name: 131 | pass 132 | else: 133 | img = read(os.path.join(img_path, i)) 134 | label = read(os.path.join(label_path, i)) 135 | vis = read(os.path.join(vis_path, i)) 136 | hist += predict_tile(opt, model, img, label, vis, i, save_path, dataset='vaihingen') 137 | 138 | iou, miou, kappa, acc, acc_cls, f_score, m_f_score = eval_hist(hist) 139 | print('total') 140 | print('mfscore:%.4f fscore:%s' % (m_f_score, str(f_score))) 141 | print('------' * 5) 142 | 143 | 144 | if __name__ == '__main__': 145 | potsdam_path = '/media/hlf/Luffy/WLS/semantic/dataset/potsdam/dataset_origin' 146 | run_potsdam(potsdam_path) 147 | 148 | # vaihingen_path = '/media/hlf/Luffy/WLS/semantic/dataset/vaihingen/dataset_origin' 149 | # run_vaihingen(vaihingen_path) 150 | 151 | -------------------------------------------------------------------------------- /run/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Luffy03/DBFNet/5ec2458098c206aa9780e693ed81320e548da4a2/run/__init__.py -------------------------------------------------------------------------------- /run/point/p_predict_train.py: -------------------------------------------------------------------------------- 1 | from utils import * 2 | from dataset import * 3 | from torchvision import transforms 4 | from matplotlib import pyplot as plt 5 | from torch.utils.data import DataLoader 6 | from tqdm import tqdm 7 | from datafiles.color_dict import * 8 | from models.DBFNet import clean_mask 9 | from models.tools import get_crf 10 | import tifffile 11 | import random 12 | import ttach as tta 13 | 14 | print("PyTorch Version: ", torch.__version__) 15 | print('cuda', torch.version.cuda) 16 | os.environ['CUDA_VISIBLE_DEVICES'] = '0' 17 | device = torch.device("cuda:0") 18 | print('Device:', device) 19 | 20 | 21 | def predict_val(): 22 | opt = Point_Options().parse() 23 | log_path, checkpoint_path, predict_test_path, predict_train_path, predict_val_path = create_save_path(opt) 24 | train_txt_path, val_txt_path, test_txt_path = create_data_path(opt) 25 | 26 | # load train and val dataset 27 | val_dataset = Dataset_point(opt, val_txt_path, flag='predict_val', transform=None) 28 | loader = DataLoader(val_dataset, batch_size=1, num_workers=opt.num_workers, 29 | pin_memory=opt.pin) 30 | 31 | model = Net(opt, flag='test') 32 | checkpoint = torch.load(checkpoint_path + '/model_best_0.8553.pth', map_location=torch.device('cpu')) 33 | model.load_state_dict(checkpoint['state_dict']) 34 | print('resume success') 35 | model = model.cuda() 36 | model.eval() 37 | 38 | for batch_idx, batch in enumerate(tqdm(loader)): 39 | img, label, cls_label, filename = batch 40 | with torch.no_grad(): 41 | input, label, cls_label = img.cuda(non_blocking=True), label.cuda(non_blocking=True), \ 42 | cls_label.cuda(non_blocking=True) 43 | tta_model = tta.SegmentationTTAWrapper(model, tta.aliases.d4_transform(), 44 | merge_mode='mean') 45 | coarse_mask = tta_model(input) 46 | torch.cuda.synchronize() 47 | 48 | # save train predict after softmax and clean 49 | coarse_mask = F.softmax(coarse_mask, dim=1) 50 | mask_clean = clean_mask(coarse_mask, cls_label) 51 | mask = mask_clean.squeeze(0).permute(1, 2, 0).cpu().numpy() 52 | mask = mask.astype(np.float32) 53 | 54 | # get crf 55 | img = img.squeeze(0).permute(1, 2, 0).cpu().numpy() 56 | img = Normalize_back(img, flag=opt.dataset) 57 | crf_out = get_crf(opt, mask, img.astype(np.uint8)) 58 | 59 | # # save crf out 60 | predict_path_crf = predict_val_path + '/crf' 61 | check_dir(predict_path_crf) 62 | save_pred_anno_numpy(crf_out, predict_path_crf, filename, dict=postdam_color_dict, flag=False) 63 | 64 | 65 | def main(): 66 | opt = Point_Options().parse() 67 | log_path, checkpoint_path, predict_test_path, predict_train_path, predict_val_path = create_save_path(opt) 68 | train_txt_path, val_txt_path, test_txt_path = create_data_path(opt) 69 | 70 | # load train and val dataset 71 | train_dataset = Dataset_point(opt, train_txt_path, flag='predict_train', transform=None) 72 | loader = DataLoader(train_dataset, batch_size=1, num_workers=opt.num_workers, 73 | pin_memory=opt.pin) 74 | 75 | model = Net(opt, flag='test') 76 | checkpoint = torch.load(checkpoint_path+'/model_best_0.8553.pth', map_location=torch.device('cpu')) 77 | model.load_state_dict(checkpoint['state_dict']) 78 | print('resume success') 79 | model = model.to(device) 80 | model.eval() 81 | 82 | # numpy for crf 83 | compute_metric_ny = IOUMetric(num_classes=opt.num_classes) 84 | hist_ny = np.zeros([opt.num_classes, opt.num_classes]) 85 | 86 | for batch_idx, batch in enumerate(tqdm(loader)): 87 | img, label, cls_label, filename = batch 88 | with torch.no_grad(): 89 | input, label, cls_label = img.cuda(non_blocking=True), label.cuda(non_blocking=True),\ 90 | cls_label.cuda(non_blocking=True) 91 | tta_model = tta.SegmentationTTAWrapper(model, tta.aliases.d4_transform(), 92 | merge_mode='mean') 93 | coarse_mask = tta_model(input) 94 | torch.cuda.synchronize() 95 | 96 | # save train predict after softmax and clean 97 | coarse_mask = F.softmax(coarse_mask, dim=1) 98 | mask_clean = clean_mask(coarse_mask, cls_label) 99 | 100 | mask = mask_clean.squeeze(0).permute(1, 2, 0).cpu().numpy() 101 | mask = mask.astype(np.float32) 102 | 103 | # get crf 104 | img = img.squeeze(0).permute(1, 2, 0).cpu().numpy() 105 | img = Normalize_back(img, flag=opt.dataset) 106 | crf_out = get_crf(opt, mask, img.astype(np.uint8)) 107 | # # # save crf out 108 | predict_path_crf = predict_train_path + '/crf' 109 | check_dir(predict_path_crf) 110 | save_pred_anno_numpy(crf_out, predict_path_crf, filename, dict=postdam_color_dict, flag=False) 111 | 112 | label_ny = label.squeeze(0).data.cpu().numpy() 113 | hist_ny += compute_metric_ny.get_hist(crf_out, label_ny) 114 | 115 | # crf out's metric 116 | iou, miou, kappa, acc, acc_cls, f_score, m_f_score = eval_hist(hist_ny) 117 | 118 | print('miou:%.4f iou:%s kappa:%.4f' % (miou, str(iou), kappa)) 119 | print('acc:%.4f acc_cls:%s' % (acc, str(acc_cls))) 120 | print('mfscore:%.4f fscore:%s' % (m_f_score, str(f_score))) 121 | 122 | result_path = log_path[:-4] + '/result_predict_train.txt' 123 | result_txt = open(result_path, 'a') 124 | result_txt.write('miou:%.4f iou:%s kappa:%.4f' % (miou, str(iou), kappa) + '\n') 125 | result_txt.write('acc:%.4f acc_cls:%s' % (acc, str(acc_cls)) + '\n') 126 | result_txt.write('mfscore:%.4f fscore:%s' % (m_f_score, str(f_score)) + '\n') 127 | result_txt.write('---------------------------' + '\n') 128 | result_txt.close() 129 | 130 | 131 | def show(predict_path, test_path): 132 | path = predict_path + '/label_vis' 133 | list = os.listdir(path) 134 | 135 | img_path = test_path + '/img' 136 | point_label_path = test_path + '/point_label_vis' 137 | label_path = test_path + '/label_vis' 138 | 139 | for i in list: 140 | predict = read(os.path.join(path, i)) 141 | 142 | img = read(os.path.join(img_path, i))[:, :, :3] 143 | label = read(os.path.join(label_path, i)) 144 | point_label = read(os.path.join(point_label_path, i)) 145 | 146 | fig, axs = plt.subplots(1, 4, figsize=(14, 4)) 147 | 148 | axs[0].imshow(img.astype(np.uint8)) 149 | axs[0].axis("off") 150 | axs[1].imshow(label.astype(np.uint8)) 151 | axs[1].axis("off") 152 | axs[2].imshow(predict.astype(np.uint8)) 153 | axs[2].axis("off") 154 | axs[3].imshow(point_label.astype(np.uint8)) 155 | axs[3].axis("off") 156 | 157 | plt.suptitle(os.path.basename(i), y=0.94) 158 | plt.tight_layout() 159 | plt.show() 160 | plt.close() 161 | 162 | 163 | if __name__ == '__main__': 164 | main() 165 | predict_val() 166 | -------------------------------------------------------------------------------- /run/point/p_test.py: -------------------------------------------------------------------------------- 1 | from utils import * 2 | from dataset import * 3 | from matplotlib import pyplot as plt 4 | from torch.utils.data import DataLoader 5 | from tqdm import tqdm 6 | from datafiles.color_dict import * 7 | from models.tools import get_crf 8 | import random 9 | import ttach as tta 10 | print("PyTorch Version: ", torch.__version__) 11 | print('cuda', torch.version.cuda) 12 | os.environ['CUDA_VISIBLE_DEVICES'] = '0' 13 | device = torch.device("cuda:0") 14 | print('Device:', device) 15 | 16 | 17 | def main(): 18 | opt = Point_Options().parse() 19 | log_path, checkpoint_path, predict_path, _, _ = create_save_path(opt) 20 | train_txt_path, val_txt_path, test_txt_path = create_data_path(opt) 21 | 22 | # load train and val dataset 23 | test_dataset = Dataset_point(opt, test_txt_path, flag='test', transform=None) 24 | loader = DataLoader(test_dataset, batch_size=1, shuffle=True, num_workers=opt.num_workers, 25 | pin_memory=opt.pin) 26 | 27 | model = Net(opt, flag='test') 28 | checkpoint = torch.load(checkpoint_path + '/model_best_0.8553.pth', map_location=torch.device('cpu')) 29 | model.load_state_dict(checkpoint['state_dict']) 30 | print('resume success') 31 | model = model.to(device) 32 | model.eval() 33 | 34 | compute_metric = IOUMetric(num_classes=opt.num_classes) 35 | hist = np.zeros([opt.num_classes, opt.num_classes]) 36 | 37 | for batch_idx, batch in enumerate(tqdm(loader)): 38 | img, label, cls_label, filename = batch 39 | with torch.no_grad(): 40 | input, label = img.cuda(non_blocking=True), label.cuda(non_blocking=True) 41 | tta_model = tta.SegmentationTTAWrapper(model, tta.aliases.d4_transform(), 42 | merge_mode='mean') 43 | output = tta_model(input) 44 | torch.cuda.synchronize() 45 | 46 | mask = F.softmax(output, dim=1) 47 | mask = mask.squeeze(0).permute(1, 2, 0).cpu().numpy() 48 | mask = mask.astype(np.float32) 49 | 50 | img = img.squeeze(0).permute(1, 2, 0).cpu().numpy() 51 | img = Normalize_back(img, flag=opt.dataset) 52 | 53 | crf_out = get_crf(opt, mask, img.astype(np.uint8)) 54 | # save_pred_anno_numpy(crf_out, predict_path, filename, dict=postdam_color_dict, flag=True) 55 | label = label.squeeze(0).data.cpu().numpy() 56 | 57 | hist += compute_metric.get_hist(crf_out, label) 58 | 59 | iou, miou, kappa, acc, acc_cls, f_score, m_f_score = eval_hist(hist) 60 | 61 | print('miou:%.4f iou:%s kappa:%.4f' % (miou, str(iou), kappa)) 62 | print('acc:%.4f acc_cls:%s' % (acc, str(acc_cls))) 63 | print('mfscore:%.4f fscore:%s' % (m_f_score, str(f_score))) 64 | 65 | 66 | if __name__ == '__main__': 67 | main() 68 | 69 | -------------------------------------------------------------------------------- /run/point/p_train.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.optim as optim 3 | from torchvision import transforms 4 | from torch.autograd import Variable 5 | import shutil 6 | from torch.utils.data import DataLoader 7 | from tqdm import tqdm 8 | from collections import OrderedDict 9 | from options import * 10 | from utils import * 11 | from dataset import * 12 | 13 | print("PyTorch Version: ", torch.__version__) 14 | print('cuda', torch.version.cuda) 15 | os.environ['CUDA_VISIBLE_DEVICES'] = '0' 16 | device = torch.device("cuda:0") 17 | print('Device:', device) 18 | 19 | 20 | def main(): 21 | opt = Point_Options().parse() 22 | log_path, checkpoint_path, _, _, _ = create_save_path(opt) 23 | train_txt_path, val_txt_path, _ = create_data_path(opt) 24 | 25 | # Define logger 26 | logger, tensorboard_log_dir = create_logger(log_path) 27 | logger.info(opt) 28 | # Define Transformation 29 | train_transform = transforms.Compose([ 30 | trans.RandomHorizontalFlip(), 31 | trans.RandomVerticleFlip(), 32 | trans.RandomRotate90(), 33 | ]) 34 | 35 | # load train and val dataset 36 | train_dataset = Dataset_point(opt, train_txt_path, flag='train', transform=train_transform) 37 | val_dataset = Dataset_point(opt, val_txt_path, flag='val', transform=None) 38 | 39 | # Create training and validation dataloaders 40 | train_loader = DataLoader(train_dataset, batch_size=opt.batch_size, shuffle=True, 41 | num_workers=opt.num_workers, pin_memory=opt.pin, drop_last=True) 42 | 43 | model = Net(opt) 44 | if opt.resume: 45 | checkpoint = torch.load(checkpoint_path+'/model_best.pth', map_location=torch.device('cpu')) 46 | model.load_state_dict(checkpoint['state_dict']) 47 | print('resume success') 48 | 49 | model = model.to(device) 50 | optimizer = optim.AdamW(model.parameters(), lr=opt.base_lr, amsgrad=True) 51 | best_metric = 1e16 52 | 53 | for epoch in range(opt.num_epochs): 54 | time_start = time.time() 55 | 56 | train(opt, epoch, model, train_loader, optimizer, logger) 57 | metric = validate(opt, model, val_dataset, logger) 58 | 59 | logger.info('best_val_metric:%.4f current_val_metric:%.4f' % (best_metric, metric)) 60 | if metric < best_metric: 61 | logger.info('epoch:%d Save to model_best' % (epoch)) 62 | torch.save({'state_dict': model.state_dict()}, 63 | os.path.join(checkpoint_path, 'model_best.pth')) 64 | best_metric = metric 65 | 66 | end_time = time.time() 67 | time_cost = end_time - time_start 68 | logger.info("Epoch %d Time %d ----------------------" % (epoch, time_cost)) 69 | logger.info('\n') 70 | 71 | 72 | def train(opt, epoch, model, loader, optimizer, logger): 73 | model.train() 74 | 75 | loss_m = AverageMeter() 76 | seg_loss_m = AverageMeter() 77 | pse_loss_m = AverageMeter() 78 | penalty_m = AverageMeter() 79 | 80 | last_idx = len(loader) - 1 81 | 82 | for batch_idx, batch in enumerate(loader): 83 | step = epoch * len(loader) + batch_idx 84 | adjust_learning_rate(opt.base_lr, optimizer, step, len(loader), num_epochs=30) 85 | lr = get_lr(optimizer) 86 | 87 | last_batch = batch_idx == last_idx 88 | img, label, cls_label, _ = batch 89 | input, label, cls_label = img.cuda(non_blocking=True), label.cuda(non_blocking=True), \ 90 | cls_label.cuda(non_blocking=True) 91 | seg_loss, penalty = model.forward_loss(input, label, cls_label) 92 | loss = seg_loss + penalty 93 | 94 | seg_loss_m.update(seg_loss.item(), input.size(0)) 95 | penalty_m.update(penalty.item(), input.size(0)) 96 | 97 | loss_m.update(loss.item(), input.size(0)) 98 | 99 | optimizer.zero_grad() 100 | loss.backward() 101 | optimizer.step() 102 | 103 | torch.cuda.synchronize() 104 | 105 | log_interval = len(loader) // 3 106 | if last_batch or batch_idx % log_interval == 0: 107 | logger.info('Train:{} [{:>4d}/{} ({:>3.0f}%)] ' 108 | 'Loss:({loss.avg:>6.4f}) ' 109 | 'segloss:({seg_loss.avg:>6.4f}) ' 110 | #'pseloss:({pse_loss.avg:>6.4f}) ' 111 | 'penal:({penal.avg:>6.4f}) ' 112 | 'LR:{lr:.3e} '.format( 113 | epoch, 114 | batch_idx, len(loader), 115 | 100. * batch_idx / last_idx, 116 | loss=loss_m, 117 | seg_loss=seg_loss_m, 118 | pse_loss=pse_loss_m, 119 | penal=penalty_m, 120 | lr=lr)) 121 | 122 | return OrderedDict([('loss', loss_m.avg)]) 123 | 124 | 125 | def validate(opt, model, val_dataset, logger): 126 | val_loader = DataLoader(val_dataset, batch_size=opt.batch_size, shuffle=True, num_workers=opt.num_workers, pin_memory=opt.pin) 127 | model.eval() 128 | loss_m = AverageMeter() 129 | seg_loss_m = AverageMeter() 130 | pse_loss_m = AverageMeter() 131 | penalty_m = AverageMeter() 132 | 133 | for batch_idx, batch in enumerate(val_loader): 134 | img, label, cls_label, _ = batch 135 | with torch.no_grad(): 136 | input, label, cls_label = img.cuda(non_blocking=True), label.cuda(non_blocking=True), \ 137 | cls_label.cuda(non_blocking=True) 138 | out = model.forward(input) 139 | 140 | criterion = nn.CrossEntropyLoss(ignore_index=255) 141 | # penalty 142 | penalty = get_penalty(out, cls_label) 143 | # label: point annotations point-level supervision 144 | seg_loss = criterion(out, label) 145 | loss = seg_loss + penalty 146 | 147 | seg_loss_m.update(seg_loss.item(), input.size(0)) 148 | # pse_loss_m.update(pse_loss.item(), input.size(0)) 149 | penalty_m.update(penalty.item(), input.size(0)) 150 | loss_m.update(loss.item(), input.size(0)) 151 | 152 | logger.info('VAL:' 153 | 'Loss:({loss.avg:>6.4f}) ' 154 | 'segloss:({seg_loss.avg:>6.4f}) ' 155 | # 'pseloss:({pse_loss.avg:>6.4f}) ' 156 | 'penal:({penal.avg:>6.4f}) '.format( 157 | loss=loss_m, 158 | seg_loss=seg_loss_m, 159 | pse_loss=pse_loss_m, 160 | penal=penalty_m 161 | )) 162 | 163 | return loss_m.avg 164 | 165 | 166 | if __name__ == '__main__': 167 | main() -------------------------------------------------------------------------------- /run/second/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Luffy03/DBFNet/5ec2458098c206aa9780e693ed81320e548da4a2/run/second/__init__.py -------------------------------------------------------------------------------- /run/second/sec_predict_train.py: -------------------------------------------------------------------------------- 1 | from utils import * 2 | from dataset import * 3 | from torchvision import transforms 4 | from matplotlib import pyplot as plt 5 | from torch.utils.data import DataLoader 6 | from tqdm import tqdm 7 | from datafiles.color_dict import * 8 | from models.MyModel import clean_mask 9 | import tifffile 10 | import random 11 | import ttach as tta 12 | 13 | print("PyTorch Version: ", torch.__version__) 14 | print('cuda', torch.version.cuda) 15 | os.environ['CUDA_VISIBLE_DEVICES'] = '0' 16 | device = torch.device("cuda:0") 17 | print('Device:', device) 18 | 19 | 20 | def predict_val(): 21 | opt = Sec_Options().parse() 22 | log_path, checkpoint_path, predict_test_path, predict_train_path, predict_val_path = create_save_path(opt) 23 | train_txt_path, val_txt_path, test_txt_path = create_data_path(opt) 24 | 25 | # load val dataset 26 | val_dataset = Dataset_point(opt, val_txt_path, flag='val', transform=None) 27 | loader = DataLoader(val_dataset, batch_size=1, num_workers=opt.num_workers, 28 | pin_memory=opt.pin) 29 | 30 | model = Seg_Net(opt) 31 | checkpoint = torch.load(checkpoint_path + '/model_best_0.8184.pth', map_location=torch.device('cpu')) 32 | model.load_state_dict(checkpoint['state_dict']) 33 | print('resume success') 34 | model = model.to(device) 35 | model.eval() 36 | 37 | for batch_idx, batch in enumerate(tqdm(loader)): 38 | img, label, cls_label, filename = batch 39 | with torch.no_grad(): 40 | input, label, cls_label = img.cuda(non_blocking=True), label.cuda(non_blocking=True), \ 41 | cls_label.cuda(non_blocking=True) 42 | coarse_mask = model.forward(input) 43 | torch.cuda.synchronize() 44 | 45 | # save train predict after softmax and clean 46 | coarse_mask = F.softmax(coarse_mask, dim=1) 47 | mask_clean = clean_mask(coarse_mask, cls_label) 48 | mask_clean = torch.argmax(mask_clean, dim=1) 49 | save_pred_anno(mask_clean, predict_val_path, filename, dict=postdam_color_dict, flag=False) 50 | 51 | 52 | def main(): 53 | opt = Sec_Options().parse() 54 | save_path = os.path.join(opt.save_path, opt.dataset) 55 | 56 | log_path, checkpoint_path, predict_test_path, predict_train_path, predict_val_path = create_save_path(opt) 57 | train_txt_path, val_txt_path, test_txt_path = create_data_path(opt) 58 | 59 | # load train and val dataset 60 | train_dataset = Dataset_point(opt, train_txt_path, flag='predict_train', transform=None) 61 | loader = DataLoader(train_dataset, batch_size=1, num_workers=opt.num_workers, 62 | pin_memory=opt.pin) 63 | 64 | model = Seg_Net(opt) 65 | checkpoint = torch.load(checkpoint_path+'/model_best_0.8184.pth', map_location=torch.device('cpu')) 66 | model.load_state_dict(checkpoint['state_dict']) 67 | print('resume success') 68 | model = model.to(device) 69 | model.eval() 70 | 71 | compute_metric = IOUMetric_tensor(num_classes=opt.num_classes) 72 | hist = torch.zeros([opt.num_classes, opt.num_classes]).cuda() 73 | 74 | for batch_idx, batch in enumerate(tqdm(loader)): 75 | img, label, cls_label, filename = batch 76 | with torch.no_grad(): 77 | input, label, cls_label = img.cuda(non_blocking=True), label.cuda(non_blocking=True),\ 78 | cls_label.cuda(non_blocking=True) 79 | tta_model = tta.SegmentationTTAWrapper(model, tta.aliases.d4_transform(), 80 | merge_mode='mean') 81 | coarse_mask = tta_model.forward(input) 82 | torch.cuda.synchronize() 83 | 84 | output = torch.argmax(coarse_mask, dim=1) 85 | 86 | # save train predict after softmax and clean 87 | coarse_mask = F.softmax(coarse_mask, dim=1) 88 | mask_clean = clean_mask(coarse_mask, cls_label) 89 | mask_clean = torch.argmax(mask_clean, dim=1) 90 | save_pred_anno(mask_clean, predict_train_path, filename, dict=postdam_color_dict, flag=True) 91 | 92 | hist += compute_metric.get_hist(output, label) 93 | 94 | hist = hist.data.cpu().numpy() 95 | iou, miou, kappa, acc, acc_cls, f_score, m_f_score = eval_hist(hist) 96 | 97 | print('miou:%.4f iou:%s kappa:%.4f' % (miou, str(iou), kappa)) 98 | print('acc:%.4f acc_cls:%s' % (acc, str(acc_cls))) 99 | print('mfscore:%.4f fscore:%s' % (m_f_score, str(f_score))) 100 | 101 | result_path = log_path[:-4] + '/result_predict_train.txt' 102 | result_txt = open(result_path, 'a') 103 | result_txt.write('miou:%.4f iou:%s kappa:%.4f' % (miou, str(iou), kappa) + '\n') 104 | result_txt.write('acc:%.4f acc_cls:%s' % (acc, str(acc_cls)) + '\n') 105 | result_txt.write('mfscore:%.4f fscore:%s' % (m_f_score, str(f_score)) + '\n') 106 | result_txt.write('---------------------------' + '\n') 107 | result_txt.close() 108 | 109 | 110 | def show(predict_path, test_path): 111 | path = predict_path + '/label_vis' 112 | list = os.listdir(path) 113 | 114 | img_path = test_path + '/img' 115 | point_label_path = test_path + '/point_label_vis' 116 | label_path = test_path + '/label_vis' 117 | 118 | for i in list: 119 | predict = read(os.path.join(path, i)) 120 | 121 | img = read(os.path.join(img_path, i))[:, :, :3] 122 | label = read(os.path.join(label_path, i)) 123 | point_label = read(os.path.join(point_label_path, i)) 124 | 125 | fig, axs = plt.subplots(1, 4, figsize=(14, 4)) 126 | 127 | axs[0].imshow(img.astype(np.uint8)) 128 | axs[0].axis("off") 129 | axs[1].imshow(label.astype(np.uint8)) 130 | axs[1].axis("off") 131 | axs[2].imshow(predict.astype(np.uint8)) 132 | axs[2].axis("off") 133 | axs[3].imshow(point_label.astype(np.uint8)) 134 | axs[3].axis("off") 135 | 136 | plt.suptitle(os.path.basename(i), y=0.94) 137 | plt.tight_layout() 138 | plt.show() 139 | plt.close() 140 | 141 | 142 | if __name__ == '__main__': 143 | main() 144 | 145 | # opt = Point_iter_second_Options().parse() 146 | # log_path, checkpoint_path, predict_path, predict_train_path, _ = create_save_path(opt) 147 | # 148 | # label_path = '/home/ggm/WLS/semantic/dataset/potsdam/train' 149 | # show(predict_train_path, label_path) 150 | 151 | predict_val() -------------------------------------------------------------------------------- /run/second/sec_test.py: -------------------------------------------------------------------------------- 1 | from utils import * 2 | from dataset import * 3 | from matplotlib import pyplot as plt 4 | from torch.utils.data import DataLoader 5 | from tqdm import tqdm 6 | from datafiles.color_dict import * 7 | from models.tools import get_crf 8 | import random 9 | import ttach as tta 10 | 11 | print("PyTorch Version: ", torch.__version__) 12 | print('cuda', torch.version.cuda) 13 | os.environ['CUDA_VISIBLE_DEVICES'] = '0' 14 | device = torch.device("cuda:0") 15 | print('Device:', device) 16 | 17 | 18 | def main(): 19 | opt = Sec_Options().parse() 20 | save_path = os.path.join(opt.save_path, opt.dataset) 21 | log_path, checkpoint_path, predict_path, _, _ = create_save_path(opt) 22 | train_txt_path, val_txt_path, test_txt_path = create_data_path(opt) 23 | 24 | # load train and val dataset 25 | test_dataset = Dataset_point(opt, test_txt_path, flag='test', transform=None) 26 | loader = DataLoader(test_dataset, batch_size=1, shuffle=True, num_workers=opt.num_workers, 27 | pin_memory=opt.pin) 28 | 29 | model = Seg_Net(opt) 30 | checkpoint = torch.load(checkpoint_path+'/model_best.pth', map_location=torch.device('cpu')) 31 | model.load_state_dict(checkpoint['state_dict']) 32 | print('resume success') 33 | model = model.to(device) 34 | model.eval() 35 | 36 | compute_metric = IOUMetric_tensor(num_classes=opt.num_classes) 37 | hist = torch.zeros([opt.num_classes, opt.num_classes]).cuda() 38 | 39 | for batch_idx, batch in enumerate(tqdm(loader)): 40 | img, label, cls_label, filename = batch 41 | with torch.no_grad(): 42 | input, label = img.cuda(non_blocking=True), label.cuda(non_blocking=True) 43 | tta_model = tta.SegmentationTTAWrapper(model, tta.aliases.d4_transform(), 44 | merge_mode='mean') 45 | output = model(input) 46 | torch.cuda.synchronize() 47 | 48 | if opt.dataset == 'vaihingen': 49 | cls_label = cls_label.cuda(non_blocking=True) 50 | coarse_mask = F.softmax(output, dim=1) 51 | mask_clean = clean_mask(coarse_mask, cls_label) 52 | output = torch.argmax(mask_clean, dim=1) 53 | else: 54 | output = torch.argmax(output, dim=1) 55 | 56 | # save_pred_anno(output, predict_path, filename, dict=postdam_color_dict, flag=True) 57 | hist += compute_metric.get_hist(output, label) 58 | 59 | hist = hist.data.cpu().numpy() 60 | iou, miou, kappa, acc, acc_cls, f_score, m_f_score = eval_hist(hist) 61 | 62 | print('miou:%.4f iou:%s kappa:%.4f' % (miou, str(iou), kappa)) 63 | print('acc:%.4f acc_cls:%s' % (acc, str(acc_cls))) 64 | print('mfscore:%.4f fscore:%s' % (m_f_score, str(f_score))) 65 | 66 | # result_path = log_path[:-4] + '/result_test.txt' 67 | # result_txt = open(result_path, 'a') 68 | # result_txt.write('miou:%.4f iou:%s kappa:%.4f' % (miou, str(iou), kappa) + '\n') 69 | # result_txt.write('acc:%.4f acc_cls:%s' % (acc, str(acc_cls)) + '\n') 70 | # result_txt.write('mfscore:%.4f fscore:%s' % (m_f_score, str(f_score)) + '\n') 71 | # result_txt.write('---------------------------' + '\n') 72 | # result_txt.close() 73 | 74 | 75 | def p_test_withcrf(): 76 | opt = Sec_Options().parse() 77 | log_path, checkpoint_path, predict_path, _, _ = create_save_path(opt) 78 | train_txt_path, val_txt_path, test_txt_path = create_data_path(opt) 79 | 80 | # load train and val dataset 81 | test_dataset = Dataset_point(opt, test_txt_path, flag='test', transform=None) 82 | loader = DataLoader(test_dataset, batch_size=1, shuffle=True, num_workers=opt.num_workers, 83 | pin_memory=opt.pin) 84 | 85 | model = Seg_Net(opt) 86 | checkpoint = torch.load(checkpoint_path + '/model_best.pth', map_location=torch.device('cpu')) 87 | model.load_state_dict(checkpoint['state_dict']) 88 | print('resume success') 89 | model = model.to(device) 90 | model.eval() 91 | 92 | compute_metric = IOUMetric(num_classes=opt.num_classes) 93 | hist = np.zeros([opt.num_classes, opt.num_classes]) 94 | 95 | for batch_idx, batch in enumerate(tqdm(loader)): 96 | img, label, cls_label, filename = batch 97 | with torch.no_grad(): 98 | input, label = img.cuda(non_blocking=True), label.cuda(non_blocking=True) 99 | tta_model = tta.SegmentationTTAWrapper(model, tta.aliases.d4_transform(), 100 | merge_mode='mean') 101 | output = tta_model(input) 102 | torch.cuda.synchronize() 103 | 104 | if opt.dataset == 'vaihingen': 105 | cls_label = cls_label.cuda(non_blocking=True) 106 | coarse_mask = F.softmax(output, dim=1) 107 | mask = clean_mask(coarse_mask, cls_label) 108 | else: 109 | mask = F.softmax(output, dim=1) 110 | 111 | mask = mask.squeeze(0).permute(1, 2, 0).cpu().numpy() 112 | mask = mask.astype(np.float32) 113 | 114 | img = img.squeeze(0).permute(1, 2, 0).cpu().numpy() 115 | img = Normalize_back(img, flag='potsdam') 116 | 117 | crf_out = get_crf(opt, mask, img.astype(np.uint8)) 118 | label = label.squeeze(0).data.cpu().numpy() 119 | 120 | # save crf out 121 | predict_path_crf = predict_path + '/crf' 122 | check_dir(predict_path_crf) 123 | # save_pred_anno_numpy(crf_out, predict_path_crf, filename, dict=nuclear_color_dict, flag=True) 124 | 125 | hist += compute_metric.get_hist(crf_out, label) 126 | 127 | iou, miou, kappa, acc, acc_cls, f_score, m_f_score = eval_hist(hist) 128 | 129 | print('miou:%.4f iou:%s kappa:%.4f' % (miou, str(iou), kappa)) 130 | print('acc:%.4f acc_cls:%s' % (acc, str(acc_cls))) 131 | print('mfscore:%.4f fscore:%s' % (m_f_score, str(f_score))) 132 | 133 | 134 | def show(predict_path, test_path): 135 | # crf_path = predict_path[:-10] + '/crf/label_vis' 136 | img_path = test_path + '/img' 137 | label_path = test_path + '/label_vis' 138 | 139 | list = os.listdir(predict_path) 140 | random.shuffle(list) 141 | 142 | for i in list: 143 | predict = read(os.path.join(predict_path, i)) 144 | img = read(os.path.join(img_path, i))[:, :, :3] 145 | # crf = read(os.path.join(crf_path, i)) 146 | label = read(os.path.join(label_path, i)) 147 | 148 | fig, axs = plt.subplots(1, 3, figsize=(14, 4)) 149 | 150 | axs[0].imshow(img.astype(np.uint8)) 151 | axs[0].axis("off") 152 | axs[1].imshow(label.astype(np.uint8)) 153 | axs[1].axis("off") 154 | 155 | axs[2].imshow(predict.astype(np.uint8)) 156 | axs[2].axis("off") 157 | 158 | # axs[3].imshow(crf.astype(np.uint8)) 159 | # axs[3].axis("off") 160 | 161 | plt.suptitle(os.path.basename(i), y=0.94) 162 | plt.tight_layout() 163 | plt.show() 164 | plt.close() 165 | 166 | 167 | if __name__ == '__main__': 168 | main() 169 | p_test_withcrf() 170 | 171 | # predict_path = '/home/ggm/WLS/semantic/PointAnno/save/nuclear/Second/predict_test/label_vis' 172 | # test_path = '/home/ggm/WLS/semantic/dataset/nuclear/test' 173 | # show(predict_path, test_path) -------------------------------------------------------------------------------- /run/second/sec_train.py: -------------------------------------------------------------------------------- 1 | import torch.optim as optim 2 | from torchvision import transforms 3 | from torch.autograd import Variable 4 | import shutil 5 | from torch.utils.data import DataLoader 6 | from tqdm import tqdm 7 | from collections import OrderedDict 8 | from options import * 9 | from utils import * 10 | from dataset import * 11 | 12 | 13 | print("PyTorch Version: ", torch.__version__) 14 | print('cuda', torch.version.cuda) 15 | os.environ['CUDA_VISIBLE_DEVICES'] = '0' 16 | device = torch.device("cuda:0") 17 | print('Device:', device) 18 | 19 | 20 | def main(): 21 | opt = Sec_Options().parse() 22 | save_path = os.path.join(opt.save_path, opt.dataset) 23 | log_path, checkpoint_path, _, _, _ = create_save_path(opt) 24 | train_txt_path, val_txt_path, _ = create_data_path(opt) 25 | 26 | # Define logger 27 | logger, tensorboard_log_dir = create_logger(log_path) 28 | logger.info(opt) 29 | # Define Transformation 30 | train_transform = transforms.Compose([ 31 | trans.RandomHorizontalFlip(), 32 | trans.RandomVerticleFlip(), 33 | trans.RandomRotate90(), 34 | ]) 35 | 36 | # load train and val dataset 37 | train_dataset = Dataset_sec(opt, save_path, train_txt_path, flag='train', transform=train_transform) 38 | val_dataset = Dataset_sec(opt, save_path, val_txt_path, flag='val', transform=None) 39 | 40 | # Create training and validation dataloaders 41 | train_loader = DataLoader(train_dataset, batch_size=opt.batch_size, shuffle=True, 42 | num_workers=opt.num_workers, pin_memory=opt.pin, drop_last=True) 43 | 44 | model = Seg_Net(opt) 45 | if opt.resume: 46 | checkpoint = torch.load(checkpoint_path+'/model_best.pth', map_location=torch.device('cpu')) 47 | model.load_state_dict(checkpoint['state_dict']) 48 | print('resume success') 49 | 50 | model = model.to(device) 51 | optimizer = optim.AdamW(model.parameters(), lr=opt.base_lr, amsgrad=True) 52 | best_metric = -1e16 53 | 54 | for epoch in range(opt.num_epochs): 55 | time_start = time.time() 56 | 57 | train(opt, epoch, model, train_loader, optimizer, logger) 58 | metric = validate(opt, epoch, model, val_dataset, logger) 59 | 60 | logger.info('best_val_metric:%.4f current_val_metric:%.4f' % (best_metric, metric)) 61 | if metric > best_metric: 62 | logger.info('epoch:%d Save to model_best' % (epoch)) 63 | torch.save({'state_dict': model.state_dict()}, 64 | os.path.join(checkpoint_path, 'model_best.pth')) 65 | best_metric = metric 66 | 67 | end_time = time.time() 68 | time_cost = end_time - time_start 69 | logger.info("Epoch %d Time %d ----------------------" % (epoch, time_cost)) 70 | logger.info('\n') 71 | 72 | 73 | def train(opt, epoch, model, loader, optimizer, logger): 74 | start_iter_epoch = 15 75 | 76 | model.train() 77 | loss_m = AverageMeter() 78 | soft_loss_m = AverageMeter() 79 | last_idx = len(loader) - 1 80 | 81 | for batch_idx, batch in enumerate(loader): 82 | # step = epoch * len(loader) + batch_idx 83 | # adjust_learning_rate(opt.base_lr, optimizer, step, len(loader)) 84 | lr = get_lr(optimizer) 85 | 86 | last_batch = batch_idx == last_idx 87 | img, label, name = batch 88 | input, label = img.cuda(non_blocking=True), label.cuda(non_blocking=True) 89 | 90 | output = model(input) 91 | 92 | criterion = nn.CrossEntropyLoss(ignore_index=255) 93 | loss = criterion(output, label) 94 | # if epoch <= start_iter_epoch: 95 | # loss = criterion(output, label) 96 | # 97 | # else: 98 | # soft_loss = get_soft_loss(output, soft_label) 99 | # loss = criterion(output, label) + soft_loss 100 | # soft_loss_m.update(soft_loss.item(), input.size(0)) 101 | 102 | loss_m.update(loss.item(), input.size(0)) 103 | optimizer.zero_grad() 104 | loss.backward() 105 | optimizer.step() 106 | torch.cuda.synchronize() 107 | 108 | log_interval = len(loader) // 3 109 | if last_batch or batch_idx % log_interval == 0: 110 | logger.info('Train:{} [{:>4d}/{} ({:>3.0f}%)] ' 111 | 'Loss:({loss.avg:>6.4f}) ' 112 | #'soft:({soft.avg:>6.4f}) ' 113 | 'LR:{lr:.3e} '.format( 114 | epoch, 115 | batch_idx, len(loader), 116 | 100. * batch_idx / last_idx, 117 | loss=loss_m, 118 | lr=lr)) 119 | 120 | # if epoch >= start_iter_epoch: 121 | # logger.info('predict_train') 122 | # predict_train(opt, model) 123 | 124 | return OrderedDict([('loss', loss_m.avg)]) 125 | 126 | 127 | def predict_train(opt, model): 128 | train_txt_path, val_txt_path, test_txt_path = create_data_path(opt) 129 | save_path = os.path.join(opt.save_path, opt.dataset) 130 | train_dataset = Dataset_soft(opt, save_path, train_txt_path, flag='train', transform=None) 131 | loader = DataLoader(train_dataset, batch_size=1, num_workers=opt.num_workers, 132 | pin_memory=opt.pin) 133 | 134 | model.eval() 135 | for batch_idx, batch in enumerate(tqdm(loader)): 136 | img, label, cls_label, soft_label, filename = batch 137 | with torch.no_grad(): 138 | input, cls_label = img.cuda(non_blocking=True), cls_label.cuda(non_blocking=True) 139 | coarse_mask = model.forward(input) 140 | mask = F.softmax(coarse_mask, dim=1) 141 | mask = clean_mask(mask, cls_label) 142 | mask = mask.squeeze(0).permute(1, 2, 0).data.cpu().numpy() 143 | 144 | save_mask_path = save_path + '/Second/soft_label' 145 | check_dir(save_mask_path) 146 | path = save_mask_path + '/' + filename[0][:-4] + '.npy' 147 | 148 | if os.path.isfile(path): 149 | soft_label = np.load(path) 150 | new_soft_label = (soft_label + mask) / 2 151 | np.save(path, new_soft_label) 152 | else: 153 | np.save(path, mask) 154 | 155 | 156 | def validate(opt, epoch, model, val_dataset, logger): 157 | compute_metric = IOUMetric_tensor(num_classes=opt.num_classes) 158 | hist = torch.zeros([opt.num_classes, opt.num_classes]).cuda() 159 | 160 | val_loader = DataLoader(val_dataset, batch_size=opt.batch_size, shuffle=True, num_workers=opt.num_workers, pin_memory=opt.pin) 161 | model.eval() 162 | 163 | for batch_idx, batch in enumerate(val_loader): 164 | with torch.no_grad(): 165 | img, label, name = batch 166 | input, label = img.cuda(non_blocking=True), label.cuda(non_blocking=True) 167 | output = model(input) 168 | torch.cuda.synchronize() 169 | 170 | output = torch.argmax(output, dim=1) 171 | hist += compute_metric.get_hist(output, label) 172 | 173 | iou = torch.diag(hist) / (hist.sum(dim=1) + hist.sum(dim=0) - torch.diag(hist)) 174 | miou = torch.mean(iou).float() 175 | 176 | logger.info('epoch:%d current_miou:%.4f current_iou:%s' % (epoch, miou, str(iou))) 177 | 178 | return miou 179 | 180 | 181 | if __name__ == '__main__': 182 | main() -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | from utils.util import * 2 | from utils.metric import * -------------------------------------------------------------------------------- /utils/__pycache__/__init__.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Luffy03/DBFNet/5ec2458098c206aa9780e693ed81320e548da4a2/utils/__pycache__/__init__.cpython-35.pyc -------------------------------------------------------------------------------- /utils/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Luffy03/DBFNet/5ec2458098c206aa9780e693ed81320e548da4a2/utils/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /utils/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Luffy03/DBFNet/5ec2458098c206aa9780e693ed81320e548da4a2/utils/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /utils/__pycache__/metric.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Luffy03/DBFNet/5ec2458098c206aa9780e693ed81320e548da4a2/utils/__pycache__/metric.cpython-36.pyc -------------------------------------------------------------------------------- /utils/__pycache__/metric.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Luffy03/DBFNet/5ec2458098c206aa9780e693ed81320e548da4a2/utils/__pycache__/metric.cpython-37.pyc -------------------------------------------------------------------------------- /utils/__pycache__/paint.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Luffy03/DBFNet/5ec2458098c206aa9780e693ed81320e548da4a2/utils/__pycache__/paint.cpython-35.pyc -------------------------------------------------------------------------------- /utils/__pycache__/paint.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Luffy03/DBFNet/5ec2458098c206aa9780e693ed81320e548da4a2/utils/__pycache__/paint.cpython-36.pyc -------------------------------------------------------------------------------- /utils/__pycache__/paint.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Luffy03/DBFNet/5ec2458098c206aa9780e693ed81320e548da4a2/utils/__pycache__/paint.cpython-37.pyc -------------------------------------------------------------------------------- /utils/__pycache__/query.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Luffy03/DBFNet/5ec2458098c206aa9780e693ed81320e548da4a2/utils/__pycache__/query.cpython-37.pyc -------------------------------------------------------------------------------- /utils/__pycache__/util.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Luffy03/DBFNet/5ec2458098c206aa9780e693ed81320e548da4a2/utils/__pycache__/util.cpython-35.pyc -------------------------------------------------------------------------------- /utils/__pycache__/util.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Luffy03/DBFNet/5ec2458098c206aa9780e693ed81320e548da4a2/utils/__pycache__/util.cpython-36.pyc -------------------------------------------------------------------------------- /utils/__pycache__/util.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Luffy03/DBFNet/5ec2458098c206aa9780e693ed81320e548da4a2/utils/__pycache__/util.cpython-37.pyc -------------------------------------------------------------------------------- /utils/metric.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import os 3 | import numpy as np 4 | import torch 5 | import math 6 | 7 | 8 | def cal_kappa(hist): 9 | if hist.sum() == 0: 10 | po = 0 11 | pe = 1 12 | kappa = 0 13 | else: 14 | po = np.diag(hist).sum() / hist.sum() 15 | pe = np.matmul(hist.sum(1), hist.sum(0).T) / hist.sum() ** 2 16 | if pe == 1: 17 | kappa = 0 18 | else: 19 | kappa = (po - pe) / (1 - pe) 20 | return kappa 21 | 22 | 23 | def cal_fscore(hist): 24 | TP = np.diag(hist) 25 | FP = hist.sum(axis=0) - np.diag(hist) 26 | FN = hist.sum(axis=1) - np.diag(hist) 27 | TN = hist.sum() - (FP + FN + TP) 28 | 29 | precision = TP / (TP + FP) 30 | recall = TP / (TP + FN) 31 | 32 | f_score = 2 * precision * recall / (precision + recall) 33 | m_f_score = np.mean(f_score) 34 | return f_score, m_f_score 35 | 36 | 37 | class IOUMetric: 38 | """ 39 | Class to calculate mean-iou using fast_hist method 40 | """ 41 | 42 | def __init__(self, num_classes): 43 | self.num_classes = num_classes 44 | self.hist = np.zeros((num_classes, num_classes)) 45 | 46 | def get_hist(self, label_pred, label_true): 47 | # 找出标签中需要计算的类别,去掉了背景 48 | mask = (label_true >= 0) & (label_true < self.num_classes) 49 | # # np.bincount计算了从0到n**2-1这n**2个数中每个数出现的次数,返回值形状(n, n) 50 | hist = np.bincount( 51 | self.num_classes * label_true[mask].astype(int) + 52 | label_pred[mask], minlength=self.num_classes ** 2).reshape(self.num_classes, self.num_classes) 53 | return hist 54 | 55 | # 输入:预测值和真实值 56 | # 语义分割的任务是为每个像素点分配一个label 57 | def evaluate(self, predictions, gts): 58 | for lp, lt in zip(predictions, gts): 59 | assert len(lp.flatten()) == len(lt.flatten()) 60 | self.hist += self.get_hist(lp.flatten(), lt.flatten()) 61 | # miou 62 | iou = np.diag(self.hist) / (self.hist.sum(axis=1) + self.hist.sum(axis=0) - np.diag(self.hist)) 63 | miou = np.nanmean(iou) 64 | # dice 65 | dice = 2 * np.diag(self.hist) / (self.hist.sum(axis=1) + self.hist.sum(axis=0)) 66 | mdice = np.nanmean(dice) 67 | 68 | # -----------------其他指标------------------------------ 69 | # mean acc 70 | acc = np.diag(self.hist).sum() / self.hist.sum() 71 | acc_cls = np.nanmean(np.diag(self.hist) / self.hist.sum(axis=1)) 72 | freq = self.hist.sum(axis=1) / self.hist.sum() 73 | fwavacc = (freq[freq > 0] * iou[freq > 0]).sum() 74 | return acc, acc_cls, iou, miou, dice, mdice, fwavacc 75 | 76 | 77 | class IOUMetric_tensor: 78 | """ 79 | Class to calculate mean-iou with tensor_type using fast_hist method 80 | """ 81 | 82 | def __init__(self, num_classes): 83 | self.num_classes = num_classes 84 | self.hist = torch.zeros([num_classes, num_classes]) 85 | 86 | def get_hist(self, label_pred, label_true): 87 | # 找出标签中需要计算的类别,去掉了背景 88 | mask = (label_true >= 0) & (label_true < self.num_classes) 89 | # # np.bincount计算了从0到n**2-1这n**2个数中每个数出现的次数,返回值形状(n, n) 90 | hist = torch.bincount( 91 | self.num_classes * label_true[mask] + 92 | label_pred[mask], minlength=self.num_classes ** 2).view(self.num_classes, self.num_classes) 93 | return hist 94 | 95 | # 输入:预测值和真实值 96 | # 语义分割的任务是为每个像素点分配一个label 97 | def evaluate(self, predictions, gts): 98 | for lp, lt in zip(predictions, gts): 99 | assert len(lp.flatten()) == len(lt.flatten()) 100 | self.hist += self.get_hist(lp.flatten(), lt.flatten()) 101 | # miou 102 | iou = torch.diag(self.hist) / (self.hist.sum(dim=1) + self.hist.sum(dim=0) - torch.diag(self.hist)) 103 | miou = torch.mean(iou) 104 | # dice 105 | dice = 2 * torch.diag(self.hist) / (self.hist.sum(dim=1) + self.hist.sum(dim=0)) 106 | mdice = torch.mean(dice) 107 | 108 | # -----------------其他指标------------------------------ 109 | # mean acc 110 | acc = torch.diag(self.hist).sum() / self.hist.sum() 111 | acc_cls = torch.mean(np.diag(self.hist) / self.hist.sum(dim=1)) 112 | freq = self.hist.sum(dim=1) / self.hist.sum() 113 | fwavacc = (freq[freq > 0] * iou[freq > 0]).sum() 114 | return acc, acc_cls, iou, miou, dice, mdice, fwavacc 115 | 116 | 117 | def eval_hist(hist): 118 | # hist must be numpy 119 | kappa = cal_kappa(hist) 120 | iou = np.diag(hist) / (hist.sum(axis=1) + hist.sum(axis=0) - np.diag(hist)) 121 | miou = np.nanmean(iou) 122 | 123 | # f_score 124 | f_score, m_f_score = cal_fscore(hist) 125 | 126 | # mean acc 127 | acc = np.diag(hist).sum() / hist.sum() 128 | acc_cls = np.diag(hist) / hist.sum(axis=1) 129 | 130 | return iou, miou, kappa, acc, acc_cls, f_score, m_f_score 131 | 132 | 133 | def cls_accuracy(output, target): 134 | """Computes the precision@k for the specified values of k""" 135 | y_true = target.data.cpu().numpy() 136 | y_pred = output.data.cpu().numpy() 137 | 138 | # True Positive:即y_true与y_pred中同时为1的个数 139 | TP = np.sum(np.multiply(y_true, y_pred)) 140 | 141 | # False Positive:即y_true中为0但是在y_pred中被识别为1的个数 142 | FP = np.sum(np.logical_and(np.equal(y_true, 0), np.equal(y_pred, 1))) 143 | # False Negative:即y_true中为1但是在y_pred中被识别为0的个数 144 | FN = np.sum(np.logical_and(np.equal(y_true, 1), np.equal(y_pred, 0))) 145 | # True Negative:即y_true与y_pred中同时为0的个数 146 | TN = np.sum(np.logical_and(np.equal(y_true, 0), np.equal(y_pred, 0))) 147 | 148 | # 根据上面得到的值计算A、P、R、F1 149 | kappa = (TP + TN) / (TP + FP + FN + TN) # y_pred与y_ture中同时为1或0 150 | P = TP / (TP + FP) # y_pred中为1的元素同时在y_true中也为1 151 | R = TP / (TP + FN) # y_true中为1的元素同时在y_pred中也为1 152 | F1 = 2 * P * R / (P + R) 153 | 154 | return P, R, kappa, F1 155 | 156 | 157 | def per_cls_accuracy(output, target): 158 | y_true = target.data.cpu().numpy() 159 | y_pred = output.data.cpu().numpy() 160 | acc_all = [] 161 | for cls in range(y_true.shape[1]): 162 | true = np.sum(y_true[:, cls] == y_pred[:, cls]) 163 | acc = true/y_true.shape[0] 164 | acc_all.append(acc) 165 | return acc_all 166 | 167 | 168 | if __name__ == '__main__': 169 | a = torch.randint(0, 2, size=[30, 4]) 170 | b = torch.randint(0, 2, size=[30, 4]) 171 | 172 | P, R, kappa, F1 = cls_accuracy(a, b) 173 | print(P, R, kappa, F1) 174 | acc_all = per_cls_accuracy(a, b) 175 | print(acc_all) -------------------------------------------------------------------------------- /utils/paint.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | from PIL import Image 4 | import numpy as np 5 | import matplotlib.pyplot as plt 6 | from datafiles.color_dict import color_dict 7 | 8 | 9 | def hex_to_rgb(value): 10 | lv = len(value) 11 | return tuple(int(value[i:i + lv // 3], 16) for i in range(0, lv, lv // 3)) 12 | 13 | 14 | def create_visual_anno(anno, dict, flag=None): 15 | # assert np.max(anno) <= num_classes-1, "only %d classes are supported, add new color in label2color_dict" % (num_classes) 16 | 17 | if flag == 'hex': 18 | rgb_dict = {} 19 | for keys, hex_value in dict.items(): 20 | rgb_value = hex_to_rgb(hex_value) 21 | rgb_dict[keys] = rgb_value 22 | else: 23 | rgb_dict = dict 24 | 25 | # visualize 26 | visual_anno = np.zeros((anno.shape[0], anno.shape[1], 3), dtype=np.uint8) 27 | for i in range(visual_anno.shape[0]): # i for h 28 | for j in range(visual_anno.shape[1]): 29 | color = rgb_dict[anno[i, j]] 30 | visual_anno[i, j, 0] = color[0] 31 | visual_anno[i, j, 1] = color[1] 32 | visual_anno[i, j, 2] = color[2] 33 | 34 | return visual_anno 35 | 36 | -------------------------------------------------------------------------------- /utils/query.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import cv2 4 | 5 | 6 | def ask_for_query(index, path): 7 | index = list(index) 8 | query_all = [] 9 | 10 | path_list = os.listdir(path) 11 | 12 | for idx in index: 13 | query = [] 14 | for path_index in path_list: 15 | if (set(list(path_index)) & set(index)) == set(idx): 16 | query.append(path_index) 17 | query_all.append((idx, query)) 18 | 19 | query_dict = dict(query_all) 20 | 21 | return query_dict 22 | 23 | 24 | def ask_for_query_new(index, path): 25 | index = list(index) 26 | query_all = [] 27 | 28 | path_list = os.listdir(path) 29 | 30 | for i in path_list: 31 | # the shared 32 | I = (set(list(i)) & set(index)) 33 | if I != set(): 34 | if len(index) > 1: 35 | if I != set(index): 36 | query_all.append(i) 37 | else: 38 | query_all.append(i) 39 | 40 | return query_all 41 | 42 | 43 | if __name__ == '__main__': 44 | path = '/home/ggm/WLS/semantic/dataset/potsdam/train/img_index' 45 | 46 | index = '012' 47 | query = ask_for_query_new(index, path) 48 | print(query) 49 | 50 | -------------------------------------------------------------------------------- /utils/util.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import logging 6 | import time 7 | import json 8 | import cv2 9 | import os 10 | from utils.paint import create_visual_anno 11 | from pathlib import Path 12 | from models import * 13 | from PIL import Image 14 | import tifffile 15 | import numpy as np 16 | import ttach as tta 17 | from math import * 18 | 19 | 20 | def read(path): 21 | if path.endswith('.tif'): 22 | return tifffile.imread(path) 23 | else: 24 | img = Image.open(path) 25 | return np.asarray(img) 26 | 27 | 28 | def imsave(path, img): 29 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) 30 | cv2.imwrite(path[:-4] + '.png', img) 31 | 32 | 33 | def check_dir(dir): 34 | if not os.path.exists(dir): 35 | os.mkdir(dir) 36 | 37 | 38 | def get_lr(optimizer): 39 | for param_group in optimizer.param_groups: 40 | return param_group['lr'] 41 | 42 | 43 | class AverageMeter: 44 | """Computes and stores the average and current value""" 45 | def __init__(self): 46 | self.reset() 47 | 48 | def reset(self): 49 | self.val = 0 50 | self.avg = 0 51 | self.sum = 0 52 | self.count = 0 53 | 54 | def update(self, val, n=1): 55 | self.val = val 56 | self.sum += val * n 57 | self.count += n 58 | self.avg = self.sum / self.count 59 | 60 | 61 | def adjust_learning_rate(learning_rate, optimizer, step, length, num_epochs=20): 62 | """Sets the learning rate to the initial LR decayed by 10 every 20 epochs""" 63 | stride = num_epochs * length 64 | lr = learning_rate * (0.1 ** (step // stride)) 65 | if step % stride == 0: 66 | print("learning_rate change to:%.8f" % (lr)) 67 | for param_group in optimizer.param_groups: 68 | param_group['lr'] = lr 69 | 70 | 71 | def save_pred_anno(out, save_pred_dir, filename, dict, flag=False): 72 | out = out.squeeze(0) 73 | check_dir(save_pred_dir) 74 | 75 | # save predict dir 76 | save_value_dir = save_pred_dir + '/label' 77 | check_dir(save_value_dir) 78 | save_value = os.path.join(save_value_dir, filename[0].split('/')[-1]) 79 | label = out.data.cpu().numpy() 80 | cv2.imwrite(save_value, label) 81 | 82 | if flag is True: 83 | # save predict_visual dir 84 | save_anno_dir = save_pred_dir + '/label_vis' 85 | check_dir(save_anno_dir) 86 | save_anno = os.path.join(save_anno_dir, filename[0].split('/')[-1]) 87 | 88 | img_visual = create_visual_anno(out.data.cpu().numpy(), dict=dict, flag=flag) 89 | b, g, r = cv2.split(img_visual) 90 | img_visual_rgb = cv2.merge([r, g, b]) 91 | cv2.imwrite(save_anno, img_visual_rgb) 92 | 93 | 94 | def save_pred_anno_numpy(out, save_pred_dir, filename, dict, flag=False): 95 | check_dir(save_pred_dir) 96 | 97 | # save predict dir 98 | save_value_dir = save_pred_dir + '/label' 99 | check_dir(save_value_dir) 100 | save_value = os.path.join(save_value_dir, filename[0].split('/')[-1]) 101 | 102 | cv2.imwrite(save_value, out) 103 | 104 | if flag is True: 105 | # save predict_visual dir 106 | save_anno_dir = save_pred_dir + '/label_vis' 107 | check_dir(save_anno_dir) 108 | save_anno = os.path.join(save_anno_dir, filename[0].split('/')[-1]) 109 | 110 | img_visual = create_visual_anno(out, dict=dict, flag=flag) 111 | b, g, r = cv2.split(img_visual) 112 | img_visual_rgb = cv2.merge([r, g, b]) 113 | cv2.imwrite(save_anno, img_visual_rgb) 114 | 115 | 116 | def save2json(metric_dict, save_path): 117 | file_ = open(save_path, 'w') 118 | file_.write(json.dumps(metric_dict, ensure_ascii=False,indent=2)) 119 | file_.close() 120 | 121 | 122 | def create_save_path(opt): 123 | save_path = os.path.join(opt.save_path, opt.dataset) 124 | exp_path = os.path.join(save_path, opt.experiment_name) 125 | 126 | log_path = os.path.join(exp_path, 'log') 127 | checkpoint_path = os.path.join(exp_path, 'checkpoint') 128 | predict_test_path = os.path.join(exp_path, 'predict_test') 129 | predict_train_path = os.path.join(exp_path, 'predict_train') 130 | predict_val_path = os.path.join(exp_path, 'predict_val') 131 | 132 | check_dir(save_path), check_dir(exp_path), check_dir(log_path), check_dir(checkpoint_path), \ 133 | check_dir(predict_test_path), check_dir(predict_train_path), check_dir(predict_val_path) 134 | 135 | return log_path, checkpoint_path, predict_test_path, predict_train_path, predict_val_path 136 | 137 | 138 | def create_data_path(opt): 139 | data_inform_path = os.path.join(opt.data_inform_path, opt.dataset) 140 | 141 | # for vai second 142 | # train_txt_path = os.path.join(data_inform_path, 'sec_train.txt') 143 | # val_txt_path = os.path.join(data_inform_path, 'sec_val.txt') 144 | 145 | train_txt_path = os.path.join(data_inform_path, 'seg_train.txt') 146 | val_txt_path = os.path.join(data_inform_path, 'seg_val.txt') 147 | test_txt_path = os.path.join(data_inform_path, 'seg_test.txt') 148 | 149 | return train_txt_path, val_txt_path, test_txt_path 150 | 151 | 152 | def create_logger(log_path): 153 | time_str = time.strftime('%Y-%m-%d-%H-%M') 154 | 155 | log_file = '{}.log'.format(time_str) 156 | 157 | final_log_file = os.path.join(log_path , log_file) 158 | head = '%(asctime)-15s %(message)s' 159 | logging.basicConfig(filename=str(final_log_file), 160 | format=head) 161 | logger = logging.getLogger() 162 | logger.setLevel(logging.INFO) 163 | console = logging.StreamHandler() 164 | logging.getLogger('').addHandler(console) 165 | 166 | tensorboard_log_dir = Path(log_path)/'scalar'/time_str 167 | print('=>creating {}'.format(tensorboard_log_dir)) 168 | tensorboard_log_dir.mkdir(parents=True, exist_ok=True) 169 | 170 | return logger, str(tensorboard_log_dir) 171 | 172 | 173 | def resize_label(label, size): 174 | if len(label.size()) == 3: 175 | label = label.unsqueeze(1) 176 | label = F.interpolate(label, size=(size, size), mode='bilinear', align_corners=True) 177 | 178 | return label 179 | 180 | 181 | def get_mean_std(flag): 182 | if flag == 'potsdam': 183 | means = [86.42521457, 92.37607528, 85.74658389] 184 | std = [35.58409409, 35.45218542, 36.91464009] 185 | elif flag == 'vaihingen': 186 | means = [119.14901543, 83.04203606, 81.79810095] 187 | std = [55.63038161, 40.67145608, 38.61447761] 188 | 189 | else: 190 | means = 0 191 | std = 0 192 | print('error') 193 | return means, std 194 | 195 | 196 | def Normalize(img, flag='potsdam'): 197 | means, std = get_mean_std(flag) 198 | img = (img - means) / std 199 | 200 | return img 201 | 202 | 203 | def Normalize_back(img, flag='potsdam'): 204 | means, std = get_mean_std(flag) 205 | 206 | means = means[:3] 207 | std = std[:3] 208 | 209 | img = img * std + means 210 | 211 | return img 212 | 213 | 214 | def pad_image(img, target_size): 215 | """Pad an image up to the target size.""" 216 | rows_missing = target_size[0] - img.shape[2] 217 | cols_missing = target_size[1] - img.shape[3] 218 | padded_img = F.pad(img, (0, 0, rows_missing, cols_missing), 'constant', 0) 219 | return padded_img 220 | 221 | 222 | def pre_slide(model, image, num_classes=7, tile_size=(512, 512), tta=False): 223 | image_size = image.shape # bigger than (1, 3, 512, 512), i.e. (1,3,1024,1024) 224 | overlap = 1 / 2 # 每次滑动的重合率为1/2 225 | 226 | stride = ceil(tile_size[0] * (1 - overlap)) # 滑动步长:769*(1-1/3) = 513 227 | tile_rows = int(ceil((image_size[2] - tile_size[0]) / stride) + 1) # 行滑动步数:(1024-769)/513 + 1 = 2 228 | tile_cols = int(ceil((image_size[3] - tile_size[1]) / stride) + 1) # 列滑动步数:(2048-769)/513 + 1 = 4 229 | 230 | full_probs = torch.zeros((1, num_classes, image_size[2], image_size[3])).cuda() # 初始化全概率矩阵 shape(1024,2048,19) 231 | 232 | count_predictions = torch.zeros((1, 1, image_size[2], image_size[3])).cuda() # 初始化计数矩阵 shape(1024,2048,19) 233 | tile_counter = 0 # 滑动计数0 234 | 235 | for row in range(tile_rows): # row = 0,1 236 | for col in range(tile_cols): # col = 0,1,2,3 237 | x1 = int(col * stride) # 起始位置x1 = 0 * 513 = 0 238 | y1 = int(row * stride) # y1 = 0 * 513 = 0 239 | x2 = min(x1 + tile_size[1], image_size[3]) # 末位置x2 = min(0+769, 2048) 240 | y2 = min(y1 + tile_size[0], image_size[2]) # y2 = min(0+769, 1024) 241 | x1 = max(int(x2 - tile_size[1]), 0) # 重新校准起始位置x1 = max(769-769, 0) 242 | y1 = max(int(y2 - tile_size[0]), 0) # y1 = max(769-769, 0) 243 | 244 | img = image[:, :, y1:y2, x1:x2] # 滑动窗口对应的图像 imge[:, :, 0:769, 0:769] 245 | padded_img = pad_image(img, tile_size) # padding 确保扣下来的图像为769*769 246 | 247 | tile_counter += 1 # 计数加1 248 | # print("Predicting tile %i" % tile_counter) 249 | 250 | # 将扣下来的部分传入网络,网络输出概率图。 251 | # use softmax 252 | if tta is True: 253 | padded = tta_predict(model, padded_img) 254 | else: 255 | padded = model(padded_img) 256 | padded = F.softmax(padded, dim=1) 257 | 258 | pre = padded[:, :, 0:img.shape[2], 0:img.shape[3]] # 扣下相应面积 shape(769,769,19) 259 | 260 | count_predictions[:, :, y1:y2, x1:x2] += 1 # 窗口区域内的计数矩阵加1 261 | full_probs[:, :, y1:y2, x1:x2] += pre # 窗口区域内的全概率矩阵叠加预测结果 262 | 263 | # average the predictions in the overlapping regions 264 | full_probs /= count_predictions # 全概率矩阵 除以 计数矩阵 即得 平均概率 265 | 266 | return full_probs # 返回整张图的平均概率 shape(1, 1, 1024,2048) 267 | 268 | 269 | def tta_predict(model, img): 270 | tta_transforms = tta.Compose( 271 | [ 272 | tta.HorizontalFlip(), 273 | tta.Rotate90(angles=[0, 90, 180, 270]), 274 | ]) 275 | 276 | xs = [] 277 | 278 | for t in tta_transforms: 279 | aug_img = t.augment_image(img) 280 | aug_x = model(aug_img) 281 | aug_x = F.softmax(aug_x, dim=1) 282 | 283 | x = t.deaugment_mask(aug_x) 284 | xs.append(x) 285 | 286 | xs = torch.cat(xs, 0) 287 | x = torch.mean(xs, dim=0, keepdim=True) 288 | 289 | return x --------------------------------------------------------------------------------