├── fcmae_model.py ├── model.py ├── model_print.py ├── my_dataset.py ├── predict.py ├── pretrain.py ├── readme.md ├── sparse_model.py ├── train.py └── utils.py /fcmae_model.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | 3 | # All rights reserved. 4 | 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | 9 | import torch 10 | import torch.nn as nn 11 | 12 | from timm.models.layers import trunc_normal_ 13 | from sparse_model import SparseConvNeXtV2 14 | from model import Block 15 | 16 | class FCMAE(nn.Module): 17 | """ Fully Convolutional Masked Autoencoder with ConvNeXtV2 backbone 18 | """ 19 | def __init__( 20 | self, 21 | img_size=224, 22 | in_chans=3, 23 | depths=[3, 3, 9, 3], 24 | dims=[96, 192, 384, 768], 25 | decoder_depth=1, 26 | decoder_embed_dim=512, 27 | patch_size=32, 28 | mask_ratio=0.6, 29 | norm_pix_loss=False): 30 | super().__init__() 31 | 32 | # configs 33 | self.img_size = img_size 34 | self.depths = depths 35 | self.imds = dims 36 | self.patch_size = patch_size 37 | self.mask_ratio = mask_ratio 38 | self.num_patches = (img_size // patch_size) ** 2 39 | self.decoder_embed_dim = decoder_embed_dim 40 | self.decoder_depth = decoder_depth 41 | self.norm_pix_loss = norm_pix_loss 42 | 43 | # encoder 44 | self.encoder = SparseConvNeXtV2( 45 | in_chans=in_chans, depths=depths, dims=dims, D=2) 46 | # decoder 47 | self.proj = nn.Conv2d( 48 | in_channels=dims[-1], 49 | out_channels=decoder_embed_dim, 50 | kernel_size=1) 51 | # mask tokens 52 | self.mask_token = nn.Parameter(torch.zeros(1, decoder_embed_dim, 1, 1)) 53 | decoder = [Block( 54 | dim=decoder_embed_dim, 55 | drop_path=0.) for i in range(decoder_depth)] 56 | self.decoder = nn.Sequential(*decoder) 57 | # pred 58 | self.pred = nn.Conv2d( 59 | in_channels=decoder_embed_dim, 60 | out_channels=patch_size ** 2 * in_chans, 61 | kernel_size=1) 62 | 63 | 64 | 65 | 66 | 67 | def patchify(self, imgs): 68 | """ 69 | imgs: (N, 3, H, W) 70 | x: (N, L, patch_size**2 *3) 71 | """ 72 | p = self.patch_size 73 | assert imgs.shape[2] == imgs.shape[3] and imgs.shape[2] % p == 0 74 | 75 | h = w = imgs.shape[2] // p 76 | x = imgs.reshape(shape=(imgs.shape[0], 3, h, p, w, p)) 77 | x = torch.einsum('nchpwq->nhwpqc', x) 78 | x = x.reshape(shape=(imgs.shape[0], h * w, p**2 * 3)) 79 | return x 80 | 81 | def unpatchify(self, x): 82 | """ 83 | x: (N, L, patch_size**2 *3) 84 | imgs: (N, 3, H, W) 85 | """ 86 | p = self.patch_size 87 | h = w = int(x.shape[1]**.5) 88 | assert h * w == x.shape[1] 89 | 90 | x = x.reshape(shape=(x.shape[0], h, w, p, p, 3)) 91 | x = torch.einsum('nhwpqc->nchpwq', x) 92 | imgs = x.reshape(shape=(x.shape[0], 3, h * p, h * p)) 93 | return imgs 94 | 95 | def gen_random_mask(self, x, mask_ratio): 96 | N = x.shape[0] 97 | L = (x.shape[2] // self.patch_size) ** 2 98 | len_keep = int(L * (1 - mask_ratio)) 99 | 100 | noise = torch.randn(N, L, device=x.device) 101 | 102 | # sort noise for each sample 103 | ids_shuffle = torch.argsort(noise, dim=1) 104 | ids_restore = torch.argsort(ids_shuffle, dim=1) 105 | 106 | # generate the binary mask: 0 is keep 1 is remove 107 | mask = torch.ones([N, L], device=x.device) 108 | mask[:, :len_keep] = 0 109 | # unshuffle to get the binary mask 110 | mask = torch.gather(mask, dim=1, index=ids_restore) 111 | return mask 112 | 113 | def upsample_mask(self, mask, scale): 114 | assert len(mask.shape) == 2 115 | p = int(mask.shape[1] ** .5) 116 | return mask.reshape(-1, p, p).\ 117 | repeat_interleave(scale, axis=1).\ 118 | repeat_interleave(scale, axis=2) 119 | 120 | def forward_encoder(self, imgs, mask_ratio): 121 | # generate random masks 122 | mask = self.gen_random_mask(imgs, mask_ratio) 123 | # encoding 124 | x = self.encoder(imgs, mask) 125 | return x, mask 126 | 127 | def forward_decoder(self, x, mask): 128 | x = self.proj(x) 129 | # append mask token 130 | n, c, h, w = x.shape 131 | mask = mask.reshape(-1, h, w).unsqueeze(1).type_as(x) 132 | mask_token = self.mask_token.repeat(x.shape[0], 1, x.shape[2], x.shape[3]) 133 | x = x * (1. - mask) + mask_token * mask 134 | # decoding 135 | x = self.decoder(x) 136 | # pred 137 | pred = self.pred(x) 138 | return pred 139 | 140 | def forward_loss(self, imgs, pred, mask): 141 | """ 142 | imgs: [N, 3, H, W] 143 | pred: [N, L, p*p*3] 144 | mask: [N, L], 0 is keep, 1 is remove 145 | """ 146 | if len(pred.shape) == 4: 147 | n, c, _, _ = pred.shape 148 | pred = pred.reshape(n, c, -1) 149 | pred = torch.einsum('ncl->nlc', pred) 150 | 151 | target = self.patchify(imgs) 152 | if self.norm_pix_loss: 153 | mean = target.mean(dim=-1, keepdim=True) 154 | var = target.var(dim=-1, keepdim=True) 155 | target = (target - mean) / (var + 1.e-6)**.5 156 | loss = (pred - target) ** 2 157 | loss = loss.mean(dim=-1) # [N, L], mean loss per patch 158 | 159 | loss = (loss * mask).sum() / mask.sum() # mean loss on removed patches 160 | return loss 161 | 162 | def forward(self, imgs, labels=None, mask_ratio=0.6): 163 | x, mask = self.forward_encoder(imgs, mask_ratio) 164 | pred = self.forward_decoder(x, mask) 165 | loss = self.forward_loss(imgs, pred, mask) 166 | return loss, pred, mask 167 | 168 | def convnextv2_atto(**kwargs): 169 | model = FCMAE( 170 | depths=[2, 2, 6, 2], dims=[40, 80, 160, 320], **kwargs) 171 | return model 172 | 173 | def convnextv2_femto(**kwargs): 174 | model = FCMAE( 175 | depths=[2, 2, 6, 2], dims=[48, 96, 192, 384], **kwargs) 176 | return model 177 | 178 | def convnextv2_pico(**kwargs): 179 | model = FCMAE( 180 | depths=[2, 2, 6, 2], dims=[64, 128, 256, 512], **kwargs) 181 | return model 182 | 183 | def convnextv2_nano(**kwargs): 184 | model = FCMAE( 185 | depths=[2, 2, 8, 2], dims=[80, 160, 320, 640], **kwargs) 186 | return model 187 | 188 | def convnextv2_tiny(**kwargs): 189 | model = FCMAE( 190 | depths=[3, 3, 9, 3], dims=[96, 192, 384, 768], **kwargs) 191 | return model 192 | 193 | def convnextv2_base(**kwargs): 194 | model = FCMAE( 195 | depths=[3, 3, 27, 3], dims=[128, 256, 512, 1024], **kwargs) 196 | return model 197 | 198 | def convnextv2_large(**kwargs): 199 | model = FCMAE( 200 | depths=[3, 3, 27, 3], dims=[192, 384, 768, 1536], **kwargs) 201 | return model 202 | 203 | def convnextv2_huge(**kwargs): 204 | model = FCMAE( 205 | depths=[3, 3, 27, 3], dims=[352, 704, 1408, 2816], **kwargs) 206 | return model -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | def drop_path(x, drop_prob: float = 0., training: bool = False): 7 | """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). 8 | This is the same as the DropConnect impl I created for EfficientNet, etc networks, however, 9 | the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... 10 | See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for 11 | changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 12 | 'survival rate' as the argument. 13 | """ 14 | if drop_prob == 0. or not training: 15 | return x 16 | keep_prob = 1 - drop_prob 17 | shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets 18 | random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device) 19 | random_tensor.floor_() # binarize 20 | output = x.div(keep_prob) * random_tensor 21 | return output 22 | 23 | 24 | class DropPath(nn.Module): 25 | """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). 26 | """ 27 | def __init__(self, drop_prob=None): 28 | super(DropPath, self).__init__() 29 | self.drop_prob = drop_prob 30 | 31 | def forward(self, x): 32 | return drop_path(x, self.drop_prob, self.training) 33 | 34 | 35 | class LayerNorm(nn.Module): 36 | r""" LayerNorm that supports two data formats: channels_last (default) or channels_first. 37 | The ordering of the dimensions in the inputs. channels_last corresponds to inputs with 38 | shape (batch_size, height, width, channels) while channels_first corresponds to inputs 39 | with shape (batch_size, channels, height, width). 40 | """ 41 | 42 | def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"): 43 | super().__init__() 44 | self.weight = nn.Parameter(torch.ones(normalized_shape), requires_grad=True) 45 | self.bias = nn.Parameter(torch.zeros(normalized_shape), requires_grad=True) 46 | self.eps = eps 47 | self.data_format = data_format 48 | if self.data_format not in ["channels_last", "channels_first"]: 49 | raise ValueError(f"not support data format '{self.data_format}'") 50 | self.normalized_shape = (normalized_shape,) 51 | 52 | def forward(self, x: torch.Tensor) -> torch.Tensor: 53 | if self.data_format == "channels_last": 54 | return F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps) 55 | elif self.data_format == "channels_first": 56 | # [batch_size, channels, height, width] 57 | mean = x.mean(1, keepdim=True) 58 | var = (x - mean).pow(2).mean(1, keepdim=True) 59 | x = (x - mean) / torch.sqrt(var + self.eps) 60 | x = self.weight[:, None, None] * x + self.bias[:, None, None] 61 | return x 62 | 63 | class GRN(nn.Module): 64 | """ GRN (Global Response Normalization) layer 65 | """ 66 | def __init__(self, dim): 67 | super().__init__() 68 | self.gamma = nn.Parameter(torch.zeros(1, 1, 1, dim)) 69 | self.beta = nn.Parameter(torch.zeros(1, 1, 1, dim)) 70 | 71 | def forward(self, x): 72 | Gx = torch.norm(x, p=2, dim=(1,2), keepdim=True) 73 | Nx = Gx / (Gx.mean(dim=-1, keepdim=True) + 1e-6) 74 | return self.gamma * (x * Nx) + self.beta + x 75 | 76 | class Block(nn.Module): 77 | """ ConvNeXtV2 Block. 78 | 79 | Args: 80 | dim (int): Number of input channels. 81 | drop_path (float): Stochastic depth rate. Default: 0.0 82 | """ 83 | def __init__(self, dim, drop_path=0.): 84 | super().__init__() 85 | self.dwconv = nn.Conv2d(dim, dim, kernel_size=7, padding=3, groups=dim) # depthwise conv 86 | self.norm = LayerNorm(dim, eps=1e-6) 87 | self.pwconv1 = nn.Linear(dim, 4 * dim) # pointwise/1x1 convs, implemented with linear layers 88 | self.act = nn.GELU() 89 | self.grn = GRN(4 * dim) 90 | self.pwconv2 = nn.Linear(4 * dim, dim) 91 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 92 | 93 | def forward(self, x): 94 | input = x 95 | x = self.dwconv(x) 96 | x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C) 97 | x = self.norm(x) 98 | x = self.pwconv1(x) 99 | x = self.act(x) 100 | x = self.grn(x) 101 | x = self.pwconv2(x) 102 | x = x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W) 103 | 104 | x = input + self.drop_path(x) 105 | return x 106 | 107 | class ConvNeXtV2(nn.Module): 108 | """ ConvNeXt V2 109 | 110 | Args: 111 | in_chans (int): Number of input image channels. Default: 3 112 | num_classes (int): Number of classes for classification head. Default: 1000 113 | depths (tuple(int)): Number of blocks at each stage. Default: [3, 3, 9, 3] 114 | dims (int): Feature dimension at each stage. Default: [96, 192, 384, 768] 115 | drop_path_rate (float): Stochastic depth rate. Default: 0. 116 | head_init_scale (float): Init scaling value for classifier weights and biases. Default: 1. 117 | """ 118 | def __init__(self, in_chans=3, num_classes=1000, 119 | depths=[3, 3, 9, 3], dims=[96, 192, 384, 768], 120 | drop_path_rate=0., head_init_scale=1. 121 | ): 122 | super().__init__() 123 | self.depths = depths 124 | self.downsample_layers = nn.ModuleList() # stem and 3 intermediate downsampling conv layers 125 | stem = nn.Sequential( 126 | nn.Conv2d(in_chans, dims[0], kernel_size=4, stride=4), 127 | LayerNorm(dims[0], eps=1e-6, data_format="channels_first") 128 | ) 129 | self.downsample_layers.append(stem) 130 | for i in range(3): 131 | downsample_layer = nn.Sequential( 132 | LayerNorm(dims[i], eps=1e-6, data_format="channels_first"), 133 | nn.Conv2d(dims[i], dims[i+1], kernel_size=2, stride=2), 134 | ) 135 | self.downsample_layers.append(downsample_layer) 136 | 137 | self.stages = nn.ModuleList() # 4 feature resolution stages, each consisting of multiple residual blocks 138 | dp_rates=[x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] 139 | cur = 0 140 | for i in range(4): 141 | stage = nn.Sequential( 142 | *[Block(dim=dims[i], drop_path=dp_rates[cur + j]) for j in range(depths[i])] 143 | ) 144 | self.stages.append(stage) 145 | cur += depths[i] 146 | 147 | self.norm = nn.LayerNorm(dims[-1], eps=1e-6) # final norm layer 148 | self.head = nn.Linear(dims[-1], num_classes) 149 | 150 | self.apply(self._init_weights) 151 | self.head.weight.data.mul_(head_init_scale) 152 | self.head.bias.data.mul_(head_init_scale) 153 | 154 | def _init_weights(self, m): 155 | if isinstance(m, (nn.Conv2d, nn.Linear)): 156 | nn.init.trunc_normal_(m.weight, std=.02) 157 | nn.init.constant_(m.bias, 0) 158 | 159 | def forward_features(self, x): 160 | for i in range(4): 161 | x = self.downsample_layers[i](x) 162 | x = self.stages[i](x) 163 | return self.norm(x.mean([-2, -1])) # global average pooling, (N, C, H, W) -> (N, C) 164 | 165 | def forward(self, x): 166 | x = self.forward_features(x) 167 | x = self.head(x) 168 | return x 169 | 170 | def convnextv2_atto(num_classes: int): 171 | #https://dl.fbaipublicfiles.com/convnext/convnextv2/im1k/convnextv2_atto_1k_224_ema.pt 172 | model = ConvNeXtV2(depths=[2, 2, 6, 2], dims=[40, 80, 160, 320], num_classes=num_classes) 173 | return model 174 | 175 | def convnextv2_femto(num_classes: int): 176 | #https://dl.fbaipublicfiles.com/convnext/convnextv2/im1k/convnextv2_femto_1k_224_ema.pt 177 | model = ConvNeXtV2(depths=[2, 2, 6, 2], dims=[48, 96, 192, 384], num_classes=num_classes) 178 | return model 179 | 180 | def convnext_pico(num_classes: int): 181 | #https://dl.fbaipublicfiles.com/convnext/convnextv2/im1k/convnextv2_pico_1k_224_ema.pt 182 | model = ConvNeXtV2(depths=[2, 2, 6, 2], dims=[64, 128, 256, 512], num_classes=num_classes) 183 | return model 184 | 185 | def convnextv2_nano(num_classes: int): 186 | #https://dl.fbaipublicfiles.com/convnext/convnextv2/im1k/convnextv2_nano_1k_224_ema.pt 187 | model = ConvNeXtV2(depths=[2, 2, 8, 2], dims=[80, 160, 320, 640], num_classes=num_classes) 188 | return model 189 | 190 | def convnextv2_tiny(num_classes: int): 191 | #https://dl.fbaipublicfiles.com/convnext/convnextv2/im1k/convnextv2_tiny_1k_224_ema.pt 192 | model = ConvNeXtV2(depths=[3, 3, 9, 3], dims=[96, 192, 384, 768], num_classes=num_classes) 193 | return model 194 | 195 | def convnextv2_base(num_classes: int): 196 | #https://dl.fbaipublicfiles.com/convnext/convnextv2/im1k/convnextv2_base_1k_224_ema.pt 197 | model = ConvNeXtV2(depths=[3, 3, 27, 3], dims=[128, 256, 512, 1024], num_classes=num_classes) 198 | return model 199 | 200 | def convnextv2_large(num_classes: int): 201 | #https://dl.fbaipublicfiles.com/convnext/convnextv2/im1k/convnextv2_large_1k_224_ema.pt 202 | model = ConvNeXtV2(depths=[3, 3, 27, 3], dims=[192, 384, 768, 1536], num_classes=num_classes) 203 | return model 204 | 205 | def convnextv2_huge(num_classes: int): 206 | #https://dl.fbaipublicfiles.com/convnext/convnextv2/im1k/convnextv2_huge_1k_224_ema.pt 207 | model = ConvNeXtV2(depths=[3, 3, 27, 3], dims=[352, 704, 1408, 2816], num_classes=num_classes) 208 | return model -------------------------------------------------------------------------------- /model_print.py: -------------------------------------------------------------------------------- 1 | from model import convnextv2_base 2 | from torchinfo import summary 3 | 4 | model = convnextv2_base(num_classes=2) 5 | out = summary(model, (1, 3, 384,384)) -------------------------------------------------------------------------------- /my_dataset.py: -------------------------------------------------------------------------------- 1 | from PIL import Image 2 | import torch 3 | from torch.utils.data import Dataset 4 | import cv2 5 | import numpy as np 6 | class MyDataSet(Dataset): 7 | """自定义数据集""" 8 | 9 | def __init__(self, images_path: list, images_class: list, transform=None): 10 | self.images_path = images_path 11 | self.images_class = images_class 12 | self.transform = transform 13 | 14 | def __len__(self): 15 | return len(self.images_path) 16 | 17 | def __getitem__(self, item): 18 | img = Image.open(self.images_path[item]).convert('RGB') 19 | # RGB为彩色图片,L为灰度图片 20 | 21 | 22 | #raise ValueError("image: {} isn't RGB mode.".format(self.images_path[item])) 23 | 24 | 25 | label = self.images_class[item] 26 | 27 | if self.transform is not None: 28 | img = self.transform(img) 29 | 30 | return img, label 31 | 32 | @staticmethod 33 | def collate_fn(batch): 34 | # 官方实现的default_collate可以参考 35 | # https://github.com/pytorch/pytorch/blob/67b7e751e6b5931a9f45274653f4f653a4e6cdf6/torch/utils/data/_utils/collate.py 36 | images, labels = tuple(zip(*batch)) 37 | 38 | images = torch.stack(images, dim=0) 39 | labels = torch.as_tensor(labels) 40 | return images, labels -------------------------------------------------------------------------------- /predict.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | 4 | import torch 5 | from PIL import Image 6 | from torchvision import transforms 7 | import matplotlib.pyplot as plt 8 | from tqdm import tqdm 9 | from model import convnext_base as create_model 10 | import os 11 | from pandas.core.frame import DataFrame 12 | paths = os.listdir(r'tests') 13 | 14 | def main(): 15 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 16 | print(f"using {device} device.") 17 | 18 | num_classes = 12 19 | img_size =224 20 | data_transform = transforms.Compose( 21 | [transforms.Resize(224), 22 | transforms.CenterCrop(img_size), 23 | transforms.ToTensor(), 24 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]) 25 | 26 | # load image 27 | clas = [] 28 | pathss = [] 29 | model = create_model(num_classes=num_classes).to(device) 30 | model_weight_path = "./weights/best_model.pth" 31 | model.load_state_dict(torch.load(model_weight_path, map_location=device)) 32 | for i in tqdm(range(len(paths))): 33 | img_path = "/kaggle/input/cats-12-end/cat_12_test/"+paths[i] 34 | assert os.path.exists(img_path), "file: '{}' dose not exist.".format(img_path) 35 | img = Image.open(img_path).convert('RGB') 36 | pathss.append(paths[i]) 37 | 38 | # [N, C, H, W] 39 | img = data_transform(img) 40 | # expand batch dimension 41 | img = torch.unsqueeze(img, dim=0) 42 | 43 | # read class_indict 44 | json_path = './class_indices.json' 45 | assert os.path.exists(json_path), "file: '{}' dose not exist.".format(json_path) 46 | 47 | with open(json_path, "r") as f: 48 | class_indict = json.load(f) 49 | 50 | # create model 51 | 52 | # load model weights 53 | 54 | model.eval() 55 | with torch.no_grad(): 56 | # predict class 57 | output = torch.squeeze(model(img.to(device))).cpu() 58 | predict = torch.softmax(output, dim=0) 59 | predict_cla = torch.argmax(predict).numpy() 60 | 61 | print_res = "class: {} prob: {:.3}".format(class_indict[str(predict_cla)], 62 | predict[predict_cla].numpy()) 63 | 64 | #print("class: {} prob: {:.3}".format(class_indict[str(list(predict.numpy()).index(predict.numpy().max()))],predict.numpy().max())) 65 | clas.append(class_indict[str(list(predict.numpy()).index(predict.numpy().max()))][:-1]) 66 | c={"a" : pathss,"b" : clas}#将列表a,b转换成字典 67 | data=DataFrame(c)#将字典转换成为数据框 68 | outputpath='results.csv' 69 | data.to_csv(outputpath,sep=',',index=False,header=False) 70 | if __name__ == '__main__': 71 | main() -------------------------------------------------------------------------------- /pretrain.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import sys 4 | import torch 5 | import torch.optim as optim 6 | from torchvision import transforms 7 | import torch 8 | import random 9 | import json 10 | from torch import optim as optim 11 | from utils import read_split_data, create_lr_scheduler,get_params_groups 12 | from fcmae_model import convnextv2_base 13 | from tqdm import tqdm 14 | 15 | 16 | from my_dataset import MyDataSet 17 | def read_split_data(root: str, val_rate: float = 0): 18 | random.seed(0) # 保证随机结果可复现 19 | assert os.path.exists(root), "dataset root: {} does not exist.".format(root) 20 | 21 | # 遍历文件夹,一个文件夹对应一个类别 22 | flower_class = [cla for cla in os.listdir(root) if os.path.isdir(os.path.join(root, cla))] 23 | # 排序,保证各平台顺序一致 24 | flower_class.sort() 25 | # 生成类别名称以及对应的数字索引 26 | class_indices = dict((k, v) for v, k in enumerate(flower_class)) 27 | json_str = json.dumps(dict((val, key) for key, val in class_indices.items()), indent=4) 28 | with open('class_indices.json', 'w') as json_file: 29 | json_file.write(json_str) 30 | 31 | train_images_path = [] # 存储训练集的所有图片路径 32 | train_images_label = [] # 存储训练集图片对应索引信息 33 | val_images_path = [] # 存储验证集的所有图片路径 34 | val_images_label = [] # 存储验证集图片对应索引信息 35 | every_class_num = [] # 存储每个类别的样本总数 36 | supported = [".jpg", ".JPG", ".png", ".PNG"] # 支持的文件后缀类型 37 | # 遍历每个文件夹下的文件 38 | for cla in flower_class: 39 | cla_path = os.path.join(root, cla) 40 | # 遍历获取supported支持的所有文件路径 41 | images = [os.path.join(root, cla, i) for i in os.listdir(cla_path) 42 | if os.path.splitext(i)[-1] in supported] 43 | # 排序,保证各平台顺序一致 44 | images.sort() 45 | # 获取该类别对应的索引 46 | image_class = class_indices[cla] 47 | # 记录该类别的样本数量 48 | every_class_num.append(len(images)) 49 | # 按比例随机采样验证样本 50 | val_path = random.sample(images, k=int(len(images) * val_rate)) 51 | 52 | for img_path in images: 53 | if img_path in val_path: # 如果该路径在采样的验证集样本中则存入验证集 54 | val_images_path.append(img_path) 55 | val_images_label.append(image_class) 56 | else: # 否则存入训练集 57 | train_images_path.append(img_path) 58 | train_images_label.append(image_class) 59 | 60 | 61 | return train_images_path, train_images_label, val_images_path, val_images_label 62 | 63 | def train_one_epoch(model, optimizer, data_loader, device, epoch, lr_scheduler): 64 | model.train() 65 | 66 | accu_loss = torch.zeros(1).to(device) # 累计损失 67 | accu_num = torch.zeros(1).to(device) # 累计预测正确的样本数 68 | optimizer.zero_grad() 69 | 70 | sample_num = 0 71 | data_loader = tqdm(data_loader, file=sys.stdout) 72 | for step, data in enumerate(data_loader): 73 | images, labels = data 74 | sample_num += images.shape[0] 75 | 76 | loss, pred, mask = model(images.to(device)) 77 | 78 | 79 | loss.backward() 80 | accu_loss += loss.detach() 81 | 82 | data_loader.desc = "[train epoch {}] loss: {:.3f}, lr: {:.5f}".format( 83 | epoch, 84 | accu_loss.item() / (step + 1), 85 | optimizer.param_groups[0]["lr"] 86 | ) 87 | 88 | if not torch.isfinite(loss): 89 | print('WARNING: non-finite loss, ending training ', loss) 90 | sys.exit(1) 91 | 92 | optimizer.step() 93 | optimizer.zero_grad() 94 | # update lr 95 | lr_scheduler.step() 96 | 97 | return accu_loss.item() / (step + 1) 98 | 99 | 100 | def main(args): 101 | device = torch.device(args.device if torch.cuda.is_available() else "cpu") 102 | print(f"using {device} device.") 103 | 104 | if os.path.exists("./weights") is False: 105 | os.makedirs("./weights") 106 | 107 | 108 | 109 | train_images_path, train_images_label, val_images_path, val_images_label = read_split_data(args.data_path) 110 | 111 | img_size = 224 112 | data_transform = { 113 | "train": transforms.Compose([transforms.Resize((img_size,img_size)), 114 | transforms.CenterCrop(img_size), 115 | transforms.RandomHorizontalFlip(p=0.5), 116 | transforms.ToTensor(), 117 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])} 118 | train_dataset = MyDataSet(images_path=train_images_path, 119 | images_class=train_images_label, 120 | transform=data_transform["train"]) 121 | batch_size = args.batch_size 122 | nw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8]) # number of workers 123 | print('Using {} dataloader workers every process'.format(nw)) 124 | train_loader = torch.utils.data.DataLoader(train_dataset, 125 | batch_size=batch_size, 126 | shuffle=True, 127 | pin_memory=True, 128 | num_workers=nw, 129 | collate_fn=train_dataset.collate_fn) 130 | model = convnextv2_base().to(device) 131 | pg = get_params_groups(model, weight_decay=args.wd) 132 | optimizer = optim.AdamW(pg, lr=args.lr, weight_decay=args.wd) 133 | lr_scheduler = create_lr_scheduler(optimizer, len(train_loader), args.epochs, 134 | warmup=True, warmup_epochs=1) 135 | 136 | for epoch in range(args.epochs): 137 | # train 138 | train_loss = train_one_epoch(model=model, 139 | optimizer=optimizer, 140 | data_loader=train_loader, 141 | device=device, 142 | epoch=epoch, 143 | lr_scheduler=lr_scheduler) 144 | 145 | # validate 146 | 147 | 148 | 149 | torch.save(model.state_dict(), "./weights/pre_last_model.pth") 150 | 151 | 152 | 153 | if __name__ == '__main__': 154 | parser = argparse.ArgumentParser() 155 | parser.add_argument('--num_classes', type=int, default=5) 156 | parser.add_argument('--epochs', type=int, default=10) 157 | parser.add_argument('--batch-size', type=int, default=8) 158 | parser.add_argument('--lr', type=float, default=2e-5) 159 | parser.add_argument('--wd', type=float, default=5e-2) 160 | 161 | # 数据集所在根目录 162 | # https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz 163 | parser.add_argument('--data-path', type=str, 164 | default="flower_photos") 165 | 166 | 167 | parser.add_argument('--weights', type=str, default='./convnext_tiny_1k_224_ema.pth', 168 | help='initial weights path') 169 | # 是否冻结head以外所有权重 170 | parser.add_argument('--freeze-layers', type=bool, default=False) 171 | parser.add_argument('--device', default='cuda:0', help='device id (i.e. 0 or 0,1 or cpu)') 172 | 173 | opt = parser.parse_args() 174 | 175 | main(opt) 176 | -------------------------------------------------------------------------------- /readme.md: -------------------------------------------------------------------------------- 1 | # ConvNeXt-V2 pytorch 复现(提醒:这是复现,重要参考官方代码) 2 | [`训练Training..........`](https://github.com/Jacky-Android/convnext-v2-pytorch/tree/main#%E4%BB%A3%E7%A0%81%E4%BD%BF%E7%94%A8%E7%AE%80%E4%BB%8B) 3 | # 2023年10月11日更新 4 | ## 加入FCMAE 5 | 全卷积掩码自编码器(FCMAE)框架是一种基于卷积神经网络的自监督学习方法,它的思想是在输入图像上随机掩盖一些区域,然后让模型尝试恢复被掩盖的部分。这样可以迫使模型学习到图像的全局和局部特征,从而提高其泛化能力。 6 | 7 | FCMAE 框架与传统的掩码自编码器(MAE)框架相比,有两个优势:一是它使用了全卷积结构,而不是使用全连接层来生成掩码和重建图像,这样可以减少参数量和计算量,同时保持空间信息;二是它使用了多尺度掩码策略,而不是使用固定大小的掩码,这样可以增加模型对不同尺度特征的感知能力。 8 | [FCMAE(fully convolutional masked autoencoder framework)](https://github.com/Jacky-Android/convnext-v2-pytorch/blob/main/fcmae_model.py) 9 | 10 | ![image](https://github.com/Jacky-Android/convnext-v2-pytorch/assets/55181594/cb3f3944-c0b6-4bba-86b3-d38f75fadcc6) 11 | 12 | 输入tensor[1,3,224,224],返回loss,pred,mask 13 | 14 | torch.Size([]) torch.Size([1, 3072, 7, 7]) torch.Size([1, 49]) 15 | # FCMAE训练文件更新 16 | [pretrain.py](https://github.com/Jacky-Android/convnext-v2-pytorch/blob/main/pretrain.py) 17 | ### torchinfo输出代码 18 | ```python 19 | from fcmae_model import convnextv2_pico 20 | from torchinfo import summary 21 | import torch 22 | 23 | #use pico 24 | models = convnextv2_pico().cuda() 25 | x = torch.randn([1,3,224,224]).cuda() 26 | print(models(x)[0],models(x)[1].shape,models(x)[2].shape) 27 | out = summary(models, (1, 3, 224,224)) 28 | ``` 29 | ### 模型torchinfo输出 30 | ```python 31 | =============================================================================================== 32 | Layer (type:depth-idx) Output Shape Param # 33 | =============================================================================================== 34 | FCMAE -- 512 35 | ├─SparseConvNeXtV2: 1-1 [1, 512, 7, 7] -- 36 | │ └─ModuleList: 2-7 -- (recursive) 37 | │ │ └─Sequential: 3-1 [1, 64, 56, 56] 3,264 38 | │ └─ModuleList: 2-8 -- (recursive) 39 | │ │ └─Sequential: 3-2 [1, 64, 56, 56] 73,856 40 | │ └─ModuleList: 2-7 -- (recursive) 41 | │ │ └─Sequential: 3-3 [1, 128, 28, 28] 33,024 42 | │ └─ModuleList: 2-8 -- (recursive) 43 | │ │ └─Sequential: 3-4 [1, 128, 28, 28] 278,784 44 | │ └─ModuleList: 2-7 -- (recursive) 45 | │ │ └─Sequential: 3-5 [1, 256, 14, 14] 131,584 46 | │ └─ModuleList: 2-8 -- (recursive) 47 | │ │ └─Sequential: 3-6 [1, 256, 14, 14] 3,245,568 48 | │ └─ModuleList: 2-7 -- (recursive) 49 | │ │ └─Sequential: 3-7 [1, 512, 7, 7] 525,312 50 | │ └─ModuleList: 2-8 -- (recursive) 51 | │ │ └─Sequential: 3-8 [1, 512, 7, 7] 4,260,864 52 | ├─Conv2d: 1-2 [1, 512, 7, 7] 262,656 53 | ├─Sequential: 1-3 [1, 512, 7, 7] -- 54 | │ └─Block: 2-9 [1, 512, 7, 7] -- 55 | │ │ └─Conv2d: 3-9 [1, 512, 7, 7] 25,600 56 | │ │ └─LayerNorm: 3-10 [1, 7, 7, 512] 1,024 57 | │ │ └─Linear: 3-11 [1, 7, 7, 2048] 1,050,624 58 | │ │ └─GELU: 3-12 [1, 7, 7, 2048] -- 59 | │ │ └─GRN: 3-13 [1, 7, 7, 2048] 4,096 60 | │ │ └─Linear: 3-14 [1, 7, 7, 512] 1,049,088 61 | │ │ └─Identity: 3-15 [1, 512, 7, 7] -- 62 | ├─Conv2d: 1-4 [1, 3072, 7, 7] 1,575,936 63 | =============================================================================================== 64 | Total params: 12,521,792 65 | Trainable params: 12,521,792 66 | Non-trainable params: 0 67 | Total mult-adds (M): 235.88 68 | =============================================================================================== 69 | Input size (MB): 0.60 70 | Forward/backward pass size (MB): 94.93 71 | Params size (MB): 50.09 72 | Estimated Total Size (MB): 145.62 73 | =============================================================================================== 74 | ``` 75 | ## 代码使用简介 76 | [参考代码](https://github.com/facebookresearch/ConvNeXt-V2) 77 | 78 | [论文ConvNeXt V2: Co-designing and Scaling ConvNets with Masked Autoencoders](https://arxiv.org/abs/2301.00808) 79 | 80 | 81 | 1. 下载好数据集,代码中默认使用的是花分类数据集,下载地址: [https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz](https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz), 82 | 如果下载不了的话可以通过kaggle链接下载: https://www.kaggle.com/datasets/l3llff/flowers 83 | 2. 在`train.py`脚本中将`--data-path`设置成解压后的`flower_photos`文件夹绝对路径 84 | 3. 下载预训练权重,在`model.py`文件中每个模型都有提供预训练权重的下载地址,根据自己使用的模型下载对应预训练权重 85 | 4. 在`train.py`脚本中将`--weights`参数设成下载好的预训练权重路径 86 | 5. 设置好数据集的路径`--data-path`以及预训练权重的路径`--weights`就能使用`train.py`脚本开始训练了(训练过程中会自动生成`class_indices.json`文件) 87 | 6. 在`predict.py`脚本中导入和训练脚本中同样的模型,并将`model_weight_path`设置成训练好的模型权重路径(默认保存在weights文件夹下) 88 | 7. 在`predict.py`脚本中将`img_path`设置成你自己需要预测的图片的文件夹绝对路径,最后生成results.csv 89 | 8. 设置好权重路径`model_weight_path`和预测的图片路径`img_path`就能使用`predict.py`脚本进行预测了 90 | 9. 如果要使用自己的数据集,请按照花分类数据集的文件结构进行摆放(即一个类别对应一个文件夹),并且将训练以及预测脚本中的`num_classes`设置成你自己数据的类别数 91 | 92 | ## Results and Pre-trained Models 93 | ### ImageNet-1K FCMAE pre-trained weights (*self-supervised*) 94 | | name | resolution | #params | model | 95 | |:---:|:---:|:---:|:---:| 96 | | ConvNeXt V2-A | 224x224 | 3.7M | [model](https://dl.fbaipublicfiles.com/convnext/convnextv2/pt_only/convnextv2_atto_1k_224_fcmae.pt) | 97 | | ConvNeXt V2-F | 224x224 | 5.2M | [model](https://dl.fbaipublicfiles.com/convnext/convnextv2/pt_only/convnextv2_femto_1k_224_fcmae.pt) | 98 | | ConvNeXt V2-P | 224x224 | 9.1M | [model](https://dl.fbaipublicfiles.com/convnext/convnextv2/pt_only/convnextv2_pico_1k_224_fcmae.pt) | 99 | | ConvNeXt V2-N | 224x224 | 15.6M| [model](https://dl.fbaipublicfiles.com/convnext/convnextv2/pt_only/convnextv2_nano_1k_224_fcmae.pt) | 100 | | ConvNeXt V2-T | 224x224 | 28.6M| [model](https://dl.fbaipublicfiles.com/convnext/convnextv2/pt_only/convnextv2_tiny_1k_224_fcmae.pt) | 101 | | ConvNeXt V2-B | 224x224 | 89M | [model](https://dl.fbaipublicfiles.com/convnext/convnextv2/pt_only/convnextv2_base_1k_224_fcmae.pt) | 102 | | ConvNeXt V2-L | 224x224 | 198M | [model](https://dl.fbaipublicfiles.com/convnext/convnextv2/pt_only/convnextv2_large_1k_224_fcmae.pt) | 103 | | ConvNeXt V2-H | 224x224 | 660M | [model](https://dl.fbaipublicfiles.com/convnext/convnextv2/pt_only/convnextv2_huge_1k_224_fcmae.pt) | 104 | 105 | ### ImageNet-1K fine-tuned models 106 | 107 | | name | resolution |acc@1 | #params | FLOPs | model | 108 | |:---:|:---:|:---:|:---:| :---:|:---:| 109 | | ConvNeXt V2-A | 224x224 | 76.7 | 3.7M | 0.55G | [model](https://dl.fbaipublicfiles.com/convnext/convnextv2/im1k/convnextv2_atto_1k_224_ema.pt) | 110 | | ConvNeXt V2-F | 224x224 | 78.5 | 5.2M | 0.78G | [model](https://dl.fbaipublicfiles.com/convnext/convnextv2/im1k/convnextv2_femto_1k_224_ema.pt) | 111 | | ConvNeXt V2-P | 224x224 | 80.3 | 9.1M | 1.37G | [model](https://dl.fbaipublicfiles.com/convnext/convnextv2/im1k/convnextv2_pico_1k_224_ema.pt) | 112 | | ConvNeXt V2-N | 224x224 | 81.9 | 15.6M | 2.45G | [model](https://dl.fbaipublicfiles.com/convnext/convnextv2/im1k/convnextv2_nano_1k_224_ema.pt) | 113 | | ConvNeXt V2-T | 224x224 | 83.0 | 28.6M | 4.47G | [model](https://dl.fbaipublicfiles.com/convnext/convnextv2/im1k/convnextv2_tiny_1k_224_ema.pt) | 114 | | ConvNeXt V2-B | 224x224 | 84.9 | 89M | 15.4G | [model](https://dl.fbaipublicfiles.com/convnext/convnextv2/im1k/convnextv2_base_1k_224_ema.pt) | 115 | | ConvNeXt V2-L | 224x224 | 85.8 | 198M | 34.4G | [model](https://dl.fbaipublicfiles.com/convnext/convnextv2/im1k/convnextv2_large_1k_224_ema.pt) | 116 | | ConvNeXt V2-H | 224x224 | 86.3 | 660M | 115G | [model](https://dl.fbaipublicfiles.com/convnext/convnextv2/im1k/convnextv2_huge_1k_224_ema.pt) | 117 | 118 | ### ImageNet-22K fine-tuned models 119 | 120 | | name | resolution |acc@1 | #params | FLOPs | model | 121 | |:---:|:---:|:---:|:---:| :---:| :---:| 122 | | ConvNeXt V2-N | 224x224 | 82.1 | 15.6M | 2.45G | [model](https://dl.fbaipublicfiles.com/convnext/convnextv2/im22k/convnextv2_nano_22k_224_ema.pt)| 123 | | ConvNeXt V2-N | 384x384 | 83.4 | 15.6M | 7.21G | [model](https://dl.fbaipublicfiles.com/convnext/convnextv2/im22k/convnextv2_nano_22k_384_ema.pt)| 124 | | ConvNeXt V2-T | 224x224 | 83.9 | 28.6M | 4.47G | [model](https://dl.fbaipublicfiles.com/convnext/convnextv2/im22k/convnextv2_tiny_22k_224_ema.pt)| 125 | | ConvNeXt V2-T | 384x384 | 85.1 | 28.6M | 13.1G | [model](https://dl.fbaipublicfiles.com/convnext/convnextv2/im22k/convnextv2_tiny_22k_384_ema.pt)| 126 | | ConvNeXt V2-B | 224x224 | 86.8 | 89M | 15.4G | [model](https://dl.fbaipublicfiles.com/convnext/convnextv2/im22k/convnextv2_base_22k_224_ema.pt)| 127 | | ConvNeXt V2-B | 384x384 | 87.7 | 89M | 45.2G | [model](https://dl.fbaipublicfiles.com/convnext/convnextv2/im22k/convnextv2_base_22k_384_ema.pt)| 128 | | ConvNeXt V2-L | 224x224 | 87.3 | 198M | 34.4G | [model](https://dl.fbaipublicfiles.com/convnext/convnextv2/im22k/convnextv2_large_22k_224_ema.pt)| 129 | | ConvNeXt V2-L | 384x384 | 88.2 | 198M | 101.1G | [model](https://dl.fbaipublicfiles.com/convnext/convnextv2/im22k/convnextv2_large_22k_384_ema.pt)| 130 | | ConvNeXt V2-H | 384x384 | 88.7 | 660M | 337.9G | [model](https://dl.fbaipublicfiles.com/convnext/convnextv2/im22k/convnextv2_huge_22k_384_ema.pt)| 131 | | ConvNeXt V2-H | 512x512 | 88.9 | 660M | 600.8G | [model](https://dl.fbaipublicfiles.com/convnext/convnextv2/im22k/convnextv2_huge_22k_512_ema.pt)| 132 | ## Star History 133 | 134 | [![Star History Chart](https://api.star-history.com/svg?repos=Jacky-Android/convnext-v2-pytorch&type=Date)](https://star-history.com/#Jacky-Android/convnext-v2-pytorch&Date) 135 | -------------------------------------------------------------------------------- /sparse_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from timm.models.layers import trunc_normal_,DropPath 4 | import torch.nn.functional as F 5 | 6 | #trunc_normal_是一个用于初始化神经网络参数的函数,它可以用截断的正态分布来填充输入张量。截断的正态分布是指在一定范围内的正态分布,例如 [a, b],如果生成的值超出这个范围,就重新生成,直到在范围内为止。这样可以避免生成一些极端的值,影响网络的训练。 7 | class LayerNorm(nn.Module): 8 | r""" LayerNorm that supports two data formats: channels_last (default) or channels_first. 9 | The ordering of the dimensions in the inputs. channels_last corresponds to inputs with 10 | shape (batch_size, height, width, channels) while channels_first corresponds to inputs 11 | with shape (batch_size, channels, height, width). 12 | """ 13 | 14 | def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"): 15 | super().__init__() 16 | self.weight = nn.Parameter(torch.ones(normalized_shape), requires_grad=True) 17 | self.bias = nn.Parameter(torch.zeros(normalized_shape), requires_grad=True) 18 | self.eps = eps 19 | self.data_format = data_format 20 | if self.data_format not in ["channels_last", "channels_first"]: 21 | raise ValueError(f"not support data format '{self.data_format}'") 22 | self.normalized_shape = (normalized_shape,) 23 | 24 | def forward(self, x: torch.Tensor) -> torch.Tensor: 25 | if self.data_format == "channels_last": 26 | return F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps) 27 | elif self.data_format == "channels_first": 28 | # [batch_size, channels, height, width] 29 | mean = x.mean(1, keepdim=True) 30 | var = (x - mean).pow(2).mean(1, keepdim=True) 31 | x = (x - mean) / torch.sqrt(var + self.eps) 32 | x = self.weight[:, None, None] * x + self.bias[:, None, None] 33 | return x 34 | 35 | class GRN(nn.Module): 36 | """ GRN (Global Response Normalization) layer 37 | """ 38 | def __init__(self, dim): 39 | super().__init__() 40 | self.gamma = nn.Parameter(torch.zeros(1, 1, 1, dim)) 41 | self.beta = nn.Parameter(torch.zeros(1, 1, 1, dim)) 42 | 43 | def forward(self, x): 44 | Gx = torch.norm(x, p=2, dim=(1,2), keepdim=True) 45 | Nx = Gx / (Gx.mean(dim=-1, keepdim=True) + 1e-6) 46 | return self.gamma * (x * Nx) + self.beta + x 47 | 48 | class Block(nn.Module): 49 | def __init__(self, dim, drop_path=0.): 50 | super().__init__() 51 | 52 | #官方源代码self.dwconv = MinkowskiDepthwiseConvolution(dim, kernel_size=7, bias=True, dimension=D) 53 | 54 | self.dwconv = nn.Conv2d(dim,dim, kernel_size=7, groups=dim,padding=3)#depthwise Conv 55 | self.norm = LayerNorm(dim, eps=1e-6) 56 | self.pwconv1 = nn.Linear(dim, 4 * dim) 57 | self.act = nn.GELU() 58 | self.pwconv2 = nn.Linear(4 * dim, dim) 59 | self.grn = GRN(4 * dim) 60 | self.drop_path = DropPath(drop_path) 61 | 62 | def forward(self, x): 63 | input = x 64 | x = self.dwconv(x) 65 | x = x.permute(0, 2, 3, 1) 66 | #print(x.shape) 67 | x = self.norm(x) 68 | 69 | x = self.pwconv1(x) 70 | x = self.act(x) 71 | x = self.grn(x) 72 | x = self.pwconv2(x) 73 | x = x.permute(0, 3,1, 2) 74 | x = input + self.drop_path(x) 75 | return x 76 | 77 | class SparseConvNeXtV2(nn.Module): 78 | """ Sparse ConvNeXtV2. 79 | 80 | Args: 81 | in_chans (int): Number of input image channels. Default: 3 82 | num_classes (int): Number of classes for classification head. Default: 1000 83 | depths (tuple(int)): Number of blocks at each stage. Default: [3, 3, 9, 3] 84 | dims (int): Feature dimension at each stage. Default: [96, 192, 384, 768] 85 | drop_path_rate (float): Stochastic depth rate. Default: 0. 86 | head_init_scale (float): Init scaling value for classifier weights and biases. Default: 1. 87 | """ 88 | def __init__(self, 89 | in_chans=3, 90 | num_classes=1000, 91 | depths=[3, 3, 9, 3], 92 | dims=[96, 192, 384, 768], 93 | drop_path_rate=0., 94 | D=2): 95 | super().__init__() 96 | self.depths = depths 97 | self.num_classes = num_classes 98 | self.downsample_layers = nn.ModuleList() # stem and 3 intermediate downsampling conv layers 99 | #这里使用LayerNorm,data_format在LayerNorm是pytorch2.1的写法,所以重写了LayerNorm类 100 | 101 | 102 | stem = nn.Sequential( 103 | nn.Conv2d(in_chans, dims[0], kernel_size=4, stride=4), 104 | LayerNorm(dims[0], eps=1e-6,data_format="channels_first") 105 | ) 106 | self.downsample_layers.append(stem) 107 | for i in range(3): 108 | downsample_layer = nn.Sequential( 109 | 110 | LayerNorm(dims[i], eps=1e-6,data_format="channels_first"), 111 | nn.Conv2d(dims[i], dims[i+1], kernel_size=2, stride=2) 112 | ) 113 | self.downsample_layers.append(downsample_layer) 114 | 115 | self.stages = nn.ModuleList() # 4 feature resolution stages, each consisting of multiple residual blocks 116 | dp_rates=[x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] 117 | cur = 0 118 | for i in range(4): 119 | stage = nn.Sequential( 120 | *[Block(dim=dims[i], drop_path=dp_rates[cur + j]) for j in range(depths[i])] 121 | ) 122 | self.stages.append(stage) 123 | cur += depths[i] 124 | 125 | #self.apply(self._init_weights) 126 | 127 | 128 | 129 | def upsample_mask(self, mask, scale): 130 | assert len(mask.shape) == 2 131 | p = int(mask.shape[1] ** .5) 132 | return mask.reshape(-1, p, p).\ 133 | repeat_interleave(scale, axis=1).\ 134 | repeat_interleave(scale, axis=2) 135 | 136 | def forward(self, x, mask): 137 | num_stages = len(self.stages) 138 | mask = self.upsample_mask(mask, 2**(num_stages-1)) 139 | mask = mask.unsqueeze(1).type_as(x) 140 | 141 | # patch embedding 142 | x = self.downsample_layers[0](x) 143 | x *= (1.-mask) 144 | 145 | # sparse encoding 146 | #x = torch.Tensor.to_sparse(x) 147 | 148 | for i in range(4): 149 | x = self.downsample_layers[i](x) if i > 0 else x 150 | x = self.stages[i](x) 151 | 152 | # densify 153 | #x = x.dense()[0] 154 | return x -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | 4 | import torch 5 | import torch.optim as optim 6 | from torch.utils.tensorboard import SummaryWriter 7 | from torchvision import transforms 8 | import torch 9 | from torch import optim as optim 10 | 11 | from my_dataset import MyDataSet 12 | from model import convnextv2_base as create_model 13 | from utils import read_split_data, create_lr_scheduler, get_params_groups, train_one_epoch, evaluate 14 | 15 | 16 | def main(args): 17 | device = torch.device(args.device if torch.cuda.is_available() else "cpu") 18 | print(f"using {device} device.") 19 | 20 | if os.path.exists("./weights") is False: 21 | os.makedirs("./weights") 22 | 23 | tb_writer = SummaryWriter() 24 | 25 | train_images_path, train_images_label, val_images_path, val_images_label = read_split_data(args.data_path) 26 | 27 | img_size = 384 28 | data_transform = { 29 | "train": transforms.Compose([transforms.Resize((img_size,img_size)), 30 | transforms.CenterCrop(img_size), 31 | transforms.RandomHorizontalFlip(p=0.5), 32 | transforms.ToTensor(), 33 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]), 34 | 35 | "val": transforms.Compose([transforms.Resize((img_size,img_size)), 36 | transforms.CenterCrop(img_size), 37 | transforms.ToTensor(), 38 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])} 39 | 40 | # 实例化训练数据集 41 | #mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375] 42 | train_dataset = MyDataSet(images_path=train_images_path, 43 | images_class=train_images_label, 44 | transform=data_transform["train"]) 45 | 46 | # 实例化验证数据集 47 | val_dataset = MyDataSet(images_path=val_images_path, 48 | images_class=val_images_label, 49 | transform=data_transform["val"]) 50 | 51 | batch_size = args.batch_size 52 | nw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8]) # number of workers 53 | print('Using {} dataloader workers every process'.format(nw)) 54 | train_loader = torch.utils.data.DataLoader(train_dataset, 55 | batch_size=batch_size, 56 | shuffle=True, 57 | pin_memory=True, 58 | num_workers=nw, 59 | collate_fn=train_dataset.collate_fn) 60 | 61 | val_loader = torch.utils.data.DataLoader(val_dataset, 62 | batch_size=batch_size, 63 | shuffle=False, 64 | pin_memory=True, 65 | num_workers=nw, 66 | collate_fn=val_dataset.collate_fn) 67 | 68 | model = create_model(num_classes=args.num_classes).to(device) 69 | 70 | if args.weights != "": 71 | assert os.path.exists(args.weights), "weights file: '{}' not exist.".format(args.weights) 72 | weights_dict = torch.load(args.weights, map_location=device)["model"] 73 | # 删除有关分类类别的权重 74 | for k in list(weights_dict.keys()): 75 | if "head" in k: 76 | del weights_dict[k] 77 | print(model.load_state_dict(weights_dict, strict=False)) 78 | 79 | if args.freeze_layers: 80 | for name, para in model.named_parameters(): 81 | # 除head外,其他权重全部冻结 82 | if "head" not in name: 83 | para.requires_grad_(False) 84 | else: 85 | print("training {}".format(name)) 86 | 87 | # pg = [p for p in model.parameters() if p.requires_grad] 88 | pg = get_params_groups(model, weight_decay=args.wd) 89 | optimizer = optim.AdamW(pg, lr=args.lr, weight_decay=args.wd) 90 | lr_scheduler = create_lr_scheduler(optimizer, len(train_loader), args.epochs, 91 | warmup=True, warmup_epochs=1) 92 | 93 | best_acc = 0. 94 | for epoch in range(args.epochs): 95 | # train 96 | train_loss, train_acc = train_one_epoch(model=model, 97 | optimizer=optimizer, 98 | data_loader=train_loader, 99 | device=device, 100 | epoch=epoch, 101 | lr_scheduler=lr_scheduler) 102 | 103 | # validate 104 | val_loss, val_acc = evaluate(model=model, 105 | data_loader=val_loader, 106 | device=device, 107 | epoch=epoch) 108 | 109 | tags = ["train_loss", "train_acc", "val_loss", "val_acc", "learning_rate"] 110 | tb_writer.add_scalar(tags[0], train_loss, epoch) 111 | tb_writer.add_scalar(tags[1], train_acc, epoch) 112 | tb_writer.add_scalar(tags[2], val_loss, epoch) 113 | tb_writer.add_scalar(tags[3], val_acc, epoch) 114 | tb_writer.add_scalar(tags[4], optimizer.param_groups[0]["lr"], epoch) 115 | 116 | if best_acc < val_acc: 117 | torch.save(model.state_dict(), "./weights/best_model.pth") 118 | best_acc = val_acc 119 | 120 | 121 | if __name__ == '__main__': 122 | parser = argparse.ArgumentParser() 123 | parser.add_argument('--num_classes', type=int, default=5) 124 | parser.add_argument('--epochs', type=int, default=10) 125 | parser.add_argument('--batch-size', type=int, default=8) 126 | parser.add_argument('--lr', type=float, default=2e-5) 127 | parser.add_argument('--wd', type=float, default=5e-2) 128 | 129 | # 数据集所在根目录 130 | # https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz 131 | parser.add_argument('--data-path', type=str, 132 | default="flower_photos") 133 | 134 | # 预训练权重路径,如果不想载入就设置为空字符 135 | # 链接: https://pan.baidu.com/s/1aNqQW4n_RrUlWUBNlaJRHA 密码: i83t 136 | parser.add_argument('--weights', type=str, default='./convnext_tiny_1k_224_ema.pth', 137 | help='initial weights path') 138 | # 是否冻结head以外所有权重 139 | parser.add_argument('--freeze-layers', type=bool, default=False) 140 | parser.add_argument('--device', default='cuda:0', help='device id (i.e. 0 or 0,1 or cpu)') 141 | 142 | opt = parser.parse_args() 143 | 144 | main(opt) -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import json 4 | import pickle 5 | import random 6 | import math 7 | 8 | import torch 9 | from tqdm import tqdm 10 | 11 | import matplotlib.pyplot as plt 12 | 13 | 14 | def read_split_data(root: str, val_rate: float = 0.2): 15 | random.seed(0) # 保证随机结果可复现 16 | assert os.path.exists(root), "dataset root: {} does not exist.".format(root) 17 | 18 | # 遍历文件夹,一个文件夹对应一个类别 19 | flower_class = [cla for cla in os.listdir(root) if os.path.isdir(os.path.join(root, cla))] 20 | # 排序,保证各平台顺序一致 21 | flower_class.sort() 22 | # 生成类别名称以及对应的数字索引 23 | class_indices = dict((k, v) for v, k in enumerate(flower_class)) 24 | json_str = json.dumps(dict((val, key) for key, val in class_indices.items()), indent=4) 25 | with open('class_indices.json', 'w') as json_file: 26 | json_file.write(json_str) 27 | 28 | train_images_path = [] # 存储训练集的所有图片路径 29 | train_images_label = [] # 存储训练集图片对应索引信息 30 | val_images_path = [] # 存储验证集的所有图片路径 31 | val_images_label = [] # 存储验证集图片对应索引信息 32 | every_class_num = [] # 存储每个类别的样本总数 33 | supported = [".jpg", ".JPG", ".png", ".PNG"] # 支持的文件后缀类型 34 | # 遍历每个文件夹下的文件 35 | for cla in flower_class: 36 | cla_path = os.path.join(root, cla) 37 | # 遍历获取supported支持的所有文件路径 38 | images = [os.path.join(root, cla, i) for i in os.listdir(cla_path) 39 | if os.path.splitext(i)[-1] in supported] 40 | # 排序,保证各平台顺序一致 41 | images.sort() 42 | # 获取该类别对应的索引 43 | image_class = class_indices[cla] 44 | # 记录该类别的样本数量 45 | every_class_num.append(len(images)) 46 | # 按比例随机采样验证样本 47 | val_path = random.sample(images, k=int(len(images) * val_rate)) 48 | 49 | for img_path in images: 50 | if img_path in val_path: # 如果该路径在采样的验证集样本中则存入验证集 51 | val_images_path.append(img_path) 52 | val_images_label.append(image_class) 53 | else: # 否则存入训练集 54 | train_images_path.append(img_path) 55 | train_images_label.append(image_class) 56 | 57 | print("{} images were found in the dataset.".format(sum(every_class_num))) 58 | print("{} images for training.".format(len(train_images_path))) 59 | print("{} images for validation.".format(len(val_images_path))) 60 | assert len(train_images_path) > 0, "number of training images must greater than 0." 61 | assert len(val_images_path) > 0, "number of validation images must greater than 0." 62 | 63 | plot_image = False 64 | if plot_image: 65 | # 绘制每种类别个数柱状图 66 | plt.bar(range(len(flower_class)), every_class_num, align='center') 67 | # 将横坐标0,1,2,3,4替换为相应的类别名称 68 | plt.xticks(range(len(flower_class)), flower_class) 69 | # 在柱状图上添加数值标签 70 | for i, v in enumerate(every_class_num): 71 | plt.text(x=i, y=v + 5, s=str(v), ha='center') 72 | # 设置x坐标 73 | plt.xlabel('image class') 74 | # 设置y坐标 75 | plt.ylabel('number of images') 76 | # 设置柱状图的标题 77 | plt.title('flower class distribution') 78 | plt.show() 79 | 80 | return train_images_path, train_images_label, val_images_path, val_images_label 81 | 82 | 83 | def plot_data_loader_image(data_loader): 84 | batch_size = data_loader.batch_size 85 | plot_num = min(batch_size, 4) 86 | 87 | json_path = './class_indices.json' 88 | assert os.path.exists(json_path), json_path + " does not exist." 89 | json_file = open(json_path, 'r') 90 | class_indices = json.load(json_file) 91 | 92 | for data in data_loader: 93 | images, labels = data 94 | for i in range(plot_num): 95 | # [C, H, W] -> [H, W, C] 96 | img = images[i].numpy().transpose(1, 2, 0) 97 | # 反Normalize操作 98 | img = (img * [0.229, 0.224, 0.225] + [0.485, 0.456, 0.406]) * 255 99 | label = labels[i].item() 100 | plt.subplot(1, plot_num, i+1) 101 | plt.xlabel(class_indices[str(label)]) 102 | plt.xticks([]) # 去掉x轴的刻度 103 | plt.yticks([]) # 去掉y轴的刻度 104 | plt.imshow(img.astype('uint8')) 105 | plt.show() 106 | 107 | 108 | def write_pickle(list_info: list, file_name: str): 109 | with open(file_name, 'wb') as f: 110 | pickle.dump(list_info, f) 111 | 112 | 113 | def read_pickle(file_name: str) -> list: 114 | with open(file_name, 'rb') as f: 115 | info_list = pickle.load(f) 116 | return info_list 117 | 118 | 119 | def train_one_epoch(model, optimizer, data_loader, device, epoch, lr_scheduler): 120 | model.train() 121 | loss_function = torch.nn.CrossEntropyLoss() 122 | accu_loss = torch.zeros(1).to(device) # 累计损失 123 | accu_num = torch.zeros(1).to(device) # 累计预测正确的样本数 124 | optimizer.zero_grad() 125 | 126 | sample_num = 0 127 | data_loader = tqdm(data_loader, file=sys.stdout) 128 | for step, data in enumerate(data_loader): 129 | images, labels = data 130 | sample_num += images.shape[0] 131 | 132 | pred = model(images.to(device)) 133 | pred_classes = torch.max(pred, dim=1)[1] 134 | accu_num += torch.eq(pred_classes, labels.to(device)).sum() 135 | 136 | loss = loss_function(pred, labels.to(device)) 137 | loss.backward() 138 | accu_loss += loss.detach() 139 | 140 | data_loader.desc = "[train epoch {}] loss: {:.3f}, acc: {:.3f}, lr: {:.5f}".format( 141 | epoch, 142 | accu_loss.item() / (step + 1), 143 | accu_num.item() / sample_num, 144 | optimizer.param_groups[0]["lr"] 145 | ) 146 | 147 | if not torch.isfinite(loss): 148 | print('WARNING: non-finite loss, ending training ', loss) 149 | sys.exit(1) 150 | 151 | optimizer.step() 152 | optimizer.zero_grad() 153 | # update lr 154 | lr_scheduler.step() 155 | 156 | return accu_loss.item() / (step + 1), accu_num.item() / sample_num 157 | 158 | 159 | @torch.no_grad() 160 | def evaluate(model, data_loader, device, epoch): 161 | loss_function = torch.nn.CrossEntropyLoss() 162 | 163 | model.eval() 164 | 165 | accu_num = torch.zeros(1).to(device) # 累计预测正确的样本数 166 | accu_loss = torch.zeros(1).to(device) # 累计损失 167 | 168 | sample_num = 0 169 | data_loader = tqdm(data_loader, file=sys.stdout) 170 | for step, data in enumerate(data_loader): 171 | images, labels = data 172 | sample_num += images.shape[0] 173 | 174 | pred = model(images.to(device)) 175 | pred_classes = torch.max(pred, dim=1)[1] 176 | accu_num += torch.eq(pred_classes, labels.to(device)).sum() 177 | 178 | loss = loss_function(pred, labels.to(device)) 179 | accu_loss += loss 180 | 181 | data_loader.desc = "[valid epoch {}] loss: {:.3f}, acc: {:.3f}".format( 182 | epoch, 183 | accu_loss.item() / (step + 1), 184 | accu_num.item() / sample_num 185 | ) 186 | 187 | return accu_loss.item() / (step + 1), accu_num.item() / sample_num 188 | 189 | 190 | def create_lr_scheduler(optimizer, 191 | num_step: int, 192 | epochs: int, 193 | warmup=True, 194 | warmup_epochs=1, 195 | warmup_factor=1e-3, 196 | end_factor=1e-6): 197 | assert num_step > 0 and epochs > 0 198 | if warmup is False: 199 | warmup_epochs = 0 200 | 201 | def f(x): 202 | """ 203 | 根据step数返回一个学习率倍率因子, 204 | 注意在训练开始之前,pytorch会提前调用一次lr_scheduler.step()方法 205 | """ 206 | if warmup is True and x <= (warmup_epochs * num_step): 207 | alpha = float(x) / (warmup_epochs * num_step) 208 | # warmup过程中lr倍率因子从warmup_factor -> 1 209 | return warmup_factor * (1 - alpha) + alpha 210 | else: 211 | current_step = (x - warmup_epochs * num_step) 212 | cosine_steps = (epochs - warmup_epochs) * num_step 213 | # warmup后lr倍率因子从1 -> end_factor 214 | return ((1 + math.cos(current_step * math.pi / cosine_steps)) / 2) * (1 - end_factor) + end_factor 215 | 216 | return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=f) 217 | 218 | 219 | 220 | def get_params_groups(model: torch.nn.Module, weight_decay: float = 1e-5): 221 | # 记录optimize要训练的权重参数 222 | parameter_group_vars = {"decay": {"params": [], "weight_decay": weight_decay}, 223 | "no_decay": {"params": [], "weight_decay": 0.}} 224 | 225 | # 记录对应的权重名称 226 | parameter_group_names = {"decay": {"params": [], "weight_decay": weight_decay}, 227 | "no_decay": {"params": [], "weight_decay": 0.}} 228 | 229 | for name, param in model.named_parameters(): 230 | if not param.requires_grad: 231 | continue # frozen weights 232 | 233 | if len(param.shape) == 1 or name.endswith(".bias"): 234 | group_name = "no_decay" 235 | else: 236 | group_name = "decay" 237 | 238 | parameter_group_vars[group_name]["params"].append(param) 239 | parameter_group_names[group_name]["params"].append(name) 240 | 241 | print("Param groups = %s" % json.dumps(parameter_group_names, indent=2)) 242 | return list(parameter_group_vars.values()) 243 | --------------------------------------------------------------------------------