├── assets └── figure_02.jpg ├── requirements.txt ├── configs └── BlackSoil.yaml ├── test.py ├── util ├── dist_helper.py ├── evaluate.py ├── utils.py └── saliency_metric.py ├── result └── README.md ├── testdata └── README.md ├── prediction ├── mIOU.py └── collage.py ├── README.md ├── dataset ├── data.py └── transform.py ├── DiceLoss.py ├── baseline ├── mamba_unet.py ├── unet.py ├── resnet.py ├── BS_Mamba.py ├── UltraLighet_VM_Unet.py ├── local_scan.py ├── multi_mamba.py ├── hrnet.py ├── local_vmamba.py └── vmamba.py ├── result.py ├── fortest └── predict_multi.py └── train.py /assets/figure_02.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kunzhan/BS-Mamba/HEAD/assets/figure_02.jpg -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | causal-conv1d==1.2.0.post2 2 | cuda-nvcc==11.8.89 3 | cudatoolkit==11.8.0 4 | einops==0.7.0 5 | fvcore==0.1.5.post20221221 6 | mamba-ssm==1.2.0.post1 7 | mpmath==1.3.0 8 | ninja==1.11.1.1 9 | numpy==1.26.4 10 | numpy-base==1.26.4 11 | opencv-python==4.9.0.80 12 | pillow==10. 13 | -------------------------------------------------------------------------------- /configs/BlackSoil.yaml: -------------------------------------------------------------------------------- 1 | # arguments for dataset 2 | dataset: grassset 3 | nclass: 2 4 | crop_size: 384 5 | data_root: grassset2 6 | 7 | # arguments for training 8 | epochs: 60 9 | batch_size: 8 10 | lr: 0.0002 # 4GPUs 11 | lr_multi: 10.0 12 | criterion: 13 | name: CELoss 14 | kwargs: 15 | ignore_index: -100 16 | thresh_init: 0.85 17 | 18 | # arguments for model 19 | model: BS_Mamba 20 | backbone: BS_Mamba 21 | pretrain: True 22 | multi_grid: False 23 | replace_stride_with_dilation: [False, True, True] 24 | # dilations: [6, 12, 18] 25 | dilations: [12, 24, 36] -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import logging 3 | import os 4 | import pprint 5 | from tqdm import tqdm 6 | import numpy as np 7 | import torch 8 | from torch import nn 9 | import torch.distributed as dist 10 | import torch.nn.functional as F 11 | import torch.backends.cudnn as cudnn 12 | import torch.optim as optim 13 | from prediction.collage import CollageGenerator 14 | from prediction.mIOU import MIouCalculator 15 | from fortest.predict_multi import ImagePredictor 16 | # from fortest.predict_multi import ckpt_path 17 | def main(): 18 | previous_best=0.770# the value of result 19 | predictor = ImagePredictor(previous_best) 20 | image_predictor = predictor.main(previous_best) 21 | collage = CollageGenerator(image_folder=image_predictor, 22 | image_list_file="...",# the path to test set 23 | output_folder="...",# the path to ouput 24 | group_size=150, 25 | rows=10, 26 | cols=15) 27 | pre_collage = collage.create_and_save_collages() 28 | MIou = MIouCalculator(mask_dir="...", # the label of test 29 | pred_dir=pre_collage) 30 | 31 | miou= MIou.compute_miou() 32 | print("mIOU:", miou) 33 | 34 | 35 | if __name__ == '__main__': 36 | main() -------------------------------------------------------------------------------- /util/dist_helper.py: -------------------------------------------------------------------------------- 1 | import os 2 | import subprocess 3 | 4 | import torch 5 | import torch.distributed as dist 6 | 7 | 8 | def setup_distributed(backend="nccl", port=None): 9 | """AdaHessian Optimizer 10 | Lifted from https://github.com/BIGBALLON/distribuuuu/blob/master/distribuuuu/utils.py 11 | Originally licensed MIT, Copyright (c) 2020 Wei Li 12 | """ 13 | num_gpus = torch.cuda.device_count() 14 | 15 | if "SLURM_JOB_ID" in os.environ: 16 | rank = int(os.environ["SLURM_PROCID"]) 17 | world_size = int(os.environ["SLURM_NTASKS"]) 18 | node_list = os.environ["SLURM_NODELIST"] 19 | addr = subprocess.getoutput(f"scontrol show hostname {node_list} | head -n1") 20 | # specify master port 21 | if port is not None: 22 | os.environ["MASTER_PORT"] = str(port) 23 | elif "MASTER_PORT" not in os.environ: 24 | os.environ["MASTER_PORT"] = "25063" 25 | if "MASTER_ADDR" not in os.environ: 26 | os.environ["MASTER_ADDR"] = addr 27 | os.environ["WORLD_SIZE"] = str(world_size) 28 | os.environ["LOCAL_RANK"] = str(rank % num_gpus) 29 | os.environ["RANK"] = str(rank) 30 | else: 31 | rank = int(os.environ["RANK"]) 32 | world_size = int(os.environ["WORLD_SIZE"]) 33 | 34 | torch.cuda.set_device(rank % num_gpus) 35 | 36 | dist.init_process_group( 37 | backend=backend, 38 | world_size=world_size, 39 | rank=rank, 40 | ) 41 | return rank, world_size 42 | -------------------------------------------------------------------------------- /result/README.md: -------------------------------------------------------------------------------- 1 | # BS-Mamba for Black-Soil Area Detection on the Qinghai-Tibetan Plateau 2 | Extremely degraded grassland on the Qinghai-Tibetan Plateau (QTP) presents a significant environmental challenge due to overgrazing, climate change, and rodent activity, which degrade vegetation cover and soil quality. These extremely degraded grassland on QTP, commonly referred to as black-soil area, require accurate assessment to guide effective restoration efforts. In this paper, we present a newly created QTP black-soil dataset, annotated under expert guidance. We introduce a novel neural network model, BS-Mamba, specifically designed for the black-soil area detection using UAV remote sensing imagery. The BS-Mamba model demonstrates higher accuracy in identifying black-soil area across two independent test datasets than the state-of-the-art models. This research contributes to grassland restoration by providing an efficient method for assessing the extent of black-soil area on the QTP. 3 | 4 | 5 | # Experiment 6 | ## Train QTP-BS dataset 7 | ```sh 8 | python train.py 9 | ``` 10 | 11 | ## Two test sets 12 | ```sh 13 | python test.py 14 | ``` 15 | 16 | # Citation 17 | We appreciate it if you cite the following paper: 18 | ``` 19 | @Article{maxjars2025, 20 | author = {Xuan Ma and Zewen Lv and Chengcai Ma and Tao Zhang and Yuelan Xin and Kun Zhan}, 21 | journal = {Journal of Applied Remote Sensing}, 22 | title = {BS-Mamba for Black-Soil Area Detection on the Qinghai-Tibetan Plateau}, 23 | year = {2025}, 24 | } 25 | ``` 26 | 27 | # Contact 28 | https://kunzhan.github.io/ 29 | 30 | If you have any questions, feel free to contact me. (Email: `ice.echo#gmail.com`) -------------------------------------------------------------------------------- /testdata/README.md: -------------------------------------------------------------------------------- 1 | # BS-Mamba for Black-Soil Area Detection on the Qinghai-Tibetan Plateau 2 | Extremely degraded grassland on the Qinghai-Tibetan Plateau (QTP) presents a significant environmental challenge due to overgrazing, climate change, and rodent activity, which degrade vegetation cover and soil quality. These extremely degraded grassland on QTP, commonly referred to as black-soil area, require accurate assessment to guide effective restoration efforts. In this paper, we present a newly created QTP black-soil dataset, annotated under expert guidance. We introduce a novel neural network model, BS-Mamba, specifically designed for the black-soil area detection using UAV remote sensing imagery. The BS-Mamba model demonstrates higher accuracy in identifying black-soil area across two independent test datasets than the state-of-the-art models. This research contributes to grassland restoration by providing an efficient method for assessing the extent of black-soil area on the QTP. 3 | 4 | 5 | # Experiment 6 | ## Train QTP-BS dataset 7 | ```sh 8 | python train.py 9 | ``` 10 | 11 | ## Two test sets 12 | ```sh 13 | python test.py 14 | ``` 15 | 16 | # Citation 17 | We appreciate it if you cite the following paper: 18 | ``` 19 | @Article{maxjars2025, 20 | author = {Xuan Ma and Zewen Lv and Chengcai Ma and Tao Zhang and Yuelan Xin and Kun Zhan}, 21 | journal = {Journal of Applied Remote Sensing}, 22 | title = {BS-Mamba for Black-Soil Area Detection on the Qinghai-Tibetan Plateau}, 23 | year = {2025}, 24 | } 25 | ``` 26 | 27 | # Contact 28 | https://kunzhan.github.io/ 29 | 30 | If you have any questions, feel free to contact me. (Email: `ice.echo#gmail.com`) -------------------------------------------------------------------------------- /prediction/mIOU.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import cv2 3 | import os 4 | from PIL import Image 5 | 6 | class MIouCalculator: 7 | def __init__(self, mask_dir, pred_dir): 8 | self.mask_dir = mask_dir 9 | self.pred_dir = pred_dir 10 | 11 | def calculate_miou(self,pred_path, mask_path): 12 | # 读取预测图像和标签图像并转换为 NumPy 数组 13 | pred_image = np.array(Image.open(pred_path)) 14 | mask_image = np.array(Image.open(mask_path)) 15 | 16 | # 预测图像和标签图像均为黑白图像,因此我们将其转换为二进制图像(0和1) 17 | pred_image = (pred_image > 0).astype(np.uint8) 18 | mask_image = (mask_image > 0).astype(np.uint8) 19 | 20 | # 计算交集和并集 21 | intersection = np.logical_and(pred_image, mask_image).sum() 22 | union = np.logical_or(pred_image, mask_image).sum() 23 | 24 | # 计算每个类别的 IoU 25 | iou = intersection / (union + 1e-10) 26 | 27 | # 计算 mIOU 28 | mIOU = np.mean(iou) * 100.0 29 | 30 | return mIOU 31 | 32 | def compute_miou(self): 33 | mask_dir = os.listdir(self.mask_dir) 34 | pred_dir = os.listdir(self.pred_dir) 35 | 36 | all_miou = [] 37 | for filename in os.listdir(self.pred_dir): 38 | if filename.endswith(".png"): 39 | pred_path = os.path.join(self.pred_dir, filename) 40 | mask_path = os.path.join(self.mask_dir, filename) 41 | # 计算单张图像的 mIOU 42 | mIOU = self.calculate_miou(pred_path, mask_path) # 假设只有两类,即黑色和白色 43 | all_miou.append(mIOU) 44 | 45 | print(f"{filename}: mIOU = {mIOU:.2f}") 46 | 47 | average_miou = np.mean(all_miou) 48 | return average_miou 49 | 50 | 51 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # BS-Mamba for Black-Soil Area Detection on the Qinghai-Tibetan Plateau 2 | - [arXiv](https://arxiv.org/abs/2503.12495) 3 | - QTP-BS dataset 4 | - [Download link](https://drive.google.com/file/d/1x91CinTrJd08omRcuY4ZMm7XPn1yWFzZ/view?usp=sharing) 5 | 6 | Extremely degraded grassland on the Qinghai-Tibetan Plateau (QTP) presents a significant environmental challenge due to overgrazing, climate change, and rodent activity, which degrade vegetation cover and soil quality. These extremely degraded grassland on QTP, commonly referred to as black-soil area, require accurate assessment to guide effective restoration efforts. In this paper, we present a newly created QTP black-soil dataset, annotated under expert guidance. We introduce a novel neural network model, BS-Mamba, specifically designed for the black-soil area detection using UAV remote sensing imagery. The BS-Mamba model demonstrates higher accuracy in identifying black-soil area across two independent test datasets than the state-of-the-art models. This research contributes to grassland restoration by providing an efficient method for assessing the extent of black-soil area on the QTP. 7 | 8 | ![framework](assets/figure_02.jpg) 9 | 10 | # Experiment 11 | ## Train QTP-BS dataset 12 | ```sh 13 | python train.py 14 | ``` 15 | 16 | ## two test sets 17 | ```sh 18 | python test.py 19 | ``` 20 | 21 | # Citation 22 | We appreciate it if you cite the following paper: 23 | ``` 24 | @Article{maxjars2025, 25 | author = {Xuan Ma and Zewen Lv and Chengcai Ma and Tao Zhang and Yuelan Xin and Kun Zhan}, 26 | journal = {Journal of Applied Remote Sensing}, 27 | title = {BS-Mamba for Black-Soil Area Detection on the Qinghai-Tibetan Plateau}, 28 | year = {2025}, 29 | } 30 | ``` 31 | 32 | # Contact 33 | https://kunzhan.github.io/ 34 | 35 | If you have any questions, feel free to contact me. (Email: `ice.echo#gmail.com`) -------------------------------------------------------------------------------- /prediction/collage.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | import numpy as np 4 | 5 | def load_images(image_folder, image_list): 6 | images = [] 7 | for image_name in image_list: 8 | image_path = os.path.join(image_folder, f"{image_name}.png") 9 | image = cv2.imread(image_path) 10 | images.append(image) 11 | return images 12 | 13 | def create_collage(images, rows, cols): 14 | row_images = [] 15 | for i in range(0, len(images), cols): 16 | row = np.hstack(images[i:i+cols]) 17 | row_images.append(row) 18 | collage = np.vstack(row_images) 19 | return collage 20 | 21 | class CollageGenerator: 22 | def __init__(self, image_folder, image_list_file, output_folder, group_size=150, rows=10, cols=15): 23 | self.image_folder = image_folder 24 | self.image_list_file = image_list_file 25 | self.output_folder = output_folder 26 | self.group_size = group_size 27 | self.rows = rows 28 | self.cols = cols 29 | 30 | def create_and_save_collages(self): 31 | # 读取图像列表文件 32 | with open(self.image_list_file, 'r') as f: 33 | image_list = [line.strip() for line in f.readlines()] 34 | 35 | # 加载图像 36 | images = load_images(self.image_folder, image_list) 37 | 38 | # 将小图按照每 group_size 张一组分组 39 | grouped_images = [images[i:i+self.group_size] for i in range(0, len(images), self.group_size)] 40 | 41 | os.makedirs(self.output_folder, exist_ok=True) 42 | for idx, group in enumerate(grouped_images): 43 | first_image_name = image_list[idx * self.group_size] 44 | name_parts = first_image_name.split("_")[:2] 45 | collage_name = "_".join(name_parts) 46 | 47 | collage = create_collage(group, self.rows, self.cols) 48 | cv2.imwrite(os.path.join(self.output_folder, f"{collage_name}.png"), collage) 49 | return self.output_folder 50 | if __name__ == "__main__": 51 | 52 | CollageGenerator() 53 | -------------------------------------------------------------------------------- /dataset/data.py: -------------------------------------------------------------------------------- 1 | from dataset.transform import * 2 | from copy import deepcopy 3 | import math 4 | import numpy as np 5 | import os 6 | import random 7 | # import ipdb 8 | from PIL import Image 9 | import torch 10 | from torch.utils.data import Dataset 11 | from torchvision import transforms 12 | import yaml 13 | import argparse 14 | 15 | class BSDataset(Dataset): 16 | def __init__(self, name, root, mode, size=None, nsample=None): 17 | self.name = name 18 | self.root = "/data/grassset2_80/" 19 | self.id_path = self.root + "train.txt" 20 | self.mode = mode 21 | # crop size 22 | self.size = size 23 | self.strong_aug = strong_img_aug() # 初始化强增强 24 | self.val = self.root + "val/" 25 | val_path = self.root + "val.txt" 26 | if mode == 'train_l' or mode == 'train_u': 27 | with open(self.id_path, 'r') as f: 28 | # with open(mini_path, 'r') as f: 29 | self.ids = f.read().splitlines() 30 | random.shuffle(self.ids) 31 | if mode == 'train_l' and nsample is not None: 32 | self.ids *= math.ceil(nsample / len(self.ids)) 33 | random.shuffle(self.ids) 34 | self.ids = self.ids[:nsample] 35 | else: 36 | with open(val_path, 'r') as f: 37 | self.ids = f.read().splitlines() 38 | 39 | 40 | def __getitem__(self, item): 41 | id = self.ids[item] 42 | if self.mode == 'train_l' or self.mode == 'train_u': 43 | img = Image.open(self.root + "train/image/" + id+'.png').convert('RGB') 44 | mask = Image.fromarray(np.array(Image.open(self.root + "train/mask/" + id+'.png' ))/255) 45 | # img, mask = hflip(img, mask, p=0.5) # 随机水平翻转 46 | # img = self.strong_aug(img) # 随机应用一些图像增强 47 | mask = torch.from_numpy(np.array(mask)).long() 48 | cutmix_box = obtain_cutmix_box(img.size[0], p=0.5) 49 | img= normalize(img) 50 | return img, mask, cutmix_box 51 | 52 | else: 53 | img = Image.open( self.root + "val/image/" + id+'.png').convert('RGB') 54 | mask = Image.fromarray(np.array(Image.open(self.root + "val/mask/" + id+'.png' ))/255) 55 | img, mask = normalize(img, mask) 56 | return img, mask, id 57 | 58 | 59 | def __len__(self): 60 | return len(self.ids) -------------------------------------------------------------------------------- /DiceLoss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | class DiceLoss(nn.Module): 4 | def __init__(self, n_classes): 5 | super(DiceLoss, self).__init__() 6 | self.n_classes = n_classes 7 | 8 | def _one_hot_encoder(self, input_tensor): 9 | tensor_list = [] 10 | for i in range(self.n_classes): 11 | temp_prob = input_tensor == i * torch.ones_like(input_tensor) 12 | tensor_list.append(temp_prob) 13 | output_tensor = torch.cat(tensor_list, dim=1) 14 | return output_tensor.float() 15 | 16 | def _dice_loss(self, score, target, ignore): 17 | target = target.float() 18 | smooth = 1e-5 19 | intersect = torch.sum(score[ignore != 1] * target[ignore != 1]) 20 | y_sum = torch.sum(target[ignore != 1] * target[ignore != 1]) 21 | z_sum = torch.sum(score[ignore != 1] * score[ignore != 1]) 22 | loss = (2 * intersect + smooth) / (z_sum + y_sum + smooth) 23 | loss = 1 - loss 24 | return loss 25 | 26 | def forward(self, inputs, target, weight=None, softmax=True, ignore=None): 27 | if softmax: 28 | inputs = torch.softmax(inputs, dim=1) 29 | target = self._one_hot_encoder(target) 30 | if weight is None: 31 | weight = [1] * self.n_classes 32 | assert inputs.size() == target.size(), 'predict & target shape do not match' 33 | class_wise_dice = [] 34 | loss = 0.0 35 | for i in range(0, self.n_classes): 36 | dice = self._dice_loss(inputs[:, i], target[:, i], ignore) 37 | class_wise_dice.append(1.0 - dice.item()) 38 | loss += dice * weight[i] 39 | return loss / self.n_classes 40 | 41 | 42 | class BCELoss(nn.Module): 43 | def __init__(self): 44 | super(BCELoss, self).__init__() 45 | self.bceloss = nn.BCELoss() 46 | 47 | def forward(self, pred, target): 48 | size = pred.size(0) 49 | pred_ = pred.view(size, -1) 50 | target_ = target.view(size, -1) 51 | 52 | return self.bceloss(pred_, target_) 53 | 54 | class IouLoss(nn.Module): 55 | def __init__(self, reduction='mean'): 56 | super(IouLoss, self).__init__() 57 | self.reduction = reduction 58 | 59 | def forward(self, inputs, targets, smooth=1): 60 | # 该代码是二分类代码 61 | """ 62 | output : NxCxHxW Variable 63 | target : NxHxW LongTensor 64 | """ 65 | # 如果inputs没有归一化可以先归一化 66 | inputs = torch.softmax(inputs,dim=1) 67 | # 因为loss计算的是pred和targets的正负样本交集并集,所以pred中的预测值(0~1)需要转为0和1的标签值 68 | # inputs从NxCxHxW装变为NxHxW,且里面不是预测值而是0和1标签值 69 | inputs = torch.argmax(inputs, 1).squeeze(0)# 大于0.5概率变0或者1 70 | # IOU公式计算 71 | intersection = (inputs * targets).sum() 72 | total = (inputs + targets).sum() 73 | union = total - intersection 74 | 75 | Iou_loss = 1- (intersection + smooth)/(union + smooth)# smooth防止分母为0 76 | 77 | if self.reduction == 'mean': 78 | return Iou_loss.mean() 79 | elif self.reduction == 'sum': 80 | return Iou_loss.sum() 81 | else: 82 | return Iou_loss 83 | -------------------------------------------------------------------------------- /util/evaluate.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import numpy as np 4 | from sklearn.metrics import confusion_matrix 5 | from DiceLoss import DiceLoss,IouLoss 6 | 7 | def evaluate_add(model, loader): 8 | return_dict = {} # 创建一个空字典,用于存储评估结果 9 | model.eval() # 将模型设为评估模式 10 | # 确保 mode 参数在指定的三种模式中 11 | 12 | # 定义损失函数(交叉熵损失) 13 | criterion_ce = nn.CrossEntropyLoss() 14 | criterion_iou= IouLoss('mean') 15 | total_loss = 0 # 初始化总损失累加器 16 | mIOU_list = [] # 用于存储每批150张图像的平均 IoU 值 17 | f1_list = [] # 用于存储每批150张图像的平均 F1 值 18 | accuracy_list = [] # 用于存储每批150张图像的平均准确率 19 | sensitivity_list = [] # 用于存储每批150张图像的平均敏感度 20 | specificity_list = [] # 用于存储每批150张图像的平均特异性 21 | 22 | epoch_tn, epoch_fp, epoch_fn, epoch_tp = 0, 0, 0, 0 # 初始化全局累加器 23 | num_large_images = 0 # 用于记录大图的数量 24 | 25 | with torch.no_grad(): # 在评估阶段,禁用梯度计算 26 | # 遍历数据加载器中的批次数据 27 | for i, (img, mask, ids) in enumerate(loader): 28 | img = img.cuda() # 将输入图像移动到 GPU 29 | res = model(img) # 通过模型进行前向传播,得到预测结果 30 | pred = res.argmax(dim=1) # 获取预测的类别标签(沿类别维度取最大值) 31 | pred_np = np.array(pred.cpu()).squeeze(axis=0).reshape(-1) # 将预测结果转换为 numpy 数组 32 | gt_np = np.array(mask).squeeze(axis=0).reshape(-1) # 将真实标签转换为 numpy 数组 33 | # 计算总损失 34 | loss_ce = criterion_ce(res.cpu(), mask) 35 | loss_iou = criterion_iou(res.cpu(), mask) 36 | loss = 0.5*loss_iou + 0.5*loss_ce 37 | total_loss += loss.item() 38 | 39 | confusion = confusion_matrix(gt_np, pred_np, labels=[0, 1]) 40 | TN, FP, FN, TP = confusion[0, 0], confusion[0, 1], confusion[1, 0], confusion[1, 1] 41 | 42 | epoch_tn += TN 43 | epoch_fp += FP 44 | epoch_fn += FN 45 | epoch_tp += TP 46 | 47 | # 每150张小图计算一次总的指标 48 | if (i + 1) % 150 == 0: 49 | accuracy = float(epoch_tn + epoch_tp) / float(epoch_tn + epoch_fp + epoch_fn + epoch_tp) if float(epoch_tn + epoch_fp + epoch_fn + epoch_tp) != 0 else 0 50 | sensitivity = float(epoch_tp) / float(epoch_tp + epoch_fn) if float(epoch_tp + epoch_fn) != 0 else 0 51 | specificity = float(epoch_tn) / float(epoch_tn + epoch_fp) if float(epoch_tn + epoch_fp) != 0 else 0 52 | f1_or_dsc = float(2 * epoch_tp) / float(2 * epoch_tp + epoch_fp + epoch_fn) if float(2 * epoch_tp + epoch_fp + epoch_fn) != 0 else 0 53 | miou = float(epoch_tp) / float(epoch_tp + epoch_fp + epoch_fn) if float(epoch_tp + epoch_fp + epoch_fn) != 0 else 0 54 | 55 | mIOU_list.append(miou) # 存储 mIOU 56 | f1_list.append(f1_or_dsc) # 存储 F1 值 57 | accuracy_list.append(accuracy) # 存储准确率 58 | sensitivity_list.append(sensitivity) # 存储敏感度 59 | specificity_list.append(specificity) # 存储特异性 60 | 61 | # 重置累加器 62 | epoch_tn, epoch_fp, epoch_fn, epoch_tp = 0, 0, 0, 0 63 | 64 | # 计算平均指标 65 | mean_mIOU = np.mean(mIOU_list) 66 | mean_f1 = np.mean(f1_list) 67 | mean_accuracy = np.mean(accuracy_list) 68 | mean_sensitivity = np.mean(sensitivity_list) 69 | mean_specificity = np.mean(specificity_list) 70 | loss = total_loss / len(loader) 71 | 72 | return_dict['iou_class'] = mIOU_list 73 | return_dict['mean_mIOU'] = mean_mIOU # 所有批次的平均 mIOU 74 | return_dict['f1_or_dsc'] = mean_f1 75 | return_dict['accuracy'] = mean_accuracy 76 | return_dict['sensitivity'] = mean_sensitivity 77 | return_dict['specificity'] = mean_specificity 78 | return_dict['Loss_val'] = loss # 整体损失值 79 | 80 | return return_dict 81 | 82 | 83 | -------------------------------------------------------------------------------- /baseline/mamba_unet.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | from __future__ import absolute_import 3 | from __future__ import division 4 | from __future__ import print_function 5 | 6 | import copy 7 | import logging 8 | import math 9 | 10 | from os.path import join as pjoin 11 | 12 | import torch 13 | import torch.nn as nn 14 | import numpy as np 15 | 16 | from torch.nn import CrossEntropyLoss, Dropout, Softmax, Linear, Conv2d, LayerNorm 17 | from torch.nn.modules.utils import _pair 18 | from scipy import ndimage 19 | from mamba_sys import VSSM 20 | # from vmamba import VSSM 21 | 22 | logger = logging.getLogger(__name__) 23 | 24 | class MambaUnet(nn.Module): 25 | def __init__(self, img_size=384, num_classes=21843, zero_head=False, vis=False): 26 | super(MambaUnet, self).__init__() 27 | self.num_classes = num_classes 28 | self.zero_head = zero_head 29 | 30 | 31 | self.mamba_unet = VSSM( 32 | img_size=384, 33 | in_chans=3, 34 | num_classes=self.num_classes, 35 | embed_dim=96, 36 | depths=[2, 2, 6, 2], 37 | num_heads=[3, 6, 12, 24], 38 | window_size=11, 39 | mlp_ratio=4., 40 | qkv_bias=True, 41 | qk_scale=None, 42 | drop_rate=0.0, 43 | drop_path_rate=0.1, 44 | ape=False, 45 | patch_norm=True, 46 | use_checkpoint=False) 47 | 48 | def forward(self, x): 49 | if x.size()[1] == 1: 50 | x = x.repeat(1,3,1,1) 51 | logits = self.mamba_unet(x) 52 | return logits 53 | 54 | def load_from(self, config): 55 | pretrained_path = config.MODEL.PRETRAIN_CKPT 56 | if pretrained_path is not None: 57 | print("pretrained_path:{}".format(pretrained_path)) 58 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 59 | pretrained_dict = torch.load(pretrained_path, map_location=device) 60 | if "model" not in pretrained_dict: 61 | print("---start load pretrained modle by splitting---") 62 | pretrained_dict = {k[17:]:v for k,v in pretrained_dict.items()} 63 | for k in list(pretrained_dict.keys()): 64 | if "output" in k: 65 | print("delete key:{}".format(k)) 66 | del pretrained_dict[k] 67 | msg = self.mamba_unet.load_state_dict(pretrained_dict,strict=False) 68 | # print(msg) 69 | return 70 | pretrained_dict = pretrained_dict['model'] 71 | print("---start load pretrained modle of swin encoder---") 72 | 73 | model_dict = self.mamba_unet.state_dict() 74 | full_dict = copy.deepcopy(pretrained_dict) 75 | for k, v in pretrained_dict.items(): 76 | if "layers." in k: 77 | current_layer_num = 3-int(k[7:8]) 78 | current_k = "layers_up." + str(current_layer_num) + k[8:] 79 | full_dict.update({current_k:v}) 80 | for k in list(full_dict.keys()): 81 | if k in model_dict: 82 | if full_dict[k].shape != model_dict[k].shape: 83 | print("delete:{};shape pretrain:{};shape model:{}".format(k,v.shape,model_dict[k].shape)) 84 | del full_dict[k] 85 | 86 | msg = self.mamba_unet.load_state_dict(full_dict, strict=False) 87 | # print(msg) 88 | else: 89 | print("none pretrain") 90 | -------------------------------------------------------------------------------- /baseline/unet.py: -------------------------------------------------------------------------------- 1 | from typing import Dict 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | 7 | class DoubleConv(nn.Sequential): #定义两个卷积层 8 | def __init__(self, in_channels, out_channels, mid_channels=None): 9 | if mid_channels is None: 10 | mid_channels = out_channels 11 | super(DoubleConv, self).__init__( 12 | nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1, bias=False), 13 | nn.BatchNorm2d(mid_channels), 14 | nn.ReLU(inplace=True), 15 | nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1, bias=False), 16 | nn.BatchNorm2d(out_channels), 17 | nn.ReLU(inplace=True) 18 | ) 19 | 20 | 21 | class Down(nn.Sequential): 22 | def __init__(self, in_channels, out_channels): 23 | super(Down, self).__init__( 24 | nn.MaxPool2d(2, stride=2), 25 | DoubleConv(in_channels, out_channels) 26 | ) 27 | 28 | #输入的特征图进行上采样操作,然后通过一个卷积层进行特征提取 29 | class Up(nn.Module): 30 | def __init__(self, in_channels, out_channels, bilinear=True): 31 | super(Up, self).__init__() 32 | if bilinear: 33 | self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) 34 | self.conv = DoubleConv(in_channels, out_channels, in_channels // 2) 35 | else: 36 | self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2) 37 | self.conv = DoubleConv(in_channels, out_channels) 38 | 39 | #前向传播过程 40 | def forward(self, x1: torch.Tensor, x2: torch.Tensor) -> torch.Tensor: 41 | x1 = self.up(x1) 42 | # [N, C, H, W] 43 | diff_y = x2.size()[2] - x1.size()[2] 44 | diff_x = x2.size()[3] - x1.size()[3] 45 | 46 | # padding_left, padding_right, padding_top, padding_bottom 47 | x1 = F.pad(x1, [diff_x // 2, diff_x - diff_x // 2, 48 | diff_y // 2, diff_y - diff_y // 2]) 49 | 50 | x = torch.cat([x2, x1], dim=1) 51 | x = self.conv(x) 52 | return x 53 | 54 | 55 | class OutConv(nn.Sequential): 56 | def __init__(self, in_channels, num_classes): 57 | super(OutConv, self).__init__( 58 | nn.Conv2d(in_channels, num_classes, kernel_size=1) 59 | ) 60 | 61 | 62 | class UNet(nn.Module): 63 | def __init__(self, 64 | in_channels: int = 3, 65 | num_classes: int = 2, 66 | bilinear: bool = True, 67 | base_c: int = 64): 68 | super(UNet, self).__init__() 69 | self.in_channels = in_channels 70 | self.num_classes = num_classes 71 | self.bilinear = bilinear 72 | 73 | self.in_conv = DoubleConv(in_channels, base_c) 74 | self.down1 = Down(base_c, base_c * 2) 75 | self.down2 = Down(base_c * 2, base_c * 4) 76 | self.down3 = Down(base_c * 4, base_c * 8) 77 | factor = 2 if bilinear else 1 78 | self.down4 = Down(base_c * 8, base_c * 16 // factor) 79 | self.up1 = Up(base_c * 16, base_c * 8 // factor, bilinear) 80 | self.up2 = Up(base_c * 8, base_c * 4 // factor, bilinear) 81 | self.up3 = Up(base_c * 4, base_c * 2 // factor, bilinear) 82 | self.up4 = Up(base_c * 2, base_c, bilinear) 83 | self.out_conv = OutConv(base_c, num_classes) 84 | 85 | def forward(self, x: torch.Tensor) -> Dict[str, torch.Tensor]: 86 | 87 | x1 = self.in_conv(x) 88 | x2 = self.down1(x1) 89 | x3 = self.down2(x2) 90 | x4 = self.down3(x3) 91 | x5 = self.down4(x4) 92 | x = self.up1(x5, x4) 93 | x = self.up2(x, x3) 94 | x = self.up3(x, x2) 95 | x = self.up4(x, x1) 96 | logits = self.out_conv(x) 97 | # from ipdb import set_trace 98 | # set_trace() 99 | return logits 100 | -------------------------------------------------------------------------------- /util/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import logging 3 | import os 4 | 5 | 6 | def count_params(model): 7 | param_num = sum(p.numel() for p in model.parameters()) 8 | return param_num / 1e6 9 | 10 | 11 | def color_map(dataset='pascal'): 12 | cmap = np.zeros((256, 3), dtype='uint8') 13 | 14 | if dataset == 'pascal' or dataset == 'coco': 15 | def bitget(byteval, idx): 16 | return (byteval & (1 << idx)) != 0 17 | 18 | for i in range(256): 19 | r = g = b = 0 20 | c = i 21 | for j in range(8): 22 | r = r | (bitget(c, 0) << 7-j) 23 | g = g | (bitget(c, 1) << 7-j) 24 | b = b | (bitget(c, 2) << 7-j) 25 | c = c >> 3 26 | 27 | cmap[i] = np.array([r, g, b]) 28 | 29 | elif dataset == 'cityscapes': 30 | cmap[0] = np.array([128, 64, 128]) 31 | cmap[1] = np.array([244, 35, 232]) 32 | cmap[2] = np.array([70, 70, 70]) 33 | cmap[3] = np.array([102, 102, 156]) 34 | cmap[4] = np.array([190, 153, 153]) 35 | cmap[5] = np.array([153, 153, 153]) 36 | cmap[6] = np.array([250, 170, 30]) 37 | cmap[7] = np.array([220, 220, 0]) 38 | cmap[8] = np.array([107, 142, 35]) 39 | cmap[9] = np.array([152, 251, 152]) 40 | cmap[10] = np.array([70, 130, 180]) 41 | cmap[11] = np.array([220, 20, 60]) 42 | cmap[12] = np.array([255, 0, 0]) 43 | cmap[13] = np.array([0, 0, 142]) 44 | cmap[14] = np.array([0, 0, 70]) 45 | cmap[15] = np.array([0, 60, 100]) 46 | cmap[16] = np.array([0, 80, 100]) 47 | cmap[17] = np.array([0, 0, 230]) 48 | cmap[18] = np.array([119, 11, 32]) 49 | 50 | return cmap 51 | 52 | 53 | class AverageMeter(object): 54 | """Computes and stores the average and current value""" 55 | 56 | def __init__(self, length=0): 57 | self.length = length 58 | self.reset() 59 | 60 | def reset(self): 61 | if self.length > 0: 62 | self.history = [] 63 | else: 64 | self.count = 0 65 | self.sum = 0.0 66 | self.val = 0.0 67 | self.avg = 0.0 68 | 69 | def update(self, val, num=1): 70 | if self.length > 0: 71 | # currently assert num==1 to avoid bad usage, refine when there are some explict requirements 72 | assert num == 1 73 | self.history.append(val) 74 | if len(self.history) > self.length: 75 | del self.history[0] 76 | 77 | self.val = self.history[-1] 78 | self.avg = np.mean(self.history) 79 | else: 80 | self.val = val 81 | self.sum += val * num 82 | self.count += num 83 | self.avg = self.sum / self.count 84 | 85 | 86 | def intersectionAndUnion(output, target, K, ignore_index=255): 87 | # 'K' classes, output and target sizes are N or N * L or N * H * W, each value in range 0 to K - 1. 88 | assert output.ndim in [1, 2, 3] 89 | assert output.shape == target.shape 90 | output = output.reshape(output.size).copy() 91 | target = target.reshape(target.size) 92 | output[np.where(target == ignore_index)[0]] = ignore_index 93 | intersection = output[np.where(output == target)[0]] 94 | area_intersection, _ = np.histogram(intersection, bins=np.arange(K + 1)) 95 | area_output, _ = np.histogram(output, bins=np.arange(K + 1)) 96 | area_target, _ = np.histogram(target, bins=np.arange(K + 1)) 97 | area_union = area_output + area_target - area_intersection 98 | return area_intersection, area_union, area_target 99 | 100 | 101 | logs = set() 102 | 103 | 104 | def init_log(name, level=logging.INFO): 105 | if (name, level) in logs: 106 | return 107 | logs.add((name, level)) 108 | logger = logging.getLogger(name) 109 | logger.setLevel(level) 110 | ch = logging.StreamHandler() 111 | ch.setLevel(level) 112 | if "SLURM_PROCID" in os.environ: 113 | rank = int(os.environ["SLURM_PROCID"]) 114 | logger.addFilter(lambda record: rank == 0) 115 | else: 116 | rank = 0 117 | format_str = "[%(asctime)s][%(levelname)8s] %(message)s" 118 | formatter = logging.Formatter(format_str) 119 | ch.setFormatter(formatter) 120 | logger.addHandler(ch) 121 | return logger 122 | -------------------------------------------------------------------------------- /result.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | from PIL import Image 4 | 5 | class MIouCalculator: 6 | def __init__(self, mask_dir, pred_dir): 7 | self.mask_dir = mask_dir 8 | self.pred_dir = pred_dir 9 | 10 | def calculate_metrics(self, pred_image, mask_image): 11 | # 计算 TP, FP, TN, FN 12 | tp = np.sum((pred_image == 1) & (mask_image == 1)) # True Positive 13 | tn = np.sum((pred_image == 0) & (mask_image == 0)) # True Negative 14 | fp = np.sum((pred_image == 1) & (mask_image == 0)) # False Positive 15 | fn = np.sum((pred_image == 0) & (mask_image == 1)) # False Negative 16 | 17 | # 计算 Accuracy 和 F1 Score 18 | accuracy = float(tp + tn) / float(tp + tn + fp + fn) if (tp + tn + fp + fn) != 0 else 0 19 | f1_score = float(2 * tp) / float(2 * tp + fp + fn) if (2 * tp + fp + fn) != 0 else 0 20 | 21 | return accuracy, f1_score 22 | 23 | def calculate_miou(self, pred_image, mask_image): 24 | # 计算前景和背景的 IoU 25 | ious = {} 26 | metrics = {} 27 | 28 | for category, (pred_binary, mask_binary) in {"foreground": (pred_image > 0, mask_image > 0), 29 | "background": (pred_image == 0, mask_image == 0)}.items(): 30 | pred_binary = pred_binary.astype(np.uint8) 31 | mask_binary = mask_binary.astype(np.uint8) 32 | 33 | intersection = np.logical_and(pred_binary, mask_binary).sum() 34 | union = np.logical_or(pred_binary, mask_binary).sum() 35 | 36 | # 计算 IoU 37 | iou = intersection / (union + 1e-10) 38 | ious[category] = iou 39 | 40 | # 计算 Accuracy 和 F1 Score 41 | accuracy, f1_score = self.calculate_metrics(pred_binary, mask_binary) 42 | metrics[category] = {"accuracy": accuracy * 100.0, "f1_score": f1_score * 100.0} 43 | 44 | return ious, metrics 45 | 46 | def compute_miou(self): 47 | all_ious = {"foreground": [], "background": []} 48 | all_metrics = {"foreground": {"accuracy": [], "f1_score": []}, "background": {"accuracy": [], "f1_score": []}} 49 | 50 | for filename in os.listdir(self.pred_dir): 51 | if filename.endswith(".png"): 52 | pred_path = os.path.join(self.pred_dir, filename) 53 | mask_path = os.path.join(self.mask_dir, filename) 54 | 55 | # 读取图像 56 | pred_image = np.array(Image.open(pred_path).convert("L")) 57 | mask_image = np.array(Image.open(mask_path).convert("L")) 58 | 59 | # 计算前景和背景的 IoU、Accuracy 和 F1 Score 60 | ious, metrics = self.calculate_miou(pred_image, mask_image) 61 | 62 | for category in ["foreground", "background"]: 63 | all_ious[category].append(ious[category] * 100.0) 64 | all_metrics[category]["accuracy"].append(metrics[category]["accuracy"]) 65 | all_metrics[category]["f1_score"].append(metrics[category]["f1_score"]) 66 | 67 | # 计算前景和背景的平均值 68 | average_ious = {category: np.mean(all_ious[category]) for category in all_ious} 69 | average_metrics = {category: {metric: np.mean(all_metrics[category][metric]) 70 | for metric in all_metrics[category]} 71 | for category in all_metrics} 72 | 73 | # 计算总体平均值(前景和背景结合) 74 | overall_average_miou = np.mean([average_ious["foreground"], average_ious["background"]]) 75 | overall_average_accuracy = np.mean([average_metrics["foreground"]["accuracy"], average_metrics["background"]["accuracy"]]) 76 | overall_average_f1_score = np.mean([average_metrics["foreground"]["f1_score"], average_metrics["background"]["f1_score"]]) 77 | 78 | # 输出最终结果 79 | print("Average Metrics:") 80 | print(f"Foreground: mIoU = {average_ious['foreground']:.2f}, " 81 | f"Accuracy = {average_metrics['foreground']['accuracy']:.2f}, " 82 | f"F1 Score = {average_metrics['foreground']['f1_score']:.2f}") 83 | print(f"Background: mIoU = {average_ious['background']:.2f}, " 84 | f"Accuracy = {average_metrics['background']['accuracy']:.2f}, " 85 | f"F1 Score = {average_metrics['background']['f1_score']:.2f}") 86 | print(f"Overall: mIoU = {overall_average_miou:.2f}, " 87 | f"Accuracy = {overall_average_accuracy:.2f}, " 88 | f"F1 Score = {overall_average_f1_score:.2f}") 89 | 90 | print(f"Overall Average mIoU: {overall_average_miou:.2f}, " 91 | f"Accuracy: {overall_average_accuracy:.2f}, " 92 | f"F1 Score: {overall_average_f1_score:.2f}") 93 | 94 | return overall_average_miou, overall_average_accuracy, overall_average_f1_score 95 | 96 | # 使用示例 97 | mask_directory = "..." # 替换为真实掩码文件夹路径 98 | pred_directory = "/home/mac/gdnet_12/testdata/gdnet_1_0.779/" # 替换为真实预测文件夹路径 99 | 100 | miou_calculator = MIouCalculator(mask_directory, pred_directory) 101 | average_miou, average_accuracy, average_f1_score = miou_calculator.compute_miou() 102 | -------------------------------------------------------------------------------- /fortest/predict_multi.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import torch 4 | from torchvision import transforms 5 | import numpy as np 6 | from PIL import Image 7 | import yaml 8 | import argparse 9 | from baseline.BS_Mamba import BS_Mamba 10 | # from baseline.gdnet import gdnet 11 | 12 | 13 | 14 | class ImagePredictor: 15 | def __init__(self,previous_best): 16 | self.parser = argparse.ArgumentParser(description='Black_soil_detection_net') 17 | self.parser.add_argument('--model', default="415", type=str) 18 | self.parser.add_argument('--config', default="./configs/BlackSoil.yaml", type=str) 19 | self.previous_best = previous_best 20 | 21 | 22 | def main(self,previous_best): 23 | args = self.parser.parse_args() 24 | cfg = yaml.load(open(args.config, "r"), Loader=yaml.Loader) 25 | mIOU = previous_best 26 | 27 | mdl = '%s_%.3f' % ('gdnet', mIOU)# mdl = 'UNet_24' 28 | 29 | 30 | dataset = 'test2_0.770' 31 | 32 | test_data ='/data/grassset2_8/sample2/image2_crop/' 33 | to_test = {'test':test_data} 34 | weights_path = "/home/mac/gdnet_1/result/" + mdl +'.pth' 35 | 36 | ckpt_path = './testdata/' + dataset + mdl 37 | 38 | if not os.path.exists(ckpt_path): 39 | os.mkdir(ckpt_path) 40 | 41 | # roi_mask_path = "./DRIVE/test/mask/01_test_mask.gif" 42 | assert os.path.exists(weights_path), f"weights {weights_path} not found." 43 | # assert os.path.exists(img_path), f"image {img_path} not found." 44 | # assert os.path.exists(roi_mask_path), f"image {roi_mask_path} not found." 45 | # [0.485, 0.456, 0.406], [0.229, 0.224, 0.225] 46 | # mean = (0.709, 0.381, 0.224) # Ori 47 | # std = (0.127, 0.079, 0.043) 48 | mean = (0.485, 0.456, 0.406) # Cor 49 | std = (0.229, 0.224, 0.225) 50 | # mean = (0.342, 0.413, 0.359) # Trian 51 | # std = (0.085, 0.094, 0.091) 52 | 53 | # get devices 54 | device = torch.device("cuda:5" if torch.cuda.is_available() else "cpu") 55 | print("using {} device.".format(device)) 56 | 57 | # create model 58 | model = BS_Mamba() 59 | 60 | 61 | # load weights 62 | # model.load_state_dict(torch.load(weights_path, map_location='cpu')['model']) 63 | model.load_state_dict(torch.load(weights_path, map_location='cpu')) 64 | ''' 65 | # 加载权重参数 66 | state_dict = torch.load(weights_path, map_location='cpu') 67 | 68 | # 在加载的状态字典中查看权重参数的形状 69 | weight_shape = state_dict['in_conv.0.weight'].shape 70 | print("Weight shape:", weight_shape) 71 | 72 | # 使用加载的状态字典来加载模型的权重参数 73 | model.load_state_dict(state_dict) 74 | ''' 75 | model.to(device) 76 | model.eval() 77 | 78 | # load roi mask 79 | # roi_img = Image.open(roi_mask_path).convert('L') 80 | # roi_img = np.array(roi_img) 81 | ''' 82 | original_img = Image.open(img_path).convert('RGB') 83 | 84 | # from pil image to tensor and normalize 85 | data_transform = transforms.Compose([transforms.ToTensor(), 86 | transforms.Normalize(mean=mean, std=std)]) 87 | img = data_transform(original_img) 88 | # expand batch dimension 89 | img = torch.unsqueeze(img, dim=0) 90 | ''' 91 | 92 | # load image 93 | with torch.no_grad(): 94 | for name, root in to_test.items(): 95 | ''' 96 | name = test 97 | root = /home/ljs/code/cloud/AIR_CD/Test_Images_png 98 | ''' 99 | # 获取图片名称list 100 | img_list = [os.path.splitext(f)[0] for f in os.listdir(root) if f.endswith('.png')] 101 | # 开始计时 102 | # t_start = self.time_synchronized() 103 | # 图像处理 104 | data_transform = transforms.Compose([transforms.ToTensor(), 105 | transforms.Normalize(mean=mean, std=std)]) 106 | for idx, img_name in enumerate(img_list): 107 | print ('predicting for %s: %d / %d' % (name, idx + 1, len(img_list))) 108 | original_img = Image.open(os.path.join(root, img_name +'.png')).convert('RGB') 109 | img = data_transform(original_img) 110 | 111 | img = torch.unsqueeze(img, dim=0) 112 | # init model? 113 | img_height, img_width = img.shape[-2:] 114 | init_img = torch.zeros((1, 3, img_height, img_width), device=device) 115 | model(init_img) 116 | 117 | output = model(img.to(device)) 118 | 119 | prediction = output.argmax(1).squeeze(0) 120 | 121 | prediction = prediction.to("cpu").numpy().astype(np.uint8) 122 | # # 将前景对应的像素值改成255(白色) 123 | prediction[prediction == 1] = 255 124 | mask = Image.fromarray(prediction) 125 | mask.save(os.path.join(ckpt_path, img_name + '.png')) 126 | 127 | 128 | output_value = ckpt_path 129 | return output_value 130 | 131 | 132 | 133 | -------------------------------------------------------------------------------- /baseline/resnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | __all__ = ['ResNet', 'resnet50', 'resnet101'] 6 | 7 | 8 | def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): 9 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 10 | padding=dilation, groups=groups, bias=False, dilation=dilation) 11 | 12 | 13 | def conv1x1(in_planes, out_planes, stride=1): 14 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) 15 | 16 | 17 | class Bottleneck(nn.Module): 18 | expansion = 4 19 | 20 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, 21 | base_width=64, dilation=1, norm_layer=None): 22 | super(Bottleneck, self).__init__() 23 | if norm_layer is None: 24 | norm_layer = nn.BatchNorm2d 25 | width = int(planes * (base_width / 64.)) * groups 26 | 27 | self.conv1 = conv1x1(inplanes, width) 28 | self.bn1 = norm_layer(width) 29 | self.conv2 = conv3x3(width, width, stride, groups, dilation) 30 | self.bn2 = norm_layer(width) 31 | self.conv3 = conv1x1(width, planes * self.expansion) 32 | self.bn3 = norm_layer(planes * self.expansion) 33 | self.relu = nn.ReLU(inplace=True) 34 | self.downsample = downsample 35 | self.stride = stride 36 | 37 | def forward(self, x): 38 | identity = x 39 | 40 | out = self.conv1(x) 41 | out = self.bn1(out) 42 | out = self.relu(out) 43 | 44 | out = self.conv2(out) 45 | out = self.bn2(out) 46 | out = self.relu(out) 47 | 48 | out = self.conv3(out) 49 | out = self.bn3(out) 50 | 51 | if self.downsample is not None: 52 | identity = self.downsample(x) 53 | 54 | out += identity 55 | out = self.relu(out) 56 | 57 | return out 58 | 59 | 60 | class ResNet(nn.Module): 61 | 62 | def __init__(self, block, layers, zero_init_residual=False, groups=1, 63 | width_per_group=64, replace_stride_with_dilation=None, norm_layer=None): 64 | super(ResNet, self).__init__() 65 | 66 | if norm_layer is None: 67 | norm_layer = nn.BatchNorm2d 68 | self._norm_layer = norm_layer 69 | 70 | self.inplanes = 128 71 | self.dilation = 1 72 | if replace_stride_with_dilation is None: 73 | replace_stride_with_dilation = [False, False, False] 74 | if len(replace_stride_with_dilation) != 3: 75 | raise ValueError("replace_stride_with_dilation should be None " 76 | "or a 3-element tuple, got {}".format(replace_stride_with_dilation)) 77 | self.groups = groups 78 | self.base_width = width_per_group 79 | self.conv1 = nn.Sequential( 80 | nn.Conv2d(3, 64, kernel_size=3, stride=2, padding=1, bias=False), 81 | norm_layer(64), 82 | nn.ReLU(inplace=True), 83 | nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1, bias=False), 84 | norm_layer(64), 85 | nn.ReLU(inplace=True), 86 | nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1, bias=False), 87 | ) 88 | self.bn1 = norm_layer(self.inplanes) 89 | self.relu = nn.ReLU(inplace=True) 90 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 91 | self.layer1 = self._make_layer(block, 64, layers[0]) 92 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2, 93 | dilate=replace_stride_with_dilation[0]) 94 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2, 95 | dilate=replace_stride_with_dilation[1]) 96 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2, 97 | dilate=replace_stride_with_dilation[2]) 98 | 99 | for m in self.modules(): 100 | if isinstance(m, nn.Conv2d): 101 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 102 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 103 | nn.init.constant_(m.weight, 1) 104 | nn.init.constant_(m.bias, 0) 105 | 106 | if zero_init_residual: 107 | for m in self.modules(): 108 | if isinstance(m, Bottleneck): 109 | nn.init.constant_(m.bn3.weight, 0) 110 | 111 | def _make_layer(self, block, planes, blocks, stride=1, dilate=False): 112 | norm_layer = self._norm_layer 113 | downsample = None 114 | previous_dilation = self.dilation 115 | if dilate: 116 | self.dilation *= stride 117 | stride = 1 118 | if stride != 1 or self.inplanes != planes * block.expansion: 119 | downsample = nn.Sequential( 120 | conv1x1(self.inplanes, planes * block.expansion, stride), 121 | norm_layer(planes * block.expansion), 122 | ) 123 | 124 | layers = list() 125 | layers.append(block(self.inplanes, planes, stride, downsample, self.groups, 126 | self.base_width, previous_dilation, norm_layer)) 127 | self.inplanes = planes * block.expansion 128 | for _ in range(1, blocks): 129 | layers.append(block(self.inplanes, planes, groups=self.groups, 130 | base_width=self.base_width, dilation=self.dilation, 131 | norm_layer=norm_layer)) 132 | 133 | return nn.Sequential(*layers) 134 | 135 | def base_forward(self, x): 136 | x = self.conv1(x) 137 | x = self.bn1(x) 138 | x = self.relu(x) 139 | x = self.maxpool(x) 140 | 141 | c1 = self.layer1(x) 142 | c2 = self.layer2(c1) 143 | c3 = self.layer3(c2) 144 | c4 = self.layer4(c3) 145 | 146 | return c1, c2, c3, c4 147 | 148 | 149 | def _resnet(arch, block, layers, pretrained, **kwargs): 150 | model = ResNet(block, layers, **kwargs) 151 | if pretrained: 152 | pretrained_path = "/home/mac/selfmatch/pretrained/resnet101.pth" % arch 153 | state_dict = torch.load(pretrained_path) 154 | model.load_state_dict(state_dict, strict=False) 155 | return model 156 | 157 | 158 | def resnet50(pretrained=False, **kwargs): 159 | return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, **kwargs) 160 | 161 | 162 | def resnet101(pretrained=False, **kwargs): 163 | return _resnet('resnet101', Bottleneck, [3, 4, 23, 3], pretrained, **kwargs) 164 | -------------------------------------------------------------------------------- /dataset/transform.py: -------------------------------------------------------------------------------- 1 | import random 2 | import math 3 | 4 | import numpy as np 5 | from PIL import Image, ImageOps, ImageFilter, ImageEnhance 6 | import torch 7 | from torchvision import transforms 8 | 9 | 10 | def crop(img, mask, size, ignore_value=255): 11 | w, h = img.size 12 | padw = size - w if w < size else 0 13 | padh = size - h if h < size else 0 14 | img = ImageOps.expand(img, border=(0, 0, padw, padh), fill=0) 15 | mask = ImageOps.expand(mask, border=(0, 0, padw, padh), fill=ignore_value) 16 | 17 | w, h = img.size 18 | x = random.randint(0, w - size) 19 | y = random.randint(0, h - size) 20 | img = img.crop((x, y, x + size, y + size)) 21 | mask = mask.crop((x, y, x + size, y + size)) 22 | 23 | return img, mask 24 | 25 | 26 | def hflip(img, mask, p=0.5): 27 | if random.random() < p: 28 | img = img.transpose(Image.FLIP_LEFT_RIGHT) 29 | mask = mask.transpose(Image.FLIP_LEFT_RIGHT) 30 | return img, mask 31 | 32 | 33 | def normalize(img, mask=None): 34 | img = transforms.Compose([ 35 | transforms.ToTensor(), 36 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), 37 | # transforms.Normalize([0.342, 0.413, 0.359],[0.085, 0.094, 0.091]), 38 | ])(img) 39 | if mask is not None: 40 | mask = torch.from_numpy(np.array(mask)).long() 41 | return img, mask 42 | return img 43 | 44 | 45 | def resize(img, mask, ratio_range): 46 | w, h = img.size 47 | long_side = random.randint(int(max(h, w) * ratio_range[0]), int(max(h, w) * ratio_range[1])) 48 | 49 | if h > w: 50 | oh = long_side 51 | ow = int(1.0 * w * long_side / h + 0.5) 52 | else: 53 | ow = long_side 54 | oh = int(1.0 * h * long_side / w + 0.5) 55 | 56 | img = img.resize((ow, oh), Image.BILINEAR) 57 | mask = mask.resize((ow, oh), Image.NEAREST) 58 | return img, mask 59 | 60 | 61 | def blur(img, p=0.5): 62 | if random.random() < p: 63 | sigma = np.random.uniform(0.1, 2.0) 64 | img = img.filter(ImageFilter.GaussianBlur(radius=sigma)) 65 | return img 66 | 67 | # 使用概率为0.5 得到要cutmix区域的二值掩码,cut区域为1 68 | def obtain_cutmix_box(img_size, p=0.5, size_min=0.02, size_max=0.4, ratio_1=0.3, ratio_2=1/0.3): 69 | mask = torch.zeros(img_size, img_size) 70 | if random.random() > p: 71 | return mask 72 | 73 | size = np.random.uniform(size_min, size_max) * img_size * img_size 74 | while True: 75 | ratio = np.random.uniform(ratio_1, ratio_2) 76 | cutmix_w = int(np.sqrt(size / ratio)) 77 | cutmix_h = int(np.sqrt(size * ratio)) 78 | x = np.random.randint(0, img_size) 79 | y = np.random.randint(0, img_size) 80 | 81 | if x + cutmix_w <= img_size and y + cutmix_h <= img_size: 82 | break 83 | 84 | mask[y:y + cutmix_h, x:x + cutmix_w] = 1 85 | 86 | return mask 87 | 88 | def img_aug_autocontrast(img, scale=None): 89 | return ImageOps.autocontrast(img) 90 | 91 | 92 | def img_aug_equalize(img, scale=None): 93 | return ImageOps.equalize(img) 94 | 95 | 96 | def img_aug_invert(img, scale=None): 97 | return ImageOps.invert(img) 98 | 99 | 100 | def img_aug_identity(img, scale=None): 101 | return img 102 | 103 | 104 | def img_aug_blur(img, scale=[0.1, 2.0]): 105 | assert scale[0] < scale[1] 106 | sigma = np.random.uniform(scale[0], scale[1]) 107 | return img.filter(ImageFilter.GaussianBlur(radius=sigma)) 108 | 109 | 110 | def img_aug_contrast(img, scale=[0.05, 0.95], p=0.2): 111 | if random.random() < p: 112 | min_v, max_v = min(scale), max(scale) 113 | v = float(max_v - min_v) * random.random() 114 | v = max_v - v 115 | return ImageEnhance.Contrast(img).enhance(v) 116 | else: 117 | return img 118 | 119 | 120 | def img_aug_brightness(img, scale=[0.05, 0.95]): 121 | min_v, max_v = min(scale), max(scale) 122 | v = float(max_v - min_v) * random.random() 123 | v = max_v - v 124 | # print(f"final:{v}") 125 | return ImageEnhance.Brightness(img).enhance(v) 126 | 127 | 128 | def img_aug_color(img, scale=[0.05, 0.95]): 129 | min_v, max_v = min(scale), max(scale) 130 | v = float(max_v - min_v) * random.random() 131 | v = max_v - v 132 | # print(f"final:{v}") 133 | return ImageEnhance.Color(img).enhance(v) 134 | 135 | 136 | def img_aug_sharpness(img, scale=[0.05, 0.95]): 137 | min_v, max_v = min(scale), max(scale) 138 | v = float(max_v - min_v) * random.random() 139 | v = max_v - v 140 | # print(f"final:{v}") 141 | return ImageEnhance.Sharpness(img).enhance(v) 142 | 143 | 144 | def img_aug_hue(img, scale=[0, 0.5]): 145 | min_v, max_v = min(scale), max(scale) 146 | v = float(max_v - min_v) * random.random() 147 | v += min_v 148 | if np.random.random() < 0.5: 149 | hue_factor = -v 150 | else: 151 | hue_factor = v 152 | # print(f"Final-V:{hue_factor}") 153 | input_mode = img.mode 154 | if input_mode in {"L", "1", "I", "F"}: 155 | return img 156 | h, s, v = img.convert("HSV").split() 157 | np_h = np.array(h, dtype=np.uint8) 158 | # uint8 addition take cares of rotation across boundaries 159 | with np.errstate(over="ignore"): 160 | np_h += np.uint8(hue_factor * 255) 161 | h = Image.fromarray(np_h, "L") 162 | img = Image.merge("HSV", (h, s, v)).convert(input_mode) 163 | return img 164 | 165 | 166 | def img_aug_posterize(img, scale=[4, 8]): 167 | min_v, max_v = min(scale), max(scale) 168 | v = float(max_v - min_v) * random.random() 169 | # print(min_v, max_v, v) 170 | v = int(np.ceil(v)) 171 | v = max(1, v) 172 | v = max_v - v 173 | # print(f"final:{v}") 174 | return ImageOps.posterize(img, v) 175 | 176 | 177 | def img_aug_solarize(img, scale=[1, 256]): 178 | min_v, max_v = min(scale), max(scale) 179 | v = float(max_v - min_v) * random.random() 180 | # print(min_v, max_v, v) 181 | v = int(np.ceil(v)) 182 | v = max(1, v) 183 | v = max_v - v 184 | # print(f"final:{v}") 185 | return ImageOps.solarize(img, v) 186 | 187 | 188 | def get_augment_list(): 189 | l = [ 190 | (img_aug_identity, None), 191 | (img_aug_autocontrast, None), 192 | (img_aug_equalize, None), 193 | (img_aug_blur, [0.1, 2.0]), 194 | (img_aug_contrast, [0.05, 0.95]), 195 | (img_aug_brightness, [0.05, 0.95]), 196 | (img_aug_color, [0.05, 0.95]), 197 | (img_aug_sharpness, [0.05, 0.95]), 198 | (img_aug_posterize, [4, 8]), 199 | (img_aug_solarize, [1, 256]), 200 | (img_aug_hue, [0, 0.5]) 201 | ] 202 | return l 203 | 204 | 205 | class strong_img_aug: 206 | def __init__(self, num_augs=4, flag_using_random_num=True): 207 | self.n = num_augs 208 | self.augment_list = get_augment_list() 209 | self.flag_using_random_num = flag_using_random_num 210 | 211 | def __call__(self, img): 212 | if self.flag_using_random_num: 213 | max_num = np.random.randint(1, high=self.n + 1) 214 | else: 215 | max_num = self.n 216 | ops = random.choices(self.augment_list, k=max_num) 217 | for op, scales in ops: 218 | img = op(img, scales) 219 | return img -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import logging 3 | import os 4 | import pprint 5 | from tqdm import tqdm 6 | import numpy as np 7 | import torch 8 | from torch import nn 9 | import torch.distributed as dist 10 | import torch.nn.functional as F 11 | import torch.backends.cudnn as cudnn 12 | import torch.optim as optim 13 | from torch.optim import SGD 14 | from torch.utils.data import DataLoader 15 | 16 | import yaml 17 | # import sys 18 | import datetime 19 | from tensorboardX import SummaryWriter 20 | 21 | from dataset.data import BSDataset 22 | from baseline.BS_Mamba import BS_Mamba 23 | # from baseline.mamba_unet import MambaUnet 24 | # from baseline.unet import UNet 25 | # from baseline.local_vmamba import UPerNet 26 | from DiceLoss import DiceLoss,IouLoss 27 | from util.evaluate import evaluate_add 28 | from util.utils import count_params, init_log 29 | import random 30 | 31 | 32 | parser = argparse.ArgumentParser(description='Black_soil_detection_net') 33 | parser.add_argument('--gpu', default='0', type=int, help='id(s) for CUDA_VISIBLE_DEVICES') 34 | parser.add_argument('--config', default="./configs/BlackSoil.yaml", type=str) 35 | parser.add_argument('--save-path', default="./result/", type=str) 36 | parser.add_argument('--local_rank', default=0, type=int) 37 | parser.add_argument('--port', default=None, type=int) 38 | 39 | def init_seeds(seed=0, cuda_deterministic=False): 40 | random.seed(seed) 41 | np.random.seed(seed) 42 | torch.manual_seed(seed) 43 | cudnn.enabled = True 44 | if cuda_deterministic: 45 | cudnn.deterministic = True 46 | cudnn.benchmark = False 47 | else: 48 | cudnn.deterministic = False 49 | cudnn.benchmark = True 50 | 51 | 52 | def main(): 53 | args = parser.parse_args() 54 | os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu) 55 | 56 | cfg = yaml.load(open(args.config, "r"), Loader=yaml.Loader) 57 | model_name = 'BS_Mamba'# UNet\MambaUnet\... 58 | 59 | results_file = args.save_path + "results_{}.txt".format(datetime.datetime.now().strftime("%Y%m%d-%H%M%S")) 60 | 61 | logger = init_log('global', logging.INFO) 62 | logger.propagate = 0 63 | 64 | 65 | rank = 0 66 | 67 | if rank == 0: 68 | logger.info('{}\n'.format(pprint.pformat(cfg))) 69 | 70 | if rank == 0: 71 | os.makedirs(args.save_path, exist_ok=True) 72 | init_seeds(0, False) 73 | 74 | model =BS_Mamba() 75 | # model = UNet() 76 | # model = MambaUnet() 77 | params_to_optimize = [p for p in model.parameters() if p.requires_grad] 78 | optimizer = torch.optim.Adam( 79 | params_to_optimize, 80 | lr=cfg['lr'], betas=(0.9, 0.999), weight_decay=0.0005 81 | ) 82 | # optimizer = torch.optim.SGD(param_list, lr=cfg['lr'], 83 | # momentum=args.momentum, weight_decay=args.weight_decay) 84 | model.cuda() 85 | criterion_ce = nn.CrossEntropyLoss().cuda() 86 | criterion_iou = IouLoss(reduction='mean') 87 | Trainset1 = BSDataset(cfg['dataset'], cfg['data_root'], 'train_u', cfg['crop_size']) 88 | Trainset2 = BSDataset(cfg['dataset'], cfg['data_root'], 'train_u', cfg['crop_size']) 89 | Valset = BSDataset(cfg['dataset'], cfg['data_root'], 'val') 90 | 91 | Trainloader1 = DataLoader(Trainset1, batch_size=cfg['batch_size'], 92 | pin_memory=False, num_workers=0, drop_last=True, sampler=None) 93 | Trainloader2 = DataLoader(Trainset2, batch_size=cfg['batch_size'], 94 | pin_memory=False, num_workers=0, drop_last=True, sampler=None) 95 | Valloader = DataLoader(Valset, batch_size=1, pin_memory=True, num_workers=4, 96 | drop_last=True, sampler=None) 97 | 98 | total_iters = len(Trainloader1) * cfg['epochs'] 99 | previous_best = 0.0 100 | writer = {'loss_tra' :SummaryWriter('./result/loss_tra'),'loss_val' :SummaryWriter('./result/loss_val')} 101 | writer_iou = {'val_iou' : SummaryWriter('./result/val_iou')} 102 | 103 | 104 | for epoch in range(cfg['epochs']): 105 | if rank == 0: 106 | logger.info('===========> Epoch: {:}, LR: {:.6f}, Previous best: {:.6f}'.format( 107 | epoch, optimizer.param_groups[0]['lr'], previous_best)) 108 | 109 | total_loss = 0.0 110 | 111 | 112 | if rank == 0: 113 | tbar = tqdm(total=len(Trainloader1),desc=f'Epoch {epoch}') 114 | loader = zip(Trainloader1, Trainloader2) 115 | for i, ((img,mask,cutmix_box),(img_mix,mask_mix,_)) in enumerate(loader): 116 | img, mask = img.cuda(),mask.cuda() 117 | img_mix, mask_mix = img_mix.cuda(),mask_mix.cuda() 118 | cutmix_box = cutmix_box.cuda() 119 | img[cutmix_box.unsqueeze(1).expand(img.shape) == 1] = \ 120 | img_mix[cutmix_box.unsqueeze(1).expand(img.shape) == 1] 121 | mask[cutmix_box == 1] = mask_mix[cutmix_box == 1] 122 | 123 | model.train() 124 | pre = model(img) 125 | 126 | loss_ce = criterion_ce(pre, mask) 127 | 128 | loss_iou = criterion_iou(pre, mask) 129 | 130 | loss = 0.5*loss_ce + 0.25*loss_iou # lamda_1 and lamda_2 131 | 132 | 133 | optimizer.zero_grad() 134 | loss.backward() 135 | optimizer.step() 136 | 137 | total_loss += loss.item() 138 | writer['loss_tra'].add_scalar('Loss/Total', total_loss / (i + 1), epoch) 139 | iters = epoch * len(Trainloader1) + i 140 | lr = cfg['lr'] * (1 - iters / total_iters) ** 0.9 141 | optimizer.param_groups[0]["lr"] = lr 142 | 143 | if rank == 0: 144 | tbar.update(1) 145 | tbar.set_description(' Loss: {:.3f} ' 146 | .format( 147 | total_loss / (i + 1) 148 | )) 149 | 150 | 151 | if rank == 0: 152 | tbar.close() 153 | 154 | res_val = evaluate_add(model, Valloader) 155 | 156 | class_IOU = res_val['iou_class'] 157 | mIOU = res_val["mean_mIOU"] 158 | f1_or_dsc = res_val['f1_or_dsc'] 159 | accuracy = res_val['accuracy'] 160 | sensitivity = res_val['sensitivity'] 161 | specificity = res_val['specificity'] 162 | loss_val = res_val['Loss_val'] 163 | 164 | writer['loss_val'].add_scalar('Loss/Total', loss_val, epoch) 165 | writer_iou['val_iou'].add_scalar('val_iou', mIOU, epoch) 166 | 167 | if rank == 0: 168 | logger.info('***** Evaluation***** >>>> mIOU: {:.6f} \n'.format(mIOU)) 169 | 170 | with open(results_file, "a") as f: 171 | train_info = f"[epoch: {epoch}]\n" \ 172 | f"train_loss: {total_loss / (i + 1):.4f}\n" \ 173 | f"lr: {lr:.6f}\n" \ 174 | f"val_mIOU: {mIOU} \n" \ 175 | f"val_class_IOU: {class_IOU}\n" \ 176 | f"val_mean_mIOU: {mIOU} \n" \ 177 | f"f1_or_dsc: {f1_or_dsc:.6f}\n" \ 178 | f"accuracy: {accuracy:.6f}\n" \ 179 | f"sensitivity: {sensitivity:.6f}\n" \ 180 | f"specificity: {specificity:.6f}\n" \ 181 | f"Loss_val: {loss_val:.4f}\n" 182 | 183 | f.write(train_info + "\n\n") 184 | 185 | 186 | if mIOU > previous_best and rank == 0: 187 | if previous_best != 0: 188 | os.remove(os.path.join(args.save_path , '%s_%.3f.pth' % (model_name, previous_best))) 189 | previous_best = mIOU 190 | torch.save(model.state_dict(), os.path.join(args.save_path, '%s_%.3f.pth' % (cfg['backbone'], mIOU))) 191 | 192 | 193 | 194 | if __name__ == '__main__': 195 | main() 196 | -------------------------------------------------------------------------------- /baseline/BS_Mamba.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from .vmamba import VSSM 5 | from typing import Dict 6 | 7 | class LocalFeatureEnhancement(nn.Module): 8 | def __init__(self, in_channels, out_channels, kernel_size=3, padding=1): 9 | super(LocalFeatureEnhancement, self).__init__() 10 | self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, padding=padding) 11 | self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=kernel_size, padding=padding) 12 | self.relu = nn.ReLU(inplace=True) 13 | self.batch_norm1 = nn.BatchNorm2d(out_channels) 14 | self.batch_norm2 = nn.BatchNorm2d(out_channels) 15 | 16 | def forward(self, x): 17 | x = self.conv1(x) 18 | x = self.batch_norm1(x) 19 | x = self.relu(x) 20 | x = self.conv2(x) 21 | x = self.batch_norm2(x) 22 | x = self.relu(x) 23 | return x 24 | 25 | class GlobalFeatureEnhancement(nn.Module): 26 | def __init__(self, in_channels, out_channels): 27 | super(GlobalFeatureEnhancement, self).__init__() 28 | self.global_avg_pool = nn.AdaptiveAvgPool2d((1, 1)) 29 | self.fc1 = nn.Linear(in_channels, out_channels) 30 | self.fc2 = nn.Linear(out_channels, in_channels) 31 | self.relu = nn.ReLU(inplace=True) 32 | self.sigmoid = nn.Sigmoid() 33 | 34 | def forward(self, x): 35 | b, c, h, w = x.size() 36 | # 全局平均池化 37 | x_global = self.global_avg_pool(x) 38 | x_global = x_global.view(b, c) 39 | # 全连接层 40 | x_global = self.fc1(x_global) 41 | x_global = self.relu(x_global) 42 | x_global = self.fc2(x_global) 43 | x_global = self.sigmoid(x_global) 44 | # 增强原始特征 45 | x_global = x_global.view(b, c, 1, 1) 46 | x = x * x_global 47 | return x 48 | 49 | 50 | class DoubleConv(nn.Sequential): #定义两个卷积层 51 | def __init__(self, in_channels, out_channels, mid_channels=None): 52 | if mid_channels is None: 53 | mid_channels = out_channels 54 | super(DoubleConv, self).__init__( 55 | nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1, bias=False), 56 | nn.BatchNorm2d(mid_channels), 57 | nn.ReLU(inplace=True), 58 | nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1, bias=False), 59 | nn.BatchNorm2d(out_channels), 60 | nn.ReLU(inplace=True) 61 | ) 62 | 63 | 64 | class Down(nn.Sequential): 65 | def __init__(self, in_channels, out_channels): 66 | super(Down, self).__init__( 67 | nn.AvgPool2d(2, stride=2), 68 | DoubleConv(in_channels, out_channels) 69 | ) 70 | 71 | class Up(nn.Module): 72 | def __init__(self, in_channels, out_channels, bilinear=True): 73 | super(Up, self).__init__() 74 | if bilinear: 75 | self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) 76 | self.conv = DoubleConv(in_channels, out_channels, in_channels // 2) 77 | else: 78 | self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2) 79 | self.conv = DoubleConv(in_channels, out_channels) 80 | 81 | #前向传播过程 82 | def forward(self, x1: torch.Tensor, x2: torch.Tensor) -> torch.Tensor: 83 | x1 = self.up(x1) 84 | # [N, C, H, W] 85 | diff_y = x2.size()[2] - x1.size()[2] 86 | diff_x = x2.size()[3] - x1.size()[3] 87 | 88 | # padding_left, padding_right, padding_top, padding_bottom 89 | x1 = F.pad(x1, [diff_x // 2, diff_x - diff_x // 2, 90 | diff_y // 2, diff_y - diff_y // 2]) 91 | 92 | x = torch.cat([x2, x1], dim=1) 93 | x = self.conv(x) 94 | return x 95 | 96 | class OutConv(nn.Sequential): 97 | def __init__(self, in_channels, num_classes): 98 | super(OutConv, self).__init__( 99 | nn.Conv2d(in_channels, num_classes, kernel_size=1) 100 | ) 101 | 102 | class FeatureExtractor2(nn.Module): 103 | def __init__(self, 104 | input_channels=3, 105 | num_classes=2, 106 | mid_channel = 48, 107 | depths=[2,2,2,2], 108 | depths_decoder=[2, 2, 9, 2], 109 | drop_path_rate=0.2, 110 | load_ckpt_path= None, 111 | deep_supervision=True 112 | ): 113 | super().__init__() 114 | self.num_classes = num_classes 115 | self.vmunet = VSSM(in_chans=input_channels, 116 | num_classes=num_classes, 117 | depths=depths, 118 | depths_decoder=depths_decoder, 119 | drop_path_rate=drop_path_rate, 120 | ) 121 | 122 | 123 | def forward(self, x): 124 | f1, f2, f3, f4 = self.vmunet(x) # [b c h w] 125 | return [f1, f2, f3, f4] 126 | 127 | class FeatureExtractor1(nn.Module): 128 | def __init__(self, 129 | in_channels: int = 3, 130 | num_classes: int = 2, 131 | bilinear: bool = True, 132 | base_c: int = 32): 133 | super(FeatureExtractor1, self).__init__() 134 | self.in_channels = in_channels 135 | self.num_classes = num_classes 136 | self.bilinear = bilinear 137 | 138 | self.in_conv = DoubleConv(in_channels, base_c) 139 | self.down1 = Down(base_c, base_c * 2) 140 | self.down2 = Down(base_c * 2, base_c * 4) 141 | self.down3 = Down(base_c * 4, base_c * 8) 142 | factor = 2 if bilinear else 1 143 | self.down4 = Down(base_c * 8, base_c * 16 // factor) 144 | 145 | def forward(self, x: torch.Tensor) -> Dict[str, torch.Tensor]: 146 | 147 | x1 = self.in_conv(x) 148 | x2 = self.down1(x1) 149 | x3 = self.down2(x2) 150 | x4 = self.down3(x3) 151 | return [x1, x2, x3, x4] 152 | 153 | class AttentionFusion(nn.Module): 154 | def __init__(self, in_channels): 155 | super(AttentionFusion, self).__init__() 156 | def forward(self, x1, x2): 157 | fused = torch.add(x1, x2) 158 | return fused 159 | 160 | 161 | class BS_Mamba(nn.Module): 162 | def __init__(self,out_channels: int = 3, 163 | in_channels: int = 1, 164 | num_classes: int = 2, 165 | bilinear: bool = True, 166 | base_c: int = 32): 167 | super(BS_Mamba, self).__init__() 168 | self.feature_extractor1 = FeatureExtractor1() 169 | self.feature_extractor2 = FeatureExtractor2() 170 | self.attention_fusion = nn.ModuleList([ 171 | AttentionFusion(in_channels) for in_channels in [ 64, 128, 256, 512] 172 | ]) # 假设每层的通道数 173 | self.up1 = Up(base_c * 12, base_c * 4, bilinear) 174 | self.up2 = Up(base_c * 6, base_c * 2, bilinear) 175 | self.up3 = Up(base_c * 3, base_c, bilinear) 176 | self.out_conv = OutConv(base_c, num_classes) 177 | self.local_feature_enhancement1 = LocalFeatureEnhancement(in_channels=32, out_channels=32) 178 | self.local_feature_enhancement2 = LocalFeatureEnhancement(in_channels=64, out_channels=64) 179 | self.local_feature_enhancement3 = LocalFeatureEnhancement(in_channels=128, out_channels=128) 180 | self.local_feature_enhancement4 = LocalFeatureEnhancement(in_channels=256, out_channels=256) 181 | self.global_feature_enhancement1 = GlobalFeatureEnhancement(in_channels=32, out_channels=32) 182 | self.global_feature_enhancement2 = GlobalFeatureEnhancement(in_channels=64, out_channels=64) 183 | self.global_feature_enhancement3 = GlobalFeatureEnhancement(in_channels=128, out_channels=128) 184 | self.global_feature_enhancement4 = GlobalFeatureEnhancement(in_channels=256, out_channels=256) 185 | def forward(self, x): 186 | features1 = self.feature_extractor1(x) 187 | features2 = self.feature_extractor2(x) 188 | features1[0] = self.local_feature_enhancement1(features1[0]) 189 | features1[1] = self.local_feature_enhancement2(features1[1]) 190 | features1[2] = self.local_feature_enhancement3(features1[2]) 191 | features1[3] = self.local_feature_enhancement4(features1[3]) 192 | features2[0] = self.global_feature_enhancement1(features2[0]) 193 | features2[1] = self.global_feature_enhancement2(features2[1]) 194 | features2[2] = self.global_feature_enhancement3(features2[2]) 195 | features2[3] = self.global_feature_enhancement4(features2[3]) 196 | fused_features = [self.attention_fusion[i](f1, f2) for i, (f1, f2) in enumerate(zip(features1[:4], features2))] 197 | # fused_features = [torch.cat((f1, f2), dim=1) for f1, f2 in zip(features1, features2)] 198 | x = self.up1(fused_features[3], fused_features[2]) 199 | x = self.up2(x, fused_features[1]) 200 | x = self.up3(x, fused_features[0]) 201 | x = self.out_conv(x) 202 | # output = self.decoder(fused_features) 203 | return x 204 | 205 | # 测试网络 206 | # x = torch.randn(1, 3, 256, 256) # 示例输入 207 | # model = gdnet() 208 | # output = model(x) 209 | # print(output.shape) 210 | -------------------------------------------------------------------------------- /baseline/UltraLighet_VM_Unet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.nn.functional as F 4 | 5 | from timm.models.layers import trunc_normal_ 6 | import math 7 | from mamba_ssm import Mamba 8 | 9 | 10 | class PVMLayer(nn.Module): 11 | def __init__(self, input_dim, output_dim, d_state = 16, d_conv = 4, expand = 2): 12 | super().__init__() 13 | self.input_dim = input_dim 14 | self.output_dim = output_dim 15 | self.norm = nn.LayerNorm(input_dim) 16 | self.mamba = Mamba( 17 | d_model=input_dim//4, # Model dimension d_model 18 | d_state=d_state, # SSM state expansion factor 19 | d_conv=d_conv, # Local convolution width 20 | expand=expand, # Block expansion factor 21 | ) 22 | self.proj = nn.Linear(input_dim, output_dim) 23 | self.skip_scale= nn.Parameter(torch.ones(1)) 24 | 25 | def forward(self, x): 26 | if x.dtype == torch.float16: 27 | x = x.type(torch.float32) 28 | B, C = x.shape[:2] 29 | assert C == self.input_dim 30 | n_tokens = x.shape[2:].numel() 31 | img_dims = x.shape[2:] 32 | x_flat = x.reshape(B, C, n_tokens).transpose(-1, -2) 33 | x_norm = self.norm(x_flat) 34 | 35 | x1, x2, x3, x4 = torch.chunk(x_norm, 4, dim=2) 36 | x_mamba1 = self.mamba(x1) + self.skip_scale * x1 37 | x_mamba2 = self.mamba(x2) + self.skip_scale * x2 38 | x_mamba3 = self.mamba(x3) + self.skip_scale * x3 39 | x_mamba4 = self.mamba(x4) + self.skip_scale * x4 40 | x_mamba = torch.cat([x_mamba1, x_mamba2,x_mamba3,x_mamba4], dim=2) 41 | 42 | x_mamba = self.norm(x_mamba) 43 | x_mamba = self.proj(x_mamba) 44 | out = x_mamba.transpose(-1, -2).reshape(B, self.output_dim, *img_dims) 45 | return out 46 | 47 | 48 | class Channel_Att_Bridge(nn.Module): 49 | def __init__(self, c_list, split_att='fc'): 50 | super().__init__() 51 | c_list_sum = sum(c_list) - c_list[-1] 52 | self.split_att = split_att 53 | self.avgpool = nn.AdaptiveAvgPool2d(1) 54 | self.get_all_att = nn.Conv1d(1, 1, kernel_size=3, padding=1, bias=False) 55 | self.att1 = nn.Linear(c_list_sum, c_list[0]) if split_att == 'fc' else nn.Conv1d(c_list_sum, c_list[0], 1) 56 | self.att2 = nn.Linear(c_list_sum, c_list[1]) if split_att == 'fc' else nn.Conv1d(c_list_sum, c_list[1], 1) 57 | self.att3 = nn.Linear(c_list_sum, c_list[2]) if split_att == 'fc' else nn.Conv1d(c_list_sum, c_list[2], 1) 58 | self.att4 = nn.Linear(c_list_sum, c_list[3]) if split_att == 'fc' else nn.Conv1d(c_list_sum, c_list[3], 1) 59 | self.att5 = nn.Linear(c_list_sum, c_list[4]) if split_att == 'fc' else nn.Conv1d(c_list_sum, c_list[4], 1) 60 | self.sigmoid = nn.Sigmoid() 61 | 62 | def forward(self, t1, t2, t3, t4, t5): 63 | att = torch.cat((self.avgpool(t1), 64 | self.avgpool(t2), 65 | self.avgpool(t3), 66 | self.avgpool(t4), 67 | self.avgpool(t5)), dim=1) 68 | att = self.get_all_att(att.squeeze(-1).transpose(-1, -2)) 69 | if self.split_att != 'fc': 70 | att = att.transpose(-1, -2) 71 | att1 = self.sigmoid(self.att1(att)) 72 | att2 = self.sigmoid(self.att2(att)) 73 | att3 = self.sigmoid(self.att3(att)) 74 | att4 = self.sigmoid(self.att4(att)) 75 | att5 = self.sigmoid(self.att5(att)) 76 | if self.split_att == 'fc': 77 | att1 = att1.transpose(-1, -2).unsqueeze(-1).expand_as(t1) 78 | att2 = att2.transpose(-1, -2).unsqueeze(-1).expand_as(t2) 79 | att3 = att3.transpose(-1, -2).unsqueeze(-1).expand_as(t3) 80 | att4 = att4.transpose(-1, -2).unsqueeze(-1).expand_as(t4) 81 | att5 = att5.transpose(-1, -2).unsqueeze(-1).expand_as(t5) 82 | else: 83 | att1 = att1.unsqueeze(-1).expand_as(t1) 84 | att2 = att2.unsqueeze(-1).expand_as(t2) 85 | att3 = att3.unsqueeze(-1).expand_as(t3) 86 | att4 = att4.unsqueeze(-1).expand_as(t4) 87 | att5 = att5.unsqueeze(-1).expand_as(t5) 88 | 89 | return att1, att2, att3, att4, att5 90 | 91 | 92 | class Spatial_Att_Bridge(nn.Module): 93 | def __init__(self): 94 | super().__init__() 95 | self.shared_conv2d = nn.Sequential(nn.Conv2d(2, 1, 7, stride=1, padding=9, dilation=3), 96 | nn.Sigmoid()) 97 | 98 | def forward(self, t1, t2, t3, t4, t5): 99 | t_list = [t1, t2, t3, t4, t5] 100 | att_list = [] 101 | for t in t_list: 102 | avg_out = torch.mean(t, dim=1, keepdim=True) 103 | max_out, _ = torch.max(t, dim=1, keepdim=True) 104 | att = torch.cat([avg_out, max_out], dim=1) 105 | att = self.shared_conv2d(att) 106 | att_list.append(att) 107 | return att_list[0], att_list[1], att_list[2], att_list[3], att_list[4] 108 | 109 | 110 | class SC_Att_Bridge(nn.Module): 111 | def __init__(self, c_list, split_att='fc'): 112 | super().__init__() 113 | 114 | self.catt = Channel_Att_Bridge(c_list, split_att=split_att) 115 | self.satt = Spatial_Att_Bridge() 116 | 117 | def forward(self, t1, t2, t3, t4, t5): 118 | r1, r2, r3, r4, r5 = t1, t2, t3, t4, t5 119 | 120 | satt1, satt2, satt3, satt4, satt5 = self.satt(t1, t2, t3, t4, t5) 121 | t1, t2, t3, t4, t5 = satt1 * t1, satt2 * t2, satt3 * t3, satt4 * t4, satt5 * t5 122 | 123 | r1_, r2_, r3_, r4_, r5_ = t1, t2, t3, t4, t5 124 | t1, t2, t3, t4, t5 = t1 + r1, t2 + r2, t3 + r3, t4 + r4, t5 + r5 125 | 126 | catt1, catt2, catt3, catt4, catt5 = self.catt(t1, t2, t3, t4, t5) 127 | t1, t2, t3, t4, t5 = catt1 * t1, catt2 * t2, catt3 * t3, catt4 * t4, catt5 * t5 128 | 129 | return t1 + r1_, t2 + r2_, t3 + r3_, t4 + r4_, t5 + r5_ 130 | 131 | 132 | class UltraLight_VM_UNet(nn.Module): 133 | 134 | def __init__(self, num_classes=1, input_channels=3, c_list=[8,16,24,32,48,64], 135 | split_att='fc', bridge=True): 136 | super().__init__() 137 | 138 | self.bridge = bridge 139 | 140 | self.encoder1 = nn.Sequential( 141 | nn.Conv2d(input_channels, c_list[0], 3, stride=1, padding=1), 142 | ) 143 | self.encoder2 =nn.Sequential( 144 | nn.Conv2d(c_list[0], c_list[1], 3, stride=1, padding=1), 145 | ) 146 | self.encoder3 = nn.Sequential( 147 | nn.Conv2d(c_list[1], c_list[2], 3, stride=1, padding=1), 148 | ) 149 | self.encoder4 = nn.Sequential( 150 | PVMLayer(input_dim=c_list[2], output_dim=c_list[3]) 151 | ) 152 | self.encoder5 = nn.Sequential( 153 | PVMLayer(input_dim=c_list[3], output_dim=c_list[4]) 154 | ) 155 | self.encoder6 = nn.Sequential( 156 | PVMLayer(input_dim=c_list[4], output_dim=c_list[5]) 157 | ) 158 | 159 | if bridge: 160 | self.scab = SC_Att_Bridge(c_list, split_att) 161 | print('SC_Att_Bridge was used') 162 | 163 | self.decoder1 = nn.Sequential( 164 | PVMLayer(input_dim=c_list[5], output_dim=c_list[4]) 165 | ) 166 | self.decoder2 = nn.Sequential( 167 | PVMLayer(input_dim=c_list[4], output_dim=c_list[3]) 168 | ) 169 | self.decoder3 = nn.Sequential( 170 | PVMLayer(input_dim=c_list[3], output_dim=c_list[2]) 171 | ) 172 | self.decoder4 = nn.Sequential( 173 | nn.Conv2d(c_list[2], c_list[1], 3, stride=1, padding=1), 174 | ) 175 | self.decoder5 = nn.Sequential( 176 | nn.Conv2d(c_list[1], c_list[0], 3, stride=1, padding=1), 177 | ) 178 | self.ebn1 = nn.GroupNorm(4, c_list[0]) 179 | self.ebn2 = nn.GroupNorm(4, c_list[1]) 180 | self.ebn3 = nn.GroupNorm(4, c_list[2]) 181 | self.ebn4 = nn.GroupNorm(4, c_list[3]) 182 | self.ebn5 = nn.GroupNorm(4, c_list[4]) 183 | self.dbn1 = nn.GroupNorm(4, c_list[4]) 184 | self.dbn2 = nn.GroupNorm(4, c_list[3]) 185 | self.dbn3 = nn.GroupNorm(4, c_list[2]) 186 | self.dbn4 = nn.GroupNorm(4, c_list[1]) 187 | self.dbn5 = nn.GroupNorm(4, c_list[0]) 188 | 189 | self.final = nn.Conv2d(c_list[0], num_classes, kernel_size=1) 190 | 191 | self.apply(self._init_weights) 192 | 193 | def _init_weights(self, m): 194 | if isinstance(m, nn.Linear): 195 | trunc_normal_(m.weight, std=.02) 196 | if isinstance(m, nn.Linear) and m.bias is not None: 197 | nn.init.constant_(m.bias, 0) 198 | elif isinstance(m, nn.Conv1d): 199 | n = m.kernel_size[0] * m.out_channels 200 | m.weight.data.normal_(0, math.sqrt(2. / n)) 201 | elif isinstance(m, nn.Conv2d): 202 | fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 203 | fan_out //= m.groups 204 | m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) 205 | if m.bias is not None: 206 | m.bias.data.zero_() 207 | 208 | def forward(self, x): 209 | 210 | out = F.gelu(F.max_pool2d(self.ebn1(self.encoder1(x)),2,2)) 211 | t1 = out # b, c0, H/2, W/2 212 | 213 | out = F.gelu(F.max_pool2d(self.ebn2(self.encoder2(out)),2,2)) 214 | t2 = out # b, c1, H/4, W/4 215 | 216 | out = F.gelu(F.max_pool2d(self.ebn3(self.encoder3(out)),2,2)) 217 | t3 = out # b, c2, H/8, W/8 218 | 219 | out = F.gelu(F.max_pool2d(self.ebn4(self.encoder4(out)),2,2)) 220 | t4 = out # b, c3, H/16, W/16 221 | 222 | out = F.gelu(F.max_pool2d(self.ebn5(self.encoder5(out)),2,2)) 223 | t5 = out # b, c4, H/32, W/32 224 | 225 | if self.bridge: t1, t2, t3, t4, t5 = self.scab(t1, t2, t3, t4, t5) 226 | 227 | out = F.gelu(self.encoder6(out)) # b, c5, H/32, W/32 228 | 229 | out5 = F.gelu(self.dbn1(self.decoder1(out))) # b, c4, H/32, W/32 230 | out5 = torch.add(out5, t5) # b, c4, H/32, W/32 231 | 232 | out4 = F.gelu(F.interpolate(self.dbn2(self.decoder2(out5)),scale_factor=(2,2),mode ='bilinear',align_corners=True)) # b, c3, H/16, W/16 233 | out4 = torch.add(out4, t4) # b, c3, H/16, W/16 234 | 235 | out3 = F.gelu(F.interpolate(self.dbn3(self.decoder3(out4)),scale_factor=(2,2),mode ='bilinear',align_corners=True)) # b, c2, H/8, W/8 236 | out3 = torch.add(out3, t3) # b, c2, H/8, W/8 237 | 238 | out2 = F.gelu(F.interpolate(self.dbn4(self.decoder4(out3)),scale_factor=(2,2),mode ='bilinear',align_corners=True)) # b, c1, H/4, W/4 239 | out2 = torch.add(out2, t2) # b, c1, H/4, W/4 240 | 241 | out1 = F.gelu(F.interpolate(self.dbn5(self.decoder5(out2)),scale_factor=(2,2),mode ='bilinear',align_corners=True)) # b, c0, H/2, W/2 242 | out1 = torch.add(out1, t1) # b, c0, H/2, W/2 243 | 244 | out0 = F.interpolate(self.final(out1),scale_factor=(2,2),mode ='bilinear',align_corners=True) # b, num_class, H, W 245 | 246 | return torch.sigmoid(out0) 247 | 248 | -------------------------------------------------------------------------------- /baseline/local_scan.py: -------------------------------------------------------------------------------- 1 | 2 | import math 3 | import torch 4 | import torch.nn.functional as F 5 | import triton 6 | import triton.language as tl 7 | 8 | 9 | @triton.jit 10 | def triton_local_scan( 11 | x, # x point (B, C, H, W) or (B, C, L) 12 | y, # y point (B, C, H, W) or (B, C, L) 13 | K: tl.constexpr, # window size 14 | flip: tl.constexpr, # whether to flip the tokens 15 | BC: tl.constexpr, # number of channels in each program 16 | BH: tl.constexpr, # number of heights in each program 17 | BW: tl.constexpr, # number of width in each program 18 | DC: tl.constexpr, # original channels 19 | DH: tl.constexpr, # original height 20 | DW: tl.constexpr, # original width 21 | NH: tl.constexpr, # number of programs on height 22 | NW: tl.constexpr, # number of programs on width 23 | ): 24 | i_hw, i_c, i_b = tl.program_id(0), tl.program_id(1), tl.program_id(2) # program id of hw axis, c axis, batch axis 25 | i_h, i_w = (i_hw // NW), (i_hw % NW) # program idx of h and w 26 | _mask_h = (i_h * BH + tl.arange(0, BH)) < DH 27 | _mask_w = (i_w * BW + tl.arange(0, BW)) < DW 28 | _mask_hw = _mask_h[:, None] & _mask_w[None, :] # [BH, BW] 29 | _for_C = min(DC - i_c * BC, BC) # valid number of c in the program 30 | 31 | _tmp0 = i_c * BC * DH * DW # start offset of this program 32 | _tmp1 = DC * DH * DW # n_elements in one batch 33 | _tmp2 = _tmp0 + i_h * BH * DW + tl.arange(0, BH)[:, None] * DW + i_w * BW + tl.arange(0, BW)[None, :] # offsets of elements in this program 34 | 35 | p_x = x + i_b * _tmp1 + _tmp2 36 | 37 | _i = (tl.arange(0, BH) + BH * i_h)[:, None] 38 | _j = (tl.arange(0, BW) + BW * i_w)[None, :] 39 | _c_offset = ((DW // K) * (_i // K) + (_j // K)) * K * K + (_i % K) * K + _j % K 40 | if flip: 41 | _c_offset = DH * DW - _c_offset - 1 42 | 43 | p_y = y + i_b * _tmp1 + _tmp0 + _c_offset 44 | for idxc in range(_for_C): 45 | _idx = idxc * DH * DW 46 | _x = tl.load(p_x + _idx, mask=_mask_hw) 47 | tl.store(p_y + _idx, _x, mask=_mask_hw) 48 | tl.debug_barrier() 49 | 50 | 51 | @triton.jit 52 | def triton_local_reverse( 53 | x, # x point (B, C, H, W) or (B, C, L) 54 | y, # y point (B, C, H, W) or (B, C, L) 55 | K: tl.constexpr, # window size 56 | flip: tl.constexpr, # whether to flip the tokens 57 | BC: tl.constexpr, # number of channels in each program 58 | BH: tl.constexpr, # number of heights in each program 59 | BW: tl.constexpr, # number of width in each program 60 | DC: tl.constexpr, # original channels 61 | DH: tl.constexpr, # original height 62 | DW: tl.constexpr, # original width 63 | NH: tl.constexpr, # number of programs on height 64 | NW: tl.constexpr, # number of programs on width 65 | ): 66 | i_hw, i_c, i_b = tl.program_id(0), tl.program_id(1), tl.program_id(2) # program id of hw axis, c axis, batch axis 67 | i_h, i_w = (i_hw // NW), (i_hw % NW) # program idx of h and w 68 | _mask_h = (i_h * BH + tl.arange(0, BH)) < DH 69 | _mask_w = (i_w * BW + tl.arange(0, BW)) < DW 70 | _mask_hw = _mask_h[:, None] & _mask_w[None, :] # [BH, BW] 71 | _for_C = min(DC - i_c * BC, BC) # valid number of c in the program 72 | 73 | _tmp0 = i_c * BC * DH * DW # start offset of this program 74 | _tmp1 = DC * DH * DW # n_elements in one batch 75 | _tmp2 = _tmp0 + i_h * BH * DW + tl.arange(0, BH)[:, None] * DW + i_w * BW + tl.arange(0, BW)[None, :] # offsets of elements in this program 76 | 77 | p_x = x + i_b * _tmp1 + _tmp2 78 | 79 | _i = (tl.arange(0, BH) + BH * i_h)[:, None] 80 | _j = (tl.arange(0, BW) + BW * i_w)[None, :] 81 | _o = _i * DW + _j 82 | 83 | _i = _o // (K * K) // (DW // K) * K + _o % (K * K) // K 84 | _j = _o // (K * K) % (DW // K) * K + _o % (K * K) % K 85 | _c_offset = _i * DW + _j 86 | if flip: 87 | _c_offset = DH * DW - _c_offset - 1 88 | 89 | p_y = y + i_b * _tmp1 + _tmp0 + _c_offset 90 | for idxc in range(_for_C): 91 | _idx = idxc * DH * DW 92 | _x = tl.load(p_x + _idx, mask=_mask_hw) 93 | tl.store(p_y + _idx, _x, mask=_mask_hw) 94 | tl.debug_barrier() 95 | 96 | 97 | class LocalScanTriton(torch.autograd.Function): 98 | @staticmethod 99 | def forward(ctx, x: torch.Tensor, K: int, flip: bool, H: int = None, W: int = None): 100 | ori_x = x 101 | B, C = x.shape[:2] 102 | if H is None or W is None: 103 | if len(x.shape) == 4: 104 | H, W = x.shape[-2:] 105 | elif len(x.shape) == 3: 106 | raise RuntimeError("x must be BCHW format to infer the H W") 107 | B, C, H, W = int(B), int(C), int(H), int(W) 108 | 109 | ctx.ori_shape = (B, C, H, W) 110 | # pad tensor to make it evenly divisble by window size 111 | x, (H, W) = pad_tensor(x, K, H, W) 112 | ctx.shape = (B, C, H, W) 113 | 114 | BC, BH, BW = min(triton.next_power_of_2(C), 1), min(triton.next_power_of_2(H), 64), min(triton.next_power_of_2(W), 64) 115 | NH, NW, NC = triton.cdiv(H, BH), triton.cdiv(W, BW), triton.cdiv(C, BC) 116 | ctx.triton_shape = (BC, BH, BW, NC, NH, NW) 117 | ctx.K = K 118 | ctx.flip = flip 119 | 120 | if x.stride(-1) != 1: 121 | x = x.contiguous() 122 | 123 | if len(ori_x.shape) == 4: 124 | y = x.new_empty((B, C, H, W)) 125 | elif len(ori_x.shape) == 3: 126 | y = x.new_empty((B, C, H * W)) 127 | 128 | triton_local_scan[(NH * NW, NC, B)](x, y, K, flip, BC, BH, BW, C, H, W, NH, NW) 129 | return y 130 | 131 | @staticmethod 132 | def backward(ctx, y: torch.Tensor): 133 | # out: (b, k, d, l) 134 | B, C, H, W = ctx.shape 135 | BC, BH, BW, NC, NH, NW = ctx.triton_shape 136 | 137 | if y.stride(-1) != 1: 138 | y = y.contiguous() 139 | if len(y.shape) == 4 or ctx.shape != ctx.ori_shape: 140 | x = y.new_empty((B, C, H, W)) 141 | else: 142 | x = y.new_empty((B, C, H * W)) 143 | 144 | triton_local_reverse[(NH * NW, NC, B)](y, x, ctx.K, ctx.flip, BC, BH, BW, C, H, W, NH, NW) 145 | 146 | if ctx.shape != ctx.ori_shape: 147 | _, _, ori_H, ori_W = ctx.ori_shape 148 | x = x[:, :, :ori_H, :ori_W] 149 | if len(y.shape) == 3: 150 | x = x.flatten(2) 151 | 152 | return x, None, None, None, None 153 | 154 | 155 | class LocalReverseTriton(torch.autograd.Function): 156 | @staticmethod 157 | def forward(ctx, x: torch.Tensor, K: int, flip: bool, H: int = None, W: int = None): 158 | B, C = x.shape[:2] 159 | if H is None or W is None: 160 | if len(x.shape) == 4: 161 | H, W = x.shape[-2:] 162 | elif len(x.shape) == 3: 163 | raise RuntimeError("x must be BCHW format to infer the H W") 164 | B, C, H, W = int(B), int(C), int(H), int(W) 165 | 166 | ctx.ori_shape = (B, C, H, W) 167 | # x may have been padded 168 | Hg, Wg = math.ceil(H / K), math.ceil(W / K) 169 | H, W = Hg * K, Wg * K 170 | ctx.shape = (B, C, H, W) 171 | 172 | BC, BH, BW = min(triton.next_power_of_2(C), 1), min(triton.next_power_of_2(H), 64), min(triton.next_power_of_2(W), 64) 173 | NH, NW, NC = triton.cdiv(H, BH), triton.cdiv(W, BW), triton.cdiv(C, BC) 174 | ctx.triton_shape = (BC, BH, BW, NC, NH, NW) 175 | ctx.K = K 176 | ctx.flip = flip 177 | 178 | if x.stride(-1) != 1: 179 | x = x.contiguous() 180 | 181 | if len(x.shape) == 4 or ctx.ori_shape != ctx.shape: 182 | y = x.new_empty((B, C, H, W)) 183 | else: 184 | y = x.new_empty((B, C, H * W)) 185 | 186 | triton_local_reverse[(NH * NW, NC, B)](x, y, K, flip, BC, BH, BW, C, H, W, NH, NW) 187 | 188 | if ctx.ori_shape != ctx.shape: 189 | ori_H, ori_W = ctx.ori_shape[-2:] 190 | y = y[:, :, :ori_H, :ori_W] 191 | if len(x.shape) == 3: 192 | y = y.flatten(2) 193 | 194 | return y 195 | 196 | @staticmethod 197 | def backward(ctx, y: torch.Tensor): 198 | # out: (b, k, d, l) 199 | B, C, H, W = ctx.ori_shape 200 | BC, BH, BW, NC, NH, NW = ctx.triton_shape 201 | 202 | _is_y_BCHW = len(y.shape) == 4 203 | 204 | y, (H, W) = pad_tensor(y, ctx.K, H, W) 205 | 206 | if y.stride(-1) != 1: 207 | y = y.contiguous() 208 | 209 | if _is_y_BCHW: 210 | x = y.new_empty((B, C, H, W)) 211 | else: 212 | x = y.new_empty((B, C, H * W)) 213 | 214 | triton_local_scan[(NH * NW, NC, B)](y, x, ctx.K, ctx.flip, BC, BH, BW, C, H, W, NH, NW) 215 | 216 | return x, None, None, None, None 217 | 218 | 219 | 220 | def pad_tensor(x, w, H, W): 221 | if H % w == 0 and W % w == 0: 222 | return x, (H, W) 223 | B, C = x.shape[:2] 224 | if len(x.shape) == 3: 225 | x = x.view(B, C, H, W) 226 | 227 | Hg, Wg = math.ceil(H / w), math.ceil(W / w) 228 | newH, newW = Hg * w, Wg * w 229 | x = F.pad(x, (0, newW - W, 0, newH - H)) 230 | 231 | # We can skip flattening x back to BCL as the next operation 232 | # is triton_local_reverse / triton_local_scan, which supports 233 | # both BCHW and BCL inputs 234 | # if len(ori_x.shape) == 3: 235 | # x = x.flatten(2) 236 | 237 | return x, (newH, newW) 238 | 239 | 240 | """PyTorch code for local scan and local reverse""" 241 | 242 | 243 | def local_scan(x, w=7, H=14, W=14, flip=False, column_first=False): 244 | """Local windowed scan in LocalMamba 245 | Input: 246 | x: [B, L, C] 247 | H, W: original width and height before padding 248 | column_first: column-wise scan first (the additional direction in VMamba) 249 | Return: [B, C, L] 250 | """ 251 | B, L, C = x.shape 252 | x = x.view(B, H, W, C) 253 | Hg, Wg = math.ceil(H / w), math.ceil(W / w) 254 | if H % w != 0 or W % w != 0: 255 | newH, newW = Hg * w, Wg * w 256 | x = F.pad(x, (0, 0, 0, newW - W, 0, newH - H)) 257 | if column_first: 258 | x = x.view(B, Hg, w, Wg, w, C).permute(0, 5, 3, 1, 4, 2).reshape(B, C, -1) 259 | else: 260 | x = x.view(B, Hg, w, Wg, w, C).permute(0, 5, 1, 3, 2, 4).reshape(B, C, -1) 261 | if flip: 262 | x = x.flip([-1]) 263 | return x 264 | 265 | 266 | def local_scan_bchw(x, w=7, H=14, W=14, flip=False, column_first=False): 267 | """Local windowed scan in LocalMamba 268 | Input: 269 | x: [B, C, H, W] 270 | H, W: original width and height before padding 271 | column_first: column-wise scan first (the additional direction in VMamba) 272 | Return: [B, C, L] 273 | """ 274 | B, C, _, _ = x.shape 275 | x = x.view(B, C, H, W) 276 | Hg, Wg = math.ceil(H / w), math.ceil(W / w) 277 | if H % w != 0 or W % w != 0: 278 | newH, newW = Hg * w, Wg * w 279 | x = F.pad(x, (0, newW - W, 0, newH - H)) 280 | if column_first: 281 | x = x.view(B, C, Hg, w, Wg, w).permute(0, 1, 4, 2, 5, 3).reshape(B, C, -1) 282 | else: 283 | x = x.view(B, C, Hg, w, Wg, w).permute(0, 1, 2, 4, 3, 5).reshape(B, C, -1) 284 | if flip: 285 | x = x.flip([-1]) 286 | return x 287 | 288 | 289 | def local_reverse(x, w=7, H=14, W=14, flip=False, column_first=False): 290 | """Local windowed scan in LocalMamba 291 | Input: 292 | x: [B, C, L] 293 | H, W: original width and height before padding 294 | column_first: column-wise scan first (the additional direction in VMamba) 295 | Return: [B, C, L] 296 | """ 297 | B, C, L = x.shape 298 | Hg, Wg = math.ceil(H / w), math.ceil(W / w) 299 | if flip: 300 | x = x.flip([-1]) 301 | if H % w != 0 or W % w != 0: 302 | if column_first: 303 | x = x.view(B, C, Wg, Hg, w, w).permute(0, 1, 3, 5, 2, 4).reshape(B, C, Hg * w, Wg * w) 304 | else: 305 | x = x.view(B, C, Hg, Wg, w, w).permute(0, 1, 2, 4, 3, 5).reshape(B, C, Hg * w, Wg * w) 306 | x = x[:, :, :H, :W].reshape(B, C, -1) 307 | else: 308 | if column_first: 309 | x = x.view(B, C, Wg, Hg, w, w).permute(0, 1, 3, 5, 2, 4).reshape(B, C, L) 310 | else: 311 | x = x.view(B, C, Hg, Wg, w, w).permute(0, 1, 2, 4, 3, 5).reshape(B, C, L) 312 | return x 313 | -------------------------------------------------------------------------------- /util/saliency_metric.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from scipy import ndimage 3 | from scipy.ndimage import convolve, distance_transform_edt as bwdist 4 | # import math 5 | import torch 6 | import numpy 7 | class cal_fm_(object): 8 | def __init__(self, num, thds=255): 9 | # TruePositive + TrueNegative, for accuracy 10 | self.tp_fp = 0 11 | # TruePositive 12 | self.tp = 0 13 | # Number of '1' predictions, for precision 14 | self.pred_true = 0 15 | # Number of '1's in gt mask, for recall 16 | self.gt_true = 0 17 | # List to save mean absolute error of each image 18 | self.mae_list = [] 19 | self.img_size = 384 20 | self.fscore, self.cnt, self.number = 0, 0, 256 21 | self.mean_pr, self.mean_re, self.threshod = 0, 0, np.linspace(0, 1, self.number, endpoint=False) 22 | self.num = num 23 | self.thds = thds 24 | self.precision = np.zeros((self.num, self.thds)) 25 | self.recall = np.zeros((self.num, self.thds)) 26 | self.meanF = np.zeros((self.num,1)) 27 | self.idx = 0 28 | # self.num_black = 0 29 | self.max_F = 0 30 | def update(self, pred, gt): 31 | 32 | # prediction, recall, Fmeasure_temp = self.cal(pred, gt) 33 | # self.precision[self.idx, :] = prediction 34 | # self.recall[self.idx, :] = recall 35 | # self.meanF[self.idx, :] = Fmeasure_temp 36 | 37 | self.idx += 1 38 | def cal(self, pred, gt): 39 | 40 | # self.tp += numpy.dot(pred, gt).sum() 41 | self.tp += (pred*gt).sum() 42 | self.pred_true += pred.sum() 43 | self.gt_true += gt.sum() 44 | 45 | # ae = torch.mean(torch.abs(res - gt), dim=(0, 1)).cpu().numpy() 46 | # mae_list.extend(ae) 47 | # mae_list.append(ae.item()) 48 | self.cnt += 1 49 | 50 | def show(self): 51 | precision = self.tp / self.pred_true 52 | recall = self.tp / self.gt_true 53 | avgf = (precision*recall*(1.3))/(0.3*precision+recall+1e-12) 54 | return avgf.max(),avgf,precision,recall 55 | 56 | class cal_fm(object): 57 | # Fmeasure(maxFm,meanFm)---Frequency-tuned salient region detection(CVPR 2009) 58 | def __init__(self, num, thds=255): 59 | self.num = num 60 | self.thds = thds 61 | self.precision = np.zeros((self.num, self.thds)) 62 | self.recall = np.zeros((self.num, self.thds)) 63 | self.meanF = np.zeros((self.num,1)) 64 | self.idx = 0 65 | self.num_black = 0 66 | self.max_F = 0 67 | 68 | # def update(self, pred, gt): 69 | # if gt.max() != 0: 70 | # prediction, recall, Fmeasure_temp = self.cal(pred, gt) 71 | # self.precision[self.idx, :] = prediction 72 | # self.recall[self.idx, :] = recall 73 | # self.meanF[self.idx, :] = Fmeasure_temp 74 | # else: 75 | # self.meanF[self.idx, :] = 1 # 让全黑图片F值为1 76 | # self.idx += 1 77 | def update(self, pred, gt): 78 | if gt.max() != 0: 79 | self.num_black += 1 80 | prediction, recall, Fmeasure_temp = self.cal(pred, gt) 81 | self.precision[self.idx, :] = prediction 82 | self.recall[self.idx, :] = recall 83 | self.meanF[self.idx, :] = Fmeasure_temp 84 | 85 | self.idx += 1 86 | 87 | def cal(self, pred, gt): 88 | ######################## meanF ############################## 89 | th = 2 * pred.mean() 90 | if th > 1: 91 | th = 1 92 | # 归一化? 1为前景,0为后景 93 | binary = np.zeros_like(pred) 94 | binary[pred >= th] = 1 95 | hard_gt = np.zeros_like(gt) 96 | hard_gt[gt > 0.5] = 1 97 | # tp:正确预测数量 98 | tp = (binary * hard_gt).sum() 99 | if tp == 0: 100 | if hard_gt.all() == 0: 101 | meanF = 1 102 | else: 103 | meanF = 0 104 | else: 105 | # 计算precision和recall 106 | pre = tp / binary.sum() 107 | rec = tp / hard_gt.sum() 108 | # beta^2 直接设置 109 | meanF = (1.3 * pre * rec) / (0.3 * pre + rec + 1e-8) 110 | if meanF > self.max_F and meanF != 1: 111 | self.max_F = meanF 112 | ######################## maxF ############################## 113 | # 从[0,1]到[0,255] 114 | pred = np.uint8(pred * 255) 115 | # onehot编码? 116 | target = pred[gt > 0.5] 117 | nontarget = pred[gt <= 0.5] 118 | # 这两行代码分别计算了目标和非目标区域中预测值的直方图。 119 | # np.histogram 函数用于计算数值数据的直方图,这里将预测值划分为 256 个区间, 120 | # 然后统计每个区间内的值的数量,从而得到直方图 121 | targetHist, _ = np.histogram(target, bins=range(256)) 122 | nontargetHist, _ = np.histogram(nontarget, bins=range(256)) 123 | # 这两行代码分别对目标和非目标区域中的直方图进行了累积和操作。 124 | # 首先,np.flip 函数用于翻转直方图的顺序, 125 | # 然后 np.cumsum 函数计算了翻转后直方图的累积和。 126 | targetHist = np.cumsum(np.flip(targetHist), axis=0) 127 | nontargetHist = np.cumsum(np.flip(nontargetHist), axis=0) 128 | 129 | precision = (targetHist) / (targetHist + nontargetHist + 1e-8) 130 | recall = (targetHist) / (np.sum(gt)+1e-8) 131 | # F = (1.3 * precision * recall) / (0.3 * precision + recall + 1e-10) 132 | # if F.max() > self.max_F: 133 | # self.max_F = F 134 | return precision, recall, meanF 135 | 136 | def show(self): 137 | assert self.num == self.idx 138 | precision = self.precision.mean(axis=0) 139 | recall = self.recall.mean(axis=0) 140 | # if precision == 0 and recall == 0: 141 | # fmeasure = 1 142 | fmeasure = (1.3 * precision * recall) / (0.3 * precision + recall) 143 | fmeasure_avg = self.meanF.mean(axis=0) 144 | # return fmeasure.max(),fmeasure_avg[0],precision,recall 145 | return fmeasure.max(),fmeasure_avg[0],precision,recall 146 | # def show(self): 147 | # assert self.num == self.idx 148 | # # precision = -np.partition(-self.precision, 1)[1] 149 | # # recall = -np.partition(-self.recall, 1)[1] 150 | # precision = self.precision.mean(axis=0) 151 | # recall = self.recall.mean(axis=0) 152 | # # precision = self.precision 153 | # # recall = self.recall 154 | # # fmeasure = (1.3 * precision * recall + 1e-10) / (0.3 * precision + recall + 1e-10) 155 | # # fmeasure_max = -np.partition(-fmeasure, 1)[1] 156 | # fmeasure_avg = self.meanF.mean(axis=0) 157 | # # return fmeasure.max(),fmeasure_avg[0],precision,recall 158 | # return self.max_F,fmeasure_avg[0],precision,recall 159 | 160 | 161 | class cal_mae(object): 162 | # mean absolute error 163 | def __init__(self): 164 | self.prediction = [] 165 | 166 | def update(self, pred, gt): 167 | score = self.cal(pred, gt) 168 | self.prediction.append(score) 169 | 170 | def cal(self, pred, gt): 171 | return np.mean(np.abs(pred - gt)) 172 | 173 | def show(self): 174 | return np.mean(self.prediction) 175 | 176 | 177 | class cal_sm(object): 178 | # Structure-measure: A new way to evaluate foreground maps (ICCV 2017) 179 | def __init__(self, alpha=0.5): 180 | self.prediction = [] 181 | self.alpha = alpha 182 | 183 | def update(self, pred, gt): 184 | gt = gt > 0.5 185 | score = self.cal(pred, gt) 186 | self.prediction.append(score) 187 | 188 | def show(self): 189 | return np.mean(self.prediction) 190 | 191 | def cal(self, pred, gt): 192 | y = np.mean(gt) 193 | if y == 0: 194 | score = 1 - np.mean(pred) 195 | elif y == 1: 196 | score = np.mean(pred) 197 | else: 198 | # 解决score返回nan问题 199 | region_value = self.region(pred, gt) 200 | if np.isnan(region_value): 201 | region_value = 0 202 | score = self.alpha * self.object(pred, gt) + (1 - self.alpha) * region_value 203 | return score 204 | 205 | def object(self, pred, gt): 206 | fg = pred * gt 207 | bg = (1 - pred) * (1 - gt) 208 | 209 | u = np.mean(gt) 210 | return u * self.s_object(fg, gt) + (1 - u) * self.s_object(bg, np.logical_not(gt)) 211 | 212 | def s_object(self, in1, in2): 213 | x = np.mean(in1[in2]) 214 | sigma_x = np.std(in1[in2]) 215 | return 2 * x / (pow(x, 2) + 1 + sigma_x + 1e-8) 216 | 217 | def region(self, pred, gt): 218 | [y, x] = ndimage.center_of_mass(gt) 219 | y = int(round(y)) + 1 220 | x = int(round(x)) + 1 221 | [gt1, gt2, gt3, gt4, w1, w2, w3, w4] = self.divideGT(gt, x, y) 222 | pred1, pred2, pred3, pred4 = self.dividePred(pred, x, y) 223 | 224 | score1 = self.ssim(pred1, gt1) 225 | score2 = self.ssim(pred2, gt2) 226 | score3 = self.ssim(pred3, gt3) 227 | score4 = self.ssim(pred4, gt4) 228 | 229 | return w1 * score1 + w2 * score2 + w3 * score3 + w4 * score4 230 | 231 | def divideGT(self, gt, x, y): 232 | h, w = gt.shape 233 | area = h * w 234 | LT = gt[0:y, 0:x] 235 | RT = gt[0:y, x:w] 236 | LB = gt[y:h, 0:x] 237 | RB = gt[y:h, x:w] 238 | 239 | w1 = x * y / area 240 | w2 = y * (w - x) / area 241 | w3 = (h - y) * x / area 242 | w4 = (h - y) * (w - x) / area 243 | 244 | return LT, RT, LB, RB, w1, w2, w3, w4 245 | 246 | def dividePred(self, pred, x, y): 247 | h, w = pred.shape 248 | LT = pred[0:y, 0:x] 249 | RT = pred[0:y, x:w] 250 | LB = pred[y:h, 0:x] 251 | RB = pred[y:h, x:w] 252 | 253 | return LT, RT, LB, RB 254 | 255 | def ssim(self, in1, in2): 256 | in2 = np.float32(in2) 257 | h, w = in1.shape 258 | N = h * w 259 | 260 | x = np.mean(in1) 261 | y = np.mean(in2) 262 | sigma_x = np.var(in1) 263 | sigma_y = np.var(in2) 264 | sigma_xy = np.sum((in1 - x) * (in2 - y)) / (N - 1) 265 | 266 | alpha = 4 * x * y * sigma_xy 267 | beta = (x * x + y * y) * (sigma_x + sigma_y) 268 | 269 | if alpha != 0: 270 | score = alpha / (beta + 1e-8) 271 | elif alpha == 0 and beta == 0: 272 | score = 1 273 | else: 274 | score = 0 275 | 276 | return score 277 | 278 | class cal_em(object): 279 | #Enhanced-alignment Measure for Binary Foreground Map Evaluation (IJCAI 2018) 280 | def __init__(self): 281 | self.prediction = [] 282 | 283 | def update(self, pred, gt): 284 | score = self.cal(pred, gt) 285 | self.prediction.append(score) 286 | 287 | def cal(self, pred, gt): 288 | th = 2 * pred.mean() 289 | if th > 1: 290 | th = 1 291 | FM = np.zeros(gt.shape) 292 | FM[pred >= th] = 1 293 | FM = np.array(FM,dtype=bool) 294 | GT = np.array(gt,dtype=bool) 295 | dFM = np.double(FM) 296 | if (sum(sum(np.double(GT)))==0): 297 | enhanced_matrix = 1.0-dFM 298 | elif (sum(sum(np.double(~GT)))==0): 299 | enhanced_matrix = dFM 300 | else: 301 | dGT = np.double(GT) 302 | align_matrix = self.AlignmentTerm(dFM, dGT) 303 | enhanced_matrix = self.EnhancedAlignmentTerm(align_matrix) 304 | [w, h] = np.shape(GT) 305 | score = sum(sum(enhanced_matrix))/ (w * h - 1 + 1e-8) 306 | return score 307 | def AlignmentTerm(self,dFM,dGT): 308 | mu_FM = np.mean(dFM) 309 | mu_GT = np.mean(dGT) 310 | align_FM = dFM - mu_FM 311 | align_GT = dGT - mu_GT 312 | align_Matrix = 2. * (align_GT * align_FM)/ (align_GT* align_GT + align_FM* align_FM + 1e-8) 313 | return align_Matrix 314 | def EnhancedAlignmentTerm(self,align_Matrix): 315 | enhanced = np.power(align_Matrix + 1,2) / 4 316 | return enhanced 317 | def show(self): 318 | return np.mean(self.prediction) 319 | 320 | 321 | 322 | 323 | class cal_wfm(object): 324 | # How to Evaluate Foreground Maps? W_Fm 325 | def __init__(self, beta=1): 326 | self.beta = beta 327 | self.eps = 1e-6 328 | self.scores_list = [] 329 | 330 | def update(self, pred, gt): 331 | assert pred.ndim == gt.ndim and pred.shape == gt.shape 332 | assert pred.max() <= 1 and pred.min() >= 0 333 | assert gt.max() <= 1 and gt.min() >= 0 334 | 335 | gt = gt > 0.5 336 | 337 | if gt.max() == 0: 338 | score = 1 339 | else: 340 | score = self.cal(pred, gt) 341 | self.scores_list.append(score) 342 | 343 | def matlab_style_gauss2D(self, shape=(7, 7), sigma=5): 344 | """ 345 | 2D gaussian mask - should give the same result as MATLAB's 346 | fspecial('gaussian',[shape],[sigma]) 347 | """ 348 | m, n = [(ss - 1.) / 2. for ss in shape] 349 | y, x = np.ogrid[-m:m + 1, -n:n + 1] 350 | h = np.exp(-(x * x + y * y) / (2. * sigma * sigma)) 351 | h[h < np.finfo(h.dtype).eps * h.max()] = 0 352 | sumh = h.sum() 353 | if sumh != 0: 354 | h /= sumh 355 | return h 356 | 357 | def cal(self, pred, gt): 358 | # [Dst,IDXT] = bwdist(dGT); 359 | ''' 360 | bwdist 函数通常用于计算图像中每个像素到最近的零像素(背景像素)的欧几里得距离。 361 | 在这个函数中,gt 代表了一个二值化图像,其中包含了两个值, 362 | 通常是0(表示背景)和1(表示前景或目标)。 363 | gt == 0:这部分代码创建一个布尔掩码, 364 | 其中与 gt 中的像素值为0的像素对应的位置为True, 365 | 其他位置为False。这实际上是一个背景像素的二值掩码。 366 | 367 | bwdist(gt == 0, return_indices=True):这是调用 bwdist 函数的语法。它接受一个布尔掩码作为输入,并计算每个前景像素到最近的背景像素的欧几里得距离。同时,return_indices=True 参数告诉函数返回距离图像和距离图像上每个像素的索引。 368 | 369 | 结果解释: 370 | 371 | Dst 是一个与输入图像 gt 具有相同形状的数组,其中每个像素的值表示该像素到最近的背景像素的欧几里得距离。对于背景像素本身,距离为0。 372 | 373 | Idxt 是一个与输入图像 gt 具有相同形状的数组,其中每个像素的值表示该像素在距离图像中对应的最近背景像素的位置索引 374 | 。这个索引可以用于找到距离最近的背景像素的位置。[坐标] 375 | ''' 376 | Dst, Idxt = bwdist(gt == 0, return_indices=True) 377 | ''' 378 | Idxt.shape = (2,352,352) 379 | ''' 380 | # %Pixel dependency 381 | # E = abs(FG-dGT); 382 | E = np.abs(pred - gt) 383 | # Et = E; 384 | # Et(~GT)=Et(IDXT(~GT)); %To deal correctly with the edges of the foreground region 385 | Et = np.copy(E) 386 | ''' 387 | 这是关键的一步,用于调整误差。首先,它检查图像中哪些像素属于背景(gt == 0), 388 | 然后使用距离图 Dst 中的信息来调整这些背景像素的误差。 389 | 390 | gt == 0 用于获取背景像素的掩码。 391 | 392 | Idxt[0][gt == 0], Idxt[1][gt == 0] 是根据之前计算的距离图中的索引信息, 393 | 获取背景像素到最近的前景像素的索引。Idxt[0] 包含了纵向索引,Idxt[1] 包含了横向索引。 394 | 395 | Et[gt == 0] = ... 用计算出的索引来更新背景像素的误差, 396 | 将背景像素的误差替换为距离最近前景像素的误差。 397 | 这个操作的目的是考虑像素之间的空间关系, 398 | 使背景像素的误差受到其周围前景像素的影响,以便更好地处理前景区域的边缘。 399 | ''' 400 | 401 | Et[gt == 0] = Et[Idxt[0][gt == 0], Idxt[1][gt == 0]] 402 | 403 | # K = fspecial('gaussian',7,5); 404 | # EA = imfilter(Et,K); 405 | # MIN_E_EA(GT & EA 5, 1, -1)) 417 | //array([-1, -1, -1, -1, -1, -1, 1, 1, 1, 1]) 418 | ''' 419 | B = np.where(gt == 0, 2 - np.exp(np.log(0.5) / 5 * Dst), np.ones_like(gt)) 420 | Ew = MIN_E_EA * B 421 | 422 | # TPw = sum(dGT(:)) - sum(sum(Ew(GT))); 423 | # FPw = sum(sum(Ew(~GT))); 424 | TPw = np.sum(gt) - np.sum(Ew[gt == 1]) 425 | FPw = np.sum(Ew[gt == 0]) 426 | 427 | # R = 1- mean2(Ew(GT)); %Weighed Recall 428 | # P = TPw./(eps+TPw+FPw); %Weighted Precision 429 | R = 1 - np.mean(Ew[gt]) 430 | P = TPw / (self.eps + TPw + FPw) 431 | 432 | # % Q = (1+Beta^2)*(R*P)./(eps+R+(Beta.*P)); 433 | Q = ((1 + self.beta) * R * P + 1e-8) / (self.eps + R + self.beta * P + 1e-8) 434 | 435 | return Q 436 | 437 | def show(self): 438 | return np.mean(self.scores_list) -------------------------------------------------------------------------------- /baseline/multi_mamba.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023, Tri Dao, Albert Gu. 2 | 3 | import math 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | from torch import Tensor 8 | 9 | from einops import rearrange, repeat 10 | import logging 11 | 12 | 13 | try: 14 | from mamba_ssm.ops.selective_scan_interface import mamba_inner_fn_no_out_proj 15 | except ImportError: 16 | mamba_inner_fn_no_out_proj = None 17 | 18 | from .local_scan import LocalScanTriton, LocalReverseTriton, local_scan, local_scan_bchw, local_reverse 19 | 20 | 21 | class MultiScan(nn.Module): 22 | 23 | ALL_CHOICES = ('h', 'h_flip', 'v', 'v_flip', 'w2', 'w2_flip', 'w7', 'w7_flip') 24 | 25 | def __init__(self, dim, choices=None, token_size=(14, 14)): 26 | super().__init__() 27 | self.token_size = token_size 28 | if choices is None: 29 | self.choices = MultiScan.ALL_CHOICES 30 | self.norms = nn.ModuleList([nn.LayerNorm(dim, elementwise_affine=False) for _ in self.choices]) 31 | self.weights = nn.Parameter(1e-3 * torch.randn(len(self.choices), 1, 1, 1)) 32 | self._iter = 0 33 | self.logger = logging.getLogger() 34 | self.search = True 35 | else: 36 | self.choices = choices 37 | self.search = False 38 | 39 | def forward(self, xs): 40 | """ 41 | Input @xs: [[B, L, D], ...] 42 | """ 43 | if self.search: 44 | weights = self.weights.softmax(0) 45 | xs = [norm(x) for norm, x in zip(self.norms, xs)] 46 | xs = torch.stack(xs) * weights 47 | x = xs.sum(0) 48 | if self._iter % 200 == 0: 49 | if torch.distributed.is_initialized() and torch.distributed.get_rank() == 0: 50 | self.logger.info(str(weights.detach().view(-1).tolist())) 51 | self._iter += 1 52 | else: 53 | x = torch.stack(xs).sum(0) 54 | return x 55 | 56 | def multi_scan(self, x): 57 | """ 58 | Input @x: shape [B, L, D] 59 | """ 60 | xs = [] 61 | for direction in self.choices: 62 | xs.append(self.scan(x, direction)) 63 | return xs 64 | 65 | def multi_reverse(self, xs): 66 | new_xs = [] 67 | for x, direction in zip(xs, self.choices): 68 | new_xs.append(self.reverse(x, direction)) 69 | return new_xs 70 | 71 | def scan(self, x, direction='h'): 72 | """ 73 | Input @x: shape [B, L, D] or [B, C, H, W] 74 | Return torch.Tensor: shape [B, D, L] 75 | """ 76 | H, W = self.token_size 77 | if len(x.shape) == 3: 78 | if direction == 'h': 79 | return x.transpose(-2, -1) 80 | elif direction == 'h_flip': 81 | return x.transpose(-2, -1).flip([-1]) 82 | elif direction == 'v': 83 | return rearrange(x, 'b (h w) d -> b d (w h)', h=H, w=W) 84 | elif direction == 'v_flip': 85 | return rearrange(x, 'b (h w) d -> b d (w h)', h=H, w=W).flip([-1]) 86 | elif direction.startswith('w'): 87 | K = int(direction[1:].split('_')[0]) 88 | flip = direction.endswith('flip') 89 | return local_scan(x, K, H, W, flip=flip) 90 | # return LocalScanTriton.apply(x.transpose(-2, -1), K, flip, H, W) 91 | else: 92 | raise RuntimeError(f'Direction {direction} not found.') 93 | elif len(x.shape) == 4: 94 | if direction == 'h': 95 | return x.flatten(2) 96 | elif direction == 'h_flip': 97 | return x.flatten(2).flip([-1]) 98 | elif direction == 'v': 99 | return rearrange(x, 'b d h w -> b d (w h)', h=H, w=W) 100 | elif direction == 'v_flip': 101 | return rearrange(x, 'b d h w -> b d (w h)', h=H, w=W).flip([-1]) 102 | elif direction.startswith('w'): 103 | K = int(direction[1:].split('_')[0]) 104 | flip = direction.endswith('flip') 105 | return local_scan_bchw(x, K, H, W, flip=flip) 106 | # return LocalScanTriton.apply(x, K, flip, H, W).flatten(2) 107 | else: 108 | raise RuntimeError(f'Direction {direction} not found.') 109 | 110 | def reverse(self, x, direction='h'): 111 | """ 112 | Input @x: shape [B, D, L] 113 | Return torch.Tensor: shape [B, D, L] 114 | """ 115 | H, W = self.token_size 116 | if direction == 'h': 117 | return x 118 | elif direction == 'h_flip': 119 | return x.flip([-1]) 120 | elif direction == 'v': 121 | return rearrange(x, 'b d (h w) -> b d (w h)', h=H, w=W) 122 | elif direction == 'v_flip': 123 | return rearrange(x.flip([-1]), 'b d (h w) -> b d (w h)', h=H, w=W) 124 | elif direction.startswith('w'): 125 | K = int(direction[1:].split('_')[0]) 126 | flip = direction.endswith('flip') 127 | return local_reverse(x, K, H, W, flip=flip) 128 | # return LocalReverseTriton.apply(x, K, flip, H, W) 129 | else: 130 | raise RuntimeError(f'Direction {direction} not found.') 131 | 132 | def __repr__(self): 133 | scans = ', '.join(self.choices) 134 | return super().__repr__().replace(self.__class__.__name__, f'{self.__class__.__name__}[{scans}]') 135 | 136 | 137 | class BiAttn(nn.Module): 138 | def __init__(self, in_channels, act_ratio=0.125, act_fn=nn.GELU, gate_fn=nn.Sigmoid): 139 | super().__init__() 140 | reduce_channels = int(in_channels * act_ratio) 141 | self.norm = nn.LayerNorm(in_channels) 142 | self.global_reduce = nn.Linear(in_channels, reduce_channels) 143 | # self.local_reduce = nn.Linear(in_channels, reduce_channels) 144 | self.act_fn = act_fn() 145 | self.channel_select = nn.Linear(reduce_channels, in_channels) 146 | # self.spatial_select = nn.Linear(reduce_channels * 2, 1) 147 | self.gate_fn = gate_fn() 148 | 149 | def forward(self, x): 150 | ori_x = x 151 | x = self.norm(x) 152 | x_global = x.mean(1, keepdim=True) 153 | x_global = self.act_fn(self.global_reduce(x_global)) 154 | # x_local = self.act_fn(self.local_reduce(x)) 155 | 156 | c_attn = self.channel_select(x_global) 157 | c_attn = self.gate_fn(c_attn) # [B, 1, C] 158 | # s_attn = self.spatial_select(torch.cat([x_local, x_global.expand(-1, x.shape[1], -1)], dim=-1)) 159 | # s_attn = self.gate_fn(s_attn) # [B, N, 1] 160 | 161 | attn = c_attn #* s_attn # [B, N, C] 162 | return ori_x * attn 163 | 164 | 165 | class MultiMamba(nn.Module): 166 | def __init__( 167 | self, 168 | d_model, 169 | d_state=16, 170 | d_conv=4, 171 | expand=2, 172 | dt_rank="auto", 173 | dt_min=0.001, 174 | dt_max=0.1, 175 | dt_init="random", 176 | dt_scale=1.0, 177 | dt_init_floor=1e-4, 178 | conv_bias=True, 179 | bias=False, 180 | use_fast_path=True, # Fused kernel options 181 | layer_idx=None, 182 | device=None, 183 | dtype=None, 184 | bimamba_type="none", 185 | directions=None, 186 | token_size=(14, 14), 187 | use_middle_cls_token=False, 188 | ): 189 | factory_kwargs = {"device": device, "dtype": dtype} 190 | super().__init__() 191 | self.d_model = d_model 192 | self.d_state = d_state 193 | self.d_conv = d_conv 194 | self.expand = expand 195 | self.d_inner = int(self.expand * self.d_model) 196 | self.dt_rank = math.ceil(self.d_model / 16) if dt_rank == "auto" else dt_rank 197 | self.use_fast_path = use_fast_path 198 | self.layer_idx = layer_idx 199 | self.bimamba_type = bimamba_type 200 | self.token_size = token_size 201 | self.use_middle_cls_token = use_middle_cls_token 202 | 203 | self.in_proj = nn.Linear(self.d_model, self.d_inner * 2, bias=bias, **factory_kwargs) 204 | 205 | self.activation = "silu" 206 | self.act = nn.SiLU() 207 | 208 | 209 | self.multi_scan = MultiScan(self.d_inner, choices=directions, token_size=token_size) 210 | '''new for search''' 211 | A = repeat( 212 | torch.arange(1, self.d_state + 1, dtype=torch.float32, device=device), 213 | "n -> d n", 214 | d=self.d_inner, 215 | ).contiguous() 216 | A_log = torch.log(A) # Keep A_log in fp32 217 | for i in range(len(self.multi_scan.choices)): 218 | setattr(self, f'A_log_{i}', nn.Parameter(A_log)) 219 | getattr(self, f'A_log_{i}')._no_weight_decay = True 220 | 221 | conv1d = nn.Conv1d( 222 | in_channels=self.d_inner, 223 | out_channels=self.d_inner, 224 | bias=conv_bias, 225 | kernel_size=d_conv, 226 | groups=self.d_inner, 227 | padding=d_conv - 1, 228 | **factory_kwargs, 229 | ) 230 | setattr(self, f'conv1d_{i}', conv1d) 231 | 232 | x_proj = nn.Linear( 233 | self.d_inner, self.dt_rank + self.d_state * 2, bias=False, **factory_kwargs 234 | ) 235 | setattr(self, f'x_proj_{i}', x_proj) 236 | 237 | dt_proj = nn.Linear(self.dt_rank, self.d_inner, bias=True, **factory_kwargs) 238 | 239 | # Initialize special dt projection to preserve variance at initialization 240 | dt_init_std = self.dt_rank**-0.5 * dt_scale 241 | if dt_init == "constant": 242 | nn.init.constant_(dt_proj.weight, dt_init_std) 243 | elif dt_init == "random": 244 | nn.init.uniform_(dt_proj.weight, -dt_init_std, dt_init_std) 245 | else: 246 | raise NotImplementedError 247 | 248 | # Initialize dt bias so that F.softplus(dt_bias) is between dt_min and dt_max 249 | dt = torch.exp( 250 | torch.rand(self.d_inner, **factory_kwargs) * (math.log(dt_max) - math.log(dt_min)) 251 | + math.log(dt_min) 252 | ).clamp(min=dt_init_floor) 253 | # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759 254 | inv_dt = dt + torch.log(-torch.expm1(-dt)) 255 | with torch.no_grad(): 256 | dt_proj.bias.copy_(inv_dt) 257 | # Our initialization would set all Linear.bias to zero, need to mark this one as _no_reinit 258 | dt_proj.bias._no_reinit = True 259 | 260 | setattr(self, f'dt_proj_{i}', dt_proj) 261 | 262 | D = nn.Parameter(torch.ones(self.d_inner, device=device)) # Keep in fp32 263 | D._no_weight_decay = True 264 | setattr(self, f'D_{i}', D) 265 | 266 | self.out_proj = nn.Linear(self.d_inner, self.d_model, bias=bias, **factory_kwargs) 267 | 268 | self.attn = BiAttn(self.d_inner) 269 | 270 | def forward(self, hidden_states, inference_params=None): 271 | """ 272 | hidden_states: (B, L, D) 273 | Returns: same shape as hidden_states 274 | """ 275 | xz = self.in_proj(hidden_states) 276 | 277 | if self.use_middle_cls_token: 278 | """ 279 | Steps to use middle cls token 280 | # 1. split cls token out 281 | # 2. do 2d scan 282 | # 3. append cls token to the middle 283 | # 4. ssm 284 | # 5. split cls token out 285 | # 6. reverse tokens 286 | # 7. append cls token to the middle 287 | """ 288 | cls_position = (xz.shape[1] - 1) // 2 289 | cls_token = xz[:, cls_position:cls_position+1] 290 | xz = torch.cat([xz[:, :cls_position], xz[:, cls_position+1:]], dim=1) 291 | 292 | xs = self.multi_scan.multi_scan(xz) # [[BDL], [BDL], ...] 293 | if self.use_middle_cls_token: 294 | # step 3 295 | xs = [torch.cat([x[:, :, :cls_position], cls_token.transpose(-2, -1), x[:, :, cls_position:]], dim=2) for x in xs] 296 | 297 | outs = [] 298 | for i, xz in enumerate(xs): 299 | # xz = rearrange(xz, "b l d -> b d l") 300 | A = -torch.exp(getattr(self, f'A_log_{i}').float()) 301 | conv1d = getattr(self, f'conv1d_{i}') 302 | x_proj = getattr(self, f'x_proj_{i}') 303 | dt_proj = getattr(self, f'dt_proj_{i}') 304 | D = getattr(self, f'D_{i}') 305 | 306 | out = mamba_inner_fn_no_out_proj( 307 | xz, 308 | conv1d.weight, 309 | conv1d.bias, 310 | x_proj.weight, 311 | dt_proj.weight, 312 | A, 313 | None, # input-dependent B 314 | None, # input-dependent C 315 | D, 316 | delta_bias=dt_proj.bias.float(), 317 | delta_softplus=True, 318 | ) 319 | outs.append(out) 320 | 321 | if self.use_middle_cls_token: 322 | # step 5 323 | new_outs = [] 324 | cls_tokens = [] 325 | for out in outs: 326 | cls_tokens.append(out[:, :, cls_position:cls_position+1]) 327 | new_outs.append(torch.cat([out[:, :, :cls_position], out[:, :, cls_position+1:]], dim=2)) 328 | outs = new_outs 329 | 330 | outs = self.multi_scan.multi_reverse(outs) 331 | 332 | if self.use_middle_cls_token: 333 | # step 7 334 | new_outs = [] 335 | for out, cls_token in zip(outs, cls_tokens): 336 | new_outs.append(torch.cat([out[:, :, :cls_position], cls_token, out[:, :, cls_position:]], dim=2)) 337 | outs = new_outs 338 | 339 | outs = [self.attn(rearrange(out, 'b d l -> b l d')) for out in outs] 340 | out = self.multi_scan(outs) 341 | out = F.linear(out, self.out_proj.weight, self.out_proj.bias) 342 | 343 | return out 344 | 345 | 346 | try: 347 | import selective_scan_cuda_oflex 348 | except: 349 | selective_scan_cuda_oflex = None 350 | 351 | class SelectiveScanOflex(torch.autograd.Function): 352 | @staticmethod 353 | @torch.cuda.amp.custom_fwd 354 | def forward(ctx, u, delta, A, B, C, D=None, delta_bias=None, delta_softplus=False, oflex=True): 355 | ctx.delta_softplus = delta_softplus 356 | out, x, *rest = selective_scan_cuda_oflex.fwd(u, delta, A, B, C, D, delta_bias, delta_softplus, 1, oflex) 357 | ctx.save_for_backward(u, delta, A, B, C, D, delta_bias, x) 358 | return out 359 | 360 | @staticmethod 361 | @torch.cuda.amp.custom_bwd 362 | def backward(ctx, dout, *args): 363 | u, delta, A, B, C, D, delta_bias, x = ctx.saved_tensors 364 | if dout.stride(-1) != 1: 365 | dout = dout.contiguous() 366 | du, ddelta, dA, dB, dC, dD, ddelta_bias, *rest = selective_scan_cuda_oflex.bwd( 367 | u, delta, A, B, C, D, delta_bias, dout, x, ctx.delta_softplus, 1 368 | ) 369 | return du, ddelta, dA, dB, dC, dD, ddelta_bias, None, None 370 | 371 | 372 | class MultiVMamba(nn.Module): 373 | def __init__( 374 | self, 375 | d_model, 376 | d_state=16, 377 | d_conv=4, 378 | expand=2, 379 | dt_rank="auto", 380 | dt_min=0.001, 381 | dt_max=0.1, 382 | dt_init="random", 383 | dt_scale=1.0, 384 | dt_init_floor=1e-4, 385 | conv_bias=True, 386 | bias=False, 387 | use_fast_path=True, # Fused kernel options 388 | layer_idx=None, 389 | device=None, 390 | dtype=None, 391 | bimamba_type="none", 392 | directions=None, 393 | token_size=(14, 14), 394 | ): 395 | factory_kwargs = {"device": device, "dtype": dtype} 396 | super().__init__() 397 | self.d_model = d_model 398 | self.d_state = d_state 399 | self.d_conv = d_conv 400 | self.expand = expand 401 | self.d_inner = int(self.expand * self.d_model) 402 | self.dt_rank = math.ceil(self.d_model / 16) if dt_rank == "auto" else dt_rank 403 | self.use_fast_path = use_fast_path 404 | self.layer_idx = layer_idx 405 | self.bimamba_type = bimamba_type 406 | self.token_size = token_size 407 | 408 | self.in_proj = nn.Linear(self.d_model, self.d_inner * 2, bias=bias, **factory_kwargs) 409 | 410 | self.activation = "silu" 411 | self.act = nn.SiLU() 412 | 413 | 414 | self.multi_scan = MultiScan(self.d_inner, choices=directions, token_size=token_size) 415 | '''new for search''' 416 | A = repeat( 417 | torch.arange(1, self.d_state + 1, dtype=torch.float32, device=device), 418 | "n -> d n", 419 | d=self.d_inner, 420 | ).contiguous() 421 | A_log = torch.log(A) # Keep A_log in fp32 422 | for i in range(len(self.multi_scan.choices)): 423 | setattr(self, f'A_log_{i}', nn.Parameter(A_log)) 424 | getattr(self, f'A_log_{i}')._no_weight_decay = True 425 | 426 | x_proj = nn.Linear( 427 | self.d_inner, self.dt_rank + self.d_state * 2, bias=False, **factory_kwargs 428 | ) 429 | setattr(self, f'x_proj_{i}', x_proj) 430 | 431 | conv1d = nn.Conv1d( 432 | in_channels=self.d_inner, 433 | out_channels=self.d_inner, 434 | bias=conv_bias, 435 | kernel_size=d_conv, 436 | groups=self.d_inner, 437 | padding=d_conv - 1, 438 | **factory_kwargs, 439 | ) 440 | setattr(self, f'conv1d_{i}', conv1d) 441 | 442 | dt_proj = nn.Linear(self.dt_rank, self.d_inner, bias=True, **factory_kwargs) 443 | 444 | # Initialize special dt projection to preserve variance at initialization 445 | dt_init_std = self.dt_rank**-0.5 * dt_scale 446 | if dt_init == "constant": 447 | nn.init.constant_(dt_proj.weight, dt_init_std) 448 | elif dt_init == "random": 449 | nn.init.uniform_(dt_proj.weight, -dt_init_std, dt_init_std) 450 | else: 451 | raise NotImplementedError 452 | 453 | # Initialize dt bias so that F.softplus(dt_bias) is between dt_min and dt_max 454 | dt = torch.exp( 455 | torch.rand(self.d_inner, **factory_kwargs) * (math.log(dt_max) - math.log(dt_min)) 456 | + math.log(dt_min) 457 | ).clamp(min=dt_init_floor) 458 | # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759 459 | inv_dt = dt + torch.log(-torch.expm1(-dt)) 460 | with torch.no_grad(): 461 | dt_proj.bias.copy_(inv_dt) 462 | # Our initialization would set all Linear.bias to zero, need to mark this one as _no_reinit 463 | dt_proj.bias._no_reinit = True 464 | 465 | setattr(self, f'dt_proj_{i}', dt_proj) 466 | 467 | D = nn.Parameter(torch.ones(self.d_inner, device=device)) # Keep in fp32 468 | D._no_weight_decay = True 469 | setattr(self, f'D_{i}', D) 470 | 471 | self.out_proj = nn.Linear(self.d_inner, self.d_model, bias=bias, **factory_kwargs) 472 | 473 | self.attn = BiAttn(self.d_inner) 474 | 475 | def forward(self, hidden_states, inference_params=None): 476 | """ 477 | hidden_states: (B, L, D) 478 | Returns: same shape as hidden_states 479 | """ 480 | batch_size, seq_len, dim = hidden_states.shape 481 | xz = self.in_proj(hidden_states) 482 | x, z = xz.chunk(2, dim=2) 483 | z = self.act(z) 484 | 485 | xs = self.multi_scan.multi_scan(x) 486 | outs = [] 487 | for i, xz in enumerate(xs): 488 | xz = rearrange(xz, "b l d -> b d l") 489 | A = -torch.exp(getattr(self, f'A_log_{i}').float()) 490 | x_proj = getattr(self, f'x_proj_{i}') 491 | conv1d = getattr(self, f'conv1d_{i}') 492 | dt_proj = getattr(self, f'dt_proj_{i}') 493 | D = getattr(self, f'D_{i}') 494 | 495 | xz = conv1d(xz)[:, :, :seq_len] 496 | xz = self.act(xz) 497 | 498 | N = A.shape[-1] 499 | R = dt_proj.weight.shape[-1] 500 | 501 | x_dbl = F.linear(rearrange(xz, 'b d l -> b l d'), x_proj.weight) 502 | dts, B, C = torch.split(x_dbl, [R, N, N], dim=2) 503 | dts = F.linear(dts, dt_proj.weight) 504 | 505 | dts = rearrange(dts, 'b l d -> b d l') 506 | B = rearrange(B, 'b l d -> b 1 d l') 507 | C = rearrange(C, 'b l d -> b 1 d l') 508 | D = D.float() 509 | delta_bias = dt_proj.bias.float() 510 | 511 | out = SelectiveScanOflex.apply(xz.contiguous(), dts.contiguous(), A.contiguous(), B.contiguous(), C.contiguous(), D.contiguous(), delta_bias, True, True) 512 | 513 | outs.append(rearrange(out, "b d l -> b l d")) 514 | 515 | outs = self.multi_scan.multi_reverse(outs) 516 | outs = [self.attn(out) for out in outs] 517 | out = self.multi_scan(outs) 518 | out = out * z 519 | out = self.out_proj(out) 520 | 521 | return out 522 | 523 | -------------------------------------------------------------------------------- /baseline/hrnet.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------ 2 | # Copyright (c) Microsoft 3 | # Licensed under the MIT License. 4 | # Written by Ke Sun (sunk@mail.ustc.edu.cn) 5 | # ------------------------------------------------------------------------------ 6 | 7 | from __future__ import absolute_import 8 | from __future__ import division 9 | from __future__ import print_function 10 | 11 | import os 12 | import logging 13 | import functools 14 | 15 | import numpy as np 16 | 17 | import torch 18 | import torch.nn as nn 19 | import torch._utils 20 | import torch.nn.functional as F 21 | 22 | # from .bn_helper import BatchNorm2d, BatchNorm2d_class, relu_inplace 23 | 24 | BatchNorm2d = nn.BatchNorm2d 25 | # BN_MOMENTUM = 0.01 26 | relu_inplace = True 27 | BN_MOMENTUM = 0.1 28 | ALIGN_CORNERS = True 29 | 30 | logger = logging.getLogger(__name__) 31 | 32 | def conv3x3(in_planes, out_planes, stride=1): 33 | """3x3 convolution with padding""" 34 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 35 | padding=1, bias=False) 36 | 37 | from yacs.config import CfgNode as CN 38 | # configs for HRNet48 39 | HRNET_48 = CN() 40 | HRNET_48.FINAL_CONV_KERNEL = 1 41 | 42 | HRNET_48.STAGE1 = CN() 43 | HRNET_48.STAGE1.NUM_MODULES = 1 44 | HRNET_48.STAGE1.NUM_BRANCHES = 1 45 | HRNET_48.STAGE1.NUM_BLOCKS = [4] 46 | HRNET_48.STAGE1.NUM_CHANNELS = [64] 47 | HRNET_48.STAGE1.BLOCK = 'BOTTLENECK' 48 | HRNET_48.STAGE1.FUSE_METHOD = 'SUM' 49 | 50 | HRNET_48.STAGE2 = CN() 51 | HRNET_48.STAGE2.NUM_MODULES = 1 52 | HRNET_48.STAGE2.NUM_BRANCHES = 2 53 | HRNET_48.STAGE2.NUM_BLOCKS = [4, 4] 54 | HRNET_48.STAGE2.NUM_CHANNELS = [48, 96] 55 | HRNET_48.STAGE2.BLOCK = 'BASIC' 56 | HRNET_48.STAGE2.FUSE_METHOD = 'SUM' 57 | 58 | HRNET_48.STAGE3 = CN() 59 | HRNET_48.STAGE3.NUM_MODULES = 4 60 | HRNET_48.STAGE3.NUM_BRANCHES = 3 61 | HRNET_48.STAGE3.NUM_BLOCKS = [4, 4, 4] 62 | HRNET_48.STAGE3.NUM_CHANNELS = [48, 96, 192] 63 | HRNET_48.STAGE3.BLOCK = 'BASIC' 64 | HRNET_48.STAGE3.FUSE_METHOD = 'SUM' 65 | 66 | HRNET_48.STAGE4 = CN() 67 | HRNET_48.STAGE4.NUM_MODULES = 3 68 | HRNET_48.STAGE4.NUM_BRANCHES = 4 69 | HRNET_48.STAGE4.NUM_BLOCKS = [4, 4, 4, 4] 70 | HRNET_48.STAGE4.NUM_CHANNELS = [48, 96, 192, 384] 71 | HRNET_48.STAGE4.BLOCK = 'BASIC' 72 | HRNET_48.STAGE4.FUSE_METHOD = 'SUM' 73 | 74 | HRNET_32 = CN() 75 | HRNET_32.FINAL_CONV_KERNEL = 1 76 | 77 | HRNET_32.STAGE1 = CN() 78 | HRNET_32.STAGE1.NUM_MODULES = 1 79 | HRNET_32.STAGE1.NUM_BRANCHES = 1 80 | HRNET_32.STAGE1.NUM_BLOCKS = [4] 81 | HRNET_32.STAGE1.NUM_CHANNELS = [64] 82 | HRNET_32.STAGE1.BLOCK = 'BOTTLENECK' 83 | HRNET_32.STAGE1.FUSE_METHOD = 'SUM' 84 | 85 | HRNET_32.STAGE2 = CN() 86 | HRNET_32.STAGE2.NUM_MODULES = 1 87 | HRNET_32.STAGE2.NUM_BRANCHES = 2 88 | HRNET_32.STAGE2.NUM_BLOCKS = [4, 4] 89 | HRNET_32.STAGE2.NUM_CHANNELS = [32, 64] 90 | HRNET_32.STAGE2.BLOCK = 'BASIC' 91 | HRNET_32.STAGE2.FUSE_METHOD = 'SUM' 92 | 93 | HRNET_32.STAGE3 = CN() 94 | HRNET_32.STAGE3.NUM_MODULES = 4 95 | HRNET_32.STAGE3.NUM_BRANCHES = 3 96 | HRNET_32.STAGE3.NUM_BLOCKS = [4, 4, 4] 97 | HRNET_32.STAGE3.NUM_CHANNELS = [32, 64, 128] 98 | HRNET_32.STAGE3.BLOCK = 'BASIC' 99 | HRNET_32.STAGE3.FUSE_METHOD = 'SUM' 100 | 101 | HRNET_32.STAGE4 = CN() 102 | HRNET_32.STAGE4.NUM_MODULES = 3 103 | HRNET_32.STAGE4.NUM_BRANCHES = 4 104 | HRNET_32.STAGE4.NUM_BLOCKS = [4, 4, 4, 4] 105 | HRNET_32.STAGE4.NUM_CHANNELS = [32, 64, 128, 256] 106 | HRNET_32.STAGE4.BLOCK = 'BASIC' 107 | HRNET_32.STAGE4.FUSE_METHOD = 'SUM' 108 | 109 | 110 | HRNET_18 = CN() 111 | HRNET_18.FINAL_CONV_KERNEL = 1 112 | 113 | HRNET_18.STAGE1 = CN() 114 | HRNET_18.STAGE1.NUM_MODULES = 1 115 | HRNET_18.STAGE1.NUM_BRANCHES = 1 116 | HRNET_18.STAGE1.NUM_BLOCKS = [4] 117 | HRNET_18.STAGE1.NUM_CHANNELS = [64] 118 | HRNET_18.STAGE1.BLOCK = 'BOTTLENECK' 119 | HRNET_18.STAGE1.FUSE_METHOD = 'SUM' 120 | 121 | HRNET_18.STAGE2 = CN() 122 | HRNET_18.STAGE2.NUM_MODULES = 1 123 | HRNET_18.STAGE2.NUM_BRANCHES = 2 124 | HRNET_18.STAGE2.NUM_BLOCKS = [4, 4] 125 | HRNET_18.STAGE2.NUM_CHANNELS = [18, 36] 126 | HRNET_18.STAGE2.BLOCK = 'BASIC' 127 | HRNET_18.STAGE2.FUSE_METHOD = 'SUM' 128 | 129 | HRNET_18.STAGE3 = CN() 130 | HRNET_18.STAGE3.NUM_MODULES = 4 131 | HRNET_18.STAGE3.NUM_BRANCHES = 3 132 | HRNET_18.STAGE3.NUM_BLOCKS = [4, 4, 4] 133 | HRNET_18.STAGE3.NUM_CHANNELS = [18, 36, 72] 134 | HRNET_18.STAGE3.BLOCK = 'BASIC' 135 | HRNET_18.STAGE3.FUSE_METHOD = 'SUM' 136 | 137 | HRNET_18.STAGE4 = CN() 138 | HRNET_18.STAGE4.NUM_MODULES = 3 139 | HRNET_18.STAGE4.NUM_BRANCHES = 4 140 | HRNET_18.STAGE4.NUM_BLOCKS = [4, 4, 4, 4] 141 | HRNET_18.STAGE4.NUM_CHANNELS = [18, 36, 72, 144] 142 | HRNET_18.STAGE4.BLOCK = 'BASIC' 143 | HRNET_18.STAGE4.FUSE_METHOD = 'SUM' 144 | 145 | class BasicBlock(nn.Module): 146 | expansion = 1 147 | 148 | def __init__(self, inplanes, planes, stride=1, downsample=None): 149 | super(BasicBlock, self).__init__() 150 | self.conv1 = conv3x3(inplanes, planes, stride) 151 | self.bn1 = BatchNorm2d(planes, momentum=BN_MOMENTUM) 152 | self.relu = nn.ReLU(inplace=relu_inplace) 153 | self.conv2 = conv3x3(planes, planes) 154 | self.bn2 = BatchNorm2d(planes, momentum=BN_MOMENTUM) 155 | self.downsample = downsample 156 | self.stride = stride 157 | 158 | def forward(self, x): 159 | residual = x 160 | 161 | out = self.conv1(x) 162 | out = self.bn1(out) 163 | out = self.relu(out) 164 | 165 | out = self.conv2(out) 166 | out = self.bn2(out) 167 | 168 | if self.downsample is not None: 169 | residual = self.downsample(x) 170 | 171 | out = out + residual 172 | out = self.relu(out) 173 | 174 | return out 175 | 176 | 177 | class Bottleneck(nn.Module): 178 | expansion = 4 179 | 180 | def __init__(self, inplanes, planes, stride=1, downsample=None): 181 | super(Bottleneck, self).__init__() 182 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 183 | self.bn1 = BatchNorm2d(planes, momentum=BN_MOMENTUM) 184 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 185 | padding=1, bias=False) 186 | self.bn2 = BatchNorm2d(planes, momentum=BN_MOMENTUM) 187 | self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1, 188 | bias=False) 189 | self.bn3 = BatchNorm2d(planes * self.expansion, 190 | momentum=BN_MOMENTUM) 191 | self.relu = nn.ReLU(inplace=relu_inplace) 192 | self.downsample = downsample 193 | self.stride = stride 194 | 195 | def forward(self, x): 196 | residual = x 197 | 198 | out = self.conv1(x) 199 | out = self.bn1(out) 200 | out = self.relu(out) 201 | 202 | out = self.conv2(out) 203 | out = self.bn2(out) 204 | out = self.relu(out) 205 | 206 | out = self.conv3(out) 207 | out = self.bn3(out) 208 | 209 | if self.downsample is not None: 210 | residual = self.downsample(x) 211 | 212 | out = out + residual 213 | out = self.relu(out) 214 | 215 | return out 216 | 217 | 218 | class HighResolutionModule(nn.Module): 219 | def __init__(self, num_branches, blocks, num_blocks, num_inchannels, 220 | num_channels, fuse_method, multi_scale_output=True): 221 | super(HighResolutionModule, self).__init__() 222 | self._check_branches( 223 | num_branches, blocks, num_blocks, num_inchannels, num_channels) 224 | 225 | self.num_inchannels = num_inchannels 226 | self.fuse_method = fuse_method 227 | self.num_branches = num_branches 228 | 229 | self.multi_scale_output = multi_scale_output 230 | 231 | self.branches = self._make_branches( 232 | num_branches, blocks, num_blocks, num_channels) 233 | self.fuse_layers = self._make_fuse_layers() 234 | self.relu = nn.ReLU(inplace=relu_inplace) 235 | 236 | def _check_branches(self, num_branches, blocks, num_blocks, 237 | num_inchannels, num_channels): 238 | if num_branches != len(num_blocks): 239 | error_msg = 'NUM_BRANCHES({}) <> NUM_BLOCKS({})'.format( 240 | num_branches, len(num_blocks)) 241 | logger.error(error_msg) 242 | raise ValueError(error_msg) 243 | 244 | if num_branches != len(num_channels): 245 | error_msg = 'NUM_BRANCHES({}) <> NUM_CHANNELS({})'.format( 246 | num_branches, len(num_channels)) 247 | logger.error(error_msg) 248 | raise ValueError(error_msg) 249 | 250 | if num_branches != len(num_inchannels): 251 | error_msg = 'NUM_BRANCHES({}) <> NUM_INCHANNELS({})'.format( 252 | num_branches, len(num_inchannels)) 253 | logger.error(error_msg) 254 | raise ValueError(error_msg) 255 | 256 | def _make_one_branch(self, branch_index, block, num_blocks, num_channels, 257 | stride=1): 258 | downsample = None 259 | if stride != 1 or \ 260 | self.num_inchannels[branch_index] != num_channels[branch_index] * block.expansion: 261 | downsample = nn.Sequential( 262 | nn.Conv2d(self.num_inchannels[branch_index], 263 | num_channels[branch_index] * block.expansion, 264 | kernel_size=1, stride=stride, bias=False), 265 | BatchNorm2d(num_channels[branch_index] * block.expansion, 266 | momentum=BN_MOMENTUM), 267 | ) 268 | 269 | layers = [] 270 | layers.append(block(self.num_inchannels[branch_index], 271 | num_channels[branch_index], stride, downsample)) 272 | self.num_inchannels[branch_index] = \ 273 | num_channels[branch_index] * block.expansion 274 | for i in range(1, num_blocks[branch_index]): 275 | layers.append(block(self.num_inchannels[branch_index], 276 | num_channels[branch_index])) 277 | 278 | return nn.Sequential(*layers) 279 | 280 | def _make_branches(self, num_branches, block, num_blocks, num_channels): 281 | branches = [] 282 | 283 | for i in range(num_branches): 284 | branches.append( 285 | self._make_one_branch(i, block, num_blocks, num_channels)) 286 | 287 | return nn.ModuleList(branches) 288 | 289 | def _make_fuse_layers(self): 290 | if self.num_branches == 1: 291 | return None 292 | 293 | num_branches = self.num_branches 294 | num_inchannels = self.num_inchannels 295 | fuse_layers = [] 296 | for i in range(num_branches if self.multi_scale_output else 1): 297 | fuse_layer = [] 298 | for j in range(num_branches): 299 | if j > i: 300 | fuse_layer.append(nn.Sequential( 301 | nn.Conv2d(num_inchannels[j], 302 | num_inchannels[i], 303 | 1, 304 | 1, 305 | 0, 306 | bias=False), 307 | BatchNorm2d(num_inchannels[i], momentum=BN_MOMENTUM))) 308 | elif j == i: 309 | fuse_layer.append(None) 310 | else: 311 | conv3x3s = [] 312 | for k in range(i-j): 313 | if k == i - j - 1: 314 | num_outchannels_conv3x3 = num_inchannels[i] 315 | conv3x3s.append(nn.Sequential( 316 | nn.Conv2d(num_inchannels[j], 317 | num_outchannels_conv3x3, 318 | 3, 2, 1, bias=False), 319 | BatchNorm2d(num_outchannels_conv3x3, 320 | momentum=BN_MOMENTUM))) 321 | else: 322 | num_outchannels_conv3x3 = num_inchannels[j] 323 | conv3x3s.append(nn.Sequential( 324 | nn.Conv2d(num_inchannels[j], 325 | num_outchannels_conv3x3, 326 | 3, 2, 1, bias=False), 327 | BatchNorm2d(num_outchannels_conv3x3, 328 | momentum=BN_MOMENTUM), 329 | nn.ReLU(inplace=relu_inplace))) 330 | fuse_layer.append(nn.Sequential(*conv3x3s)) 331 | fuse_layers.append(nn.ModuleList(fuse_layer)) 332 | 333 | return nn.ModuleList(fuse_layers) 334 | 335 | def get_num_inchannels(self): 336 | return self.num_inchannels 337 | 338 | def forward(self, x): 339 | if self.num_branches == 1: 340 | return [self.branches[0](x[0])] 341 | 342 | for i in range(self.num_branches): 343 | x[i] = self.branches[i](x[i]) 344 | 345 | x_fuse = [] 346 | for i in range(len(self.fuse_layers)): 347 | y = x[0] if i == 0 else self.fuse_layers[i][0](x[0]) 348 | for j in range(1, self.num_branches): 349 | if i == j: 350 | y = y + x[j] 351 | elif j > i: 352 | width_output = x[i].shape[-1] 353 | height_output = x[i].shape[-2] 354 | y = y + F.interpolate( 355 | self.fuse_layers[i][j](x[j]), 356 | size=[height_output, width_output], 357 | mode='bilinear', align_corners=ALIGN_CORNERS) 358 | else: 359 | y = y + self.fuse_layers[i][j](x[j]) 360 | x_fuse.append(self.relu(y)) 361 | 362 | return x_fuse 363 | 364 | blocks_dict = { 365 | 'BASIC': BasicBlock, 366 | 'BOTTLENECK': Bottleneck 367 | } 368 | 369 | class HRNet(nn.Module): 370 | 371 | def __init__(self, num_classes, **kwargs): 372 | global ALIGN_CORNERS 373 | extra = HRNET_48 374 | super(HRNet, self).__init__() 375 | ALIGN_CORNERS = True 376 | # ALIGN_CORNERS = config.MODEL.ALIGN_CORNERS 377 | self.num_classes = num_classes 378 | # stem net 379 | self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=2, padding=1, 380 | bias=False) 381 | self.bn1 = BatchNorm2d(64, momentum=BN_MOMENTUM) 382 | self.conv2 = nn.Conv2d(64, 64, kernel_size=3, stride=2, padding=1, 383 | bias=False) 384 | self.bn2 = BatchNorm2d(64, momentum=BN_MOMENTUM) 385 | self.relu = nn.ReLU(inplace=relu_inplace) 386 | 387 | self.stage1_cfg = extra['STAGE1'] 388 | num_channels = self.stage1_cfg['NUM_CHANNELS'][0] 389 | block = blocks_dict[self.stage1_cfg['BLOCK']] 390 | num_blocks = self.stage1_cfg['NUM_BLOCKS'][0] 391 | self.layer1 = self._make_layer(block, 64, num_channels, num_blocks) 392 | stage1_out_channel = block.expansion*num_channels 393 | 394 | self.stage2_cfg = extra['STAGE2'] 395 | num_channels = self.stage2_cfg['NUM_CHANNELS'] 396 | block = blocks_dict[self.stage2_cfg['BLOCK']] 397 | num_channels = [ 398 | num_channels[i] * block.expansion for i in range(len(num_channels))] 399 | self.transition1 = self._make_transition_layer( 400 | [stage1_out_channel], num_channels) 401 | self.stage2, pre_stage_channels = self._make_stage( 402 | self.stage2_cfg, num_channels) 403 | 404 | self.stage3_cfg = extra['STAGE3'] 405 | num_channels = self.stage3_cfg['NUM_CHANNELS'] 406 | block = blocks_dict[self.stage3_cfg['BLOCK']] 407 | num_channels = [ 408 | num_channels[i] * block.expansion for i in range(len(num_channels))] 409 | self.transition2 = self._make_transition_layer( 410 | pre_stage_channels, num_channels) 411 | self.stage3, pre_stage_channels = self._make_stage( 412 | self.stage3_cfg, num_channels) 413 | 414 | self.stage4_cfg = extra['STAGE4'] 415 | num_channels = self.stage4_cfg['NUM_CHANNELS'] 416 | block = blocks_dict[self.stage4_cfg['BLOCK']] 417 | num_channels = [ 418 | num_channels[i] * block.expansion for i in range(len(num_channels))] 419 | self.transition3 = self._make_transition_layer( 420 | pre_stage_channels, num_channels) 421 | self.stage4, pre_stage_channels = self._make_stage( 422 | self.stage4_cfg, num_channels, multi_scale_output=True) 423 | 424 | last_inp_channels = int(np.sum(pre_stage_channels)) 425 | 426 | self.last_layer = nn.Sequential( 427 | nn.Conv2d( 428 | in_channels=last_inp_channels, 429 | out_channels=last_inp_channels, 430 | kernel_size=1, 431 | stride=1, 432 | padding=0), 433 | BatchNorm2d(last_inp_channels, momentum=BN_MOMENTUM), 434 | nn.ReLU(inplace=relu_inplace), 435 | nn.Conv2d( 436 | in_channels=last_inp_channels, 437 | out_channels=self.num_classes, 438 | kernel_size=extra.FINAL_CONV_KERNEL, 439 | stride=1, 440 | padding=1 if extra.FINAL_CONV_KERNEL == 3 else 0) 441 | ) 442 | self.init_weights() 443 | 444 | def _make_transition_layer( 445 | self, num_channels_pre_layer, num_channels_cur_layer): 446 | num_branches_cur = len(num_channels_cur_layer) 447 | num_branches_pre = len(num_channels_pre_layer) 448 | 449 | transition_layers = [] 450 | for i in range(num_branches_cur): 451 | if i < num_branches_pre: 452 | if num_channels_cur_layer[i] != num_channels_pre_layer[i]: 453 | transition_layers.append(nn.Sequential( 454 | nn.Conv2d(num_channels_pre_layer[i], 455 | num_channels_cur_layer[i], 456 | 3, 457 | 1, 458 | 1, 459 | bias=False), 460 | BatchNorm2d( 461 | num_channels_cur_layer[i], momentum=BN_MOMENTUM), 462 | nn.ReLU(inplace=relu_inplace))) 463 | else: 464 | transition_layers.append(None) 465 | else: 466 | conv3x3s = [] 467 | for j in range(i+1-num_branches_pre): 468 | inchannels = num_channels_pre_layer[-1] 469 | outchannels = num_channels_cur_layer[i] \ 470 | if j == i-num_branches_pre else inchannels 471 | conv3x3s.append(nn.Sequential( 472 | nn.Conv2d( 473 | inchannels, outchannels, 3, 2, 1, bias=False), 474 | BatchNorm2d(outchannels, momentum=BN_MOMENTUM), 475 | nn.ReLU(inplace=relu_inplace))) 476 | transition_layers.append(nn.Sequential(*conv3x3s)) 477 | 478 | return nn.ModuleList(transition_layers) 479 | 480 | def _make_layer(self, block, inplanes, planes, blocks, stride=1): 481 | downsample = None 482 | if stride != 1 or inplanes != planes * block.expansion: 483 | downsample = nn.Sequential( 484 | nn.Conv2d(inplanes, planes * block.expansion, 485 | kernel_size=1, stride=stride, bias=False), 486 | BatchNorm2d(planes * block.expansion, momentum=BN_MOMENTUM), 487 | ) 488 | 489 | layers = [] 490 | layers.append(block(inplanes, planes, stride, downsample)) 491 | inplanes = planes * block.expansion 492 | for i in range(1, blocks): 493 | layers.append(block(inplanes, planes)) 494 | 495 | return nn.Sequential(*layers) 496 | 497 | def _make_stage(self, layer_config, num_inchannels, 498 | multi_scale_output=True): 499 | num_modules = layer_config['NUM_MODULES'] 500 | num_branches = layer_config['NUM_BRANCHES'] 501 | num_blocks = layer_config['NUM_BLOCKS'] 502 | num_channels = layer_config['NUM_CHANNELS'] 503 | block = blocks_dict[layer_config['BLOCK']] 504 | fuse_method = layer_config['FUSE_METHOD'] 505 | 506 | modules = [] 507 | for i in range(num_modules): 508 | # multi_scale_output is only used last module 509 | if not multi_scale_output and i == num_modules - 1: 510 | reset_multi_scale_output = False 511 | else: 512 | reset_multi_scale_output = True 513 | modules.append( 514 | HighResolutionModule(num_branches, 515 | block, 516 | num_blocks, 517 | num_inchannels, 518 | num_channels, 519 | fuse_method, 520 | reset_multi_scale_output) 521 | ) 522 | num_inchannels = modules[-1].get_num_inchannels() 523 | 524 | return nn.Sequential(*modules), num_inchannels 525 | 526 | def forward(self, input): 527 | x = self.conv1(input) 528 | x = self.bn1(x) 529 | x = self.relu(x) 530 | x = self.conv2(x) 531 | x = self.bn2(x) 532 | x = self.relu(x) 533 | x = self.layer1(x) 534 | 535 | x_list = [] 536 | for i in range(self.stage2_cfg['NUM_BRANCHES']): 537 | if self.transition1[i] is not None: 538 | x_list.append(self.transition1[i](x)) 539 | else: 540 | x_list.append(x) 541 | y_list = self.stage2(x_list) 542 | 543 | x_list = [] 544 | for i in range(self.stage3_cfg['NUM_BRANCHES']): 545 | if self.transition2[i] is not None: 546 | if i < self.stage2_cfg['NUM_BRANCHES']: 547 | x_list.append(self.transition2[i](y_list[i])) 548 | else: 549 | x_list.append(self.transition2[i](y_list[-1])) 550 | else: 551 | x_list.append(y_list[i]) 552 | y_list = self.stage3(x_list) 553 | 554 | x_list = [] 555 | for i in range(self.stage4_cfg['NUM_BRANCHES']): 556 | if self.transition3[i] is not None: 557 | if i < self.stage3_cfg['NUM_BRANCHES']: 558 | x_list.append(self.transition3[i](y_list[i])) 559 | else: 560 | x_list.append(self.transition3[i](y_list[-1])) 561 | else: 562 | x_list.append(y_list[i]) 563 | x = self.stage4(x_list) 564 | 565 | # Upsampling 566 | x0_h, x0_w = x[0].size(2), x[0].size(3) 567 | x1 = F.interpolate(x[1], size=(x0_h, x0_w), mode='bilinear', align_corners=ALIGN_CORNERS) 568 | x2 = F.interpolate(x[2], size=(x0_h, x0_w), mode='bilinear', align_corners=ALIGN_CORNERS) 569 | x3 = F.interpolate(x[3], size=(x0_h, x0_w), mode='bilinear', align_corners=ALIGN_CORNERS) 570 | 571 | x = torch.cat([x[0], x1, x2, x3], 1) 572 | 573 | x = self.last_layer(x) 574 | x = F.interpolate(x, size=input.size()[2:], mode='bilinear', align_corners=True) 575 | logits = x 576 | 577 | return {"out": logits} 578 | # return x 579 | 580 | def init_weights(self, pretrained='',): 581 | logger.info('=> init weights from normal distribution') 582 | for m in self.modules(): 583 | if isinstance(m, nn.Conv2d): 584 | nn.init.normal_(m.weight, std=0.001) 585 | elif isinstance(m, nn.BatchNorm2d): 586 | nn.init.constant_(m.weight, 1) 587 | nn.init.constant_(m.bias, 0) 588 | if os.path.isfile(pretrained): 589 | pretrained_dict = torch.load(pretrained) 590 | logger.info('=> loading pretrained model {}'.format(pretrained)) 591 | model_dict = self.state_dict() 592 | pretrained_dict = {k: v for k, v in pretrained_dict.items() 593 | if k in model_dict.keys()} 594 | for k, _ in pretrained_dict.items(): 595 | logger.info( 596 | '=> loading {} pretrained model {}'.format(k, pretrained)) 597 | model_dict.update(pretrained_dict) 598 | self.load_state_dict(model_dict) 599 | 600 | # def get_seg_model(cfg, **kwargs): 601 | # model = HighResolutionNet(cfg, **kwargs) 602 | # model.init_weights(cfg.MODEL.PRETRAINED) 603 | 604 | # return model 605 | -------------------------------------------------------------------------------- /baseline/local_vmamba.py: -------------------------------------------------------------------------------- 1 | import math 2 | import copy 3 | import warnings 4 | from functools import partial 5 | from typing import Optional, Callable, Any 6 | from collections import OrderedDict 7 | 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | import torch.utils.checkpoint as checkpoint 12 | from einops import rearrange, repeat 13 | from timm.models.layers import DropPath, trunc_normal_ 14 | from timm.models.registry import register_model 15 | from fvcore.nn import FlopCountAnalysis, flop_count_str, flop_count, parameter_count 16 | from baseline.multi_mamba import MultiScan 17 | 18 | 19 | DropPath.__repr__ = lambda self: f"timm.DropPath({self.drop_prob})" 20 | 21 | 22 | try: 23 | "sscore acts the same as mamba_ssm" 24 | SSMODE = "sscore" 25 | import selective_scan_cuda_core 26 | print("Using \"selective_scan_cuda_core\"") 27 | except Exception as e: 28 | warnings.warn(f"{e}\n\"selective_scan_cuda_core\" not found, use default \"selective_scan_cuda\" instead.") 29 | # print(e, flush=True) 30 | SSMODE = "mamba_ssm" 31 | import selective_scan_cuda 32 | 33 | 34 | # fvcore flops ======================================= 35 | 36 | def flops_selective_scan_fn(B=1, L=256, D=768, N=16, with_D=True, with_Z=False, with_Group=True, with_complex=False): 37 | """ 38 | u: r(B D L) 39 | delta: r(B D L) 40 | A: r(D N) 41 | B: r(B N L) 42 | C: r(B N L) 43 | D: r(D) 44 | z: r(B D L) 45 | delta_bias: r(D), fp32 46 | 47 | ignores: 48 | [.float(), +, .softplus, .shape, new_zeros, repeat, stack, to(dtype), silu] 49 | """ 50 | assert not with_complex 51 | # https://github.com/state-spaces/mamba/issues/110 52 | flops = 9 * B * L * D * N 53 | if with_D: 54 | flops += B * D * L 55 | if with_Z: 56 | flops += B * D * L 57 | return flops 58 | 59 | def selective_scan_flop_jit(inputs, outputs): 60 | B, D, L = inputs[0].type().sizes() 61 | N = inputs[2].type().sizes()[1] 62 | flops = flops_selective_scan_fn(B=B, L=L, D=D, N=N, with_D=True, with_Z=False, with_Group=True) 63 | return flops 64 | 65 | 66 | class SelectiveScan(torch.autograd.Function): 67 | 68 | @staticmethod 69 | @torch.cuda.amp.custom_fwd(cast_inputs=torch.float32) 70 | def forward(ctx, u, delta, A, B, C, D=None, delta_bias=None, delta_softplus=False, nrows=1): 71 | assert nrows in [1, 2, 3, 4], f"{nrows}" # 8+ is too slow to compile 72 | assert u.shape[1] % (B.shape[1] * nrows) == 0, f"{nrows}, {u.shape}, {B.shape}" 73 | ctx.delta_softplus = delta_softplus 74 | ctx.nrows = nrows 75 | # all in float 76 | if u.stride(-1) != 1: 77 | u = u.contiguous() 78 | if delta.stride(-1) != 1: 79 | delta = delta.contiguous() 80 | if D is not None: 81 | D = D.contiguous() 82 | if B.stride(-1) != 1: 83 | B = B.contiguous() 84 | if C.stride(-1) != 1: 85 | C = C.contiguous() 86 | if B.dim() == 3: 87 | B = B.unsqueeze(dim=1) 88 | ctx.squeeze_B = True 89 | if C.dim() == 3: 90 | C = C.unsqueeze(dim=1) 91 | ctx.squeeze_C = True 92 | 93 | if SSMODE == "mamba_ssm": 94 | out, x, *rest = selective_scan_cuda.fwd(u, delta, A, B, C, D, None, delta_bias, delta_softplus) 95 | else: 96 | out, x, *rest = selective_scan_cuda_core.fwd(u, delta, A, B, C, D, delta_bias, delta_softplus, nrows) 97 | 98 | ctx.save_for_backward(u, delta, A, B, C, D, delta_bias, x) 99 | return out 100 | 101 | @staticmethod 102 | @torch.cuda.amp.custom_bwd 103 | def backward(ctx, dout, *args): 104 | u, delta, A, B, C, D, delta_bias, x = ctx.saved_tensors 105 | if dout.stride(-1) != 1: 106 | dout = dout.contiguous() 107 | 108 | if SSMODE == "mamba_ssm": 109 | du, ddelta, dA, dB, dC, dD, ddelta_bias, *rest = selective_scan_cuda.bwd( 110 | u, delta, A, B, C, D, None, delta_bias, dout, x, None, None, ctx.delta_softplus, 111 | False # option to recompute out_z, not used here 112 | ) 113 | else: 114 | du, ddelta, dA, dB, dC, dD, ddelta_bias, *rest = selective_scan_cuda_core.bwd( 115 | u, delta, A, B, C, D, delta_bias, dout, x, ctx.delta_softplus, 1 116 | # u, delta, A, B, C, D, delta_bias, dout, x, ctx.delta_softplus, ctx.nrows, 117 | ) 118 | 119 | dB = dB.squeeze(1) if getattr(ctx, "squeeze_B", False) else dB 120 | dC = dC.squeeze(1) if getattr(ctx, "squeeze_C", False) else dC 121 | return (du, ddelta, dA, dB, dC, dD, ddelta_bias, None, None) 122 | 123 | 124 | """ 125 | Local Mamba 126 | """ 127 | class MultiScanVSSM(MultiScan): 128 | 129 | ALL_CHOICES = MultiScan.ALL_CHOICES 130 | 131 | def __init__(self, dim, choices=None): 132 | super().__init__(dim, choices=choices, token_size=None) 133 | self.attn = BiAttn(dim) 134 | 135 | def merge(self, xs): 136 | # xs: [B, K, D, L] 137 | # return: [B, D, L] 138 | 139 | # remove the padded tokens 140 | xs = [xs[:, i, :, :l] for i, l in enumerate(self.scan_lengths)] 141 | xs = super().multi_reverse(xs) 142 | xs = [self.attn(x.transpose(-2, -1)) for x in xs] 143 | x = super().forward(xs) 144 | return x 145 | 146 | 147 | def multi_scan(self, x): 148 | # x: [B, C, H, W] 149 | # return: [B, K, C, H * W] 150 | B, C, H, W = x.shape 151 | self.token_size = (H, W) 152 | 153 | xs = super().multi_scan(x) # [[B, C, H, W], ...] 154 | 155 | self.scan_lengths = [x.shape[2] for x in xs] 156 | max_length = max(self.scan_lengths) 157 | 158 | # pad the tokens into the same length as VMamba compute all directions together 159 | new_xs = [] 160 | for x in xs: 161 | if x.shape[2] < max_length: 162 | x = F.pad(x, (0, max_length - x.shape[2])) 163 | new_xs.append(x) 164 | return torch.stack(new_xs, 1) 165 | 166 | def __repr__(self): 167 | scans = ', '.join(self.choices) 168 | return super().__repr__().replace('MultiScanVSSM', f'MultiScanVSSM[{scans}]') 169 | 170 | 171 | class BiAttn(nn.Module): 172 | def __init__(self, in_channels, act_ratio=0.125, act_fn=nn.GELU, gate_fn=nn.Sigmoid): 173 | super().__init__() 174 | reduce_channels = int(in_channels * act_ratio) 175 | self.norm = nn.LayerNorm(in_channels) 176 | self.global_reduce = nn.Linear(in_channels, reduce_channels) 177 | # self.local_reduce = nn.Linear(in_channels, reduce_channels) 178 | self.act_fn = act_fn() 179 | self.channel_select = nn.Linear(reduce_channels, in_channels) 180 | # self.spatial_select = nn.Linear(reduce_channels * 2, 1) 181 | self.gate_fn = gate_fn() 182 | 183 | def forward(self, x): 184 | ori_x = x 185 | x = self.norm(x) 186 | x_global = x.mean(1, keepdim=True) 187 | x_global = self.act_fn(self.global_reduce(x_global)) 188 | # x_local = self.act_fn(self.local_reduce(x)) 189 | 190 | c_attn = self.channel_select(x_global) 191 | c_attn = self.gate_fn(c_attn) # [B, 1, C] 192 | # s_attn = self.spatial_select(torch.cat([x_local, x_global.expand(-1, x.shape[1], -1)], dim=-1)) 193 | # s_attn = self.gate_fn(s_attn) # [B, N, 1] 194 | 195 | attn = c_attn #* s_attn # [B, N, C] 196 | out = ori_x * attn 197 | return out 198 | 199 | 200 | def multi_selective_scan( 201 | x: torch.Tensor=None, 202 | x_proj_weight: torch.Tensor=None, 203 | x_proj_bias: torch.Tensor=None, 204 | dt_projs_weight: torch.Tensor=None, 205 | dt_projs_bias: torch.Tensor=None, 206 | A_logs: torch.Tensor=None, 207 | Ds: torch.Tensor=None, 208 | out_norm: torch.nn.Module=None, 209 | nrows = -1, 210 | delta_softplus = True, 211 | to_dtype=True, 212 | multi_scan=None, 213 | ): 214 | B, D, H, W = x.shape 215 | D, N = A_logs.shape 216 | K, D, R = dt_projs_weight.shape 217 | L = H * W 218 | 219 | if nrows < 1: 220 | if D % 4 == 0: 221 | nrows = 4 222 | elif D % 3 == 0: 223 | nrows = 3 224 | elif D % 2 == 0: 225 | nrows = 2 226 | else: 227 | nrows = 1 228 | 229 | xs = multi_scan.multi_scan(x) 230 | 231 | L = xs.shape[-1] 232 | x_dbl = torch.einsum("b k d l, k c d -> b k c l", xs, x_proj_weight) # l fixed 233 | 234 | if x_proj_bias is not None: 235 | x_dbl = x_dbl + x_proj_bias.view(1, K, -1, 1) 236 | dts, Bs, Cs = torch.split(x_dbl, [R, N, N], dim=2) 237 | dts = torch.einsum("b k r l, k d r -> b k d l", dts, dt_projs_weight) 238 | 239 | xs = xs.view(B, -1, L).to(torch.float) 240 | dts = dts.contiguous().view(B, -1, L).to(torch.float) 241 | As = -torch.exp(A_logs.to(torch.float)) # (k * c, d_state) 242 | Bs = Bs.contiguous().to(torch.float) 243 | Cs = Cs.contiguous().to(torch.float) 244 | Ds = Ds.to(torch.float) # (K * c) 245 | delta_bias = dt_projs_bias.view(-1).to(torch.float) 246 | 247 | def selective_scan(u, delta, A, B, C, D=None, delta_bias=None, delta_softplus=True, nrows=1): 248 | return SelectiveScan.apply(u, delta, A, B, C, D, delta_bias, delta_softplus, nrows) 249 | 250 | ys: torch.Tensor = selective_scan( 251 | xs, dts, As, Bs, Cs, Ds, delta_bias, delta_softplus, nrows, 252 | ).view(B, K, -1, L) 253 | 254 | y = multi_scan.merge(ys) 255 | 256 | y = out_norm(y).view(B, H, W, -1) 257 | 258 | return (y.to(x.dtype) if to_dtype else y) 259 | 260 | 261 | class PatchMerging2D(nn.Module): 262 | def __init__(self, dim, out_dim=-1, norm_layer=nn.LayerNorm): 263 | super().__init__() 264 | self.dim = dim 265 | self.reduction = nn.Linear(4 * dim, (2 * dim) if out_dim < 0 else out_dim, bias=False) 266 | self.norm = norm_layer(4 * dim) 267 | 268 | @staticmethod 269 | def _patch_merging_pad(x: torch.Tensor): 270 | H, W, _ = x.shape[-3:] 271 | if (W % 2 != 0) or (H % 2 != 0): 272 | x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2)) 273 | x0 = x[..., 0::2, 0::2, :] # ... H/2 W/2 C 274 | x1 = x[..., 1::2, 0::2, :] # ... H/2 W/2 C 275 | x2 = x[..., 0::2, 1::2, :] # ... H/2 W/2 C 276 | x3 = x[..., 1::2, 1::2, :] # ... H/2 W/2 C 277 | x = torch.cat([x0, x1, x2, x3], -1) # ... H/2 W/2 4*C 278 | return x 279 | 280 | def forward(self, x): 281 | x = self._patch_merging_pad(x) 282 | x = self.norm(x) 283 | x = self.reduction(x) 284 | 285 | return x 286 | 287 | 288 | class SS2D(nn.Module): 289 | def __init__( 290 | self, 291 | # basic dims =========== 292 | d_model=96, 293 | d_state=16, 294 | ssm_ratio=2.0, 295 | dt_rank="auto", 296 | act_layer=nn.SiLU, 297 | # dwconv =============== 298 | d_conv=3, # < 2 means no conv 299 | conv_bias=True, 300 | # ====================== 301 | dropout=0.0, 302 | bias=False, 303 | # dt init ============== 304 | dt_min=0.001, 305 | dt_max=0.1, 306 | dt_init="random", 307 | dt_scale=1.0, 308 | dt_init_floor=1e-4, 309 | simple_init=False, 310 | directions=None, 311 | **kwargs, 312 | ): 313 | factory_kwargs = {"device": None, "dtype": None} 314 | super().__init__() 315 | d_expand = int(ssm_ratio * d_model) 316 | d_inner = d_expand 317 | self.dt_rank = math.ceil(d_model / 16) if dt_rank == "auto" else dt_rank 318 | self.d_state = math.ceil(d_model / 6) if d_state == "auto" else d_state # 20240109 319 | self.d_conv = d_conv 320 | 321 | self.out_norm = nn.LayerNorm(d_inner) 322 | 323 | self.K = len(MultiScanVSSM.ALL_CHOICES) if directions is None else len(directions) 324 | self.K2 = self.K 325 | 326 | # in proj ======================================= 327 | self.in_proj = nn.Linear(d_model, d_expand * 2, bias=bias, **factory_kwargs) 328 | self.act: nn.Module = act_layer() 329 | 330 | # conv ======================================= 331 | if self.d_conv > 1: 332 | self.conv2d = nn.Conv2d( 333 | in_channels=d_expand, 334 | out_channels=d_expand, 335 | groups=d_expand, 336 | bias=conv_bias, 337 | kernel_size=d_conv, 338 | padding=(d_conv - 1) // 2, 339 | **factory_kwargs, 340 | ) 341 | 342 | # rank ratio ===================================== 343 | self.ssm_low_rank = False 344 | if d_inner < d_expand: 345 | self.ssm_low_rank = True 346 | self.in_rank = nn.Conv2d(d_expand, d_inner, kernel_size=1, bias=False, **factory_kwargs) 347 | self.out_rank = nn.Linear(d_inner, d_expand, bias=False, **factory_kwargs) 348 | 349 | # x proj ============================ 350 | self.x_proj = [ 351 | nn.Linear(d_inner, (self.dt_rank + self.d_state * 2), bias=False, **factory_kwargs) 352 | for _ in range(self.K) 353 | ] 354 | self.x_proj_weight = nn.Parameter(torch.stack([t.weight for t in self.x_proj], dim=0)) # (K, N, inner) 355 | del self.x_proj 356 | 357 | # dt proj ============================ 358 | self.dt_projs = [ 359 | self.dt_init(self.dt_rank, d_inner, dt_scale, dt_init, dt_min, dt_max, dt_init_floor, **factory_kwargs) 360 | for _ in range(self.K) 361 | ] 362 | self.dt_projs_weight = nn.Parameter(torch.stack([t.weight for t in self.dt_projs], dim=0)) # (K, inner, rank) 363 | self.dt_projs_bias = nn.Parameter(torch.stack([t.bias for t in self.dt_projs], dim=0)) # (K, inner) 364 | del self.dt_projs 365 | 366 | # A, D ======================================= 367 | self.A_logs = self.A_log_init(self.d_state, d_inner, copies=self.K2, merge=True) # (K * D, N) 368 | self.Ds = self.D_init(d_inner, copies=self.K2, merge=True) # (K * D) 369 | 370 | # out proj ======================================= 371 | self.out_proj = nn.Linear(d_expand, d_model, bias=bias, **factory_kwargs) 372 | self.dropout = nn.Dropout(dropout) if dropout > 0. else nn.Identity() 373 | 374 | # Local Mamba 375 | self.multi_scan = MultiScanVSSM(d_expand, choices=directions) 376 | 377 | if simple_init: 378 | # simple init dt_projs, A_logs, Ds 379 | self.Ds = nn.Parameter(torch.ones((self.K2 * d_inner))) 380 | self.A_logs = nn.Parameter(torch.randn((self.K2 * d_inner, self.d_state))) # A == -A_logs.exp() < 0; # 0 < exp(A * dt) < 1 381 | self.dt_projs_weight = nn.Parameter(torch.randn((self.K, d_inner, self.dt_rank))) 382 | self.dt_projs_bias = nn.Parameter(torch.randn((self.K, d_inner))) 383 | 384 | @staticmethod 385 | def dt_init(dt_rank, d_inner, dt_scale=1.0, dt_init="random", dt_min=0.001, dt_max=0.1, dt_init_floor=1e-4, **factory_kwargs): 386 | dt_proj = nn.Linear(dt_rank, d_inner, bias=True, **factory_kwargs) 387 | 388 | # Initialize special dt projection to preserve variance at initialization 389 | dt_init_std = dt_rank**-0.5 * dt_scale 390 | if dt_init == "constant": 391 | nn.init.constant_(dt_proj.weight, dt_init_std) 392 | elif dt_init == "random": 393 | nn.init.uniform_(dt_proj.weight, -dt_init_std, dt_init_std) 394 | else: 395 | raise NotImplementedError 396 | 397 | # Initialize dt bias so that F.softplus(dt_bias) is between dt_min and dt_max 398 | dt = torch.exp( 399 | torch.rand(d_inner, **factory_kwargs) * (math.log(dt_max) - math.log(dt_min)) 400 | + math.log(dt_min) 401 | ).clamp(min=dt_init_floor) 402 | # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759 403 | inv_dt = dt + torch.log(-torch.expm1(-dt)) 404 | with torch.no_grad(): 405 | dt_proj.bias.copy_(inv_dt) 406 | # Our initialization would set all Linear.bias to zero, need to mark this one as _no_reinit 407 | # dt_proj.bias._no_reinit = True 408 | 409 | return dt_proj 410 | 411 | @staticmethod 412 | def A_log_init(d_state, d_inner, copies=-1, device=None, merge=True): 413 | # S4D real initialization 414 | A = repeat( 415 | torch.arange(1, d_state + 1, dtype=torch.float32, device=device), 416 | "n -> d n", 417 | d=d_inner, 418 | ).contiguous() 419 | A_log = torch.log(A) # Keep A_log in fp32 420 | if copies > 0: 421 | A_log = repeat(A_log, "d n -> r d n", r=copies) 422 | if merge: 423 | A_log = A_log.flatten(0, 1) 424 | A_log = nn.Parameter(A_log) 425 | A_log._no_weight_decay = True 426 | return A_log 427 | 428 | @staticmethod 429 | def D_init(d_inner, copies=-1, device=None, merge=True): 430 | # D "skip" parameter 431 | D = torch.ones(d_inner, device=device) 432 | if copies > 0: 433 | D = repeat(D, "n1 -> r n1", r=copies) 434 | if merge: 435 | D = D.flatten(0, 1) 436 | D = nn.Parameter(D) # Keep in fp32 437 | D._no_weight_decay = True 438 | return D 439 | 440 | def forward_core(self, x: torch.Tensor, nrows=-1, channel_first=False): 441 | nrows = 1 442 | if not channel_first: 443 | x = x.permute(0, 3, 1, 2).contiguous() 444 | if self.ssm_low_rank: 445 | x = self.in_rank(x) 446 | x = multi_selective_scan( 447 | x, self.x_proj_weight, None, self.dt_projs_weight, self.dt_projs_bias, 448 | self.A_logs, self.Ds, self.out_norm, 449 | nrows=nrows, delta_softplus=True, multi_scan=self.multi_scan, 450 | ) 451 | if self.ssm_low_rank: 452 | x = self.out_rank(x) 453 | return x 454 | 455 | def forward(self, x: torch.Tensor): 456 | xz = self.in_proj(x) 457 | if self.d_conv > 1: 458 | x, z = xz.chunk(2, dim=-1) # (b, h, w, d) 459 | z = self.act(z) 460 | x = x.permute(0, 3, 1, 2).contiguous() 461 | x = self.act(self.conv2d(x)) # (b, d, h, w) 462 | else: 463 | xz = self.act(xz) 464 | x, z = xz.chunk(2, dim=-1) # (b, h, w, d) 465 | y = self.forward_core(x, channel_first=(self.d_conv > 1)) 466 | y = y * z 467 | out = self.dropout(self.out_proj(y)) 468 | return out 469 | 470 | 471 | class Permute(nn.Module): 472 | def __init__(self, *args): 473 | super().__init__() 474 | self.args = args 475 | 476 | def forward(self, x: torch.Tensor): 477 | return x.permute(*self.args) 478 | 479 | 480 | class Mlp(nn.Module): 481 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.,channels_first=False): 482 | super().__init__() 483 | out_features = out_features or in_features 484 | hidden_features = hidden_features or in_features 485 | 486 | Linear = partial(nn.Conv2d, kernel_size=1, padding=0) if channels_first else nn.Linear 487 | self.fc1 = Linear(in_features, hidden_features) 488 | self.act = act_layer() 489 | self.fc2 = Linear(hidden_features, out_features) 490 | self.drop = nn.Dropout(drop) 491 | 492 | def forward(self, x): 493 | x = self.fc1(x) 494 | x = self.act(x) 495 | x = self.drop(x) 496 | x = self.fc2(x) 497 | x = self.drop(x) 498 | return x 499 | 500 | 501 | class VSSBlock(nn.Module): 502 | def __init__( 503 | self, 504 | hidden_dim: int = 0, 505 | drop_path: float = 0, 506 | norm_layer: Callable[..., torch.nn.Module] = partial(nn.LayerNorm, eps=1e-6), 507 | # ============================= 508 | ssm_d_state: int = 16, 509 | ssm_ratio=2.0, 510 | ssm_dt_rank: Any = "auto", 511 | ssm_act_layer=nn.SiLU, 512 | ssm_conv: int = 3, 513 | ssm_conv_bias=True, 514 | ssm_drop_rate: float = 0, 515 | ssm_simple_init=False, 516 | # ============================= 517 | use_checkpoint: bool = False, 518 | directions=None, 519 | **kwargs, 520 | ): 521 | super().__init__() 522 | self.use_checkpoint = use_checkpoint 523 | self.norm = norm_layer(hidden_dim) 524 | self.op = SS2D( 525 | d_model=hidden_dim, 526 | d_state=ssm_d_state, 527 | ssm_ratio=ssm_ratio, 528 | dt_rank=ssm_dt_rank, 529 | act_layer=ssm_act_layer, 530 | # ========================== 531 | d_conv=ssm_conv, 532 | conv_bias=ssm_conv_bias, 533 | # ========================== 534 | dropout=ssm_drop_rate, 535 | # bias=False, 536 | # ========================== 537 | # dt_min=0.001, 538 | # dt_max=0.1, 539 | # dt_init="random", 540 | # dt_scale="random", 541 | # dt_init_floor=1e-4, 542 | simple_init=ssm_simple_init, 543 | # ========================== 544 | directions=directions 545 | ) 546 | self.drop_path = DropPath(drop_path) 547 | 548 | def _forward(self, input: torch.Tensor): 549 | x = input + self.drop_path(self.op(self.norm(input))) 550 | return x 551 | 552 | def forward(self, input: torch.Tensor): 553 | if self.use_checkpoint: 554 | return checkpoint.checkpoint(self._forward, input) 555 | else: 556 | return self._forward(input) 557 | 558 | 559 | class VSSM(nn.Module): 560 | def __init__( 561 | self, 562 | patch_size=4, 563 | in_chans=3, 564 | num_classes=2, 565 | depths=[2, 2, 9, 2], 566 | dims=[32, 64, 128, 256], 567 | # ========================= 568 | ssm_d_state=16, 569 | ssm_ratio=2.0, 570 | ssm_dt_rank="auto", 571 | ssm_act_layer="silu", 572 | ssm_conv=3, 573 | ssm_conv_bias=True, 574 | ssm_drop_rate=0.0, 575 | ssm_simple_init=False, 576 | # ========================= 577 | drop_path_rate=0.1, 578 | patch_norm=True, 579 | norm_layer="LN", 580 | use_checkpoint=False, 581 | directions=None, 582 | **kwargs, 583 | ): 584 | super().__init__() 585 | self.num_classes = num_classes 586 | self.num_layers = len(depths) 587 | if isinstance(dims, int): 588 | dims = [int(dims * 2 ** i_layer) for i_layer in range(self.num_layers)] 589 | self.num_features = dims[-1] 590 | self.dims = dims 591 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule 592 | 593 | _NORMLAYERS = dict( 594 | ln=nn.LayerNorm, 595 | bn=nn.BatchNorm2d, 596 | ) 597 | 598 | _ACTLAYERS = dict( 599 | silu=nn.SiLU, 600 | gelu=nn.GELU, 601 | relu=nn.ReLU, 602 | sigmoid=nn.Sigmoid, 603 | ) 604 | 605 | if norm_layer.lower() in ["ln"]: 606 | norm_layer: nn.Module = _NORMLAYERS[norm_layer.lower()] 607 | 608 | if ssm_act_layer.lower() in ["silu", "gelu", "relu"]: 609 | ssm_act_layer: nn.Module = _ACTLAYERS[ssm_act_layer.lower()] 610 | 611 | self.patch_embed = nn.Sequential( 612 | nn.Conv2d(in_chans, dims[0], kernel_size=patch_size, stride=patch_size, bias=True), 613 | Permute(0, 2, 3, 1), 614 | (norm_layer(dims[0]) if patch_norm else nn.Identity()), 615 | ) 616 | 617 | self.layers = nn.ModuleList() 618 | for i_layer in range(self.num_layers): 619 | downsample = PatchMerging2D( 620 | self.dims[i_layer], 621 | self.dims[i_layer + 1], 622 | norm_layer=norm_layer, 623 | ) if (i_layer < self.num_layers - 1) else nn.Identity() 624 | 625 | self.layers.append(self._make_layer( 626 | dim = self.dims[i_layer], 627 | drop_path = dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], 628 | use_checkpoint=use_checkpoint, 629 | norm_layer=norm_layer, 630 | downsample=downsample, 631 | # ================= 632 | ssm_d_state=ssm_d_state, 633 | ssm_ratio=ssm_ratio, 634 | ssm_dt_rank=ssm_dt_rank, 635 | ssm_act_layer=ssm_act_layer, 636 | ssm_conv=ssm_conv, 637 | ssm_conv_bias=ssm_conv_bias, 638 | ssm_drop_rate=ssm_drop_rate, 639 | ssm_simple_init=ssm_simple_init, 640 | # ================= 641 | directions=None if directions is None else directions[sum(depths[:i_layer]):sum(depths[:i_layer + 1])] 642 | )) 643 | 644 | self.classifier = nn.Sequential(OrderedDict( 645 | norm=norm_layer(self.num_features), # B,H,W,C 646 | permute=Permute(0, 3, 1, 2), 647 | avgpool=nn.AdaptiveAvgPool2d(1), 648 | flatten=nn.Flatten(1), 649 | head=nn.Linear(self.num_features, num_classes), 650 | )) 651 | 652 | self.apply(self._init_weights) 653 | 654 | def _init_weights(self, m: nn.Module): 655 | if isinstance(m, nn.Linear): 656 | trunc_normal_(m.weight, std=.02) 657 | if isinstance(m, nn.Linear) and m.bias is not None: 658 | nn.init.constant_(m.bias, 0) 659 | elif isinstance(m, nn.LayerNorm): 660 | if m.bias is not None: 661 | nn.init.constant_(m.bias, 0) 662 | if m.weight is not None: 663 | nn.init.constant_(m.weight, 1.0) 664 | 665 | # used in building optimizer 666 | @torch.jit.ignore 667 | def no_weight_decay(self): 668 | return {} 669 | 670 | @staticmethod 671 | def _make_downsample(dim=96, out_dim=192, norm_layer=nn.LayerNorm): 672 | return nn.Sequential( 673 | Permute(0, 3, 1, 2), 674 | nn.Conv2d(dim, out_dim, kernel_size=2, stride=2), 675 | Permute(0, 2, 3, 1), 676 | norm_layer(out_dim), 677 | ) 678 | 679 | @staticmethod 680 | def _make_layer( 681 | dim=96, 682 | drop_path=[0.1, 0.1], 683 | use_checkpoint=False, 684 | norm_layer=nn.LayerNorm, 685 | downsample=nn.Identity(), 686 | # =========================== 687 | ssm_d_state=16, 688 | ssm_ratio=2.0, 689 | ssm_dt_rank="auto", 690 | ssm_act_layer=nn.SiLU, 691 | ssm_conv=3, 692 | ssm_conv_bias=True, 693 | ssm_drop_rate=0.0, 694 | ssm_simple_init=False, 695 | # =========================== 696 | directions=None, 697 | **kwargs, 698 | ): 699 | depth = len(drop_path) 700 | blocks = [] 701 | for d in range(depth): 702 | blocks.append(VSSBlock( 703 | hidden_dim=dim, 704 | drop_path=drop_path[d], 705 | norm_layer=norm_layer, 706 | ssm_d_state=ssm_d_state, 707 | ssm_ratio=ssm_ratio, 708 | ssm_dt_rank=ssm_dt_rank, 709 | ssm_act_layer=ssm_act_layer, 710 | ssm_conv=ssm_conv, 711 | ssm_conv_bias=ssm_conv_bias, 712 | ssm_drop_rate=ssm_drop_rate, 713 | ssm_simple_init=ssm_simple_init, 714 | use_checkpoint=use_checkpoint, 715 | directions=directions[d] if directions is not None else None 716 | )) 717 | 718 | return nn.Sequential(OrderedDict( 719 | blocks=nn.Sequential(*blocks,), 720 | downsample=downsample, 721 | )) 722 | 723 | def forward_features(self, x): # x [1, 3, 256, 256] 724 | skip_list = [] 725 | x = self.patch_embed(x) # x [1, 64, 64, 96] , dims=[96, 192, 384, 768] 726 | for layer in self.layers: 727 | skip_list.append(x) # x [1, 96, 64, 64] 728 | x = layer(x) # x encoder 的最终输出 729 | return x, skip_list # x [1, 8, 8, 768] , len(skip_list) = 4 , skip_list[0] [1, 64, 64, 96] 730 | 731 | def forward_features_up(self, x, skip_list): 732 | for inx, layer_up in enumerate(self.layers_up): # x [1, 8, 8, 768] 733 | if inx == 0: 734 | x = layer_up(x) 735 | else: 736 | x = layer_up(x+skip_list[-inx]) 737 | 738 | return x # [1, 64, 64, 96] 739 | 740 | def forward_final(self, x): # x [1, 64, 64, 96] 741 | x = self.final_up(x) # x [1, 256, 256, 24] 742 | x = x.permute(0,3,1,2) # x [1, 24, 256, 256] 743 | x = self.final_conv(x) # x [1, 1, 256, 256] 744 | return x 745 | 746 | def forward_backbone(self, x): 747 | x = self.patch_embed(x) 748 | if self.ape: 749 | x = x + self.absolute_pos_embed 750 | x = self.pos_drop(x) 751 | 752 | for layer in self.layers: 753 | x = layer(x) 754 | return x 755 | 756 | def forward_bak(self, x): 757 | x, skip_list = self.forward_features(x) # skip_list[0] [2, 64, 64, 96] [3] [2, 8, 8, 768] 758 | x = self.forward_features_up(x, skip_list) 759 | x = self.forward_final(x) 760 | 761 | return x 762 | # modified by sim to sdi module 763 | def forward(self, x): 764 | x, skip_list = self.forward_features(x) 765 | skip_list[0] = skip_list[0].permute(0, 3, 1, 2) 766 | skip_list[1] = skip_list[1].permute(0, 3, 1, 2) 767 | skip_list[2] = skip_list[2].permute(0, 3, 1, 2) 768 | skip_list[3] = skip_list[3].permute(0, 3, 1, 2) 769 | for i, o in enumerate(skip_list): # 4 倍上采样 770 | skip_list[i] = F.interpolate(o, scale_factor=4, mode='bilinear') # skip_list[0] [2, 64, 64, 96] [3] [2, 8, 8, 768] 771 | return skip_list[0], skip_list[1], skip_list[2], skip_list[3] 772 | 773 | 774 | 775 | 776 | -------------------------------------------------------------------------------- /baseline/vmamba.py: -------------------------------------------------------------------------------- 1 | import math 2 | import copy 3 | import warnings 4 | from functools import partial 5 | from typing import Optional, Callable, Any 6 | from collections import OrderedDict 7 | 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | import torch.utils.checkpoint as checkpoint 12 | from einops import rearrange, repeat 13 | from timm.models.layers import DropPath, trunc_normal_ 14 | from timm.models.registry import register_model 15 | from fvcore.nn import FlopCountAnalysis, flop_count_str, flop_count, parameter_count 16 | from baseline.multi_mamba import MultiScan 17 | 18 | 19 | DropPath.__repr__ = lambda self: f"timm.DropPath({self.drop_prob})" 20 | 21 | 22 | try: 23 | "sscore acts the same as mamba_ssm" 24 | SSMODE = "sscore" 25 | import selective_scan_cuda_core 26 | print("Using \"selective_scan_cuda_core\"") 27 | except Exception as e: 28 | warnings.warn(f"{e}\n\"selective_scan_cuda_core\" not found, use default \"selective_scan_cuda\" instead.") 29 | # print(e, flush=True) 30 | SSMODE = "mamba_ssm" 31 | import selective_scan_cuda 32 | 33 | 34 | # fvcore flops ======================================= 35 | 36 | def flops_selective_scan_fn(B=1, L=256, D=768, N=16, with_D=True, with_Z=False, with_Group=True, with_complex=False): 37 | """ 38 | u: r(B D L) 39 | delta: r(B D L) 40 | A: r(D N) 41 | B: r(B N L) 42 | C: r(B N L) 43 | D: r(D) 44 | z: r(B D L) 45 | delta_bias: r(D), fp32 46 | 47 | ignores: 48 | [.float(), +, .softplus, .shape, new_zeros, repeat, stack, to(dtype), silu] 49 | """ 50 | assert not with_complex 51 | # https://github.com/state-spaces/mamba/issues/110 52 | flops = 9 * B * L * D * N 53 | if with_D: 54 | flops += B * D * L 55 | if with_Z: 56 | flops += B * D * L 57 | return flops 58 | 59 | def selective_scan_flop_jit(inputs, outputs): 60 | B, D, L = inputs[0].type().sizes() 61 | N = inputs[2].type().sizes()[1] 62 | flops = flops_selective_scan_fn(B=B, L=L, D=D, N=N, with_D=True, with_Z=False, with_Group=True) 63 | return flops 64 | 65 | # def q_shift(input, shift_pixel=1, gamma=1/4, patch_resolution=None): 66 | # assert gamma <= 1/4 67 | # input = input.permute(0, 3, 1, 2) 68 | # B, C, H, W = input.shape 69 | # output = torch.zeros_like(input) 70 | # output[:, 0:int(C*gamma), :, shift_pixel:W] = input[:, 0:int(C*gamma), :, 0:W-shift_pixel] 71 | # output[:, int(C*gamma):int(C*gamma*2), :, 0:W-shift_pixel] = input[:, int(C*gamma):int(C*gamma*2), :, shift_pixel:W] 72 | # output[:, int(C*gamma*2):int(C*gamma*3), shift_pixel:H, :] = input[:, int(C*gamma*2):int(C*gamma*3), 0:H-shift_pixel, :] 73 | # output[:, int(C*gamma*3):int(C*gamma*4), 0:H-shift_pixel, :] = input[:, int(C*gamma*3):int(C*gamma*4), shift_pixel:H, :] 74 | # output[:, int(C*gamma*4):, ...] = input[:, int(C*gamma*4):, ...] 75 | # return output.permute(0, 2, 3, 1) 76 | 77 | 78 | class SelectiveScan(torch.autograd.Function): 79 | 80 | @staticmethod 81 | @torch.cuda.amp.custom_fwd(cast_inputs=torch.float32) 82 | def forward(ctx, u, delta, A, B, C, D=None, delta_bias=None, delta_softplus=False, nrows=1): 83 | assert nrows in [1, 2, 3, 4], f"{nrows}" # 8+ is too slow to compile 84 | assert u.shape[1] % (B.shape[1] * nrows) == 0, f"{nrows}, {u.shape}, {B.shape}" 85 | ctx.delta_softplus = delta_softplus 86 | ctx.nrows = nrows 87 | # all in float 88 | if u.stride(-1) != 1: 89 | u = u.contiguous() 90 | if delta.stride(-1) != 1: 91 | delta = delta.contiguous() 92 | if D is not None: 93 | D = D.contiguous() 94 | if B.stride(-1) != 1: 95 | B = B.contiguous() 96 | if C.stride(-1) != 1: 97 | C = C.contiguous() 98 | if B.dim() == 3: 99 | B = B.unsqueeze(dim=1) 100 | ctx.squeeze_B = True 101 | if C.dim() == 3: 102 | C = C.unsqueeze(dim=1) 103 | ctx.squeeze_C = True 104 | 105 | if SSMODE == "mamba_ssm": 106 | out, x, *rest = selective_scan_cuda.fwd(u, delta, A, B, C, D, None, delta_bias, delta_softplus) 107 | else: 108 | out, x, *rest = selective_scan_cuda_core.fwd(u, delta, A, B, C, D, delta_bias, delta_softplus, nrows) 109 | 110 | ctx.save_for_backward(u, delta, A, B, C, D, delta_bias, x) 111 | return out 112 | 113 | @staticmethod 114 | @torch.cuda.amp.custom_bwd 115 | def backward(ctx, dout, *args): 116 | u, delta, A, B, C, D, delta_bias, x = ctx.saved_tensors 117 | if dout.stride(-1) != 1: 118 | dout = dout.contiguous() 119 | 120 | if SSMODE == "mamba_ssm": 121 | du, ddelta, dA, dB, dC, dD, ddelta_bias, *rest = selective_scan_cuda.bwd( 122 | u, delta, A, B, C, D, None, delta_bias, dout, x, None, None, ctx.delta_softplus, 123 | False # option to recompute out_z, not used here 124 | ) 125 | else: 126 | du, ddelta, dA, dB, dC, dD, ddelta_bias, *rest = selective_scan_cuda_core.bwd( 127 | u, delta, A, B, C, D, delta_bias, dout, x, ctx.delta_softplus, 1 128 | # u, delta, A, B, C, D, delta_bias, dout, x, ctx.delta_softplus, ctx.nrows, 129 | ) 130 | 131 | dB = dB.squeeze(1) if getattr(ctx, "squeeze_B", False) else dB 132 | dC = dC.squeeze(1) if getattr(ctx, "squeeze_C", False) else dC 133 | return (du, ddelta, dA, dB, dC, dD, ddelta_bias, None, None) 134 | 135 | 136 | """ 137 | Local Mamba 138 | """ 139 | class MultiScanVSSM(MultiScan): 140 | 141 | ALL_CHOICES = MultiScan.ALL_CHOICES 142 | 143 | def __init__(self, dim, choices=None): 144 | super().__init__(dim, choices=choices, token_size=None) 145 | self.attn = BiAttn(dim) 146 | 147 | def merge(self, xs): 148 | # xs: [B, K, D, L] 149 | # return: [B, D, L] 150 | 151 | # remove the padded tokens 152 | xs = [xs[:, i, :, :l] for i, l in enumerate(self.scan_lengths)] 153 | xs = super().multi_reverse(xs) 154 | xs = [self.attn(x.transpose(-2, -1)) for x in xs] 155 | x = super().forward(xs) 156 | return x 157 | 158 | 159 | def multi_scan(self, x): 160 | # x: [B, C, H, W] 161 | # return: [B, K, C, H * W] 162 | B, C, H, W = x.shape 163 | self.token_size = (H, W) 164 | 165 | xs = super().multi_scan(x) # [[B, C, H, W], ...] 166 | 167 | self.scan_lengths = [x.shape[2] for x in xs] 168 | max_length = max(self.scan_lengths) 169 | 170 | # pad the tokens into the same length as VMamba compute all directions together 171 | new_xs = [] 172 | for x in xs: 173 | if x.shape[2] < max_length: 174 | x = F.pad(x, (0, max_length - x.shape[2])) 175 | new_xs.append(x) 176 | return torch.stack(new_xs, 1) 177 | 178 | def __repr__(self): 179 | scans = ', '.join(self.choices) 180 | return super().__repr__().replace('MultiScanVSSM', f'MultiScanVSSM[{scans}]') 181 | 182 | 183 | class BiAttn(nn.Module): 184 | def __init__(self, in_channels, act_ratio=0.125, act_fn=nn.GELU, gate_fn=nn.Sigmoid): 185 | super().__init__() 186 | reduce_channels = int(in_channels * act_ratio) 187 | self.norm = nn.LayerNorm(in_channels) 188 | self.global_reduce = nn.Linear(in_channels, reduce_channels) 189 | # self.local_reduce = nn.Linear(in_channels, reduce_channels) 190 | self.act_fn = act_fn() 191 | self.channel_select = nn.Linear(reduce_channels, in_channels) 192 | # self.spatial_select = nn.Linear(reduce_channels * 2, 1) 193 | self.gate_fn = gate_fn() 194 | 195 | def forward(self, x): 196 | ori_x = x 197 | x = self.norm(x) 198 | x_global = x.mean(1, keepdim=True) 199 | x_global = self.act_fn(self.global_reduce(x_global)) 200 | # x_local = self.act_fn(self.local_reduce(x)) 201 | 202 | c_attn = self.channel_select(x_global) 203 | c_attn = self.gate_fn(c_attn) # [B, 1, C] 204 | # s_attn = self.spatial_select(torch.cat([x_local, x_global.expand(-1, x.shape[1], -1)], dim=-1)) 205 | # s_attn = self.gate_fn(s_attn) # [B, N, 1] 206 | 207 | attn = c_attn #* s_attn # [B, N, C] 208 | out = ori_x * attn 209 | return out 210 | 211 | 212 | def multi_selective_scan( 213 | x: torch.Tensor=None, 214 | x_proj_weight: torch.Tensor=None, 215 | x_proj_bias: torch.Tensor=None, 216 | dt_projs_weight: torch.Tensor=None, 217 | dt_projs_bias: torch.Tensor=None, 218 | A_logs: torch.Tensor=None, 219 | Ds: torch.Tensor=None, 220 | out_norm: torch.nn.Module=None, 221 | nrows = -1, 222 | delta_softplus = True, 223 | to_dtype=True, 224 | multi_scan=None, 225 | ): 226 | B, D, H, W = x.shape 227 | D, N = A_logs.shape 228 | K, D, R = dt_projs_weight.shape 229 | L = H * W 230 | 231 | if nrows < 1: 232 | if D % 4 == 0: 233 | nrows = 4 234 | elif D % 3 == 0: 235 | nrows = 3 236 | elif D % 2 == 0: 237 | nrows = 2 238 | else: 239 | nrows = 1 240 | 241 | xs = multi_scan.multi_scan(x) 242 | 243 | L = xs.shape[-1] 244 | x_dbl = torch.einsum("b k d l, k c d -> b k c l", xs, x_proj_weight) # l fixed 245 | 246 | if x_proj_bias is not None: 247 | x_dbl = x_dbl + x_proj_bias.view(1, K, -1, 1) 248 | dts, Bs, Cs = torch.split(x_dbl, [R, N, N], dim=2) 249 | dts = torch.einsum("b k r l, k d r -> b k d l", dts, dt_projs_weight) 250 | 251 | xs = xs.view(B, -1, L).to(torch.float) 252 | dts = dts.contiguous().view(B, -1, L).to(torch.float) 253 | As = -torch.exp(A_logs.to(torch.float)) # (k * c, d_state) 254 | Bs = Bs.contiguous().to(torch.float) 255 | Cs = Cs.contiguous().to(torch.float) 256 | Ds = Ds.to(torch.float) # (K * c) 257 | delta_bias = dt_projs_bias.view(-1).to(torch.float) 258 | 259 | def selective_scan(u, delta, A, B, C, D=None, delta_bias=None, delta_softplus=True, nrows=1): 260 | return SelectiveScan.apply(u, delta, A, B, C, D, delta_bias, delta_softplus, nrows) 261 | 262 | ys: torch.Tensor = selective_scan( 263 | xs, dts, As, Bs, Cs, Ds, delta_bias, delta_softplus, nrows, 264 | ).view(B, K, -1, L) 265 | 266 | y = multi_scan.merge(ys) 267 | 268 | y = out_norm(y).view(B, H, W, -1) 269 | 270 | return (y.to(x.dtype) if to_dtype else y) 271 | 272 | 273 | class PatchMerging2D(nn.Module): 274 | def __init__(self, dim, out_dim=-1, norm_layer=nn.LayerNorm): 275 | super().__init__() 276 | self.dim = dim 277 | self.reduction = nn.Linear(4 * dim, (2 * dim) if out_dim < 0 else out_dim, bias=False) 278 | self.norm = norm_layer(4 * dim) 279 | 280 | @staticmethod 281 | def _patch_merging_pad(x: torch.Tensor): 282 | H, W, _ = x.shape[-3:] 283 | if (W % 2 != 0) or (H % 2 != 0): 284 | x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2)) 285 | x0 = x[..., 0::2, 0::2, :] # ... H/2 W/2 C 286 | x1 = x[..., 1::2, 0::2, :] # ... H/2 W/2 C 287 | x2 = x[..., 0::2, 1::2, :] # ... H/2 W/2 C 288 | x3 = x[..., 1::2, 1::2, :] # ... H/2 W/2 C 289 | x = torch.cat([x0, x1, x2, x3], -1) # ... H/2 W/2 4*C 290 | return x 291 | 292 | def forward(self, x): 293 | x = self._patch_merging_pad(x) 294 | x = self.norm(x) 295 | x = self.reduction(x) 296 | 297 | return x 298 | 299 | 300 | class SS2D(nn.Module): 301 | def __init__( 302 | self, 303 | # basic dims =========== 304 | d_model=96, 305 | d_state=16, 306 | ssm_ratio=2.0, 307 | dt_rank="auto", 308 | act_layer=nn.SiLU, 309 | # dwconv =============== 310 | d_conv=3, # < 2 means no conv 311 | conv_bias=True, 312 | # ====================== 313 | dropout=0.0, 314 | bias=False, 315 | # dt init ============== 316 | dt_min=0.001, 317 | dt_max=0.1, 318 | dt_init="random", 319 | dt_scale=1.0, 320 | dt_init_floor=1e-4, 321 | simple_init=False, 322 | directions=None, 323 | **kwargs, 324 | ): 325 | factory_kwargs = {"device": None, "dtype": None} 326 | super().__init__() 327 | d_expand = int(ssm_ratio * d_model) 328 | d_inner = d_expand 329 | self.dt_rank = math.ceil(d_model / 16) if dt_rank == "auto" else dt_rank 330 | self.d_state = math.ceil(d_model / 6) if d_state == "auto" else d_state # 20240109 331 | self.d_conv = d_conv 332 | 333 | self.out_norm = nn.LayerNorm(d_inner) 334 | 335 | self.K = len(MultiScanVSSM.ALL_CHOICES) if directions is None else len(directions) 336 | self.K2 = self.K 337 | 338 | # in proj ======================================= 339 | self.in_proj = nn.Linear(d_model, d_expand * 2, bias=bias, **factory_kwargs) 340 | self.act: nn.Module = act_layer() 341 | 342 | # conv ======================================= 343 | if self.d_conv > 1: 344 | self.conv2d = nn.Conv2d( 345 | in_channels=d_expand, 346 | out_channels=d_expand, 347 | groups=d_expand, 348 | bias=conv_bias, 349 | kernel_size=d_conv, 350 | padding=(d_conv - 1) // 2, 351 | **factory_kwargs, 352 | ) 353 | 354 | # rank ratio ===================================== 355 | self.ssm_low_rank = False 356 | if d_inner < d_expand: 357 | self.ssm_low_rank = True 358 | self.in_rank = nn.Conv2d(d_expand, d_inner, kernel_size=1, bias=False, **factory_kwargs) 359 | self.out_rank = nn.Linear(d_inner, d_expand, bias=False, **factory_kwargs) 360 | 361 | # x proj ============================ 362 | self.x_proj = [ 363 | nn.Linear(d_inner, (self.dt_rank + self.d_state * 2), bias=False, **factory_kwargs) 364 | for _ in range(self.K) 365 | ] 366 | self.x_proj_weight = nn.Parameter(torch.stack([t.weight for t in self.x_proj], dim=0)) # (K, N, inner) 367 | del self.x_proj 368 | 369 | # dt proj ============================ 370 | self.dt_projs = [ 371 | self.dt_init(self.dt_rank, d_inner, dt_scale, dt_init, dt_min, dt_max, dt_init_floor, **factory_kwargs) 372 | for _ in range(self.K) 373 | ] 374 | self.dt_projs_weight = nn.Parameter(torch.stack([t.weight for t in self.dt_projs], dim=0)) # (K, inner, rank) 375 | self.dt_projs_bias = nn.Parameter(torch.stack([t.bias for t in self.dt_projs], dim=0)) # (K, inner) 376 | del self.dt_projs 377 | 378 | # A, D ======================================= 379 | self.A_logs = self.A_log_init(self.d_state, d_inner, copies=self.K2, merge=True) # (K * D, N) 380 | self.Ds = self.D_init(d_inner, copies=self.K2, merge=True) # (K * D) 381 | 382 | # out proj ======================================= 383 | self.out_proj = nn.Linear(d_expand, d_model, bias=bias, **factory_kwargs) 384 | self.dropout = nn.Dropout(dropout) if dropout > 0. else nn.Identity() 385 | 386 | # Local Mamba 387 | self.multi_scan = MultiScanVSSM(d_expand, choices=directions) 388 | 389 | if simple_init: 390 | # simple init dt_projs, A_logs, Ds 391 | self.Ds = nn.Parameter(torch.ones((self.K2 * d_inner))) 392 | self.A_logs = nn.Parameter(torch.randn((self.K2 * d_inner, self.d_state))) # A == -A_logs.exp() < 0; # 0 < exp(A * dt) < 1 393 | self.dt_projs_weight = nn.Parameter(torch.randn((self.K, d_inner, self.dt_rank))) 394 | self.dt_projs_bias = nn.Parameter(torch.randn((self.K, d_inner))) 395 | 396 | @staticmethod 397 | def dt_init(dt_rank, d_inner, dt_scale=1.0, dt_init="random", dt_min=0.001, dt_max=0.1, dt_init_floor=1e-4, **factory_kwargs): 398 | dt_proj = nn.Linear(dt_rank, d_inner, bias=True, **factory_kwargs) 399 | 400 | # Initialize special dt projection to preserve variance at initialization 401 | dt_init_std = dt_rank**-0.5 * dt_scale 402 | if dt_init == "constant": 403 | nn.init.constant_(dt_proj.weight, dt_init_std) 404 | elif dt_init == "random": 405 | nn.init.uniform_(dt_proj.weight, -dt_init_std, dt_init_std) 406 | else: 407 | raise NotImplementedError 408 | 409 | # Initialize dt bias so that F.softplus(dt_bias) is between dt_min and dt_max 410 | dt = torch.exp( 411 | torch.rand(d_inner, **factory_kwargs) * (math.log(dt_max) - math.log(dt_min)) 412 | + math.log(dt_min) 413 | ).clamp(min=dt_init_floor) 414 | # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759 415 | inv_dt = dt + torch.log(-torch.expm1(-dt)) 416 | with torch.no_grad(): 417 | dt_proj.bias.copy_(inv_dt) 418 | # Our initialization would set all Linear.bias to zero, need to mark this one as _no_reinit 419 | # dt_proj.bias._no_reinit = True 420 | 421 | return dt_proj 422 | 423 | @staticmethod 424 | def A_log_init(d_state, d_inner, copies=-1, device=None, merge=True): 425 | # S4D real initialization 426 | A = repeat( 427 | torch.arange(1, d_state + 1, dtype=torch.float32, device=device), 428 | "n -> d n", 429 | d=d_inner, 430 | ).contiguous() 431 | A_log = torch.log(A) # Keep A_log in fp32 432 | if copies > 0: 433 | A_log = repeat(A_log, "d n -> r d n", r=copies) 434 | if merge: 435 | A_log = A_log.flatten(0, 1) 436 | A_log = nn.Parameter(A_log) 437 | A_log._no_weight_decay = True 438 | return A_log 439 | 440 | @staticmethod 441 | def D_init(d_inner, copies=-1, device=None, merge=True): 442 | # D "skip" parameter 443 | D = torch.ones(d_inner, device=device) 444 | if copies > 0: 445 | D = repeat(D, "n1 -> r n1", r=copies) 446 | if merge: 447 | D = D.flatten(0, 1) 448 | D = nn.Parameter(D) # Keep in fp32 449 | D._no_weight_decay = True 450 | return D 451 | 452 | def forward_core(self, x: torch.Tensor, nrows=-1, channel_first=False): 453 | nrows = 1 454 | if not channel_first: 455 | x = x.permute(0, 3, 1, 2).contiguous() 456 | if self.ssm_low_rank: 457 | x = self.in_rank(x) 458 | x = multi_selective_scan( 459 | x, self.x_proj_weight, None, self.dt_projs_weight, self.dt_projs_bias, 460 | self.A_logs, self.Ds, self.out_norm, 461 | nrows=nrows, delta_softplus=True, multi_scan=self.multi_scan, 462 | ) 463 | if self.ssm_low_rank: 464 | x = self.out_rank(x) 465 | return x 466 | 467 | def forward(self, x: torch.Tensor): 468 | xz = self.in_proj(x) 469 | if self.d_conv > 1: 470 | x, z = xz.chunk(2, dim=-1) # (b, h, w, d) 471 | z = self.act(z) 472 | x = x.permute(0, 3, 1, 2).contiguous() 473 | x = self.act(self.conv2d(x)) # (b, d, h, w) 474 | else: 475 | xz = self.act(xz) 476 | x, z = xz.chunk(2, dim=-1) # (b, h, w, d) 477 | y = self.forward_core(x, channel_first=(self.d_conv > 1)) 478 | y = y * z 479 | out = self.dropout(self.out_proj(y)) 480 | return out 481 | 482 | 483 | class Permute(nn.Module): 484 | def __init__(self, *args): 485 | super().__init__() 486 | self.args = args 487 | 488 | def forward(self, x: torch.Tensor): 489 | return x.permute(*self.args) 490 | 491 | 492 | class Mlp(nn.Module): 493 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.,channels_first=False): 494 | super().__init__() 495 | out_features = out_features or in_features 496 | hidden_features = hidden_features or in_features 497 | 498 | Linear = partial(nn.Conv2d, kernel_size=1, padding=0) if channels_first else nn.Linear 499 | self.fc1 = Linear(in_features, hidden_features) 500 | self.act = act_layer() 501 | self.fc2 = Linear(hidden_features, out_features) 502 | self.drop = nn.Dropout(drop) 503 | 504 | def forward(self, x): 505 | x = self.fc1(x) 506 | x = self.act(x) 507 | x = self.drop(x) 508 | x = self.fc2(x) 509 | x = self.drop(x) 510 | return x 511 | 512 | 513 | class VSSBlock(nn.Module): 514 | def __init__( 515 | self, 516 | hidden_dim: int = 0, 517 | drop_path: float = 0, 518 | norm_layer: Callable[..., torch.nn.Module] = partial(nn.LayerNorm, eps=1e-6), 519 | # ============================= 520 | ssm_d_state: int = 16, 521 | ssm_ratio=2.0, 522 | ssm_dt_rank: Any = "auto", 523 | ssm_act_layer=nn.SiLU, 524 | ssm_conv: int = 3, 525 | ssm_conv_bias=True, 526 | ssm_drop_rate: float = 0, 527 | ssm_simple_init=False, 528 | # ============================= 529 | use_checkpoint: bool = False, 530 | directions=None, 531 | **kwargs, 532 | ): 533 | super().__init__() 534 | self.use_checkpoint = use_checkpoint 535 | self.norm = norm_layer(hidden_dim) 536 | self.op = SS2D( 537 | d_model=hidden_dim, 538 | d_state=ssm_d_state, 539 | ssm_ratio=ssm_ratio, 540 | dt_rank=ssm_dt_rank, 541 | act_layer=ssm_act_layer, 542 | # ========================== 543 | d_conv=ssm_conv, 544 | conv_bias=ssm_conv_bias, 545 | # ========================== 546 | dropout=ssm_drop_rate, 547 | # bias=False, 548 | # ========================== 549 | # dt_min=0.001, 550 | # dt_max=0.1, 551 | # dt_init="random", 552 | # dt_scale="random", 553 | # dt_init_floor=1e-4, 554 | simple_init=ssm_simple_init, 555 | # ========================== 556 | directions=directions 557 | ) 558 | self.drop_path = DropPath(drop_path) 559 | 560 | def _forward(self, input: torch.Tensor): 561 | x = input + self.drop_path(self.op(self.norm(input))) 562 | return x 563 | 564 | def forward(self, input: torch.Tensor): 565 | if self.use_checkpoint: 566 | return checkpoint.checkpoint(self._forward, input) 567 | else: 568 | return self._forward(input) 569 | 570 | 571 | class VSSM(nn.Module): 572 | def __init__( 573 | self, 574 | patch_size=4, 575 | in_chans=3, 576 | num_classes=2, 577 | depths=[2, 2, 9, 2], 578 | dims=[32, 64, 128, 256], 579 | # ========================= 580 | ssm_d_state=16, 581 | ssm_ratio=2.0, 582 | ssm_dt_rank="auto", 583 | ssm_act_layer="silu", 584 | ssm_conv=3, 585 | ssm_conv_bias=True, 586 | ssm_drop_rate=0.0, 587 | ssm_simple_init=False, 588 | # ========================= 589 | drop_path_rate=0.1, 590 | patch_norm=True, 591 | norm_layer="LN", 592 | use_checkpoint=False, 593 | directions=[ 594 | ['h', 'h_flip', 'w7', 'w7_flip'], 595 | ['h_flip', 'v_flip', 'w2', 'w2_flip'], 596 | ['h_flip', 'v_flip', 'w2_flip', 'w7'], 597 | ['h_flip', 'v', 'v_flip', 'w2'], 598 | ['h', 'h_flip', 'v_flip', 'w2_flip'], 599 | ['h_flip', 'v_flip', 'w2', 'w2_flip'], 600 | ['h', 'w2_flip', 'w7', 'w7_flip'], 601 | ['h', 'h_flip', 'v', 'v_flip'], 602 | ['h', 'v_flip', 'w7', 'w7_flip'], 603 | ['h_flip', 'v', 'w2', 'w7_flip'], 604 | ['v', 'v_flip', 'w2', 'w7_flip'], 605 | ['h', 'h_flip', 'v_flip', 'w2_flip'], 606 | ['v_flip', 'w2_flip', 'w7', 'w7_flip'], 607 | ['h_flip', 'v_flip', 'w2_flip', 'w7_flip'], 608 | ['h_flip', 'v', 'w7', 'w7_flip'], 609 | ], 610 | **kwargs, 611 | ): 612 | super().__init__() 613 | self.num_classes = num_classes 614 | self.num_layers = len(depths) 615 | if isinstance(dims, int): 616 | dims = [int(dims * 2 ** i_layer) for i_layer in range(self.num_layers)] 617 | self.num_features = dims[-1] 618 | self.dims = dims 619 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule 620 | 621 | _NORMLAYERS = dict( 622 | ln=nn.LayerNorm, 623 | bn=nn.BatchNorm2d, 624 | ) 625 | 626 | _ACTLAYERS = dict( 627 | silu=nn.SiLU, 628 | gelu=nn.GELU, 629 | relu=nn.ReLU, 630 | sigmoid=nn.Sigmoid, 631 | ) 632 | 633 | if norm_layer.lower() in ["ln"]: 634 | norm_layer: nn.Module = _NORMLAYERS[norm_layer.lower()] 635 | 636 | if ssm_act_layer.lower() in ["silu", "gelu", "relu"]: 637 | ssm_act_layer: nn.Module = _ACTLAYERS[ssm_act_layer.lower()] 638 | 639 | self.patch_embed = nn.Sequential( 640 | nn.Conv2d(in_chans, dims[0], kernel_size=patch_size, stride=patch_size, bias=True), 641 | Permute(0, 2, 3, 1), 642 | (norm_layer(dims[0]) if patch_norm else nn.Identity()), 643 | ) 644 | 645 | self.layers = nn.ModuleList() 646 | for i_layer in range(self.num_layers): 647 | downsample = PatchMerging2D( 648 | self.dims[i_layer], 649 | self.dims[i_layer + 1], 650 | norm_layer=norm_layer, 651 | ) if (i_layer < self.num_layers - 1) else nn.Identity() 652 | 653 | self.layers.append(self._make_layer( 654 | dim = self.dims[i_layer], 655 | drop_path = dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], 656 | use_checkpoint=use_checkpoint, 657 | norm_layer=norm_layer, 658 | downsample=downsample, 659 | # ================= 660 | ssm_d_state=ssm_d_state, 661 | ssm_ratio=ssm_ratio, 662 | ssm_dt_rank=ssm_dt_rank, 663 | ssm_act_layer=ssm_act_layer, 664 | ssm_conv=ssm_conv, 665 | ssm_conv_bias=ssm_conv_bias, 666 | ssm_drop_rate=ssm_drop_rate, 667 | ssm_simple_init=ssm_simple_init, 668 | # ================= 669 | directions=None if directions is None else directions[sum(depths[:i_layer]):sum(depths[:i_layer + 1])] 670 | )) 671 | 672 | 673 | self.apply(self._init_weights) 674 | 675 | def _init_weights(self, m: nn.Module): 676 | if isinstance(m, nn.Linear): 677 | trunc_normal_(m.weight, std=.02) 678 | if isinstance(m, nn.Linear) and m.bias is not None: 679 | nn.init.constant_(m.bias, 0) 680 | elif isinstance(m, nn.LayerNorm): 681 | if m.bias is not None: 682 | nn.init.constant_(m.bias, 0) 683 | if m.weight is not None: 684 | nn.init.constant_(m.weight, 1.0) 685 | 686 | # used in building optimizer 687 | @torch.jit.ignore 688 | def no_weight_decay(self): 689 | return {} 690 | 691 | @staticmethod 692 | def _make_downsample(dim=96, out_dim=192, norm_layer=nn.LayerNorm): 693 | return nn.Sequential( 694 | Permute(0, 3, 1, 2), 695 | nn.Conv2d(dim, out_dim, kernel_size=2, stride=2), 696 | Permute(0, 2, 3, 1), 697 | norm_layer(out_dim), 698 | ) 699 | 700 | @staticmethod 701 | def _make_layer( 702 | dim=96, 703 | drop_path=[0.1, 0.1], 704 | use_checkpoint=False, 705 | norm_layer=nn.LayerNorm, 706 | downsample=nn.Identity(), 707 | # =========================== 708 | ssm_d_state=16, 709 | ssm_ratio=2.0, 710 | ssm_dt_rank="auto", 711 | ssm_act_layer=nn.SiLU, 712 | ssm_conv=3, 713 | ssm_conv_bias=True, 714 | ssm_drop_rate=0.0, 715 | ssm_simple_init=False, 716 | # =========================== 717 | directions=None, 718 | **kwargs, 719 | ): 720 | depth = len(drop_path) 721 | blocks = [] 722 | for d in range(depth): 723 | blocks.append(VSSBlock( 724 | hidden_dim=dim, 725 | drop_path=drop_path[d], 726 | norm_layer=norm_layer, 727 | ssm_d_state=ssm_d_state, 728 | ssm_ratio=ssm_ratio, 729 | ssm_dt_rank=ssm_dt_rank, 730 | ssm_act_layer=ssm_act_layer, 731 | ssm_conv=ssm_conv, 732 | ssm_conv_bias=ssm_conv_bias, 733 | ssm_drop_rate=ssm_drop_rate, 734 | ssm_simple_init=ssm_simple_init, 735 | use_checkpoint=use_checkpoint, 736 | directions=directions[d] if directions is not None else None 737 | )) 738 | 739 | return nn.Sequential(OrderedDict( 740 | blocks=nn.Sequential(*blocks,), 741 | downsample=downsample, 742 | )) 743 | 744 | def forward_features(self, x): # x [1, 3, 256, 256] 745 | skip_list = [] 746 | x = self.patch_embed(x) # x [1, 64, 64, 96] , dims=[96, 192, 384, 768] 747 | for layer in self.layers: 748 | skip_list.append(x) # x [1, 96, 64, 64] 749 | x = layer(x) # x encoder 的最终输出 750 | return x, skip_list # x [1, 8, 8, 768] , len(skip_list) = 4 , skip_list[0] [1, 64, 64, 96] 751 | 752 | # def forward_features(self, x): # x [1, 3, 256, 256] 753 | # skip_list = [] 754 | # x = self.patch_embed(x) # x [1, 64, 64, 96] , dims=[96, 192, 384, 768] 755 | # for i, layer in enumerate(self.layers): 756 | # #引入q-shift操作,参数随着层数加深而减小 757 | # with torch.no_grad(): 758 | # ratio_1_to_almost0 = (1.0 - (i / len(self.layers))) 759 | # #q-shift操作计算x* 760 | # xx = q_shift(x, shift_pixel=1, gamma=1/4, patch_resolution=None) 761 | 762 | # x = x * ratio_1_to_almost0 + xx * (1 - ratio_1_to_almost0) 763 | # skip_list.append(x) # x [1, 96, 64, 64] 764 | # x = layer(x) # x encoder 的最终输出 765 | # return x, skip_list # x [1, 8, 8, 768] , len(skip_list) = 4 , skip_list[0] [1, 64, 64, 96] 766 | 767 | 768 | def forward_features_up(self, x, skip_list): 769 | for inx, layer_up in enumerate(self.layers_up): # x [1, 8, 8, 768] 770 | if inx == 0: 771 | x = layer_up(x) 772 | else: 773 | x = layer_up(x+skip_list[-inx]) 774 | 775 | return x # [1, 64, 64, 96] 776 | 777 | def forward_final(self, x): # x [1, 64, 64, 96] 778 | x = self.final_up(x) # x [1, 256, 256, 24] 779 | x = x.permute(0,3,1,2) # x [1, 24, 256, 256] 780 | x = self.final_conv(x) # x [1, 1, 256, 256] 781 | return x 782 | 783 | def forward_backbone(self, x): 784 | x = self.patch_embed(x) 785 | if self.ape: 786 | x = x + self.absolute_pos_embed 787 | x = self.pos_drop(x) 788 | 789 | for layer in self.layers: 790 | x = layer(x) 791 | return x 792 | 793 | def forward_bak(self, x): 794 | x, skip_list = self.forward_features(x) # skip_list[0] [2, 64, 64, 96] [3] [2, 8, 8, 768] 795 | x = self.forward_features_up(x, skip_list) 796 | x = self.forward_final(x) 797 | 798 | return x 799 | # modified by sim to sdi module 800 | def forward(self, x): 801 | x, skip_list = self.forward_features(x) 802 | skip_list[0] = skip_list[0].permute(0, 3, 1, 2) 803 | skip_list[1] = skip_list[1].permute(0, 3, 1, 2) 804 | skip_list[2] = skip_list[2].permute(0, 3, 1, 2) 805 | skip_list[3] = skip_list[3].permute(0, 3, 1, 2) 806 | for i, o in enumerate(skip_list): # 4 倍上采样 807 | skip_list[i] = F.interpolate(o, scale_factor=4, mode='bilinear') # skip_list[0] [2, 64, 64, 96] [3] [2, 8, 8, 768] 808 | return skip_list[0], skip_list[1], skip_list[2], skip_list[3] 809 | 810 | 811 | 812 | 813 | --------------------------------------------------------------------------------