├── README.md ├── dataset ├── fewshot.py └── transform.py ├── docs ├── CrackNex.jpg ├── vis.jpg └── vis_wild.jpg ├── model ├── CrackNex_matching.py ├── DecompNet.py ├── backbone │ └── DecompNet.tar └── resnet.py ├── test.py ├── test.sh ├── train.py ├── train.sh └── util ├── __init__.py ├── loss.py └── utils.py /README.md: -------------------------------------------------------------------------------- 1 | # [ICRA2024] CrackNex: a Few-shot Low-light Crack Segmentation Model Based on Retinex Theory for UAV Inspections 2 | [Zhen Yao](https://scholar.google.com/citations?user=8-IhrB0AAAAJ&hl=en&oi=sra), [Jiawei Xu](https://scholar.google.com/citations?user=b3XkcPkAAAAJ&hl=en&oi=ao), Shuhang Hou, [Mooi Choo Chuah](https://scholar.google.com/citations?hl=en&user=SZBKvksAAAAJ). 3 | 4 | 5 | 6 | 7 | The codebase contains the official code of our paper [CrackNex: a Few-shot Low-light Crack Segmentation Model Based on Retinex Theory for UAV Inspections](https://arxiv.org/abs/2403.03063), ICRA 2024. 8 | 9 | ## Abstract 10 | Routine visual inspections of concrete structures are imperative for upholding the safety and integrity of critical infrastructure. Such visual inspections sometimes happen under low-light conditions, e.g., checking for bridge health. Crack segmentation under such conditions is challenging due to the poor contrast between cracks and their surroundings. However, most deep learning methods are designed for well-illuminated crack images and hence their performance drops dramatically in low-light scenes. In addition, conventional approaches require many annotated low-light crack images which is time-consuming. In this paper, we address these challenges by proposing CrackNex, a framework that utilizes reflectance information based on Retinex Theory to learn a unified illumination-invariant representation. Furthermore, we utilize few-shot segmentation to solve the inefficient training data problem. In CrackNex, both a support prototype and a reflectance prototype are extracted from the support set. Then, a prototype fusion module is designed to integrate the features from both prototypes. CrackNex outperforms the SOTA methods on multiple datasets. Additionally, we present the first benchmark dataset, LCSD, for low-light crack segmentation. LCSD consists of 102 well-illuminated crack images and 41 low-light crack images. 11 | 12 |

13 | 14 |

15 | 16 | 17 | ## LCSD Dataset 18 | We have established a crack segmentation dataset, LCSD. It comprises 102 normal light crack images and 41 low-light crack images with fine-grained annotations. 19 | 20 | You can download the datasets from [the link](https://drive.google.com/drive/folders/1K81AjvIFje5BBxG1qxFmfqRhxf0PKNxe?usp=drive_link). 21 | 22 | ## Requirements 23 | 24 | - Python 3.7 25 | - PyTorch 1.9.1 26 | - cuda 11.1 27 | - pillow 8.4.0 28 | - opencv 29 | - sklearn 30 | 31 | Conda environment settings: 32 | ```bash 33 | conda create -n cracknex python=3.7 -y 34 | conda activate cracknex 35 | 36 | pip install torch==1.9.1+cu111 torchvision==0.10.1+cu111 -f https://download.pytorch.org/whl/torch_stable.html 37 | conda install -c conda-forge cudatoolkit-dev==11.1.1 38 | pip install opencv-python==4.5.2.54 39 | pip install scikit-learn 40 | pip install tqdm 41 | pip install pillow==8.4.0 42 | ``` 43 | 44 | ## Data preparation 45 | **Pretrained model:** [ResNet-50](https://drive.google.com/file/d/1zphUj3ffl8J2HCRq5AjPsdOVojiGuQZB/view?usp=drive_link) | [ResNet-101](https://drive.google.com/file/d/1G6MVe_8ywk3NyHwpWWoUoZtvQ4pWJWj-/view?usp=drive_link) 46 | 47 | 48 | ### File Organization 49 | 50 | To prepare the LCSD datasets, you need to download all images and labels from the above link. Put all images in ``JPEGImages/`` subfolder and all labels in ``SegmentationClassAug/`` subfolder, and two split ``txt`` file into ``ImageSets/`` subfolder. We also provide a zip file of our processed [LCSD](https://drive.google.com/file/d/1a8Q1ng38KfNad2emyWJE4mgGv-xs3Dg4/view?usp=drive_link) dataset. 51 | ``` 52 | ../ # parent directory 53 | ├── ./CrackNex # current (project) directory 54 | | ├── codes # various codes 55 | | └── ./pretrained # pretrained model directory 56 | | ├── resnet50.pth 57 | | └── resnet101.pth 58 | └── Datasets_CrackNex/ 59 | ├── LCSD/ 60 | │ ├── ImageSets/ 61 | │ ├── JPEGImages/ 62 | │ └── SegmentationClassAug/ 63 | ``` 64 | 65 | ## Run the code 66 | 67 | You can adapt the scripts of `train.sh` and `test.sh` to train and evaluate your models. 68 | 69 | You may change the ``backbone`` from ``resnet50`` to ``resnet101`` or change the ``shot`` from ``1`` to ``5`` for other settings. 70 | 71 | ``` 72 | bash train.sh 73 | ``` 74 | 75 | ``` 76 | bash test.sh 77 | ``` 78 | 79 | Remember to change the path to dataset and checkpoint. 80 | 81 | ## Evaluation and Trained Models 82 | 83 | ### LCSD 84 | 85 | | Method | Setting | Backbone | mIoU | 86 | | :-----: | :-----: | :---------: | :----: | 87 | | CrackNex (Ours) | 1-shot | ResNet-50 | [63.85](https://drive.google.com/file/d/1T9i1S_UlFOzjuSpn7Oeo9RX7BYgQ9jMw/view?usp=drive_link) | 88 | | CrackNex (Ours) | 1-shot | ResNet-101 | [66.10](https://drive.google.com/file/d/1vM3cRazeLmU2QnLIY4RjS7JTJBfHaCV3/view?usp=drive_link) | 89 | | CrackNex (Ours) | 5-shot | ResNet-50 | [65.17](https://drive.google.com/file/d/1uADCeaGZQNvr25dqVgXNJcjbZNAIrdI-/view?usp=drive_link) | 90 | | CrackNex (Ours) | 5-shot | ResNet-101 | [68.82](https://drive.google.com/file/d/1D3R9rCHrP58l48qSOhgtMzmDLvk8-8nf/view?usp=drive_link) | 91 | 92 | ### Visualization in LCSD 93 | 94 |

95 | 96 |

97 | 98 | ### Visualization of videos in the wild 99 |

100 | 101 |

