├── 2024.03.28.pptx ├── README.md ├── cfg ├── BCDD.py ├── CDD-linux ├── CDD.py └── __init__.py ├── checkpoints └── just ignore it ├── dataset ├── __init__.py └── rs.py ├── example └── CDD │ ├── script.ipynb │ ├── test.txt │ ├── test │ └── the same settings as train │ ├── test1.txt │ ├── train.txt │ ├── train │ ├── A │ │ └── replace this with data │ ├── B │ │ └── replace this with data │ └── OUT1 │ │ └── replace this with data │ ├── train1.txt │ ├── val.txt │ ├── val │ └── the same settings as train │ └── val1.txt ├── layer ├── __init__.py ├── function.py ├── loss.py └── loss_test.ipynb ├── model ├── __init__.py ├── files.py ├── model_store.py └── siameseNet │ ├── __init__.py │ ├── attention.py │ ├── d_aa.py │ ├── dares.py │ ├── res.py │ ├── res50.py │ ├── resbase.py │ ├── resnet.py │ └── test.ipynb ├── pretrained └── note.txt ├── test.py ├── train.py └── utils ├── __init__.py ├── metric.py ├── transforms.py └── utils.py /2024.03.28.pptx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/silence-tang/DASNet-V2/e56f6e095e1e232b85365ef500267665c4c47fb5/2024.03.28.pptx -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # DASNet-V2 2 | An improved version of DASNet, achieving 0.96+ in F1_score. 3 | 欢迎大家star、follow,一起学习交流,共同进步。 4 | 5 | 6 | ## 一、前言 7 | - 参考论文:DASNet: Dual attentive fully convolutional siamese networks for change detection in high-resolution satellite images 8 | - 参考代码:https://github.com/lehaifeng/DASNet 9 | - 说明:本人复现该论文时发现开源代码有部分写错的地方,均已修正 10 | 11 | ## 二、改进思路 12 | - 整个模型给人头重脚轻的感觉:一开始用很重的模块只是为了算一个融合空间和通道注意力的embedding map,可以理解是为了获得更为有效的representation,但是后续模块就比较弱,直接计算了ground truth和embedding map之间的loss,这可能太粗糙了。而且**源码中是用32x32的尺度算loss,这个过程中ground truth的精度其实已经被丢失了,因此模型学出来的知识可能只够用于检测低分辨率图像**。 13 | - 改进思路:借鉴语义分割中的encoder-decoder结构,一方面可以把embedding map逐步由32x32上采样到256x256,然后计算一个loss,另一方面可以引入多尺度loss,在decoder不同layer的输出feature map上与ground truth计算loss,最后把多层loss联合起来优化网络,使得模型可以学到多尺度的信息。此外,由于变化检测任务其实本质上就是像素级别的分类的任务,本质上和语义分割没有区别(语义分割是单点像素的多分类任务,而变化检测是特殊的语义分割,即单点像素处的二分类任务),因此语义分割那块的所有模型都可以拿来借鉴。 14 | - 注:当前版本暂未联合Decoder的multi-layer losses进行训练,仅利用了Decoder最后一层输出的256x256进行训练。 15 | 16 | ## 三、其他优化 17 | - √ 每1000个batch都输出P-R图+metric.json,命名方式为`epoch_batch_idx_P_R.jpg`和`epoch_batch_idx_metric.json`。 18 | - √ 每20个batch,将当前的epoch, batch_id, f_score, best f_score保存到log中。 19 | - √ 每20个batch打印learning_rate至控制台。 20 | - √ 最佳模型保存的命名需要带上`epoch_batchid`,便于查找。 21 | - √ 命令行参数加入--start_epoch --resume以实现训练的暂停与继续。 22 | - √ 优化了checkpoints/目录结构。 23 | 24 | ## 四、配置 25 | - 硬件:单卡 NVIDIA RTX 2080Ti 26 | - Requirements: 参考原论文的Github项目 27 | - 数据集:Change Detection Dataset 28 | - 预训练backbone:resnet50-19c8e357.pth 29 | - Train from scratch:进入DASNet-V2目录,直接在命令行运行`python train.py`(可根据需要自行添加其他命令行参数) 30 | 31 | 32 | ## 五、注意点 33 | 34 | - 最佳模型路径在cfg文件夹里,linux/windows需要设置相应的路径。 35 | - labels必须为0/1二值单通道图,处理完之后要生成新的train1.txt, val1.txt, test1.txt,其中mask的路径为OUT1,然后改cfg。 36 | - max f_score和AUC基本成正比,因此可以用它衡量模型的性能优劣,也即是可以用max f_score来筛出best model。 37 | - 原始的输入图像经过了减去均值的中心化操作。 38 | 39 | ## 六、实验结果 40 | ### DASNet 41 | - 基础模型 42 | - 训练 60 epoch,前 40 epoch lr=1e-4,后 20 epoch lr = 1e-5 43 | - **best_max_f = 0.9299360653186844** 44 | - best_epoch: 49 45 | - best_batch_idx: 500 46 | 47 | 48 | ![VnwCy.png](https://i.328888.xyz/img/2022/12/04/VnwCy.png) 49 | 50 | ### DASNet-V2 51 | - SiamseNet+Decoder 52 | - Decoder = 3 * (TransposedConv + BN+ ReLu) 53 | - 训练 80 epoch,0-20 epoch lr = 1e-4,21-40 epoch lr = 5e-5,41-60 epoch lr = 1e-5,61-70 epoch lr = 5e-6,71-80 epoch lr = 1e-6 54 | - batch size = 4 55 | - **best_max_f = 0.9566290556389714** 56 | - best_epoch: 76 57 | - best_batch_idx: 2499 58 | 59 | 60 | ![Vn2X5.png](https://i.328888.xyz/img/2022/12/04/Vn2X5.png) 61 | 62 | ### DASNet-V2 (replace l_2 loss with cossim loss) 63 | - cossim loss,双边阈值为m_1=0.1, m_2=0.8 64 | - SiamseNet+Decoder 65 | - 训练 70 epoch,0-20 epoch lr = 1e-4,21-40 epoch lr = 5e-5,41-60 epoch lr = 1e-5,61-70 epoch lr = 5e-6 66 | - **best_max_f = 0.9601** 67 | - best_epoch: 63,best_batch_idx: 2000 68 | 69 | ## 七、其他实验 70 | 均在DASNet-Decoder基础上做实验,为了减小单次实验的时长,可以减小epoch数为45(后面再提升也只是0.2左右) 71 | 72 | - **空白对照** 73 | - **SiamseNet+Decoder,Distance Metric = $L_2$** 74 | - 训练 80 epoch,0-20 epoch lr = 1e-4,21-40 epoch lr = 5e-5,41-60 epoch lr = 1e-5,61-70 epoch lr = 5e-6,71-80 epoch lr = 1e-6 75 | - **best_max_f = 0.9566** 76 | - best_epoch: 76,best_batch_idx: 2499 77 | - **EXP1——修改损失函数中距离的度量标准为余弦相似度,同时修改双边阈值为$m_1=0.1, m_2=0.8$,看看效果升还是降。同时可以观察用余弦相似度计算的热力图** 78 | - **SiamseNet+Decoder,Distance Metric = Cosine Sim** 79 | - 训练 70 epoch,0-20 epoch lr = 1e-4,21-40 epoch lr = 5e-5,41-60 epoch lr = 1e-5,61-70 epoch lr = 5e-6 80 | - **best_max_f = 0.9601** 81 | - best_epoch: 63,best_batch_idx: 2000 82 | 83 | ![8c5MV.png](https://i.328888.xyz/2023/01/31/8c5MV.png) 84 | ![8cscb.png](https://i.328888.xyz/2023/01/31/8cscb.png) 85 | 86 | 上图中,前两行是cossim的结果,第3行是使用原始$L_2$loss的结果。不难看出,在使用cossim作为损失函数的计算范式后,变化检测结果的可视化结果也是相比原来的更令人满意的。具体体现为:少了很多斑块状噪声块,而且变化区域和不变区域的颜色对比也更加鲜明,这表示改进过的模型能够更加“坚决”地划分出变与不变的区域。 87 | 88 | - **EXP2——损失函数有三项,$Loss = λ_1L_{sa} + λ_2L_{ca} + λ_3L_{saca}$,在实验中发现这三项loss总是非常接近的,应该只用最后的$L_{saca}$也行?** 89 | - **SiamseNet+Decoder,Distance Metric = $L_2$**,$Loss = L_{saca}$ 90 | - 训练 80 epoch,0-20 epoch lr = 1e-4,21-40 epoch lr = 5e-5,41-60 epoch lr = 1e-5,61-70 epoch lr = 5e-6,71-80 epoch lr = 1e-6 91 | - **best_max_f = 0.9546** 92 | - best_epoch: 68,best_batch_idx: 2000 93 | 94 | ![8miyy.png](https://i.328888.xyz/2023/01/31/8miyy.png) 95 | 96 | 通过结果,我们可以看出,即使将总损失替换为$L_{sasc}$,也不会对模型性能造成太大的损害,F1_score仅仅下降了0.2。 97 | 98 | - **EXP3——消融实验** 99 | - 无SAM、无CAM,其他不变 100 | - **SiamseNet+Decoder,Distance Metric = $L_2$** 101 | - 训练 80 epoch,0-20 epoch lr = 1e-4,21-40 epoch lr = 5e-5,41-60 epoch lr = 1e-5,61-70 epoch lr = 5e-6,71-80 epoch lr = 1e-6 102 | - **best_max_f = 0.9543** 103 | - best_epoch: 75,best_batch_idx: 2499 104 | 105 | ![8mb7o.png](https://i.328888.xyz/2023/01/31/8mb7o.png) 106 | 107 | - 无SAM、有CAM,其他不变:可以不做实验,因为发现即使SAM、CAM都去掉,性能也只是下降了0.2左右 108 | - 有SAM、无CAM,其他不变:可以不做实验,因为发现即使SAM、CAM都去掉,性能也只是下降了0.2左右 109 | 110 | ## 八、实验结论 111 | - 加上优化设计的decoder结构后,原始的DASNet性能得到了显著提升(F1_score从0.92+提升至0.95+) 112 | - 将损失函数的计算范式从$L_2$loss改为余弦相似度,DASNet F1_score可达0.96+ 113 | - 通过消融实验发现:原始论文中联合三个feature map的loss计算最终损失其实对结果的提升效果不大,只利用CAM+SAM的融合损失作为总损失也不会下降多少性能(下降0.2左右) 114 | - 通过消融实验发现:原始论文中的亮点(CAM、SAM机制)在具有decoder结构下的DASNet网络中无法对提升性能起到显著作用。若将这两个模块去掉,纯粹使用ResNet backbone抽取出来的特征去过一遍decoder,然后进行后续的loss计算,F1_score也只是下降了0.2左右。 115 | - 当去掉SAM、CAM模块后,平均显存使用量约为7.5G,而加上这两个模块后,平均显存使用量接近9G。这说明DASNet-V2的encoder-decoder+无CAM/SAM结构,能够在减小计算量的情况下超越原论文中汇报的F1_score(0.9543 > 0.9299)。 116 | 117 | 118 | -------------------------------------------------------------------------------- /cfg/BCDD.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | BASE_PATH = '/home/lhf/yzy/cd_res' 4 | PRETRAIN_MODEL_PATH = os.path.join(BASE_PATH,'pretrain') 5 | DATA_PATH = '/home/lhf/yzy/changedetection/SceneChangeDet/BCD' 6 | TRAIN_DATA_PATH = os.path.join(DATA_PATH) 7 | TRAIN_LABEL_PATH = os.path.join(DATA_PATH) 8 | TRAIN_TXT_PATH = os.path.join(TRAIN_DATA_PATH,'train.txt') 9 | VAL_DATA_PATH = os.path.join(DATA_PATH) 10 | VAL_LABEL_PATH = os.path.join(DATA_PATH) 11 | VAL_TXT_PATH = os.path.join(VAL_DATA_PATH,'val.txt') 12 | SAVE_PATH = '/home/lhf/yzy/cdout/bone/resnet50/BCD2' 13 | SAVE_CKPT_PATH = os.path.join(SAVE_PATH,'ckpt') 14 | if not os.path.exists(SAVE_CKPT_PATH): 15 | os.mkdir(SAVE_CKPT_PATH) 16 | SAVE_PRED_PATH = os.path.join(SAVE_PATH,'prediction') 17 | if not os.path.exists(SAVE_PRED_PATH): 18 | os.mkdir(SAVE_PRED_PATH) 19 | TRAINED_BEST_PERFORMANCE_CKPT = os.path.join(SAVE_CKPT_PATH,'model_best.pth') 20 | INIT_LEARNING_RATE = 1e-4 21 | DECAY = 5e-5 22 | MOMENTUM = 0.90 23 | MAX_ITER = 40000 24 | BATCH_SIZE = 1 25 | TRANSFROM_SCALES= (256,256) 26 | T0_MEAN_VALUE = (98.62,113.27,123.59) 27 | T1_MEAN_VALUE = (117.38 ,123.09 , 123.20) 28 | -------------------------------------------------------------------------------- /cfg/CDD-linux: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | # BASE_PATH = '/home/lhf/yzy/cd_res' 4 | PRETRAIN_MODEL_PATH = '/root/autodl-tmp/DASNet-master/pretrained' 5 | DATA_PATH = '/root/autodl-tmp/DASNet-master/example/CDD' 6 | 7 | TRAIN_DATA_PATH = os.path.join(DATA_PATH) 8 | TRAIN_LABEL_PATH = os.path.join(DATA_PATH) 9 | # TRAIN_TXT_PATH = os.path.join(TRAIN_DATA_PATH, 'train.txt') 10 | TRAIN_TXT_PATH = os.path.join(TRAIN_DATA_PATH, 'train1.txt') 11 | 12 | VAL_DATA_PATH = os.path.join(DATA_PATH) 13 | VAL_LABEL_PATH = os.path.join(DATA_PATH) 14 | # VAL_TXT_PATH = os.path.join(VAL_DATA_PATH, 'val.txt') 15 | VAL_TXT_PATH = os.path.join(VAL_DATA_PATH, 'val1.txt') 16 | 17 | SAVE_PATH = '/root/autodl-tmp/DASNet-master/checkpoints' 18 | SAVE_CKPT_PATH = os.path.join(SAVE_PATH,'ckpt') 19 | if not os.path.exists(SAVE_CKPT_PATH): 20 | os.mkdir(SAVE_CKPT_PATH) 21 | 22 | SAVE_PRED_PATH = os.path.join(SAVE_PATH,'prediction') 23 | if not os.path.exists(SAVE_PRED_PATH): 24 | os.mkdir(SAVE_PRED_PATH) 25 | 26 | TRAINED_BEST_PERFORMANCE_CKPT = os.path.join(SAVE_CKPT_PATH, 'model_best_exp1.pth') 27 | TRAINED_RESUME_PATH = SAVE_PRED_PATH 28 | INIT_LEARNING_RATE = 1e-4 29 | DECAY = 5e-5 30 | MOMENTUM = 0.90 31 | MAX_EPOCH = 80 32 | TRAIN_BATCH_SIZE = 4 33 | VAL_BATCH_SIZE = 1 34 | THRESHS = [0.1, 0.3, 0.5] 35 | THRESH = 0.1 36 | LOSS_PARAM_CONV = 3 37 | LOSS_PARAM_FC = 3 38 | TRANSFROM_SCALES= (256, 256) 39 | T0_MEAN_VALUE = (87.72, 100.2, 90.43) 40 | T1_MEAN_VALUE = (120.24, 127.93, 121.18) 41 | -------------------------------------------------------------------------------- /cfg/CDD.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | # BASE_PATH = '/home/lhf/yzy/cd_res' 4 | PRETRAIN_MODEL_PATH = r'C:\Users\HP\Desktop\DASNet\DASNet-master\pretrained' 5 | DATA_PATH = r'C:\Users\HP\Desktop\DASNet\DASNet-master\example\CDD' 6 | 7 | TRAIN_DATA_PATH = os.path.join(DATA_PATH) 8 | TRAIN_LABEL_PATH = os.path.join(DATA_PATH) 9 | # TRAIN_TXT_PATH = os.path.join(TRAIN_DATA_PATH, 'train.txt') 10 | TRAIN_TXT_PATH = os.path.join(TRAIN_DATA_PATH, 'train1.txt') 11 | 12 | VAL_DATA_PATH = os.path.join(DATA_PATH) 13 | VAL_LABEL_PATH = os.path.join(DATA_PATH) 14 | # VAL_TXT_PATH = os.path.join(VAL_DATA_PATH, 'val.txt') 15 | VAL_TXT_PATH = os.path.join(VAL_DATA_PATH, 'val1.txt') 16 | 17 | SAVE_PATH = r'C:\Users\HP\Desktop\DASNet\DASNet-master\checkpoints' 18 | SAVE_CKPT_PATH = os.path.join(SAVE_PATH,'ckpt') 19 | if not os.path.exists(SAVE_CKPT_PATH): 20 | os.mkdir(SAVE_CKPT_PATH) 21 | 22 | SAVE_PRED_PATH = os.path.join(SAVE_PATH,'prediction') 23 | if not os.path.exists(SAVE_PRED_PATH): 24 | os.mkdir(SAVE_PRED_PATH) 25 | 26 | TRAINED_BEST_PERFORMANCE_CKPT = os.path.join(SAVE_CKPT_PATH, 'model_best_exp1.pth') 27 | INIT_LEARNING_RATE = 1e-4 28 | DECAY = 5e-5 29 | MOMENTUM = 0.90 30 | MAX_ITER = 40000 31 | BATCH_SIZE = 1 32 | THRESHS = [0.1, 0.3, 0.5] 33 | THRESH = 0.1 34 | LOSS_PARAM_CONV = 3 35 | LOSS_PARAM_FC = 3 36 | TRANSFROM_SCALES= (256, 256) 37 | T0_MEAN_VALUE = (87.72, 100.2, 90.43) 38 | T1_MEAN_VALUE = (120.24, 127.93, 121.18) 39 | -------------------------------------------------------------------------------- /cfg/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/silence-tang/DASNet-V2/e56f6e095e1e232b85365ef500267665c4c47fb5/cfg/__init__.py -------------------------------------------------------------------------------- /checkpoints/just ignore it: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /dataset/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/silence-tang/DASNet-V2/e56f6e095e1e232b85365ef500267665c4c47fb5/dataset/__init__.py -------------------------------------------------------------------------------- /dataset/rs.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data.dataset import Dataset 3 | import numpy as np 4 | import os 5 | import scipy.io 6 | import scipy.misc as m 7 | from PIL import Image 8 | import matplotlib.pyplot as plt 9 | plt.switch_backend('agg') 10 | import utils.transforms as trans 11 | import cv2 12 | import cfg.CDD as cfg 13 | 14 | IMG_EXTENSIONS = [ 15 | '.jpg', '.JPG', '.jpeg', '.JPEG', 16 | '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', '.tif', 17 | ] 18 | 19 | # def is_image_file(filename): 20 | # return any(filename.endswith(extension) for extension in IMG_EXTENSIONS) 21 | 22 | # def pil_loader(path): 23 | # # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835) 24 | # with open(path, 'rb') as f: 25 | # with Image.open(f) as img: 26 | # return img.convert('RGB') 27 | 28 | # def accimage_loader(path): 29 | # import accimage 30 | # try: 31 | # return accimage.Image(path) 32 | # except IOError: 33 | # # Potentially a decoding problem, fall back to PIL.Image 34 | # return pil_loader(path) 35 | 36 | # def default_loader(path): 37 | # from torchvision import get_image_backend 38 | # if get_image_backend() == 'accimage': 39 | # return accimage_loader(path) 40 | # else: 41 | # return pil_loader(path) 42 | 43 | # palette = [0, 0, 0, 255, 0, 0] 44 | 45 | # def colorize_mask(mask): 46 | # # mask: numpy array of the mask 47 | # new_mask = Image.fromarray(mask.astype(np.uint8)).convert('P') 48 | # new_mask.putpalette(palette) 49 | 50 | # return new_mask 51 | 52 | # def get_pascal_labels(): 53 | # return np.asarray([[0,0,0],[0,0,255]]) 54 | 55 | # def decode_segmap(temp, plot=False): 56 | # label_colours = get_pascal_labels() 57 | # r = temp.copy() 58 | # g = temp.copy() 59 | # b = temp.copy() 60 | # for l in range(0, 2): 61 | # r[temp == l] = label_colours[l, 0] 62 | # g[temp == l] = label_colours[l, 1] 63 | # b[temp == l] = label_colours[l, 2] 64 | 65 | # rgb = np.zeros((temp.shape[0], temp.shape[1], 3)) 66 | # rgb[:, :, 0] = r 67 | # rgb[:, :, 1] = g 68 | # rgb[:, :, 2] = b 69 | # #rgb = np.resize(rgb,(321,321,3)) 70 | # if plot: 71 | # plt.imshow(rgb) 72 | # plt.show() 73 | # else: 74 | # return rgb 75 | 76 | #--------------------------------------------------------------------------------------------------------------------- 77 | class Dataset(Dataset): 78 | def __init__(self, img_path, label_path, file_name_txt_path, split_flag, transform=True, transform_med = None): 79 | self.label_path = label_path 80 | self.img_path = img_path 81 | self.img_txt_path = file_name_txt_path 82 | self.imgs_path_list = np.loadtxt(self.img_txt_path, dtype=str) 83 | self.flag = split_flag 84 | self.transform = transform 85 | self.transform_med = transform_med 86 | self.img_label_path_pairs = self.get_img_label_path_pairs() 87 | 88 | # 返回A, B, OUT的路径字典 89 | def get_img_label_path_pairs(self): 90 | img_label_pair_list = {} 91 | # 训练集 92 | if self.flag =='train': 93 | for idx, did in enumerate(open(self.img_txt_path)): 94 | try: 95 | image1_name, image2_name, mask_name = did.strip("\n").split(' ') 96 | except ValueError: # Adhoc for test. 97 | image_name = mask_name = did.strip("\n") 98 | extract_name = image1_name[image1_name.rindex('/') + 1: image1_name.rindex('.')] 99 | img1_file = os.path.join(self.img_path, image1_name) 100 | img2_file = os.path.join(self.img_path, image2_name) 101 | lbl_file = os.path.join(self.label_path, mask_name) 102 | img_label_pair_list.setdefault(idx, [img1_file, img2_file, lbl_file, image2_name]) 103 | # 测试集 104 | if self.flag == 'val': 105 | self.label_ext = '.png' 106 | for idx , did in enumerate(open(self.img_txt_path)): 107 | try: 108 | image1_name, image2_name, mask_name = did.strip("\n").split(' ') 109 | except ValueError: # Adhoc for test. 110 | image_name = mask_name = did.strip("\n") 111 | extract_name = image1_name[image1_name.rindex('/') +1: image1_name.rindex('.')] 112 | img1_file = os.path.join(self.img_path, image1_name) 113 | img2_file = os.path.join(self.img_path, image2_name) 114 | lbl_file = os.path.join(self.label_path, mask_name) 115 | img_label_pair_list.setdefault(idx, [img1_file, img2_file, lbl_file, image2_name]) 116 | 117 | return img_label_pair_list 118 | 119 | def data_transform(self, img1, img2, lbl): 120 | img1, img2 = img1[:, :, ::-1], img2[:, :, ::-1] # RGB -> BGR 121 | img1, img2 = img1.astype(np.float64), img2.astype(np.float64) 122 | img1 -= cfg.T0_MEAN_VALUE 123 | img2 -= cfg.T1_MEAN_VALUE # T0_MEAN_VALUE = (87.72, 100.2, 90.43) 124 | img1, img2 = img1.transpose(2, 0, 1), img2.transpose(2, 0, 1) 125 | img1, img2 = torch.from_numpy(img1).float(), torch.from_numpy(img2).float() 126 | lbl = torch.from_numpy(lbl).long() 127 | # lbl_reverse = torch.from_numpy(lbl_reverse).long() 128 | return img1, img2, lbl 129 | 130 | def __getitem__(self, index): 131 | img1_path, img2_path, label_path, filename = self.img_label_path_pairs[index] 132 | # load images 133 | img1 = Image.open(img1_path) 134 | img2 = Image.open(img2_path) 135 | height, width, _ = np.array(img1, dtype=np.uint8).shape 136 | if self.transform_med != None: 137 | img1 = self.transform_med(img1) 138 | img2 = self.transform_med(img2) 139 | img1 = np.array(img1, dtype=np.uint8) 140 | img2 = np.array(img2, dtype=np.uint8) 141 | # load labels 142 | if self.flag == 'train': 143 | label = Image.open(label_path) 144 | if self.transform_med != None: 145 | label = self.transform_med(label) 146 | label = np.array(label,dtype=np.int32) 147 | 148 | if self.flag == 'val': 149 | label = Image.open(label_path) 150 | if self.transform_med != None: 151 | label = self.transform_med(label) 152 | label = np.array(label, dtype=np.int32) 153 | 154 | if self.transform: 155 | img1, img2, label = self.data_transform(img1, img2, label) 156 | # 返回实体对象,图像均为array形式 157 | # 返回图像pair列表['.../A/00000.jpg', '.../B/00000.jpg', '.../OUT/00000.jpg'] 158 | return img1, img2, label, str(filename), int(height), int(width) 159 | 160 | # 返回某数据集的pair总数 161 | def __len__(self): 162 | return len(self.img_label_path_pairs) 163 | 164 | -------------------------------------------------------------------------------- /example/CDD/script.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 10, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import matplotlib as plt\n", 10 | "import numpy as np\n", 11 | "import cv2\n", 12 | "import torch" 13 | ] 14 | }, 15 | { 16 | "cell_type": "code", 17 | "execution_count": 12, 18 | "metadata": {}, 19 | "outputs": [], 20 | "source": [ 21 | "for i in range(10000):\n", 22 | " in_path_name = 'train/OUT/' + str(i).rjust(5, '0') + '.jpg'\n", 23 | " out_path_name = 'train/OUT1/' + str(i).rjust(5, '0') + '.jpg'\n", 24 | " tmp = cv2.imread(in_path_name, flags=0)\n", 25 | " tmp_tensor = torch.tensor(tmp >= 128)\n", 26 | " tmp_tensor = tmp_tensor.int()\n", 27 | " tmp_after_img = tmp_tensor.numpy()\n", 28 | " cv2.imwrite(out_path_name, tmp_after_img)" 29 | ] 30 | }, 31 | { 32 | "cell_type": "code", 33 | "execution_count": null, 34 | "metadata": {}, 35 | "outputs": [], 36 | "source": [ 37 | "for i in range(2998):\n", 38 | " in_path_name = 'val/OUT/' + str(i).rjust(5, '0') + '.jpg'\n", 39 | " out_path_name = 'val/OUT1/' + str(i).rjust(5, '0') + '.jpg'\n", 40 | " tmp = cv2.imread(in_path_name, flags=0) # 读入为单通道灰度图\n", 41 | " tmp_tensor = torch.tensor(tmp >= 128)\n", 42 | " tmp_tensor = tmp_tensor.int()\n", 43 | " tmp_after_img = tmp_tensor.numpy()\n", 44 | " cv2.imwrite(out_path_name, tmp_after_img)" 45 | ] 46 | }, 47 | { 48 | "cell_type": "code", 49 | "execution_count": 11, 50 | "metadata": {}, 51 | "outputs": [], 52 | "source": [ 53 | "for i in range(3000):\n", 54 | " in_path_name = 'test/OUT/' + str(i).rjust(5, '0') + '.jpg'\n", 55 | " out_path_name = 'test/OUT1/' + str(i).rjust(5, '0') + '.jpg'\n", 56 | " tmp = cv2.imread(in_path_name, flags=0)\n", 57 | " tmp_tensor = torch.tensor(tmp >= 128)\n", 58 | " tmp_tensor = tmp_tensor.int()\n", 59 | " tmp_after_img = tmp_tensor.numpy()\n", 60 | " cv2.imwrite(out_path_name, tmp_after_img)" 61 | ] 62 | }, 63 | { 64 | "cell_type": "code", 65 | "execution_count": 15, 66 | "metadata": {}, 67 | "outputs": [], 68 | "source": [ 69 | "new_file = open('train1.txt', 'w+')\n", 70 | "for line in open('train.txt'):\n", 71 | " img0_name, img1_name, mask_name = line.strip(\"\\n\").split(' ') # train/OUT/00000.jpg\n", 72 | " mask_name_new = 'train/OUT1/' + mask_name.split(\"/\")[2] \n", 73 | " line_new = img0_name + ' ' + img1_name + ' ' + mask_name_new\n", 74 | " # print(line_new)\n", 75 | " new_file.writelines(line_new + '\\n')\n", 76 | " \n", 77 | "new_file.close()" 78 | ] 79 | }, 80 | { 81 | "cell_type": "code", 82 | "execution_count": 16, 83 | "metadata": {}, 84 | "outputs": [], 85 | "source": [ 86 | "new_file = open('val1.txt', 'w+')\n", 87 | "for line in open('val.txt'):\n", 88 | " img0_name, img1_name, mask_name = line.strip(\"\\n\").split(' ') # train/OUT/00000.jpg\n", 89 | " mask_name_new = 'val/OUT1/' + mask_name.split(\"/\")[2] \n", 90 | " line_new = img0_name + ' ' + img1_name + ' ' + mask_name_new\n", 91 | " # print(line_new)\n", 92 | " new_file.writelines(line_new + '\\n')\n", 93 | " \n", 94 | "new_file.close()" 95 | ] 96 | }, 97 | { 98 | "cell_type": "code", 99 | "execution_count": 17, 100 | "metadata": {}, 101 | "outputs": [], 102 | "source": [ 103 | "new_file = open('test1.txt', 'w+')\n", 104 | "for line in open('test.txt'):\n", 105 | " img0_name, img1_name, mask_name = line.strip(\"\\n\").split(' ') # train/OUT/00000.jpg\n", 106 | " mask_name_new = 'test/OUT1/' + mask_name.split(\"/\")[2] \n", 107 | " line_new = img0_name + ' ' + img1_name + ' ' + mask_name_new\n", 108 | " # print(line_new)\n", 109 | " new_file.writelines(line_new + '\\n')\n", 110 | " \n", 111 | "new_file.close()" 112 | ] 113 | } 114 | ], 115 | "metadata": { 116 | "kernelspec": { 117 | "display_name": "Python 3.6.13 ('gluon')", 118 | "language": "python", 119 | "name": "python3" 120 | }, 121 | "language_info": { 122 | "codemirror_mode": { 123 | "name": "ipython", 124 | "version": 3 125 | }, 126 | "file_extension": ".py", 127 | "mimetype": "text/x-python", 128 | "name": "python", 129 | "nbconvert_exporter": "python", 130 | "pygments_lexer": "ipython3", 131 | "version": "3.6.13" 132 | }, 133 | "orig_nbformat": 4, 134 | "vscode": { 135 | "interpreter": { 136 | "hash": "9862ae77e9daaaf9c9239620ed827aad4ce184b3776eb7a3f75df899d88e405b" 137 | } 138 | } 139 | }, 140 | "nbformat": 4, 141 | "nbformat_minor": 2 142 | } 143 | -------------------------------------------------------------------------------- /example/CDD/test/the same settings as train: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /example/CDD/train/A/replace this with data: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /example/CDD/train/B/replace this with data: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /example/CDD/train/OUT1/replace this with data: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /example/CDD/val/the same settings as train: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /layer/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/silence-tang/DASNet-V2/e56f6e095e1e232b85365ef500267665c4c47fb5/layer/__init__.py -------------------------------------------------------------------------------- /layer/function.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | #### source code from 5 | #### https://github.com/ignacio-rocco/cnngeometric_pytorch 6 | #### 7 | class FeatureCorrelation(nn.Module): 8 | def __init__(self,scale): 9 | super(FeatureCorrelation, self).__init__() 10 | self.scale = scale 11 | 12 | def forward(self, feature_A, feature_B): 13 | b, c, h, w = feature_A.size() 14 | # reshape features for matrix multiplication 15 | feature_A = feature_A.transpose(2, 3).contiguous().view(b, c, h * w) 16 | feature_B = feature_B.view(b, c, h * w).transpose(1, 2) 17 | # perform matrix mult. 18 | feature_mul = torch.bmm(feature_B, feature_A) 19 | correlation_tensor = self.scale * feature_mul.view(b, h, w, h * w).transpose(2, 3).transpose(1, 2) 20 | return correlation_tensor 21 | 22 | class l2normalization(nn.Module): 23 | def __init__(self,scale): 24 | 25 | super(l2normalization, self).__init__() 26 | self.scale = scale 27 | 28 | def forward(self, x, dim=1): 29 | '''out = scale * x / sqrt(\sum x_i^2)''' 30 | #f = x.data.cpu().numpy() 31 | #scal = self.scale * x * x.pow(2).sum(dim).clamp(min=1e-12).rsqrt().expand_as(x) 32 | #sca = scal.data.cpu().numpy() 33 | return self.scale * x * x.pow(2).sum(dim).clamp(min=1e-12).rsqrt().expand_as(x) 34 | 35 | class l1normalization(nn.Module): 36 | def __init__(self,scale): 37 | super(l1normalization, self).__init__() 38 | self.scale = scale 39 | 40 | def forward(self,x,dim=1): 41 | # out = scale * x / sum(abs(x)) 42 | return self.scale * x * x.pow(1).sum(dim).clamp(min=1e-12).rsqrt().expand_as(x) 43 | 44 | class scale_feature(nn.Module): 45 | def __init__(self,scale): 46 | super(scale_feature, self).__init__() 47 | self.scale = scale 48 | 49 | def forward(self,x): 50 | return self.scale * x 51 | 52 | class Mahalanobis_Distance(nn.Module): 53 | def __init__(self): 54 | super(Mahalanobis_Distance, self).__init__() 55 | 56 | def cal_con(self): 57 | pass 58 | 59 | def cal_invert_matrix(self): 60 | pass 61 | 62 | def forward(self,x1,x2): 63 | dis_abs = x1 - x2 64 | 65 | 66 | 67 | 68 | 69 | -------------------------------------------------------------------------------- /layer/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.autograd import Variable 5 | import numpy as np 6 | 7 | # ------------------------------------------------------------------------------- 8 | class ContrastiveLoss1(nn.Module): 9 | def __init__(self, margin1 = 0.3, margin2=2.2, eps=1e-6): 10 | super(ContrastiveLoss1, self).__init__() 11 | self.margin1 = margin1 12 | self.margin2 = margin2 13 | self.eps = eps 14 | 15 | def forward(self, x1, x2, y): 16 | diff = torch.abs(x1 - x2) 17 | dist_sq = torch.pow(diff + self.eps, 2).sum(dim=1) # 加一个微小扰动 18 | dist = torch.sqrt(dist_sq) # 求距离(L2) 19 | mdist_pos = torch.clamp(dist-self.margin1, min=0.0) 20 | mdist_neg = torch.clamp(self.margin2-dist, min=0.0) 21 | loss_pos = (1 - y) * (mdist_pos.pow(2)) 22 | loss_neg = y * (mdist_neg.pow(2)) 23 | loss = torch.mean(loss_pos + loss_neg) # 没加权吗? 24 | return loss 25 | 26 | 27 | # ------------------------------------------------------------------------------- 28 | class CLNew(nn.Module): 29 | def __init__(self, margin1 = 0.3, margin2=2.2, eps=1e-6): 30 | super(CLNew, self).__init__() 31 | self.margin1 = margin1 32 | self.margin2 = margin2 33 | self.eps = eps 34 | 35 | def forward(self, x1, x2, y): 36 | diff = torch.abs(x1 - x2) # 8x32x256x256 37 | dist_sq = torch.pow(diff + self.eps, 2).sum(dim=1) # 加一个微小扰动 32x32 -> 256x256 38 | dist = torch.sqrt(dist_sq) # 求距离(L2) 32x32 -> 256x256 39 | mdist_pos = torch.clamp(dist-self.margin1, min=0.0) 40 | mdist_neg = torch.clamp(self.margin2-dist, min=0.0) 41 | w1 = 1 / 0.147 42 | w2 = 1 / (1 - 0.147) 43 | loss_pos = w2 * ((1 - y) * (mdist_pos.pow(2))) # 逐元素乘 w2小,使网络较为不关注unchanged pairs 44 | loss_neg = w1 * (y * (mdist_neg.pow(2))) # 逐元素乘 w1大,使网络重点关注changed pairs 45 | loss = torch.mean(loss_pos + loss_neg) 46 | return loss 47 | 48 | 49 | class KLCoefficient(nn.Module): 50 | def __init__(self): 51 | super(KLCoefficient, self).__init__() 52 | 53 | def forward(self,hist1,hist2): 54 | kl = F.kl_div(hist1,hist2) 55 | dist = 1. / 1 + kl 56 | return dist -------------------------------------------------------------------------------- /layer/loss_test.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 2, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import torch\n", 10 | "import torch.nn as nn\n", 11 | "import torch.nn.functional as F\n", 12 | "from torch.autograd import Variable\n", 13 | "import numpy as np" 14 | ] 15 | }, 16 | { 17 | "cell_type": "code", 18 | "execution_count": 8, 19 | "metadata": {}, 20 | "outputs": [ 21 | { 22 | "data": { 23 | "text/plain": [ 24 | "tensor(1947.2969)" 25 | ] 26 | }, 27 | "execution_count": 8, 28 | "metadata": {}, 29 | "output_type": "execute_result" 30 | } 31 | ], 32 | "source": [ 33 | "x1 = torch.randn((1, 512, 32, 32))\n", 34 | "x2 = torch.randn((1, 512, 32, 32))\n", 35 | "y = (torch.randn((1, 32, 32)) > 0.5).float()\n", 36 | "\n", 37 | "diff = torch.abs(x1 - x2)\n", 38 | "\n", 39 | "dist_sq = torch.pow(diff + 1e-6, 2).sum(dim=1) # 加一个微小扰动, sum(dim=1)沿通道方向求和,故输出尺寸为(1,32,32)\n", 40 | "\n", 41 | "dist = torch.sqrt(dist_sq) # 求距离(L2)\n", 42 | "\n", 43 | "mdist_pos = torch.clamp(0.3-dist, min=0.0) # m1 = 0.3, unchanged\n", 44 | "mdist_neg = torch.clamp(dist-2.2, min=0.0) # m2 = 2.2, changed\n", 45 | "\n", 46 | "w1 = 1 / 0.147\n", 47 | "w2 = 1 / (1 - 0.147)\n", 48 | "loss_pos = w2 * ((1 - y) * (mdist_pos.pow(2))) # 矩阵乘 w2小,使网络较为不关注unchanged pairs\n", 49 | "loss_neg = w1 * (y * (mdist_neg.pow(2))) # 矩阵乘 w1大,使网络重点关注changed pairs\n", 50 | "\n", 51 | "loss = torch.mean(loss_pos + loss_neg) # 原代码没加权\n", 52 | "loss" 53 | ] 54 | }, 55 | { 56 | "cell_type": "code", 57 | "execution_count": 122, 58 | "metadata": {}, 59 | "outputs": [ 60 | { 61 | "data": { 62 | "text/plain": [ 63 | "(tensor([[[0., 0., 0., 1.],\n", 64 | " [0., 0., 0., 1.],\n", 65 | " [1., 0., 0., 1.],\n", 66 | " [1., 1., 0., 0.]]]),\n", 67 | " tensor([[[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.1429, 0.5714, 1.0000],\n", 68 | " [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.1429, 0.5714, 1.0000],\n", 69 | " [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.1429, 0.5714, 1.0000],\n", 70 | " [0.2857, 0.1633, 0.0408, 0.0000, 0.0000, 0.1429, 0.5714, 1.0000],\n", 71 | " [0.7143, 0.4082, 0.1020, 0.0000, 0.0000, 0.1429, 0.5714, 1.0000],\n", 72 | " [1.0000, 0.6327, 0.2653, 0.1020, 0.0408, 0.1224, 0.4898, 0.8571],\n", 73 | " [1.0000, 0.8163, 0.6327, 0.4082, 0.1633, 0.0612, 0.2449, 0.4286],\n", 74 | " [1.0000, 1.0000, 1.0000, 0.7143, 0.2857, 0.0000, 0.0000, 0.0000]]]))" 75 | ] 76 | }, 77 | "execution_count": 122, 78 | "metadata": {}, 79 | "output_type": "execute_result" 80 | } 81 | ], 82 | "source": [ 83 | "# 原始label可能尺寸小于feature,故插值构成新的label\n", 84 | "def rz_label(label, size):\n", 85 | " gt_e = torch.unsqueeze(label, dim=1)\n", 86 | " interp = nn.functional.interpolate(gt_e, (size[0],size[1]), mode='bilinear', align_corners=True)\n", 87 | " gt_rz = torch.squeeze(interp, dim=1)\n", 88 | " return gt_rz\n", 89 | "\n", 90 | "label = (torch.randn((1,4,4)) > 0.5).float()\n", 91 | "label, rz_label(label, (8,8))" 92 | ] 93 | } 94 | ], 95 | "metadata": { 96 | "kernelspec": { 97 | "display_name": "Python 3.6.13 ('gluon')", 98 | "language": "python", 99 | "name": "python3" 100 | }, 101 | "language_info": { 102 | "codemirror_mode": { 103 | "name": "ipython", 104 | "version": 3 105 | }, 106 | "file_extension": ".py", 107 | "mimetype": "text/x-python", 108 | "name": "python", 109 | "nbconvert_exporter": "python", 110 | "pygments_lexer": "ipython3", 111 | "version": "3.6.13" 112 | }, 113 | "orig_nbformat": 4, 114 | "vscode": { 115 | "interpreter": { 116 | "hash": "9862ae77e9daaaf9c9239620ed827aad4ce184b3776eb7a3f75df899d88e405b" 117 | } 118 | } 119 | }, 120 | "nbformat": 4, 121 | "nbformat_minor": 2 122 | } 123 | -------------------------------------------------------------------------------- /model/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/silence-tang/DASNet-V2/e56f6e095e1e232b85365ef500267665c4c47fb5/model/__init__.py -------------------------------------------------------------------------------- /model/files.py: -------------------------------------------------------------------------------- 1 | import os 2 | import requests 3 | import errno 4 | import shutil 5 | import hashlib 6 | from tqdm import tqdm 7 | import torch 8 | 9 | __all__ = ['save_checkpoint', 'download', 'mkdir', 'check_sha1'] 10 | 11 | def save_checkpoint(state, args, is_best, filename='checkpoint.pth.tar'): 12 | """Saves checkpoint to disk""" 13 | directory = "%s/%s_model/%s/"%(args.dataset, args.model, args.checkname) 14 | if not os.path.exists(directory): 15 | os.makedirs(directory) 16 | filename = directory + filename 17 | torch.save(state, filename) 18 | if is_best: 19 | shutil.copyfile(filename, directory + 'model_best.pth.tar') 20 | 21 | 22 | def download(url, path=None, overwrite=False, sha1_hash=None): 23 | """Download an given URL 24 | Parameters 25 | ---------- 26 | url : str 27 | URL to download 28 | path : str, optional 29 | Destination path to store downloaded file. By default stores to the 30 | current directory with same name as in url. 31 | overwrite : bool, optional 32 | Whether to overwrite destination file if already exists. 33 | sha1_hash : str, optional 34 | Expected sha1 hash in hexadecimal digits. Will ignore existing file when hash is specified 35 | but doesn't match. 36 | Returns 37 | ------- 38 | str 39 | The file path of the downloaded file. 40 | """ 41 | if path is None: 42 | fname = url.split('/')[-1] 43 | else: 44 | path = os.path.expanduser(path) 45 | if os.path.isdir(path): 46 | fname = os.path.join(path, url.split('/')[-1]) 47 | else: 48 | fname = path 49 | 50 | if overwrite or not os.path.exists(fname) or (sha1_hash and not check_sha1(fname, sha1_hash)): 51 | dirname = os.path.dirname(os.path.abspath(os.path.expanduser(fname))) 52 | if not os.path.exists(dirname): 53 | os.makedirs(dirname) 54 | 55 | print('Downloading %s from %s...'%(fname, url)) 56 | r = requests.get(url, stream=True) 57 | if r.status_code != 200: 58 | raise RuntimeError("Failed downloading url %s"%url) 59 | total_length = r.headers.get('content-length') 60 | with open(fname, 'wb') as f: 61 | if total_length is None: # no content length header 62 | for chunk in r.iter_content(chunk_size=1024): 63 | if chunk: # filter out keep-alive new chunks 64 | f.write(chunk) 65 | else: 66 | total_length = int(total_length) 67 | for chunk in tqdm(r.iter_content(chunk_size=1024), 68 | total=int(total_length / 1024. + 0.5), 69 | unit='KB', unit_scale=False, dynamic_ncols=True): 70 | f.write(chunk) 71 | 72 | if sha1_hash and not check_sha1(fname, sha1_hash): 73 | raise UserWarning('File {} is downloaded but the content hash does not match. ' \ 74 | 'The repo may be outdated or download may be incomplete. ' \ 75 | 'If the "repo_url" is overridden, consider switching to ' \ 76 | 'the default repo.'.format(fname)) 77 | 78 | return fname 79 | 80 | 81 | def check_sha1(filename, sha1_hash): 82 | """Check whether the sha1 hash of the file content matches the expected hash. 83 | Parameters 84 | ---------- 85 | filename : str 86 | Path to the file. 87 | sha1_hash : str 88 | Expected sha1 hash in hexadecimal digits. 89 | Returns 90 | ------- 91 | bool 92 | Whether the file content matches the expected hash. 93 | """ 94 | sha1 = hashlib.sha1() 95 | with open(filename, 'rb') as f: 96 | while True: 97 | data = f.read(1048576) 98 | if not data: 99 | break 100 | sha1.update(data) 101 | 102 | return sha1.hexdigest() == sha1_hash 103 | 104 | 105 | def mkdir(path): 106 | """make dir exists okay""" 107 | try: 108 | os.makedirs(path) 109 | except OSError as exc: # Python >2.5 110 | if exc.errno == errno.EEXIST and os.path.isdir(path): 111 | pass 112 | else: 113 | raise 114 | -------------------------------------------------------------------------------- /model/model_store.py: -------------------------------------------------------------------------------- 1 | """Model store which provides pretrained models.""" 2 | from __future__ import print_function 3 | __all__ = ['get_model_file', 'purge'] 4 | import os 5 | import zipfile 6 | 7 | from files import download, check_sha1 8 | 9 | _model_sha1 = {name: checksum for checksum, name in [ 10 | ('853f2fb07aeb2927f7696e166b215609a987fd44', 'resnet50'), 11 | ('5be5422ad7cb6a2e5f5a54070d0aa9affe69a9a4', 'resnet101'), 12 | ('6cb047cda851de6aa31963e779fae5f4c299056a', 'deepten_minc'), 13 | ('fc8c0b795abf0133700c2d4265d2f9edab7eb6cc', 'fcn_resnet50_ade'), 14 | ('eeed8e582f0fdccdba8579e7490570adc6d85c7c', 'fcn_resnet50_pcontext'), 15 | ('54f70c772505064e30efd1ddd3a14e1759faa363', 'psp_resnet50_ade'), 16 | ('558e8904e123813f23dc0347acba85224650fe5f', 'encnet_resnet50_ade'), 17 | ('7846a2f065e90ce70d268ba8ada1a92251587734', 'encnet_resnet50_pcontext'), 18 | ('6f7c372259988bc2b6d7fc0007182e7835c31a11', 'encnet_resnet101_pcontext'), 19 | ]} 20 | 21 | encoding_repo_url = 'https://hangzh.s3.amazonaws.com/' 22 | _url_format = '{repo_url}encoding/models/{file_name}.zip' 23 | 24 | def short_hash(name): 25 | if name not in _model_sha1: 26 | raise ValueError('Pretrained model for {name} is not available.'.format(name=name)) 27 | return _model_sha1[name][:8] 28 | 29 | def get_model_file(name, root='./pretrain_models'): 30 | r"""Return location for the pretrained on local file system. 31 | 32 | This function will download from online model zoo when model cannot be found or has mismatch. 33 | The root directory will be created if it doesn't exist. 34 | 35 | Parameters 36 | ---------- 37 | name : str 38 | Name of the model. 39 | root : str, default './pretrain_models' 40 | Location for keeping the model parameters. 41 | 42 | Returns 43 | ------- 44 | file_path 45 | Path to the requested pretrained model file. 46 | """ 47 | file_name = '{name}-{short_hash}'.format(name=name, short_hash=short_hash(name)) 48 | root = os.path.expanduser(root) 49 | file_path = os.path.join(root, file_name+'.pth') 50 | sha1_hash = _model_sha1[name] 51 | if os.path.exists(file_path): 52 | if check_sha1(file_path, sha1_hash): 53 | return file_path 54 | else: 55 | print('Mismatch in the content of model file detected. Downloading again.') 56 | else: 57 | print('Model file is not found. Downloading.') 58 | 59 | if not os.path.exists(root): 60 | os.makedirs(root) 61 | 62 | zip_file_path = os.path.join(root, file_name+'.zip') 63 | repo_url = os.environ.get('ENCODING_REPO', encoding_repo_url) 64 | if repo_url[-1] != '/': 65 | repo_url = repo_url + '/' 66 | download(_url_format.format(repo_url=repo_url, file_name=file_name), 67 | path=zip_file_path, 68 | overwrite=True) 69 | with zipfile.ZipFile(zip_file_path) as zf: 70 | zf.extractall(root) 71 | os.remove(zip_file_path) 72 | 73 | if check_sha1(file_path, sha1_hash): 74 | return file_path 75 | else: 76 | raise ValueError('Downloaded file has different hash. Please try again.') 77 | 78 | def purge(root='./pretrain_models'): 79 | r"""Purge all pretrained model files in local file store. 80 | 81 | Parameters 82 | ---------- 83 | root : str, default './pretrain_models' 84 | Location for keeping the model parameters. 85 | """ 86 | root = os.path.expanduser(root) 87 | files = os.listdir(root) 88 | for f in files: 89 | if f.endswith(".pth"): 90 | os.remove(os.path.join(root, f)) 91 | 92 | def pretrained_model_list(): 93 | return list(_model_sha1.keys()) 94 | -------------------------------------------------------------------------------- /model/siameseNet/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/silence-tang/DASNet-V2/e56f6e095e1e232b85365ef500267665c4c47fb5/model/siameseNet/__init__.py -------------------------------------------------------------------------------- /model/siameseNet/attention.py: -------------------------------------------------------------------------------- 1 | ########################################################################### 2 | # Created by: CASIA IVA 3 | # Email: jliu@nlpr.ia.ac.cn 4 | # Copyright (c) 2018 5 | ########################################################################### 6 | 7 | import numpy as np 8 | import torch 9 | import math 10 | from torch.nn import Module, Sequential, Conv2d, ReLU,AdaptiveMaxPool2d, AdaptiveAvgPool2d, \ 11 | NLLLoss, BCELoss, CrossEntropyLoss, AvgPool2d, MaxPool2d, Parameter, Linear, Sigmoid, Softmax, Dropout, Embedding 12 | from torch.nn import functional as F 13 | from torch.autograd import Variable 14 | torch_ver = torch.__version__[:3] 15 | 16 | __all__ = ['PAM_Module', 'CAM_Module'] 17 | 18 | # attention between local areas(each pixel pair) 19 | # 建模long-range局部特征的上下文语音信息 20 | class PAM_Module(Module): 21 | """ Position attention module""" 22 | # Ref from SAGAN 23 | def __init__(self, in_dim): 24 | super(PAM_Module, self).__init__() 25 | self.chanel_in = in_dim # 512 26 | self.query_conv = Conv2d(in_channels=in_dim, out_channels=in_dim//8, kernel_size=1) # 512,64 27 | self.key_conv = Conv2d(in_channels=in_dim, out_channels=in_dim//8, kernel_size=1) # 512,64 28 | self.value_conv = Conv2d(in_channels=in_dim, out_channels=in_dim, kernel_size=1) # 512,512 29 | self.gamma = Parameter(torch.zeros(1)) 30 | self.softmax = Softmax(dim=-1) 31 | 32 | def forward(self, x): 33 | m_batchsize, C, height, width = x.size() # (B, 512, H, W) 34 | # Q 35 | proj_query = self.query_conv(x).view(m_batchsize, -1, width*height).permute(0, 2, 1) # (B, HxW, 64) 36 | # K 37 | proj_key = self.key_conv(x).view(m_batchsize, -1, width*height) # (B, 64, HxW) 38 | # original attention matrix 39 | energy = torch.bmm(proj_query, proj_key) # (B, HxW, HxW) 40 | # softmax让各区域间的差异更大 41 | attention = self.softmax(energy) # (B, HxW, HxW) 42 | # V 43 | proj_value = self.value_conv(x).view(m_batchsize, -1, width*height) # (B, 512, HxW) 44 | # 计算由attention score加权的V matrix 45 | out = torch.bmm(proj_value, attention.permute(0, 2, 1)) # (B, 512, HxW) 46 | out = out.view(m_batchsize, C, height, width) # (B, 512, H, W) 47 | out = self.gamma * out + x # (B, 512, H, W) 48 | return out 49 | 50 | # attention between channels(each channel pair, also can be regarded as each ground object pair) 51 | # 建模通道间的语义信息 52 | class CAM_Module(Module): 53 | """ Channel attention module""" 54 | def __init__(self, in_dim): 55 | super(CAM_Module, self).__init__() 56 | self.chanel_in = in_dim 57 | self.gamma = Parameter(torch.zeros(1)) 58 | self.softmax = Softmax(dim=-1) 59 | 60 | def forward(self,x): 61 | m_batchsize, C, height, width = x.size() 62 | proj_query = x.view(m_batchsize, C, -1) # (B, 512, HxW) 63 | proj_key = x.view(m_batchsize, C, -1).permute(0, 2, 1) # (B, HxW, 512) 64 | energy = torch.bmm(proj_query, proj_key) # (B, 512, 512) 65 | energy_new = torch.max(energy, -1, keepdim=True)[0].expand_as(energy) - energy 66 | attention = self.softmax(energy_new) # (B, 512, 512) 67 | proj_value = x.view(m_batchsize, C, -1) # (B, 512, HxW) 68 | out = torch.bmm(attention, proj_value) # (B, 512, HxW) 69 | out = out.view(m_batchsize, C, height, width) # (B, 512, H, W) 70 | out = self.gamma * out + x # (B, 512, H, W) 71 | return out 72 | 73 | -------------------------------------------------------------------------------- /model/siameseNet/d_aa.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import torch.nn.init as init 5 | import layer.function as fun 6 | from attention import * 7 | 8 | def convert_dict_names_for_fucking_faults(): 9 | 10 | deeplab_v2_dict_names_mapping = { 11 | 12 | 'conv1.0' : 'conv1_1', 13 | 'conv1.2' : 'conv1_2', 14 | 'conv2.0' : 'conv2_1', 15 | 'conv2.2' : 'conv2_2', 16 | 'conv3.0' : 'conv3_1', 17 | 'conv3.2' : 'conv3_2', 18 | 'conv3.4' : 'conv3_3', 19 | 'conv4.0' : 'conv4_1', 20 | 'conv4.2' : 'conv4_2', 21 | 'conv4.4' : 'conv4_3', 22 | 'conv5.0' : 'conv5_1', 23 | 'conv5.2' : 'conv5_2', 24 | 'conv5.4' : 'conv5_3'} 25 | 26 | return deeplab_v2_dict_names_mapping 27 | 28 | 29 | class deeplab_V2(nn.Module): 30 | def __init__(self): 31 | super(deeplab_V2, self).__init__() 32 | self.conv1 = nn.Sequential( 33 | nn.Conv2d(in_channels=3,out_channels=64,kernel_size=3,padding=1), 34 | nn.ReLU(inplace=True), 35 | nn.Conv2d(in_channels=64,out_channels=64,kernel_size=3,padding=1), 36 | nn.ReLU(inplace=True), 37 | nn.MaxPool2d(kernel_size=3,stride=2,padding=1,ceil_mode=True), 38 | 39 | ) 40 | self.conv2 = nn.Sequential( 41 | nn.Conv2d(in_channels=64,out_channels=128,kernel_size=3,padding=1), 42 | nn.ReLU(inplace=True), 43 | nn.Conv2d(in_channels=128,out_channels=128,kernel_size=3,padding=1), 44 | nn.ReLU(inplace=True), 45 | nn.MaxPool2d(kernel_size=3,stride=2,padding=1,ceil_mode=True), 46 | 47 | ) 48 | self.conv3 = nn.Sequential( 49 | nn.Conv2d(in_channels=128,out_channels=256,kernel_size=3,padding=1), 50 | nn.ReLU(inplace=True), 51 | nn.Conv2d(in_channels=256,out_channels=256,kernel_size=3,padding=1), 52 | nn.ReLU(inplace=True), 53 | nn.Conv2d(in_channels=256,out_channels=256,kernel_size=3,padding=1), 54 | nn.ReLU(inplace=True), 55 | nn.MaxPool2d(kernel_size=3,stride=2,padding=1,ceil_mode=True), 56 | 57 | ) 58 | self.conv4 = nn.Sequential( 59 | nn.Conv2d(in_channels=256,out_channels=512,kernel_size=3,padding=1), 60 | nn.ReLU(inplace=True), 61 | nn.Conv2d(in_channels=512,out_channels=512,kernel_size=3,padding=1), 62 | nn.ReLU(inplace=True), 63 | nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, padding=1), 64 | nn.ReLU(inplace=True), 65 | nn.MaxPool2d(kernel_size=3, stride=1, padding=1, ceil_mode=True), 66 | 67 | ) 68 | self.conv5 = nn.Sequential( 69 | nn.Conv2d(in_channels=512,out_channels=512,kernel_size=3,dilation=2,padding=2), 70 | nn.ReLU(inplace=True), 71 | nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, dilation=2, padding=2), 72 | nn.ReLU(inplace=True), 73 | nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, dilation=2, padding=2), 74 | nn.ReLU(inplace=True), 75 | 76 | ) 77 | inter_channels = 512 // 4 78 | self.conv5a = nn.Sequential(nn.Conv2d(512, inter_channels, 3, padding=1, bias=False), 79 | 80 | nn.ReLU()) 81 | 82 | self.conv5c = nn.Sequential(nn.Conv2d(512, inter_channels, 3, padding=1, bias=False), 83 | 84 | nn.ReLU()) 85 | 86 | self.sa = PAM_Module(inter_channels) 87 | self.sc = CAM_Module(inter_channels) 88 | self.conv51 = nn.Sequential(nn.Conv2d(inter_channels, inter_channels, 3, padding=1, bias=False), 89 | 90 | nn.ReLU()) 91 | self.conv52 = nn.Sequential(nn.Conv2d(inter_channels, inter_channels, 3, padding=1, bias=False), 92 | 93 | nn.ReLU()) 94 | 95 | self.conv6 = nn.Sequential(nn.Dropout2d(0.1, False), nn.Conv2d(128, 128, 1)) 96 | self.conv7 = nn.Sequential(nn.Dropout2d(0.1, False), nn.Conv2d(128, 128, 1)) 97 | 98 | self.conv8 = nn.Sequential(nn.Dropout2d(0.1, False), nn.Conv2d(128, 128, 1)) 99 | 100 | 101 | 102 | ####### multi-scale contexts ####### 103 | ####### dialtion = 6 ########## 104 | self.fc6_1 = nn.Sequential( 105 | nn.Conv2d(in_channels=512,out_channels=1024,kernel_size=3,dilation=6,padding=6), 106 | nn.ReLU(inplace=True), 107 | nn.Dropout2d(p=0.5) 108 | ) 109 | self.fc7_1 = nn.Sequential( 110 | nn.Conv2d(in_channels=1024,out_channels=1024,kernel_size=1), 111 | nn.ReLU(inplace=True), 112 | nn.Dropout2d(p=0.5) 113 | ) 114 | ####### dialtion = 12 ########## 115 | self.fc6_2 = nn.Sequential( 116 | nn.Conv2d(in_channels=512,out_channels=1024,kernel_size=3,dilation=12,padding=12), 117 | nn.ReLU(inplace=True), 118 | nn.Dropout2d(p=0.5) 119 | ) 120 | self.fc7_2 = nn.Sequential( 121 | nn.Conv2d(in_channels=1024,out_channels=1024,kernel_size=1), 122 | nn.ReLU(inplace=True), 123 | nn.Dropout2d(p=0.5) 124 | ) 125 | ####### dialtion = 18 ########## 126 | self.fc6_3 = nn.Sequential( 127 | nn.Conv2d(in_channels=512, out_channels=1024, kernel_size=3, dilation=18, padding=18), 128 | nn.ReLU(inplace=True), 129 | nn.Dropout2d(p=0.5) 130 | ) 131 | self.fc7_3 = nn.Sequential( 132 | nn.Conv2d(in_channels=1024, out_channels=1024, kernel_size=1), 133 | nn.ReLU(inplace=True), 134 | nn.Dropout2d(p=0.5) 135 | ) 136 | ####### dialtion = 24 ########## 137 | self.fc6_4 = nn.Sequential( 138 | nn.Conv2d(in_channels=512, out_channels=1024, kernel_size=3, dilation=24, padding=24), 139 | nn.ReLU(inplace=True), 140 | nn.Dropout2d(p=0.5) 141 | ) 142 | self.fc7_4 = nn.Sequential( 143 | nn.Conv2d(in_channels=1024, out_channels=1024, kernel_size=1), 144 | nn.ReLU(inplace=True), 145 | nn.Dropout2d(p=0.5) 146 | ) 147 | self.embedding_layer = nn.Conv2d(in_channels=512,out_channels=512,kernel_size=1) 148 | #self.fc8 = nn.Softmax2d() 149 | #self.fc8 = fun.l2normalization(scale=1) 150 | 151 | def forward(self,x): 152 | 153 | x = self.conv1(x) 154 | x = self.conv2(x) 155 | conv3_feature = self.conv3(x) 156 | conv4_feature = self.conv4(conv3_feature) 157 | conv5_feature = self.conv5(conv4_feature) 158 | feat1 = self.conv5a(conv5_feature) 159 | sa_feat = self.sa(feat1) 160 | sa_conv = self.conv51(sa_feat) 161 | sa_output = self.conv6(sa_conv) 162 | feat2 = self.conv5c(conv5_feature) 163 | sc_feat = self.sc(feat2) 164 | sc_conv = self.conv52(sc_feat) 165 | sc_output = self.conv7(sc_conv) 166 | 167 | feat_sum = sa_conv + sc_conv 168 | 169 | sasc_output = self.conv8(feat_sum) 170 | 171 | return sa_output,sc_output,sasc_output 172 | 173 | 174 | class SiameseNet(nn.Module): 175 | def __init__(self,norm_flag = 'l2'): 176 | super(SiameseNet, self).__init__() 177 | self.CNN = deeplab_V2() 178 | if norm_flag == 'l2': 179 | self.norm = F.normalize 180 | if norm_flag == 'exp': 181 | self.norm = nn.Softmax2d() 182 | ''''''''' 183 | def forward(self,t0,t1): 184 | out_t0_embedding = self.CNN(t0) 185 | out_t1_embedding = self.CNN(t1) 186 | #out_t0_conv5_norm,out_t1_conv5_norm = self.norm(out_t0_conv5),self.norm(out_t1_conv5) 187 | #out_t0_fc7_norm,out_t1_fc7_norm = self.norm(out_t0_fc7),self.norm(out_t1_fc7) 188 | out_t0_embedding_norm,out_t1_embedding_norm = self.norm(out_t0_embedding),self.norm(out_t1_embedding) 189 | return [out_t0_embedding_norm,out_t1_embedding_norm] 190 | ''''''''' 191 | 192 | def forward(self,t0,t1): 193 | 194 | 195 | 196 | out_t0_conv5,out_t0_fc7,out_t0_embedding = self.CNN(t0) 197 | out_t1_conv5,out_t1_fc7,out_t1_embedding = self.CNN(t1) 198 | out_t0_conv5_norm,out_t1_conv5_norm = self.norm(out_t0_conv5,2,dim=1),self.norm(out_t1_conv5,2,dim=1) 199 | out_t0_fc7_norm,out_t1_fc7_norm = self.norm(out_t0_fc7,2,dim=1),self.norm(out_t1_fc7,2,dim=1) 200 | out_t0_embedding_norm,out_t1_embedding_norm = self.norm(out_t0_embedding,2,dim=1),self.norm(out_t1_embedding,2,dim=1) 201 | return [out_t0_conv5_norm,out_t1_conv5_norm],[out_t0_fc7_norm,out_t1_fc7_norm],[out_t0_embedding_norm,out_t1_embedding_norm] 202 | 203 | ''''''''' 204 | def forward(self,t0,t1): 205 | out_t0_conv4,out_t0_conv5,out_t0_fc7,out_t0_embedding = self.CNN(t0) 206 | out_t1_conv4,out_t1_conv5,out_t1_fc7,out_t1_embedding = self.CNN(t1) 207 | out_t0_conv4_norm,out_t1_conv4_norm = self.norm(out_t0_conv5),self.norm(out_t1_conv5) 208 | out_t0_conv5_norm,out_t1_conv5_norm = self.norm(out_t0_conv5),self.norm(out_t1_conv5) 209 | out_t0_fc7_norm,out_t1_fc7_norm = self.norm(out_t0_fc7),self.norm(out_t1_fc7) 210 | out_t0_embedding_norm,out_t1_embedding_norm = self.norm(out_t0_embedding),self.norm(out_t1_embedding) 211 | return [out_t0_conv4_norm,out_t1_conv4_norm],[out_t0_conv5_norm,out_t1_conv5_norm],[out_t0_fc7_norm,out_t1_fc7_norm],[out_t0_embedding_norm,out_t1_embedding_norm] 212 | ''''''''' 213 | ''''''''' 214 | def forward(self,t0,t1): 215 | out_t0_conv4,out_t0_conv5,out_t0_fc7 = self.CNN(t0) 216 | out_t1_conv4,out_t1_conv5,out_t1_fc7 = self.CNN(t1) 217 | out_t0_conv4_norm,out_t1_conv4_norm = self.norm(out_t0_conv4),self.norm(out_t1_conv4) 218 | out_t0_conv5_norm,out_t1_conv5_norm = self.norm(out_t0_conv5),self.norm(out_t1_conv5) 219 | out_t0_fc7_norm,out_t1_fc7_norm = self.norm(out_t0_fc7),self.norm(out_t1_fc7) 220 | return [out_t0_conv4_norm,out_t1_conv4_norm],[out_t0_conv5_norm,out_t1_conv5_norm],[out_t0_fc7_norm,out_t1_fc7_norm] 221 | ''''''''' 222 | 223 | def init_parameters_from_deeplab(self,pretrain_vgg16_1024): 224 | 225 | ##### init parameter using pretrain vgg16 model ########### 226 | pretrain_dict_names = convert_dict_names_for_fucking_faults() 227 | keys = sorted(pretrain_dict_names.keys()) 228 | conv_blocks = [self.CNN.conv1, 229 | self.CNN.conv2, 230 | self.CNN.conv3, 231 | self.CNN.conv4, 232 | self.CNN.conv5] 233 | ranges = [[0,2], [0,2], [0,2,4], [0,2,4], [0,2,4]] 234 | for key in keys: 235 | dic_name = pretrain_dict_names[key] 236 | base_conv_name,conv_index,sub_index = dic_name[:5],int(dic_name[4]),int(dic_name[-1]) 237 | conv_blocks[conv_index -1][ranges[sub_index -1][sub_index -1]].weight.data = pretrain_vgg16_1024[key + '.weight'] 238 | conv_blocks[conv_index- 1][ranges[sub_index -1][sub_index -1]].bias.data = pretrain_vgg16_1024[key + '.bias'] 239 | 240 | ####### init fc parameters (transplant) ############## 241 | self.CNN.fc6_1[0].weight.data = pretrain_vgg16_1024['fc6_1.0.weight'].view(self.CNN.fc6_1[0].weight.size()) 242 | self.CNN.fc6_1[0].bias.data = pretrain_vgg16_1024['fc6_1.0.bias'].view(self.CNN.fc6_1[0].bias.size()) 243 | 244 | self.CNN.fc7_1[0].weight.data = pretrain_vgg16_1024['fc7_1.0.weight'].view(self.CNN.fc7_1[0].weight.size()) 245 | self.CNN.fc7_1[0].bias.data = pretrain_vgg16_1024['fc7_1.0.bias'].view(self.CNN.fc7_1[0].bias.size()) 246 | 247 | self.CNN.fc6_2[0].weight.data = pretrain_vgg16_1024['fc6_2.0.weight'].view(self.CNN.fc6_2[0].weight.size()) 248 | self.CNN.fc6_2[0].bias.data = pretrain_vgg16_1024['fc6_2.0.bias'].view(self.CNN.fc6_2[0].bias.size()) 249 | 250 | self.CNN.fc7_2[0].weight.data = pretrain_vgg16_1024['fc7_2.0.weight'].view(self.CNN.fc7_2[0].weight.size()) 251 | self.CNN.fc7_2[0].bias.data = pretrain_vgg16_1024['fc7_2.0.bias'].view(self.CNN.fc7_2[0].bias.size()) 252 | 253 | self.CNN.fc6_3[0].weight.data = pretrain_vgg16_1024['fc6_3.0.weight'].view(self.CNN.fc6_3[0].weight.size()) 254 | self.CNN.fc6_3[0].bias.data = pretrain_vgg16_1024['fc6_3.0.bias'].view(self.CNN.fc6_3[0].bias.size()) 255 | 256 | self.CNN.fc7_3[0].weight.data = pretrain_vgg16_1024['fc7_3.0.weight'].view(self.CNN.fc7_3[0].weight.size()) 257 | self.CNN.fc7_3[0].bias.data = pretrain_vgg16_1024['fc7_3.0.bias'].view(self.CNN.fc7_3[0].bias.size()) 258 | 259 | self.CNN.fc6_4[0].weight.data = pretrain_vgg16_1024['fc6_4.0.weight'].view(self.CNN.fc6_4[0].weight.size()) 260 | self.CNN.fc6_4[0].bias.data = pretrain_vgg16_1024['fc6_4.0.bias'].view(self.CNN.fc6_4[0].bias.size()) 261 | 262 | self.CNN.fc7_4[0].weight.data = pretrain_vgg16_1024['fc7_4.0.weight'].view(self.CNN.fc7_4[0].weight.size()) 263 | self.CNN.fc7_4[0].bias.data = pretrain_vgg16_1024['fc7_4.0.bias'].view(self.CNN.fc7_4[0].bias.size()) 264 | 265 | #init.kaiming_uniform(self.CNN.embedding_layer.weight.data,mode='fan_in') 266 | #init.constant(self.CNN.embedding_layer.bias.data,0) 267 | 268 | def init_parameters(self,pretrain_vgg16_1024): 269 | 270 | ##### init parameter using pretrain vgg16 model ########### 271 | conv_blocks = [self.CNN.conv1, 272 | self.CNN.conv2, 273 | self.CNN.conv3, 274 | self.CNN.conv4, 275 | self.CNN.conv5] 276 | 277 | ranges = [[0, 4], [5, 9], [10, 16], [17, 23], [24, 29]] 278 | features = list(pretrain_vgg16_1024.features.children()) 279 | for idx, conv_block in enumerate(conv_blocks): 280 | for l1, l2 in zip(features[ranges[idx][0]:ranges[idx][1]], conv_block): 281 | if isinstance(l1, nn.Conv2d) and isinstance(l2, nn.Conv2d): 282 | # print idx, l1, l2 283 | assert l1.weight.size() == l2.weight.size() 284 | assert l1.bias.size() == l2.bias.size() 285 | l2.weight.data = l1.weight.data 286 | l2.bias.data = l1.bias.data 287 | 288 | ####### init fc parameters (transplant) ############## 289 | # 290 | self.CNN.fc6[0].weight.data = pretrain_vgg16_1024.classifier[0].weight.data.view(self.CNN.fc6[0].weight.size()) 291 | self.CNN.fc6[0].bias.data = pretrain_vgg16_1024.classifier[0].bias.data.view(self.CNN.fc6[0].bias.size()) 292 | 293 | self.CNN.fc7[0].weight.data = pretrain_vgg16_1024.classifier[3].weight.data.view(self.CNN.fc7[0].weight.size()) 294 | self.CNN.fc7[0].bias.data = pretrain_vgg16_1024.classifier[3].bias.data.view(self.CNN.fc7[0].bias.size()) 295 | 296 | if __name__ == '__main__': 297 | net= deeplab_V2() 298 | print('hh') -------------------------------------------------------------------------------- /model/siameseNet/dares.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | import os 3 | import numpy as np 4 | import torch 5 | import torch.nn as nn 6 | from torch.nn.functional import upsample, normalize 7 | from .attention import PAM_Module 8 | from .attention import CAM_Module 9 | from .resbase import BaseNet 10 | import torch.nn.functional as F 11 | 12 | __all__ = ['DANet'] 13 | 14 | 15 | class DANetHead(nn.Module): 16 | def __init__(self, in_channels, out_channels, norm_layer): 17 | super(DANetHead, self).__init__() 18 | inter_channels = in_channels // 4 # 2048 / 4 = 512 19 | # CBR block 20 | self.conv5a = nn.Sequential(nn.Conv2d(in_channels, inter_channels, 3, padding=1, bias=False), 21 | norm_layer(inter_channels), 22 | nn.ReLU()) 23 | # CBR block 24 | self.conv5c = nn.Sequential(nn.Conv2d(in_channels, inter_channels, 3, padding=1, bias=False), 25 | norm_layer(inter_channels), 26 | nn.ReLU()) 27 | # PAM_Module & CAM_Module 28 | self.sa = PAM_Module(inter_channels) 29 | self.sc = CAM_Module(inter_channels) 30 | self.conv51 = nn.Sequential(nn.Conv2d(inter_channels, inter_channels, 3, padding=1, bias=False), 31 | norm_layer(inter_channels), 32 | nn.ReLU()) 33 | self.conv52 = nn.Sequential(nn.Conv2d(inter_channels, inter_channels, 3, padding=1, bias=False), 34 | norm_layer(inter_channels), 35 | nn.ReLU()) 36 | self.conv6 = nn.Sequential(nn.Dropout2d(0.1, False), nn.Conv2d(512, out_channels, 1)) 37 | self.conv7 = nn.Sequential(nn.Dropout2d(0.1, False), nn.Conv2d(512, out_channels, 1)) 38 | self.conv8 = nn.Sequential(nn.Dropout2d(0.1, False), nn.Conv2d(512, out_channels, 1)) 39 | 40 | def forward(self, x): 41 | # 空间注意力分支 42 | feat1 = self.conv5a(x) 43 | sa_feat = self.sa(feat1) # (B, 512, H, W) 44 | sa_conv = self.conv51(sa_feat) # (B, 512, H, W) 45 | sa_output = self.conv6(sa_conv) # (B, 512, H, W) 46 | # 通道注意力分支 47 | feat2 = self.conv5c(x) # (B, 512, H, W) 48 | sc_feat = self.sc(feat2) # (B, 512, H, W) 49 | sc_conv = self.conv52(sc_feat) # (B, 512, H, W) 50 | sc_output = self.conv7(sc_conv) # (B, 512, H, W) 51 | # 特征融合(先pixel-wise add,然后过conv8) 52 | feat_sum = sa_conv + sc_conv # (B, 512, H, W) 53 | sasc_output = self.conv8(feat_sum) # (B, 512, H, W) 54 | return sa_output, sc_output, sasc_output # (B, 512, H, W) 55 | 56 | 57 | class SiameseDecoder(nn.Module): 58 | def __init__(self, in_channels, out_channels): 59 | super(SiameseDecoder, self).__init__() 60 | 61 | # UpBlock = Upsample + BatchNorm + LeakyRelu 62 | # in: 1x512x32x32, out: 1x256x64x64 63 | 64 | self.UpBlock1 = nn.Sequential(nn.ConvTranspose2d(in_channels, in_channels//2, padding=1, stride=2, kernel_size=(4,4)), 65 | nn.BatchNorm2d(in_channels//2), 66 | nn.LeakyReLU()) 67 | 68 | # in: 1x256x64x64, out: 1x128x128x128 69 | self.UpBlock2 = nn.Sequential(nn.ConvTranspose2d(in_channels//2, in_channels//4, padding=1, stride=2, kernel_size=(4,4)), 70 | nn.BatchNorm2d(in_channels//4), 71 | nn.LeakyReLU()) 72 | 73 | # in: 1x128x128x128, out: 1x64x256x256 74 | self.UpBlock3 = nn.Sequential(nn.ConvTranspose2d(in_channels//4, out_channels, padding=1, stride=2, kernel_size=(4,4)), 75 | nn.BatchNorm2d(out_channels), 76 | nn.LeakyReLU()) 77 | 78 | 79 | # self.UpBlock1 = nn.Sequential(nn.Conv2d(in_channels, in_channels//2, 3, padding=1, bias=False), 80 | # nn.BatchNorm2d(in_channels//2), 81 | # nn.ReLU(), 82 | # nn.Upsample(scale_factor = 2, mode='bilinear', align_corners=True) 83 | # ) 84 | 85 | # # in: 1x256x64x64, out: 1x128x128x128 86 | # self.UpBlock2 = nn.Sequential(nn.Conv2d(in_channels//2, in_channels//4, 3, padding=1, bias=False), 87 | # nn.BatchNorm2d(in_channels//4), 88 | # nn.ReLU(), 89 | # nn.Upsample(scale_factor = 2, mode='bilinear', align_corners=True) 90 | # ) 91 | 92 | # # in: 1x128x128x128, out: 1x32x256x256 93 | # self.UpBlock3 = nn.Sequential(nn.Conv2d(in_channels//4, out_channels, 3, padding=1, bias=False), 94 | # nn.BatchNorm2d(out_channels), 95 | # nn.ReLU(), 96 | # nn.Upsample(scale_factor = 2, mode='bilinear', align_corners=True) 97 | # ) 98 | 99 | def forward(self, x): 100 | for upblock in [self.UpBlock1, self.UpBlock2, self.UpBlock3]: 101 | x = upblock(x) 102 | return x 103 | 104 | 105 | class DANet(BaseNet): 106 | """ 107 | Paper: Fully Convolutional Networks for Semantic Segmentation 108 | Backbone: default:'resnet50'; 'resnet50', 'resnet101' or 'resnet152' 109 | """ 110 | def __init__(self, nclass, backbone, norm_layer=nn.BatchNorm2d, **kwargs): 111 | super(DANet, self).__init__(nclass, backbone, norm_layer=norm_layer, **kwargs) 112 | self.head = DANetHead(2048, nclass, norm_layer) 113 | 114 | def forward(self, x): 115 | # base_forward()是resnet backbone的前向函数 116 | _, _, _, c4 = self.base_forward(x) # c4是layer4的输出 117 | x = self.head(c4) 118 | x = list(x) 119 | return x[0], x[1], x[2] 120 | 121 | 122 | def cnn(): 123 | model = DANet(512, backbone='resnet50') 124 | return model 125 | 126 | 127 | class SiameseNet(nn.Module): 128 | def __init__(self, norm_flag = 'l2'): 129 | super(SiameseNet, self).__init__() 130 | self.CNN = cnn() 131 | self.decoder = SiameseDecoder(512, 32) 132 | 133 | if norm_flag == 'l2': 134 | self.norm = F.normalize 135 | if norm_flag == 'exp': 136 | self.norm = nn.Softmax2d() 137 | 138 | def forward(self, t0, t1): 139 | # CNN(t0)返回的是before img的sa, sc, sasc 140 | out_t0_conv5, out_t0_fc7, out_t0_embedding = self.CNN(t0) 141 | # CNN(t1)返回的是after img的sa, sc, sasc 142 | # 1x512x32x32 143 | out_t1_conv5, out_t1_fc7, out_t1_embedding = self.CNN(t1) 144 | 145 | # decoder forward 146 | # 1x512x32x32 --> 1x64x256x256 147 | out_t0_conv5, out_t1_conv5 = self.decoder(out_t0_conv5), self.decoder(out_t1_conv5) 148 | out_t0_fc7, out_t1_fc7 = self.decoder(out_t0_fc7), self.decoder(out_t1_fc7) 149 | out_t0_embedding, out_t1_embedding = self.decoder(out_t0_embedding), self.decoder(out_t1_embedding) 150 | 151 | # 归一化t0 t1的sa 152 | out_t0_conv5_norm, out_t1_conv5_norm = self.norm(out_t0_conv5, 2, dim=1), self.norm(out_t1_conv5, 2, dim=1) 153 | # 归一化t0 t1的sc 154 | out_t0_fc7_norm, out_t1_fc7_norm = self.norm(out_t0_fc7, 2, dim=1), self.norm(out_t1_fc7, 2, dim=1) 155 | # 归一化t0 t1的sasc 156 | out_t0_embedding_norm, out_t1_embedding_norm = self.norm(out_t0_embedding, 2, dim=1), self.norm(out_t1_embedding, 2, dim=1) 157 | # 返回归一化完毕的t0 t1三种特征向量 158 | # 原1x512x32x32 --> 1x64x256x256 159 | return [out_t0_conv5_norm, out_t1_conv5_norm], [out_t0_fc7_norm, out_t1_fc7_norm], [out_t0_embedding_norm, out_t1_embedding_norm] 160 | 161 | 162 | if __name__ == '__main__': 163 | model = SiameseNet(norm_flag='l2').cuda() 164 | print('gg') -------------------------------------------------------------------------------- /model/siameseNet/res.py: -------------------------------------------------------------------------------- 1 | ########################################################################### 2 | # Created by: CASIA IVA 3 | # Email: jliu@nlpr.ia.ac.cn 4 | # Copyright (c) 2018 5 | ########################################################################### 6 | from __future__ import division 7 | import os 8 | import numpy as np 9 | import torch 10 | import torch.nn as nn 11 | from torch.nn.functional import upsample, normalize 12 | from attention import PAM_Module 13 | from attention import CAM_Module 14 | from resbase import BaseNet 15 | import torch.nn.functional as F 16 | 17 | __all__ = ['Net'] 18 | 19 | class Net(BaseNet): 20 | r"""Fully Convolutional Networks for Semantic Segmentation 21 | 22 | Parameters 23 | ---------- 24 | nclass : int 25 | Number of categories for the training dataset. 26 | backbone : string 27 | Pre-trained dilated backbone network type (default:'resnet50'; 'resnet50', 28 | 'resnet101' or 'resnet152'). 29 | norm_layer : object 30 | Normalization layer used in backbone network (default: :class:`mxnet.gluon.nn.BatchNorm`; 31 | 32 | 33 | Reference: 34 | 35 | Long, Jonathan, Evan Shelhamer, and Trevor Darrell. "Fully convolutional networks 36 | for semantic segmentation." *CVPR*, 2015 37 | 38 | """ 39 | 40 | def __init__(self, nclass, backbone, norm_layer=nn.BatchNorm2d, **kwargs): 41 | super(Net, self).__init__(nclass, backbone, norm_layer=norm_layer, **kwargs) 42 | self.conv8 = nn.Sequential(nn.Dropout2d(0.1, False), nn.Conv2d(2048, 512, 1)) 43 | 44 | def forward(self, x): 45 | _, _, c3, c4 = self.base_forward(x) 46 | x = self.conv8(c4) 47 | 48 | return x 49 | 50 | def cnn(): 51 | model = Net(512, backbone='resnet101') 52 | return model 53 | 54 | 55 | class SiameseNet(nn.Module): 56 | def __init__(self,norm_flag = 'l2'): 57 | super(SiameseNet, self).__init__() 58 | self.CNN = cnn() 59 | if norm_flag == 'l2': 60 | self.norm = F.normalize 61 | if norm_flag == 'exp': 62 | self.norm = nn.Softmax2d() 63 | ''''''''' 64 | def forward(self,t0,t1): 65 | out_t0_embedding = self.CNN(t0) 66 | out_t1_embedding = self.CNN(t1) 67 | #out_t0_conv5_norm,out_t1_conv5_norm = self.norm(out_t0_conv5),self.norm(out_t1_conv5) 68 | #out_t0_fc7_norm,out_t1_fc7_norm = self.norm(out_t0_fc7),self.norm(out_t1_fc7) 69 | out_t0_embedding_norm,out_t1_embedding_norm = self.norm(out_t0_embedding),self.norm(out_t1_embedding) 70 | return [out_t0_embedding_norm,out_t1_embedding_norm] 71 | ''''''''' 72 | 73 | def forward(self,t0,t1): 74 | out_t0_embedding = self.CNN(t0) 75 | out_t1_embedding = self.CNN(t1) 76 | out_t0_embedding_norm,out_t1_embedding_norm = self.norm(out_t0_embedding,2,dim=1),self.norm(out_t1_embedding,2,dim=1) 77 | return [out_t0_embedding_norm,out_t1_embedding_norm] 78 | 79 | 80 | if __name__ == '__main__': 81 | m = SiameseNet() 82 | print('gg') -------------------------------------------------------------------------------- /model/siameseNet/res50.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | import math 4 | import torch.utils.model_zoo as model_zoo 5 | import torch 6 | import numpy as np 7 | from attention import * 8 | # affine_par = True # True: BN has learnable affine parameters, False: without learnable affine parameters of BatchNorm Layer 9 | 10 | 11 | def outS(i): 12 | i = int(i) 13 | i = (i + 1) / 2 14 | i = int(np.ceil((i + 1) / 2.0)) 15 | i = (i + 1) / 2 16 | return i 17 | 18 | 19 | def conv3x3(in_planes, out_planes, stride=1): 20 | "3x3 convolution with padding" 21 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 22 | padding=1, bias=False) 23 | 24 | 25 | class BasicBlock(nn.Module): 26 | expansion = 1 27 | 28 | def __init__(self, inplanes, planes, stride=1, downsample=None): 29 | super(BasicBlock, self).__init__() 30 | self.conv1 = conv3x3(inplanes, planes, stride) 31 | self.bn1 = nn.BatchNorm2d(planes, affine=True) 32 | self.relu = nn.ReLU(inplace=True) 33 | self.conv2 = conv3x3(planes, planes) 34 | self.bn2 = nn.BatchNorm2d(planes, affine=True) 35 | self.downsample = downsample 36 | self.stride = stride 37 | 38 | def forward(self, x): 39 | residual = x 40 | 41 | out = self.conv1(x) 42 | out = self.bn1(out) 43 | out = self.relu(out) 44 | 45 | out = self.conv2(out) 46 | out = self.bn2(out) 47 | 48 | if self.downsample is not None: 49 | residual = self.downsample(x) 50 | 51 | out += residual 52 | out = self.relu(out) 53 | 54 | return out 55 | 56 | 57 | class Bottleneck(nn.Module): 58 | expansion = 4 59 | 60 | def __init__(self, inplanes, planes, stride=1, dilation=1, downsample=None): 61 | super(Bottleneck, self).__init__() 62 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, stride=stride, bias=False) # change 63 | self.bn1 = nn.BatchNorm2d(planes, affine=True) 64 | for i in self.bn1.parameters(): 65 | i.requires_grad = True 66 | 67 | padding = dilation 68 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, # change 69 | padding=padding, bias=False, dilation=dilation) 70 | self.bn2 = nn.BatchNorm2d(planes, affine=True) 71 | for i in self.bn2.parameters(): 72 | i.requires_grad = True 73 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) 74 | self.bn3 = nn.BatchNorm2d(planes * 4, affine=True) 75 | for i in self.bn3.parameters(): 76 | i.requires_grad = True 77 | self.relu = nn.ReLU(inplace=True) 78 | self.downsample = downsample 79 | self.stride = stride 80 | 81 | def forward(self, x): 82 | residual = x 83 | 84 | out = self.conv1(x) 85 | out = self.bn1(out) 86 | out = self.relu(out) 87 | 88 | out = self.conv2(out) 89 | out = self.bn2(out) 90 | out = self.relu(out) 91 | 92 | out = self.conv3(out) 93 | out = self.bn3(out) 94 | 95 | if self.downsample is not None: 96 | residual = self.downsample(x) 97 | 98 | out += residual 99 | out = self.relu(out) 100 | 101 | return out 102 | 103 | 104 | class Residual_Covolution(nn.Module): 105 | def __init__(self, icol, ocol, num_classes): 106 | super(Residual_Covolution, self).__init__() 107 | self.conv1 = nn.Conv2d(icol, ocol, kernel_size=3, stride=1, padding=12, dilation=12, bias=True) 108 | self.conv2 = nn.Conv2d(ocol, num_classes, kernel_size=3, stride=1, padding=12, dilation=12, bias=True) 109 | self.conv3 = nn.Conv2d(num_classes, ocol, kernel_size=1, stride=1, padding=0, dilation=1, bias=True) 110 | self.conv4 = nn.Conv2d(ocol, icol, kernel_size=1, stride=1, padding=0, dilation=1, bias=True) 111 | self.relu = nn.ReLU(inplace=True) 112 | 113 | def forward(self, x): 114 | dow1 = self.conv1(x) 115 | dow1 = self.relu(dow1) 116 | seg = self.conv2(dow1) 117 | inc1 = self.conv3(seg) 118 | add1 = dow1 + self.relu(inc1) 119 | inc2 = self.conv4(add1) 120 | out = x + self.relu(inc2) 121 | return out, seg 122 | 123 | 124 | 125 | class ResNet(nn.Module): 126 | def __init__(self, block, layers): 127 | self.inplanes = 64 128 | super(ResNet, self).__init__() 129 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, 130 | bias=False) 131 | self.bn1 = nn.BatchNorm2d(64, affine=True) 132 | for i in self.bn1.parameters(): 133 | i.requires_grad = True 134 | self.relu = nn.ReLU(inplace=True) 135 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1, ceil_mode=True) # change 136 | self.layer1 = self._make_layer(block, 64, layers[0]) 137 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 138 | self.layer3 = self._make_layer(block, 256, layers[2], stride=1, dilation=2) 139 | self.layer4 = self._make_layer(block, 512, layers[3], stride=1, dilation=4) 140 | inter_channels = 2048 // 4 141 | self.conv5a = nn.Sequential(nn.Conv2d(2048, inter_channels, 3, padding=1, bias=False), 142 | 143 | nn.ReLU()) 144 | 145 | self.conv5c = nn.Sequential(nn.Conv2d(2048, inter_channels, 3, padding=1, bias=False), 146 | 147 | nn.ReLU()) 148 | 149 | self.sa = PAM_Module(inter_channels) 150 | self.sc = CAM_Module(inter_channels) 151 | self.conv51 = nn.Sequential(nn.Conv2d(inter_channels, inter_channels, 3, padding=1, bias=False), 152 | 153 | nn.ReLU()) 154 | self.conv52 = nn.Sequential(nn.Conv2d(inter_channels, inter_channels, 3, padding=1, bias=False), 155 | 156 | nn.ReLU()) 157 | 158 | self.conv6 = nn.Sequential(nn.Dropout2d(0.1, False), nn.Conv2d(512, 512, 1)) 159 | self.conv7 = nn.Sequential(nn.Dropout2d(0.1, False), nn.Conv2d(512, 512, 1)) 160 | 161 | self.conv8 = nn.Sequential(nn.Dropout2d(0.1, False), nn.Conv2d(512, 512, 1)) 162 | self.embedding_layer = nn.Conv2d(in_channels=2048,out_channels=512,kernel_size=1) 163 | 164 | #self.embedding_layer = nn.Conv2d(512, num_classes, kernel_size=1) 165 | # self.softmax = nn.Softmax() 166 | 167 | # init weights 168 | for m in self.modules(): 169 | if isinstance(m, nn.Conv2d): 170 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 171 | m.weight.data.normal_(0, 0.01) 172 | elif isinstance(m, nn.BatchNorm2d): 173 | m.weight.data.fill_(1) 174 | m.bias.data.zero_() 175 | # for i in m.parameters(): 176 | # i.requires_grad = False 177 | 178 | def _make_layer(self, block, planes, blocks, stride=1, dilation=1): 179 | downsample = None 180 | if stride != 1 or self.inplanes != planes * block.expansion or dilation == 2 or dilation == 4: 181 | downsample = nn.Sequential( 182 | nn.Conv2d(self.inplanes, planes * block.expansion, 183 | kernel_size=1, stride=stride, bias=False), 184 | nn.BatchNorm2d(planes * block.expansion, affine=True)) 185 | for i in downsample._modules['1'].parameters(): 186 | i.requires_grad = True 187 | layers = [] 188 | layers.append(block(self.inplanes, planes, stride, dilation=dilation, downsample=downsample)) 189 | self.inplanes = planes * block.expansion 190 | for i in range(1, blocks): 191 | layers.append(block(self.inplanes, planes, dilation=dilation)) 192 | 193 | return nn.Sequential(*layers) 194 | 195 | def forward(self, x): 196 | input_size = x.size()[2:] 197 | x = self.conv1(x) # 7x7Conv 198 | x = self.bn1(x) 199 | x = self.relu(x) 200 | x = self.maxpool(x) 201 | x = self.layer1(x) # res2 202 | x = self.layer2(x) # res3 203 | x = self.layer3(x) # res4 204 | conv_feature = self.layer4(x) # res5 205 | feat1 = self.conv5a(conv_feature) 206 | sa_feat = self.sa(feat1) 207 | sa_conv = self.conv51(sa_feat) 208 | sa_output = self.conv6(sa_conv) 209 | feat2 = self.conv5c(conv_feature) 210 | sc_feat = self.sc(feat2) 211 | sc_conv = self.conv52(sc_feat) 212 | sc_output = self.conv7(sc_conv) 213 | 214 | feat_sum = sa_conv + sc_conv 215 | 216 | sasc_output = self.conv8(feat_sum) 217 | # embedding_feature = self.embedding_layer(conv_feature) 218 | 219 | 220 | # return sa_output,sc_output, sasc_output 221 | return sa_output, sc_output,sasc_output # 222 | 223 | 224 | def PSPNet(): 225 | """ """ 226 | model = ResNet(Bottleneck, [3, 4, 6, 3]) 227 | return model 228 | 229 | class SiameseNet(nn.Module): 230 | def __init__(self,norm_flag = 'l2'): 231 | super(SiameseNet, self).__init__() 232 | self.CNN = ResNet(Bottleneck, [3, 4, 6, 3]) 233 | if norm_flag == 'l2': 234 | self.norm = F.normalize 235 | if norm_flag == 'exp': 236 | self.norm = nn.Softmax2d() 237 | 238 | def forward(self, t0, t1): 239 | out_t0_conv5,out_t0_fc7,out_t0_embedding = self.CNN(t0) 240 | out_t1_conv5,out_t1_fc7,out_t1_embedding = self.CNN(t1) 241 | out_t0_conv5_norm,out_t1_conv5_norm = self.norm(out_t0_conv5,2,dim=1),self.norm(out_t1_conv5,2,dim=1) 242 | out_t0_fc7_norm,out_t1_fc7_norm = self.norm(out_t0_fc7,2,dim=1),self.norm(out_t1_fc7,2,dim=1) 243 | out_t0_embedding_norm,out_t1_embedding_norm = self.norm(out_t0_embedding,2,dim=1),self.norm(out_t1_embedding,2,dim=1) 244 | return [out_t0_conv5_norm,out_t1_conv5_norm],[out_t0_fc7_norm,out_t1_fc7_norm],[out_t0_embedding_norm,out_t1_embedding_norm] 245 | 246 | # out_t0_conv5, out_t0_embedding = self.CNN(t0) 247 | # out_t1_conv5, out_t1_embedding = self.CNN(t1) 248 | # out_t0_conv5_norm, out_t1_conv5_norm = self.norm(out_t0_conv5, 2, dim=1), self.norm(out_t1_conv5, 2, dim=1) 249 | # out_t0_embedding_norm, out_t1_embedding_norm = self.norm(out_t0_embedding, 2, dim=1), self.norm(out_t1_embedding, 2, 250 | # dim=1) 251 | # return [out_t0_conv5_norm, out_t1_conv5_norm], [out_t0_embedding_norm,out_t1_embedding_norm] -------------------------------------------------------------------------------- /model/siameseNet/resbase.py: -------------------------------------------------------------------------------- 1 | ########################################################################### 2 | # Created by: Hang Zhang 3 | # Email: zhang.hang@rutgers.edu 4 | # Copyright (c) 2017 5 | ########################################################################### 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | from torch.nn.functional import upsample 9 | from torch.nn.parallel.data_parallel import DataParallel 10 | from torch.nn.parallel.parallel_apply import parallel_apply 11 | from torch.nn.parallel.scatter_gather import scatter 12 | from . import resnet 13 | 14 | up_kwargs = {'mode': 'bilinear', 'align_corners': True} 15 | 16 | __all__ = ['BaseNet'] 17 | 18 | class BaseNet(nn.Module): 19 | def __init__(self, nclass, backbone, dilated=True, norm_layer=None, root='pretrained', multi_grid=False, multi_dilation=None): 20 | super(BaseNet, self).__init__() 21 | 22 | # copying modules from pretrained models 23 | if backbone == 'resnet34': 24 | self.pretrained = resnet.resnet34(pretrained=False, dilated=dilated, 25 | norm_layer=norm_layer, root=root, 26 | multi_grid=multi_grid, multi_dilation=multi_dilation) 27 | elif backbone == 'resnet50': 28 | self.pretrained = resnet.resnet50(pretrained=True, dilated=dilated, 29 | norm_layer=norm_layer, root=root, 30 | multi_grid=multi_grid, multi_dilation=multi_dilation) 31 | elif backbone == 'resnet101': 32 | self.pretrained = resnet.resnet101(pretrained=True, dilated=dilated, 33 | norm_layer=norm_layer, root=root, 34 | multi_grid=multi_grid,multi_dilation=multi_dilation) 35 | elif backbone == 'resnet152': 36 | self.pretrained = resnet.resnet152(pretrained=False, dilated=dilated, 37 | norm_layer=norm_layer, root=root, 38 | multi_grid=multi_grid, multi_dilation=multi_dilation) 39 | else: 40 | raise RuntimeError('unknown backbone: {}'.format(backbone)) 41 | # bilinear upsample options 42 | self._up_kwargs = up_kwargs 43 | 44 | def base_forward(self, x): 45 | x = self.pretrained.conv1(x) 46 | x = self.pretrained.bn1(x) 47 | x = self.pretrained.relu(x) 48 | x = self.pretrained.maxpool(x) 49 | c1 = self.pretrained.layer1(x) 50 | c2 = self.pretrained.layer2(c1) 51 | c3 = self.pretrained.layer3(c2) 52 | c4 = self.pretrained.layer4(c3) 53 | return c1, c2, c3, c4 54 | 55 | 56 | 57 | 58 | -------------------------------------------------------------------------------- /model/siameseNet/resnet.py: -------------------------------------------------------------------------------- 1 | """Dilated ResNet""" 2 | import math 3 | import torch 4 | import torch.utils.model_zoo as model_zoo 5 | import torch.nn as nn 6 | 7 | __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 8 | 'resnet152', 'BasicBlock', 'Bottleneck'] 9 | 10 | model_urls = { 11 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', 12 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', 13 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', 14 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', 15 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', 16 | } 17 | 18 | 19 | def conv3x3(in_planes, out_planes, stride=1): 20 | "3x3 convolution with padding" 21 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 22 | padding=1, bias=False) 23 | 24 | 25 | class BasicBlock(nn.Module): 26 | """ResNet BasicBlock 27 | """ 28 | expansion = 1 29 | def __init__(self, inplanes, planes, stride=1, dilation=1, downsample=None, previous_dilation=1, 30 | norm_layer=None): 31 | super(BasicBlock, self).__init__() 32 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=3, stride=stride, 33 | padding=dilation, dilation=dilation, bias=False) 34 | self.bn1 = norm_layer(planes) 35 | self.relu = nn.ReLU(inplace=True) 36 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, 37 | padding=previous_dilation, dilation=previous_dilation, bias=False) 38 | self.bn2 = norm_layer(planes) 39 | self.downsample = downsample 40 | self.stride = stride 41 | 42 | def forward(self, x): 43 | residual = x 44 | 45 | out = self.conv1(x) 46 | out = self.bn1(out) 47 | out = self.relu(out) 48 | 49 | out = self.conv2(out) 50 | out = self.bn2(out) 51 | 52 | if self.downsample is not None: 53 | residual = self.downsample(x) 54 | 55 | out += residual 56 | out = self.relu(out) 57 | 58 | return out 59 | 60 | 61 | class Bottleneck(nn.Module): 62 | """ResNet Bottleneck 63 | """ 64 | # pylint: disable=unused-argument 65 | expansion = 4 66 | def __init__(self, inplanes, planes, stride=1, dilation=1, 67 | downsample=None, previous_dilation=1, norm_layer=None): 68 | super(Bottleneck, self).__init__() 69 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 70 | self.bn1 = norm_layer(planes) 71 | self.conv2 = nn.Conv2d( 72 | planes, planes, kernel_size=3, stride=stride, 73 | padding=dilation, dilation=dilation, bias=False) 74 | self.bn2 = norm_layer(planes) 75 | self.conv3 = nn.Conv2d( 76 | planes, planes * 4, kernel_size=1, bias=False) 77 | self.bn3 = norm_layer(planes * 4) 78 | self.relu = nn.ReLU(inplace=True) 79 | self.downsample = downsample 80 | self.dilation = dilation 81 | self.stride = stride 82 | 83 | def _sum_each(self, x, y): 84 | assert(len(x) == len(y)) 85 | z = [] 86 | for i in range(len(x)): 87 | z.append(x[i]+y[i]) 88 | return z 89 | 90 | def forward(self, x): 91 | residual = x 92 | 93 | out = self.conv1(x) 94 | out = self.bn1(out) 95 | out = self.relu(out) 96 | 97 | out = self.conv2(out) 98 | out = self.bn2(out) 99 | out = self.relu(out) 100 | 101 | out = self.conv3(out) 102 | out = self.bn3(out) 103 | 104 | if self.downsample is not None: 105 | residual = self.downsample(x) 106 | 107 | out += residual 108 | out = self.relu(out) 109 | 110 | return out 111 | 112 | 113 | class ResNet(nn.Module): 114 | """Dilated Pre-trained ResNet Model, which preduces the stride of 8 featuremaps at conv5. 115 | 116 | Parameters 117 | ---------- 118 | block : Block 119 | Class for the residual block. Options are BasicBlockV1, BottleneckV1. 120 | layers : list of int 121 | Numbers of layers in each block 122 | classes : int, default 1000 123 | Number of classification classes. 124 | dilated : bool, default False 125 | Applying dilation strategy to pretrained ResNet yielding a stride-8 model, 126 | typically used in Semantic Segmentation. 127 | norm_layer : object 128 | Normalization layer used in backbone network (default: :class:`mxnet.gluon.nn.BatchNorm`; 129 | for Synchronized Cross-GPU BachNormalization). 130 | 131 | Reference: 132 | 133 | - He, Kaiming, et al. "Deep residual learning for image recognition." Proceedings of the IEEE conference on computer vision and pattern recognition. 2016. 134 | 135 | - Yu, Fisher, and Vladlen Koltun. "Multi-scale context aggregation by dilated convolutions." 136 | """ 137 | # pylint: disable=unused-variable 138 | def __init__(self, block, layers, num_classes=1000, dilated=True, norm_layer=nn.BatchNorm2d, multi_grid=False, multi_dilation=None): 139 | self.inplanes = 64 140 | super(ResNet, self).__init__() 141 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, 142 | bias=False) 143 | self.bn1 = norm_layer(64) 144 | self.relu = nn.ReLU(inplace=True) 145 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 146 | self.layer1 = self._make_layer(block, 64, layers[0], norm_layer=norm_layer) 147 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2, norm_layer=norm_layer) 148 | if dilated: 149 | if multi_grid: 150 | self.layer3 = self._make_layer(block,256,layers[2],stride=1, 151 | dilation=2, norm_layer=norm_layer) 152 | self.layer4 = self._make_layer(block,512,layers[3],stride=1, 153 | dilation=4,norm_layer=norm_layer, 154 | multi_grid=multi_grid, multi_dilation=multi_dilation) 155 | else: 156 | self.layer3 = self._make_layer(block, 256, layers[2], stride=1, 157 | dilation=2, norm_layer=norm_layer) 158 | self.layer4 = self._make_layer(block, 512, layers[3], stride=1, 159 | dilation=4, norm_layer=norm_layer) 160 | else: 161 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2, 162 | norm_layer=norm_layer) 163 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2, 164 | norm_layer=norm_layer) 165 | self.avgpool = nn.AvgPool2d(7) 166 | self.fc = nn.Linear(512 * block.expansion, num_classes) 167 | 168 | for m in self.modules(): 169 | if isinstance(m, nn.Conv2d): 170 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 171 | m.weight.data.normal_(0, math.sqrt(2. / n)) 172 | elif isinstance(m, norm_layer): 173 | m.weight.data.fill_(1) 174 | m.bias.data.zero_() 175 | 176 | def _make_layer(self, block, planes, blocks, stride=1, dilation=1, norm_layer=None, multi_grid=False, multi_dilation=None): 177 | downsample = None 178 | if stride != 1 or self.inplanes != planes * block.expansion: 179 | downsample = nn.Sequential( 180 | nn.Conv2d(self.inplanes, planes * block.expansion, 181 | kernel_size=1, stride=stride, bias=False), 182 | norm_layer(planes * block.expansion), 183 | ) 184 | 185 | layers = [] 186 | if multi_grid == False: 187 | if dilation == 1 or dilation == 2: 188 | layers.append(block(self.inplanes, planes, stride, dilation=1, 189 | downsample=downsample, previous_dilation=dilation, norm_layer=norm_layer)) 190 | elif dilation == 4: 191 | layers.append(block(self.inplanes, planes, stride, dilation=2, 192 | downsample=downsample, previous_dilation=dilation, norm_layer=norm_layer)) 193 | else: 194 | raise RuntimeError("=> unknown dilation size: {}".format(dilation)) 195 | else: 196 | layers.append(block(self.inplanes, planes, stride, dilation=multi_dilation[0], 197 | downsample=downsample, previous_dilation=dilation, norm_layer=norm_layer)) 198 | self.inplanes = planes * block.expansion 199 | if multi_grid: 200 | div = len(multi_dilation) 201 | for i in range(1,blocks): 202 | layers.append(block(self.inplanes, planes, dilation=multi_dilation[i%div], previous_dilation=dilation, 203 | norm_layer=norm_layer)) 204 | else: 205 | for i in range(1, blocks): 206 | layers.append(block(self.inplanes, planes, dilation=dilation, previous_dilation=dilation, 207 | norm_layer=norm_layer)) 208 | 209 | return nn.Sequential(*layers) 210 | 211 | def forward(self, x): 212 | x = self.conv1(x) 213 | x = self.bn1(x) 214 | x = self.relu(x) 215 | x = self.maxpool(x) 216 | 217 | x = self.layer1(x) 218 | x = self.layer2(x) 219 | x = self.layer3(x) 220 | x = self.layer4(x) 221 | 222 | x = self.avgpool(x) 223 | x = x.view(x.size(0), -1) 224 | x = self.fc(x) 225 | 226 | return x 227 | 228 | 229 | def resnet18(pretrained=False, **kwargs): 230 | """Constructs a ResNet-18 model. 231 | 232 | Args: 233 | pretrained (bool): If True, returns a model pre-trained on ImageNet 234 | """ 235 | model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs) 236 | if pretrained: 237 | model.load_state_dict(model_zoo.load_url(model_urls['resnet18'])) 238 | return model 239 | 240 | 241 | def resnet34(pretrained=False, root='/home/lhf/yzy/cd_res/pretrain',**kwargs): 242 | """Constructs a ResNet-34 model. 243 | 244 | Args: 245 | pretrained (bool): If True, returns a model pre-trained on ImageNet 246 | """ 247 | model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs) 248 | if pretrained: 249 | model.load_state_dict(model_zoo.load_url(model_urls['resnet34'])) 250 | return model 251 | 252 | 253 | def resnet50(pretrained=True, root='pretrained', **kwargs): # 默认参数不能为r'xxxxx' 254 | """Constructs a ResNet-50 model. 255 | 256 | Args: 257 | pretrained (bool): If True, returns a model pre-trained on ImageNet 258 | """ 259 | model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs) 260 | if pretrained: 261 | model_pth = '/root/autodl-tmp/DASNet-master/pretrained/resnet50-19c8e357.pth' 262 | model.load_state_dict(torch.load(model_pth)) 263 | print('load pretrained resnet50:', model_pth) 264 | return model 265 | 266 | 267 | def resnet101(pretrained=False, root='/home/lhf/yzy/cd_res/pretrain', **kwargs): 268 | """Constructs a ResNet-101 model. 269 | 270 | Args: 271 | pretrained (bool): If True, returns a model pre-trained on ImageNet 272 | """ 273 | model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs) 274 | #Remove the following lines of comments 275 | #if u want to train from a pretrained model 276 | if pretrained: 277 | model_pth = '/home/lhf/yzy/cd_res/pretrain/resnet101-5d3b4d8f.pth' 278 | model.load_state_dict(torch.load(model_pth)) 279 | print('load pretrained resnet101:', model_pth) 280 | return model 281 | 282 | 283 | def resnet152(pretrained=False, root='~/.encoding/models', **kwargs): 284 | """Constructs a ResNet-152 model. 285 | 286 | Args: 287 | pretrained (bool): If True, returns a model pre-trained on ImageNet 288 | """ 289 | model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs) 290 | if pretrained: 291 | # model.load_state_dict(model_zoo.load_url(model_urls['resnet152'])) 292 | model.load_state_dict(torch.load( 293 | '/home/lhf/yzy/cd_res/pretrain/resnet152-b121ed2d.pth'), strict=False) 294 | return model 295 | 296 | if __name__ == '__main__': 297 | m1 = resnet50(pretrained=True) 298 | print('hh') -------------------------------------------------------------------------------- /model/siameseNet/test.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import numpy as np\n", 10 | "import os\n", 11 | "import sys \n", 12 | "sys.path.append(\"../..\")\n", 13 | "import torch\n", 14 | "import torch.nn as nn\n", 15 | "import torch.utils.data as Data\n", 16 | "from torch.nn import functional as F\n", 17 | "import utils.transforms as trans\n", 18 | "import utils.utils as util\n", 19 | "import layer.loss as ls\n", 20 | "import utils.metric as mc\n", 21 | "import shutil\n", 22 | "import cv2\n", 23 | "\n", 24 | "import cfg.CDD as cfg\n", 25 | "import dataset.rs as dates\n", 26 | "import time\n", 27 | "import datetime\n", 28 | "from datetime import datetime\n", 29 | "import logging" 30 | ] 31 | }, 32 | { 33 | "cell_type": "code", 34 | "execution_count": 32, 35 | "metadata": {}, 36 | "outputs": [], 37 | "source": [ 38 | "thres = np.linspace(0, 1, 5) # array([0. , 0.25, 0.5 , 0.75, 1. ])\n", 39 | "gtBin = np.array([[True, False, False],\n", 40 | " [True, True, True],\n", 41 | " [False, False, True]])\n", 42 | "\n", 43 | "cur_prob = np.array([[0.7, 0.2, 0.4],\n", 44 | " [0.6, 0.8, 0.5],\n", 45 | " [0.3, 0.5, 0.9]])\n", 46 | "\n", 47 | "thresInf = np.concatenate(([-np.Inf], thres, [np.Inf])) # array([-inf, 0. , 0.25, 0.5 , 0.75, 1. , inf])\n", 48 | "\n", 49 | "fnArray = cur_prob[(gtBin == True)] # array([0.7, 0.6, 0.8, 0.5, 0.9])\n", 50 | "fnHist = np.histogram(fnArray, bins=thresInf)[0] # array([0, 0, 0, 3, 2, 0], dtype=int64)\n", 51 | "fnCum = np.cumsum(fnHist) # array([0, 0, 0, 3, 5, 5], dtype=int64)\n", 52 | "FN = fnCum[0:0 + len(thres)] # array([0, 0, 0, 3, 5], dtype=int64)\n", 53 | "\n", 54 | "fpArray = cur_prob[(gtBin == False)] # array([0.2, 0.4, 0.3, 0.5])\n", 55 | "fpHist = np.histogram(fpArray, bins=thresInf)[0] # array([0, 1, 2, 1, 0, 0], dtype=int64)\n", 56 | "# 倒置求累计和再倒置回去\n", 57 | "fpCum = np.flipud(np.cumsum(np.flipud(fpHist))) # array([4, 4, 3, 1, 0, 0], dtype=int64)\n", 58 | "FP = fpCum[1:1 + len(thres)] # array([4, 3, 1, 0, 0], dtype=int64)\n", 59 | "\n", 60 | "posNum = np.sum(gtBin == True)\n", 61 | "negNum = np.sum(gtBin == False)" 62 | ] 63 | } 64 | ], 65 | "metadata": { 66 | "kernelspec": { 67 | "display_name": "Python 3.6.13 ('gluon')", 68 | "language": "python", 69 | "name": "python3" 70 | }, 71 | "language_info": { 72 | "codemirror_mode": { 73 | "name": "ipython", 74 | "version": 3 75 | }, 76 | "file_extension": ".py", 77 | "mimetype": "text/x-python", 78 | "name": "python", 79 | "nbconvert_exporter": "python", 80 | "pygments_lexer": "ipython3", 81 | "version": "3.6.13" 82 | }, 83 | "orig_nbformat": 4, 84 | "vscode": { 85 | "interpreter": { 86 | "hash": "9862ae77e9daaaf9c9239620ed827aad4ce184b3776eb7a3f75df899d88e405b" 87 | } 88 | } 89 | }, 90 | "nbformat": 4, 91 | "nbformat_minor": 2 92 | } 93 | -------------------------------------------------------------------------------- /pretrained/note.txt: -------------------------------------------------------------------------------- 1 | resnet50-19c8e357.pth可自行去网上下载 2 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | import torch 4 | import torch.nn as nn 5 | import torch.utils.data as Data 6 | from torch.autograd import Variable 7 | from torch.nn import functional as F 8 | import utils.transforms as trans 9 | import utils.utils as util 10 | import utils.metric as mc 11 | import time 12 | import datetime 13 | import cv2 14 | import dataset.rs as dates 15 | import model.siameseNet.dares as models 16 | 17 | os.environ["CUDA_VISIBLE_DEVICES"] = "0" 18 | 19 | def check_dir(dir): 20 | if not os.path.exists(dir): 21 | os.mkdir(dir) 22 | 23 | def various_distance(out_vec_t0, out_vec_t1,dist_flag): 24 | if dist_flag == 'l2': 25 | distance = F.pairwise_distance(out_vec_t0, out_vec_t1, p=2) 26 | if dist_flag == 'l1': 27 | distance = F.pairwise_distance(out_vec_t0, out_vec_t1, p=1) 28 | if dist_flag == 'cos': 29 | distance = 1 - F.cosine_similarity(out_vec_t0, out_vec_t1) 30 | return distance 31 | 32 | def single_layer_similar_heatmap_visual(output_t0,output_t1,save_change_map_dir,epoch,filename,layer_flag,dist_flag): 33 | fname = filename[7:12] 34 | n, c, h, w = output_t0.data.shape 35 | out_t0_rz = torch.transpose(output_t0.view(c, h * w), 1, 0) 36 | out_t1_rz = torch.transpose(output_t1.view(c, h * w), 1, 0) 37 | distance = various_distance(out_t0_rz,out_t1_rz,dist_flag=dist_flag) 38 | similar_distance_map = distance.view(h,w).data.cpu().numpy() 39 | similar_distance_map_rz = nn.functional.interpolate(torch.from_numpy(similar_distance_map[np.newaxis, np.newaxis, :]),size=[256,256], mode='bilinear',align_corners=True) 40 | similar_dis_map_colorize = cv2.applyColorMap(np.uint8(255 * similar_distance_map_rz.data.cpu().numpy()[0][0]), cv2.COLORMAP_JET) 41 | save_change_map_dir_ = os.path.join(save_change_map_dir, 'epoch_' + str(epoch)) 42 | check_dir(save_change_map_dir_) 43 | save_change_map_dir_layer = os.path.join(save_change_map_dir_,layer_flag) 44 | check_dir(save_change_map_dir_layer) 45 | save_weight_fig_dir = os.path.join(save_change_map_dir_layer, fname + '.jpg') 46 | cv2.imwrite(save_weight_fig_dir, similar_dis_map_colorize) 47 | return similar_distance_map_rz.data.cpu().numpy() 48 | 49 | def validate(net, val_dataloader,save_change_map_dir,save_roc_dir): 50 | epoch = 1 51 | net.eval() 52 | with torch.no_grad(): 53 | cont_conv5_total, cont_fc_total, cont_embedding_total, num = 0.0, 0.0, 0.0, 0.0 54 | metric_for_conditions = util.init_metric_for_class_for_cmu(1) 55 | for batch_idx, batch in enumerate(val_dataloader): 56 | inputs1, input2, targets, filename, height, width = batch 57 | height, width, filename = height.numpy()[0], width.numpy()[0], filename[0] 58 | inputs1, input2, targets = inputs1.cuda(), input2.cuda(), targets.cuda() 59 | fname = filename.split('/')[1][:-4] 60 | out_conv5, out_fc, out_embedding = net(inputs1, input2) 61 | out_conv5_t0, out_conv5_t1 = out_conv5 62 | out_fc_t0, out_fc_t1 = out_fc 63 | out_embedding_t0, out_embedding_t1 = out_embedding 64 | conv5_distance_map = single_layer_similar_heatmap_visual(out_conv5_t0, out_conv5_t1, save_change_map_dir,epoch, filename, 'conv5', 'l2') 65 | fc_distance_map = single_layer_similar_heatmap_visual(out_fc_t0, out_fc_t1, save_change_map_dir, epoch,filename, 'fc', 'l2') 66 | embedding_distance_map = single_layer_similar_heatmap_visual(out_embedding_t0, out_embedding_t1,save_change_map_dir, epoch, filename,'embedding', 'l2') 67 | cont_conv5 = mc.RMS_Contrast(conv5_distance_map) 68 | cont_fc = mc.RMS_Contrast(fc_distance_map) 69 | cont_embedding = mc.RMS_Contrast(embedding_distance_map) 70 | cont_conv5_total += cont_conv5 71 | cont_fc_total += cont_fc 72 | cont_embedding_total += cont_embedding 73 | num += 1 74 | prob_change = embedding_distance_map[0][0] 75 | gt = targets.data.cpu().numpy() 76 | FN, FP, posNum, negNum = mc.eval_image_rewrite(gt[0], prob_change, cl_index=1) 77 | metric_for_conditions[0]['total_fp'] += FP 78 | metric_for_conditions[0]['total_fn'] += FN 79 | metric_for_conditions[0]['total_posnum'] += posNum 80 | metric_for_conditions[0]['total_negnum'] += negNum 81 | cont_conv5_mean, cont_fc_mean, cont_embedding_mean = cont_conv5_total / num, cont_fc_total / num, cont_embedding_total / num 82 | 83 | thresh = np.array(range(0, 256)) / 255.0 84 | conds = metric_for_conditions.keys() 85 | for cond_name in conds: 86 | total_posnum = metric_for_conditions[cond_name]['total_posnum'] 87 | total_negnum = metric_for_conditions[cond_name]['total_negnum'] 88 | total_fn = metric_for_conditions[cond_name]['total_fn'] 89 | total_fp = metric_for_conditions[cond_name]['total_fp'] 90 | metric_dict = mc.pxEval_maximizeFMeasure(total_posnum, total_negnum,total_fn, total_fp, thresh=thresh) 91 | metric_for_conditions[cond_name].setdefault('metric', metric_dict) 92 | metric_for_conditions[cond_name].setdefault('contrast_conv5', cont_conv5_mean) 93 | metric_for_conditions[cond_name].setdefault('contrast_fc', cont_fc_mean) 94 | metric_for_conditions[cond_name].setdefault('contrast_embedding', cont_embedding_mean) 95 | 96 | f_score_total = 0.0 97 | for cond_name in conds: 98 | pr, recall, f_score = metric_for_conditions[cond_name]['metric']['precision'], \ 99 | metric_for_conditions[cond_name]['metric']['recall'], \ 100 | metric_for_conditions[cond_name]['metric']['MaxF'] 101 | roc_save_epoch_dir = os.path.join(save_roc_dir, str(epoch)) 102 | check_dir(roc_save_epoch_dir) 103 | roc_save_epoch_cat_dir = os.path.join(roc_save_epoch_dir) 104 | check_dir(roc_save_epoch_cat_dir) 105 | mc.save_PTZ_metric2disk(metric_for_conditions[cond_name], roc_save_epoch_cat_dir) 106 | roc_save_dir = os.path.join(roc_save_epoch_cat_dir, 107 | '_' + str(cond_name) + '_roc.png') 108 | mc.plotPrecisionRecall(pr, recall, roc_save_dir, benchmark_pr=None) 109 | f_score_total += f_score 110 | print(f_score_total / (len(conds))) 111 | return f_score_total / len(conds) 112 | 113 | 114 | def main(): 115 | DATA_PATH = r'C:\Users\HP\Desktop\DASNet\DASNet-master\example\CDD' 116 | val_transform_det = trans.Compose([trans.Scale(256,256)]) # rescale 117 | val_data = dates.Dataset(DATA_PATH, DATA_PATH, DATA_PATH + 'test1.txt', 'val', transform=True, transform_med=val_transform_det) 118 | # num_workers 119 | val_loader = Data.DataLoader(val_data, batch_size=1, shuffle=False, num_workers=0, pin_memory=True) 120 | model = models.SiameseNet(norm_flag='l2') 121 | # 路径自定 122 | checkpoint = torch.load('the path to best model', map_location='cpu') 123 | model.load_state_dict(checkpoint['state_dict']) 124 | print('load success') 125 | model = model.cuda() 126 | # 路径自定 127 | save_change_map_dir = 'checkpoints\prediction\contrastive_loss\changemaps' 128 | save_roc_dir = 'checkpoints\prediction\contrastive_loss\roc' 129 | time_start = time.time() 130 | # validate会打印f_score,因此返回值没啥用 131 | current_metric = validate(model, val_loader, save_change_map_dir, save_roc_dir) 132 | elapsed = round(time.time() - time_start) 133 | elapsed = str(datetime.timedelta(seconds=elapsed)) 134 | print('Elapsed {}'.format(elapsed)) 135 | 136 | if __name__ == '__main__': 137 | main() 138 | 139 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | import torch 4 | import torch.nn as nn 5 | import torch.utils.data as Data 6 | from torch.nn import functional as F 7 | import utils.transforms as trans 8 | import utils.utils as util 9 | import layer.loss as ls 10 | import utils.metric as mc 11 | import cv2 12 | import model.siameseNet.dares as models 13 | import cfg.CDD as cfg 14 | import dataset.rs as dates 15 | import time 16 | from datetime import datetime 17 | import logging 18 | import configargparse 19 | from torch.cuda.amp import autocast, GradScaler 20 | 21 | 22 | os.environ["CUDA_VISIBLE_DEVICES"] = "0" 23 | 24 | 25 | def check_dir(dir): 26 | if not os.path.exists(dir): 27 | os.mkdir(dir) 28 | 29 | def various_distance(out_vec_t0, out_vec_t1,dist_flag): 30 | if dist_flag == 'l2': 31 | distance = F.pairwise_distance(out_vec_t0, out_vec_t1, p=2) 32 | if dist_flag == 'l1': 33 | distance = F.pairwise_distance(out_vec_t0, out_vec_t1, p=1) 34 | if dist_flag == 'cos': 35 | distance = 1 - F.cosine_similarity(out_vec_t0, out_vec_t1) 36 | return distance 37 | 38 | def single_layer_similar_heatmap_visual(idx,output_t0,output_t1,save_change_map_dir,epoch,batch_idx,filename,layer_flag,dist_flag): 39 | n, c, h, w = output_t0.data.shape # 1,512,32,32 -> 1,64,256,256 40 | # 拉成1024x512 -> 65536x64的向量 41 | out_t0_rz = torch.transpose(output_t0.view(c, h * w), 1, 0) 42 | out_t1_rz = torch.transpose(output_t1.view(c, h * w), 1, 0) 43 | # 计算像素对在通道上的距离,比如0,0处的两个通道的距离,这个距离应该是都小于2的 44 | distance = various_distance(out_t0_rz,out_t1_rz,dist_flag=dist_flag) 45 | similar_distance_map = distance.view(h,w).data.cpu().numpy() # 256x256 46 | # 插值到256x256 47 | similar_distance_map_rz = nn.functional.interpolate(torch.from_numpy(similar_distance_map[np.newaxis, np.newaxis, :]),size=[cfg.TRANSFROM_SCALES[1],cfg.TRANSFROM_SCALES[0]], mode='bilinear',align_corners=True) 48 | # 渲染热力图并保存 49 | save_change_map_dir_ = os.path.join(save_change_map_dir, 'epoch_' + str(epoch)) 50 | check_dir(save_change_map_dir_) 51 | save_change_map_dir_layer = os.path.join(save_change_map_dir_,layer_flag) 52 | check_dir(save_change_map_dir_layer) 53 | save_weight_fig_dir = os.path.join(save_change_map_dir_layer, str(batch_idx) + '_' + filename[0].split('/')[2]) 54 | if idx % 20 == 0: 55 | similar_dis_map_colorize = cv2.applyColorMap(np.uint8(255 * similar_distance_map_rz.data.cpu().numpy()[0][0]), cv2.COLORMAP_JET) 56 | cv2.imwrite(save_weight_fig_dir, similar_dis_map_colorize) 57 | # 返回距离map 58 | return similar_distance_map_rz.data.cpu().numpy() 59 | 60 | def validate(net, val_dataloader, epoch, batch_idx, save_change_map_dir, save_pr_dir, best_metric, best_epoch, best_batch_idx): 61 | net.eval() 62 | with torch.no_grad(): 63 | # 初始化 64 | num = 0.0 65 | # 减小阈值个数可以加速 66 | num_thresh = 96 67 | thresh = np.linspace(0.0, 2.2, num_thresh) 68 | # metric_dict = {'total_fp': [0,...,0], 'total_fn':[0,...,0], 'total_posnum':0, 'total_negnum':0} 69 | metric_dict = util.init_metric_dict(thresh=thresh) 70 | for idx, batch in enumerate(val_dataloader): 71 | input1, input2, targets, filename, height, width = batch 72 | input1, input2, targets = input1.cuda(), input2.cuda(), targets.cuda() 73 | out_conv5, out_fc, out_embedding = net(input1, input2) 74 | out_embedding_t0, out_embedding_t1 = out_embedding # 已经被标准化为[0,1]向量 75 | embedding_distance_map = single_layer_similar_heatmap_visual(idx,out_embedding_t0,out_embedding_t1,save_change_map_dir,epoch,batch_idx,filename,'embedding','l2') 76 | num += 1 77 | prob_change = embedding_distance_map[0][0] # 256x256 78 | gt = targets.data.cpu().numpy() 79 | # 求单个batch的FN, FP等, 这里的FN是在不同阈值下算出来的, posNum=TP+FN, negNum=TN+FP 80 | FN, FP, posNum, negNum = mc.eval_image(gt[0], prob_change, cl_index=1, thresh=thresh) 81 | # 循环结束后metric_dict存的是所有batch在各度量指标上的的度量数值之和 82 | metric_dict['total_fp'] += FP 83 | metric_dict['total_fn'] += FN 84 | metric_dict['total_posnum'] += posNum 85 | metric_dict['total_negnum'] += negNum 86 | 87 | # 拿到整个val set在各指标上的数值之和 88 | total_fp = metric_dict['total_fp'] 89 | total_fn = metric_dict['total_fn'] 90 | total_posnum = metric_dict['total_posnum'] 91 | total_negnum = metric_dict['total_negnum'] 92 | # mc.pxEval_maximizeFMeasure计算一个batch中最大的F_score值 93 | res_dict = mc.pxEval_maximizeFMeasure(total_posnum, total_negnum, total_fn, total_fp, thresh=thresh) 94 | metric_dict.setdefault('metric', res_dict) 95 | # 拿到f_score 96 | pr, recall, f_score = metric_dict['metric']['precision'], metric_dict['metric']['recall'], metric_dict['metric']['MaxF'] 97 | pr_save_epoch_dir = os.path.join(save_pr_dir) 98 | check_dir(pr_save_epoch_dir) 99 | pr_save_epoch_cat_dir = os.path.join(pr_save_epoch_dir) 100 | check_dir(pr_save_epoch_cat_dir) 101 | # 保存metric日志 102 | mc.save_metric_json(metric_dict, pr_save_epoch_cat_dir, epoch, batch_idx) 103 | pr_save_dir = os.path.join(pr_save_epoch_cat_dir, str(epoch) + '_' + str(batch_idx) + '_pr.png') 104 | # 保存P-R曲线 105 | mc.plotPrecisionRecall(pr, recall, pr_save_dir, benchmark_pr=None) 106 | print('f_max: ', f_score) 107 | print('best_f_max: ', best_metric) 108 | print('best_epoch: ', best_epoch) 109 | print('best_batch_idx: ',best_batch_idx) 110 | return f_score 111 | 112 | 113 | def config_parser(): 114 | parser = configargparse.ArgumentParser() 115 | parser.add_argument("--resume", type=int, default=0, help='0: no resume, 100: load the best model, others: load trained model_i.pth after i_th epoch') 116 | parser.add_argument("--start_epoch", type=int, default=0, help='from which epoch to continue training') 117 | parser.add_argument("--datatime", type=str, default=None, help='used for resume, for example: 10.28_09') 118 | return parser 119 | 120 | def main(): 121 | # parse args 122 | parser = config_parser() 123 | args = parser.parse_args() 124 | # logs 125 | DateTime = datetime.now().strftime("%D-%H") 126 | logname = DateTime.split('/')[0] + '.' + DateTime.split('/')[1] + '_' + DateTime.split('/')[2].split('-')[1] 127 | logging.basicConfig(filename=logname + '.txt', level=logging.DEBUG) 128 | # configs 129 | best_metric = 0 130 | best_epoch = 0 131 | best_batch_idx = 0 132 | # load datasets 133 | train_transform_det = trans.Compose([trans.Scale(cfg.TRANSFROM_SCALES)]) 134 | val_transform_det = trans.Compose([trans.Scale(cfg.TRANSFROM_SCALES)]) 135 | train_data = dates.Dataset(cfg.TRAIN_DATA_PATH, cfg.TRAIN_LABEL_PATH, cfg.TRAIN_TXT_PATH, 'train', transform=True, transform_med=train_transform_det) 136 | val_data = dates.Dataset(cfg.VAL_DATA_PATH, cfg.VAL_LABEL_PATH, cfg.VAL_TXT_PATH, 'val', transform=True, transform_med=val_transform_det) 137 | train_loader = Data.DataLoader(train_data, batch_size=cfg.TRAIN_BATCH_SIZE, shuffle=True, num_workers=4, pin_memory=False) 138 | val_loader = Data.DataLoader(val_data, batch_size=cfg.VAL_BATCH_SIZE, shuffle=False, num_workers=0, pin_memory=False) 139 | 140 | # build models 141 | model = models.SiameseNet(norm_flag='l2') 142 | # init training configs 143 | model = model.cuda() 144 | MaskLoss = ls.CLNew() 145 | optimizer = torch.optim.Adam(params=model.parameters(), lr=cfg.INIT_LEARNING_RATE, weight_decay=cfg.DECAY) 146 | 147 | if args.resume == 100: 148 | checkpoint = torch.load(cfg.TRAINED_BEST_PERFORMANCE_CKPT) 149 | model.load_state_dict(checkpoint['net']) 150 | optimizer.load_state_dict(checkpoint['optimizer']) 151 | print('The best model has been loaded.') 152 | 153 | elif args.resume != 0: 154 | checkpoint = torch.load(os.path.join(cfg.TRAINED_RESUME_PATH, args.datatime, 'model_'+str(args.resume)+'.pth')) 155 | model.load_state_dict(checkpoint['net']) 156 | optimizer.load_state_dict(checkpoint['optimizer']) 157 | print(args.datatime+'/'+'model_'+str(args.resume)+' has been loaded.') 158 | 159 | else: 160 | print('ResNet50 backbone has been loaded.') 161 | 162 | # check directories 163 | ab_test_dir = os.path.join(cfg.SAVE_PRED_PATH, logname) 164 | check_dir(ab_test_dir) 165 | save_change_map_dir = os.path.join(ab_test_dir, 'distance_maps/') 166 | save_pr_dir = os.path.join(ab_test_dir,'pr_curves') 167 | check_dir(save_change_map_dir) 168 | check_dir(save_pr_dir) 169 | 170 | # train loop 171 | time_start = time.time() 172 | start = args.start_epoch 173 | print('Start training from epoch {}.'.format(start)) 174 | for epoch in range(start, cfg.MAX_EPOCH): 175 | for batch_idx, batch in enumerate(train_loader): 176 | # -----lr needs to be adjusted for different batch_size----- 177 | step = epoch * 10000 + batch_idx # 30 * 10000 + 1000 178 | util.adjust_learning_rate(cfg.INIT_LEARNING_RATE, optimizer, step) 179 | # AMP traning 180 | # with autocast(): 181 | model.train() 182 | # img1, img2: (8, 3, 256, 256), label: (8, 256, 256) 183 | img1, img2, label, filename, height, width = batch 184 | img1, img2, label = img1.cuda(), img2.cuda(), label.cuda().float() 185 | out_conv5, out_fc, out_embedding = model(img1, img2) # forward 186 | out_conv5_t0, out_conv5_t1 = out_conv5 # (1, 512, 32, 32) --> (1, 64, 256, 256) 187 | out_fc_t0, out_fc_t1 = out_fc # (1, 512, 32, 32) --> (1, 64, 256, 256) 188 | out_embedding_t0, out_embedding_t1 = out_embedding # (1, 512, 32, 32) --> (1, 64, 256, 256) 189 | # 这三个rz_label相等 32x32 --> 256x256 190 | label_rz_conv5 = util.rz_label(label, size=out_conv5_t0.data.cpu().numpy().shape[2:]).cuda() 191 | label_rz_fc = util.rz_label(label, size=out_fc_t0.data.cpu().numpy().shape[2:]).cuda() 192 | label_rz_embedding = util.rz_label(label, size=out_embedding_t0.data.cpu().numpy().shape[2:]).cuda() 193 | # 求3个loss 194 | contrastive_loss_conv5 = MaskLoss(out_conv5_t0, out_conv5_t1, label_rz_conv5) # 一个实数 195 | contrastive_loss_fc = MaskLoss(out_fc_t0, out_fc_t1, label_rz_fc) 196 | contrastive_loss_embedding = MaskLoss(out_embedding_t0, out_embedding_t1, label_rz_embedding) 197 | loss = contrastive_loss_conv5 + contrastive_loss_fc + contrastive_loss_embedding 198 | 199 | optimizer.zero_grad() 200 | loss.backward() 201 | optimizer.step() 202 | 203 | if batch_idx % 20 == 0: 204 | logging.info("epoch/batch_idx: [%d/%d] lr: %.6f best_f_max: %.4f best_epoch: %d best_batch_idx: %d loss: %.4f loss_conv5: %.4f loss_fc: %.4f " 205 | "loss_embedding: %.4f" % (epoch, batch_idx, optimizer.state_dict()['param_groups'][0]['lr'], best_metric, best_epoch, best_batch_idx, loss.item(), contrastive_loss_conv5.item(), 206 | contrastive_loss_fc.item(), contrastive_loss_embedding.item())) 207 | print("epoch/batch_idx: [%d/%d] lr: %.6f loss: %.4f loss_conv5: %.4f loss_fc: %.4f " 208 | "loss_embedding: %.4f" % (epoch, batch_idx, optimizer.state_dict()['param_groups'][0]['lr'], loss.item(), contrastive_loss_conv5.item(), 209 | contrastive_loss_fc.item(), contrastive_loss_embedding.item())) 210 | # 每1000个batch执行一次f_score验证 211 | # if (batch_idx) % 1000 == 0: 212 | if (batch_idx != 0) & (batch_idx % 1000 == 0): 213 | model.eval() 214 | current_metric = validate(model, val_loader, epoch, batch_idx, save_change_map_dir, save_pr_dir, best_metric, best_epoch, best_batch_idx) 215 | if current_metric > best_metric: 216 | state = {'net': model.state_dict(), 'optimizer': optimizer.state_dict()} 217 | torch.save(state, os.path.join(ab_test_dir, 'model_best_'+logname+'_'+str(best_epoch)+'_'+str(best_batch_idx)+'.pth')) 218 | best_metric = current_metric 219 | best_epoch = epoch 220 | best_batch_idx = batch_idx 221 | 222 | # 训练集已遍历一遍 223 | model.eval() 224 | current_metric = validate(model, val_loader, epoch, batch_idx, save_change_map_dir, save_pr_dir, best_metric, best_epoch, best_batch_idx) 225 | # model_i.pth表示经过第i个epoch的训练后得到的模型 226 | state = {'net': model.state_dict(), 'optimizer': optimizer.state_dict()} 227 | torch.save(state, os.path.join(ab_test_dir, 'model_' + str(epoch) + '.pth')) 228 | if current_metric > best_metric: 229 | state = {'net': model.state_dict(), 'optimizer': optimizer.state_dict()} 230 | torch.save(state, os.path.join(ab_test_dir, 'model_best_'+logname+'_'+str(best_epoch)+'_'+str(best_batch_idx)+'.pth')) 231 | best_metric = current_metric 232 | best_epoch = epoch 233 | best_batch_idx = batch_idx 234 | 235 | elapsed = round(time.time() - time_start) 236 | print('Elapsed {}'.format(elapsed)) 237 | 238 | if __name__ == '__main__': 239 | main() 240 | 241 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/silence-tang/DASNet-V2/e56f6e095e1e232b85365ef500267665c4c47fb5/utils/__init__.py -------------------------------------------------------------------------------- /utils/metric.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pylab 3 | 4 | # def evalExp(gtBin, cur_prob, thres): 5 | 6 | # # thresInf = np.concatenate(([-np.Inf], thres, [np.Inf])) 7 | 8 | # # 求FN 9 | # fnArray = cur_prob[(gtBin == True)] 10 | # FN = np.array([sum(fnArray <= t) for t in thres]) 11 | # # fnHist = np.histogram(fnArray, bins=thresInf)[0] 12 | # # fnCum = np.cumsum(fnHist) 13 | # # FN = fnCum[0: 0+len(thres)] 14 | 15 | # # 求FP 16 | # fpArray = cur_prob[(gtBin == False)] 17 | # FP = np.array([sum(fpArray > t) for t in thres]) 18 | # # fpHist = np.histogram(fpArray, bins=thresInf)[0] 19 | # # fpCum = np.flipud(np.cumsum(np.flipud(fpHist))) 20 | # # FP = fpCum[1: 1+len(thres)] 21 | 22 | # posNum = np.sum(gtBin == True) # 实际为changed的数量(TP+FN) 23 | # negNum = np.sum(gtBin == False) # 实际为unchanged的数量(TN+FP) 24 | 25 | # return FN, FP, posNum, negNum 26 | 27 | 28 | def evalExp(gtBin, cur_prob, thres): 29 | 30 | thresInf = np.concatenate(([-np.Inf], thres, [np.Inf])) # 便于后续进行直方图统计 31 | 32 | fnArray = cur_prob[(gtBin == True)] 33 | fnHist = np.histogram(fnArray, bins=thresInf)[0] 34 | fnCum = np.cumsum(fnHist) 35 | FN = fnCum[0: 0 + len(thres)] 36 | 37 | fpArray = cur_prob[(gtBin == False)] 38 | fpHist = np.histogram(fpArray, bins=thresInf)[0] # 左闭右开区间 39 | # 倒置求累计和再倒置回去 40 | fpCum = np.flipud(np.cumsum(np.flipud(fpHist))) 41 | FP = fpCum[1: 1 + len(thres)] # 从0或1开始取数取决于开闭区间 42 | 43 | posNum = np.sum(gtBin == True) 44 | negNum = np.sum(gtBin == False) 45 | 46 | return FN, FP, posNum, negNum 47 | 48 | 49 | def eval_image(gt_image, prob, cl_index, thresh): 50 | # 设定阈值 51 | thresh = thresh 52 | # true/false map of ground truth 53 | cl_gt = gt_image[:, :] == cl_index 54 | FN, FP, posNum, negNum = evalExp(cl_gt, prob, thresh) 55 | return FN, FP, posNum, negNum 56 | 57 | 58 | def pxEval_maximizeFMeasure(totalPosNum, totalNegNum, totalFN, totalFP, thresh=None): 59 | ''' 60 | @param totalPosNum: scalar 61 | @param totalNegNum: scalar 62 | @param totalFN: vector 63 | @param totalFP: vector 64 | @param thresh: vector 65 | ''' 66 | # TP TN 67 | totalTP = totalPosNum - totalFN 68 | totalTN = totalNegNum - totalFP 69 | 70 | valid = (totalTP >= 0) & (totalTN >= 0) # 检测有效值 71 | assert valid.all(), 'Detected invalid elements in eval' 72 | 73 | recall = totalTP / float(totalPosNum) # 分母一定非0 74 | precision = totalTP / (totalTP + totalFP + 1e-10) # 防止出现分母为0 75 | selector_invalid = (recall == 0) & (precision == 0) # 找到TP=0的情况 76 | recall = recall[~selector_invalid] # 将其排除 77 | precision = precision[~selector_invalid] 78 | 79 | # F-measure 80 | beta = 1.0 81 | betasq = beta ** 2 82 | F = (1 + betasq) * (precision * recall) / ((betasq * precision) + recall + 1e-10) # 防止出现分母为0 83 | index = F.argmax() 84 | MaxF = F[index] # 求maxf 85 | 86 | # recall_bst = recall[index] 87 | # precision_bst = precision[index] 88 | # TP = totalTP[index] 89 | # TN = totalTN[index] 90 | # FP = totalFP[index] 91 | # FN = totalFN[index] 92 | # valuesMaxF = np.zeros((1, 4), 'u4') 93 | # valuesMaxF[0, 0] = TP 94 | # valuesMaxF[0, 1] = TN 95 | # valuesMaxF[0, 2] = FP 96 | # valuesMaxF[0, 3] = FN 97 | 98 | # ACC = (totalTP+ totalTN)/(totalPosNum+totalNegNum) 99 | 100 | prob_eval_scores = {} 101 | prob_eval_scores['MaxF'] = MaxF 102 | prob_eval_scores['totalPosNum'] = totalPosNum 103 | prob_eval_scores['totalNegNum'] = totalNegNum 104 | prob_eval_scores['precision'] = precision 105 | prob_eval_scores['recall'] = recall 106 | prob_eval_scores['thresh'] = thresh 107 | 108 | if np.any(thresh) != None: 109 | BestThresh = thresh[index] 110 | prob_eval_scores['BestThresh'] = BestThresh 111 | print('cur_best_thresh: ', BestThresh) 112 | 113 | # return a dict 114 | return prob_eval_scores 115 | 116 | 117 | def setFigLinesBW(fig): 118 | """ 119 | Take each axes in the figure, and for each line in the axes, make the 120 | line viewable in black and white. 121 | """ 122 | for ax in fig.get_axes(): 123 | setAxLinesBW(ax) 124 | 125 | 126 | def setAxLinesBW(ax): 127 | """ 128 | Take each Line2D in the axes, ax, and convert the line style to be 129 | suitable for black and white viewing. 130 | """ 131 | MARKERSIZE = 3 132 | 133 | # COLORMAP = { 134 | # 'r': {'marker': None, 'dash': (None,None)}, 135 | # 'g': {'marker': None, 'dash': [5,2]}, 136 | # 'm': {'marker': None, 'dash': [11,3]}, 137 | # 'b': {'marker': None, 'dash': [6,3,2,3]}, 138 | # 'c': {'marker': None, 'dash': [1,3]}, 139 | # 'y': {'marker': None, 'dash': [5,3,1,2,1,10]}, 140 | # 'k': {'marker': 'o', 'dash': (None,None)} #[1,2,1,10]} 141 | # } 142 | ''''''''' 143 | COLORMAP = { 144 | 'r': {'marker': "None", 'dash': (None,None)}, 145 | 'g': {'marker': "None", 'dash': [5,2]}, 146 | 'm': {'marker': "None", 'dash': [11,3]}, 147 | 'b': {'marker': "None", 'dash': [6,3,2,3]}, 148 | 'c': {'marker': "None", 'dash': [1,3]}, 149 | 'y': {'marker': "None", 'dash': [5,3,1,2,1,10]}, 150 | 'k': {'marker': 'o', 'dash': (None,None)} #[1,2,1,10]} 151 | } 152 | ''''''''' 153 | 154 | COLORMAP = { 155 | 'r': {'marker': "None", 'dash': (None, None)}, 156 | 'g': {'marker': "None", 'dash': (None, None)}, 157 | 'm': {'marker': "None", 'dash': (None, None)}, 158 | 'b': {'marker': "None", 'dash': (None, None)}, 159 | 'c': {'marker': "None", 'dash': (None, None)}, 160 | 'y': {'marker': "None", 'dash': (None, None)}, 161 | 'k': {'marker': 'o', 'dash': (None, None)} # [1,2,1,10]} 162 | } 163 | 164 | for line in ax.get_lines(): 165 | origColor = line.get_color() 166 | # line.set_color('black') 167 | 168 | line.set_dashes(COLORMAP[origColor]['dash']) 169 | line.set_marker(COLORMAP[origColor]['marker']) 170 | line.set_markersize(MARKERSIZE) 171 | 172 | 173 | def plotPrecisionRecall(precision, recall, outFileName, benchmark_pr=None, Fig=None, drawCol=0, textLabel=None, 174 | title=None, fontsize1=14, fontsize2=10, linewidth=3): 175 | ''' 176 | :param precision: 177 | :param recall: 178 | :param outFileName: 179 | :param Fig: 180 | :param drawCol: 181 | :param textLabel: 182 | :param fontsize1: 183 | :param fontsize2: 184 | :param linewidth: 185 | ''' 186 | clearFig = False 187 | 188 | if Fig == None: 189 | Fig = pylab.figure() 190 | clearFig = True 191 | 192 | linecol = ['r', 'm', 'b', 'c'] 193 | 194 | if benchmark_pr != None: 195 | 196 | benchmark_recall = np.array(benchmark_pr['recall']) 197 | benchmark_precision = np.array(benchmark_pr['precision']) 198 | pylab.plot(100 * benchmark_recall, 100 * benchmark_precision, linewidth=linewidth, color=linecol[drawCol], 199 | label=textLabel) 200 | else: 201 | pylab.plot(100 * recall, 100 * precision, linewidth=2, color=linecol[drawCol], label=textLabel) 202 | 203 | # writing out PrecRecall curves as graphic 204 | setFigLinesBW(Fig) 205 | if textLabel != None: 206 | pylab.legend(loc='lower left', prop={'size': fontsize2}) 207 | 208 | if title != None: 209 | pylab.title(title, fontsize=fontsize1) 210 | 211 | # pylab.title(title,fontsize=24) 212 | pylab.ylabel('Precision [%]', fontsize=fontsize1) 213 | pylab.xlabel('Recall [%]', fontsize=fontsize1) 214 | 215 | pylab.xlim(0, 100) 216 | pylab.xticks([0, 10, 20, 30, 40, 50, 60, 70, 80, 90, 100], 217 | ('0', '', '0.20', '', '0.40', '', '0.60', '', '0.80', '', '1.0'), fontsize=fontsize2) 218 | pylab.ylim(0, 100) 219 | pylab.yticks([0, 10, 20, 30, 40, 50, 60, 70, 80, 90, 100], 220 | ('0', '', '0.20', '', '0.40', '', '0.60', '', '0.80', '', '1.0'), fontsize=fontsize2) 221 | 222 | # pylab.grid(True) 223 | # 224 | if type(outFileName) != list: 225 | pylab.savefig(outFileName) 226 | else: 227 | for outFn in outFileName: 228 | pylab.savefig(outFn) 229 | if clearFig: 230 | pylab.close() 231 | Fig.clear() 232 | 233 | 234 | def save_metric_json(metrics, save_path, epoch, batch_idx): 235 | import json 236 | # metric_dict= {} 237 | recall_ = list(metrics['metric']['recall']) 238 | precision_ = list(metrics['metric']['precision']) 239 | f_score = metrics['metric']['MaxF'] 240 | # cont_conv5 = metrics['contrast_conv5'] 241 | # cont_embedding = metrics['contrast_embedding'] 242 | metric_ = {'recall': recall_, 'precision': precision_, 'f-score': f_score} 243 | # metric_ = {'recall': recall_, 'precision': precision_, 'f-score': f_score, 244 | # 'contrast_embedding': cont_embedding,'contrast_conv5':cont_conv5} 245 | file_ = open(save_path + '/' + str(epoch) + '_' + str(batch_idx) + '_metric.json', 'w') 246 | file_.write(json.dumps(metric_, ensure_ascii=False, indent=2)) 247 | file_.close() 248 | 249 | 250 | def RMS_Contrast(dist_map): 251 | n, c, h, w = dist_map.shape 252 | dist_map_l = np.resize(dist_map, (n * c * h * w)) 253 | mean = np.mean(dist_map_l, axis=0) 254 | std = np.std(dist_map_l, axis=0, ddof=1) 255 | contrast = std / mean 256 | return contrast -------------------------------------------------------------------------------- /utils/transforms.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | 3 | from PIL import Image 4 | import collections 5 | 6 | 7 | 8 | 9 | class Scale(object): 10 | """Rescale the input PIL.Image to the given size. 11 | 12 | Args: 13 | size (sequence or int): Desired output size. If size is a sequence like 14 | (w, h), output size will be matched to this. If size is an int, 15 | smaller edge of the image will be matched to this number. 16 | i.e, if height > width, then image will be rescaled to 17 | (size * height / width, size) 18 | interpolation (int, optional): Desired interpolation. Default is 19 | ``PIL.Image.BILINEAR`` 20 | """ 21 | 22 | def __init__(self, size, interpolation=Image.BILINEAR): 23 | assert isinstance(size, int) or (isinstance(size, collections.Iterable) and len(size) == 2) 24 | self.size = size 25 | self.interpolation = interpolation 26 | 27 | def __call__(self, img): 28 | """ 29 | Args: 30 | img (PIL.Image): Image to be scaled. 31 | 32 | Returns: 33 | PIL.Image: Rescaled image. 34 | """ 35 | if isinstance(self.size, int): 36 | w, h = img.size 37 | if (w <= h and w == self.size) or (h <= w and h == self.size): 38 | return img 39 | if w < h: 40 | ow = self.size 41 | oh = int(self.size * h / w) 42 | return img.resize((ow, oh), self.interpolation) 43 | else: 44 | oh = self.size 45 | ow = int(self.size * w / h) 46 | return img.resize((ow, oh), self.interpolation) 47 | else: 48 | return img.resize(self.size, self.interpolation) 49 | -------------------------------------------------------------------------------- /utils/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import torch.nn as nn 4 | 5 | 6 | def adjust_learning_rate(learning_rate, optimizer, step): 7 | """Adjust the initial_lr to be decayed every 20 epochs""" 8 | if step <= 200000: 9 | lr = learning_rate 10 | elif (step > 200000) & (step <= 400000): 11 | lr = 5e-5 12 | elif (step > 400000) & (step <= 500000): 13 | lr = 1e-5 14 | elif (step > 600000) & (step <= 700000): 15 | lr = 5e-6 16 | else: 17 | lr = 1e-6 18 | 19 | optimizer.state_dict()['param_groups'][0]['lr'] = lr 20 | for param_group in optimizer.param_groups: 21 | param_group['lr'] = lr 22 | 23 | 24 | def init_metric_dict(thresh): 25 | metric_for_class = {} 26 | thresh = thresh # 阈值 27 | total_fp = np.zeros(thresh.shape) 28 | total_fn = np.zeros(thresh.shape) 29 | metric_for_class.setdefault('total_fp', total_fp) 30 | metric_for_class.setdefault('total_fn', total_fn) 31 | metric_for_class.setdefault('total_posnum', 0) 32 | metric_for_class.setdefault('total_negnum', 0) 33 | return metric_for_class 34 | 35 | 36 | # 原始label可能尺寸大于feature,故插值构成新的label 37 | def rz_label(label, size): 38 | gt_e = torch.unsqueeze(label, dim=1) 39 | interp = nn.functional.interpolate(gt_e, (size[0],size[1]), mode='bilinear', align_corners=True) 40 | gt_rz = torch.squeeze(interp, dim=1) 41 | return gt_rz 42 | --------------------------------------------------------------------------------