├── model ├── __init__.py ├── localextro.py ├── basemodel.py ├── utils.py ├── unet.py ├── vgg.py ├── unetp.py ├── detect_head.py └── caformer.py ├── Imgs ├── BSDS.png └── Imgs-bsds.png ├── requirements.txt ├── inference.py ├── README.md ├── train.py ├── utils.py ├── test.py ├── main.py └── data_process.py /model/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /Imgs/BSDS.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Li-yachuan/NBED/HEAD/Imgs/BSDS.png -------------------------------------------------------------------------------- /Imgs/Imgs-bsds.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Li-yachuan/NBED/HEAD/Imgs/Imgs-bsds.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | fvcore==0.1.5.post20221221 2 | imageio==2.10.1 3 | numpy==1.20.2 4 | opencv_python==4.5.1.48 5 | Pillow==9.0.1 6 | pysodmetrics==1.4.2 7 | scikit_learn==1.0.2 8 | scipy==1.7.1 9 | thop==0.0.31.post2005241907 10 | timm==0.6.12 11 | torch==1.13.1 12 | torchvision==0.10.0 13 | tqdm==4.51.0 14 | -------------------------------------------------------------------------------- /model/localextro.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | 3 | 4 | class LCAL(nn.Module): 5 | """""" 6 | def __init__(self, Dulbrn): 7 | super(LCAL, self).__init__() 8 | 9 | self.dulbrn = Dulbrn 10 | 11 | self.conv1 = nn.Sequential( 12 | nn.Conv2d(3, self.dulbrn, (3, 3), stride=1, padding=1), 13 | nn.ReLU(inplace=True) 14 | ) 15 | 16 | self.conv2 = nn.Sequential( 17 | # nn.MaxPool2d(2, stride=2, ceil_mode=True), 18 | nn.Conv2d(self.dulbrn, self.dulbrn * 2, (3, 3), stride=2, padding=1), 19 | nn.ReLU(inplace=True) 20 | ) 21 | 22 | self.out_channels = [self.dulbrn, self.dulbrn * 2] 23 | 24 | def forward(self, x): 25 | f1 = self.conv1(x) 26 | f2 = self.conv2(f1) 27 | features = [f1, f2] 28 | return features -------------------------------------------------------------------------------- /model/basemodel.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | from model.utils import get_encoder, get_decoder, get_head 3 | from torch.nn.functional import interpolate 4 | 5 | 6 | class Basemodel(nn.Module): 7 | def __init__(self, encoder_name="caformer-m36", decoder_name="unetp", head_name=None): 8 | super(Basemodel, self).__init__() 9 | 10 | self.encoder = get_encoder(encoder_name) 11 | 12 | self.decoder, self.decoder_channel = get_decoder(decoder_name, self.encoder.out_channels) 13 | 14 | self.head = get_head(head_name, self.decoder_channel) 15 | 16 | def forward(self, x): 17 | _, _, H, W = x.size() 18 | features = self.encoder(x) 19 | features = self.decoder(features) 20 | edge = self.head(features) 21 | return edge 22 | # return interpolate(edge, (H, W), mode="bilinear") 23 | 24 | 25 | -------------------------------------------------------------------------------- /inference.py: -------------------------------------------------------------------------------- 1 | from model.basemodel import Basemodel 2 | from PIL import Image 3 | import torch,torchvision 4 | img = "/workspace/00Dataset/BSDS-yc/test/100007.jpg" 5 | ckpt = "best.pth" 6 | encoder = "DUL-M36" 7 | decoder = "UNETP" 8 | head = "default" 9 | model = Basemodel(encoder_name=encoder, 10 | decoder_name=decoder, 11 | head_name=head).cuda() 12 | ckpt = torch.load(ckpt, map_location='cpu')['state_dict'] 13 | 14 | value = ckpt.pop('encoder.conv2.1.weight') 15 | ckpt['encoder.conv2.0.weight'] = value 16 | # 17 | value = ckpt.pop('encoder.conv2.1.bias') 18 | ckpt['encoder.conv2.0.bias'] = value 19 | 20 | 21 | model.load_state_dict(ckpt) 22 | model.eval() 23 | 24 | img = torchvision.transforms.ToTensor()(Image.open(img)).unsqueeze(0) 25 | img = img*2-1 26 | 27 | edge = model(img.cuda()) 28 | 29 | torchvision.utils.save_image(edge,"edge.png") 30 | -------------------------------------------------------------------------------- /model/utils.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | def get_encoder(nm,Dulbrn=16): 3 | if "CAFORMER-M36" == nm.upper(): 4 | from model.caformer import caformer_m36_384_in21ft1k 5 | encoder = caformer_m36_384_in21ft1k(pretrained=True) 6 | elif "DUL-M36" == nm.upper(): 7 | from model.caformer import caformer_m36_384_in21ft1k 8 | encoder = caformer_m36_384_in21ft1k(pretrained=True, Dulbrn=Dulbrn) 9 | elif "DUL-S18" == nm.upper(): 10 | from model.caformer import caformer_s18_384_in21ft1k 11 | encoder = caformer_s18_384_in21ft1k(pretrained=True, Dulbrn=Dulbrn) 12 | elif "VGG-16" == nm.upper(): 13 | from model.vgg import VGG16_C 14 | encoder = VGG16_C(pretrain="model/vgg16.pth") 15 | elif "LCAL" == nm.upper(): 16 | from model.localextro import LCAL 17 | encoder = LCAL(Dulbrn=Dulbrn) 18 | else: 19 | raise Exception("Error encoder") 20 | return encoder 21 | 22 | def get_head(nm,channels): 23 | from model.detect_head import CoFusion_head,CSAM_head,CDCM_head,Default_head,Fusion_head 24 | if nm == "aspp": 25 | head = CDCM_head(channels) 26 | elif nm == "atten": 27 | head = CSAM_head(channels) 28 | elif nm == "cofusion": 29 | head = CoFusion_head(channels) 30 | elif nm == "fusion": 31 | head = Fusion_head(channels) 32 | elif nm == "default": 33 | head = Default_head(channels) 34 | else: 35 | raise Exception("Error head") 36 | return head 37 | 38 | 39 | def get_decoder(nm, incs, oucs=None): 40 | if oucs is None: 41 | # oucs = (32, 32, 64, 128, 256, 512) 42 | oucs = (32, 32, 64, 128, 384) 43 | 44 | if nm.upper() == "UNETP": 45 | from model.unetp import UnetDecoder 46 | decoder = UnetDecoder(incs, oucs[-len(incs):]) 47 | elif nm.upper() == "UNET": 48 | from model.unet import UnetDecoder 49 | decoder = UnetDecoder(incs, oucs[-len(incs):]) 50 | elif nm.upper() == "DEFAULT": 51 | from model.unet import Identity 52 | decoder = Identity(incs, oucs[-len(incs):]) 53 | 54 | else: 55 | raise Exception("Error decoder") 56 | return decoder,oucs[-len(incs):] 57 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # NBED 2 | Code of paper [A new baseline for edge detection: Make Encoder-Decoder great again](https://arxiv.org/pdf/2409.14976) 3 | 4 | ## 0. Pip environment 5 | 6 | ``` 7 | pip install -r requirements.txt 8 | ``` 9 | 10 | ## 1. Test NBED 11 | Modify the values of ckpt and img in inference.py. ckpt on BSDS can be found from 12 | https://drive.google.com/file/d/1TKg37m3KWuv4A8FTlXJ-N20Ar46PmRdH/view?usp=sharing 13 | And running command 14 | ``` 15 | python inference.py 16 | ``` 17 | 18 | ## 2. Training NBED 19 | ### 2.1 Preparing the dataset 20 | Download the dataset to any dir and point to the dir in the code 21 | -BSDS500 following the setting of "The Treasure Beneath Multiple Annotations: An Uncertainty-aware Edge Detector" 22 | -NYUDv2 following the setting of "Pixel Difference Networks for Efficient Edge Detection" and random crop to 400*400 23 | -BIPED following the setting of "Dense Extreme Inception Network for Edge Detection" 24 | ### 2.2 Preparing the pretrained weights 25 | Down it from https://huggingface.co/sail/dl/resolve/main/caformer/caformer_m36_384_in21ft1k.pth 26 | and put it into the dir ./model 27 | ### 2.3 Training NBED 28 | ``` 29 | python main.py --batch_size 4 --stepsize 3-4 --gpu 1 --savedir 0305-bsds --encoder Dul-M36 --decoder unetp --head default --note 'training on BSDS500' --dataset BSDS --maxepoch 6 30 | ``` 31 | ### 2.4 Eval NBED 32 | Following the previous methods. such as RCF and PiDiNet 33 | 34 | ![Result of BSDS](./Imgs/BSDS.png "Result of BSDS") 35 | ![Img of BSDS](./Imgs/Imgs-bsds.png "Img of BSDS") 36 | 37 | The result of BSDS500 can be download here 38 | https://drive.google.com/file/d/1PiPklsH7w6zNxdGWW-JpnUsFOdiYLHwG/view?usp=sharing 39 | 40 | ## 3. UPDATE 41 | ### 3.1 release the ckpt of BIPED 42 | checkpoint on BIPED is [here](https://drive.google.com/file/d/1IJO3VYrzi1Rp6YS4CzawZTrggxzz5cBx/view?usp=drive_link) 43 | ### 3.2 Fix bugs 44 | ODS/OIS on BIPED is a little error, due to a wrong tolerance is used. And it is corrected in [version 2](https://arxiv.org/pdf/2409.14976) Thanks for the reminder from [yx-yyds](https://github.com/yx-yyds) 45 | 46 | ### 3.3 **We released [DDN](https://github.com/Li-yachuan/DDN), the follow-up work of NBED.** 47 | Main features: By introducing Evidence Lower Bound loss and learnable Gaussian distributions, DDN is capable of generating multi-granularity edges. The ODS of DDN on the BSDS500 dataset is 0.867, which is 0.022 higher than that of NBED (0.845) 48 | 49 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | from utils import step_lr_scheduler, Averagvalue, cross_entropy_loss, save_checkpoint 2 | import time 3 | import torchvision 4 | from os.path import join 5 | import os 6 | import torch 7 | 8 | 9 | def train(train_loader, model, optimizer, epoch, args): 10 | optimizer = step_lr_scheduler(optimizer, epoch, args.stepsize) 11 | save_dir = join(args.savedir, 'epoch-%d-training-record' % epoch) 12 | os.makedirs(save_dir, exist_ok=True) 13 | 14 | batch_time = Averagvalue() 15 | data_time = Averagvalue() 16 | losses = Averagvalue() 17 | 18 | # switch to train mode 19 | model.train() 20 | print(epoch, 21 | "Pretrained lr:",optimizer.state_dict()['param_groups'][0]['lr'], 22 | "Unpretrained lr:",optimizer.state_dict()['param_groups'][2]['lr']) 23 | 24 | end = time.time() 25 | epoch_loss = [] 26 | counter = 0 27 | for i, (image, label) in enumerate(train_loader): 28 | # measure data loading time 29 | data_time.update(time.time() - end) 30 | image, label = image.cuda(), label.cuda() 31 | outputs = model(image) 32 | counter += 1 33 | loss = cross_entropy_loss(outputs, label, args.loss_lmbda) 34 | loss = loss / args.itersize 35 | loss.backward() 36 | 37 | if counter == args.itersize: 38 | optimizer.step() 39 | optimizer.zero_grad() 40 | counter = 0 41 | 42 | losses.update(loss, image.size(0)) 43 | epoch_loss.append(loss) 44 | batch_time.update(time.time() - end) 45 | end = time.time() 46 | if i % args.print_freq == 0: 47 | info = 'Epoch: [{0}/{1}][{2}/{3}] '.format(epoch, args.maxepoch, i, len(train_loader)) + \ 48 | 'Time {batch_time.val:.2f} (avg:{batch_time.avg:.2f}) '.format(batch_time=batch_time) + \ 49 | 'Loss {loss.val:.2f} (avg:{loss.avg:.2f}) '.format(loss=losses) 50 | print(info) 51 | label[label == 2] = 0.5 52 | outputs = torch.cat([outputs, label], dim=0) 53 | torchvision.utils.save_image(outputs, join(save_dir, "iter-%d.jpg" % i), nrow=args.batch_size) 54 | 55 | # save_checkpoint({ 56 | # 'epoch': epoch, 57 | # 'state_dict': model.state_dict(), 58 | # }, filename=join(save_dir, "epoch-%d-checkpoint-%d.pth" % (epoch,i))) 59 | 60 | 61 | # save checkpoint 62 | save_checkpoint({ 63 | 'epoch': epoch, 64 | 'state_dict': model.state_dict(), 65 | }, filename=join(save_dir, "epoch-%d-checkpoint.pth" % epoch)) 66 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import os 4 | import sys 5 | 6 | 7 | def cross_entropy_loss(prediction, labelf, beta): 8 | label = labelf.long() 9 | mask = labelf.clone() 10 | num_positive = torch.sum(label == 1).float() 11 | num_negative = torch.sum(label == 0).float() 12 | 13 | mask[label == 1] = 1.0 * num_negative / (num_positive + num_negative) 14 | mask[label == 0] = beta * num_positive / (num_positive + num_negative) 15 | mask[label == 2] = 0 16 | cost = F.binary_cross_entropy( 17 | prediction, labelf, weight=mask, reduction='sum') 18 | 19 | return cost 20 | 21 | 22 | def get_model_parm_nums(model): 23 | total = sum([param.numel() for param in model.parameters()]) 24 | total = float(total) / 1e6 25 | return total 26 | 27 | 28 | class Logger(object): 29 | def __init__(self, fpath=None): 30 | self.console = sys.stdout 31 | self.file = None 32 | if fpath is not None: 33 | self.file = open(fpath, 'w') 34 | 35 | def __del__(self): 36 | self.close() 37 | 38 | def __enter__(self): 39 | pass 40 | 41 | def __exit__(self, *args): 42 | self.close() 43 | 44 | def write(self, msg): 45 | self.console.write(msg) 46 | if self.file is not None: 47 | self.file.write(msg) 48 | 49 | def flush(self): 50 | self.console.flush() 51 | if self.file is not None: 52 | self.file.flush() 53 | os.fsync(self.file.fileno()) 54 | 55 | def close(self): 56 | self.console.close() 57 | if self.file is not None: 58 | self.file.close() 59 | 60 | 61 | class Averagvalue(object): 62 | """Computes and stores the average and current value""" 63 | 64 | def __init__(self): 65 | self.reset() 66 | 67 | def reset(self): 68 | self.val = 0 69 | self.avg = 0 70 | self.sum = 0 71 | self.count = 0 72 | 73 | def update(self, val, n=1): 74 | self.val = val 75 | self.sum += val * n 76 | self.count += n 77 | self.avg = self.sum / self.count 78 | 79 | 80 | def save_checkpoint(state, filename='checkpoint.pth'): 81 | torch.save(state, filename) 82 | 83 | 84 | def step_lr_scheduler(optimizer, epoch, lr_decay_epoch): 85 | """Decay learning rate by a factor of 0.1 every lr_decay_epoch epochs.""" 86 | if epoch in lr_decay_epoch: 87 | for param_group in optimizer.param_groups: 88 | param_group['lr'] = 0.1 * param_group['lr'] 89 | 90 | return optimizer 91 | -------------------------------------------------------------------------------- /model/unet.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | import torch 3 | from torch.nn import functional as F 4 | 5 | 6 | class Attention(nn.Module): 7 | def __init__(self, name, **params): 8 | super().__init__() 9 | 10 | if name is None: 11 | self.attention = nn.Identity(**params) 12 | 13 | def forward(self, x): 14 | return self.attention(x) 15 | 16 | 17 | class Conv2dReLU(nn.Sequential): 18 | def __init__( 19 | self, 20 | in_channels, 21 | out_channels, 22 | kernel_size, 23 | padding=0, 24 | stride=1, 25 | use_batchnorm=True, 26 | ): 27 | conv = nn.Conv2d( 28 | in_channels, 29 | out_channels, 30 | kernel_size, 31 | stride=stride, 32 | padding=padding, 33 | bias=not (use_batchnorm), 34 | ) 35 | relu = nn.ReLU(inplace=True) 36 | 37 | bn = nn.InstanceNorm2d(out_channels) 38 | 39 | super(Conv2dReLU, self).__init__(conv, bn, relu) 40 | 41 | 42 | class DecoderBlock(nn.Module): 43 | def __init__( 44 | self, 45 | in_channels, 46 | skip_channels, 47 | out_channels, 48 | use_batchnorm=True, 49 | attention_type=None, 50 | ): 51 | super().__init__() 52 | self.conv1 = Conv2dReLU( 53 | in_channels + skip_channels, 54 | out_channels, 55 | kernel_size=3, 56 | padding=1, 57 | use_batchnorm=use_batchnorm, 58 | ) 59 | self.attention1 = Attention(attention_type, in_channels=in_channels + skip_channels) 60 | self.conv2 = Conv2dReLU( 61 | out_channels, 62 | out_channels, 63 | kernel_size=3, 64 | padding=1, 65 | use_batchnorm=use_batchnorm, 66 | ) 67 | self.attention2 = Attention(attention_type, in_channels=out_channels) 68 | 69 | def forward(self, x, skip=None): 70 | 71 | if skip is not None: 72 | x = F.interpolate(x, size=skip.size()[2:], mode="bilinear") 73 | # x = F.interpolate(x, size=skip.size()[2:], mode="nearest") 74 | x = torch.cat([x, skip], dim=1) 75 | x = self.attention1(x) 76 | else: 77 | x = F.interpolate(x, scale_factor=2, mode="bilinear") 78 | # x = F.interpolate(x, scale_factor=2, mode="nearest") 79 | x = self.conv1(x) 80 | x = self.conv2(x) 81 | x = self.attention2(x) 82 | return x 83 | 84 | 85 | class UnetDecoder(nn.Module): 86 | def __init__( 87 | self, 88 | encoder_channels, 89 | decoder_channels, 90 | ): 91 | super().__init__() 92 | 93 | self.depth = len(encoder_channels) 94 | convs = dict() 95 | 96 | for d in range(self.depth - 1): 97 | if d == self.depth - 2: 98 | convs["conv{}".format(d)] = DecoderBlock(encoder_channels[d + 1], 99 | encoder_channels[d], 100 | decoder_channels[d]) 101 | else: 102 | convs["conv{}".format(d)] = DecoderBlock(decoder_channels[d + 1], 103 | encoder_channels[d], 104 | decoder_channels[d]) 105 | 106 | self.convs = nn.ModuleDict(convs) 107 | 108 | # self.final = nn.Sequential( 109 | # nn.Conv2d(decoder_channels[0], 1, 3, padding=1), 110 | # nn.Sigmoid()) 111 | 112 | def forward(self, features): 113 | 114 | for d in range(self.depth - 2, -1, -1): 115 | features[d] = self.convs["conv{}".format(d)](features[d + 1], features[d]) 116 | 117 | return features 118 | # return self.final(features[0]) 119 | 120 | 121 | class Identity(nn.Module): 122 | def __init__( 123 | self, 124 | encoder_channels, 125 | decoder_channels, 126 | ): 127 | super().__init__() 128 | convs = [] 129 | for ec, dc in zip(encoder_channels, decoder_channels): 130 | convs.append(nn.Conv2d(ec, dc, 1)) 131 | self.convs = nn.ModuleList(convs) 132 | 133 | def forward(self, features): 134 | return [c(f) for f, c in zip(features, self.convs)] 135 | -------------------------------------------------------------------------------- /model/vgg.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torchvision 4 | import torch.nn as nn 5 | import math 6 | 7 | 8 | class VGG16_C(nn.Module): 9 | """""" 10 | 11 | def __init__(self, pretrain=None, logger=None): 12 | super(VGG16_C, self).__init__() 13 | self.conv1_1 = nn.Conv2d(3, 64, (3, 3), stride=1, padding=1) 14 | self.relu1_1 = nn.ReLU(inplace=True) 15 | self.conv1_2 = nn.Conv2d(64, 64, (3, 3), stride=1, padding=1) 16 | self.relu1_2 = nn.ReLU(inplace=True) 17 | self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True) 18 | self.conv2_1 = nn.Conv2d(64, 128, (3, 3), stride=1, padding=1) 19 | self.relu2_1 = nn.ReLU(inplace=True) 20 | self.conv2_2 = nn.Conv2d(128, 128, (3, 3), stride=1, padding=1) 21 | self.relu2_2 = nn.ReLU(inplace=True) 22 | self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True) 23 | self.conv3_1 = nn.Conv2d(128, 256, (3, 3), stride=1, padding=1) 24 | self.relu3_1 = nn.ReLU(inplace=True) 25 | self.conv3_2 = nn.Conv2d(256, 256, (3, 3), stride=1, padding=1) 26 | self.relu3_2 = nn.ReLU(inplace=True) 27 | self.conv3_3 = nn.Conv2d(256, 256, (3, 3), stride=1, padding=1) 28 | self.relu3_3 = nn.ReLU(inplace=True) 29 | self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True) 30 | self.conv4_1 = nn.Conv2d(256, 512, (3, 3), stride=1, padding=1) 31 | self.relu4_1 = nn.ReLU(inplace=True) 32 | self.conv4_2 = nn.Conv2d(512, 512, (3, 3), stride=1, padding=1) 33 | self.relu4_2 = nn.ReLU(inplace=True) 34 | self.conv4_3 = nn.Conv2d(512, 512, (3, 3), stride=1, padding=1) 35 | self.relu4_3 = nn.ReLU(inplace=True) 36 | self.pool4 = nn.MaxPool2d(2, stride=1, padding=1, ceil_mode=True) 37 | self.conv5_1 = nn.Conv2d(512, 512, (3, 3), stride=1, padding=2, dilation=2) 38 | self.relu5_1 = nn.ReLU(inplace=True) 39 | self.conv5_2 = nn.Conv2d(512, 512, (3, 3), stride=1, padding=2, dilation=2) 40 | self.relu5_2 = nn.ReLU(inplace=True) 41 | self.conv5_3 = nn.Conv2d(512, 512, (3, 3), stride=1, padding=2, dilation=2) 42 | self.relu5_3 = nn.ReLU(inplace=True) 43 | 44 | self.out_channels = [3, 64, 128, 256, 512, 512] 45 | 46 | if pretrain: 47 | state_dict = torch.load(pretrain) 48 | own_state_dict = self.state_dict() 49 | for name, param in own_state_dict.items(): 50 | if name in state_dict: 51 | if logger: 52 | logger.info('copy the weights of %s from pretrained model' % name) 53 | param.copy_(state_dict[name]) 54 | else: 55 | 56 | if logger: 57 | logger.info('init the weights of %s from mean 0, std 0.01 gaussian distribution' \ 58 | % name) 59 | if 'bias' in name: 60 | param.zero_() 61 | else: 62 | param.normal_(0, 0.01) 63 | else: 64 | self._initialize_weights(logger) 65 | 66 | def forward(self, x): 67 | conv1_1 = self.relu1_1(self.conv1_1(x)) 68 | conv1_2 = self.relu1_2(self.conv1_2(conv1_1)) 69 | pool1 = self.pool1(conv1_2) 70 | conv2_1 = self.relu2_1(self.conv2_1(pool1)) 71 | conv2_2 = self.relu2_2(self.conv2_2(conv2_1)) 72 | pool2 = self.pool2(conv2_2) 73 | conv3_1 = self.relu3_1(self.conv3_1(pool2)) 74 | conv3_2 = self.relu3_2(self.conv3_2(conv3_1)) 75 | conv3_3 = self.relu3_3(self.conv3_3(conv3_2)) 76 | pool3 = self.pool3(conv3_3) 77 | conv4_1 = self.relu4_1(self.conv4_1(pool3)) 78 | conv4_2 = self.relu4_2(self.conv4_2(conv4_1)) 79 | conv4_3 = self.relu4_3(self.conv4_3(conv4_2)) 80 | pool4 = self.pool4(conv4_3) 81 | conv5_1 = self.relu5_1(self.conv5_1(pool4)) 82 | conv5_2 = self.relu5_2(self.conv5_2(conv5_1)) 83 | conv5_3 = self.relu5_3(self.conv5_3(conv5_2)) 84 | side = [x,conv1_2, conv2_2, conv3_3, conv4_3, conv5_3] 85 | 86 | return side 87 | 88 | def _initialize_weights(self, logger=None): 89 | for m in self.modules(): 90 | if isinstance(m, nn.Conv2d): 91 | if logger: 92 | logger.info('init the weights of %s from mean 0, std 0.01 gaussian distribution' \ 93 | % m) 94 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 95 | m.weight.data.normal_(0, math.sqrt(2. / n)) 96 | if m.bias is not None: 97 | m.bias.data.zero_() 98 | elif isinstance(m, nn.BatchNorm2d): 99 | m.weight.data.fill_(1) 100 | m.bias.data.zero_() 101 | elif isinstance(m, nn.Linear): 102 | m.weight.data.normal_(0, 0.01) 103 | m.bias.data.zero_() 104 | 105 | 106 | if __name__ == '__main__': 107 | model = VGG16_C() 108 | # im = np.zeros((1,3,100,100)) 109 | # out = model(Variable(torch.from_numpy(im))) 110 | -------------------------------------------------------------------------------- /model/unetp.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | import torch 3 | from torch.nn import functional as F 4 | 5 | 6 | class Attention(nn.Module): 7 | def __init__(self, name, **params): 8 | super().__init__() 9 | 10 | if name is None: 11 | self.attention = nn.Identity(**params) 12 | 13 | def forward(self, x): 14 | return self.attention(x) 15 | 16 | 17 | class Conv2dReLU(nn.Sequential): 18 | def __init__( 19 | self, 20 | in_channels, 21 | out_channels, 22 | kernel_size, 23 | padding=0, 24 | stride=1, 25 | use_batchnorm=True, 26 | ): 27 | conv = nn.Conv2d( 28 | in_channels, 29 | out_channels, 30 | kernel_size, 31 | stride=stride, 32 | padding=padding, 33 | bias=not (use_batchnorm), 34 | ) 35 | relu = nn.ReLU(inplace=True) 36 | 37 | bn = nn.InstanceNorm2d(out_channels) 38 | 39 | super(Conv2dReLU, self).__init__(conv, bn, relu) 40 | 41 | 42 | class DecoderBlock(nn.Module): 43 | def __init__( 44 | self, 45 | in_channels, 46 | skip_channels, 47 | out_channels, 48 | use_batchnorm=True, 49 | attention_type=None, 50 | ): 51 | super().__init__() 52 | self.conv1 = Conv2dReLU( 53 | in_channels + skip_channels, 54 | out_channels, 55 | kernel_size=3, 56 | padding=1, 57 | use_batchnorm=use_batchnorm, 58 | ) 59 | self.attention1 = Attention(attention_type, in_channels=in_channels + skip_channels) 60 | self.conv2 = Conv2dReLU( 61 | out_channels, 62 | out_channels, 63 | kernel_size=3, 64 | padding=1, 65 | use_batchnorm=use_batchnorm, 66 | ) 67 | self.attention2 = Attention(attention_type, in_channels=out_channels) 68 | 69 | def forward(self, x, skip=None): 70 | 71 | if skip is not None: 72 | x = F.interpolate(x, size=skip.size()[2:], mode="bilinear") 73 | # x = F.interpolate(x, size=skip.size()[2:], mode="nearest") 74 | x = torch.cat([x, skip], dim=1) 75 | x = self.attention1(x) 76 | else: 77 | x = F.interpolate(x, scale_factor=2, mode="bilinear") 78 | # x = F.interpolate(x, scale_factor=2, mode="nearest") 79 | x = self.conv1(x) 80 | x = self.conv2(x) 81 | x = self.attention2(x) 82 | return x 83 | 84 | 85 | class UnetDecoder(nn.Module): 86 | def __init__( 87 | self, 88 | encoder_channels, 89 | decoder_channels, 90 | ): 91 | super().__init__() 92 | 93 | # encoder_channels: 96,192,384,576 94 | # decoder_channel: 32,64, 128,256 95 | 96 | # self.conv11 = DecoderBlock(encoder_channels[1], encoder_channels[0], decoder_channels[0]) 97 | # self.conv21 = DecoderBlock(encoder_channels[2], encoder_channels[1], decoder_channels[1]) 98 | # self.conv31 = DecoderBlock(encoder_channels[3], encoder_channels[2], decoder_channels[2]) 99 | # 100 | # self.conv12 = DecoderBlock(decoder_channels[1], decoder_channels[0], decoder_channels[0]) 101 | # self.conv22 = DecoderBlock(decoder_channels[2], decoder_channels[1], decoder_channels[1]) 102 | # 103 | # self.conv13 = DecoderBlock(decoder_channels[1], decoder_channels[0], decoder_channels[0]) 104 | self.width = len(encoder_channels) 105 | convs = dict() 106 | for w in range(self.width - 1): 107 | for d in range(self.width - w - 1): 108 | if w == 0: 109 | convs["conv{}_{}".format(d, w)] = DecoderBlock(encoder_channels[d + 1], 110 | encoder_channels[d], 111 | decoder_channels[d]) 112 | else: 113 | convs["conv{}_{}".format(d, w)] = DecoderBlock(decoder_channels[d + 1], 114 | decoder_channels[d], 115 | decoder_channels[d]) 116 | 117 | 118 | self.convs = nn.ModuleDict(convs) 119 | 120 | # self.final = nn.Sequential( 121 | # nn.Conv2d(decoder_channels[0], 1, 3, padding=1), 122 | # nn.Sigmoid()) 123 | 124 | def forward(self, features): 125 | 126 | # features[0] = self.conv11(features[1], features[0]) 127 | # features[1] = self.conv21(features[2], features[1]) 128 | # features[2] = self.conv31(features[3], features[2]) 129 | # 130 | # features[0] = self.conv12(features[1], features[0]) 131 | # features[1] = self.conv22(features[2], features[1]) 132 | # 133 | # features[0] = self.conv13(features[1], features[0]) 134 | for w in range(self.width - 1): 135 | for d in range(self.width - w - 1): 136 | features[d] = self.convs["conv{}_{}".format(d, w)](features[d + 1], features[d]) 137 | # return self.final(features[0]) 138 | return features 139 | -------------------------------------------------------------------------------- /model/detect_head.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class CSAM(nn.Module): 7 | """ 8 | Compact Spatial Attention Module 9 | """ 10 | 11 | def __init__(self, channels): 12 | super(CSAM, self).__init__() 13 | 14 | mid_channels = 4 15 | self.relu1 = nn.ReLU() 16 | self.conv1 = nn.Conv2d(channels, mid_channels, kernel_size=1, padding=0) 17 | self.conv2 = nn.Conv2d(mid_channels, 1, kernel_size=3, padding=1, bias=False) 18 | self.sigmoid = nn.Sigmoid() 19 | nn.init.constant_(self.conv1.bias, 0) 20 | 21 | def forward(self, x): 22 | y = self.relu1(x) 23 | y = self.conv1(y) 24 | y = self.conv2(y) 25 | y = self.sigmoid(y) 26 | 27 | return x * y 28 | 29 | 30 | class CSAM_head(nn.Module): 31 | def __init__(self, channels): 32 | super(CSAM_head, self).__init__() 33 | 34 | modulelst = [] 35 | for num in channels: 36 | modulelst.append(nn.Sequential(CSAM(num), nn.Conv2d(num, 1, 1))) 37 | 38 | self.modulelst = nn.ModuleList(modulelst) 39 | 40 | self.final = nn.Sequential(nn.Conv2d(5, 1, 1), nn.Sigmoid()) 41 | 42 | def forward(self, feats): 43 | _, _, H, W = feats[0].size() 44 | for i in range(len(feats)): 45 | feats[i] = F.interpolate(self.modulelst[i](feats[i]), (H, W), mode="bilinear") 46 | 47 | return self.final(torch.cat(feats, dim=1)) 48 | 49 | 50 | class CDCM(nn.Module): 51 | """ 52 | Compact Dilation Convolution based Module 53 | """ 54 | 55 | def __init__(self, in_channels, out_channels): 56 | super(CDCM, self).__init__() 57 | 58 | self.relu1 = nn.ReLU() 59 | self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=1, padding=0) 60 | self.conv2_1 = nn.Conv2d(out_channels, out_channels, kernel_size=3, dilation=5, padding=5, bias=False) 61 | self.conv2_2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, dilation=7, padding=7, bias=False) 62 | self.conv2_3 = nn.Conv2d(out_channels, out_channels, kernel_size=3, dilation=9, padding=9, bias=False) 63 | self.conv2_4 = nn.Conv2d(out_channels, out_channels, kernel_size=3, dilation=11, padding=11, bias=False) 64 | nn.init.constant_(self.conv1.bias, 0) 65 | 66 | def forward(self, x): 67 | x = self.relu1(x) 68 | x = self.conv1(x) 69 | x1 = self.conv2_1(x) 70 | x2 = self.conv2_2(x) 71 | x3 = self.conv2_3(x) 72 | x4 = self.conv2_4(x) 73 | return x1 + x2 + x3 + x4 74 | 75 | 76 | class CDCM_head(nn.Module): 77 | def __init__(self, channels): 78 | super(CDCM_head, self).__init__() 79 | self.channels = channels 80 | modulelst = [] 81 | for num in channels: 82 | modulelst.append( 83 | nn.Sequential( 84 | CDCM(num, num), 85 | nn.Conv2d(num, 1, 1)) 86 | ) 87 | 88 | self.modulelst = nn.ModuleList(modulelst) 89 | 90 | self.final = nn.Sequential(nn.Conv2d(5, 1, 1), nn.Sigmoid()) 91 | 92 | def forward(self, feats): 93 | _, _, H, W = feats[0].size() 94 | for i in range(len(feats)): 95 | feats[i] = F.interpolate(self.modulelst[i](feats[i]), (H, W), mode="bilinear") 96 | return self.final(torch.cat(feats, dim=1)) 97 | 98 | 99 | class CoFusion(nn.Module): 100 | 101 | def __init__(self, in_ch, out_ch): 102 | super(CoFusion, self).__init__() 103 | self.conv1 = nn.Conv2d(in_ch, 64, kernel_size=3, 104 | stride=1, padding=1) 105 | self.conv2 = nn.Conv2d(64, 64, kernel_size=3, 106 | stride=1, padding=1) 107 | self.conv3 = nn.Conv2d(64, out_ch, kernel_size=3, 108 | stride=1, padding=1) 109 | self.relu = nn.ReLU() 110 | 111 | self.norm_layer1 = nn.GroupNorm(4, 64) 112 | self.norm_layer2 = nn.GroupNorm(4, 64) 113 | 114 | def forward(self, x): 115 | attn = self.relu(self.norm_layer1(self.conv1(x))) 116 | attn = self.relu(self.norm_layer2(self.conv2(attn))) 117 | attn = F.softmax(self.conv3(attn), dim=1) 118 | 119 | return ((x * attn).sum(1)).unsqueeze(1) 120 | 121 | 122 | class CoFusion_head(nn.Module): 123 | def __init__(self, channels): 124 | super(CoFusion_head, self).__init__() 125 | 126 | modulelst = [] 127 | for num in channels: 128 | modulelst.append(nn.Sequential( 129 | nn.Conv2d(num, 1, 1), 130 | nn.Sigmoid()) 131 | ) 132 | 133 | self.modulelst = nn.ModuleList(modulelst) 134 | 135 | self.final = CoFusion(5, 5) 136 | 137 | def forward(self, feats): 138 | _, _, H, W = feats[0].size() 139 | for i in range(len(feats)): 140 | feats[i] = F.interpolate(self.modulelst[i](feats[i]), (H, W), mode="bilinear") 141 | return self.final(torch.cat(feats, dim=1)) 142 | 143 | class Fusion_head(nn.Module): 144 | def __init__(self, channels): 145 | super(Fusion_head, self).__init__() 146 | 147 | modulelst = [] 148 | for num in channels: 149 | modulelst.append(nn.Sequential( 150 | nn.Conv2d(num, 1, 1), 151 | nn.Sigmoid()) 152 | ) 153 | 154 | self.modulelst = nn.ModuleList(modulelst) 155 | 156 | self.final = nn.Sequential(nn.Conv2d(5,1,3,padding=1), 157 | nn.Sigmoid() 158 | ) 159 | 160 | def forward(self, feats): 161 | _, _, H, W = feats[0].size() 162 | for i in range(len(feats)): 163 | feats[i] = F.interpolate(self.modulelst[i](feats[i]), (H, W), mode="bilinear") 164 | return self.final(torch.cat(feats, dim=1)) 165 | 166 | 167 | class Default_head(nn.Module): 168 | def __init__(self, channels): 169 | super(Default_head, self).__init__() 170 | 171 | self.final = nn.Sequential( 172 | nn.Conv2d(channels[0], 1, 3, padding=1), 173 | nn.Sigmoid() 174 | ) 175 | 176 | def forward(self, feats): 177 | return self.final(feats[0]) 178 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | from os.path import join 2 | import os 3 | from tqdm import tqdm 4 | import torch 5 | from PIL import Image 6 | import numpy as np 7 | import scipy.io as sio 8 | 9 | 10 | def test(model, test_loader, save_dir): 11 | print("single scale test") 12 | png_save_dir = os.path.join(save_dir, "png") 13 | mat_save_dir = os.path.join(save_dir, "mat") 14 | if not os.path.exists(png_save_dir): 15 | os.makedirs(png_save_dir) 16 | if not os.path.exists(mat_save_dir): 17 | os.makedirs(mat_save_dir) 18 | 19 | model.eval() 20 | for idx, (image, filename) in enumerate(tqdm(test_loader)): 21 | image = image.cuda() 22 | with torch.no_grad(): 23 | result = model(image).squeeze().cpu().numpy() 24 | result_png = Image.fromarray((result * 255).astype(np.uint8)) 25 | result_png.save(join(png_save_dir, "%s.png" % filename)) 26 | 27 | sio.savemat(join(mat_save_dir, "%s.mat" % filename), {'result': result}, do_compression=True) 28 | 29 | 30 | import cv2 31 | 32 | 33 | def multiscale_test(model, test_loader, save_dir, scale_num=7): 34 | png_save_dir = os.path.join(save_dir, "png") 35 | mat_save_dir = os.path.join(save_dir, "mat") 36 | if not os.path.exists(png_save_dir): 37 | os.makedirs(png_save_dir) 38 | if not os.path.exists(mat_save_dir): 39 | os.makedirs(mat_save_dir) 40 | 41 | model.eval() 42 | if scale_num == 7: 43 | print("7 scale test") 44 | scale = [0.5, 0.75, 1.0, 1.25, 1.5, 1.75] 45 | else: 46 | print("3 scale test") 47 | scale = [0.5, 1.0, 1.5] 48 | 49 | for idx, (image, filename) in enumerate(tqdm(test_loader)): 50 | image = image[0] 51 | image_in = image.numpy().transpose((1, 2, 0)) 52 | _, H, W = image.shape 53 | multi_fuse = np.zeros((H, W), np.float32) 54 | for k in range(0, len(scale)): 55 | im_ = cv2.resize(image_in, None, fx=scale[k], fy=scale[k], interpolation=cv2.INTER_LINEAR) 56 | im_ = torch.from_numpy(im_.transpose((2, 0, 1))).unsqueeze(0) 57 | with torch.no_grad(): 58 | result = model(im_.cuda()).squeeze().cpu().numpy() 59 | fuse = cv2.resize(result, (W, H), interpolation=cv2.INTER_LINEAR) 60 | multi_fuse += fuse 61 | multi_fuse = multi_fuse / len(scale) 62 | 63 | result_png = Image.fromarray((multi_fuse * 255).astype(np.uint8)) 64 | result_png.save(join(png_save_dir, "%s.png" % filename)) 65 | 66 | sio.savemat(join(mat_save_dir, "%s.mat" % filename), {'result': multi_fuse}, do_compression=True) 67 | 68 | 69 | from functools import partial 70 | 71 | 72 | def __identity(x): 73 | return x 74 | 75 | 76 | def enhence_test(model, test_loader, save_dir): 77 | print("rotate enhence test") 78 | png_save_dir = os.path.join(save_dir, "png") 79 | mat_save_dir = os.path.join(save_dir, "mat") 80 | if not os.path.exists(png_save_dir): 81 | os.makedirs(png_save_dir) 82 | if not os.path.exists(mat_save_dir): 83 | os.makedirs(mat_save_dir) 84 | 85 | model.eval() 86 | funcs = [partial(__identity), 87 | partial(cv2.rotate, rotateCode=cv2.ROTATE_90_CLOCKWISE), 88 | partial(cv2.rotate, rotateCode=cv2.ROTATE_180), 89 | partial(cv2.rotate, rotateCode=cv2.ROTATE_90_COUNTERCLOCKWISE)] 90 | 91 | funcs_t = [partial(__identity), 92 | partial(cv2.rotate, rotateCode=cv2.ROTATE_90_COUNTERCLOCKWISE), 93 | partial(cv2.rotate, rotateCode=cv2.ROTATE_180), 94 | partial(cv2.rotate, rotateCode=cv2.ROTATE_90_CLOCKWISE)] 95 | 96 | for idx, (image, filename) in enumerate(tqdm(test_loader)): 97 | 98 | image = image[0] 99 | image_in = image.numpy().transpose((1, 2, 0)) 100 | 101 | H, W, _ = image_in.shape 102 | 103 | multi_fuse = np.zeros((H, W), np.float32) 104 | 105 | for func, funct in zip(funcs, funcs_t): 106 | img = func(image_in) 107 | edge = __enhence_test_single(img, model) 108 | edge = funct(edge) 109 | multi_fuse += edge 110 | 111 | image_inf = cv2.flip(image_in, 1) # shuiping fanzhuan 112 | multi_fuse_f = np.zeros((H, W), np.float32) 113 | 114 | for func, funct in zip(funcs, funcs_t): 115 | img = func(image_inf) 116 | edge = __enhence_test_single(img, model) 117 | edge = funct(edge) 118 | multi_fuse_f += edge 119 | 120 | multi_fuse = multi_fuse + cv2.flip(multi_fuse_f, 1) 121 | 122 | multi_fuse = multi_fuse / 8 123 | 124 | result_png = Image.fromarray((multi_fuse * 255).astype(np.uint8)) 125 | result_png.save(join(png_save_dir, "%s.png" % filename)) 126 | 127 | sio.savemat(join(mat_save_dir, "%s.mat" % filename), {'result': multi_fuse}, do_compression=True) 128 | 129 | 130 | def bright_enhence_test(model, test_loader, save_dir): 131 | print("bright enhence test") 132 | png_save_dir = os.path.join(save_dir, "png") 133 | mat_save_dir = os.path.join(save_dir, "mat") 134 | if not os.path.exists(png_save_dir): 135 | os.makedirs(png_save_dir) 136 | if not os.path.exists(mat_save_dir): 137 | os.makedirs(mat_save_dir) 138 | 139 | model.eval() 140 | for idx, (image, filename) in enumerate(tqdm(test_loader)): 141 | image = image[0] 142 | image_in = image.numpy().transpose((1, 2, 0)) 143 | 144 | H, W, _ = image_in.shape 145 | 146 | multi_fuse = np.zeros((H, W), np.float32) 147 | bright_intals = [(0, 0.5), (0.25, 0.75), (0.5, 1)] 148 | for internl in bright_intals: 149 | img = __bright_func(image_in, internl) 150 | edge = __enhence_test_single(img, model) 151 | multi_fuse += edge 152 | 153 | multi_fuse = multi_fuse / len(bright_intals) 154 | 155 | result_png = Image.fromarray((multi_fuse * 255).astype(np.uint8)) 156 | result_png.save(join(png_save_dir, "%s.png" % filename)) 157 | 158 | sio.savemat(join(mat_save_dir, "%s.mat" % filename), {'result': multi_fuse}, do_compression=True) 159 | 160 | 161 | def __bright_func(image_in, internl): 162 | threshold_min = image_in.min() + (image_in.max() - image_in.min())*internl[0] 163 | threshold_max = image_in.min() + (image_in.max() - image_in.min())*internl[1] 164 | 165 | enh_image = np.clip(image_in,threshold_min,threshold_max) 166 | 167 | scale_factor = (image_in.max() - image_in.min()) / (threshold_max - threshold_min) 168 | offset = image_in.min() - scale_factor * threshold_min 169 | enh_image = scale_factor * enh_image + offset 170 | 171 | image = (image_in + enh_image) / 2 172 | return image 173 | 174 | 175 | def __enhence_test_single(image_in, model): 176 | scale = [0.5, 0.75, 1.0, 1.25, 1.5, 1.75] 177 | H, W, _ = image_in.shape 178 | multi_fuse = np.zeros((H, W), np.float32) 179 | 180 | for k in range(0, len(scale)): 181 | im_ = cv2.resize(image_in, None, fx=scale[k], fy=scale[k], interpolation=cv2.INTER_LINEAR) 182 | im_ = torch.from_numpy(im_.transpose((2, 0, 1))).unsqueeze(0) 183 | with torch.no_grad(): 184 | result = model(im_.cuda()).squeeze().cpu().numpy() 185 | fuse = cv2.resize(result, (W, H), interpolation=cv2.INTER_LINEAR) 186 | multi_fuse += fuse 187 | multi_fuse = multi_fuse / len(scale) 188 | return multi_fuse 189 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import torch 4 | import random 5 | import numpy 6 | 7 | os.environ["CUDA_LAUNCH_BLOCKING"] = "1" 8 | parser = argparse.ArgumentParser(description='PyTorch Training') 9 | parser.add_argument('--batch_size', default=4, type=int, metavar='BT', 10 | help='batch size') 11 | parser.add_argument('--LR', default=0.0001, type=float, 12 | metavar='LR', help='initial learning rate') 13 | parser.add_argument('--weight_decay', '--wd', default=0.0005, type=float, 14 | metavar='W', help='default weight decay') 15 | parser.add_argument('--stepsize', default="3", type=str, help='learning rate step size') 16 | parser.add_argument('--maxepoch', default=10, type=int, metavar='N', 17 | help='number of total epochs to run') 18 | parser.add_argument('--start_epoch', default=0, type=int, metavar='N', 19 | help='manual epoch number (useful on restarts)') 20 | parser.add_argument('--print_freq', '-p', default=500, type=int, 21 | metavar='N', help='print frequency (default: 50)') 22 | parser.add_argument('--gpu', default=None, type=str, 23 | help='GPU ID') 24 | parser.add_argument('--loss_lmbda', default=None, type=float, 25 | help='hype-param of loss 1.1 for BSDS 1.3 for NYUD') 26 | parser.add_argument('--itersize', default=1, type=int, 27 | metavar='IS', help='iter size') 28 | parser.add_argument("--encoder", default="Dul-M36", 29 | help="caformer-m36,Dul-M36") 30 | parser.add_argument("--decoder", default="unetp", 31 | help="unet,unetp,default") 32 | parser.add_argument("--head", default="default", 33 | help="default,aspp,atten,cofusion") 34 | 35 | parser.add_argument("--savedir", default="tmp") 36 | parser.add_argument("--colorJ", action="store_true") 37 | parser.add_argument("-plr", "--pretrainlr", type=float, default=0.1) 38 | parser.add_argument("--mode", type=str, default="train", choices=["train", "test"]) 39 | parser.add_argument("--resume", type=str,default=None) 40 | parser.add_argument("--dataset", type=str,default="BSDS") 41 | parser.add_argument("--note", default=None) 42 | 43 | args = parser.parse_args() 44 | 45 | args.stepsize = [int(i) for i in args.stepsize.split("-")] 46 | print(args.stepsize) 47 | os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" # see issue #152 48 | if args.gpu is not None: 49 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu 50 | 51 | random_seed = 3407 52 | random.seed(random_seed) 53 | torch.manual_seed(random_seed) 54 | torch.cuda.manual_seed(random_seed) 55 | numpy.random.seed(random_seed) 56 | 57 | import sys 58 | from os.path import join 59 | from data_process import BSDS_Loader,NYUD_Loader,BIPED_Loader,Multicue_Loader 60 | from model.basemodel import Basemodel 61 | from torch.utils.data import DataLoader 62 | from utils import Logger, get_model_parm_nums 63 | from train import train 64 | from test import test,multiscale_test,enhence_test,bright_enhence_test 65 | 66 | 67 | def main(): 68 | if args.dataset == "BSDS": 69 | datadir = "/workspace/00Dataset/BSDS-yc" 70 | 71 | train_dataset = BSDS_Loader(root=datadir, split="train", 72 | colorJitter=args.colorJ) 73 | test_dataset = BSDS_Loader(root=datadir, split="test") 74 | if args.loss_lmbda is None: 75 | args.loss_lmbda = 1.1 76 | elif "NYUD" in args.dataset: 77 | datadir = "/workspace/00Dataset/NYUD" 78 | mode = args.dataset.split("-")[1] 79 | train_dataset = NYUD_Loader(root=datadir, split="train", mode=mode) 80 | test_dataset = NYUD_Loader(root=datadir, split="test", mode=mode) 81 | if args.loss_lmbda is None: 82 | args.loss_lmbda = 1.3 83 | elif args.dataset == "BIPED": 84 | datadir = "/workspace/00Dataset/BIPED" 85 | train_dataset = BIPED_Loader(root=datadir, split="train") 86 | test_dataset = BIPED_Loader(root=datadir, split="test") 87 | if args.loss_lmbda is None: 88 | args.loss_lmbda = 1.1 89 | elif args.dataset == "BIPEDv2": 90 | datadir = "/workspace/00Dataset/BIPEDv2" 91 | # if not os.path.isdir(datadir): 92 | # datadir = "/media/aita130/AIDD/PFHan/Pytorch/BIPEDv2" 93 | # datadir = "/media/aita130/AIDD/PFHan/Pytorch/BIPEDv2" 94 | train_dataset = BIPED_Loader(root=datadir, split="train") 95 | test_dataset = BIPED_Loader(root=datadir, split="test") 96 | if args.loss_lmbda is None: 97 | args.loss_lmbda = 1.1 98 | 99 | elif 'Multicue' in args.dataset: 100 | root = "/workspace/00Dataset/multicue_pidinet" 101 | train_dataset = Multicue_Loader(root=root, split="train", setting=args.dataset.split("-")[1:]) 102 | test_dataset = Multicue_Loader(root=root, split="test", setting=args.dataset.split("-")[1:]) 103 | if args.loss_lmbda is None: 104 | args.loss_lmbda = 1.1 105 | elif args.dataset == "UDED": 106 | datadir = "/workspace/00Dataset/UDED" 107 | train_dataset = None 108 | test_dataset = BIPED_Loader(root=datadir, split="test") 109 | if args.loss_lmbda is None: 110 | args.loss_lmbda = 1.1 111 | else: 112 | raise Exception("Error dataset name") 113 | 114 | test_loader = DataLoader( 115 | test_dataset, batch_size=1, shuffle=False) 116 | 117 | model = Basemodel(encoder_name=args.encoder, 118 | decoder_name=args.decoder, 119 | head_name=args.head).cuda() 120 | print("MODEL SIZE: {}".format(get_model_parm_nums(model))) 121 | 122 | # 123 | # new_key = 'new_key' 124 | # if 'old_key' in original_dict: 125 | # original_value = original_dict.pop('old_key') # 移除旧键及其对应的值 126 | # original_dict[new_key] = original_value # 添加新键及原有的值 127 | 128 | 129 | 130 | if args.resume is not None: 131 | ckpt = torch.load(args.resume, map_location='cpu')['state_dict'] 132 | ckpt["encoder.conv2.0.weight"] = ckpt.pop("encoder.conv2.1.weight") 133 | ckpt["encoder.conv2.0.bias"] = ckpt.pop("encoder.conv2.1.bias") 134 | model.load_state_dict(ckpt) 135 | print("load pretrained model, successfully!") 136 | 137 | if args.mode == "test": 138 | assert args.resume is not None 139 | # test(model, test_loader, save_dir=join(args.savedir, 140 | # os.path.basename(args.resume).split(".")[0]+"-ss")) 141 | if "BSDS" in args.dataset.upper(): 142 | # multiscale_test(model, 143 | # test_loader, 144 | # save_dir=join(args.savedir, os.path.basename(args.resume).split(".")[0]+"-ms7")) 145 | # multiscale_test(model, 146 | # test_loader, 147 | # save_dir=join(args.savedir, os.path.basename(args.resume).split(".")[0] + "-ms3"), 148 | # scale_num=3) 149 | # enhence_test(model, 150 | # test_loader, 151 | # save_dir=join(args.savedir, os.path.basename(args.resume).split(".")[0] + "-enh"),) 152 | bright_enhence_test(model, 153 | test_loader, 154 | save_dir=join(args.savedir, os.path.basename(args.resume).split(".")[0] + "-brienh"),) 155 | 156 | 157 | 158 | else: 159 | train_loader = DataLoader( 160 | train_dataset, batch_size=args.batch_size, num_workers=args.batch_size, drop_last=True, shuffle=True) 161 | 162 | parameters = {'pretrained.weight': [], 'pretrained.bias': [], 163 | 'nopretrained.weight': [], 'nopretrained.bias': []} 164 | 165 | for pname, p in model.named_parameters(): 166 | if ("encoder.stages" in pname) or ("encoder.downsample_layers" in pname): 167 | # p.requires_grad = False 168 | if "weight" in pname: 169 | parameters['pretrained.weight'].append(p) 170 | else: 171 | parameters['pretrained.bias'].append(p) 172 | 173 | else: 174 | if "weight" in pname: 175 | parameters['nopretrained.weight'].append(p) 176 | else: 177 | parameters['nopretrained.bias'].append(p) 178 | 179 | optimizer = torch.optim.Adam([ 180 | {'params': parameters['pretrained.weight'], 'lr': args.LR * args.pretrainlr, 'weight_decay': args.weight_decay}, 181 | {'params': parameters['pretrained.bias'], 'lr': args.LR * 2 * args.pretrainlr, 'weight_decay': 0.}, 182 | {'params': parameters['nopretrained.weight'], 'lr': args.LR * 1, 'weight_decay': args.weight_decay}, 183 | {'params': parameters['nopretrained.bias'], 'lr': args.LR * 2, 'weight_decay': 0.}, 184 | ], lr=args.LR, weight_decay=args.weight_decay) 185 | 186 | 187 | # optimizer = torch.optim.Adam(model.parameters(), lr=args.LR, weight_decay=args.weight_decay) 188 | # optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=args.LR, 189 | # weight_decay=args.weight_decay) 190 | # test(model, test_loader, save_dir=join(args.savedir, 'epoch-init-testing-record')) 191 | for epoch in range(args.start_epoch, args.maxepoch): 192 | train(train_loader, model, optimizer, epoch, args) 193 | test(model, test_loader, save_dir=join(args.savedir, 'epoch-%d-ss-test' % epoch)) 194 | if "BSDS" in args.dataset.upper(): 195 | multiscale_test(model, test_loader, save_dir=join(args.savedir, 'epoch-%d-ms-test' % epoch)) 196 | log.flush() 197 | 198 | 199 | if __name__ == '__main__': 200 | import datetime 201 | 202 | # 获取当前日期和时间 203 | current_time = datetime.datetime.now() 204 | # 将日期和时间转换为字符串格式 205 | time_string = current_time.strftime("%Y-%m-%d %H:%M:%S") 206 | 207 | args.savedir = join("output-abl", args.savedir) 208 | os.makedirs(args.savedir, exist_ok=True) 209 | log = Logger(join(args.savedir, '%s-log.txt' % (time_string))) 210 | sys.stdout = log 211 | cmds = "python" 212 | for cmd in sys.argv: 213 | if " " in cmd: 214 | cmd = "\'" + cmd + "\'" 215 | cmds = cmds + " " + cmd 216 | print(cmds) 217 | print(args) 218 | 219 | main() 220 | -------------------------------------------------------------------------------- /data_process.py: -------------------------------------------------------------------------------- 1 | from torch.utils import data 2 | from os.path import join, abspath, splitext, split, isdir, isfile 3 | from PIL import Image 4 | from torchvision import transforms 5 | import torch 6 | import numpy as np 7 | import random 8 | import imageio 9 | from pathlib import Path 10 | from torch.nn.functional import interpolate 11 | 12 | 13 | class BSDS_Loader(data.Dataset): 14 | """ 15 | Dataloader BSDS500 16 | """ 17 | 18 | def __init__(self, root='data/HED-BSDS', split='train', threshold=0.3, colorJitter=False): 19 | self.root = root 20 | self.split = split 21 | self.threshold = threshold 22 | print('Threshold for ground truth: %f on BSDS' % self.threshold) 23 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 24 | std=[0.229, 0.224, 0.225]) 25 | if colorJitter: 26 | self.transform = transforms.Compose([ 27 | transforms.ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5, hue=0.5), 28 | transforms.RandomGrayscale(p=0.2), 29 | transforms.ToTensor(), 30 | normalize]) 31 | else: 32 | self.transform = transforms.Compose([ 33 | transforms.ToTensor(), 34 | normalize]) 35 | if self.split == 'train': 36 | self.filelist = join(self.root, 'train_BSDS.lst') 37 | elif self.split == 'test': 38 | self.filelist = join(self.root, 'test.lst') 39 | else: 40 | raise ValueError("Invalid split type!") 41 | with open(self.filelist, 'r') as f: 42 | self.filelist = f.readlines() 43 | 44 | def __len__(self): 45 | return len(self.filelist) 46 | 47 | def __getitem__(self, index): 48 | if self.split == "train": 49 | img_lb_file = self.filelist[index].strip("\n").split(" ") 50 | img_file = img_lb_file[0] 51 | 52 | label_list = [] 53 | for lb_file in img_lb_file[1:]: 54 | label_list.append( 55 | transforms.ToTensor()(Image.open(join(self.root, lb_file)))) 56 | lb = torch.cat(label_list, 0).mean(0, keepdim=True) 57 | 58 | lb[lb >= self.threshold] = 1 59 | lb[(lb > 0) & (lb < self.threshold)] = 2 60 | 61 | else: 62 | img_file = self.filelist[index].rstrip() 63 | 64 | with open(join(self.root, img_file), 'rb') as f: 65 | img = Image.open(f) 66 | img = img.convert('RGB') 67 | img = self.transform(img) 68 | 69 | if self.split == "train": 70 | return img, lb 71 | else: 72 | img_name = Path(img_file).stem 73 | return img, img_name 74 | 75 | 76 | class NYUD_Loader(data.Dataset): 77 | """ 78 | Dataloader BSDS500 79 | """ 80 | 81 | def __init__(self, root='data/HED-BSDS_PASCAL', split='train', mode="RGB"): 82 | self.root = root 83 | self.split = split 84 | # 85 | if mode == "RGB": 86 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 87 | std=[0.229, 0.224, 0.225]) 88 | else: # calculate by Archer 89 | normalize = transforms.Normalize(mean=[0.519, 0.370, 0.465], 90 | std=[0.226, 0.246, 0.186]) 91 | 92 | self.transform = transforms.Compose([ 93 | transforms.ToTensor(), 94 | normalize]) 95 | 96 | if self.split == 'train': 97 | if mode == "RGB": 98 | self.filelist = join(root, "image-train.txt") 99 | else: 100 | self.filelist = join(root, "hha-train.txt") 101 | 102 | elif self.split == 'test': 103 | if mode == "RGB": 104 | self.filelist = join(root, "image-test.txt") 105 | else: 106 | self.filelist = join(root, "hha-test.txt") 107 | 108 | else: 109 | raise ValueError("Invalid split type!") 110 | 111 | with open(self.filelist, 'r') as f: 112 | self.filelist = f.readlines() 113 | 114 | def __len__(self): 115 | return len(self.filelist) 116 | 117 | def __getitem__(self, index): 118 | if self.split == "train": 119 | img_file, lb_file = self.filelist[index].strip("\n").split(" ") 120 | 121 | else: 122 | img_file = self.filelist[index].strip("\n").split(" ")[0] 123 | 124 | img = imageio.imread(join(self.root, img_file)) 125 | img = self.transform(img) 126 | 127 | if self.split == "train": 128 | label = transforms.ToTensor()(imageio.imread(join(self.root, lb_file), as_gray=True)) / 255 129 | img, label = self.crop(img, label) 130 | return img, label 131 | 132 | else: 133 | img_name = Path(img_file).stem 134 | return img, img_name 135 | 136 | @staticmethod 137 | def crop(img, lb): 138 | _, h, w = img.size() 139 | crop_size = 400 140 | 141 | if h < crop_size or w < crop_size: 142 | resize_scale = round(max(crop_size / h, crop_size / w) + 0.1, 1) 143 | 144 | img = interpolate(img.unsqueeze(0), scale_factor=resize_scale, mode="bilinear").squeeze(0) 145 | lb = interpolate(lb.unsqueeze(0), scale_factor=resize_scale, mode="nearest").squeeze(0) 146 | _, h, w = img.size() 147 | i = random.randint(0, h - crop_size) 148 | j = random.randint(0, w - crop_size) 149 | img = img[:, i:i + crop_size, j:j + crop_size] 150 | lb = lb[:, i:i + crop_size, j:j + crop_size] 151 | 152 | return img, lb 153 | 154 | 155 | class BIPED_Loader(data.Dataset): 156 | """ 157 | Dataloader BSDS500 158 | """ 159 | 160 | def __init__(self, root=' ', split='train', transform=False): 161 | self.root = root 162 | self.split = split 163 | self.transform = transform 164 | if self.split == 'train': 165 | self.filelist = join(root, "train_pair.lst") 166 | 167 | elif self.split == 'test': 168 | self.filelist = join(root, "test.lst") 169 | else: 170 | raise ValueError("Invalid split type!") 171 | with open(self.filelist, 'r') as f: 172 | self.filelist = f.readlines() 173 | # print(self.filelist) 174 | 175 | def __len__(self): 176 | return len(self.filelist) 177 | 178 | def __getitem__(self, index): 179 | if self.split == "train": 180 | img_file, lb_file = self.filelist[index].strip("\n").split(" ") 181 | 182 | else: 183 | img_file = self.filelist[index].rstrip() 184 | 185 | img = imageio.imread(join(self.root, img_file)) 186 | img = transforms.ToTensor()(img) 187 | 188 | if self.split == "train": 189 | label = transforms.ToTensor()(imageio.imread(join(self.root, lb_file), as_gray=True)) / 255 190 | img, label = self.crop(img, label) 191 | return img, label 192 | 193 | else: 194 | img_name = Path(img_file).stem 195 | return img, img_name 196 | 197 | @staticmethod 198 | def crop(img, lb): 199 | _, h, w = img.size() 200 | assert (h > 400) and (w > 400) 201 | crop_size = 400 202 | i = random.randint(0, h - crop_size) 203 | j = random.randint(0, w - crop_size) 204 | img = img[:, i:i + crop_size, j:j + crop_size] 205 | lb = lb[:, i:i + crop_size, j:j + crop_size] 206 | 207 | return img, lb 208 | 209 | 210 | class Multicue_Loader(data.Dataset): 211 | """ 212 | Dataloader for Multicue 213 | """ 214 | 215 | def __init__(self, root='data/', split='train', transform=False, threshold=0.3, setting=['boundary', '1']): 216 | """ 217 | setting[0] should be 'boundary' or 'edge' 218 | setting[1] should be '1' or '2' or '3' 219 | """ 220 | self.root = root 221 | self.split = split 222 | self.threshold = threshold 223 | print('Threshold for ground truth: %f on setting %s' % (self.threshold, str(setting))) 224 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 225 | std=[0.229, 0.224, 0.225]) 226 | self.transform = transforms.Compose([ 227 | transforms.ToTensor(), 228 | normalize]) 229 | if self.split == 'train': 230 | self.filelist = join(self.root, 'train_pair_%s_set_%s.lst' % (setting[0], setting[1])) 231 | elif self.split == 'test': 232 | self.filelist = join(self.root, 'test_%s_set_%s.lst' % (setting[0], setting[1])) 233 | else: 234 | raise ValueError("Invalid split type!") 235 | with open(self.filelist, 'r') as f: 236 | self.filelist = f.readlines() 237 | 238 | def __len__(self): 239 | return len(self.filelist) 240 | 241 | def __getitem__(self, index): 242 | if self.split == "train": 243 | img_file, lb_file = self.filelist[index].split() 244 | img_file = img_file.strip() 245 | lb_file = lb_file.strip() 246 | 247 | lb = transforms.ToTensor()(Image.open(join(self.root, lb_file)).convert("L")) 248 | 249 | lb[lb > self.threshold] = 1 250 | lb[(lb > 0) & (lb < self.threshold)] = 2 251 | 252 | else: 253 | img_file = self.filelist[index].rstrip() 254 | 255 | with open(join(self.root, img_file), 'rb') as f: 256 | img = Image.open(f) 257 | img = img.convert('RGB') 258 | img = self.transform(img) 259 | 260 | if self.split == "train": 261 | return self.crop(img, lb) 262 | else: 263 | img_name = Path(img_file).stem 264 | return img, img_name 265 | 266 | @staticmethod 267 | def crop(img, lb): 268 | _, h, w = img.size() 269 | crop_size = 400 270 | 271 | if (h < crop_size) or (w < crop_size): 272 | resize_scale = round(max(crop_size / h, crop_size / w) + 0.1, 1) 273 | 274 | img = interpolate(img.unsqueeze(0), scale_factor=resize_scale, mode="bilinear").squeeze(0) 275 | lb = interpolate(lb.unsqueeze(0), scale_factor=resize_scale, mode="nearest").squeeze(0) 276 | _, h, w = img.size() 277 | i = random.randint(0, h - crop_size) 278 | j = random.randint(0, w - crop_size) 279 | img = img[:, i:i + crop_size, j:j + crop_size] 280 | lb = lb[:, i:i + crop_size, j:j + crop_size] 281 | 282 | return img, lb 283 | -------------------------------------------------------------------------------- /model/caformer.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 Garena Online Private Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """ 16 | MetaFormer baselines including IdentityFormer, RandFormer, PoolFormerV2, 17 | ConvFormer and CAFormer. 18 | Some implementations are modified from timm (https://github.com/rwightman/pytorch-image-models). 19 | """ 20 | from functools import partial 21 | import torch 22 | import torch.nn as nn 23 | import torch.nn.functional as F 24 | 25 | from timm.models.layers import trunc_normal_, DropPath 26 | from timm.models.registry import register_model 27 | from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD 28 | from timm.models.layers.helpers import to_2tuple 29 | 30 | 31 | def _cfg(url='', **kwargs): 32 | return { 33 | 'url': url, 34 | 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None, 35 | 'crop_pct': 1.0, 'interpolation': 'bicubic', 36 | 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, 'classifier': 'head', 37 | **kwargs 38 | } 39 | 40 | 41 | default_cfgs = { 42 | 'identityformer_s12': _cfg( 43 | url='https://huggingface.co/sail/dl/resolve/main/identityformer/identityformer_s12.pth'), 44 | 'identityformer_s24': _cfg( 45 | url='https://huggingface.co/sail/dl/resolve/main/identityformer/identityformer_s24.pth'), 46 | 'identityformer_s36': _cfg( 47 | url='https://huggingface.co/sail/dl/resolve/main/identityformer/identityformer_s36.pth'), 48 | 'identityformer_m36': _cfg( 49 | url='https://huggingface.co/sail/dl/resolve/main/identityformer/identityformer_m36.pth'), 50 | 'identityformer_m48': _cfg( 51 | url='https://huggingface.co/sail/dl/resolve/main/identityformer/identityformer_m48.pth'), 52 | 53 | 'randformer_s12': _cfg( 54 | url='https://huggingface.co/sail/dl/resolve/main/randformer/randformer_s12.pth'), 55 | 'randformer_s24': _cfg( 56 | url='https://huggingface.co/sail/dl/resolve/main/randformer/randformer_s24.pth'), 57 | 'randformer_s36': _cfg( 58 | url='https://huggingface.co/sail/dl/resolve/main/randformer/randformer_s36.pth'), 59 | 'randformer_m36': _cfg( 60 | url='https://huggingface.co/sail/dl/resolve/main/randformer/randformer_m36.pth'), 61 | 'randformer_m48': _cfg( 62 | url='https://huggingface.co/sail/dl/resolve/main/randformer/randformer_m48.pth'), 63 | 64 | 'poolformerv2_s12': _cfg( 65 | url='https://huggingface.co/sail/dl/resolve/main/poolformerv2/poolformerv2_s12.pth'), 66 | 'poolformerv2_s24': _cfg( 67 | url='https://huggingface.co/sail/dl/resolve/main/poolformerv2/poolformerv2_s24.pth'), 68 | 'poolformerv2_s36': _cfg( 69 | url='https://huggingface.co/sail/dl/resolve/main/poolformerv2/poolformerv2_s36.pth'), 70 | 'poolformerv2_m36': _cfg( 71 | url='https://huggingface.co/sail/dl/resolve/main/poolformerv2/poolformerv2_m36.pth'), 72 | 'poolformerv2_m48': _cfg( 73 | url='https://huggingface.co/sail/dl/resolve/main/poolformerv2/poolformerv2_m48.pth'), 74 | 75 | 'convformer_s18': _cfg( 76 | url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_s18.pth'), 77 | 'convformer_s18_384': _cfg( 78 | url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_s18_384.pth', 79 | input_size=(3, 384, 384)), 80 | 'convformer_s18_in21ft1k': _cfg( 81 | url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_s18_in21ft1k.pth'), 82 | 'convformer_s18_384_in21ft1k': _cfg( 83 | url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_s18_384_in21ft1k.pth', 84 | input_size=(3, 384, 384)), 85 | 'convformer_s18_in21k': _cfg( 86 | url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_s18_in21k.pth', 87 | num_classes=21841), 88 | 89 | 'convformer_s36': _cfg( 90 | url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_s36.pth'), 91 | 'convformer_s36_384': _cfg( 92 | url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_s36_384.pth', 93 | input_size=(3, 384, 384)), 94 | 'convformer_s36_in21ft1k': _cfg( 95 | url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_s36_in21ft1k.pth'), 96 | 'convformer_s36_384_in21ft1k': _cfg( 97 | url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_s36_384_in21ft1k.pth', 98 | input_size=(3, 384, 384)), 99 | 'convformer_s36_in21k': _cfg( 100 | url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_s36_in21k.pth', 101 | num_classes=21841), 102 | 103 | 'convformer_m36': _cfg( 104 | url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_m36.pth'), 105 | 'convformer_m36_384': _cfg( 106 | url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_m36_384.pth', 107 | input_size=(3, 384, 384)), 108 | 'convformer_m36_in21ft1k': _cfg( 109 | url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_m36_in21ft1k.pth'), 110 | 'convformer_m36_384_in21ft1k': _cfg( 111 | url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_m36_384_in21ft1k.pth', 112 | input_size=(3, 384, 384)), 113 | 'convformer_m36_in21k': _cfg( 114 | url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_m36_in21k.pth', 115 | num_classes=21841), 116 | 117 | 'convformer_b36': _cfg( 118 | url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_b36.pth'), 119 | 'convformer_b36_384': _cfg( 120 | url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_b36_384.pth', 121 | input_size=(3, 384, 384)), 122 | 'convformer_b36_in21ft1k': _cfg( 123 | url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_b36_in21ft1k.pth'), 124 | 'convformer_b36_384_in21ft1k': _cfg( 125 | url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_b36_384_in21ft1k.pth', 126 | input_size=(3, 384, 384)), 127 | 'convformer_b36_in21k': _cfg( 128 | url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_b36_in21k.pth', 129 | num_classes=21841), 130 | 131 | 'caformer_s18': _cfg( 132 | url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_s18.pth'), 133 | 'caformer_s18_384': _cfg( 134 | url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_s18_384.pth', 135 | input_size=(3, 384, 384)), 136 | 'caformer_s18_in21ft1k': _cfg( 137 | url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_s18_in21ft1k.pth'), 138 | 'caformer_s18_384_in21ft1k': _cfg( 139 | url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_s18_384_in21ft1k.pth', 140 | input_size=(3, 384, 384)), 141 | 'caformer_s18_in21k': _cfg( 142 | url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_s18_in21k.pth', 143 | num_classes=21841), 144 | 145 | 'caformer_s36': _cfg( 146 | url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_s36.pth'), 147 | 'caformer_s36_384': _cfg( 148 | url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_s36_384.pth', 149 | input_size=(3, 384, 384)), 150 | 'caformer_s36_in21ft1k': _cfg( 151 | url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_s36_in21ft1k.pth'), 152 | 'caformer_s36_384_in21ft1k': _cfg( 153 | url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_s36_384_in21ft1k.pth', 154 | input_size=(3, 384, 384)), 155 | 'caformer_s36_in21k': _cfg( 156 | url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_s36_in21k.pth', 157 | num_classes=21841), 158 | 159 | 'caformer_m36': _cfg( 160 | url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_m36.pth'), 161 | 'caformer_m36_384': _cfg( 162 | url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_m36_384.pth', 163 | input_size=(3, 384, 384)), 164 | 'caformer_m36_in21ft1k': _cfg( 165 | url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_m36_in21ft1k.pth'), 166 | 'caformer_m36_384_in21ft1k': _cfg( 167 | url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_m36_384_in21ft1k.pth', 168 | input_size=(3, 384, 384)), 169 | 'caformer_m36_in21k': _cfg( 170 | url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_m36_in21k.pth', 171 | num_classes=21841), 172 | 173 | 'caformer_b36': _cfg( 174 | url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_b36.pth'), 175 | 'caformer_b36_384': _cfg( 176 | url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_b36_384.pth', 177 | input_size=(3, 384, 384)), 178 | 'caformer_b36_in21ft1k': _cfg( 179 | url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_b36_in21ft1k.pth'), 180 | 'caformer_b36_384_in21ft1k': _cfg( 181 | url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_b36_384_in21ft1k.pth', 182 | input_size=(3, 384, 384)), 183 | 'caformer_b36_in21k': _cfg( 184 | url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_b36_in21k.pth', 185 | num_classes=21841), 186 | } 187 | 188 | 189 | class Downsampling(nn.Module): 190 | """ 191 | Downsampling implemented by a layer of convolution. 192 | """ 193 | 194 | def __init__(self, in_channels, out_channels, 195 | kernel_size, stride=1, padding=0, 196 | pre_norm=None, post_norm=None, pre_permute=False): 197 | super().__init__() 198 | self.pre_norm = pre_norm(in_channels) if pre_norm else nn.Identity() 199 | self.pre_permute = pre_permute 200 | self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, 201 | stride=stride, padding=padding) 202 | self.post_norm = post_norm(out_channels) if post_norm else nn.Identity() 203 | 204 | def forward(self, x): 205 | x = self.pre_norm(x) 206 | if self.pre_permute: 207 | # if take [B, H, W, C] as input, permute it to [B, C, H, W] 208 | x = x.permute(0, 3, 1, 2) 209 | x = self.conv(x) 210 | x = x.permute(0, 2, 3, 1) # [B, C, H, W] -> [B, H, W, C] 211 | x = self.post_norm(x) 212 | return x 213 | 214 | 215 | class Scale(nn.Module): 216 | """ 217 | Scale vector by element multiplications. 218 | """ 219 | 220 | def __init__(self, dim, init_value=1.0, trainable=True): 221 | super().__init__() 222 | self.scale = nn.Parameter(init_value * torch.ones(dim), requires_grad=trainable) 223 | 224 | def forward(self, x): 225 | return x * self.scale 226 | 227 | 228 | class SquaredReLU(nn.Module): 229 | """ 230 | Squared ReLU: https://arxiv.org/abs/2109.08668 231 | """ 232 | 233 | def __init__(self, inplace=False): 234 | super().__init__() 235 | self.relu = nn.ReLU(inplace=inplace) 236 | 237 | def forward(self, x): 238 | return torch.square(self.relu(x)) 239 | 240 | 241 | class StarReLU(nn.Module): 242 | """ 243 | StarReLU: s * relu(x) ** 2 + b 244 | """ 245 | 246 | def __init__(self, scale_value=1.0, bias_value=0.0, 247 | scale_learnable=True, bias_learnable=True, 248 | mode=None, inplace=False): 249 | super().__init__() 250 | self.inplace = inplace 251 | self.relu = nn.ReLU(inplace=inplace) 252 | self.scale = nn.Parameter(scale_value * torch.ones(1), 253 | requires_grad=scale_learnable) 254 | self.bias = nn.Parameter(bias_value * torch.ones(1), 255 | requires_grad=bias_learnable) 256 | 257 | def forward(self, x): 258 | return self.scale * self.relu(x) ** 2 + self.bias 259 | 260 | 261 | class Attention(nn.Module): 262 | """ 263 | Vanilla self-attention from Transformer: https://arxiv.org/abs/1706.03762. 264 | Modified from timm. 265 | """ 266 | 267 | def __init__(self, dim, head_dim=32, num_heads=None, qkv_bias=False, 268 | attn_drop=0., proj_drop=0., proj_bias=False, **kwargs): 269 | super().__init__() 270 | 271 | self.head_dim = head_dim 272 | self.scale = head_dim ** -0.5 273 | 274 | self.num_heads = num_heads if num_heads else dim // head_dim 275 | if self.num_heads == 0: 276 | self.num_heads = 1 277 | 278 | self.attention_dim = self.num_heads * self.head_dim 279 | 280 | self.qkv = nn.Linear(dim, self.attention_dim * 3, bias=qkv_bias) 281 | self.attn_drop = nn.Dropout(attn_drop) 282 | self.proj = nn.Linear(self.attention_dim, dim, bias=proj_bias) 283 | self.proj_drop = nn.Dropout(proj_drop) 284 | 285 | def forward(self, x): 286 | B, H, W, C = x.shape 287 | N = H * W 288 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4) 289 | q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple) 290 | 291 | attn = (q @ k.transpose(-2, -1)) * self.scale 292 | attn = attn.softmax(dim=-1) 293 | attn = self.attn_drop(attn) 294 | 295 | x = (attn @ v).transpose(1, 2).reshape(B, H, W, self.attention_dim) 296 | x = self.proj(x) 297 | x = self.proj_drop(x) 298 | return x 299 | 300 | 301 | class RandomMixing(nn.Module): 302 | def __init__(self, num_tokens=196, **kwargs): 303 | super().__init__() 304 | self.random_matrix = nn.parameter.Parameter( 305 | data=torch.softmax(torch.rand(num_tokens, num_tokens), dim=-1), 306 | requires_grad=False) 307 | 308 | def forward(self, x): 309 | B, H, W, C = x.shape 310 | x = x.reshape(B, H * W, C) 311 | x = torch.einsum('mn, bnc -> bmc', self.random_matrix, x) 312 | x = x.reshape(B, H, W, C) 313 | return x 314 | 315 | 316 | class LayerNormGeneral(nn.Module): 317 | r""" General LayerNorm for different situations. 318 | 319 | Args: 320 | affine_shape (int, list or tuple): The shape of affine weight and bias. 321 | Usually the affine_shape=C, but in some implementation, like torch.nn.LayerNorm, 322 | the affine_shape is the same as normalized_dim by default. 323 | To adapt to different situations, we offer this argument here. 324 | normalized_dim (tuple or list): Which dims to compute mean and variance. 325 | scale (bool): Flag indicates whether to use scale or not. 326 | bias (bool): Flag indicates whether to use scale or not. 327 | 328 | We give several examples to show how to specify the arguments. 329 | 330 | LayerNorm (https://arxiv.org/abs/1607.06450): 331 | For input shape of (B, *, C) like (B, N, C) or (B, H, W, C), 332 | affine_shape=C, normalized_dim=(-1, ), scale=True, bias=True; 333 | For input shape of (B, C, H, W), 334 | affine_shape=(C, 1, 1), normalized_dim=(1, ), scale=True, bias=True. 335 | 336 | Modified LayerNorm (https://arxiv.org/abs/2111.11418) 337 | that is idental to partial(torch.nn.GroupNorm, num_groups=1): 338 | For input shape of (B, N, C), 339 | affine_shape=C, normalized_dim=(1, 2), scale=True, bias=True; 340 | For input shape of (B, H, W, C), 341 | affine_shape=C, normalized_dim=(1, 2, 3), scale=True, bias=True; 342 | For input shape of (B, C, H, W), 343 | affine_shape=(C, 1, 1), normalized_dim=(1, 2, 3), scale=True, bias=True. 344 | 345 | For the several metaformer baslines, 346 | IdentityFormer, RandFormer and PoolFormerV2 utilize Modified LayerNorm without bias (bias=False); 347 | ConvFormer and CAFormer utilizes LayerNorm without bias (bias=False). 348 | """ 349 | 350 | def __init__(self, affine_shape=None, normalized_dim=(-1,), scale=True, 351 | bias=True, eps=1e-5): 352 | super().__init__() 353 | self.normalized_dim = normalized_dim 354 | self.use_scale = scale 355 | self.use_bias = bias 356 | self.weight = nn.Parameter(torch.ones(affine_shape)) if scale else None 357 | self.bias = nn.Parameter(torch.zeros(affine_shape)) if bias else None 358 | self.eps = eps 359 | 360 | def forward(self, x): 361 | c = x - x.mean(self.normalized_dim, keepdim=True) 362 | s = c.pow(2).mean(self.normalized_dim, keepdim=True) 363 | x = c / torch.sqrt(s + self.eps) 364 | if self.use_scale: 365 | x = x * self.weight 366 | if self.use_bias: 367 | x = x + self.bias 368 | return x 369 | 370 | 371 | class LayerNormWithoutBias(nn.Module): 372 | """ 373 | Equal to partial(LayerNormGeneral, bias=False) but faster, 374 | because it directly utilizes otpimized F.layer_norm 375 | """ 376 | 377 | def __init__(self, normalized_shape, eps=1e-5, **kwargs): 378 | super().__init__() 379 | self.eps = eps 380 | self.bias = None 381 | if isinstance(normalized_shape, int): 382 | normalized_shape = (normalized_shape,) 383 | self.weight = nn.Parameter(torch.ones(normalized_shape)) 384 | self.normalized_shape = normalized_shape 385 | 386 | def forward(self, x): 387 | return F.layer_norm(x, self.normalized_shape, weight=self.weight, bias=self.bias, eps=self.eps) 388 | 389 | 390 | class SepConv(nn.Module): 391 | r""" 392 | Inverted separable convolution from MobileNetV2: https://arxiv.org/abs/1801.04381. 393 | """ 394 | 395 | def __init__(self, dim, expansion_ratio=2, 396 | act1_layer=StarReLU, act2_layer=nn.Identity, 397 | bias=False, kernel_size=7, padding=3, 398 | **kwargs, ): 399 | super().__init__() 400 | med_channels = int(expansion_ratio * dim) 401 | self.pwconv1 = nn.Linear(dim, med_channels, bias=bias) 402 | self.act1 = act1_layer() 403 | self.dwconv = nn.Conv2d( 404 | med_channels, med_channels, kernel_size=kernel_size, 405 | padding=padding, groups=med_channels, bias=bias) # depthwise conv 406 | self.act2 = act2_layer() 407 | self.pwconv2 = nn.Linear(med_channels, dim, bias=bias) 408 | 409 | def forward(self, x): 410 | x = self.pwconv1(x) 411 | x = self.act1(x) 412 | x = x.permute(0, 3, 1, 2) 413 | x = self.dwconv(x) 414 | x = x.permute(0, 2, 3, 1) 415 | x = self.act2(x) 416 | x = self.pwconv2(x) 417 | return x 418 | 419 | 420 | class Pooling(nn.Module): 421 | """ 422 | Implementation of pooling for PoolFormer: https://arxiv.org/abs/2111.11418 423 | Modfiled for [B, H, W, C] input 424 | """ 425 | 426 | def __init__(self, pool_size=3, **kwargs): 427 | super().__init__() 428 | self.pool = nn.AvgPool2d( 429 | pool_size, stride=1, padding=pool_size // 2, count_include_pad=False) 430 | 431 | def forward(self, x): 432 | y = x.permute(0, 3, 1, 2) 433 | y = self.pool(y) 434 | y = y.permute(0, 2, 3, 1) 435 | return y - x 436 | 437 | 438 | class Mlp(nn.Module): 439 | """ MLP as used in MetaFormer models, eg Transformer, MLP-Mixer, PoolFormer, MetaFormer baslines and related networks. 440 | Mostly copied from timm. 441 | """ 442 | 443 | def __init__(self, dim, mlp_ratio=4, out_features=None, act_layer=StarReLU, drop=0., bias=False, **kwargs): 444 | super().__init__() 445 | in_features = dim 446 | out_features = out_features or in_features 447 | hidden_features = int(mlp_ratio * in_features) 448 | drop_probs = to_2tuple(drop) 449 | 450 | self.fc1 = nn.Linear(in_features, hidden_features, bias=bias) 451 | self.act = act_layer() 452 | self.drop1 = nn.Dropout(drop_probs[0]) 453 | self.fc2 = nn.Linear(hidden_features, out_features, bias=bias) 454 | self.drop2 = nn.Dropout(drop_probs[1]) 455 | 456 | def forward(self, x): 457 | x = self.fc1(x) 458 | x = self.act(x) 459 | x = self.drop1(x) 460 | x = self.fc2(x) 461 | x = self.drop2(x) 462 | return x 463 | 464 | 465 | class MlpHead(nn.Module): 466 | """ MLP classification head 467 | """ 468 | 469 | def __init__(self, dim, num_classes=1000, mlp_ratio=4, act_layer=SquaredReLU, 470 | norm_layer=nn.LayerNorm, head_dropout=0., bias=True): 471 | super().__init__() 472 | hidden_features = int(mlp_ratio * dim) 473 | self.fc1 = nn.Linear(dim, hidden_features, bias=bias) 474 | self.act = act_layer() 475 | self.norm = norm_layer(hidden_features) 476 | self.fc2 = nn.Linear(hidden_features, num_classes, bias=bias) 477 | self.head_dropout = nn.Dropout(head_dropout) 478 | 479 | def forward(self, x): 480 | x = self.fc1(x) 481 | x = self.act(x) 482 | x = self.norm(x) 483 | x = self.head_dropout(x) 484 | x = self.fc2(x) 485 | return x 486 | 487 | 488 | class MetaFormerBlock(nn.Module): 489 | """ 490 | Implementation of one MetaFormer block. 491 | """ 492 | 493 | def __init__(self, dim, 494 | token_mixer=nn.Identity, mlp=Mlp, 495 | norm_layer=nn.LayerNorm, 496 | drop=0., drop_path=0., 497 | layer_scale_init_value=None, res_scale_init_value=None 498 | ): 499 | super().__init__() 500 | 501 | self.norm1 = norm_layer(dim) 502 | self.token_mixer = token_mixer(dim=dim, drop=drop) 503 | self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity() 504 | self.layer_scale1 = Scale(dim=dim, init_value=layer_scale_init_value) \ 505 | if layer_scale_init_value else nn.Identity() 506 | self.res_scale1 = Scale(dim=dim, init_value=res_scale_init_value) \ 507 | if res_scale_init_value else nn.Identity() 508 | 509 | self.norm2 = norm_layer(dim) 510 | self.mlp = mlp(dim=dim, drop=drop) 511 | self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity() 512 | self.layer_scale2 = Scale(dim=dim, init_value=layer_scale_init_value) \ 513 | if layer_scale_init_value else nn.Identity() 514 | self.res_scale2 = Scale(dim=dim, init_value=res_scale_init_value) \ 515 | if res_scale_init_value else nn.Identity() 516 | 517 | def forward(self, x): 518 | x = self.res_scale1(x) + \ 519 | self.layer_scale1( 520 | self.drop_path1( 521 | self.token_mixer(self.norm1(x)) 522 | ) 523 | ) 524 | x = self.res_scale2(x) + \ 525 | self.layer_scale2( 526 | self.drop_path2( 527 | self.mlp(self.norm2(x)) 528 | ) 529 | ) 530 | return x 531 | 532 | 533 | r""" 534 | downsampling (stem) for the first stage is a layer of conv with k7, s4 and p2 535 | downsamplings for the last 3 stages is a layer of conv with k3, s2 and p1 536 | DOWNSAMPLE_LAYERS_FOUR_STAGES format: [Downsampling, Downsampling, Downsampling, Downsampling] 537 | use `partial` to specify some arguments 538 | """ 539 | DOWNSAMPLE_LAYERS_FOUR_STAGES = [partial(Downsampling, 540 | kernel_size=7, stride=4, padding=2, 541 | post_norm=partial(LayerNormGeneral, bias=False, eps=1e-6) 542 | )] + \ 543 | [partial(Downsampling, 544 | kernel_size=3, stride=2, padding=1, 545 | pre_norm=partial(LayerNormGeneral, bias=False, eps=1e-6), pre_permute=True 546 | )] * 3 547 | 548 | 549 | class MetaFormer(nn.Module): 550 | r""" MetaFormer 551 | A PyTorch impl of : `MetaFormer Baselines for Vision` - 552 | https://arxiv.org/abs/2210.13452 553 | 554 | Args: 555 | in_chans (int): Number of input image channels. Default: 3. 556 | num_classes (int): Number of classes for classification head. Default: 1000. 557 | depths (list or tuple): Number of blocks at each stage. Default: [2, 2, 6, 2]. 558 | dims (int): Feature dimension at each stage. Default: [64, 128, 320, 512]. 559 | downsample_layers: (list or tuple): Downsampling layers before each stage. 560 | token_mixers (list, tuple or token_fcn): Token mixer for each stage. Default: nn.Identity. 561 | mlps (list, tuple or mlp_fcn): Mlp for each stage. Default: Mlp. 562 | norm_layers (list, tuple or norm_fcn): Norm layers for each stage. Default: partial(LayerNormGeneral, eps=1e-6, bias=False). 563 | drop_path_rate (float): Stochastic depth rate. Default: 0. 564 | head_dropout (float): dropout for MLP classifier. Default: 0. 565 | layer_scale_init_values (list, tuple, float or None): Init value for Layer Scale. Default: None. 566 | None means not use the layer scale. Form: https://arxiv.org/abs/2103.17239. 567 | res_scale_init_values (list, tuple, float or None): Init value for Layer Scale. Default: [None, None, 1.0, 1.0]. 568 | None means not use the layer scale. From: https://arxiv.org/abs/2110.09456. 569 | output_norm: norm before classifier head. Default: partial(nn.LayerNorm, eps=1e-6). 570 | head_fn: classification head. Default: nn.Linear. 571 | """ 572 | 573 | def __init__(self, in_chans=3, num_classes=1000, 574 | depths=[2, 2, 6, 2], 575 | dims=[64, 128, 320, 512], 576 | downsample_layers=DOWNSAMPLE_LAYERS_FOUR_STAGES, 577 | token_mixers=nn.Identity, 578 | mlps=Mlp, 579 | norm_layers=partial(LayerNormWithoutBias, eps=1e-6), 580 | # partial(LayerNormGeneral, eps=1e-6, bias=False), 581 | drop_path_rate=0., 582 | head_dropout=0.0, 583 | layer_scale_init_values=None, 584 | res_scale_init_values=[None, None, 1.0, 1.0], 585 | output_norm=partial(nn.LayerNorm, eps=1e-6), 586 | head_fn=nn.Linear, 587 | get_feat=False, 588 | dulbrn=0, 589 | **kwargs, 590 | ): 591 | super().__init__() 592 | self.dulbrn = dulbrn 593 | if self.dulbrn: 594 | print("add local conv layer") 595 | # self.local_conv = nn.Sequential( 596 | # nn.Conv2d(3, self.dulbrn, 3, stride=2, padding=1), 597 | # nn.ReLU(True), 598 | # nn.Conv2d(self.dulbrn, self.dulbrn, 3, padding=1), 599 | # nn.ReLU(True), 600 | # ) 601 | self.conv1 = nn.Sequential( 602 | nn.Conv2d(3, self.dulbrn, (3, 3), stride=1, padding=1), 603 | nn.ReLU(inplace=True) 604 | ) 605 | 606 | self.conv2 = nn.Sequential( 607 | # nn.MaxPool2d(2, stride=2, ceil_mode=True), 608 | nn.Conv2d(self.dulbrn, self.dulbrn * 2, (3, 3), stride=2, padding=1), 609 | nn.ReLU(inplace=True) 610 | ) 611 | 612 | self.out_channels = [self.dulbrn, self.dulbrn * 2] 613 | # self.out_channels = [self.dulbrn] 614 | self.out_channels.extend(dims[:-1]) 615 | else: 616 | print("not add local conv layer") 617 | self.out_channels = dims 618 | 619 | self.num_classes = num_classes 620 | 621 | if not isinstance(depths, (list, tuple)): 622 | depths = [depths] # it means the model has only one stage 623 | if not isinstance(dims, (list, tuple)): 624 | dims = [dims] 625 | ## add by Archer, remove the last SA layer 626 | num_stage = len(depths) - 1 627 | # num_stage = len(depths) 628 | self.num_stage = num_stage 629 | 630 | if not isinstance(downsample_layers, (list, tuple)): 631 | downsample_layers = [downsample_layers] * num_stage 632 | down_dims = [in_chans] + dims 633 | self.downsample_layers = nn.ModuleList( 634 | [downsample_layers[i](down_dims[i], down_dims[i + 1]) for i in range(num_stage)] 635 | ) 636 | 637 | if not isinstance(token_mixers, (list, tuple)): 638 | token_mixers = [token_mixers] * num_stage 639 | 640 | if not isinstance(mlps, (list, tuple)): 641 | mlps = [mlps] * num_stage 642 | 643 | if not isinstance(norm_layers, (list, tuple)): 644 | norm_layers = [norm_layers] * num_stage 645 | 646 | dp_rates = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] 647 | 648 | if not isinstance(layer_scale_init_values, (list, tuple)): 649 | layer_scale_init_values = [layer_scale_init_values] * num_stage 650 | if not isinstance(res_scale_init_values, (list, tuple)): 651 | res_scale_init_values = [res_scale_init_values] * num_stage 652 | 653 | self.stages = nn.ModuleList() # each stage consists of multiple metaformer blocks 654 | cur = 0 655 | for i in range(num_stage): 656 | stage = nn.Sequential( 657 | *[MetaFormerBlock(dim=dims[i], 658 | token_mixer=token_mixers[i], 659 | mlp=mlps[i], 660 | norm_layer=norm_layers[i], 661 | drop_path=dp_rates[cur + j], 662 | layer_scale_init_value=layer_scale_init_values[i], 663 | res_scale_init_value=res_scale_init_values[i], 664 | ) for j in range(depths[i])] 665 | ) 666 | self.stages.append(stage) 667 | cur += depths[i] 668 | 669 | @torch.jit.ignore 670 | def no_weight_decay(self): 671 | return {'norm'} 672 | 673 | def forward(self, x): 674 | if self.dulbrn: 675 | f1 = self.conv1(x) 676 | f2 = self.conv2(f1) 677 | features = [f1, f2] 678 | # features = [self.local_conv(x)] 679 | else: 680 | features = [] 681 | for i in range(self.num_stage): 682 | x = self.downsample_layers[i](x) 683 | x = self.stages[i](x) 684 | features.append(x.permute(0, 3, 1, 2)) 685 | 686 | return features 687 | 688 | 689 | @register_model 690 | def identityformer_s12(pretrained=False, **kwargs): 691 | model = MetaFormer( 692 | depths=[2, 2, 6, 2], 693 | dims=[64, 128, 320, 512], 694 | token_mixers=nn.Identity, 695 | norm_layers=partial(LayerNormGeneral, normalized_dim=(1, 2, 3), eps=1e-6, bias=False), 696 | **kwargs) 697 | model.default_cfg = default_cfgs['identityformer_s12'] 698 | if pretrained: 699 | state_dict = torch.hub.load_state_dict_from_url( 700 | url=model.default_cfg['url'], map_location="cpu", check_hash=True) 701 | model.load_state_dict(state_dict) 702 | return model 703 | 704 | 705 | @register_model 706 | def identityformer_s24(pretrained=False, **kwargs): 707 | model = MetaFormer( 708 | depths=[4, 4, 12, 4], 709 | dims=[64, 128, 320, 512], 710 | token_mixers=nn.Identity, 711 | norm_layers=partial(LayerNormGeneral, normalized_dim=(1, 2, 3), eps=1e-6, bias=False), 712 | **kwargs) 713 | model.default_cfg = default_cfgs['identityformer_s24'] 714 | if pretrained: 715 | state_dict = torch.hub.load_state_dict_from_url( 716 | url=model.default_cfg['url'], map_location="cpu", check_hash=True) 717 | model.load_state_dict(state_dict) 718 | return model 719 | 720 | 721 | @register_model 722 | def identityformer_s36(pretrained=False, **kwargs): 723 | model = MetaFormer( 724 | depths=[6, 6, 18, 6], 725 | dims=[64, 128, 320, 512], 726 | token_mixers=nn.Identity, 727 | norm_layers=partial(LayerNormGeneral, normalized_dim=(1, 2, 3), eps=1e-6, bias=False), 728 | **kwargs) 729 | model.default_cfg = default_cfgs['identityformer_s36'] 730 | if pretrained: 731 | state_dict = torch.hub.load_state_dict_from_url( 732 | url=model.default_cfg['url'], map_location="cpu", check_hash=True) 733 | model.load_state_dict(state_dict) 734 | return model 735 | 736 | 737 | @register_model 738 | def identityformer_m36(pretrained=False, **kwargs): 739 | model = MetaFormer( 740 | depths=[6, 6, 18, 6], 741 | dims=[96, 192, 384, 768], 742 | token_mixers=nn.Identity, 743 | norm_layers=partial(LayerNormGeneral, normalized_dim=(1, 2, 3), eps=1e-6, bias=False), 744 | **kwargs) 745 | model.default_cfg = default_cfgs['identityformer_m36'] 746 | if pretrained: 747 | state_dict = torch.hub.load_state_dict_from_url( 748 | url=model.default_cfg['url'], map_location="cpu", check_hash=True) 749 | model.load_state_dict(state_dict) 750 | return model 751 | 752 | 753 | @register_model 754 | def identityformer_m48(pretrained=False, **kwargs): 755 | model = MetaFormer( 756 | depths=[8, 8, 24, 8], 757 | dims=[96, 192, 384, 768], 758 | token_mixers=nn.Identity, 759 | norm_layers=partial(LayerNormGeneral, normalized_dim=(1, 2, 3), eps=1e-6, bias=False), 760 | **kwargs) 761 | model.default_cfg = default_cfgs['identityformer_m48'] 762 | if pretrained: 763 | state_dict = torch.hub.load_state_dict_from_url( 764 | url=model.default_cfg['url'], map_location="cpu", check_hash=True) 765 | model.load_state_dict(state_dict) 766 | return model 767 | 768 | 769 | @register_model 770 | def randformer_s12(pretrained=False, **kwargs): 771 | model = MetaFormer( 772 | depths=[2, 2, 6, 2], 773 | dims=[64, 128, 320, 512], 774 | token_mixers=[nn.Identity, nn.Identity, RandomMixing, partial(RandomMixing, num_tokens=49)], 775 | norm_layers=partial(LayerNormGeneral, normalized_dim=(1, 2, 3), eps=1e-6, bias=False), 776 | **kwargs) 777 | model.default_cfg = default_cfgs['randformer_s12'] 778 | if pretrained: 779 | state_dict = torch.hub.load_state_dict_from_url( 780 | url=model.default_cfg['url'], map_location="cpu", check_hash=True) 781 | model.load_state_dict(state_dict) 782 | return model 783 | 784 | 785 | @register_model 786 | def randformer_s24(pretrained=False, **kwargs): 787 | model = MetaFormer( 788 | depths=[4, 4, 12, 4], 789 | dims=[64, 128, 320, 512], 790 | token_mixers=[nn.Identity, nn.Identity, RandomMixing, partial(RandomMixing, num_tokens=49)], 791 | norm_layers=partial(LayerNormGeneral, normalized_dim=(1, 2, 3), eps=1e-6, bias=False), 792 | **kwargs) 793 | model.default_cfg = default_cfgs['randformer_s24'] 794 | if pretrained: 795 | state_dict = torch.hub.load_state_dict_from_url( 796 | url=model.default_cfg['url'], map_location="cpu", check_hash=True) 797 | model.load_state_dict(state_dict) 798 | return model 799 | 800 | 801 | @register_model 802 | def randformer_s36(pretrained=False, **kwargs): 803 | model = MetaFormer( 804 | depths=[6, 6, 18, 6], 805 | dims=[64, 128, 320, 512], 806 | token_mixers=[nn.Identity, nn.Identity, RandomMixing, partial(RandomMixing, num_tokens=49)], 807 | norm_layers=partial(LayerNormGeneral, normalized_dim=(1, 2, 3), eps=1e-6, bias=False), 808 | **kwargs) 809 | model.default_cfg = default_cfgs['randformer_s36'] 810 | if pretrained: 811 | state_dict = torch.hub.load_state_dict_from_url( 812 | url=model.default_cfg['url'], map_location="cpu", check_hash=True) 813 | model.load_state_dict(state_dict) 814 | return model 815 | 816 | 817 | @register_model 818 | def randformer_m36(pretrained=False, **kwargs): 819 | model = MetaFormer( 820 | depths=[6, 6, 18, 6], 821 | dims=[96, 192, 384, 768], 822 | token_mixers=[nn.Identity, nn.Identity, RandomMixing, partial(RandomMixing, num_tokens=49)], 823 | norm_layers=partial(LayerNormGeneral, normalized_dim=(1, 2, 3), eps=1e-6, bias=False), 824 | **kwargs) 825 | model.default_cfg = default_cfgs['randformer_m36'] 826 | if pretrained: 827 | state_dict = torch.hub.load_state_dict_from_url( 828 | url=model.default_cfg['url'], map_location="cpu", check_hash=True) 829 | model.load_state_dict(state_dict) 830 | return model 831 | 832 | 833 | @register_model 834 | def randformer_m48(pretrained=False, **kwargs): 835 | model = MetaFormer( 836 | depths=[8, 8, 24, 8], 837 | dims=[96, 192, 384, 768], 838 | token_mixers=[nn.Identity, nn.Identity, RandomMixing, partial(RandomMixing, num_tokens=49)], 839 | norm_layers=partial(LayerNormGeneral, normalized_dim=(1, 2, 3), eps=1e-6, bias=False), 840 | **kwargs) 841 | model.default_cfg = default_cfgs['randformer_m48'] 842 | if pretrained: 843 | state_dict = torch.hub.load_state_dict_from_url( 844 | url=model.default_cfg['url'], map_location="cpu", check_hash=True) 845 | model.load_state_dict(state_dict) 846 | return model 847 | 848 | 849 | @register_model 850 | def poolformerv2_s12(pretrained=False, **kwargs): 851 | model = MetaFormer( 852 | depths=[2, 2, 6, 2], 853 | dims=[64, 128, 320, 512], 854 | token_mixers=Pooling, 855 | norm_layers=partial(LayerNormGeneral, normalized_dim=(1, 2, 3), eps=1e-6, bias=False), 856 | **kwargs) 857 | model.default_cfg = default_cfgs['poolformerv2_s12'] 858 | if pretrained: 859 | state_dict = torch.hub.load_state_dict_from_url( 860 | url=model.default_cfg['url'], map_location="cpu", check_hash=True) 861 | model.load_state_dict(state_dict) 862 | return model 863 | 864 | 865 | @register_model 866 | def poolformerv2_s24(pretrained=False, **kwargs): 867 | model = MetaFormer( 868 | depths=[4, 4, 12, 4], 869 | dims=[64, 128, 320, 512], 870 | token_mixers=Pooling, 871 | norm_layers=partial(LayerNormGeneral, normalized_dim=(1, 2, 3), eps=1e-6, bias=False), 872 | **kwargs) 873 | model.default_cfg = default_cfgs['poolformerv2_s24'] 874 | if pretrained: 875 | state_dict = torch.hub.load_state_dict_from_url( 876 | url=model.default_cfg['url'], map_location="cpu", check_hash=True) 877 | model.load_state_dict(state_dict) 878 | return model 879 | 880 | 881 | @register_model 882 | def poolformerv2_s36(pretrained=False, **kwargs): 883 | model = MetaFormer( 884 | depths=[6, 6, 18, 6], 885 | dims=[64, 128, 320, 512], 886 | token_mixers=Pooling, 887 | norm_layers=partial(LayerNormGeneral, normalized_dim=(1, 2, 3), eps=1e-6, bias=False), 888 | **kwargs) 889 | model.default_cfg = default_cfgs['poolformerv2_s36'] 890 | if pretrained: 891 | state_dict = torch.hub.load_state_dict_from_url( 892 | url=model.default_cfg['url'], map_location="cpu", check_hash=True) 893 | model.load_state_dict(state_dict) 894 | return model 895 | 896 | 897 | @register_model 898 | def poolformerv2_m36(pretrained=False, **kwargs): 899 | model = MetaFormer( 900 | depths=[6, 6, 18, 6], 901 | dims=[96, 192, 384, 768], 902 | token_mixers=Pooling, 903 | norm_layers=partial(LayerNormGeneral, normalized_dim=(1, 2, 3), eps=1e-6, bias=False), 904 | **kwargs) 905 | model.default_cfg = default_cfgs['poolformerv2_m36'] 906 | if pretrained: 907 | state_dict = torch.hub.load_state_dict_from_url( 908 | url=model.default_cfg['url'], map_location="cpu", check_hash=True) 909 | model.load_state_dict(state_dict) 910 | return model 911 | 912 | 913 | @register_model 914 | def poolformerv2_m48(pretrained=False, **kwargs): 915 | model = MetaFormer( 916 | depths=[8, 8, 24, 8], 917 | dims=[96, 192, 384, 768], 918 | token_mixers=Pooling, 919 | norm_layers=partial(LayerNormGeneral, normalized_dim=(1, 2, 3), eps=1e-6, bias=False), 920 | **kwargs) 921 | model.default_cfg = default_cfgs['poolformerv2_m48'] 922 | if pretrained: 923 | state_dict = torch.hub.load_state_dict_from_url( 924 | url=model.default_cfg['url'], map_location="cpu", check_hash=True) 925 | model.load_state_dict(state_dict) 926 | return model 927 | 928 | 929 | @register_model 930 | def convformer_s18(pretrained=False, **kwargs): 931 | model = MetaFormer( 932 | depths=[3, 3, 9, 3], 933 | dims=[64, 128, 320, 512], 934 | token_mixers=SepConv, 935 | head_fn=MlpHead, 936 | **kwargs) 937 | model.default_cfg = default_cfgs['convformer_s18'] 938 | if pretrained: 939 | state_dict = torch.hub.load_state_dict_from_url( 940 | url=model.default_cfg['url'], map_location="cpu", check_hash=True) 941 | model.load_state_dict(state_dict) 942 | return model 943 | 944 | 945 | @register_model 946 | def convformer_s18_384(pretrained=False, **kwargs): 947 | model = MetaFormer( 948 | depths=[3, 3, 9, 3], 949 | dims=[64, 128, 320, 512], 950 | token_mixers=SepConv, 951 | head_fn=MlpHead, 952 | **kwargs) 953 | model.default_cfg = default_cfgs['convformer_s18_384'] 954 | if pretrained: 955 | state_dict = torch.hub.load_state_dict_from_url( 956 | url=model.default_cfg['url'], map_location="cpu", check_hash=True) 957 | model.load_state_dict(state_dict) 958 | return model 959 | 960 | 961 | @register_model 962 | def convformer_s18_in21ft1k(pretrained=False, **kwargs): 963 | model = MetaFormer( 964 | depths=[3, 3, 9, 3], 965 | dims=[64, 128, 320, 512], 966 | token_mixers=SepConv, 967 | head_fn=MlpHead, 968 | **kwargs) 969 | model.default_cfg = default_cfgs['convformer_s18_in21ft1k'] 970 | if pretrained: 971 | state_dict = torch.hub.load_state_dict_from_url( 972 | url=model.default_cfg['url'], map_location="cpu", check_hash=True) 973 | model.load_state_dict(state_dict) 974 | return model 975 | 976 | 977 | @register_model 978 | def convformer_s18_384_in21ft1k(pretrained=False, **kwargs): 979 | model = MetaFormer( 980 | depths=[3, 3, 9, 3], 981 | dims=[64, 128, 320, 512], 982 | token_mixers=SepConv, 983 | head_fn=MlpHead, 984 | **kwargs) 985 | model.default_cfg = default_cfgs['convformer_s18_384_in21ft1k'] 986 | if pretrained: 987 | state_dict = torch.hub.load_state_dict_from_url( 988 | url=model.default_cfg['url'], map_location="cpu", check_hash=True) 989 | model.load_state_dict(state_dict) 990 | return model 991 | 992 | 993 | @register_model 994 | def convformer_s18_in21k(pretrained=False, **kwargs): 995 | model = MetaFormer( 996 | depths=[3, 3, 9, 3], 997 | dims=[64, 128, 320, 512], 998 | token_mixers=SepConv, 999 | head_fn=MlpHead, 1000 | **kwargs) 1001 | model.default_cfg = default_cfgs['convformer_s18_in21k'] 1002 | if pretrained: 1003 | state_dict = torch.hub.load_state_dict_from_url( 1004 | url=model.default_cfg['url'], map_location="cpu", check_hash=True) 1005 | model.load_state_dict(state_dict) 1006 | return model 1007 | 1008 | 1009 | @register_model 1010 | def convformer_s36(pretrained=False, **kwargs): 1011 | model = MetaFormer( 1012 | depths=[3, 12, 18, 3], 1013 | dims=[64, 128, 320, 512], 1014 | token_mixers=SepConv, 1015 | head_fn=MlpHead, 1016 | **kwargs) 1017 | model.default_cfg = default_cfgs['convformer_s36'] 1018 | if pretrained: 1019 | state_dict = torch.hub.load_state_dict_from_url( 1020 | url=model.default_cfg['url'], map_location="cpu", check_hash=True) 1021 | model.load_state_dict(state_dict) 1022 | return model 1023 | 1024 | 1025 | @register_model 1026 | def convformer_s36_384(pretrained=False, **kwargs): 1027 | model = MetaFormer( 1028 | depths=[3, 12, 18, 3], 1029 | dims=[64, 128, 320, 512], 1030 | token_mixers=SepConv, 1031 | head_fn=MlpHead, 1032 | **kwargs) 1033 | model.default_cfg = default_cfgs['convformer_s36_384'] 1034 | if pretrained: 1035 | state_dict = torch.hub.load_state_dict_from_url( 1036 | url=model.default_cfg['url'], map_location="cpu", check_hash=True) 1037 | model.load_state_dict(state_dict) 1038 | return model 1039 | 1040 | 1041 | @register_model 1042 | def convformer_s36_in21ft1k(pretrained=False, **kwargs): 1043 | model = MetaFormer( 1044 | depths=[3, 12, 18, 3], 1045 | dims=[64, 128, 320, 512], 1046 | token_mixers=SepConv, 1047 | head_fn=MlpHead, 1048 | **kwargs) 1049 | model.default_cfg = default_cfgs['convformer_s36_in21ft1k'] 1050 | if pretrained: 1051 | state_dict = torch.hub.load_state_dict_from_url( 1052 | url=model.default_cfg['url'], map_location="cpu", check_hash=True) 1053 | model.load_state_dict(state_dict) 1054 | return model 1055 | 1056 | 1057 | @register_model 1058 | def convformer_s36_384_in21ft1k(pretrained=False, **kwargs): 1059 | model = MetaFormer( 1060 | depths=[3, 12, 18, 3], 1061 | dims=[64, 128, 320, 512], 1062 | token_mixers=SepConv, 1063 | head_fn=MlpHead, 1064 | **kwargs) 1065 | 1066 | if pretrained: 1067 | state_dict = torch.load("model/convformer_s36_384_in21ft1k.pth", 1068 | map_location="cpu") 1069 | model_dict = model.state_dict() 1070 | pretrained_dict = {k: v for k, v in state_dict.items() if k in model_dict} 1071 | model_dict.update(pretrained_dict) 1072 | model.load_state_dict(model_dict) 1073 | # from collections import OrderedDict 1074 | # new_state_dict = OrderedDict() 1075 | # for k, v in state_dict.items(): 1076 | # name = k.replace('module.', '') # 去除 "module." 的 prefix 1077 | # new_state_dict[name] = v 1078 | # model.load_state_dict(new_state_dict) 1079 | return model 1080 | 1081 | 1082 | @register_model 1083 | def convformer_s36_in21k(pretrained=False, **kwargs): 1084 | model = MetaFormer( 1085 | depths=[3, 12, 18, 3], 1086 | dims=[64, 128, 320, 512], 1087 | token_mixers=SepConv, 1088 | head_fn=MlpHead, 1089 | **kwargs) 1090 | model.default_cfg = default_cfgs['convformer_s36_in21k'] 1091 | if pretrained: 1092 | state_dict = torch.hub.load_state_dict_from_url( 1093 | url=model.default_cfg['url'], map_location="cpu", check_hash=True) 1094 | model.load_state_dict(state_dict) 1095 | return model 1096 | 1097 | 1098 | @register_model 1099 | def convformer_m36(pretrained=False, **kwargs): 1100 | model = MetaFormer( 1101 | depths=[3, 12, 18, 3], 1102 | dims=[96, 192, 384, 576], 1103 | token_mixers=SepConv, 1104 | head_fn=MlpHead, 1105 | **kwargs) 1106 | model.default_cfg = default_cfgs['convformer_m36'] 1107 | if pretrained: 1108 | state_dict = torch.hub.load_state_dict_from_url( 1109 | url=model.default_cfg['url'], map_location="cpu", check_hash=True) 1110 | model.load_state_dict(state_dict) 1111 | return model 1112 | 1113 | 1114 | @register_model 1115 | def convformer_m36_384(pretrained=False, **kwargs): 1116 | model = MetaFormer( 1117 | depths=[3, 12, 18, 3], 1118 | dims=[96, 192, 384, 576], 1119 | token_mixers=SepConv, 1120 | head_fn=MlpHead, 1121 | **kwargs) 1122 | model.default_cfg = default_cfgs['convformer_m36_384'] 1123 | if pretrained: 1124 | state_dict = torch.hub.load_state_dict_from_url( 1125 | url=model.default_cfg['url'], map_location="cpu", check_hash=True) 1126 | model.load_state_dict(state_dict) 1127 | return model 1128 | 1129 | 1130 | @register_model 1131 | def convformer_m36_in21ft1k(pretrained=False, **kwargs): 1132 | model = MetaFormer( 1133 | depths=[3, 12, 18, 3], 1134 | dims=[96, 192, 384, 576], 1135 | token_mixers=SepConv, 1136 | head_fn=MlpHead, 1137 | **kwargs) 1138 | model.default_cfg = default_cfgs['convformer_m36_in21ft1k'] 1139 | if pretrained: 1140 | state_dict = torch.hub.load_state_dict_from_url( 1141 | url=model.default_cfg['url'], map_location="cpu", check_hash=True) 1142 | model.load_state_dict(state_dict) 1143 | return model 1144 | 1145 | 1146 | def convformer_m36_384_in21ft1k(pretrained=False, **kwargs): 1147 | model = MetaFormer( 1148 | depths=[3, 12, 18, 3], 1149 | dims=[96, 192, 384, 576], 1150 | token_mixers=SepConv, 1151 | head_fn=MlpHead, 1152 | get_feat=True, 1153 | **kwargs) 1154 | if pretrained: 1155 | state_dict = torch.load("model/convformer_m36_384_in21ft1k.pth", 1156 | map_location="cpu") 1157 | model_dict = model.state_dict() 1158 | pretrained_dict = {k: v for k, v in state_dict.items() if k in model_dict} 1159 | model_dict.update(pretrained_dict) 1160 | model.load_state_dict(model_dict) 1161 | 1162 | # from collections import OrderedDict 1163 | # new_state_dict = OrderedDict() 1164 | # for k, v in state_dict.items(): 1165 | # name = k.replace('module.', '') # 去除 "module." 的 prefix 1166 | # new_state_dict[name] = v 1167 | # model.load_state_dict(new_state_dict) 1168 | return model 1169 | 1170 | 1171 | @register_model 1172 | def convformer_m36_in21k(pretrained=False, **kwargs): 1173 | model = MetaFormer( 1174 | depths=[3, 12, 18, 3], 1175 | dims=[96, 192, 384, 576], 1176 | token_mixers=SepConv, 1177 | head_fn=MlpHead, 1178 | **kwargs) 1179 | model.default_cfg = default_cfgs['convformer_m36_in21k'] 1180 | if pretrained: 1181 | state_dict = torch.hub.load_state_dict_from_url( 1182 | url=model.default_cfg['url'], map_location="cpu", check_hash=True) 1183 | model.load_state_dict(state_dict) 1184 | return model 1185 | 1186 | 1187 | @register_model 1188 | def convformer_b36(pretrained=False, **kwargs): 1189 | model = MetaFormer( 1190 | depths=[3, 12, 18, 3], 1191 | dims=[128, 256, 512, 768], 1192 | token_mixers=SepConv, 1193 | head_fn=MlpHead, 1194 | **kwargs) 1195 | model.default_cfg = default_cfgs['convformer_b36'] 1196 | if pretrained: 1197 | state_dict = torch.hub.load_state_dict_from_url( 1198 | url=model.default_cfg['url'], map_location="cpu", check_hash=True) 1199 | model.load_state_dict(state_dict) 1200 | return model 1201 | 1202 | 1203 | @register_model 1204 | def convformer_b36_384(pretrained=False, **kwargs): 1205 | model = MetaFormer( 1206 | depths=[3, 12, 18, 3], 1207 | dims=[128, 256, 512, 768], 1208 | token_mixers=SepConv, 1209 | head_fn=MlpHead, 1210 | **kwargs) 1211 | model.default_cfg = default_cfgs['convformer_b36_384'] 1212 | if pretrained: 1213 | state_dict = torch.hub.load_state_dict_from_url( 1214 | url=model.default_cfg['url'], map_location="cpu", check_hash=True) 1215 | model.load_state_dict(state_dict) 1216 | return model 1217 | 1218 | 1219 | @register_model 1220 | def convformer_b36_in21ft1k(pretrained=False, **kwargs): 1221 | model = MetaFormer( 1222 | depths=[3, 12, 18, 3], 1223 | dims=[128, 256, 512, 768], 1224 | token_mixers=SepConv, 1225 | head_fn=MlpHead, 1226 | **kwargs) 1227 | model.default_cfg = default_cfgs['convformer_b36_in21ft1k'] 1228 | if pretrained: 1229 | state_dict = torch.hub.load_state_dict_from_url( 1230 | url=model.default_cfg['url'], map_location="cpu", check_hash=True) 1231 | model.load_state_dict(state_dict) 1232 | return model 1233 | 1234 | 1235 | @register_model 1236 | def convformer_b36_384_in21ft1k(pretrained=False, **kwargs): 1237 | model = MetaFormer( 1238 | depths=[3, 12, 18, 3], 1239 | dims=[128, 256, 512, 768], 1240 | token_mixers=SepConv, 1241 | head_fn=MlpHead, 1242 | **kwargs) 1243 | model.default_cfg = default_cfgs['convformer_b36_384_in21ft1k'] 1244 | if pretrained: 1245 | state_dict = torch.hub.load_state_dict_from_url( 1246 | url=model.default_cfg['url'], map_location="cpu", check_hash=True) 1247 | model.load_state_dict(state_dict) 1248 | return model 1249 | 1250 | 1251 | @register_model 1252 | def convformer_b36_in21k(pretrained=False, **kwargs): 1253 | model = MetaFormer( 1254 | depths=[3, 12, 18, 3], 1255 | dims=[128, 256, 512, 768], 1256 | token_mixers=SepConv, 1257 | head_fn=MlpHead, 1258 | **kwargs) 1259 | model.default_cfg = default_cfgs['convformer_b36_in21k'] 1260 | if pretrained: 1261 | state_dict = torch.hub.load_state_dict_from_url( 1262 | url=model.default_cfg['url'], map_location="cpu", check_hash=True) 1263 | model.load_state_dict(state_dict) 1264 | return model 1265 | 1266 | 1267 | @register_model 1268 | def caformer_s18(pretrained=False, **kwargs): 1269 | model = MetaFormer( 1270 | depths=[3, 3, 9, 3], 1271 | dims=[64, 128, 320, 512], 1272 | token_mixers=[SepConv, SepConv, Attention, Attention], 1273 | head_fn=MlpHead, 1274 | **kwargs) 1275 | model.default_cfg = default_cfgs['caformer_s18'] 1276 | if pretrained: 1277 | state_dict = torch.hub.load_state_dict_from_url( 1278 | url=model.default_cfg['url'], map_location="cpu", check_hash=True) 1279 | model.load_state_dict(state_dict) 1280 | return model 1281 | 1282 | 1283 | @register_model 1284 | def caformer_s18_384(pretrained=False, **kwargs): 1285 | model = MetaFormer( 1286 | depths=[3, 3, 9, 3], 1287 | dims=[64, 128, 320, 512], 1288 | token_mixers=[SepConv, SepConv, Attention, Attention], 1289 | head_fn=MlpHead, 1290 | **kwargs) 1291 | model.default_cfg = default_cfgs['caformer_s18_384'] 1292 | if pretrained: 1293 | state_dict = torch.hub.load_state_dict_from_url( 1294 | url=model.default_cfg['url'], map_location="cpu", check_hash=True) 1295 | model.load_state_dict(state_dict) 1296 | return model 1297 | 1298 | 1299 | @register_model 1300 | def caformer_s18_in21ft1k(pretrained=False, **kwargs): 1301 | model = MetaFormer( 1302 | depths=[3, 3, 9, 3], 1303 | dims=[64, 128, 320, 512], 1304 | token_mixers=[SepConv, SepConv, Attention, Attention], 1305 | head_fn=MlpHead, 1306 | **kwargs) 1307 | model.default_cfg = default_cfgs['caformer_s18_in21ft1k'] 1308 | if pretrained: 1309 | state_dict = torch.hub.load_state_dict_from_url( 1310 | url=model.default_cfg['url'], map_location="cpu", check_hash=True) 1311 | model.load_state_dict(state_dict) 1312 | return model 1313 | 1314 | 1315 | @register_model 1316 | def caformer_s18_384_in21ft1k(pretrained=False, Dulbrn=0,**kwargs): 1317 | model = MetaFormer( 1318 | depths=[3, 3, 9, 3], 1319 | dims=[64, 128, 320, 512], 1320 | token_mixers=[SepConv, SepConv, Attention, Attention], 1321 | head_fn=MlpHead, 1322 | dulbrn=Dulbrn, 1323 | **kwargs) 1324 | if pretrained: 1325 | state_dict = torch.load("model/caformer_s18_384_in21ft1k.pth", 1326 | map_location="cpu") 1327 | model_dict = model.state_dict() 1328 | pretrained_dict = {k: v for k, v in state_dict.items() if k in model_dict} 1329 | model_dict.update(pretrained_dict) 1330 | model.load_state_dict(model_dict) 1331 | return model 1332 | 1333 | @register_model 1334 | def caformer_s18_in21k(pretrained=False, **kwargs): 1335 | model = MetaFormer( 1336 | depths=[3, 3, 9, 3], 1337 | dims=[64, 128, 320, 512], 1338 | token_mixers=[SepConv, SepConv, Attention, Attention], 1339 | head_fn=MlpHead, 1340 | **kwargs) 1341 | model.default_cfg = default_cfgs['caformer_s18_in21k'] 1342 | if pretrained: 1343 | state_dict = torch.hub.load_state_dict_from_url( 1344 | url=model.default_cfg['url'], map_location="cpu", check_hash=True) 1345 | model.load_state_dict(state_dict) 1346 | return model 1347 | 1348 | 1349 | @register_model 1350 | def caformer_s36(pretrained=False, **kwargs): 1351 | model = MetaFormer( 1352 | depths=[3, 12, 18, 3], 1353 | dims=[64, 128, 320, 512], 1354 | token_mixers=[SepConv, SepConv, Attention, Attention], 1355 | head_fn=MlpHead, 1356 | **kwargs) 1357 | model.default_cfg = default_cfgs['caformer_s36'] 1358 | if pretrained: 1359 | state_dict = torch.hub.load_state_dict_from_url( 1360 | url=model.default_cfg['url'], map_location="cpu", check_hash=True) 1361 | model.load_state_dict(state_dict) 1362 | return model 1363 | 1364 | 1365 | @register_model 1366 | def caformer_s36_384(pretrained=False, **kwargs): 1367 | model = MetaFormer( 1368 | depths=[3, 12, 18, 3], 1369 | dims=[64, 128, 320, 512], 1370 | token_mixers=[SepConv, SepConv, Attention, Attention], 1371 | head_fn=MlpHead, 1372 | **kwargs) 1373 | model.default_cfg = default_cfgs['caformer_s36_384'] 1374 | if pretrained: 1375 | state_dict = torch.hub.load_state_dict_from_url( 1376 | url=model.default_cfg['url'], map_location="cpu", check_hash=True) 1377 | model.load_state_dict(state_dict) 1378 | return model 1379 | 1380 | 1381 | @register_model 1382 | def caformer_s36_in21ft1k(pretrained=False, **kwargs): 1383 | model = MetaFormer( 1384 | depths=[3, 12, 18, 3], 1385 | dims=[64, 128, 320, 512], 1386 | token_mixers=[SepConv, SepConv, Attention, Attention], 1387 | head_fn=MlpHead, 1388 | **kwargs) 1389 | model.default_cfg = default_cfgs['caformer_s36_in21ft1k'] 1390 | if pretrained: 1391 | state_dict = torch.hub.load_state_dict_from_url( 1392 | url=model.default_cfg['url'], map_location="cpu", check_hash=True) 1393 | model.load_state_dict(state_dict) 1394 | return model 1395 | 1396 | 1397 | @register_model 1398 | def caformer_s36_384_in21ft1k(pretrained=False, **kwargs): 1399 | model = MetaFormer( 1400 | depths=[3, 12, 18, 3], 1401 | dims=[64, 128, 320, 512], 1402 | token_mixers=[SepConv, SepConv, Attention, Attention], 1403 | head_fn=MlpHead, 1404 | **kwargs) 1405 | if pretrained: 1406 | state_dict = torch.load("model/caformer_s36_384_in21ft1k.pth", 1407 | map_location="cpu") 1408 | model_dict = model.state_dict() 1409 | pretrained_dict = {k: v for k, v in state_dict.items() if k in model_dict} 1410 | model_dict.update(pretrained_dict) 1411 | model.load_state_dict(model_dict) 1412 | return model 1413 | 1414 | 1415 | @register_model 1416 | def caformer_s36_in21k(pretrained=False, **kwargs): 1417 | model = MetaFormer( 1418 | depths=[3, 12, 18, 3], 1419 | dims=[64, 128, 320, 512], 1420 | token_mixers=[SepConv, SepConv, Attention, Attention], 1421 | head_fn=MlpHead, 1422 | **kwargs) 1423 | model.default_cfg = default_cfgs['caformer_s36_in21k'] 1424 | if pretrained: 1425 | state_dict = torch.hub.load_state_dict_from_url( 1426 | url=model.default_cfg['url'], map_location="cpu", check_hash=True) 1427 | model.load_state_dict(state_dict) 1428 | return model 1429 | 1430 | 1431 | @register_model 1432 | def caformer_m36(pretrained=False, **kwargs): 1433 | model = MetaFormer( 1434 | depths=[3, 12, 18, 3], 1435 | dims=[96, 192, 384, 576], 1436 | token_mixers=[SepConv, SepConv, Attention, Attention], 1437 | head_fn=MlpHead, 1438 | **kwargs) 1439 | model.default_cfg = default_cfgs['caformer_m36'] 1440 | if pretrained: 1441 | state_dict = torch.hub.load_state_dict_from_url( 1442 | url=model.default_cfg['url'], map_location="cpu", check_hash=True) 1443 | model.load_state_dict(state_dict) 1444 | return model 1445 | 1446 | 1447 | @register_model 1448 | def caformer_m36_384(pretrained=False, **kwargs): 1449 | model = MetaFormer( 1450 | depths=[3, 12, 18, 3], 1451 | dims=[96, 192, 384, 576], 1452 | token_mixers=[SepConv, SepConv, Attention, Attention], 1453 | head_fn=MlpHead, 1454 | **kwargs) 1455 | model.default_cfg = default_cfgs['caformer_m36_384'] 1456 | if pretrained: 1457 | state_dict = torch.hub.load_state_dict_from_url( 1458 | url=model.default_cfg['url'], map_location="cpu", check_hash=True) 1459 | model.load_state_dict(state_dict) 1460 | return model 1461 | 1462 | 1463 | @register_model 1464 | def caformer_m36_in21ft1k(pretrained=False, **kwargs): 1465 | model = MetaFormer( 1466 | depths=[3, 12, 18, 3], 1467 | dims=[96, 192, 384, 576], 1468 | token_mixers=[SepConv, SepConv, Attention, Attention], 1469 | head_fn=MlpHead, 1470 | **kwargs) 1471 | model.default_cfg = default_cfgs['caformer_m36_in21ft1k'] 1472 | if pretrained: 1473 | state_dict = torch.hub.load_state_dict_from_url( 1474 | url=model.default_cfg['url'], map_location="cpu", check_hash=True) 1475 | model.load_state_dict(state_dict) 1476 | return model 1477 | 1478 | 1479 | @register_model 1480 | def caformer_m36_384_in21ft1k(pretrained=False, Dulbrn=0, **kwargs): 1481 | model = MetaFormer( 1482 | depths=[3, 12, 18, 3], 1483 | dims=[96, 192, 384, 576], 1484 | token_mixers=[SepConv, SepConv, Attention, Attention], 1485 | head_fn=MlpHead, 1486 | dulbrn=Dulbrn, 1487 | **kwargs) 1488 | if pretrained: 1489 | state_dict = torch.load("model/caformer_m36_384_in21ft1k.pth", 1490 | map_location="cpu") 1491 | model_dict = model.state_dict() 1492 | pretrained_dict = {k: v for k, v in state_dict.items() if k in model_dict} 1493 | model_dict.update(pretrained_dict) 1494 | model.load_state_dict(model_dict) 1495 | 1496 | return model 1497 | 1498 | 1499 | @register_model 1500 | def caformer_m364_in21k(pretrained=False, **kwargs): 1501 | model = MetaFormer( 1502 | depths=[3, 12, 18, 3], 1503 | dims=[96, 192, 384, 576], 1504 | token_mixers=[SepConv, SepConv, Attention, Attention], 1505 | head_fn=MlpHead, 1506 | **kwargs) 1507 | model.default_cfg = default_cfgs['caformer_m364_in21k'] 1508 | if pretrained: 1509 | state_dict = torch.hub.load_state_dict_from_url( 1510 | url=model.default_cfg['url'], map_location="cpu", check_hash=True) 1511 | model.load_state_dict(state_dict) 1512 | return model 1513 | 1514 | 1515 | @register_model 1516 | def caformer_b36(pretrained=False, **kwargs): 1517 | model = MetaFormer( 1518 | depths=[3, 12, 18, 3], 1519 | dims=[128, 256, 512, 768], 1520 | token_mixers=[SepConv, SepConv, Attention, Attention], 1521 | head_fn=MlpHead, 1522 | **kwargs) 1523 | model.default_cfg = default_cfgs['caformer_b36'] 1524 | if pretrained: 1525 | state_dict = torch.hub.load_state_dict_from_url( 1526 | url=model.default_cfg['url'], map_location="cpu", check_hash=True) 1527 | model.load_state_dict(state_dict) 1528 | return model 1529 | 1530 | 1531 | @register_model 1532 | def caformer_b36_384(pretrained=False, **kwargs): 1533 | model = MetaFormer( 1534 | depths=[3, 12, 18, 3], 1535 | dims=[128, 256, 512, 768], 1536 | token_mixers=[SepConv, SepConv, Attention, Attention], 1537 | head_fn=MlpHead, 1538 | **kwargs) 1539 | model.default_cfg = default_cfgs['caformer_b36_384'] 1540 | if pretrained: 1541 | state_dict = torch.hub.load_state_dict_from_url( 1542 | url=model.default_cfg['url'], map_location="cpu", check_hash=True) 1543 | model.load_state_dict(state_dict) 1544 | return model 1545 | 1546 | 1547 | @register_model 1548 | def caformer_b36_in21ft1k(pretrained=False, **kwargs): 1549 | model = MetaFormer( 1550 | depths=[3, 12, 18, 3], 1551 | dims=[128, 256, 512, 768], 1552 | token_mixers=[SepConv, SepConv, Attention, Attention], 1553 | head_fn=MlpHead, 1554 | **kwargs) 1555 | model.default_cfg = default_cfgs['caformer_b36_in21ft1k'] 1556 | if pretrained: 1557 | state_dict = torch.hub.load_state_dict_from_url( 1558 | url=model.default_cfg['url'], map_location="cpu", check_hash=True) 1559 | model.load_state_dict(state_dict) 1560 | return model 1561 | 1562 | 1563 | @register_model 1564 | def caformer_b36_384_in21ft1k(pretrained=False, **kwargs): 1565 | model = MetaFormer( 1566 | depths=[3, 12, 18, 3], 1567 | dims=[128, 256, 512, 768], 1568 | token_mixers=[SepConv, SepConv, Attention, Attention], 1569 | head_fn=MlpHead, 1570 | **kwargs) 1571 | model.default_cfg = default_cfgs['caformer_b36_384_in21ft1k'] 1572 | if pretrained: 1573 | state_dict = torch.hub.load_state_dict_from_url( 1574 | url=model.default_cfg['url'], map_location="cpu", check_hash=True) 1575 | model.load_state_dict(state_dict) 1576 | return model 1577 | 1578 | 1579 | @register_model 1580 | def caformer_b36_in21k(pretrained=False, **kwargs): 1581 | model = MetaFormer( 1582 | depths=[3, 12, 18, 3], 1583 | dims=[128, 256, 512, 768], 1584 | token_mixers=[SepConv, SepConv, Attention, Attention], 1585 | head_fn=MlpHead, 1586 | **kwargs) 1587 | model.default_cfg = default_cfgs['caformer_b36_in21k'] 1588 | if pretrained: 1589 | state_dict = torch.hub.load_state_dict_from_url( 1590 | url=model.default_cfg['url'], map_location="cpu", check_hash=True) 1591 | model.load_state_dict(state_dict) 1592 | return model 1593 | --------------------------------------------------------------------------------