102 | 103 | ## Acknowledgment 104 | This codebase is built based on [SSP's baseline code](https://github.com/fanq15/SSP/). We thank SSP and other FSS works for their great contributions. 105 | 106 | ## Citation 107 | ```bibtex 108 | @inproceedings{yao2024cracknex, 109 | title={Cracknex: a few-shot low-light crack segmentation model based on retinex theory for uav inspections}, 110 | author={Yao, Zhen and Xu, Jiawei and Hou, Shuhang and Chuah, Mooi Choo}, 111 | booktitle={2024 IEEE International Conference on Robotics and Automation (ICRA)}, 112 | pages={11155--11162}, 113 | year={2024}, 114 | organization={IEEE} 115 | } 116 | ``` 117 | -------------------------------------------------------------------------------- /dataset/fewshot.py: -------------------------------------------------------------------------------- 1 | from dataset.transform import crop, hflip, normalize 2 | 3 | from collections import defaultdict 4 | import numpy as np 5 | import os 6 | from PIL import Image 7 | import random 8 | import torch 9 | from torch.utils.data import Dataset 10 | from torchvision import transforms 11 | 12 | 13 | class FewShot(Dataset): 14 | """ 15 | FewShot generates support-query pairs in an episodic manner, 16 | intended for meta-training and meta-testing paradigm. 17 | """ 18 | 19 | def __init__(self, root, size, mode, shot, episode): 20 | super(FewShot, self).__init__() 21 | self.size = size 22 | self.mode = mode 23 | self.fold = 0 24 | self.shot = shot 25 | self.episode = episode 26 | 27 | self.img_path = os.path.join(root, 'JPEGImages') 28 | self.mask_path = os.path.join(root, 'SegmentationClass') 29 | self.id_path = os.path.join(root, 'ImageSets') 30 | 31 | #class 1: normal light crack; class 2: lowlight crack; 32 | n_class = 2 33 | 34 | interval = n_class // 2 35 | if self.mode == 'train': 36 | # base classes = all classes - novel classes 37 | self.classes = set(range(interval * self.fold + 1, interval * (self.fold + 1) + 1)) 38 | else: 39 | # novel classes 40 | self.classes = set(range(1, n_class + 1)) - set(range(interval * self.fold + 1, interval * (self.fold + 1) + 1)) 41 | # the image ids must be stored in 'train.txt' and 'val.txt' 42 | with open(os.path.join(self.id_path, '%s.txt' % mode), 'r') as f: 43 | self.ids = f.read().splitlines() 44 | 45 | self._filter_ids() 46 | self.cls_to_ids = self._map_cls_to_cls() 47 | 48 | def __getitem__(self, item): 49 | # the sampling strategy is based on the description in OSLSM paper 50 | 51 | # query id, image, mask 52 | if self.mode == 'train': 53 | id_q = random.choice(self.ids) 54 | else: 55 | id_q = self.ids[item] 56 | img_q = Image.open(os.path.join(self.img_path, id_q + ".jpg")).convert('RGB') 57 | mask_q = Image.fromarray(np.array(Image.open(os.path.join(self.mask_path, id_q + ".png")))) 58 | # target class 59 | cls = random.choice(sorted(set(np.unique(mask_q)) & self.classes)) 60 | 61 | # support ids, images and masks 62 | id_s_list, img_s_list, hiseq_s_list, mask_s_list = [], [], [], [] 63 | while True: 64 | id_s = random.choice(sorted(set(self.cls_to_ids[cls]) - {id_q} - set(id_s_list))) 65 | img_s = Image.open(os.path.join(self.img_path, id_s + ".jpg")).convert('RGB') 66 | mask_s = Image.fromarray(np.array(Image.open(os.path.join(self.mask_path, id_s + ".png")))) 67 | 68 | # small objects in support images are filtered following PFENet 69 | if np.sum(np.array(mask_s) == cls) < 2 * 32 * 32: 70 | continue 71 | 72 | id_s_list.append(id_s) 73 | img_s_list.append(img_s) 74 | mask_s_list.append(mask_s) 75 | if len(id_s_list) == self.shot: 76 | break 77 | 78 | if self.mode == 'train': 79 | img_q, mask_q = crop(img_q, mask_q, self.size) 80 | img_q, mask_q = hflip(img_q, mask_q) 81 | for k in range(self.shot): 82 | img_s_list[k], mask_s_list[k] = crop(img_s_list[k], mask_s_list[k], self.size) 83 | img_s_list[k], mask_s_list[k] = hflip(img_s_list[k], mask_s_list[k]) 84 | 85 | img_q, hiseq_q, mask_q = normalize(img_q, mask_q) 86 | 87 | for k in range(self.shot): 88 | img_s_list[k], hiseq_s, mask_s_list[k] = normalize(img_s_list[k], mask_s_list[k]) 89 | hiseq_s_list.append(hiseq_s) 90 | 91 | # filter out irrelevant classes by setting them as background 92 | mask_q[(mask_q != cls) & (mask_q != 255)] = 0 93 | mask_q[mask_q == cls] = 1 94 | for k in range(self.shot): 95 | mask_s_list[k][(mask_s_list[k] != cls) & (mask_s_list[k] != 255)] = 0 96 | mask_s_list[k][mask_s_list[k] == cls] = 1 97 | 98 | return img_s_list, hiseq_s_list, mask_s_list, img_q, hiseq_q, mask_q, cls, id_s_list, id_q 99 | 100 | def __len__(self): 101 | if self.mode == 'train': 102 | return self.episode 103 | else: 104 | return len(self.ids) 105 | 106 | # remove images that do not contain any valid classes 107 | # and remove images whose valid objects are all small (according to PFENet) 108 | def _filter_ids(self): 109 | for i in range(len(self.ids) - 1, -1, -1): 110 | mask = Image.fromarray(np.array(Image.open(os.path.join(self.mask_path, self.ids[i] + '.png')))) 111 | classes = set(np.unique(mask)) & self.classes 112 | if not classes: 113 | del self.ids[i] 114 | continue 115 | 116 | # remove images whose valid objects are all small (according to PFENet) 117 | exist_large_objects = False 118 | for cls in classes: 119 | if np.sum(np.array(mask) == cls) >= 2 * 32 * 32: 120 | exist_large_objects = True 121 | break 122 | if not exist_large_objects: 123 | del self.ids[i] 124 | 125 | # map each valid class to a list of image ids 126 | def _map_cls_to_cls(self): 127 | cls_to_ids = defaultdict(list) 128 | for id_ in self.ids: 129 | mask = np.array(Image.open(os.path.join(self.mask_path, id_ + ".png"))) 130 | valid_classes = set(np.unique(mask)) & self.classes 131 | for cls in valid_classes: 132 | cls_to_ids[cls].append(id_) 133 | return cls_to_ids 134 | -------------------------------------------------------------------------------- /dataset/transform.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from PIL import Image, ImageOps 3 | import random 4 | import torch 5 | from torchvision import transforms 6 | 7 | 8 | def crop(img, mask, size): 9 | # padding height or width if smaller than cropping size 10 | w, h = img.size 11 | padw = size - w if w < size else 0 12 | padh = size - h if h < size else 0 13 | img = ImageOps.expand(img, border=(0, 0, padw, padh), fill=0) 14 | mask = ImageOps.expand(mask, border=(0, 0, padw, padh), fill=255) 15 | 16 | # cropping 17 | w, h = img.size 18 | x = random.randint(0, w - size) 19 | y = random.randint(0, h - size) 20 | img = img.crop((x, y, x + size, y + size)) 21 | mask = mask.crop((x, y, x + size, y + size)) 22 | 23 | return img, mask 24 | 25 | 26 | def hflip(img, mask): 27 | if random.random() < 0.5: 28 | img = img.transpose(Image.FLIP_LEFT_RIGHT) 29 | mask = mask.transpose(Image.FLIP_LEFT_RIGHT) 30 | return img, mask 31 | 32 | 33 | def normalize(img, mask): 34 | """ 35 | :param img: PIL image 36 | :param mask: PIL image, corresponding mask 37 | :return: normalized torch tensor of image and mask 38 | """ 39 | hiseq = ImageOps.equalize(img, mask=None) 40 | hiseq = transforms.Compose([ 41 | transforms.ToTensor(), 42 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), 43 | ])(hiseq) 44 | img = transforms.Compose([ 45 | transforms.ToTensor(), 46 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), 47 | ])(img) 48 | mask = torch.from_numpy(np.array(mask)).long() 49 | return img, hiseq, mask 50 | -------------------------------------------------------------------------------- /docs/CrackNex.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zyaocoder/CrackNex/122725c5e43ed1a9967cad0f29b01627901f3569/docs/CrackNex.jpg -------------------------------------------------------------------------------- /docs/vis.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zyaocoder/CrackNex/122725c5e43ed1a9967cad0f29b01627901f3569/docs/vis.jpg -------------------------------------------------------------------------------- /docs/vis_wild.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zyaocoder/CrackNex/122725c5e43ed1a9967cad0f29b01627901f3569/docs/vis_wild.jpg -------------------------------------------------------------------------------- /model/CrackNex_matching.py: -------------------------------------------------------------------------------- 1 | import model.resnet as resnet 2 | 3 | import torch 4 | from torch import nn 5 | import torch.nn.functional as F 6 | import pdb 7 | from model.DecompNet import DecompNet, load_decomp_ckpt 8 | 9 | class ASPP_module(nn.Module): 10 | def __init__(self, inplanes, planes, rate): 11 | super(ASPP_module, self).__init__() 12 | if rate == 1: 13 | kernel_size = 1 14 | padding = 0 15 | else: 16 | kernel_size = 3 17 | padding = rate 18 | self.atrous_convolution = nn.Conv2d(inplanes, planes, kernel_size=kernel_size, 19 | stride=1, padding=padding, dilation=rate, bias=False) 20 | self.bn = nn.BatchNorm2d(planes) 21 | self.relu = nn.ReLU() 22 | 23 | def forward(self, x): 24 | x = self.atrous_convolution(x) 25 | x = self.bn(x) 26 | 27 | return self.relu(x) 28 | 29 | class CrackNex(nn.Module): 30 | def __init__(self, backbone): 31 | super(CrackNex, self).__init__() 32 | decomp_path = './model/backbone/DecompNet.tar' 33 | rgb_backbone = resnet.__dict__[backbone](pretrained=True) 34 | ref_backbone = resnet.__dict__[backbone](pretrained=True) 35 | self.rgb_layer0 = nn.Sequential(rgb_backbone.conv1, rgb_backbone.bn1, rgb_backbone.relu, rgb_backbone.maxpool) 36 | self.rgb_layer1, self.rgb_layer2, self.rgb_layer3 = rgb_backbone.layer1, rgb_backbone.layer2, rgb_backbone.layer3 37 | self.ref_layer0 = nn.Sequential(ref_backbone.conv1, ref_backbone.bn1, ref_backbone.relu, ref_backbone.maxpool) 38 | self.ref_layer1, self.ref_layer2, self.ref_layer3 = ref_backbone.layer1, ref_backbone.layer2, ref_backbone.layer3 39 | 40 | self.proj_weight_FP = nn.Sequential(nn.Conv2d(2048, 1024, 1, stride=1, bias=False), 41 | nn.ReLU(), 42 | nn.Conv2d(1024, 1024, 1, stride=1, bias=False), 43 | nn.Sigmoid()) 44 | 45 | self.proj_weight_BP = nn.Sequential(nn.Conv2d(2048, 1024, 1, stride=1, bias=False), 46 | nn.ReLU(), 47 | nn.Conv2d(1024, 1024, 1, stride=1, bias=False), 48 | nn.Sigmoid()) 49 | 50 | rates = [1, 6, 12, 18] 51 | self.aspp1 = ASPP_module(1024, 512, rate=rates[0]) 52 | self.aspp2 = ASPP_module(1024, 512, rate=rates[1]) 53 | self.aspp3 = ASPP_module(1024, 512, rate=rates[2]) 54 | self.aspp4 = ASPP_module(1024, 512, rate=rates[3]) 55 | self.relu = nn.ReLU(inplace=True) 56 | self.dropout = nn.Dropout(0.1) 57 | self.global_avg_pool = nn.Sequential(nn.AdaptiveAvgPool2d((1, 1)), 58 | nn.Conv2d(1024, 512, 1, stride=1, bias=False), 59 | nn.BatchNorm2d(512), 60 | nn.ReLU()) 61 | self.conv1 = nn.Conv2d(2560, 1024, 1, bias=False) 62 | self.bn1 = nn.BatchNorm2d(1024) 63 | 64 | # adopt [1x1, 48] for channel reduction. 65 | self.conv2 = nn.Conv2d(256, 256, 1, bias=False) 66 | self.bn2 = nn.BatchNorm2d(256) 67 | self.last_conv = nn.Sequential(nn.Conv2d(1280, 1024, kernel_size=3, stride=1, padding=1, bias=False), 68 | nn.BatchNorm2d(1024), 69 | nn.ReLU(), 70 | nn.Dropout(0.1), 71 | nn.Conv2d(1024, 1024, kernel_size=3, stride=1, padding=1, bias=False), 72 | nn.BatchNorm2d(1024), 73 | nn.ReLU(), 74 | nn.Dropout(0.1), 75 | nn.Conv2d(1024, 1024, kernel_size=1, stride=1)) 76 | 77 | self.DecompNet = DecompNet() 78 | load_decomp_ckpt(self.DecompNet, decomp_path) 79 | 80 | weight = self.rgb_layer2._modules['0'].conv1.weight.clone() 81 | self.rgb_layer2._modules['0'].conv1 = nn.Conv2d(512, 128, kernel_size=(1,1), stride=(1,1), bias=False) 82 | self.rgb_layer2._modules['0'].conv1.weight.data[:, :256] = weight 83 | self.rgb_layer2._modules['0'].conv1.weight.data[:, 256:] = weight 84 | weight2 = self.rgb_layer2._modules['0'].downsample[0].weight.clone() 85 | self.rgb_layer2._modules['0'].downsample[0] = nn.Conv2d(512, 512, kernel_size=(1,1), stride=(2,2), bias=False) 86 | self.rgb_layer2._modules['0'].downsample[0].weight.data[:, :256] = weight2 87 | self.rgb_layer2._modules['0'].downsample[0].weight.data[:, 256:] = weight2 88 | 89 | def forward(self, img_s_list, hiseq_s_list, mask_s_list, img_q, hiseq_q, mask_q): 90 | h, w = img_q.shape[-2:] 91 | reflectance_q = self.DecompNet(img_q) 92 | 93 | # feature maps of support images 94 | feature_s_list = [] 95 | feature_s_ref_list = [] 96 | feature_s_ref_lowlevel_list = [] 97 | for k in range(len(img_s_list)): 98 | with torch.no_grad(): 99 | reflectance_s = self.DecompNet(img_s_list[k]) 100 | 101 | s_0 = self.rgb_layer0(img_s_list[k]) # (100, 100) 102 | s_0_hiseq = self.rgb_layer0(hiseq_s_list[k]) 103 | 104 | s_0 = self.rgb_layer1(s_0) # (100, 100) 105 | s_0_hiseq = self.rgb_layer1(s_0_hiseq) 106 | 107 | s_0_ref = self.ref_layer0(reflectance_s) 108 | s_0_ref = self.ref_layer1(s_0_ref) 109 | s_0_lowlevel = s_0_ref 110 | 111 | s_0 = self.rgb_layer2(torch.cat([s_0, s_0_hiseq], dim=1)) # (50, 50) 112 | s_0 = self.rgb_layer3(s_0) # (50, 50) 113 | 114 | s_0_ref = self.ref_layer2(s_0_ref) # (50, 50) 115 | s_0_ref = self.ref_layer3(s_0_ref) # (50, 50) 116 | 117 | s_0 = self.aspp_block(s_0, s_0_lowlevel) 118 | 119 | feature_s_list.append(s_0) 120 | feature_s_ref_list.append(s_0_ref) 121 | feature_s_ref_lowlevel_list.append(s_0_lowlevel) 122 | del s_0 123 | del s_0_hiseq 124 | del s_0_ref 125 | del s_0_lowlevel 126 | # feature map of query image 127 | with torch.no_grad(): 128 | q_0 = self.rgb_layer0(img_q) 129 | q_0_hiseq = self.rgb_layer0(hiseq_q) 130 | q_0 = self.rgb_layer1(q_0) 131 | q_0_hiseq = self.rgb_layer1(q_0_hiseq) 132 | 133 | q_0_ref = self.ref_layer0(reflectance_q) 134 | q_0_ref = self.ref_layer1(q_0_ref) 135 | q_0_lowlevel = q_0_ref 136 | 137 | q_0 = self.rgb_layer2(torch.cat([q_0, q_0_hiseq], dim=1)) 138 | feature_q = self.rgb_layer3(q_0) #(Bs, 1024, 50, 50) 139 | 140 | feature_q = self.aspp_block(feature_q, q_0_lowlevel) 141 | 142 | # foreground(target class) and background prototypes pooled from K support features 143 | feature_fg_list = [] 144 | feature_bg_list = [] 145 | feature_ref_fg_list = [] 146 | feature_ref_bg_list = [] 147 | 148 | supp_out_ls = [] 149 | supp_out_ref_ls = [] 150 | for k in range(len(img_s_list)): 151 | # Generate original prototype 152 | feature_fg = self.masked_average_pooling(feature_s_list[k], 153 | (mask_s_list[k] == 1).float())[None, :] 154 | feature_bg = self.masked_average_pooling(feature_s_list[k], 155 | (mask_s_list[k] == 0).float())[None, :] 156 | feature_fg_list.append(feature_fg) # (1, Bs, C) 157 | feature_bg_list.append(feature_bg) # (1, Bs, C) 158 | 159 | # Generate reflectance prototype 160 | feature_ref_fg = self.masked_average_pooling(feature_s_ref_list[k], 161 | (mask_s_list[k] == 1).float())[None, :] 162 | feature_ref_bg = self.masked_average_pooling(feature_s_ref_list[k], 163 | (mask_s_list[k] == 0).float())[None, :] 164 | feature_ref_fg_list.append(feature_ref_fg) # (1, Bs, C) 165 | feature_ref_bg_list.append(feature_ref_bg) # (1, Bs, C) 166 | 167 | if self.training: 168 | supp_similarity_fg = F.cosine_similarity(feature_s_list[k], feature_fg.squeeze(0)[..., None, None], dim=1) 169 | supp_similarity_bg = F.cosine_similarity(feature_s_list[k], feature_bg.squeeze(0)[..., None, None], dim=1) 170 | supp_out = torch.cat((supp_similarity_bg[:, None, ...], supp_similarity_fg[:, None, ...]), dim=1) * 10.0 171 | 172 | supp_out = F.interpolate(supp_out, size=(h, w), mode="bilinear", align_corners=True) # (Bs, 2, H, W) 173 | supp_out_ls.append(supp_out) 174 | 175 | # Reflectance support 176 | supp_similarity_ref_fg = F.cosine_similarity(feature_s_ref_list[k], feature_ref_fg.squeeze(0)[..., None, None], dim=1) 177 | supp_similarity_ref_bg = F.cosine_similarity(feature_s_ref_list[k], feature_ref_bg.squeeze(0)[..., None, None], dim=1) 178 | supp_out_ref = torch.cat((supp_similarity_ref_bg[:, None, ...], supp_similarity_ref_fg[:, None, ...]), dim=1) * 10.0 179 | 180 | supp_out_ref = F.interpolate(supp_out_ref, size=(h, w), mode="bilinear", align_corners=True) # (Bs, 2, H, W) 181 | supp_out_ref_ls.append(supp_out_ref) 182 | 183 | # average K foreground prototypes and K background prototypes (Bs, C, 1, 1) 184 | FP = torch.mean(torch.cat(feature_fg_list, dim=0), dim=0).unsqueeze(-1).unsqueeze(-1) 185 | BP = torch.mean(torch.cat(feature_bg_list, dim=0), dim=0).unsqueeze(-1).unsqueeze(-1) 186 | 187 | ref_FP = torch.mean(torch.cat(feature_ref_fg_list, dim=0), dim=0).unsqueeze(-1).unsqueeze(-1) 188 | ref_BP = torch.mean(torch.cat(feature_ref_bg_list, dim=0), dim=0).unsqueeze(-1).unsqueeze(-1) 189 | 190 | # Fuse two prototypes and generate updated features and prototypes 191 | cate_FP = torch.cat((FP, ref_FP), dim=1) 192 | cate_BP = torch.cat((BP, ref_BP), dim=1) 193 | 194 | normalized_FP = self.z_score_norm(cate_FP) 195 | normalized_BP = self.z_score_norm(cate_BP) 196 | 197 | weights_FP = self.proj_weight_FP(normalized_FP) 198 | weights_BP = self.proj_weight_BP(normalized_BP) 199 | 200 | FP = torch.mul((1 + weights_FP), FP) 201 | feature_q = torch.mul((1 + weights_FP), feature_q) 202 | BP = torch.mul((1 + weights_BP), BP) 203 | 204 | # measure the similarity of query features to fg/bg prototypes 205 | out_0 = self.similarity_func(feature_q, FP, BP) # (Bs, 2, H, W) 206 | 207 | out_1, out_2 = self.SSP_module(feature_q, out_0, FP, BP) 208 | 209 | out_0 = F.interpolate(out_0, size=(h, w), mode="bilinear", align_corners=True) 210 | out_1 = F.interpolate(out_1, size=(h, w), mode="bilinear", align_corners=True) 211 | out_2 = F.interpolate(out_2, size=(h, w), mode="bilinear", align_corners=True) 212 | 213 | out_ls = [out_2, out_1] 214 | 215 | if self.training: 216 | fg_q = self.masked_average_pooling(feature_q, (mask_q == 1).float())[None, :].squeeze(0) 217 | bg_q = self.masked_average_pooling(feature_q, (mask_q == 0).float())[None, :].squeeze(0) 218 | 219 | self_similarity_fg = F.cosine_similarity(feature_q, fg_q[..., None, None], dim=1) 220 | self_similarity_bg = F.cosine_similarity(feature_q, bg_q[..., None, None], dim=1) 221 | self_out = torch.cat((self_similarity_bg[:, None, ...], self_similarity_fg[:, None, ...]), dim=1) * 10.0 222 | 223 | self_out = F.interpolate(self_out, size=(h, w), mode="bilinear", align_corners=True) 224 | supp_out = torch.cat(supp_out_ls, 0) 225 | supp_out_ref = torch.cat(supp_out_ref_ls, 0) 226 | 227 | out_ls.append(self_out) 228 | out_ls.append(supp_out) 229 | out_ls.append(supp_out_ref) 230 | 231 | return out_ls 232 | 233 | def aspp_block(self, x, low_level_features): 234 | x1 = self.aspp1(x) 235 | x2 = self.aspp2(x) 236 | x3 = self.aspp3(x) 237 | x4 = self.aspp4(x) 238 | x5 = self.global_avg_pool(x) 239 | x5 = F.interpolate(x5, size=x4.size()[2:], mode='bilinear', align_corners=True) 240 | 241 | x = torch.cat((x1, x2, x3, x4, x5), dim=1) 242 | 243 | x = self.conv1(x) 244 | x = self.bn1(x) 245 | x = self.relu(x) 246 | x = self.dropout(x) 247 | x = F.interpolate(x, size=(low_level_features.size()[-2], 248 | low_level_features.size()[-1]), mode='bilinear', align_corners=True) 249 | 250 | low_level_features = self.conv2(low_level_features) 251 | low_level_features = self.bn2(low_level_features) 252 | low_level_features = self.relu(low_level_features) 253 | 254 | x = torch.cat((x, low_level_features), dim=1) 255 | x = self.last_conv(x) 256 | 257 | return x 258 | 259 | def z_score_norm(self, tensor): 260 | mu = torch.mean(tensor,dim=(1),keepdim=True) 261 | sd = torch.std(tensor,dim=(1),keepdim=True) 262 | normalized_tensor = (tensor - mu) / sd 263 | 264 | return normalized_tensor 265 | 266 | def SSP_module(self, feature_q, out_0, FP, BP): 267 | ##################### Self-Support Prototype (SSP) ##################### 268 | SSFP_1, SSBP_1, ASFP_1, ASBP_1 = self.SSP_func(feature_q, out_0) 269 | 270 | FP_1 = FP * 0.5 + SSFP_1 * 0.5 271 | BP_1 = SSBP_1 * 0.3 + ASBP_1 * 0.7 272 | 273 | out_1 = self.similarity_func(feature_q, FP_1, BP_1) 274 | 275 | ##################### SSP Refinement ##################### 276 | SSFP_2, SSBP_2, ASFP_2, ASBP_2 = self.SSP_func(feature_q, out_1) 277 | 278 | FP_2 = FP * 0.5 + SSFP_2 * 0.5 279 | BP_2 = SSBP_2 * 0.3 + ASBP_2 * 0.7 280 | 281 | FP_2 = FP * 0.5 + FP_1 * 0.2 + FP_2 * 0.3 282 | BP_2 = BP * 0.5 + BP_1 * 0.2 + BP_2 * 0.3 283 | 284 | out_2 = self.similarity_func(feature_q, FP_2, BP_2) 285 | 286 | out_2 = out_2 * 0.7 + out_1 * 0.3 287 | 288 | return out_1, out_2 289 | 290 | def SSP_func(self, feature_q, out): 291 | bs = feature_q.shape[0] 292 | pred_1 = out.softmax(1) 293 | pred_1 = pred_1.view(bs, 2, -1) 294 | pred_fg = pred_1[:, 1] # (Bs, H*W) 295 | pred_bg = pred_1[:, 0] # (Bs, H*W) 296 | fg_ls = [] 297 | bg_ls = [] 298 | fg_local_ls = [] 299 | bg_local_ls = [] 300 | for epi in range(bs): 301 | fg_thres = 0.7 #0.9 #0.6 302 | bg_thres = 0.6 #0.6 303 | cur_feat = feature_q[epi].view(1024, -1) 304 | f_h, f_w = feature_q[epi].shape[-2:] 305 | if (pred_fg[epi] > fg_thres).sum() > 0: 306 | fg_feat = cur_feat[:, (pred_fg[epi]>fg_thres)] #.mean(-1) 307 | else: 308 | fg_feat = cur_feat[:, torch.topk(pred_fg[epi], 12).indices] #.mean(-1) 309 | if (pred_bg[epi] > bg_thres).sum() > 0: 310 | bg_feat = cur_feat[:, (pred_bg[epi]>bg_thres)] #.mean(-1) 311 | else: 312 | bg_feat = cur_feat[:, torch.topk(pred_bg[epi], 12).indices] #.mean(-1) 313 | # global proto 314 | fg_proto = fg_feat.mean(-1) 315 | bg_proto = bg_feat.mean(-1) 316 | fg_ls.append(fg_proto.unsqueeze(0)) 317 | bg_ls.append(bg_proto.unsqueeze(0)) 318 | 319 | # local proto 320 | fg_feat_norm = fg_feat / torch.norm(fg_feat, 2, 0, True) # 1024, N1 321 | bg_feat_norm = bg_feat / torch.norm(bg_feat, 2, 0, True) # 1024, N2 322 | cur_feat_norm = cur_feat / torch.norm(cur_feat, 2, 0, True) # 1024, N3 323 | 324 | cur_feat_norm_t = cur_feat_norm.t() # N3, 1024 325 | fg_sim = torch.matmul(cur_feat_norm_t, fg_feat_norm) * 2.0 # N3, N1 326 | bg_sim = torch.matmul(cur_feat_norm_t, bg_feat_norm) * 2.0 # N3, N2 327 | 328 | fg_sim = fg_sim.softmax(-1) 329 | bg_sim = bg_sim.softmax(-1) 330 | 331 | fg_proto_local = torch.matmul(fg_sim, fg_feat.t()) # N3, 1024 332 | bg_proto_local = torch.matmul(bg_sim, bg_feat.t()) # N3, 1024 333 | 334 | fg_proto_local = fg_proto_local.t().view(1024, f_h, f_w).unsqueeze(0) # 1024, N3 335 | bg_proto_local = bg_proto_local.t().view(1024, f_h, f_w).unsqueeze(0) # 1024, N3 336 | 337 | fg_local_ls.append(fg_proto_local) 338 | bg_local_ls.append(bg_proto_local) 339 | 340 | # global proto 341 | new_fg = torch.cat(fg_ls, 0).unsqueeze(-1).unsqueeze(-1) 342 | new_bg = torch.cat(bg_ls, 0).unsqueeze(-1).unsqueeze(-1) 343 | 344 | # local proto 345 | new_fg_local = torch.cat(fg_local_ls, 0).unsqueeze(-1).unsqueeze(-1) 346 | new_bg_local = torch.cat(bg_local_ls, 0) 347 | 348 | return new_fg, new_bg, new_fg_local, new_bg_local 349 | 350 | def similarity_func(self, feature_q, fg_proto, bg_proto): 351 | similarity_fg = F.cosine_similarity(feature_q, fg_proto, dim=1) 352 | similarity_bg = F.cosine_similarity(feature_q, bg_proto, dim=1) 353 | 354 | out = torch.cat((similarity_bg[:, None, ...], similarity_fg[:, None, ...]), dim=1) * 10.0 355 | return out 356 | 357 | def masked_average_pooling(self, feature, mask): 358 | mask = F.interpolate(mask.unsqueeze(1), size=feature.shape[-2:], mode='bilinear', align_corners=True) 359 | masked_feature = torch.sum(feature * mask, dim=(2, 3)) \ 360 | / (mask.sum(dim=(2, 3)) + 1e-5) 361 | return masked_feature 362 | -------------------------------------------------------------------------------- /model/DecompNet.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import random 4 | 5 | from PIL import Image 6 | import torch 7 | import torch.nn as nn 8 | import torch.optim as optim 9 | import torch.nn.functional as F 10 | from torch.autograd import Variable 11 | import numpy as np 12 | 13 | def load_decomp_ckpt(model, ckpt_dir): 14 | assert os.path.exists(ckpt_dir), "Pretrained weights don't exist!" 15 | 16 | ckpt_dict = torch.load(ckpt_dir) 17 | 18 | model.load_state_dict(ckpt_dict) 19 | 20 | #freeze the DecompNet weights 21 | for _, param in model.named_parameters(): 22 | param.requires_grad = False 23 | 24 | print('DecompNet successfull loaded from {}'.format(ckpt_dir)) 25 | 26 | class DecompNet(nn.Module): 27 | def __init__(self, channel=64, kernel_size=3): 28 | super(DecompNet, self).__init__() 29 | # Shallow feature extraction 30 | self.net1_conv0 = nn.Conv2d(4, channel, kernel_size * 3, 31 | padding=4, padding_mode='replicate') 32 | # Activated layers! 33 | self.net1_convs = nn.Sequential(nn.Conv2d(channel, channel, kernel_size, 34 | padding=1, padding_mode='replicate'), 35 | nn.ReLU(), 36 | nn.Conv2d(channel, channel, kernel_size, 37 | padding=1, padding_mode='replicate'), 38 | nn.ReLU(), 39 | nn.Conv2d(channel, channel, kernel_size, 40 | padding=1, padding_mode='replicate'), 41 | nn.ReLU(), 42 | nn.Conv2d(channel, channel, kernel_size, 43 | padding=1, padding_mode='replicate'), 44 | nn.ReLU(), 45 | nn.Conv2d(channel, channel, kernel_size, 46 | padding=1, padding_mode='replicate'), 47 | nn.ReLU()) 48 | # Final recon layer 49 | self.net1_recon = nn.Conv2d(channel, 4, kernel_size, 50 | padding=1, padding_mode='replicate') 51 | 52 | def forward(self, input_im): 53 | input_max= torch.max(input_im, dim=1, keepdim=True)[0] 54 | input_img= torch.cat((input_max, input_im), dim=1) 55 | feats0 = self.net1_conv0(input_img) 56 | featss = self.net1_convs(feats0) 57 | outs = self.net1_recon(featss) 58 | R = torch.sigmoid(outs[:, 0:3, :, :]) 59 | return R -------------------------------------------------------------------------------- /model/backbone/DecompNet.tar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zyaocoder/CrackNex/122725c5e43ed1a9967cad0f29b01627901f3569/model/backbone/DecompNet.tar -------------------------------------------------------------------------------- /model/resnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 'resnet152'] 7 | 8 | 9 | def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): 10 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 11 | padding=dilation, groups=groups, bias=False, dilation=dilation) 12 | 13 | 14 | def conv1x1(in_planes, out_planes, stride=1): 15 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) 16 | 17 | 18 | class BasicBlock(nn.Module): 19 | expansion = 1 20 | 21 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, 22 | base_width=64, dilation=1, norm_layer=None): 23 | super(BasicBlock, self).__init__() 24 | if norm_layer is None: 25 | norm_layer = nn.BatchNorm2d 26 | if groups != 1 or base_width != 64: 27 | raise ValueError('BasicBlock only supports groups=1 and base_width=64') 28 | if dilation > 1: 29 | raise NotImplementedError("Dilation > 1 not supported in BasicBlock") 30 | 31 | self.conv1 = conv3x3(inplanes, planes, stride) 32 | self.bn1 = norm_layer(planes) 33 | self.relu = nn.ReLU(inplace=True) 34 | self.conv2 = conv3x3(planes, planes) 35 | self.bn2 = norm_layer(planes) 36 | self.downsample = downsample 37 | self.stride = stride 38 | 39 | def forward(self, x): 40 | identity = x 41 | 42 | out = self.conv1(x) 43 | out = self.bn1(out) 44 | out = self.relu(out) 45 | 46 | out = self.conv2(out) 47 | out = self.bn2(out) 48 | 49 | if self.downsample is not None: 50 | identity = self.downsample(x) 51 | 52 | out += identity 53 | out = self.relu(out) 54 | 55 | return out 56 | 57 | 58 | class Bottleneck(nn.Module): 59 | expansion = 4 60 | 61 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, 62 | base_width=64, dilation=1, norm_layer=None, last_relu=True): 63 | super(Bottleneck, self).__init__() 64 | if norm_layer is None: 65 | norm_layer = nn.BatchNorm2d 66 | width = int(planes * (base_width / 64.)) * groups 67 | 68 | self.conv1 = conv1x1(inplanes, width) 69 | self.bn1 = norm_layer(width) 70 | self.conv2 = conv3x3(width, width, stride, groups, dilation) 71 | self.bn2 = norm_layer(width) 72 | self.conv3 = conv1x1(width, planes * self.expansion) 73 | self.bn3 = norm_layer(planes * self.expansion) 74 | self.relu = nn.ReLU(inplace=True) 75 | self.downsample = downsample 76 | self.stride = stride 77 | self.last_relu = last_relu 78 | 79 | def forward(self, x): 80 | identity = x 81 | 82 | out = self.conv1(x) 83 | out = self.bn1(out) 84 | out = self.relu(out) 85 | 86 | out = self.conv2(out) 87 | out = self.bn2(out) 88 | out = self.relu(out) 89 | 90 | out = self.conv3(out) 91 | out = self.bn3(out) 92 | 93 | if self.downsample is not None: 94 | identity = self.downsample(x) 95 | 96 | out += identity 97 | if self.last_relu: 98 | out = self.relu(out) 99 | 100 | return out 101 | 102 | 103 | class ResNet(nn.Module): 104 | 105 | def __init__(self, block, layers, zero_init_residual=False, groups=1, 106 | width_per_group=64, replace_stride_with_dilation=None, norm_layer=None): 107 | super(ResNet, self).__init__() 108 | 109 | self.out_channels = block.expansion * 256 110 | 111 | if norm_layer is None: 112 | norm_layer = nn.BatchNorm2d 113 | self._norm_layer = norm_layer 114 | 115 | self.inplanes = 128 116 | self.dilation = 1 117 | if replace_stride_with_dilation is None: 118 | replace_stride_with_dilation = [False, False, False] 119 | if len(replace_stride_with_dilation) != 3: 120 | raise ValueError("replace_stride_with_dilation should be None " 121 | "or a 3-element tuple, got {}".format(replace_stride_with_dilation)) 122 | self.groups = groups 123 | self.base_width = width_per_group 124 | 125 | self.conv1 = nn.Sequential( 126 | conv3x3(3, 64, stride=2), 127 | norm_layer(64), 128 | nn.ReLU(inplace=True), 129 | conv3x3(64, 64), 130 | norm_layer(64), 131 | nn.ReLU(inplace=True), 132 | conv3x3(64, 128) 133 | ) 134 | self.bn1 = norm_layer(128) 135 | self.relu = nn.ReLU(inplace=True) 136 | 137 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 138 | self.layer1 = self._make_layer(block, 64, layers[0]) 139 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2, 140 | dilate=replace_stride_with_dilation[0]) 141 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2, 142 | dilate=replace_stride_with_dilation[1], last_relu=False) 143 | 144 | for m in self.modules(): 145 | if isinstance(m, nn.Conv2d): 146 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 147 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 148 | nn.init.constant_(m.weight, 1) 149 | nn.init.constant_(m.bias, 0) 150 | 151 | if zero_init_residual: 152 | for m in self.modules(): 153 | if isinstance(m, Bottleneck): 154 | nn.init.constant_(m.bn3.weight, 0) 155 | elif isinstance(m, BasicBlock): 156 | nn.init.constant_(m.bn2.weight, 0) 157 | 158 | def _make_layer(self, block, planes, blocks, stride=1, dilate=False, last_relu=True): 159 | """ 160 | :param last_relu: in metric learning paradigm, the final relu is removed (last_relu = False) 161 | """ 162 | norm_layer = self._norm_layer 163 | downsample = None 164 | previous_dilation = self.dilation 165 | if dilate: 166 | self.dilation *= stride 167 | stride = 1 168 | if stride != 1 or self.inplanes != planes * block.expansion: 169 | downsample = nn.Sequential( 170 | conv1x1(self.inplanes, planes * block.expansion, stride), 171 | norm_layer(planes * block.expansion), 172 | ) 173 | 174 | layers = list() 175 | layers.append(block(self.inplanes, planes, stride, downsample, self.groups, 176 | self.base_width, previous_dilation, norm_layer)) 177 | self.inplanes = planes * block.expansion 178 | for i in range(1, blocks): 179 | use_relu = True if i != blocks - 1 else last_relu 180 | layers.append(block(self.inplanes, planes, groups=self.groups, 181 | base_width=self.base_width, dilation=self.dilation, 182 | norm_layer=norm_layer, last_relu=use_relu)) 183 | 184 | return nn.Sequential(*layers) 185 | 186 | def base_forward(self, x): 187 | x = self.conv1(x) 188 | x = self.bn1(x) 189 | x = self.relu(x) 190 | x = self.maxpool(x) 191 | 192 | c1 = self.layer1(x) 193 | c2 = self.layer2(c1) 194 | c3 = self.layer3(c2) 195 | 196 | return c3 197 | 198 | 199 | def _resnet(arch, block, layers, pretrained, **kwargs): 200 | model = ResNet(block, layers, **kwargs) 201 | if pretrained: 202 | state_dict = torch.load("./pretrained/%s.pth" % arch) 203 | model.load_state_dict(state_dict, strict=False) 204 | return model 205 | 206 | 207 | def resnet18(pretrained=False): 208 | return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained) 209 | 210 | 211 | def resnet34(pretrained=False): 212 | return _resnet('resnet34', BasicBlock, [3, 4, 6, 3], pretrained) 213 | 214 | 215 | def resnet50(pretrained=False): 216 | return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, 217 | replace_stride_with_dilation=[False, True, True]) 218 | 219 | 220 | def resnet101(pretrained=False): 221 | return _resnet('resnet101', Bottleneck, [3, 4, 23, 3], pretrained, 222 | replace_stride_with_dilation=[False, True, True]) 223 | 224 | 225 | def resnet152(pretrained=False): 226 | return _resnet('resnet152', Bottleneck, [3, 8, 36, 3], pretrained, 227 | replace_stride_with_dilation=[False, True, True]) 228 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | from dataset.fewshot import FewShot 2 | from model.CrackNex_matching import CrackNex 3 | from util.utils import count_params, set_seed, mIOU 4 | 5 | import argparse 6 | import os 7 | import torch 8 | from torch.nn import DataParallel 9 | from torch.utils.data import DataLoader 10 | from tqdm import tqdm 11 | import glob 12 | import numpy as np 13 | from PIL import Image 14 | 15 | def parse_args(): 16 | parser = argparse.ArgumentParser(description='Mining Latent Classes for Few-shot Segmentation') 17 | # basic arguments 18 | parser.add_argument('--data-root', 19 | type=str, 20 | required=True, 21 | help='root path of training dataset') 22 | parser.add_argument('--dataset', 23 | type=str, 24 | default='llCrackSeg9k', 25 | choices=['llCrackSeg9k', 'LCSD'], 26 | help='training dataset') 27 | parser.add_argument('--backbone', 28 | type=str, 29 | choices=['resnet50', 'resnet101'], 30 | default='resnet50', 31 | help='backbone of semantic segmentation model') 32 | 33 | # few-shot training arguments 34 | parser.add_argument('--shot', 35 | type=int, 36 | default=1, 37 | help='number of support pairs') 38 | parser.add_argument('--seed', 39 | type=int, 40 | default=0, 41 | help='random seed to generate tesing samples') 42 | parser.add_argument('--path', 43 | type=str, 44 | help='chekpoint path') 45 | parser.add_argument('--savepath', 46 | type=str, 47 | default='./logs/', 48 | help='results saving path') 49 | 50 | args = parser.parse_args() 51 | return args 52 | 53 | 54 | def evaluate(model, dataloader, args): 55 | tbar = tqdm(dataloader) 56 | 57 | num_classes = 3 58 | metric = mIOU(num_classes) 59 | 60 | for i, (img_s_list, hiseq_s_list, mask_s_list, img_q, hiseq_q, mask_q, cls, _, id_q) in enumerate(tbar): 61 | img_q, hiseq_q, mask_q = img_q.cuda(), hiseq_q.cuda(), mask_q.cuda() 62 | for k in range(len(img_s_list)): 63 | img_s_list[k], hiseq_s_list[k], mask_s_list[k] = img_s_list[k].cuda(), hiseq_s_list[k].cuda(), mask_s_list[k].cuda() 64 | cls = cls[0].item() 65 | 66 | with torch.no_grad(): 67 | out_ls = model(img_s_list, hiseq_s_list, mask_s_list, img_q, hiseq_q, mask_q) 68 | pred = torch.argmax(out_ls[0], dim=1) 69 | 70 | pred[pred == 1] = cls 71 | mask_q[mask_q == 1] = cls 72 | 73 | # if seed == 0: 74 | # result = pred.squeeze(0).cpu().numpy().copy() 75 | # result[result == 2] = 255 76 | 77 | # im = Image.fromarray(np.uint8(result)) 78 | # im.save(args.savepath + id_q[0] + '.png') 79 | 80 | metric.add_batch(pred.cpu().numpy(), mask_q.cpu().numpy()) 81 | 82 | tbar.set_description("Testing mIOU: %.2f" % (metric.evaluate() * 100.0)) 83 | 84 | return metric.evaluate() * 100.0 85 | 86 | def main(): 87 | args = parse_args() 88 | print('\n' + str(args)) 89 | 90 | save_path = 'outdir/models/%s' % (args.dataset) 91 | os.makedirs(save_path, exist_ok=True) 92 | 93 | testset = FewShot(args.data_root.replace('train_coco', 'val_coco'), None, 'val', 94 | args.shot, 760 if args.dataset == 'LCSD' else 4000) 95 | testloader = DataLoader(testset, batch_size=1, shuffle=False, 96 | pin_memory=True, num_workers=4, drop_last=False) 97 | 98 | model = CrackNex(args.backbone) 99 | checkpoint_path = args.path 100 | 101 | print('Evaluating model:', checkpoint_path) 102 | 103 | checkpoint = torch.load(checkpoint_path) 104 | model.load_state_dict(checkpoint) 105 | 106 | #print(model) 107 | print('\nParams: %.1fM' % count_params(model)) 108 | 109 | best_model = DataParallel(model).cuda() 110 | 111 | print('\nEvaluating on 5 seeds.....') 112 | total_miou = 0.0 113 | model.eval() 114 | for seed in range(5): 115 | print('\nRun %i:' % (seed + 1)) 116 | set_seed(args.seed + seed) 117 | 118 | miou = evaluate(best_model, testloader, args) 119 | total_miou += miou 120 | 121 | print('\n' + '*' * 32) 122 | print('Averaged mIOU on 5 seeds: %.2f' % (total_miou / 5)) 123 | print('*' * 32 + '\n') 124 | 125 | 126 | if __name__ == '__main__': 127 | main() 128 | 129 | -------------------------------------------------------------------------------- /test.sh: -------------------------------------------------------------------------------- 1 | python -W ignore test.py \ 2 | --dataset LCSD --data-root /Your/path/to/dataset \ 3 | --backbone resnet101 --shot 5 --path /Your/path/to/checkpoint 4 | 5 | # python -W ignore test.py \ 6 | # --dataset LCSD --data-root /Your/path/to/dataset \ 7 | # --backbone resnet101 --shot 1 --path /Your/path/to/checkpoint 8 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | from dataset.fewshot import FewShot 2 | from model.CrackNex_matching import CrackNex 3 | from util.utils import count_params, set_seed, calc_crack_pixel_weight, mIOU 4 | 5 | import argparse 6 | from copy import deepcopy 7 | import os 8 | import time 9 | import torch 10 | from torch.nn import CrossEntropyLoss, DataParallel 11 | from torch.optim import SGD 12 | from torch.utils.data import DataLoader 13 | from tqdm import tqdm 14 | 15 | def parse_args(): 16 | parser = argparse.ArgumentParser(description='Mining Latent Classes for Few-shot Segmentation') 17 | # basic arguments 18 | parser.add_argument('--data-root', 19 | type=str, 20 | required=True, 21 | help='root path of training dataset') 22 | parser.add_argument('--dataset', 23 | type=str, 24 | default='LCSD', 25 | choices=['llCrackSeg9k', 'LCSD'], 26 | help='training dataset') 27 | parser.add_argument('--batch-size', 28 | type=int, 29 | default=1, 30 | help='batch size of training') 31 | parser.add_argument('--lr', 32 | type=float, 33 | default=0.001, 34 | help='learning rate') 35 | parser.add_argument('--loss', 36 | type=str, 37 | choices=['CE', 'weightedCE'], 38 | default='CE', 39 | help='loss function') 40 | parser.add_argument('--crop-size', 41 | type=int, 42 | default=400, 43 | help='cropping size of training samples') 44 | parser.add_argument('--backbone', 45 | type=str, 46 | choices=['resnet50', 'resnet101'], 47 | default='resnet50', 48 | help='backbone of semantic segmentation model') 49 | 50 | # few-shot training arguments 51 | parser.add_argument('--shot', 52 | type=int, 53 | default=1, 54 | help='number of support pairs') 55 | parser.add_argument('--episode', 56 | type=int, 57 | default=6000, 58 | choices=[6000, 18000, 24000, 36000], 59 | help='total episodes of training') 60 | parser.add_argument('--snapshot', 61 | type=int, 62 | default=200, 63 | choices=[200, 1200, 2000], 64 | help='save the model after each snapshot episodes') 65 | parser.add_argument('--seed', 66 | type=int, 67 | default=0, 68 | help='random seed to generate tesing samples') 69 | 70 | args = parser.parse_args() 71 | return args 72 | 73 | def evaluate(model, dataloader, args): 74 | tbar = tqdm(dataloader) 75 | 76 | num_classes = 3 77 | 78 | metric = mIOU(num_classes) 79 | for i, (img_s_list, hiseq_s_list, mask_s_list, img_q, hiseq_q, mask_q, cls, _, id_q) in enumerate(tbar): 80 | img_q, hiseq_q, mask_q = img_q.cuda(), hiseq_q.cuda(), mask_q.cuda() 81 | for k in range(len(img_s_list)): 82 | img_s_list[k], hiseq_s_list[k], mask_s_list[k] = img_s_list[k].cuda(), hiseq_s_list[k].cuda(), mask_s_list[k].cuda() 83 | cls = cls[0].item() 84 | 85 | with torch.no_grad(): 86 | out_ls = model(img_s_list, hiseq_s_list, mask_s_list, img_q, hiseq_q, mask_q) 87 | pred = torch.argmax(out_ls[0], dim=1) 88 | 89 | pred[pred == 1] = cls 90 | mask_q[mask_q == 1] = cls 91 | 92 | metric.add_batch(pred.cpu().numpy(), mask_q.cpu().numpy()) 93 | 94 | tbar.set_description("Testing mIOU: %.2f" % (metric.evaluate() * 100.0)) 95 | 96 | return metric.evaluate() * 100.0 97 | 98 | def main(): 99 | args = parse_args() 100 | print('\n' + str(args)) 101 | 102 | save_path = 'outdir/models/%s' % (args.dataset) 103 | os.makedirs(save_path, exist_ok=True) 104 | 105 | trainset = FewShot(args.data_root, args.crop_size, 106 | 'train', args.shot, args.snapshot) 107 | trainloader = DataLoader(trainset, batch_size=args.batch_size, shuffle=True, 108 | pin_memory=True, num_workers=0, drop_last=True) 109 | testset = FewShot(args.data_root, None, 'val', 110 | args.shot, 41 if args.dataset == 'LCSD' else 1486) 111 | testloader = DataLoader(testset, batch_size=1, shuffle=False, 112 | pin_memory=True, num_workers=0, drop_last=False) 113 | 114 | model = CrackNex(args.backbone) 115 | print('\nParams: %.1fM' % count_params(model)) 116 | 117 | for param in model.rgb_layer0.parameters(): 118 | param.requires_grad = False 119 | for param in model.rgb_layer1.parameters(): 120 | param.requires_grad = False 121 | for param in model.ref_layer0.parameters(): 122 | param.requires_grad = False 123 | for param in model.ref_layer1.parameters(): 124 | param.requires_grad = False 125 | 126 | for module in model.modules(): 127 | if isinstance(module, torch.nn.BatchNorm2d): 128 | for param in module.parameters(): 129 | param.requires_grad = False 130 | 131 | if args.loss == 'CE': 132 | criterion = CrossEntropyLoss(ignore_index=255) 133 | elif args.loss == 'weightedCE': 134 | crack_weight = [1, 0.4] * calc_crack_pixel_weight(args.data_root) 135 | print(f'positive weight: {crack_weight}') 136 | criterion = CrossEntropyLoss(weight=torch.Tensor([crack_weight]).to('cuda').squeeze(0), ignore_index=255) 137 | 138 | optimizer = SGD([param for param in model.parameters() if param.requires_grad], 139 | lr=args.lr, momentum=0.9, weight_decay=5e-4) 140 | 141 | model = DataParallel(model).cuda() 142 | best_model = None 143 | 144 | iters = 0 145 | total_iters = args.episode // args.batch_size 146 | lr_decay_iters = [total_iters // 3, total_iters * 2 // 3] 147 | previous_best = 0 148 | 149 | # each snapshot is considered as an epoch 150 | for epoch in range(args.episode // args.snapshot): 151 | print("\n==> Epoch %i, learning rate = %.5f\t\t\t\t Previous best = %.2f" 152 | % (epoch, optimizer.param_groups[0]["lr"], previous_best)) 153 | 154 | model.train() 155 | 156 | for module in model.modules(): 157 | if isinstance(module, torch.nn.BatchNorm2d): 158 | module.eval() 159 | 160 | total_loss = 0.0 161 | 162 | tbar = tqdm(trainloader) 163 | set_seed(int(time.time())) 164 | 165 | for i, (img_s_list, hiseq_s_list, mask_s_list, img_q, hiseq_q, mask_q, _, _, _) in enumerate(tbar): 166 | img_q, hiseq_q, mask_q = img_q.cuda(), hiseq_q.cuda(), mask_q.cuda() 167 | for k in range(len(img_s_list)): 168 | img_s_list[k], hiseq_s_list[k], mask_s_list[k] = img_s_list[k].cuda(), hiseq_s_list[k].cuda(), mask_s_list[k].cuda() 169 | 170 | out_ls = model(img_s_list, hiseq_s_list, mask_s_list, img_q, hiseq_q, mask_q) 171 | mask_s = torch.cat(mask_s_list, dim=0) 172 | 173 | loss = criterion(out_ls[0], mask_q) + criterion(out_ls[1], mask_q) + criterion(out_ls[2], mask_q) + criterion(out_ls[3], mask_s) * 0.2 174 | 175 | optimizer.zero_grad() 176 | loss.backward() 177 | optimizer.step() 178 | 179 | total_loss += loss.item() 180 | 181 | iters += 1 182 | if iters in lr_decay_iters: 183 | optimizer.param_groups[0]['lr'] /= 10.0 184 | 185 | tbar.set_description('Loss: %.3f' % (total_loss / (i + 1))) 186 | 187 | model.eval() 188 | set_seed(args.seed) 189 | miou = evaluate(model, testloader, args) 190 | 191 | if miou >= previous_best: 192 | best_model = deepcopy(model) 193 | previous_best = miou 194 | 195 | print('\nEvaluating on 5 seeds.....') 196 | total_miou = 0.0 197 | for seed in range(5): 198 | print('\nRun %i:' % (seed + 1)) 199 | set_seed(args.seed + seed) 200 | 201 | miou = evaluate(best_model, testloader, args) 202 | total_miou += miou 203 | 204 | print('\n' + '*' * 32) 205 | print('Averaged mIOU on 5 seeds: %.2f' % (total_miou / 5)) 206 | print('*' * 32 + '\n') 207 | 208 | torch.save(best_model.module.state_dict(), 209 | os.path.join(save_path, '%s_%ishot_%.2f.pth' % (args.backbone, args.shot, total_miou / 5))) 210 | 211 | 212 | if __name__ == '__main__': 213 | main() 214 | -------------------------------------------------------------------------------- /train.sh: -------------------------------------------------------------------------------- 1 | python -W ignore train.py \ 2 | --dataset LCSD --data-root /Your/path/to/dataset \ 3 | --backbone resnet101 --shot 1 --episode 6000 --snapshot 200 4 | 5 | # python -W ignore train.py \ 6 | # --dataset LCSD --data-root /Your/path/to/dataset \ 7 | # --backbone resnet101 --shot 5 --episode 6000 --snapshot 200 8 | 9 | # python -W ignore train.py \ 10 | # --dataset llCrackSeg9k --data-root /Your/path/to/dataset \ 11 | # --backbone resnet101 --shot 1 --episode 18000 --snapshot 1200 12 | 13 | # python -W ignore train.py \ 14 | # --dataset llCrackSeg9k --data-root /Your/path/to/dataset \ 15 | # --backbone resnet101 --shot 5 --episode 18000 --snapshot 1200 16 | -------------------------------------------------------------------------------- /util/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zyaocoder/CrackNex/122725c5e43ed1a9967cad0f29b01627901f3569/util/__init__.py -------------------------------------------------------------------------------- /util/loss.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | 5 | def one_hot_it(label): 6 | """ 7 | Convert a segmentation image label array to one-hot format 8 | by replacing each pixel value with a vector of length num_classes 9 | 10 | # Arguments 11 | label: The 2D array segmentation image label 12 | label_values 13 | 14 | # Returns 15 | A 2D array with the same width and hieght as the input, but 16 | with a depth size of num_classes 17 | """ 18 | label_values = [[0], [1]] 19 | semantic_map = [] 20 | for colour in label_values: 21 | equality = (label == colour[0]).unsqueeze(1).int() 22 | semantic_map.append(equality) 23 | semantic_map = torch.cat(semantic_map, dim=1) 24 | 25 | return semantic_map 26 | 27 | def identify_axis(shape): 28 | # Three dimensional 29 | if len(shape) == 5 : return [2,3,4] 30 | # Two dimensional 31 | elif len(shape) == 4 : return [2,3] 32 | # Exception - Unknown 33 | else : raise ValueError('Metric: Shape of tensor is neither 2D or 3D.') 34 | 35 | 36 | class SymmetricFocalLoss(nn.Module): 37 | """ 38 | Parameters 39 | ---------- 40 | delta : float, optional 41 | controls weight given to false positive and false negatives, by default 0.7 42 | gamma : float, optional 43 | Focal Tversky loss' focal parameter controls degree of down-weighting of easy examples, by default 2.0 44 | epsilon : float, optional 45 | clip values to prevent division by zero error 46 | """ 47 | def __init__(self, delta=0.7, gamma=2., epsilon=1e-07): 48 | super(SymmetricFocalLoss, self).__init__() 49 | self.delta = delta 50 | self.gamma = gamma 51 | self.epsilon = epsilon 52 | 53 | def forward(self, y_pred, y_true): 54 | 55 | y_pred = torch.clamp(y_pred, self.epsilon, 1. - self.epsilon) 56 | cross_entropy = -y_true * torch.log(y_pred) 57 | 58 | # Calculate losses separately for each class 59 | back_ce = torch.pow(1 - y_pred[:,0,:,:], self.gamma) * cross_entropy[:,0,:,:] 60 | back_ce = (1 - self.delta) * back_ce 61 | 62 | fore_ce = torch.pow(1 - y_pred[:,1,:,:], self.gamma) * cross_entropy[:,1,:,:] 63 | fore_ce = self.delta * fore_ce 64 | 65 | loss = torch.mean(torch.sum(torch.stack([back_ce, fore_ce], axis=-1), axis=-1)) 66 | 67 | return loss 68 | 69 | 70 | class AsymmetricFocalLoss(nn.Module): 71 | """For Imbalanced datasets 72 | Parameters 73 | ---------- 74 | delta : float, optional 75 | controls weight given to false positive and false negatives, by default 0.25 76 | gamma : float, optional 77 | Focal Tversky loss' focal parameter controls degree of down-weighting of easy examples, by default 2.0 78 | epsilon : float, optional 79 | clip values to prevent division by zero error 80 | """ 81 | def __init__(self, delta=0.7, gamma=2., epsilon=1e-07): 82 | super(AsymmetricFocalLoss, self).__init__() 83 | self.delta = delta 84 | self.gamma = gamma 85 | self.epsilon = epsilon 86 | 87 | def forward(self, y_pred, y_true): 88 | y_pred = torch.clamp(y_pred, self.epsilon, 1. - self.epsilon) 89 | cross_entropy = -y_true * torch.log(y_pred) 90 | 91 | # Calculate losses separately for each class, only suppressing background class 92 | back_ce = torch.pow(1 - y_pred[:,0,:,:], self.gamma) * cross_entropy[:,0,:,:] 93 | back_ce = (1 - self.delta) * back_ce 94 | 95 | fore_ce = cross_entropy[:,1,:,:] 96 | fore_ce = self.delta * fore_ce 97 | 98 | loss = torch.mean(torch.sum(torch.stack([back_ce, fore_ce], axis=-1), axis=-1)) 99 | 100 | return loss 101 | 102 | 103 | class SymmetricFocalTverskyLoss(nn.Module): 104 | """This is the implementation for binary segmentation. 105 | Parameters 106 | ---------- 107 | delta : float, optional 108 | controls weight given to false positive and false negatives, by default 0.7 109 | gamma : float, optional 110 | focal parameter controls degree of down-weighting of easy examples, by default 0.75 111 | smooth : float, optional 112 | smooithing constant to prevent division by 0 errors, by default 0.000001 113 | epsilon : float, optional 114 | clip values to prevent division by zero error 115 | """ 116 | def __init__(self, delta=0.7, gamma=0.75, epsilon=1e-07): 117 | super(SymmetricFocalTverskyLoss, self).__init__() 118 | self.delta = delta 119 | self.gamma = gamma 120 | self.epsilon = epsilon 121 | 122 | def forward(self, y_pred, y_true): 123 | y_pred = torch.clamp(y_pred, self.epsilon, 1. - self.epsilon) 124 | axis = identify_axis(y_true.size()) 125 | 126 | # Calculate true positives (tp), false negatives (fn) and false positives (fp) 127 | tp = torch.sum(y_true * y_pred, axis=axis) 128 | fn = torch.sum(y_true * (1-y_pred), axis=axis) 129 | fp = torch.sum((1-y_true) * y_pred, axis=axis) 130 | dice_class = (tp + self.epsilon)/(tp + self.delta*fn + (1-self.delta)*fp + self.epsilon) 131 | 132 | # Calculate losses separately for each class, enhancing both classes 133 | back_dice = (1-dice_class[:,0]) * torch.pow(1-dice_class[:,0], -self.gamma) 134 | fore_dice = (1-dice_class[:,1]) * torch.pow(1-dice_class[:,1], -self.gamma) 135 | 136 | # Average class scores 137 | loss = torch.mean(torch.stack([back_dice,fore_dice], axis=-1)) 138 | return loss 139 | 140 | 141 | class AsymmetricFocalTverskyLoss(nn.Module): 142 | """This is the implementation for binary segmentation. 143 | Parameters 144 | ---------- 145 | delta : float, optional 146 | controls weight given to false positive and false negatives, by default 0.7 147 | gamma : float, optional 148 | focal parameter controls degree of down-weighting of easy examples, by default 0.75 149 | smooth : float, optional 150 | smooithing constant to prevent division by 0 errors, by default 0.000001 151 | epsilon : float, optional 152 | clip values to prevent division by zero error 153 | """ 154 | def __init__(self, delta=0.7, gamma=0.75, epsilon=1e-07): 155 | super(AsymmetricFocalTverskyLoss, self).__init__() 156 | self.delta = delta 157 | self.gamma = gamma 158 | self.epsilon = epsilon 159 | 160 | def forward(self, y_pred, y_true): 161 | # Clip values to prevent division by zero error 162 | y_pred = torch.clamp(y_pred, self.epsilon, 1. - self.epsilon) 163 | axis = identify_axis(y_true.size()) 164 | 165 | # Calculate true positives (tp), false negatives (fn) and false positives (fp) 166 | tp = torch.sum(y_true * y_pred, axis=axis) 167 | fn = torch.sum(y_true * (1-y_pred), axis=axis) 168 | fp = torch.sum((1-y_true) * y_pred, axis=axis) 169 | dice_class = (tp + self.epsilon)/(tp + self.delta*fn + (1-self.delta)*fp + self.epsilon) 170 | 171 | # Calculate losses separately for each class, only enhancing foreground class 172 | back_dice = (1-dice_class[:,0]) 173 | fore_dice = (1-dice_class[:,1]) * torch.pow(1-dice_class[:,1], -self.gamma) 174 | 175 | # Average class scores 176 | loss = torch.mean(torch.stack([back_dice,fore_dice], axis=-1)) 177 | return loss 178 | 179 | 180 | class SymmetricUnifiedFocalLoss(nn.Module): 181 | """The Unified Focal loss is a new compound loss function that unifies Dice-based and cross entropy-based loss functions into a single framework. 182 | Parameters 183 | ---------- 184 | weight : float, optional 185 | represents lambda parameter and controls weight given to symmetric Focal Tversky loss and symmetric Focal loss, by default 0.5 186 | delta : float, optional 187 | controls weight given to each class, by default 0.6 188 | gamma : float, optional 189 | focal parameter controls the degree of background suppression and foreground enhancement, by default 0.5 190 | epsilon : float, optional 191 | clip values to prevent division by zero error 192 | """ 193 | def __init__(self, weight=0.5, delta=0.6, gamma=0.5): 194 | super(SymmetricUnifiedFocalLoss, self).__init__() 195 | self.weight = weight 196 | self.delta = delta 197 | self.gamma = gamma 198 | # self.sigmoid = torch.sigmoid() 199 | 200 | def forward(self, y_pred, y_true): 201 | y_true = one_hot_it(y_true) 202 | y_pred = torch.sigmoid(y_pred) 203 | 204 | symmetric_ftl = SymmetricFocalTverskyLoss(delta=self.delta, gamma=self.gamma)(y_pred, y_true) 205 | symmetric_fl = SymmetricFocalLoss(delta=self.delta, gamma=self.gamma)(y_pred, y_true) 206 | if self.weight is not None: 207 | return (self.weight * symmetric_ftl) + ((1-self.weight) * symmetric_fl) 208 | else: 209 | return symmetric_ftl + symmetric_fl 210 | 211 | 212 | class AsymmetricUnifiedFocalLoss(nn.Module): 213 | """The Unified Focal loss is a new compound loss function that unifies Dice-based and cross entropy-based loss functions into a single framework. 214 | Parameters 215 | ---------- 216 | weight : float, optional 217 | represents lambda parameter and controls weight given to asymmetric Focal Tversky loss and asymmetric Focal loss, by default 0.5 218 | delta : float, optional 219 | controls weight given to each class, by default 0.6 220 | gamma : float, optional 221 | focal parameter controls the degree of background suppression and foreground enhancement, by default 0.5 222 | epsilon : float, optional 223 | clip values to prevent division by zero error 224 | """ 225 | def __init__(self, weight=0.5, delta=0.6, gamma=0.2): 226 | super(AsymmetricUnifiedFocalLoss, self).__init__() 227 | self.weight = weight 228 | self.delta = delta 229 | self.gamma = gamma 230 | self.softmax = nn.LogSoftmax(dim=1) 231 | 232 | def forward(self, y_pred, y_true): 233 | y_true = one_hot_it(y_true) 234 | y_pred = self.softmax(y_pred) 235 | 236 | # Obtain Asymmetric Focal Tversky loss 237 | asymmetric_ftl = AsymmetricFocalTverskyLoss(delta=self.delta, gamma=self.gamma)(y_pred, y_true) 238 | 239 | # Obtain Asymmetric Focal loss 240 | asymmetric_fl = AsymmetricFocalLoss(delta=self.delta, gamma=self.gamma)(y_pred, y_true) 241 | 242 | # Return weighted sum of Asymmetrical Focal loss and Asymmetric Focal Tversky loss 243 | if self.weight is not None: 244 | return (self.weight * asymmetric_ftl) + ((1-self.weight) * asymmetric_fl) 245 | else: 246 | return asymmetric_ftl + asymmetric_fl -------------------------------------------------------------------------------- /util/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import random 3 | import torch 4 | import os 5 | import cv2 6 | from sklearn.utils.class_weight import compute_class_weight 7 | import pickle 8 | from pathlib import Path 9 | 10 | def count_params(model): 11 | param_num = sum(p.numel() for p in model.parameters()) 12 | return param_num / 1e6 13 | 14 | 15 | def set_seed(seed): 16 | torch.manual_seed(seed) 17 | random.seed(seed) 18 | np.random.seed(seed) 19 | 20 | def calc_crack_pixel_weight(data_dir): 21 | print('Computing class weights...') 22 | 23 | cweight_path = data_dir + '/cweight.pkl' 24 | 25 | if os.path.exists(cweight_path): 26 | print('Loading saved class weights.') 27 | with open(cweight_path, 'rb') as f: 28 | weight = pickle.load(f) 29 | else: 30 | files = [] 31 | 32 | for path in Path(data_dir + '/SegmentationClass').glob('*.*'): 33 | label = cv2.imread(str(path)).astype(np.uint8) 34 | if 2 not in np.unique(label): 35 | files.append(label) 36 | 37 | all_arr = np.stack(files, axis=0)[:,:,:,0] 38 | 39 | weight = compute_class_weight(class_weight = 'balanced',classes=np.unique(label), y=all_arr.flatten()) 40 | 41 | with open(cweight_path, 'wb') as f: 42 | pickle.dump(weight, f) 43 | print('Saved class weights under dataset path.') 44 | 45 | return weight 46 | 47 | class mIOU: 48 | def __init__(self, num_classes): 49 | self.num_classes = num_classes 50 | self.hist = np.zeros((num_classes, num_classes)) 51 | 52 | def _fast_hist(self, label_pred, label_true): 53 | mask = (label_true >= 0) & (label_true < self.num_classes) 54 | hist = np.bincount( 55 | self.num_classes * label_true[mask].astype(int) + 56 | label_pred[mask], minlength=self.num_classes ** 2).reshape(self.num_classes, self.num_classes) 57 | return hist 58 | 59 | def add_batch(self, predictions, gts): 60 | for lp, lt in zip(predictions, gts): 61 | self.hist += self._fast_hist(lp.flatten(), lt.flatten()) 62 | 63 | def evaluate(self): 64 | iu = np.diag(self.hist) / (self.hist.sum(axis=1) + self.hist.sum(axis=0) - np.diag(self.hist)) 65 | return np.nanmean(iu) 66 | --------------------------------------------------------------------------------