├── README.md
├── dataset
├── fewshot.py
└── transform.py
├── docs
├── CrackNex.jpg
├── vis.jpg
└── vis_wild.jpg
├── model
├── CrackNex_matching.py
├── DecompNet.py
├── backbone
│ └── DecompNet.tar
└── resnet.py
├── test.py
├── test.sh
├── train.py
├── train.sh
└── util
├── __init__.py
├── loss.py
└── utils.py
/README.md:
--------------------------------------------------------------------------------
1 | # [ICRA2024] CrackNex: a Few-shot Low-light Crack Segmentation Model Based on Retinex Theory for UAV Inspections
2 | [Zhen Yao](https://scholar.google.com/citations?user=8-IhrB0AAAAJ&hl=en&oi=sra), [Jiawei Xu](https://scholar.google.com/citations?user=b3XkcPkAAAAJ&hl=en&oi=ao), Shuhang Hou, [Mooi Choo Chuah](https://scholar.google.com/citations?hl=en&user=SZBKvksAAAAJ).
3 |
4 |
5 |
6 |
7 | The codebase contains the official code of our paper [CrackNex: a Few-shot Low-light Crack Segmentation Model Based on Retinex Theory for UAV Inspections](https://arxiv.org/abs/2403.03063), ICRA 2024.
8 |
9 | ## Abstract
10 | Routine visual inspections of concrete structures are imperative for upholding the safety and integrity of critical infrastructure. Such visual inspections sometimes happen under low-light conditions, e.g., checking for bridge health. Crack segmentation under such conditions is challenging due to the poor contrast between cracks and their surroundings. However, most deep learning methods are designed for well-illuminated crack images and hence their performance drops dramatically in low-light scenes. In addition, conventional approaches require many annotated low-light crack images which is time-consuming. In this paper, we address these challenges by proposing CrackNex, a framework that utilizes reflectance information based on Retinex Theory to learn a unified illumination-invariant representation. Furthermore, we utilize few-shot segmentation to solve the inefficient training data problem. In CrackNex, both a support prototype and a reflectance prototype are extracted from the support set. Then, a prototype fusion module is designed to integrate the features from both prototypes. CrackNex outperforms the SOTA methods on multiple datasets. Additionally, we present the first benchmark dataset, LCSD, for low-light crack segmentation. LCSD consists of 102 well-illuminated crack images and 41 low-light crack images.
11 |
12 |
13 |
14 |
15 |
16 |
17 | ## LCSD Dataset
18 | We have established a crack segmentation dataset, LCSD. It comprises 102 normal light crack images and 41 low-light crack images with fine-grained annotations.
19 |
20 | You can download the datasets from [the link](https://drive.google.com/drive/folders/1K81AjvIFje5BBxG1qxFmfqRhxf0PKNxe?usp=drive_link).
21 |
22 | ## Requirements
23 |
24 | - Python 3.7
25 | - PyTorch 1.9.1
26 | - cuda 11.1
27 | - pillow 8.4.0
28 | - opencv
29 | - sklearn
30 |
31 | Conda environment settings:
32 | ```bash
33 | conda create -n cracknex python=3.7 -y
34 | conda activate cracknex
35 |
36 | pip install torch==1.9.1+cu111 torchvision==0.10.1+cu111 -f https://download.pytorch.org/whl/torch_stable.html
37 | conda install -c conda-forge cudatoolkit-dev==11.1.1
38 | pip install opencv-python==4.5.2.54
39 | pip install scikit-learn
40 | pip install tqdm
41 | pip install pillow==8.4.0
42 | ```
43 |
44 | ## Data preparation
45 | **Pretrained model:** [ResNet-50](https://drive.google.com/file/d/1zphUj3ffl8J2HCRq5AjPsdOVojiGuQZB/view?usp=drive_link) | [ResNet-101](https://drive.google.com/file/d/1G6MVe_8ywk3NyHwpWWoUoZtvQ4pWJWj-/view?usp=drive_link)
46 |
47 |
48 | ### File Organization
49 |
50 | To prepare the LCSD datasets, you need to download all images and labels from the above link. Put all images in ``JPEGImages/`` subfolder and all labels in ``SegmentationClassAug/`` subfolder, and two split ``txt`` file into ``ImageSets/`` subfolder. We also provide a zip file of our processed [LCSD](https://drive.google.com/file/d/1a8Q1ng38KfNad2emyWJE4mgGv-xs3Dg4/view?usp=drive_link) dataset.
51 | ```
52 | ../ # parent directory
53 | ├── ./CrackNex # current (project) directory
54 | | ├── codes # various codes
55 | | └── ./pretrained # pretrained model directory
56 | | ├── resnet50.pth
57 | | └── resnet101.pth
58 | └── Datasets_CrackNex/
59 | ├── LCSD/
60 | │ ├── ImageSets/
61 | │ ├── JPEGImages/
62 | │ └── SegmentationClassAug/
63 | ```
64 |
65 | ## Run the code
66 |
67 | You can adapt the scripts of `train.sh` and `test.sh` to train and evaluate your models.
68 |
69 | You may change the ``backbone`` from ``resnet50`` to ``resnet101`` or change the ``shot`` from ``1`` to ``5`` for other settings.
70 |
71 | ```
72 | bash train.sh
73 | ```
74 |
75 | ```
76 | bash test.sh
77 | ```
78 |
79 | Remember to change the path to dataset and checkpoint.
80 |
81 | ## Evaluation and Trained Models
82 |
83 | ### LCSD
84 |
85 | | Method | Setting | Backbone | mIoU |
86 | | :-----: | :-----: | :---------: | :----: |
87 | | CrackNex (Ours) | 1-shot | ResNet-50 | [63.85](https://drive.google.com/file/d/1T9i1S_UlFOzjuSpn7Oeo9RX7BYgQ9jMw/view?usp=drive_link) |
88 | | CrackNex (Ours) | 1-shot | ResNet-101 | [66.10](https://drive.google.com/file/d/1vM3cRazeLmU2QnLIY4RjS7JTJBfHaCV3/view?usp=drive_link) |
89 | | CrackNex (Ours) | 5-shot | ResNet-50 | [65.17](https://drive.google.com/file/d/1uADCeaGZQNvr25dqVgXNJcjbZNAIrdI-/view?usp=drive_link) |
90 | | CrackNex (Ours) | 5-shot | ResNet-101 | [68.82](https://drive.google.com/file/d/1D3R9rCHrP58l48qSOhgtMzmDLvk8-8nf/view?usp=drive_link) |
91 |
92 | ### Visualization in LCSD
93 |
94 |
95 |
96 |
97 |
98 | ### Visualization of videos in the wild
99 |
100 |
101 |
102 |
103 | ## Acknowledgment
104 | This codebase is built based on [SSP's baseline code](https://github.com/fanq15/SSP/). We thank SSP and other FSS works for their great contributions.
105 |
106 | ## Citation
107 | ```bibtex
108 | @inproceedings{yao2024cracknex,
109 | title={Cracknex: a few-shot low-light crack segmentation model based on retinex theory for uav inspections},
110 | author={Yao, Zhen and Xu, Jiawei and Hou, Shuhang and Chuah, Mooi Choo},
111 | booktitle={2024 IEEE International Conference on Robotics and Automation (ICRA)},
112 | pages={11155--11162},
113 | year={2024},
114 | organization={IEEE}
115 | }
116 | ```
117 |
--------------------------------------------------------------------------------
/dataset/fewshot.py:
--------------------------------------------------------------------------------
1 | from dataset.transform import crop, hflip, normalize
2 |
3 | from collections import defaultdict
4 | import numpy as np
5 | import os
6 | from PIL import Image
7 | import random
8 | import torch
9 | from torch.utils.data import Dataset
10 | from torchvision import transforms
11 |
12 |
13 | class FewShot(Dataset):
14 | """
15 | FewShot generates support-query pairs in an episodic manner,
16 | intended for meta-training and meta-testing paradigm.
17 | """
18 |
19 | def __init__(self, root, size, mode, shot, episode):
20 | super(FewShot, self).__init__()
21 | self.size = size
22 | self.mode = mode
23 | self.fold = 0
24 | self.shot = shot
25 | self.episode = episode
26 |
27 | self.img_path = os.path.join(root, 'JPEGImages')
28 | self.mask_path = os.path.join(root, 'SegmentationClass')
29 | self.id_path = os.path.join(root, 'ImageSets')
30 |
31 | #class 1: normal light crack; class 2: lowlight crack;
32 | n_class = 2
33 |
34 | interval = n_class // 2
35 | if self.mode == 'train':
36 | # base classes = all classes - novel classes
37 | self.classes = set(range(interval * self.fold + 1, interval * (self.fold + 1) + 1))
38 | else:
39 | # novel classes
40 | self.classes = set(range(1, n_class + 1)) - set(range(interval * self.fold + 1, interval * (self.fold + 1) + 1))
41 | # the image ids must be stored in 'train.txt' and 'val.txt'
42 | with open(os.path.join(self.id_path, '%s.txt' % mode), 'r') as f:
43 | self.ids = f.read().splitlines()
44 |
45 | self._filter_ids()
46 | self.cls_to_ids = self._map_cls_to_cls()
47 |
48 | def __getitem__(self, item):
49 | # the sampling strategy is based on the description in OSLSM paper
50 |
51 | # query id, image, mask
52 | if self.mode == 'train':
53 | id_q = random.choice(self.ids)
54 | else:
55 | id_q = self.ids[item]
56 | img_q = Image.open(os.path.join(self.img_path, id_q + ".jpg")).convert('RGB')
57 | mask_q = Image.fromarray(np.array(Image.open(os.path.join(self.mask_path, id_q + ".png"))))
58 | # target class
59 | cls = random.choice(sorted(set(np.unique(mask_q)) & self.classes))
60 |
61 | # support ids, images and masks
62 | id_s_list, img_s_list, hiseq_s_list, mask_s_list = [], [], [], []
63 | while True:
64 | id_s = random.choice(sorted(set(self.cls_to_ids[cls]) - {id_q} - set(id_s_list)))
65 | img_s = Image.open(os.path.join(self.img_path, id_s + ".jpg")).convert('RGB')
66 | mask_s = Image.fromarray(np.array(Image.open(os.path.join(self.mask_path, id_s + ".png"))))
67 |
68 | # small objects in support images are filtered following PFENet
69 | if np.sum(np.array(mask_s) == cls) < 2 * 32 * 32:
70 | continue
71 |
72 | id_s_list.append(id_s)
73 | img_s_list.append(img_s)
74 | mask_s_list.append(mask_s)
75 | if len(id_s_list) == self.shot:
76 | break
77 |
78 | if self.mode == 'train':
79 | img_q, mask_q = crop(img_q, mask_q, self.size)
80 | img_q, mask_q = hflip(img_q, mask_q)
81 | for k in range(self.shot):
82 | img_s_list[k], mask_s_list[k] = crop(img_s_list[k], mask_s_list[k], self.size)
83 | img_s_list[k], mask_s_list[k] = hflip(img_s_list[k], mask_s_list[k])
84 |
85 | img_q, hiseq_q, mask_q = normalize(img_q, mask_q)
86 |
87 | for k in range(self.shot):
88 | img_s_list[k], hiseq_s, mask_s_list[k] = normalize(img_s_list[k], mask_s_list[k])
89 | hiseq_s_list.append(hiseq_s)
90 |
91 | # filter out irrelevant classes by setting them as background
92 | mask_q[(mask_q != cls) & (mask_q != 255)] = 0
93 | mask_q[mask_q == cls] = 1
94 | for k in range(self.shot):
95 | mask_s_list[k][(mask_s_list[k] != cls) & (mask_s_list[k] != 255)] = 0
96 | mask_s_list[k][mask_s_list[k] == cls] = 1
97 |
98 | return img_s_list, hiseq_s_list, mask_s_list, img_q, hiseq_q, mask_q, cls, id_s_list, id_q
99 |
100 | def __len__(self):
101 | if self.mode == 'train':
102 | return self.episode
103 | else:
104 | return len(self.ids)
105 |
106 | # remove images that do not contain any valid classes
107 | # and remove images whose valid objects are all small (according to PFENet)
108 | def _filter_ids(self):
109 | for i in range(len(self.ids) - 1, -1, -1):
110 | mask = Image.fromarray(np.array(Image.open(os.path.join(self.mask_path, self.ids[i] + '.png'))))
111 | classes = set(np.unique(mask)) & self.classes
112 | if not classes:
113 | del self.ids[i]
114 | continue
115 |
116 | # remove images whose valid objects are all small (according to PFENet)
117 | exist_large_objects = False
118 | for cls in classes:
119 | if np.sum(np.array(mask) == cls) >= 2 * 32 * 32:
120 | exist_large_objects = True
121 | break
122 | if not exist_large_objects:
123 | del self.ids[i]
124 |
125 | # map each valid class to a list of image ids
126 | def _map_cls_to_cls(self):
127 | cls_to_ids = defaultdict(list)
128 | for id_ in self.ids:
129 | mask = np.array(Image.open(os.path.join(self.mask_path, id_ + ".png")))
130 | valid_classes = set(np.unique(mask)) & self.classes
131 | for cls in valid_classes:
132 | cls_to_ids[cls].append(id_)
133 | return cls_to_ids
134 |
--------------------------------------------------------------------------------
/dataset/transform.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from PIL import Image, ImageOps
3 | import random
4 | import torch
5 | from torchvision import transforms
6 |
7 |
8 | def crop(img, mask, size):
9 | # padding height or width if smaller than cropping size
10 | w, h = img.size
11 | padw = size - w if w < size else 0
12 | padh = size - h if h < size else 0
13 | img = ImageOps.expand(img, border=(0, 0, padw, padh), fill=0)
14 | mask = ImageOps.expand(mask, border=(0, 0, padw, padh), fill=255)
15 |
16 | # cropping
17 | w, h = img.size
18 | x = random.randint(0, w - size)
19 | y = random.randint(0, h - size)
20 | img = img.crop((x, y, x + size, y + size))
21 | mask = mask.crop((x, y, x + size, y + size))
22 |
23 | return img, mask
24 |
25 |
26 | def hflip(img, mask):
27 | if random.random() < 0.5:
28 | img = img.transpose(Image.FLIP_LEFT_RIGHT)
29 | mask = mask.transpose(Image.FLIP_LEFT_RIGHT)
30 | return img, mask
31 |
32 |
33 | def normalize(img, mask):
34 | """
35 | :param img: PIL image
36 | :param mask: PIL image, corresponding mask
37 | :return: normalized torch tensor of image and mask
38 | """
39 | hiseq = ImageOps.equalize(img, mask=None)
40 | hiseq = transforms.Compose([
41 | transforms.ToTensor(),
42 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
43 | ])(hiseq)
44 | img = transforms.Compose([
45 | transforms.ToTensor(),
46 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
47 | ])(img)
48 | mask = torch.from_numpy(np.array(mask)).long()
49 | return img, hiseq, mask
50 |
--------------------------------------------------------------------------------
/docs/CrackNex.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zyaocoder/CrackNex/122725c5e43ed1a9967cad0f29b01627901f3569/docs/CrackNex.jpg
--------------------------------------------------------------------------------
/docs/vis.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zyaocoder/CrackNex/122725c5e43ed1a9967cad0f29b01627901f3569/docs/vis.jpg
--------------------------------------------------------------------------------
/docs/vis_wild.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zyaocoder/CrackNex/122725c5e43ed1a9967cad0f29b01627901f3569/docs/vis_wild.jpg
--------------------------------------------------------------------------------
/model/CrackNex_matching.py:
--------------------------------------------------------------------------------
1 | import model.resnet as resnet
2 |
3 | import torch
4 | from torch import nn
5 | import torch.nn.functional as F
6 | import pdb
7 | from model.DecompNet import DecompNet, load_decomp_ckpt
8 |
9 | class ASPP_module(nn.Module):
10 | def __init__(self, inplanes, planes, rate):
11 | super(ASPP_module, self).__init__()
12 | if rate == 1:
13 | kernel_size = 1
14 | padding = 0
15 | else:
16 | kernel_size = 3
17 | padding = rate
18 | self.atrous_convolution = nn.Conv2d(inplanes, planes, kernel_size=kernel_size,
19 | stride=1, padding=padding, dilation=rate, bias=False)
20 | self.bn = nn.BatchNorm2d(planes)
21 | self.relu = nn.ReLU()
22 |
23 | def forward(self, x):
24 | x = self.atrous_convolution(x)
25 | x = self.bn(x)
26 |
27 | return self.relu(x)
28 |
29 | class CrackNex(nn.Module):
30 | def __init__(self, backbone):
31 | super(CrackNex, self).__init__()
32 | decomp_path = './model/backbone/DecompNet.tar'
33 | rgb_backbone = resnet.__dict__[backbone](pretrained=True)
34 | ref_backbone = resnet.__dict__[backbone](pretrained=True)
35 | self.rgb_layer0 = nn.Sequential(rgb_backbone.conv1, rgb_backbone.bn1, rgb_backbone.relu, rgb_backbone.maxpool)
36 | self.rgb_layer1, self.rgb_layer2, self.rgb_layer3 = rgb_backbone.layer1, rgb_backbone.layer2, rgb_backbone.layer3
37 | self.ref_layer0 = nn.Sequential(ref_backbone.conv1, ref_backbone.bn1, ref_backbone.relu, ref_backbone.maxpool)
38 | self.ref_layer1, self.ref_layer2, self.ref_layer3 = ref_backbone.layer1, ref_backbone.layer2, ref_backbone.layer3
39 |
40 | self.proj_weight_FP = nn.Sequential(nn.Conv2d(2048, 1024, 1, stride=1, bias=False),
41 | nn.ReLU(),
42 | nn.Conv2d(1024, 1024, 1, stride=1, bias=False),
43 | nn.Sigmoid())
44 |
45 | self.proj_weight_BP = nn.Sequential(nn.Conv2d(2048, 1024, 1, stride=1, bias=False),
46 | nn.ReLU(),
47 | nn.Conv2d(1024, 1024, 1, stride=1, bias=False),
48 | nn.Sigmoid())
49 |
50 | rates = [1, 6, 12, 18]
51 | self.aspp1 = ASPP_module(1024, 512, rate=rates[0])
52 | self.aspp2 = ASPP_module(1024, 512, rate=rates[1])
53 | self.aspp3 = ASPP_module(1024, 512, rate=rates[2])
54 | self.aspp4 = ASPP_module(1024, 512, rate=rates[3])
55 | self.relu = nn.ReLU(inplace=True)
56 | self.dropout = nn.Dropout(0.1)
57 | self.global_avg_pool = nn.Sequential(nn.AdaptiveAvgPool2d((1, 1)),
58 | nn.Conv2d(1024, 512, 1, stride=1, bias=False),
59 | nn.BatchNorm2d(512),
60 | nn.ReLU())
61 | self.conv1 = nn.Conv2d(2560, 1024, 1, bias=False)
62 | self.bn1 = nn.BatchNorm2d(1024)
63 |
64 | # adopt [1x1, 48] for channel reduction.
65 | self.conv2 = nn.Conv2d(256, 256, 1, bias=False)
66 | self.bn2 = nn.BatchNorm2d(256)
67 | self.last_conv = nn.Sequential(nn.Conv2d(1280, 1024, kernel_size=3, stride=1, padding=1, bias=False),
68 | nn.BatchNorm2d(1024),
69 | nn.ReLU(),
70 | nn.Dropout(0.1),
71 | nn.Conv2d(1024, 1024, kernel_size=3, stride=1, padding=1, bias=False),
72 | nn.BatchNorm2d(1024),
73 | nn.ReLU(),
74 | nn.Dropout(0.1),
75 | nn.Conv2d(1024, 1024, kernel_size=1, stride=1))
76 |
77 | self.DecompNet = DecompNet()
78 | load_decomp_ckpt(self.DecompNet, decomp_path)
79 |
80 | weight = self.rgb_layer2._modules['0'].conv1.weight.clone()
81 | self.rgb_layer2._modules['0'].conv1 = nn.Conv2d(512, 128, kernel_size=(1,1), stride=(1,1), bias=False)
82 | self.rgb_layer2._modules['0'].conv1.weight.data[:, :256] = weight
83 | self.rgb_layer2._modules['0'].conv1.weight.data[:, 256:] = weight
84 | weight2 = self.rgb_layer2._modules['0'].downsample[0].weight.clone()
85 | self.rgb_layer2._modules['0'].downsample[0] = nn.Conv2d(512, 512, kernel_size=(1,1), stride=(2,2), bias=False)
86 | self.rgb_layer2._modules['0'].downsample[0].weight.data[:, :256] = weight2
87 | self.rgb_layer2._modules['0'].downsample[0].weight.data[:, 256:] = weight2
88 |
89 | def forward(self, img_s_list, hiseq_s_list, mask_s_list, img_q, hiseq_q, mask_q):
90 | h, w = img_q.shape[-2:]
91 | reflectance_q = self.DecompNet(img_q)
92 |
93 | # feature maps of support images
94 | feature_s_list = []
95 | feature_s_ref_list = []
96 | feature_s_ref_lowlevel_list = []
97 | for k in range(len(img_s_list)):
98 | with torch.no_grad():
99 | reflectance_s = self.DecompNet(img_s_list[k])
100 |
101 | s_0 = self.rgb_layer0(img_s_list[k]) # (100, 100)
102 | s_0_hiseq = self.rgb_layer0(hiseq_s_list[k])
103 |
104 | s_0 = self.rgb_layer1(s_0) # (100, 100)
105 | s_0_hiseq = self.rgb_layer1(s_0_hiseq)
106 |
107 | s_0_ref = self.ref_layer0(reflectance_s)
108 | s_0_ref = self.ref_layer1(s_0_ref)
109 | s_0_lowlevel = s_0_ref
110 |
111 | s_0 = self.rgb_layer2(torch.cat([s_0, s_0_hiseq], dim=1)) # (50, 50)
112 | s_0 = self.rgb_layer3(s_0) # (50, 50)
113 |
114 | s_0_ref = self.ref_layer2(s_0_ref) # (50, 50)
115 | s_0_ref = self.ref_layer3(s_0_ref) # (50, 50)
116 |
117 | s_0 = self.aspp_block(s_0, s_0_lowlevel)
118 |
119 | feature_s_list.append(s_0)
120 | feature_s_ref_list.append(s_0_ref)
121 | feature_s_ref_lowlevel_list.append(s_0_lowlevel)
122 | del s_0
123 | del s_0_hiseq
124 | del s_0_ref
125 | del s_0_lowlevel
126 | # feature map of query image
127 | with torch.no_grad():
128 | q_0 = self.rgb_layer0(img_q)
129 | q_0_hiseq = self.rgb_layer0(hiseq_q)
130 | q_0 = self.rgb_layer1(q_0)
131 | q_0_hiseq = self.rgb_layer1(q_0_hiseq)
132 |
133 | q_0_ref = self.ref_layer0(reflectance_q)
134 | q_0_ref = self.ref_layer1(q_0_ref)
135 | q_0_lowlevel = q_0_ref
136 |
137 | q_0 = self.rgb_layer2(torch.cat([q_0, q_0_hiseq], dim=1))
138 | feature_q = self.rgb_layer3(q_0) #(Bs, 1024, 50, 50)
139 |
140 | feature_q = self.aspp_block(feature_q, q_0_lowlevel)
141 |
142 | # foreground(target class) and background prototypes pooled from K support features
143 | feature_fg_list = []
144 | feature_bg_list = []
145 | feature_ref_fg_list = []
146 | feature_ref_bg_list = []
147 |
148 | supp_out_ls = []
149 | supp_out_ref_ls = []
150 | for k in range(len(img_s_list)):
151 | # Generate original prototype
152 | feature_fg = self.masked_average_pooling(feature_s_list[k],
153 | (mask_s_list[k] == 1).float())[None, :]
154 | feature_bg = self.masked_average_pooling(feature_s_list[k],
155 | (mask_s_list[k] == 0).float())[None, :]
156 | feature_fg_list.append(feature_fg) # (1, Bs, C)
157 | feature_bg_list.append(feature_bg) # (1, Bs, C)
158 |
159 | # Generate reflectance prototype
160 | feature_ref_fg = self.masked_average_pooling(feature_s_ref_list[k],
161 | (mask_s_list[k] == 1).float())[None, :]
162 | feature_ref_bg = self.masked_average_pooling(feature_s_ref_list[k],
163 | (mask_s_list[k] == 0).float())[None, :]
164 | feature_ref_fg_list.append(feature_ref_fg) # (1, Bs, C)
165 | feature_ref_bg_list.append(feature_ref_bg) # (1, Bs, C)
166 |
167 | if self.training:
168 | supp_similarity_fg = F.cosine_similarity(feature_s_list[k], feature_fg.squeeze(0)[..., None, None], dim=1)
169 | supp_similarity_bg = F.cosine_similarity(feature_s_list[k], feature_bg.squeeze(0)[..., None, None], dim=1)
170 | supp_out = torch.cat((supp_similarity_bg[:, None, ...], supp_similarity_fg[:, None, ...]), dim=1) * 10.0
171 |
172 | supp_out = F.interpolate(supp_out, size=(h, w), mode="bilinear", align_corners=True) # (Bs, 2, H, W)
173 | supp_out_ls.append(supp_out)
174 |
175 | # Reflectance support
176 | supp_similarity_ref_fg = F.cosine_similarity(feature_s_ref_list[k], feature_ref_fg.squeeze(0)[..., None, None], dim=1)
177 | supp_similarity_ref_bg = F.cosine_similarity(feature_s_ref_list[k], feature_ref_bg.squeeze(0)[..., None, None], dim=1)
178 | supp_out_ref = torch.cat((supp_similarity_ref_bg[:, None, ...], supp_similarity_ref_fg[:, None, ...]), dim=1) * 10.0
179 |
180 | supp_out_ref = F.interpolate(supp_out_ref, size=(h, w), mode="bilinear", align_corners=True) # (Bs, 2, H, W)
181 | supp_out_ref_ls.append(supp_out_ref)
182 |
183 | # average K foreground prototypes and K background prototypes (Bs, C, 1, 1)
184 | FP = torch.mean(torch.cat(feature_fg_list, dim=0), dim=0).unsqueeze(-1).unsqueeze(-1)
185 | BP = torch.mean(torch.cat(feature_bg_list, dim=0), dim=0).unsqueeze(-1).unsqueeze(-1)
186 |
187 | ref_FP = torch.mean(torch.cat(feature_ref_fg_list, dim=0), dim=0).unsqueeze(-1).unsqueeze(-1)
188 | ref_BP = torch.mean(torch.cat(feature_ref_bg_list, dim=0), dim=0).unsqueeze(-1).unsqueeze(-1)
189 |
190 | # Fuse two prototypes and generate updated features and prototypes
191 | cate_FP = torch.cat((FP, ref_FP), dim=1)
192 | cate_BP = torch.cat((BP, ref_BP), dim=1)
193 |
194 | normalized_FP = self.z_score_norm(cate_FP)
195 | normalized_BP = self.z_score_norm(cate_BP)
196 |
197 | weights_FP = self.proj_weight_FP(normalized_FP)
198 | weights_BP = self.proj_weight_BP(normalized_BP)
199 |
200 | FP = torch.mul((1 + weights_FP), FP)
201 | feature_q = torch.mul((1 + weights_FP), feature_q)
202 | BP = torch.mul((1 + weights_BP), BP)
203 |
204 | # measure the similarity of query features to fg/bg prototypes
205 | out_0 = self.similarity_func(feature_q, FP, BP) # (Bs, 2, H, W)
206 |
207 | out_1, out_2 = self.SSP_module(feature_q, out_0, FP, BP)
208 |
209 | out_0 = F.interpolate(out_0, size=(h, w), mode="bilinear", align_corners=True)
210 | out_1 = F.interpolate(out_1, size=(h, w), mode="bilinear", align_corners=True)
211 | out_2 = F.interpolate(out_2, size=(h, w), mode="bilinear", align_corners=True)
212 |
213 | out_ls = [out_2, out_1]
214 |
215 | if self.training:
216 | fg_q = self.masked_average_pooling(feature_q, (mask_q == 1).float())[None, :].squeeze(0)
217 | bg_q = self.masked_average_pooling(feature_q, (mask_q == 0).float())[None, :].squeeze(0)
218 |
219 | self_similarity_fg = F.cosine_similarity(feature_q, fg_q[..., None, None], dim=1)
220 | self_similarity_bg = F.cosine_similarity(feature_q, bg_q[..., None, None], dim=1)
221 | self_out = torch.cat((self_similarity_bg[:, None, ...], self_similarity_fg[:, None, ...]), dim=1) * 10.0
222 |
223 | self_out = F.interpolate(self_out, size=(h, w), mode="bilinear", align_corners=True)
224 | supp_out = torch.cat(supp_out_ls, 0)
225 | supp_out_ref = torch.cat(supp_out_ref_ls, 0)
226 |
227 | out_ls.append(self_out)
228 | out_ls.append(supp_out)
229 | out_ls.append(supp_out_ref)
230 |
231 | return out_ls
232 |
233 | def aspp_block(self, x, low_level_features):
234 | x1 = self.aspp1(x)
235 | x2 = self.aspp2(x)
236 | x3 = self.aspp3(x)
237 | x4 = self.aspp4(x)
238 | x5 = self.global_avg_pool(x)
239 | x5 = F.interpolate(x5, size=x4.size()[2:], mode='bilinear', align_corners=True)
240 |
241 | x = torch.cat((x1, x2, x3, x4, x5), dim=1)
242 |
243 | x = self.conv1(x)
244 | x = self.bn1(x)
245 | x = self.relu(x)
246 | x = self.dropout(x)
247 | x = F.interpolate(x, size=(low_level_features.size()[-2],
248 | low_level_features.size()[-1]), mode='bilinear', align_corners=True)
249 |
250 | low_level_features = self.conv2(low_level_features)
251 | low_level_features = self.bn2(low_level_features)
252 | low_level_features = self.relu(low_level_features)
253 |
254 | x = torch.cat((x, low_level_features), dim=1)
255 | x = self.last_conv(x)
256 |
257 | return x
258 |
259 | def z_score_norm(self, tensor):
260 | mu = torch.mean(tensor,dim=(1),keepdim=True)
261 | sd = torch.std(tensor,dim=(1),keepdim=True)
262 | normalized_tensor = (tensor - mu) / sd
263 |
264 | return normalized_tensor
265 |
266 | def SSP_module(self, feature_q, out_0, FP, BP):
267 | ##################### Self-Support Prototype (SSP) #####################
268 | SSFP_1, SSBP_1, ASFP_1, ASBP_1 = self.SSP_func(feature_q, out_0)
269 |
270 | FP_1 = FP * 0.5 + SSFP_1 * 0.5
271 | BP_1 = SSBP_1 * 0.3 + ASBP_1 * 0.7
272 |
273 | out_1 = self.similarity_func(feature_q, FP_1, BP_1)
274 |
275 | ##################### SSP Refinement #####################
276 | SSFP_2, SSBP_2, ASFP_2, ASBP_2 = self.SSP_func(feature_q, out_1)
277 |
278 | FP_2 = FP * 0.5 + SSFP_2 * 0.5
279 | BP_2 = SSBP_2 * 0.3 + ASBP_2 * 0.7
280 |
281 | FP_2 = FP * 0.5 + FP_1 * 0.2 + FP_2 * 0.3
282 | BP_2 = BP * 0.5 + BP_1 * 0.2 + BP_2 * 0.3
283 |
284 | out_2 = self.similarity_func(feature_q, FP_2, BP_2)
285 |
286 | out_2 = out_2 * 0.7 + out_1 * 0.3
287 |
288 | return out_1, out_2
289 |
290 | def SSP_func(self, feature_q, out):
291 | bs = feature_q.shape[0]
292 | pred_1 = out.softmax(1)
293 | pred_1 = pred_1.view(bs, 2, -1)
294 | pred_fg = pred_1[:, 1] # (Bs, H*W)
295 | pred_bg = pred_1[:, 0] # (Bs, H*W)
296 | fg_ls = []
297 | bg_ls = []
298 | fg_local_ls = []
299 | bg_local_ls = []
300 | for epi in range(bs):
301 | fg_thres = 0.7 #0.9 #0.6
302 | bg_thres = 0.6 #0.6
303 | cur_feat = feature_q[epi].view(1024, -1)
304 | f_h, f_w = feature_q[epi].shape[-2:]
305 | if (pred_fg[epi] > fg_thres).sum() > 0:
306 | fg_feat = cur_feat[:, (pred_fg[epi]>fg_thres)] #.mean(-1)
307 | else:
308 | fg_feat = cur_feat[:, torch.topk(pred_fg[epi], 12).indices] #.mean(-1)
309 | if (pred_bg[epi] > bg_thres).sum() > 0:
310 | bg_feat = cur_feat[:, (pred_bg[epi]>bg_thres)] #.mean(-1)
311 | else:
312 | bg_feat = cur_feat[:, torch.topk(pred_bg[epi], 12).indices] #.mean(-1)
313 | # global proto
314 | fg_proto = fg_feat.mean(-1)
315 | bg_proto = bg_feat.mean(-1)
316 | fg_ls.append(fg_proto.unsqueeze(0))
317 | bg_ls.append(bg_proto.unsqueeze(0))
318 |
319 | # local proto
320 | fg_feat_norm = fg_feat / torch.norm(fg_feat, 2, 0, True) # 1024, N1
321 | bg_feat_norm = bg_feat / torch.norm(bg_feat, 2, 0, True) # 1024, N2
322 | cur_feat_norm = cur_feat / torch.norm(cur_feat, 2, 0, True) # 1024, N3
323 |
324 | cur_feat_norm_t = cur_feat_norm.t() # N3, 1024
325 | fg_sim = torch.matmul(cur_feat_norm_t, fg_feat_norm) * 2.0 # N3, N1
326 | bg_sim = torch.matmul(cur_feat_norm_t, bg_feat_norm) * 2.0 # N3, N2
327 |
328 | fg_sim = fg_sim.softmax(-1)
329 | bg_sim = bg_sim.softmax(-1)
330 |
331 | fg_proto_local = torch.matmul(fg_sim, fg_feat.t()) # N3, 1024
332 | bg_proto_local = torch.matmul(bg_sim, bg_feat.t()) # N3, 1024
333 |
334 | fg_proto_local = fg_proto_local.t().view(1024, f_h, f_w).unsqueeze(0) # 1024, N3
335 | bg_proto_local = bg_proto_local.t().view(1024, f_h, f_w).unsqueeze(0) # 1024, N3
336 |
337 | fg_local_ls.append(fg_proto_local)
338 | bg_local_ls.append(bg_proto_local)
339 |
340 | # global proto
341 | new_fg = torch.cat(fg_ls, 0).unsqueeze(-1).unsqueeze(-1)
342 | new_bg = torch.cat(bg_ls, 0).unsqueeze(-1).unsqueeze(-1)
343 |
344 | # local proto
345 | new_fg_local = torch.cat(fg_local_ls, 0).unsqueeze(-1).unsqueeze(-1)
346 | new_bg_local = torch.cat(bg_local_ls, 0)
347 |
348 | return new_fg, new_bg, new_fg_local, new_bg_local
349 |
350 | def similarity_func(self, feature_q, fg_proto, bg_proto):
351 | similarity_fg = F.cosine_similarity(feature_q, fg_proto, dim=1)
352 | similarity_bg = F.cosine_similarity(feature_q, bg_proto, dim=1)
353 |
354 | out = torch.cat((similarity_bg[:, None, ...], similarity_fg[:, None, ...]), dim=1) * 10.0
355 | return out
356 |
357 | def masked_average_pooling(self, feature, mask):
358 | mask = F.interpolate(mask.unsqueeze(1), size=feature.shape[-2:], mode='bilinear', align_corners=True)
359 | masked_feature = torch.sum(feature * mask, dim=(2, 3)) \
360 | / (mask.sum(dim=(2, 3)) + 1e-5)
361 | return masked_feature
362 |
--------------------------------------------------------------------------------
/model/DecompNet.py:
--------------------------------------------------------------------------------
1 | import os
2 | import time
3 | import random
4 |
5 | from PIL import Image
6 | import torch
7 | import torch.nn as nn
8 | import torch.optim as optim
9 | import torch.nn.functional as F
10 | from torch.autograd import Variable
11 | import numpy as np
12 |
13 | def load_decomp_ckpt(model, ckpt_dir):
14 | assert os.path.exists(ckpt_dir), "Pretrained weights don't exist!"
15 |
16 | ckpt_dict = torch.load(ckpt_dir)
17 |
18 | model.load_state_dict(ckpt_dict)
19 |
20 | #freeze the DecompNet weights
21 | for _, param in model.named_parameters():
22 | param.requires_grad = False
23 |
24 | print('DecompNet successfull loaded from {}'.format(ckpt_dir))
25 |
26 | class DecompNet(nn.Module):
27 | def __init__(self, channel=64, kernel_size=3):
28 | super(DecompNet, self).__init__()
29 | # Shallow feature extraction
30 | self.net1_conv0 = nn.Conv2d(4, channel, kernel_size * 3,
31 | padding=4, padding_mode='replicate')
32 | # Activated layers!
33 | self.net1_convs = nn.Sequential(nn.Conv2d(channel, channel, kernel_size,
34 | padding=1, padding_mode='replicate'),
35 | nn.ReLU(),
36 | nn.Conv2d(channel, channel, kernel_size,
37 | padding=1, padding_mode='replicate'),
38 | nn.ReLU(),
39 | nn.Conv2d(channel, channel, kernel_size,
40 | padding=1, padding_mode='replicate'),
41 | nn.ReLU(),
42 | nn.Conv2d(channel, channel, kernel_size,
43 | padding=1, padding_mode='replicate'),
44 | nn.ReLU(),
45 | nn.Conv2d(channel, channel, kernel_size,
46 | padding=1, padding_mode='replicate'),
47 | nn.ReLU())
48 | # Final recon layer
49 | self.net1_recon = nn.Conv2d(channel, 4, kernel_size,
50 | padding=1, padding_mode='replicate')
51 |
52 | def forward(self, input_im):
53 | input_max= torch.max(input_im, dim=1, keepdim=True)[0]
54 | input_img= torch.cat((input_max, input_im), dim=1)
55 | feats0 = self.net1_conv0(input_img)
56 | featss = self.net1_convs(feats0)
57 | outs = self.net1_recon(featss)
58 | R = torch.sigmoid(outs[:, 0:3, :, :])
59 | return R
--------------------------------------------------------------------------------
/model/backbone/DecompNet.tar:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zyaocoder/CrackNex/122725c5e43ed1a9967cad0f29b01627901f3569/model/backbone/DecompNet.tar
--------------------------------------------------------------------------------
/model/resnet.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 |
6 | __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 'resnet152']
7 |
8 |
9 | def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
10 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
11 | padding=dilation, groups=groups, bias=False, dilation=dilation)
12 |
13 |
14 | def conv1x1(in_planes, out_planes, stride=1):
15 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
16 |
17 |
18 | class BasicBlock(nn.Module):
19 | expansion = 1
20 |
21 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
22 | base_width=64, dilation=1, norm_layer=None):
23 | super(BasicBlock, self).__init__()
24 | if norm_layer is None:
25 | norm_layer = nn.BatchNorm2d
26 | if groups != 1 or base_width != 64:
27 | raise ValueError('BasicBlock only supports groups=1 and base_width=64')
28 | if dilation > 1:
29 | raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
30 |
31 | self.conv1 = conv3x3(inplanes, planes, stride)
32 | self.bn1 = norm_layer(planes)
33 | self.relu = nn.ReLU(inplace=True)
34 | self.conv2 = conv3x3(planes, planes)
35 | self.bn2 = norm_layer(planes)
36 | self.downsample = downsample
37 | self.stride = stride
38 |
39 | def forward(self, x):
40 | identity = x
41 |
42 | out = self.conv1(x)
43 | out = self.bn1(out)
44 | out = self.relu(out)
45 |
46 | out = self.conv2(out)
47 | out = self.bn2(out)
48 |
49 | if self.downsample is not None:
50 | identity = self.downsample(x)
51 |
52 | out += identity
53 | out = self.relu(out)
54 |
55 | return out
56 |
57 |
58 | class Bottleneck(nn.Module):
59 | expansion = 4
60 |
61 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
62 | base_width=64, dilation=1, norm_layer=None, last_relu=True):
63 | super(Bottleneck, self).__init__()
64 | if norm_layer is None:
65 | norm_layer = nn.BatchNorm2d
66 | width = int(planes * (base_width / 64.)) * groups
67 |
68 | self.conv1 = conv1x1(inplanes, width)
69 | self.bn1 = norm_layer(width)
70 | self.conv2 = conv3x3(width, width, stride, groups, dilation)
71 | self.bn2 = norm_layer(width)
72 | self.conv3 = conv1x1(width, planes * self.expansion)
73 | self.bn3 = norm_layer(planes * self.expansion)
74 | self.relu = nn.ReLU(inplace=True)
75 | self.downsample = downsample
76 | self.stride = stride
77 | self.last_relu = last_relu
78 |
79 | def forward(self, x):
80 | identity = x
81 |
82 | out = self.conv1(x)
83 | out = self.bn1(out)
84 | out = self.relu(out)
85 |
86 | out = self.conv2(out)
87 | out = self.bn2(out)
88 | out = self.relu(out)
89 |
90 | out = self.conv3(out)
91 | out = self.bn3(out)
92 |
93 | if self.downsample is not None:
94 | identity = self.downsample(x)
95 |
96 | out += identity
97 | if self.last_relu:
98 | out = self.relu(out)
99 |
100 | return out
101 |
102 |
103 | class ResNet(nn.Module):
104 |
105 | def __init__(self, block, layers, zero_init_residual=False, groups=1,
106 | width_per_group=64, replace_stride_with_dilation=None, norm_layer=None):
107 | super(ResNet, self).__init__()
108 |
109 | self.out_channels = block.expansion * 256
110 |
111 | if norm_layer is None:
112 | norm_layer = nn.BatchNorm2d
113 | self._norm_layer = norm_layer
114 |
115 | self.inplanes = 128
116 | self.dilation = 1
117 | if replace_stride_with_dilation is None:
118 | replace_stride_with_dilation = [False, False, False]
119 | if len(replace_stride_with_dilation) != 3:
120 | raise ValueError("replace_stride_with_dilation should be None "
121 | "or a 3-element tuple, got {}".format(replace_stride_with_dilation))
122 | self.groups = groups
123 | self.base_width = width_per_group
124 |
125 | self.conv1 = nn.Sequential(
126 | conv3x3(3, 64, stride=2),
127 | norm_layer(64),
128 | nn.ReLU(inplace=True),
129 | conv3x3(64, 64),
130 | norm_layer(64),
131 | nn.ReLU(inplace=True),
132 | conv3x3(64, 128)
133 | )
134 | self.bn1 = norm_layer(128)
135 | self.relu = nn.ReLU(inplace=True)
136 |
137 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
138 | self.layer1 = self._make_layer(block, 64, layers[0])
139 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2,
140 | dilate=replace_stride_with_dilation[0])
141 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2,
142 | dilate=replace_stride_with_dilation[1], last_relu=False)
143 |
144 | for m in self.modules():
145 | if isinstance(m, nn.Conv2d):
146 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
147 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
148 | nn.init.constant_(m.weight, 1)
149 | nn.init.constant_(m.bias, 0)
150 |
151 | if zero_init_residual:
152 | for m in self.modules():
153 | if isinstance(m, Bottleneck):
154 | nn.init.constant_(m.bn3.weight, 0)
155 | elif isinstance(m, BasicBlock):
156 | nn.init.constant_(m.bn2.weight, 0)
157 |
158 | def _make_layer(self, block, planes, blocks, stride=1, dilate=False, last_relu=True):
159 | """
160 | :param last_relu: in metric learning paradigm, the final relu is removed (last_relu = False)
161 | """
162 | norm_layer = self._norm_layer
163 | downsample = None
164 | previous_dilation = self.dilation
165 | if dilate:
166 | self.dilation *= stride
167 | stride = 1
168 | if stride != 1 or self.inplanes != planes * block.expansion:
169 | downsample = nn.Sequential(
170 | conv1x1(self.inplanes, planes * block.expansion, stride),
171 | norm_layer(planes * block.expansion),
172 | )
173 |
174 | layers = list()
175 | layers.append(block(self.inplanes, planes, stride, downsample, self.groups,
176 | self.base_width, previous_dilation, norm_layer))
177 | self.inplanes = planes * block.expansion
178 | for i in range(1, blocks):
179 | use_relu = True if i != blocks - 1 else last_relu
180 | layers.append(block(self.inplanes, planes, groups=self.groups,
181 | base_width=self.base_width, dilation=self.dilation,
182 | norm_layer=norm_layer, last_relu=use_relu))
183 |
184 | return nn.Sequential(*layers)
185 |
186 | def base_forward(self, x):
187 | x = self.conv1(x)
188 | x = self.bn1(x)
189 | x = self.relu(x)
190 | x = self.maxpool(x)
191 |
192 | c1 = self.layer1(x)
193 | c2 = self.layer2(c1)
194 | c3 = self.layer3(c2)
195 |
196 | return c3
197 |
198 |
199 | def _resnet(arch, block, layers, pretrained, **kwargs):
200 | model = ResNet(block, layers, **kwargs)
201 | if pretrained:
202 | state_dict = torch.load("./pretrained/%s.pth" % arch)
203 | model.load_state_dict(state_dict, strict=False)
204 | return model
205 |
206 |
207 | def resnet18(pretrained=False):
208 | return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained)
209 |
210 |
211 | def resnet34(pretrained=False):
212 | return _resnet('resnet34', BasicBlock, [3, 4, 6, 3], pretrained)
213 |
214 |
215 | def resnet50(pretrained=False):
216 | return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained,
217 | replace_stride_with_dilation=[False, True, True])
218 |
219 |
220 | def resnet101(pretrained=False):
221 | return _resnet('resnet101', Bottleneck, [3, 4, 23, 3], pretrained,
222 | replace_stride_with_dilation=[False, True, True])
223 |
224 |
225 | def resnet152(pretrained=False):
226 | return _resnet('resnet152', Bottleneck, [3, 8, 36, 3], pretrained,
227 | replace_stride_with_dilation=[False, True, True])
228 |
--------------------------------------------------------------------------------
/test.py:
--------------------------------------------------------------------------------
1 | from dataset.fewshot import FewShot
2 | from model.CrackNex_matching import CrackNex
3 | from util.utils import count_params, set_seed, mIOU
4 |
5 | import argparse
6 | import os
7 | import torch
8 | from torch.nn import DataParallel
9 | from torch.utils.data import DataLoader
10 | from tqdm import tqdm
11 | import glob
12 | import numpy as np
13 | from PIL import Image
14 |
15 | def parse_args():
16 | parser = argparse.ArgumentParser(description='Mining Latent Classes for Few-shot Segmentation')
17 | # basic arguments
18 | parser.add_argument('--data-root',
19 | type=str,
20 | required=True,
21 | help='root path of training dataset')
22 | parser.add_argument('--dataset',
23 | type=str,
24 | default='llCrackSeg9k',
25 | choices=['llCrackSeg9k', 'LCSD'],
26 | help='training dataset')
27 | parser.add_argument('--backbone',
28 | type=str,
29 | choices=['resnet50', 'resnet101'],
30 | default='resnet50',
31 | help='backbone of semantic segmentation model')
32 |
33 | # few-shot training arguments
34 | parser.add_argument('--shot',
35 | type=int,
36 | default=1,
37 | help='number of support pairs')
38 | parser.add_argument('--seed',
39 | type=int,
40 | default=0,
41 | help='random seed to generate tesing samples')
42 | parser.add_argument('--path',
43 | type=str,
44 | help='chekpoint path')
45 | parser.add_argument('--savepath',
46 | type=str,
47 | default='./logs/',
48 | help='results saving path')
49 |
50 | args = parser.parse_args()
51 | return args
52 |
53 |
54 | def evaluate(model, dataloader, args):
55 | tbar = tqdm(dataloader)
56 |
57 | num_classes = 3
58 | metric = mIOU(num_classes)
59 |
60 | for i, (img_s_list, hiseq_s_list, mask_s_list, img_q, hiseq_q, mask_q, cls, _, id_q) in enumerate(tbar):
61 | img_q, hiseq_q, mask_q = img_q.cuda(), hiseq_q.cuda(), mask_q.cuda()
62 | for k in range(len(img_s_list)):
63 | img_s_list[k], hiseq_s_list[k], mask_s_list[k] = img_s_list[k].cuda(), hiseq_s_list[k].cuda(), mask_s_list[k].cuda()
64 | cls = cls[0].item()
65 |
66 | with torch.no_grad():
67 | out_ls = model(img_s_list, hiseq_s_list, mask_s_list, img_q, hiseq_q, mask_q)
68 | pred = torch.argmax(out_ls[0], dim=1)
69 |
70 | pred[pred == 1] = cls
71 | mask_q[mask_q == 1] = cls
72 |
73 | # if seed == 0:
74 | # result = pred.squeeze(0).cpu().numpy().copy()
75 | # result[result == 2] = 255
76 |
77 | # im = Image.fromarray(np.uint8(result))
78 | # im.save(args.savepath + id_q[0] + '.png')
79 |
80 | metric.add_batch(pred.cpu().numpy(), mask_q.cpu().numpy())
81 |
82 | tbar.set_description("Testing mIOU: %.2f" % (metric.evaluate() * 100.0))
83 |
84 | return metric.evaluate() * 100.0
85 |
86 | def main():
87 | args = parse_args()
88 | print('\n' + str(args))
89 |
90 | save_path = 'outdir/models/%s' % (args.dataset)
91 | os.makedirs(save_path, exist_ok=True)
92 |
93 | testset = FewShot(args.data_root.replace('train_coco', 'val_coco'), None, 'val',
94 | args.shot, 760 if args.dataset == 'LCSD' else 4000)
95 | testloader = DataLoader(testset, batch_size=1, shuffle=False,
96 | pin_memory=True, num_workers=4, drop_last=False)
97 |
98 | model = CrackNex(args.backbone)
99 | checkpoint_path = args.path
100 |
101 | print('Evaluating model:', checkpoint_path)
102 |
103 | checkpoint = torch.load(checkpoint_path)
104 | model.load_state_dict(checkpoint)
105 |
106 | #print(model)
107 | print('\nParams: %.1fM' % count_params(model))
108 |
109 | best_model = DataParallel(model).cuda()
110 |
111 | print('\nEvaluating on 5 seeds.....')
112 | total_miou = 0.0
113 | model.eval()
114 | for seed in range(5):
115 | print('\nRun %i:' % (seed + 1))
116 | set_seed(args.seed + seed)
117 |
118 | miou = evaluate(best_model, testloader, args)
119 | total_miou += miou
120 |
121 | print('\n' + '*' * 32)
122 | print('Averaged mIOU on 5 seeds: %.2f' % (total_miou / 5))
123 | print('*' * 32 + '\n')
124 |
125 |
126 | if __name__ == '__main__':
127 | main()
128 |
129 |
--------------------------------------------------------------------------------
/test.sh:
--------------------------------------------------------------------------------
1 | python -W ignore test.py \
2 | --dataset LCSD --data-root /Your/path/to/dataset \
3 | --backbone resnet101 --shot 5 --path /Your/path/to/checkpoint
4 |
5 | # python -W ignore test.py \
6 | # --dataset LCSD --data-root /Your/path/to/dataset \
7 | # --backbone resnet101 --shot 1 --path /Your/path/to/checkpoint
8 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | from dataset.fewshot import FewShot
2 | from model.CrackNex_matching import CrackNex
3 | from util.utils import count_params, set_seed, calc_crack_pixel_weight, mIOU
4 |
5 | import argparse
6 | from copy import deepcopy
7 | import os
8 | import time
9 | import torch
10 | from torch.nn import CrossEntropyLoss, DataParallel
11 | from torch.optim import SGD
12 | from torch.utils.data import DataLoader
13 | from tqdm import tqdm
14 |
15 | def parse_args():
16 | parser = argparse.ArgumentParser(description='Mining Latent Classes for Few-shot Segmentation')
17 | # basic arguments
18 | parser.add_argument('--data-root',
19 | type=str,
20 | required=True,
21 | help='root path of training dataset')
22 | parser.add_argument('--dataset',
23 | type=str,
24 | default='LCSD',
25 | choices=['llCrackSeg9k', 'LCSD'],
26 | help='training dataset')
27 | parser.add_argument('--batch-size',
28 | type=int,
29 | default=1,
30 | help='batch size of training')
31 | parser.add_argument('--lr',
32 | type=float,
33 | default=0.001,
34 | help='learning rate')
35 | parser.add_argument('--loss',
36 | type=str,
37 | choices=['CE', 'weightedCE'],
38 | default='CE',
39 | help='loss function')
40 | parser.add_argument('--crop-size',
41 | type=int,
42 | default=400,
43 | help='cropping size of training samples')
44 | parser.add_argument('--backbone',
45 | type=str,
46 | choices=['resnet50', 'resnet101'],
47 | default='resnet50',
48 | help='backbone of semantic segmentation model')
49 |
50 | # few-shot training arguments
51 | parser.add_argument('--shot',
52 | type=int,
53 | default=1,
54 | help='number of support pairs')
55 | parser.add_argument('--episode',
56 | type=int,
57 | default=6000,
58 | choices=[6000, 18000, 24000, 36000],
59 | help='total episodes of training')
60 | parser.add_argument('--snapshot',
61 | type=int,
62 | default=200,
63 | choices=[200, 1200, 2000],
64 | help='save the model after each snapshot episodes')
65 | parser.add_argument('--seed',
66 | type=int,
67 | default=0,
68 | help='random seed to generate tesing samples')
69 |
70 | args = parser.parse_args()
71 | return args
72 |
73 | def evaluate(model, dataloader, args):
74 | tbar = tqdm(dataloader)
75 |
76 | num_classes = 3
77 |
78 | metric = mIOU(num_classes)
79 | for i, (img_s_list, hiseq_s_list, mask_s_list, img_q, hiseq_q, mask_q, cls, _, id_q) in enumerate(tbar):
80 | img_q, hiseq_q, mask_q = img_q.cuda(), hiseq_q.cuda(), mask_q.cuda()
81 | for k in range(len(img_s_list)):
82 | img_s_list[k], hiseq_s_list[k], mask_s_list[k] = img_s_list[k].cuda(), hiseq_s_list[k].cuda(), mask_s_list[k].cuda()
83 | cls = cls[0].item()
84 |
85 | with torch.no_grad():
86 | out_ls = model(img_s_list, hiseq_s_list, mask_s_list, img_q, hiseq_q, mask_q)
87 | pred = torch.argmax(out_ls[0], dim=1)
88 |
89 | pred[pred == 1] = cls
90 | mask_q[mask_q == 1] = cls
91 |
92 | metric.add_batch(pred.cpu().numpy(), mask_q.cpu().numpy())
93 |
94 | tbar.set_description("Testing mIOU: %.2f" % (metric.evaluate() * 100.0))
95 |
96 | return metric.evaluate() * 100.0
97 |
98 | def main():
99 | args = parse_args()
100 | print('\n' + str(args))
101 |
102 | save_path = 'outdir/models/%s' % (args.dataset)
103 | os.makedirs(save_path, exist_ok=True)
104 |
105 | trainset = FewShot(args.data_root, args.crop_size,
106 | 'train', args.shot, args.snapshot)
107 | trainloader = DataLoader(trainset, batch_size=args.batch_size, shuffle=True,
108 | pin_memory=True, num_workers=0, drop_last=True)
109 | testset = FewShot(args.data_root, None, 'val',
110 | args.shot, 41 if args.dataset == 'LCSD' else 1486)
111 | testloader = DataLoader(testset, batch_size=1, shuffle=False,
112 | pin_memory=True, num_workers=0, drop_last=False)
113 |
114 | model = CrackNex(args.backbone)
115 | print('\nParams: %.1fM' % count_params(model))
116 |
117 | for param in model.rgb_layer0.parameters():
118 | param.requires_grad = False
119 | for param in model.rgb_layer1.parameters():
120 | param.requires_grad = False
121 | for param in model.ref_layer0.parameters():
122 | param.requires_grad = False
123 | for param in model.ref_layer1.parameters():
124 | param.requires_grad = False
125 |
126 | for module in model.modules():
127 | if isinstance(module, torch.nn.BatchNorm2d):
128 | for param in module.parameters():
129 | param.requires_grad = False
130 |
131 | if args.loss == 'CE':
132 | criterion = CrossEntropyLoss(ignore_index=255)
133 | elif args.loss == 'weightedCE':
134 | crack_weight = [1, 0.4] * calc_crack_pixel_weight(args.data_root)
135 | print(f'positive weight: {crack_weight}')
136 | criterion = CrossEntropyLoss(weight=torch.Tensor([crack_weight]).to('cuda').squeeze(0), ignore_index=255)
137 |
138 | optimizer = SGD([param for param in model.parameters() if param.requires_grad],
139 | lr=args.lr, momentum=0.9, weight_decay=5e-4)
140 |
141 | model = DataParallel(model).cuda()
142 | best_model = None
143 |
144 | iters = 0
145 | total_iters = args.episode // args.batch_size
146 | lr_decay_iters = [total_iters // 3, total_iters * 2 // 3]
147 | previous_best = 0
148 |
149 | # each snapshot is considered as an epoch
150 | for epoch in range(args.episode // args.snapshot):
151 | print("\n==> Epoch %i, learning rate = %.5f\t\t\t\t Previous best = %.2f"
152 | % (epoch, optimizer.param_groups[0]["lr"], previous_best))
153 |
154 | model.train()
155 |
156 | for module in model.modules():
157 | if isinstance(module, torch.nn.BatchNorm2d):
158 | module.eval()
159 |
160 | total_loss = 0.0
161 |
162 | tbar = tqdm(trainloader)
163 | set_seed(int(time.time()))
164 |
165 | for i, (img_s_list, hiseq_s_list, mask_s_list, img_q, hiseq_q, mask_q, _, _, _) in enumerate(tbar):
166 | img_q, hiseq_q, mask_q = img_q.cuda(), hiseq_q.cuda(), mask_q.cuda()
167 | for k in range(len(img_s_list)):
168 | img_s_list[k], hiseq_s_list[k], mask_s_list[k] = img_s_list[k].cuda(), hiseq_s_list[k].cuda(), mask_s_list[k].cuda()
169 |
170 | out_ls = model(img_s_list, hiseq_s_list, mask_s_list, img_q, hiseq_q, mask_q)
171 | mask_s = torch.cat(mask_s_list, dim=0)
172 |
173 | loss = criterion(out_ls[0], mask_q) + criterion(out_ls[1], mask_q) + criterion(out_ls[2], mask_q) + criterion(out_ls[3], mask_s) * 0.2
174 |
175 | optimizer.zero_grad()
176 | loss.backward()
177 | optimizer.step()
178 |
179 | total_loss += loss.item()
180 |
181 | iters += 1
182 | if iters in lr_decay_iters:
183 | optimizer.param_groups[0]['lr'] /= 10.0
184 |
185 | tbar.set_description('Loss: %.3f' % (total_loss / (i + 1)))
186 |
187 | model.eval()
188 | set_seed(args.seed)
189 | miou = evaluate(model, testloader, args)
190 |
191 | if miou >= previous_best:
192 | best_model = deepcopy(model)
193 | previous_best = miou
194 |
195 | print('\nEvaluating on 5 seeds.....')
196 | total_miou = 0.0
197 | for seed in range(5):
198 | print('\nRun %i:' % (seed + 1))
199 | set_seed(args.seed + seed)
200 |
201 | miou = evaluate(best_model, testloader, args)
202 | total_miou += miou
203 |
204 | print('\n' + '*' * 32)
205 | print('Averaged mIOU on 5 seeds: %.2f' % (total_miou / 5))
206 | print('*' * 32 + '\n')
207 |
208 | torch.save(best_model.module.state_dict(),
209 | os.path.join(save_path, '%s_%ishot_%.2f.pth' % (args.backbone, args.shot, total_miou / 5)))
210 |
211 |
212 | if __name__ == '__main__':
213 | main()
214 |
--------------------------------------------------------------------------------
/train.sh:
--------------------------------------------------------------------------------
1 | python -W ignore train.py \
2 | --dataset LCSD --data-root /Your/path/to/dataset \
3 | --backbone resnet101 --shot 1 --episode 6000 --snapshot 200
4 |
5 | # python -W ignore train.py \
6 | # --dataset LCSD --data-root /Your/path/to/dataset \
7 | # --backbone resnet101 --shot 5 --episode 6000 --snapshot 200
8 |
9 | # python -W ignore train.py \
10 | # --dataset llCrackSeg9k --data-root /Your/path/to/dataset \
11 | # --backbone resnet101 --shot 1 --episode 18000 --snapshot 1200
12 |
13 | # python -W ignore train.py \
14 | # --dataset llCrackSeg9k --data-root /Your/path/to/dataset \
15 | # --backbone resnet101 --shot 5 --episode 18000 --snapshot 1200
16 |
--------------------------------------------------------------------------------
/util/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zyaocoder/CrackNex/122725c5e43ed1a9967cad0f29b01627901f3569/util/__init__.py
--------------------------------------------------------------------------------
/util/loss.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torch.nn as nn
4 |
5 | def one_hot_it(label):
6 | """
7 | Convert a segmentation image label array to one-hot format
8 | by replacing each pixel value with a vector of length num_classes
9 |
10 | # Arguments
11 | label: The 2D array segmentation image label
12 | label_values
13 |
14 | # Returns
15 | A 2D array with the same width and hieght as the input, but
16 | with a depth size of num_classes
17 | """
18 | label_values = [[0], [1]]
19 | semantic_map = []
20 | for colour in label_values:
21 | equality = (label == colour[0]).unsqueeze(1).int()
22 | semantic_map.append(equality)
23 | semantic_map = torch.cat(semantic_map, dim=1)
24 |
25 | return semantic_map
26 |
27 | def identify_axis(shape):
28 | # Three dimensional
29 | if len(shape) == 5 : return [2,3,4]
30 | # Two dimensional
31 | elif len(shape) == 4 : return [2,3]
32 | # Exception - Unknown
33 | else : raise ValueError('Metric: Shape of tensor is neither 2D or 3D.')
34 |
35 |
36 | class SymmetricFocalLoss(nn.Module):
37 | """
38 | Parameters
39 | ----------
40 | delta : float, optional
41 | controls weight given to false positive and false negatives, by default 0.7
42 | gamma : float, optional
43 | Focal Tversky loss' focal parameter controls degree of down-weighting of easy examples, by default 2.0
44 | epsilon : float, optional
45 | clip values to prevent division by zero error
46 | """
47 | def __init__(self, delta=0.7, gamma=2., epsilon=1e-07):
48 | super(SymmetricFocalLoss, self).__init__()
49 | self.delta = delta
50 | self.gamma = gamma
51 | self.epsilon = epsilon
52 |
53 | def forward(self, y_pred, y_true):
54 |
55 | y_pred = torch.clamp(y_pred, self.epsilon, 1. - self.epsilon)
56 | cross_entropy = -y_true * torch.log(y_pred)
57 |
58 | # Calculate losses separately for each class
59 | back_ce = torch.pow(1 - y_pred[:,0,:,:], self.gamma) * cross_entropy[:,0,:,:]
60 | back_ce = (1 - self.delta) * back_ce
61 |
62 | fore_ce = torch.pow(1 - y_pred[:,1,:,:], self.gamma) * cross_entropy[:,1,:,:]
63 | fore_ce = self.delta * fore_ce
64 |
65 | loss = torch.mean(torch.sum(torch.stack([back_ce, fore_ce], axis=-1), axis=-1))
66 |
67 | return loss
68 |
69 |
70 | class AsymmetricFocalLoss(nn.Module):
71 | """For Imbalanced datasets
72 | Parameters
73 | ----------
74 | delta : float, optional
75 | controls weight given to false positive and false negatives, by default 0.25
76 | gamma : float, optional
77 | Focal Tversky loss' focal parameter controls degree of down-weighting of easy examples, by default 2.0
78 | epsilon : float, optional
79 | clip values to prevent division by zero error
80 | """
81 | def __init__(self, delta=0.7, gamma=2., epsilon=1e-07):
82 | super(AsymmetricFocalLoss, self).__init__()
83 | self.delta = delta
84 | self.gamma = gamma
85 | self.epsilon = epsilon
86 |
87 | def forward(self, y_pred, y_true):
88 | y_pred = torch.clamp(y_pred, self.epsilon, 1. - self.epsilon)
89 | cross_entropy = -y_true * torch.log(y_pred)
90 |
91 | # Calculate losses separately for each class, only suppressing background class
92 | back_ce = torch.pow(1 - y_pred[:,0,:,:], self.gamma) * cross_entropy[:,0,:,:]
93 | back_ce = (1 - self.delta) * back_ce
94 |
95 | fore_ce = cross_entropy[:,1,:,:]
96 | fore_ce = self.delta * fore_ce
97 |
98 | loss = torch.mean(torch.sum(torch.stack([back_ce, fore_ce], axis=-1), axis=-1))
99 |
100 | return loss
101 |
102 |
103 | class SymmetricFocalTverskyLoss(nn.Module):
104 | """This is the implementation for binary segmentation.
105 | Parameters
106 | ----------
107 | delta : float, optional
108 | controls weight given to false positive and false negatives, by default 0.7
109 | gamma : float, optional
110 | focal parameter controls degree of down-weighting of easy examples, by default 0.75
111 | smooth : float, optional
112 | smooithing constant to prevent division by 0 errors, by default 0.000001
113 | epsilon : float, optional
114 | clip values to prevent division by zero error
115 | """
116 | def __init__(self, delta=0.7, gamma=0.75, epsilon=1e-07):
117 | super(SymmetricFocalTverskyLoss, self).__init__()
118 | self.delta = delta
119 | self.gamma = gamma
120 | self.epsilon = epsilon
121 |
122 | def forward(self, y_pred, y_true):
123 | y_pred = torch.clamp(y_pred, self.epsilon, 1. - self.epsilon)
124 | axis = identify_axis(y_true.size())
125 |
126 | # Calculate true positives (tp), false negatives (fn) and false positives (fp)
127 | tp = torch.sum(y_true * y_pred, axis=axis)
128 | fn = torch.sum(y_true * (1-y_pred), axis=axis)
129 | fp = torch.sum((1-y_true) * y_pred, axis=axis)
130 | dice_class = (tp + self.epsilon)/(tp + self.delta*fn + (1-self.delta)*fp + self.epsilon)
131 |
132 | # Calculate losses separately for each class, enhancing both classes
133 | back_dice = (1-dice_class[:,0]) * torch.pow(1-dice_class[:,0], -self.gamma)
134 | fore_dice = (1-dice_class[:,1]) * torch.pow(1-dice_class[:,1], -self.gamma)
135 |
136 | # Average class scores
137 | loss = torch.mean(torch.stack([back_dice,fore_dice], axis=-1))
138 | return loss
139 |
140 |
141 | class AsymmetricFocalTverskyLoss(nn.Module):
142 | """This is the implementation for binary segmentation.
143 | Parameters
144 | ----------
145 | delta : float, optional
146 | controls weight given to false positive and false negatives, by default 0.7
147 | gamma : float, optional
148 | focal parameter controls degree of down-weighting of easy examples, by default 0.75
149 | smooth : float, optional
150 | smooithing constant to prevent division by 0 errors, by default 0.000001
151 | epsilon : float, optional
152 | clip values to prevent division by zero error
153 | """
154 | def __init__(self, delta=0.7, gamma=0.75, epsilon=1e-07):
155 | super(AsymmetricFocalTverskyLoss, self).__init__()
156 | self.delta = delta
157 | self.gamma = gamma
158 | self.epsilon = epsilon
159 |
160 | def forward(self, y_pred, y_true):
161 | # Clip values to prevent division by zero error
162 | y_pred = torch.clamp(y_pred, self.epsilon, 1. - self.epsilon)
163 | axis = identify_axis(y_true.size())
164 |
165 | # Calculate true positives (tp), false negatives (fn) and false positives (fp)
166 | tp = torch.sum(y_true * y_pred, axis=axis)
167 | fn = torch.sum(y_true * (1-y_pred), axis=axis)
168 | fp = torch.sum((1-y_true) * y_pred, axis=axis)
169 | dice_class = (tp + self.epsilon)/(tp + self.delta*fn + (1-self.delta)*fp + self.epsilon)
170 |
171 | # Calculate losses separately for each class, only enhancing foreground class
172 | back_dice = (1-dice_class[:,0])
173 | fore_dice = (1-dice_class[:,1]) * torch.pow(1-dice_class[:,1], -self.gamma)
174 |
175 | # Average class scores
176 | loss = torch.mean(torch.stack([back_dice,fore_dice], axis=-1))
177 | return loss
178 |
179 |
180 | class SymmetricUnifiedFocalLoss(nn.Module):
181 | """The Unified Focal loss is a new compound loss function that unifies Dice-based and cross entropy-based loss functions into a single framework.
182 | Parameters
183 | ----------
184 | weight : float, optional
185 | represents lambda parameter and controls weight given to symmetric Focal Tversky loss and symmetric Focal loss, by default 0.5
186 | delta : float, optional
187 | controls weight given to each class, by default 0.6
188 | gamma : float, optional
189 | focal parameter controls the degree of background suppression and foreground enhancement, by default 0.5
190 | epsilon : float, optional
191 | clip values to prevent division by zero error
192 | """
193 | def __init__(self, weight=0.5, delta=0.6, gamma=0.5):
194 | super(SymmetricUnifiedFocalLoss, self).__init__()
195 | self.weight = weight
196 | self.delta = delta
197 | self.gamma = gamma
198 | # self.sigmoid = torch.sigmoid()
199 |
200 | def forward(self, y_pred, y_true):
201 | y_true = one_hot_it(y_true)
202 | y_pred = torch.sigmoid(y_pred)
203 |
204 | symmetric_ftl = SymmetricFocalTverskyLoss(delta=self.delta, gamma=self.gamma)(y_pred, y_true)
205 | symmetric_fl = SymmetricFocalLoss(delta=self.delta, gamma=self.gamma)(y_pred, y_true)
206 | if self.weight is not None:
207 | return (self.weight * symmetric_ftl) + ((1-self.weight) * symmetric_fl)
208 | else:
209 | return symmetric_ftl + symmetric_fl
210 |
211 |
212 | class AsymmetricUnifiedFocalLoss(nn.Module):
213 | """The Unified Focal loss is a new compound loss function that unifies Dice-based and cross entropy-based loss functions into a single framework.
214 | Parameters
215 | ----------
216 | weight : float, optional
217 | represents lambda parameter and controls weight given to asymmetric Focal Tversky loss and asymmetric Focal loss, by default 0.5
218 | delta : float, optional
219 | controls weight given to each class, by default 0.6
220 | gamma : float, optional
221 | focal parameter controls the degree of background suppression and foreground enhancement, by default 0.5
222 | epsilon : float, optional
223 | clip values to prevent division by zero error
224 | """
225 | def __init__(self, weight=0.5, delta=0.6, gamma=0.2):
226 | super(AsymmetricUnifiedFocalLoss, self).__init__()
227 | self.weight = weight
228 | self.delta = delta
229 | self.gamma = gamma
230 | self.softmax = nn.LogSoftmax(dim=1)
231 |
232 | def forward(self, y_pred, y_true):
233 | y_true = one_hot_it(y_true)
234 | y_pred = self.softmax(y_pred)
235 |
236 | # Obtain Asymmetric Focal Tversky loss
237 | asymmetric_ftl = AsymmetricFocalTverskyLoss(delta=self.delta, gamma=self.gamma)(y_pred, y_true)
238 |
239 | # Obtain Asymmetric Focal loss
240 | asymmetric_fl = AsymmetricFocalLoss(delta=self.delta, gamma=self.gamma)(y_pred, y_true)
241 |
242 | # Return weighted sum of Asymmetrical Focal loss and Asymmetric Focal Tversky loss
243 | if self.weight is not None:
244 | return (self.weight * asymmetric_ftl) + ((1-self.weight) * asymmetric_fl)
245 | else:
246 | return asymmetric_ftl + asymmetric_fl
--------------------------------------------------------------------------------
/util/utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import random
3 | import torch
4 | import os
5 | import cv2
6 | from sklearn.utils.class_weight import compute_class_weight
7 | import pickle
8 | from pathlib import Path
9 |
10 | def count_params(model):
11 | param_num = sum(p.numel() for p in model.parameters())
12 | return param_num / 1e6
13 |
14 |
15 | def set_seed(seed):
16 | torch.manual_seed(seed)
17 | random.seed(seed)
18 | np.random.seed(seed)
19 |
20 | def calc_crack_pixel_weight(data_dir):
21 | print('Computing class weights...')
22 |
23 | cweight_path = data_dir + '/cweight.pkl'
24 |
25 | if os.path.exists(cweight_path):
26 | print('Loading saved class weights.')
27 | with open(cweight_path, 'rb') as f:
28 | weight = pickle.load(f)
29 | else:
30 | files = []
31 |
32 | for path in Path(data_dir + '/SegmentationClass').glob('*.*'):
33 | label = cv2.imread(str(path)).astype(np.uint8)
34 | if 2 not in np.unique(label):
35 | files.append(label)
36 |
37 | all_arr = np.stack(files, axis=0)[:,:,:,0]
38 |
39 | weight = compute_class_weight(class_weight = 'balanced',classes=np.unique(label), y=all_arr.flatten())
40 |
41 | with open(cweight_path, 'wb') as f:
42 | pickle.dump(weight, f)
43 | print('Saved class weights under dataset path.')
44 |
45 | return weight
46 |
47 | class mIOU:
48 | def __init__(self, num_classes):
49 | self.num_classes = num_classes
50 | self.hist = np.zeros((num_classes, num_classes))
51 |
52 | def _fast_hist(self, label_pred, label_true):
53 | mask = (label_true >= 0) & (label_true < self.num_classes)
54 | hist = np.bincount(
55 | self.num_classes * label_true[mask].astype(int) +
56 | label_pred[mask], minlength=self.num_classes ** 2).reshape(self.num_classes, self.num_classes)
57 | return hist
58 |
59 | def add_batch(self, predictions, gts):
60 | for lp, lt in zip(predictions, gts):
61 | self.hist += self._fast_hist(lp.flatten(), lt.flatten())
62 |
63 | def evaluate(self):
64 | iu = np.diag(self.hist) / (self.hist.sum(axis=1) + self.hist.sum(axis=0) - np.diag(self.hist))
65 | return np.nanmean(iu)
66 |
--------------------------------------------------------------------------------