├── __pycache__ └── config.cpython-38.pyc ├── classes.txt ├── cls_test.txt ├── cls_train.txt ├── config.py ├── datasets ├── test │ └── 存放测试数据集.txt └── train │ └── 存放训练数据.txt ├── eval.py ├── losses └── focal_loss.py ├── nets ├── ConvMixer.py ├── Convnext.py ├── MlpMixer.py ├── __pycache__ │ ├── ConvMixer.cpython-38.pyc │ └── MlpMixer.cpython-38.pyc ├── attention.py ├── resnet.py └── vit.py ├── predict.py ├── readme.md ├── torch.yaml ├── train.py ├── txt_annotation.py └── utils ├── __pycache__ ├── dataloader.cpython-38.pyc ├── train_one_epoch.cpython-38.pyc └── utils.cpython-38.pyc ├── dataloader.py ├── train_one_epoch.py ├── training_utils.py └── utils.py /__pycache__/config.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jiantenggei/torch-classification/ed5a39941bda5897aa3c2e3b1c5bb3d00806aea7/__pycache__/config.cpython-38.pyc -------------------------------------------------------------------------------- /classes.txt: -------------------------------------------------------------------------------- 1 | cats 2 | dogs -------------------------------------------------------------------------------- /cls_test.txt: -------------------------------------------------------------------------------- 1 | #存放测试集路径和标签 -------------------------------------------------------------------------------- /cls_train.txt: -------------------------------------------------------------------------------- 1 | #存放训练集路径和标签 -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | Cuda = True #是否使用GPU 没有为Flase 2 | 3 | input_shape = [48,48] # 输入图片大小 4 | is_grayscale = False 5 | batch_size = 2 # 自己可以更改 6 | lr = 1e-3 7 | 8 | classes_path = 'classes.txt' 9 | 10 | 11 | num_workers = 0 # 是否开启多进程 12 | 13 | 14 | train_annotation_path = 'cls_train.txt' 15 | 16 | val_annotation_path = 'cls_test.txt' 17 | 18 | 19 | 20 | 21 | resume ='' # 加载训练权重路径 22 | 23 | log_dir = 'logs' # 日志路径 tensorboard 保存 24 | 25 | #------------------------------------------# 26 | # FocalLoss :处理样本不均衡 27 | # alpha 28 | # gamma >0 当 gamma=0 时就是交叉熵损失函数 29 | # 论文中gamma = [0,0.5,1,2,5] 30 | # 一般而言当γ增加的时候,a需要减小一点 31 | # reduction : 就平均:'mean' 求和 'sum' 32 | # 还未ti 33 | #------------------------------------------# 34 | #Focal_loss = True # True Focal loss 处理原本不均衡 False 使用 CrossEntropyLoss() # 还未使用成功 35 | 36 | #label_smoothing 防止过拟合 37 | label_smoothing = True # 38 | 39 | smoothing_value = 0.1 #[0,1] 之间 40 | 41 | 42 | 43 | #学习率变化策略 44 | scheduler = 'cos' #[None,reduce,cos] None保持不变 reduce 按epoch 来减少 cos 余弦下降算法 45 | 46 | 47 | -------------------------------------------------------------------------------- /datasets/test/存放测试数据集.txt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jiantenggei/torch-classification/ed5a39941bda5897aa3c2e3b1c5bb3d00806aea7/datasets/test/存放测试数据集.txt -------------------------------------------------------------------------------- /datasets/train/存放训练数据.txt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jiantenggei/torch-classification/ed5a39941bda5897aa3c2e3b1c5bb3d00806aea7/datasets/train/存放训练数据.txt -------------------------------------------------------------------------------- /eval.py: -------------------------------------------------------------------------------- 1 | from unicodedata import name 2 | import matplotlib.pyplot as plt 3 | import torch 4 | from utils.dataloader import DataGenerator, detection_collate 5 | from utils.utils import cvtColor,letterbox_image,get_classes,load_dict 6 | from config import input_shape,Cuda,classes_path # 来源于config.py 中的Cuda 7 | import numpy as np 8 | from PIL import Image 9 | from nets.ConvMixer import ConvMixer_768_32 10 | from tqdm import tqdm 11 | import cv2 12 | import sys 13 | @torch.no_grad() 14 | def evaluate(model, data_loader, epoch): 15 | loss_function = torch.nn.CrossEntropyLoss() 16 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 17 | 18 | model.eval() 19 | 20 | accu_num = torch.zeros(1).to(device) # 累计预测正确的样本数 21 | accu_loss = torch.zeros(1).to(device) # 累计损失 22 | model.to(device) 23 | sample_num = 0 24 | data_loader = tqdm(data_loader, file=sys.stdout) 25 | for step, data in enumerate(data_loader): 26 | images, labels = data 27 | images = torch.from_numpy(images).type(torch.FloatTensor) 28 | labels = torch.from_numpy(labels).type(torch.FloatTensor).long() 29 | sample_num += images.shape[0] 30 | 31 | pred = model(images.to(device)) 32 | pred_classes = torch.max(pred, dim=1)[1] 33 | accu_num += torch.eq(pred_classes, labels.to(device)).sum() 34 | 35 | loss = loss_function(pred, labels.to(device)) 36 | accu_loss += loss 37 | 38 | data_loader.desc = "[valid epoch {}] loss: {:.3f}, acc: {:.3f}".format( 39 | epoch, 40 | accu_loss.item() / (step + 1), 41 | accu_num.item() / sample_num 42 | ) 43 | 44 | return accu_loss.item() / (step + 1), accu_num.item() / sample_num 45 | 46 | class eval_top: 47 | 48 | def __init__(self,anno_lines,model) -> None: 49 | 50 | self.anno_lines = anno_lines 51 | self.model = model 52 | self.class_name,_ = get_classes(classes_path) 53 | 54 | def detect_img(self,image,mode='predict'): 55 | 56 | # image -> RGB 57 | image = cvtColor(image) 58 | 59 | image_data = letterbox_image(image,input_shape) 60 | 61 | # 归一化 62 | image_data = np.array(image_data, np.float32) 63 | image_data = image_data/127.5 64 | image_data -= 1.0 65 | 66 | 67 | # 添加bacth_size 维度 68 | image_data = np.expand_dims(image_data,0) 69 | #[batch_size,width,height,channel] -> [batch_size,channel,width,height] 70 | image_data = np.transpose(image_data,(0, 3, 1, 2)) 71 | 72 | with torch.no_grad(): 73 | img = torch.from_numpy(image_data) 74 | if Cuda: 75 | img = img.cuda() 76 | self.model.cuda() 77 | 78 | pred = torch.softmax(self.model(img)[0],dim=-1).cpu().numpy() 79 | 80 | name = self.class_name[np.argmax(pred)] 81 | 82 | # 预测 83 | if mode == 'predict': 84 | probability = np.max(pred) 85 | #---------------------------------------------------# 86 | # 绘图并写字 87 | #---------------------------------------------------# 88 | plt.subplot(1, 1, 1) 89 | plt.imshow(np.array(image)) 90 | plt.title('Class:%s Probability:%.3f' %(name, probability)) 91 | plt.show() 92 | return name 93 | #top1 94 | elif mode == 'top1': 95 | return np.argmax(pred) 96 | ##top5 97 | elif mode == 'top5' : 98 | arg_pred = np.argsort(pred)[::-1] 99 | arg_pred_top5 = arg_pred[:5] 100 | return arg_pred_top5 101 | 102 | #---------------------------------------------------# 103 | # eval_top1 104 | #---------------------------------------------------# 105 | def eval_top1(self): 106 | print('Eval Top1....') 107 | correct = 0 108 | total = len(self.anno_lines) 109 | with tqdm(total=total,postfix=dict,mininterval=0.3) as pbar: 110 | for idx,line in enumerate(self.anno_lines): 111 | annotation_path = line.split(';')[1].split()[0] 112 | x = Image.open(annotation_path) 113 | y = int(line.split(';')[0]) 114 | 115 | pred = self.detect_img(x,mode='top1') 116 | correct += pred == y 117 | pbar.update(1) 118 | return correct / total 119 | 120 | #---------------------------------------------------# 121 | # eval_top5 更新进度条 122 | #---------------------------------------------------# 123 | def eval_top5(self): 124 | correct = 0 125 | total = len(self.anno_lines) 126 | print('Eval Top5....') 127 | with tqdm(total=total,postfix=dict,mininterval=0.3) as pbar: 128 | for idx,line in enumerate(self.anno_lines): 129 | annotation_path = line.split(';')[1].split()[0] 130 | x = Image.open(annotation_path) 131 | y = int(line.split(';')[0]) 132 | 133 | pred = self.detect_img(x,'top5') 134 | correct += y in pred 135 | pbar.update(1) 136 | return correct / total 137 | if __name__ == "__main__": 138 | from torch.utils.data import DataLoader 139 | # 读取测试集路劲和标签 140 | with open("./cls_test.txt","r") as f: 141 | lines = f.readlines() 142 | #---------------------------------------------------# 143 | # 权重和模型 144 | # 注意:训练时设置的模型需要和权重匹配, 145 | # 也就是训练的啥模型使用啥权重 146 | #---------------------------------------------------# 147 | model_path = 'logs\ep300-loss0.990-val_loss1.105.pth' #训练好的权重路径 148 | from nets.resnet import ResNet18 149 | model = ResNet18() 150 | 151 | model = load_dict(model_path,model) 152 | 153 | 154 | dataset = DataGenerator(lines, input_shape, False,is_grayscale=True) 155 | 156 | gen_val = DataLoader(dataset, batch_size=128, num_workers=0, pin_memory=True, 157 | drop_last=True, collate_fn=detection_collate) 158 | 159 | print(evaluate(model=model,data_loader=gen_val,epoch=0)) -------------------------------------------------------------------------------- /losses/focal_loss.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | # version 1: use torch.autograd 6 | class FocalLoss(nn.Module): 7 | #------------------------------------------# 8 | # FocalLoss :处理样本不均衡 9 | # alpha 10 | # gamma >0 当 gamma=0 时就是交叉熵损失函数 11 | # 论文中gamma = [0,0.5,1,2,5] 12 | # 一般而言当γ增加的时候,a需要减小一点 13 | # reduction : 就平均:'mean' 求和 'sum' 14 | #------------------------------------------# 15 | def __init__(self,alpha=0.25,gamma=2,reduction='mean',): 16 | super(FocalLoss, self).__init__() 17 | self.alpha = alpha 18 | self.gamma = gamma 19 | self.reduction = reduction 20 | self.crit = nn.BCEWithLogitsLoss(reduction='none') 21 | 22 | 23 | def forward(self, logits, label): 24 | probs = torch.sigmoid(logits) 25 | coeff = torch.abs(label - probs).pow(self.gamma).neg() 26 | log_probs = torch.where(logits >= 0, 27 | F.softplus(logits, -1, 50), 28 | logits - F.softplus(logits, 1, 50)) 29 | log_1_probs = torch.where(logits >= 0, 30 | -logits + F.softplus(logits, -1, 50), 31 | -F.softplus(logits, 1, 50)) 32 | loss = label * self.alpha * log_probs + (1. - label) * (1. - self.alpha) * log_1_probs 33 | loss = loss * coeff 34 | 35 | if self.reduction == 'mean': 36 | loss = loss.mean() 37 | if self.reduction == 'sum': 38 | loss = loss.sum() 39 | return loss -------------------------------------------------------------------------------- /nets/ConvMixer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torchsummary import summary 4 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 5 | #-----------------------------------------------------------------------# 6 | # Conv-Mixer 网络 7 | # 论文地址:https://openreview.net/pdf?id=TVHS5Y4dNvM 8 | # 我的博客 :https://blog.csdn.net/qq_38676487/article/details/120705254 9 | #-------------------------------------------------------------------------# 10 | class ConvMixerLayer(nn.Module): 11 | def __init__(self,dim,kernel_size = 9): 12 | super().__init__() 13 | #残差结构 14 | self.Resnet = nn.Sequential( 15 | nn.Conv2d(dim,dim,kernel_size=kernel_size,groups=dim,padding='same'), 16 | nn.GELU(), 17 | nn.BatchNorm2d(dim) 18 | ) 19 | #逐点卷积 20 | self.Conv_1x1 = nn.Sequential( 21 | nn.Conv2d(dim,dim,kernel_size=1), 22 | nn.GELU(), 23 | nn.BatchNorm2d(dim) 24 | ) 25 | def forward(self,x): 26 | x = x +self.Resnet(x) 27 | x = self.Conv_1x1(x) 28 | return x 29 | 30 | class ConvMixer(nn.Module): 31 | def __init__(self,dim,depth,kernel_size=9, patch_size=7, n_classes=1000): 32 | super().__init__() 33 | self.conv2d1 = nn.Sequential( 34 | nn.Conv2d(3,dim,kernel_size=patch_size,stride=patch_size), 35 | nn.GELU(), 36 | nn.BatchNorm2d(dim) 37 | ) 38 | self.ConvMixer_blocks =nn.ModuleList([]) 39 | 40 | for _ in range(depth): 41 | self.ConvMixer_blocks.append(ConvMixerLayer(dim=dim,kernel_size=kernel_size)) 42 | 43 | self.head = nn.Sequential( 44 | nn.AdaptiveAvgPool2d((1,1)), 45 | nn.Flatten(), 46 | nn.Linear(dim,n_classes) 47 | ) 48 | 49 | def forward(self,x): 50 | x = self.conv2d1(x) 51 | 52 | for ConvMixer_block in self.ConvMixer_blocks: 53 | x = ConvMixer_block(x) 54 | 55 | x = self.head(x) 56 | 57 | return x 58 | #-----------------------------------------------------------------------# 59 | # 论文中给出的配置: 60 | # ConvMixer_h_d h:dim 隐藏层维度 d:depth 网络深度 61 | #-------------------------------------------------------------------------# 62 | def ConvMixer_1536_20(n_classes = 1000): 63 | return ConvMixer(dim=1536,depth=20,patch_size=7,kernel_size=9,n_classes=n_classes) 64 | 65 | def ConvMixer_768_32(n_classes = 1000): 66 | return ConvMixer(dim=768,depth=32,patch_size=7,kernel_size=7,n_classes=n_classes) 67 | 68 | # 自定义的 ConvMixer 不传参 为 ConvMixer_768_32 69 | def custom_ConvMixer(dim=768,depth=32,patch_size=7,kernel_size=7,n_classes=1000): 70 | return ConvMixer(dim=dim,depth=depth,patch_size = patch_size,kernel_size=kernel_size,n_classes=n_classes) 71 | 72 | if __name__ == '__main__': 73 | model =ConvMixer_1536_20().to(device) 74 | summary(model, (3, 224, 224)) -------------------------------------------------------------------------------- /nets/Convnext.py: -------------------------------------------------------------------------------- 1 | from statistics import mode 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from timm.models.layers import trunc_normal_,DropPath 6 | from timm.models.registry import register_model 7 | from torchsummary import summary 8 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 9 | #-----------------------------------------------------------------------# 10 | # ConvNeXt 网络 11 | # 论文地址: https://arxiv.org/pdf/2201.03545.pdf 12 | # 我的博客 :https://blog.csdn.net/qq_38676487/article/details/123298605 13 | #-------------------------------------------------------------------------# 14 | class Block(nn.Module): 15 | #-----------------------------------------------------------------------# 16 | # ConvNeXt Block 块 两种实现方式 17 | # (1) 深度可分离卷积 + 1x1 的卷积 18 | # DwConv -> LayerNorm (channels_first) -> 1x1 Conv -> GELU -> 1x1 Conv; all in (N, C, H, W) 19 | # (2)深度可分离卷积 + Linear 全连接来代替 1x1 卷积 ,发现在pytorch 更快 20 | # DwConv -> Permute to (N, H, W, C); LayerNorm (channels_last) -> Linear -> GELU -> Linear; Permute back 21 | # 参数: 22 | # dim:维度 drop_path:0~1 layer_scale_init_value: 23 | #-------------------------------------------------------------------------# 24 | def __init__(self, dim, drop_path=0., layer_scale_init_value=1e-6): 25 | super().__init__() 26 | self.dwconv = nn.Conv2d(dim, dim, kernel_size=7, padding=3, groups=dim) # 深度课分离卷积 27 | self.norm = LayerNorm(dim, eps=1e-6) 28 | self.pwconv1 = nn.Linear(dim, 4 * dim) # 用全连接代替1x1的卷积 29 | self.act = nn.GELU() 30 | self.pwconv2 = nn.Linear(4 * dim, dim) 31 | # 一个可学习的参数 32 | self.gamma = nn.Parameter(layer_scale_init_value * torch.ones((dim)), 33 | requires_grad=True) if layer_scale_init_value > 0 else None 34 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 35 | 36 | def forward(self, x): 37 | input = x 38 | x = self.dwconv(x) 39 | x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C) 40 | x = self.norm(x) 41 | x = self.pwconv1(x) 42 | x = self.act(x) 43 | x = self.pwconv2(x) 44 | if self.gamma is not None: 45 | x = self.gamma * x 46 | x = x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W) 47 | 48 | x = input + self.drop_path(x) 49 | return x 50 | 51 | 52 | class LayerNorm(nn.Module): 53 | #-----------------------------------------------------------------------# 54 | # 自定义 LayerNorm 默认channels_last 55 | # channels_last [batch_size, height, width, channels] 56 | # channels_first [batch_size, channels, height, width] 57 | #-------------------------------------------------------------------------# 58 | def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"): 59 | super().__init__() 60 | self.weight = nn.Parameter(torch.ones(normalized_shape)) 61 | self.bias = nn.Parameter(torch.zeros(normalized_shape)) 62 | self.eps = eps 63 | self.data_format = data_format 64 | if self.data_format not in ["channels_last", "channels_first"]: 65 | raise NotImplementedError 66 | self.normalized_shape = (normalized_shape, ) 67 | 68 | def forward(self, x): 69 | if self.data_format == "channels_last": 70 | return F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps) 71 | elif self.data_format == "channels_first": 72 | u = x.mean(1, keepdim=True) 73 | s = (x - u).pow(2).mean(1, keepdim=True) 74 | x = (x - u) / torch.sqrt(s + self.eps) 75 | x = self.weight[:, None, None] * x + self.bias[:, None, None] 76 | return x 77 | 78 | 79 | class ConvNeXt(nn.Module): 80 | def __init__(self, in_chans=3, num_classes=1000, 81 | depths=[3, 3, 9, 3], dims=[96, 192, 384, 768], drop_path_rate=0., 82 | layer_scale_init_value=1e-6, head_init_scale=1., 83 | ): 84 | super().__init__() 85 | 86 | 87 | #保存stem 和下采样 88 | self.downsample_layers = nn.ModuleList() # 89 | #[batch_size,3,224,224] -> [batch_size,dim[0],56,56] 90 | stem = nn.Sequential( 91 | nn.Conv2d(in_chans, dims[0], kernel_size=4, stride=4), 92 | LayerNorm(dims[0], eps=1e-6, data_format="channels_first") 93 | ) 94 | 95 | self.downsample_layers.append(stem) 96 | 97 | 98 | #下采样 -> 用2x2 步长为2 的卷积来代替池化 99 | # 这里一次将所有stage的下采样放入 100 | for i in range(3): 101 | downsample_layer = nn.Sequential( 102 | LayerNorm(dims[i], eps=1e-6, data_format="channels_first"), 103 | nn.Conv2d(dims[i], dims[i+1], kernel_size=2, stride=2), 104 | ) 105 | self.downsample_layers.append(downsample_layer) 106 | #添加stage 107 | self.stages = nn.ModuleList() 108 | dp_rates=[x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] 109 | cur = 0 110 | for i in range(4): 111 | stage = nn.Sequential( 112 | *[Block(dim=dims[i], drop_path=dp_rates[cur + j], 113 | layer_scale_init_value=layer_scale_init_value) for j in range(depths[i])] 114 | ) 115 | self.stages.append(stage) 116 | cur += depths[i] 117 | 118 | #最后的分类输出 119 | self.norm = nn.LayerNorm(dims[-1], eps=1e-6) # final norm layer 120 | self.head = nn.Linear(dims[-1], num_classes) 121 | 122 | def forward_features(self, x): 123 | for i in range(4): 124 | x = self.downsample_layers[i](x) 125 | x = self.stages[i](x) 126 | return self.norm(x.mean([-2, -1])) # global average pooling, (N, C, H, W) -> (N, C) 127 | 128 | def forward(self, x): 129 | x = self.forward_features(x) 130 | x = self.head(x) 131 | return x 132 | 133 | 134 | #-----------------------------------------------------------------------# 135 | # 论文中的 model, 以及其预训练权重 136 | # model.head = torch.nn.Linear(768,num_classes) 137 | # 加载预训练权重后 仍然可以调整分类数 138 | # 训练数据需要是三通道彩图 139 | #-------------------------------------------------------------------------# 140 | model_urls = { 141 | "convnext_tiny_1k": "https://dl.fbaipublicfiles.com/convnext/convnext_tiny_1k_224_ema.pth", 142 | "convnext_small_1k": "https://dl.fbaipublicfiles.com/convnext/convnext_small_1k_224_ema.pth", 143 | "convnext_base_1k": "https://dl.fbaipublicfiles.com/convnext/convnext_base_1k_224_ema.pth", 144 | "convnext_large_1k": "https://dl.fbaipublicfiles.com/convnext/convnext_large_1k_224_ema.pth", 145 | "convnext_base_22k": "https://dl.fbaipublicfiles.com/convnext/convnext_base_22k_224.pth", 146 | "convnext_large_22k": "https://dl.fbaipublicfiles.com/convnext/convnext_large_22k_224.pth", 147 | "convnext_xlarge_22k": "https://dl.fbaipublicfiles.com/convnext/convnext_xlarge_22k_224.pth", 148 | } 149 | 150 | 151 | @register_model 152 | def convnext_tiny(pretrained=False, num_classes=1000,**kwargs): 153 | model = ConvNeXt(depths=[3, 3, 9, 3], dims=[96, 192, 384, 768], **kwargs) 154 | if pretrained: 155 | url = model_urls['convnext_tiny_1k'] 156 | checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu", check_hash=True) 157 | model.load_state_dict(checkpoint["model"],strict=False) 158 | #更改model head 使其能够符合自己的分类数 159 | model.head = torch.nn.Linear(768,num_classes) 160 | return model 161 | 162 | 163 | @register_model 164 | def convnext_small(pretrained=False,num_classes=1000, **kwargs): 165 | model = ConvNeXt(depths=[3, 3, 27, 3], dims=[96, 192, 384, 768], **kwargs) 166 | if pretrained: 167 | url = model_urls['convnext_small_1k'] 168 | checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu", check_hash=True) 169 | model.load_state_dict(checkpoint["model"]) 170 | model.head = torch.nn.Linear(768,num_classes) 171 | return model 172 | 173 | 174 | @register_model 175 | def convnext_base(pretrained=False, in_22k=False,num_classes=1000, **kwargs): 176 | model = ConvNeXt(depths=[3, 3, 27, 3], dims=[128, 256, 512, 1024], **kwargs) 177 | if pretrained: 178 | url = model_urls['convnext_base_22k'] if in_22k else model_urls['convnext_base_1k'] 179 | checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu", check_hash=True) 180 | model.load_state_dict(checkpoint["model"]) 181 | model.head = torch.nn.Linear(1024,num_classes) 182 | return model 183 | 184 | @register_model 185 | def convnext_large(pretrained=False, in_22k=False,num_classes=1000, **kwargs): 186 | model = ConvNeXt(depths=[3, 3, 27, 3], dims=[192, 384, 768, 1536], **kwargs) 187 | if pretrained: 188 | url = model_urls['convnext_large_22k'] if in_22k else model_urls['convnext_large_1k'] 189 | checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu", check_hash=True) 190 | model.load_state_dict(checkpoint["model"]) 191 | model.head = torch.nn.Linear(1536,num_classes) 192 | return model 193 | 194 | @register_model 195 | def convnext_xlarge(pretrained=False, in_22k=False, num_classes=1000, **kwargs): 196 | model = ConvNeXt(depths=[3, 3, 27, 3], dims=[256, 512, 1024, 2048], **kwargs) 197 | if pretrained: 198 | url = model_urls['convnext_xlarge_22k'] if in_22k else model_urls['convnext_xlarge_1k'] 199 | checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu", check_hash=True) 200 | model.load_state_dict(checkpoint["model"]) 201 | model.head = torch.nn.Linear(2048,num_classes) 202 | return model 203 | 204 | if __name__ == '__main__': 205 | # 测试自定义分类数 206 | model_tiny =convnext_tiny(pretrained=False,num_classes=2).to(device) 207 | summary(model_tiny, (3, 48, 48)) 208 | # model_xlarge =convnext_xlarge(pretrained=False,num_classes=2).to(device) 209 | # summary(model_xlarge, (3, 224, 224)) 210 | # model_large =convnext_large(pretrained=False,num_classes=2).to(device) 211 | # summary(model_large, (3, 224, 224)) 212 | # model_base =convnext_base(pretrained=False,num_classes=2).to(device) 213 | # summary(model_base, (3, 224, 224)) 214 | # model_small =convnext_small(pretrained=False,num_classes=2).to(device) 215 | # summary(model_small, (3, 224, 224)) 216 | 217 | 218 | -------------------------------------------------------------------------------- /nets/MlpMixer.py: -------------------------------------------------------------------------------- 1 | #定义多层感知机 2 | import torch 3 | import numpy as np 4 | from torch import nn 5 | from einops.layers.torch import Rearrange 6 | from torchsummary import summary 7 | import torch.nn.functional as F 8 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 9 | #-----------------------------------------------------------------------# 10 | # MLP-Mixer 网络 11 | # 论文地址:https://arxiv.org/abs/2105.01601v1 12 | # 我的博客 :https://blog.csdn.net/qq_38676487/article/details/118640096 13 | #-------------------------------------------------------------------------# 14 | class FeedForward(nn.Module): 15 | def __init__(self,dim,hidden_dim,dropout=0.): 16 | super().__init__() 17 | self.net=nn.Sequential( 18 | #由此可以看出 FeedForward 的输入和输出维度是一致的 19 | nn.Linear(dim,hidden_dim), 20 | #激活函数 21 | nn.GELU(), 22 | #防止过拟合 23 | nn.Dropout(dropout), 24 | #重复上述过程 25 | nn.Linear(hidden_dim,dim), 26 | 27 | nn.Dropout(dropout) 28 | ) 29 | def forward(self,x): 30 | x=self.net(x) 31 | return x 32 | 33 | 34 | class MixerBlock(nn.Module): 35 | def __init__(self,dim,num_patch,token_dim,channel_dim,dropout=0.): 36 | super().__init__() 37 | self.token_mixer=nn.Sequential( 38 | nn.LayerNorm(dim), 39 | Rearrange('b n d -> b d n'), #这里是[batch_size, num_patch, dim] -> [batch_size, dim, num_patch] 40 | FeedForward(num_patch,token_dim,dropout), 41 | Rearrange('b d n -> b n d') #[batch_size, dim, num_patch] -> [batch_size, num_patch, dim] 42 | 43 | ) 44 | self.channel_mixer=nn.Sequential( 45 | nn.LayerNorm(dim), 46 | FeedForward(dim,channel_dim,dropout) 47 | ) 48 | def forward(self,x): 49 | 50 | x=x+self.token_mixer(x) 51 | 52 | x=x+self.channel_mixer(x) 53 | 54 | return x 55 | 56 | class MLPMixer(nn.Module): 57 | def __init__(self,in_channels,dim,num_classes,patch_size,image_size,depth,token_dim,channel_dim,dropout=0.): 58 | super().__init__() 59 | assert image_size%patch_size==0 60 | self.num_patches=(image_size//patch_size)**2 61 | #embedding 操作,用卷积来分成一小块一小块的 62 | self.to_embedding=nn.Sequential(nn.Conv2d(in_channels=in_channels,out_channels=dim,kernel_size=patch_size,stride=patch_size), 63 | Rearrange('b c h w -> b (h w) c') 64 | ) 65 | #经过Mixer Layer 的次数 66 | self.mixer_blocks=nn.ModuleList([]) 67 | for _ in range(depth): 68 | self.mixer_blocks.append(MixerBlock(dim,self.num_patches,token_dim,channel_dim,dropout)) 69 | self.layer_normal=nn.LayerNorm(dim) 70 | 71 | self.mlp_head=nn.Sequential( 72 | nn.Linear(dim,num_classes) 73 | ) 74 | def forward(self,x): 75 | x=self.to_embedding(x) 76 | for mixer_block in self.mixer_blocks: 77 | x=mixer_block(x) 78 | x=self.layer_normal(x) 79 | x=x.mean(dim=1) 80 | 81 | x=self.mlp_head(x) 82 | 83 | return x 84 | 85 | #测试Mlp-Mixer 86 | if __name__ == '__main__': 87 | model = MLPMixer(in_channels=3, dim=512, num_classes=1000, patch_size=16, image_size=224, depth=1, token_dim=256, 88 | channel_dim=2048).to(device) 89 | summary(model,(3,224,224)) 90 | -------------------------------------------------------------------------------- /nets/__pycache__/ConvMixer.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jiantenggei/torch-classification/ed5a39941bda5897aa3c2e3b1c5bb3d00806aea7/nets/__pycache__/ConvMixer.cpython-38.pyc -------------------------------------------------------------------------------- /nets/__pycache__/MlpMixer.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jiantenggei/torch-classification/ed5a39941bda5897aa3c2e3b1c5bb3d00806aea7/nets/__pycache__/MlpMixer.cpython-38.pyc -------------------------------------------------------------------------------- /nets/attention.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import math 4 | #---------------------------------------------------------------------------- 5 | # senet 通道注意力 6 | # 参考链接: 7 | # https://github.com/moskomule/senet.pytorch/blob/master/senet/se_module.py 8 | #--------------------------------------------------------------------------- 9 | class SELayer(nn.Module): 10 | #--------------------------- 11 | # channel 输入通道数 12 | # reduction 压缩比 13 | #---------------------------- 14 | def __init__(self, channel, reduction=16): 15 | super(SELayer, self).__init__() 16 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 17 | self.fc = nn.Sequential( 18 | nn.Linear(channel, channel // reduction, bias=False), 19 | nn.ReLU(inplace=True), 20 | nn.Linear(channel // reduction, channel, bias=False), 21 | nn.Sigmoid() 22 | ) 23 | 24 | def forward(self, x): 25 | b, c, _, _ = x.size() 26 | y = self.avg_pool(x).view(b, c) 27 | y = self.fc(y).view(b, c, 1, 1) 28 | return x * y.expand_as(x) 29 | 30 | 31 | #---------------------------------------------------------------------------- 32 | # cbam 通道和空间注意力 同时注意空间和通道 33 | # 参考链接: 34 | # https://github.com/luuuyi/CBAM.PyTorch/blob/master/model/resnet_cbam.py 35 | #-------------------------------------------------------------------------- 36 | 37 | class ChannelAttention(nn.Module): 38 | def __init__(self, in_planes, ratio=16): 39 | super(ChannelAttention, self).__init__() 40 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 41 | self.max_pool = nn.AdaptiveMaxPool2d(1) 42 | #1x1 的卷积替换全链接 43 | self.fc = nn.Sequential(nn.Conv2d(in_planes, in_planes // ratio, 1, bias=False), 44 | nn.ReLU(), 45 | nn.Conv2d(in_planes // ratio, in_planes, 1, bias=False)) 46 | self.sigmoid = nn.Sigmoid() 47 | 48 | def forward(self, x): 49 | avg_out = self.fc(self.avg_pool(x)) 50 | max_out = self.fc(self.max_pool(x)) 51 | out = avg_out + max_out 52 | return self.sigmoid(out) 53 | 54 | class SpatialAttention(nn.Module): 55 | def __init__(self, kernel_size=7): 56 | super(SpatialAttention, self).__init__() 57 | 58 | self.conv1 = nn.Conv2d(2, 1, kernel_size, padding=kernel_size//2, bias=False) 59 | self.sigmoid = nn.Sigmoid() 60 | 61 | def forward(self, x): 62 | avg_out = torch.mean(x, dim=1, keepdim=True) 63 | max_out, _ = torch.max(x, dim=1, keepdim=True) 64 | x = torch.cat([avg_out, max_out], dim=1) 65 | x = self.conv1(x) 66 | return self.sigmoid(x) 67 | 68 | class cbam_block(nn.Module): 69 | def __init__(self, channel, ratio=8, kernel_size=7): 70 | super(cbam_block, self).__init__() 71 | self.channelattention = ChannelAttention(channel, ratio=ratio) 72 | self.spatialattention = SpatialAttention(kernel_size=kernel_size) 73 | 74 | def forward(self, x): 75 | x = x * self.channelattention(x) 76 | x = x * self.spatialattention(x) 77 | return x 78 | 79 | #---------------------------------------------------------------------------- 80 | # eca 通道和空间注意力 同时注意空间和通道 81 | # 参考链接: 82 | # https://github.com/digantamisra98/Reproducibilty-Challenge-ECANET/tree/main/models 83 | #-------------------------------------------------------------------------- 84 | class eca_block(nn.Module): 85 | def __init__(self, channel, b=1, gamma=2): 86 | super(eca_block, self).__init__() 87 | kernel_size = int(abs((math.log(channel, 2) + b) / gamma)) 88 | kernel_size = kernel_size if kernel_size % 2 else kernel_size + 1 89 | 90 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 91 | self.conv = nn.Conv1d(1, 1, kernel_size=kernel_size, padding=(kernel_size - 1) // 2, bias=False) 92 | self.sigmoid = nn.Sigmoid() 93 | 94 | def forward(self, x): 95 | y = self.avg_pool(x) 96 | y = self.conv(y.squeeze(-1).transpose(-1, -2)).transpose(-1, -2).unsqueeze(-1) 97 | y = self.sigmoid(y) 98 | return x * y.expand_as(x) -------------------------------------------------------------------------------- /nets/resnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class BasicBlock(nn.Module): 7 | expansion = 1 8 | 9 | def __init__(self, in_planes, planes, stride=1): 10 | super(BasicBlock, self).__init__() 11 | self.conv1 = nn.Conv2d( 12 | in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 13 | self.bn1 = nn.BatchNorm2d(planes) 14 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, 15 | stride=1, padding=1, bias=False) 16 | self.bn2 = nn.BatchNorm2d(planes) 17 | 18 | self.shortcut = nn.Sequential() 19 | if stride != 1 or in_planes != self.expansion*planes: 20 | self.shortcut = nn.Sequential( 21 | nn.Conv2d(in_planes, self.expansion*planes, 22 | kernel_size=1, stride=stride, bias=False), 23 | nn.BatchNorm2d(self.expansion*planes) 24 | ) 25 | 26 | def forward(self, x): 27 | out = F.relu(self.bn1(self.conv1(x))) 28 | out = self.bn2(self.conv2(out)) 29 | out += self.shortcut(x) 30 | out = F.relu(out) 31 | return out 32 | 33 | 34 | class Bottleneck(nn.Module): 35 | expansion = 4 36 | def __init__(self, in_planes, planes, stride=1): 37 | super(Bottleneck, self).__init__() 38 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False) 39 | self.bn1 = nn.BatchNorm2d(planes) 40 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, 41 | stride=stride, padding=1, bias=False) 42 | self.bn2 = nn.BatchNorm2d(planes) 43 | self.conv3 = nn.Conv2d(planes, self.expansion * 44 | planes, kernel_size=1, bias=False) 45 | self.bn3 = nn.BatchNorm2d(self.expansion*planes) 46 | 47 | self.shortcut = nn.Sequential() 48 | if stride != 1 or in_planes != self.expansion*planes: 49 | self.shortcut = nn.Sequential( 50 | nn.Conv2d(in_planes, self.expansion*planes, 51 | kernel_size=1, stride=stride, bias=False), 52 | nn.BatchNorm2d(self.expansion*planes) 53 | ) 54 | 55 | def forward(self, x): 56 | out = F.relu(self.bn1(self.conv1(x))) 57 | out = F.relu(self.bn2(self.conv2(out))) 58 | out = self.bn3(self.conv3(out)) 59 | out += self.shortcut(x) 60 | out = F.relu(out) 61 | return out 62 | 63 | 64 | class ResNet(nn.Module): 65 | def __init__(self, block, num_blocks, num_classes=7): 66 | super(ResNet, self).__init__() 67 | self.in_planes = 64 68 | 69 | self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False) 70 | #self.conv1 = nn.Conv2d(1, 64, kernel_size = 3, stride = 1, padding = 1, bias = False) 71 | self.bn1 = nn.BatchNorm2d(64) 72 | self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1) 73 | self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2) 74 | self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2) 75 | self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2) 76 | self.linear = nn.Linear(512*block.expansion, num_classes) 77 | 78 | def _make_layer(self, block, planes, num_blocks, stride): 79 | strides = [stride] + [1]*(num_blocks-1) 80 | layers = [] 81 | for stride in strides: 82 | layers.append(block(self.in_planes, planes, stride)) 83 | self.in_planes = planes * block.expansion 84 | return nn.Sequential(*layers) 85 | 86 | def forward(self, x): 87 | out = F.relu(self.bn1(self.conv1(x))) 88 | out = self.layer1(out) 89 | out = self.layer2(out) 90 | out = self.layer3(out) 91 | out = self.layer4(out) 92 | out = F.avg_pool2d(out, 4) 93 | out = out.view(out.size(0), -1) 94 | out = self.linear(out) 95 | return out 96 | 97 | 98 | def ResNet18(): 99 | return ResNet(BasicBlock, [2, 2, 2, 2]) 100 | 101 | 102 | def ResNet34(): 103 | return ResNet(BasicBlock, [3, 4, 6, 3]) 104 | 105 | 106 | def ResNet50(): 107 | return ResNet(Bottleneck, [3, 4, 6, 3]) 108 | 109 | 110 | def ResNet101(): 111 | return ResNet(Bottleneck, [3, 4, 23, 3]) 112 | 113 | 114 | def ResNet152(): 115 | return ResNet(Bottleneck, [3, 8, 36, 3]) -------------------------------------------------------------------------------- /nets/vit.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | from torch import nn 4 | 5 | from einops import rearrange, repeat 6 | from einops.layers.torch import Rearrange 7 | 8 | #-----------------------------------------------------------------------# 9 | # Vit 网络 10 | #-------------------------------------------------------------------------# 11 | 12 | 13 | #-------------------------------------------- 14 | # 用户判断 是否为正方形输入 如224x224 15 | # 如果不是(例如:224) 自动调整 16 | #-------------------------------------------- 17 | def pair(t): 18 | return t if isinstance(t, tuple) else (t, t) 19 | 20 | 21 | #-------------------------------------------- 22 | # Normal + FN 23 | # FN 为传入变量 传入 multi-head attention 24 | #-------------------------------------------- 25 | class PreNorm(nn.Module): 26 | def __init__(self, dim, fn): 27 | super().__init__() 28 | self.norm = nn.LayerNorm(dim) 29 | self.fn = fn 30 | def forward(self, x, **kwargs): 31 | return self.fn(self.norm(x), **kwargs) 32 | 33 | 34 | #-------------------------------------------- 35 | # Tranformer Encoder 中 MLP 块 36 | #-------------------------------------------- 37 | class FeedForward(nn.Module): 38 | def __init__(self, dim, hidden_dim, dropout = 0.): 39 | super().__init__() 40 | self.net = nn.Sequential( 41 | nn.Linear(dim, hidden_dim), 42 | nn.GELU(), 43 | nn.Dropout(dropout), 44 | nn.Linear(hidden_dim, dim), 45 | nn.Dropout(dropout) 46 | ) 47 | def forward(self, x): 48 | return self.net(x) 49 | 50 | #------------------------------------------------- 51 | # Tranformer Encoder 中 multi-head attention 块 52 | #------------------------------------------------- 53 | class Attention(nn.Module): 54 | def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.): 55 | super().__init__() 56 | inner_dim = dim_head * heads 57 | project_out = not (heads == 1 and dim_head == dim) 58 | 59 | self.heads = heads 60 | self.scale = dim_head ** -0.5 61 | 62 | self.attend = nn.Softmax(dim = -1) 63 | self.dropout = nn.Dropout(dropout) 64 | 65 | self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False) 66 | 67 | self.to_out = nn.Sequential( 68 | nn.Linear(inner_dim, dim), 69 | nn.Dropout(dropout) 70 | ) if project_out else nn.Identity() 71 | 72 | def forward(self, x): 73 | qkv = self.to_qkv(x).chunk(3, dim = -1) 74 | q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv) 75 | 76 | dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale 77 | 78 | attn = self.attend(dots) 79 | attn = self.dropout(attn) 80 | 81 | out = torch.matmul(attn, v) 82 | out = rearrange(out, 'b h n d -> b n (h d)') 83 | return self.to_out(out) 84 | 85 | #------------------------------------------------- 86 | # Tranformer Encoder 块 87 | #------------------------------------------------- 88 | class Transformer(nn.Module): 89 | def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.): 90 | super().__init__() 91 | self.layers = nn.ModuleList([]) 92 | for _ in range(depth): 93 | self.layers.append(nn.ModuleList([ 94 | PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout)), 95 | PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout)) 96 | ])) 97 | def forward(self, x): 98 | for attn, ff in self.layers: 99 | x = attn(x) + x 100 | x = ff(x) + x 101 | return x 102 | 103 | class ViT(nn.Module): 104 | def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, pool = 'cls', channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0.): 105 | super().__init__() 106 | image_height, image_width = pair(image_size) 107 | patch_height, patch_width = pair(patch_size) 108 | 109 | assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.' 110 | 111 | num_patches = (image_height // patch_height) * (image_width // patch_width) 112 | patch_dim = channels * patch_height * patch_width 113 | assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)' 114 | 115 | self.to_patch_embedding = nn.Sequential( 116 | Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_height, p2 = patch_width), 117 | nn.Linear(patch_dim, dim), 118 | ) 119 | 120 | self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim)) 121 | self.cls_token = nn.Parameter(torch.randn(1, 1, dim)) 122 | self.dropout = nn.Dropout(emb_dropout) 123 | 124 | self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout) 125 | 126 | self.pool = pool 127 | self.to_latent = nn.Identity() 128 | 129 | self.mlp_head = nn.Sequential( 130 | nn.LayerNorm(dim), 131 | nn.Linear(dim, num_classes) 132 | ) 133 | 134 | def forward(self, img): 135 | x = self.to_patch_embedding(img) 136 | b, n, _ = x.shape 137 | 138 | cls_tokens = repeat(self.cls_token, '1 n d -> b n d', b = b) 139 | x = torch.cat((cls_tokens, x), dim=1) 140 | x += self.pos_embedding[:, :(n + 1)] 141 | x = self.dropout(x) 142 | 143 | x = self.transformer(x) 144 | 145 | x = x.mean(dim = 1) if self.pool == 'mean' else x[:, 0] 146 | 147 | x = self.to_latent(x) 148 | return self.mlp_head(x) 149 | if __name__ == "__main__": 150 | 151 | from torchsummary import summary 152 | v = ViT( 153 | image_size = 256, 154 | patch_size = 32, 155 | num_classes = 1000, 156 | dim = 1024, 157 | depth = 6, 158 | heads = 16, 159 | mlp_dim = 2048, 160 | dropout = 0.1, 161 | emb_dropout = 0.1 162 | ) 163 | 164 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 165 | model = v.to(device) 166 | summary(model,(3,256,256)) -------------------------------------------------------------------------------- /predict.py: -------------------------------------------------------------------------------- 1 | from PIL import Image 2 | from eval import eval_top 3 | from nets.ConvMixer import ConvMixer_768_32 4 | from utils.utils import load_dict 5 | #加载模型 6 | model_path = 'logs\ep050-loss0.414-val_loss0.376.pth' 7 | model = ConvMixer_768_32(n_classes=2) 8 | model = load_dict(model_path,model) 9 | eval = eval_top(anno_lines=None,model=model) 10 | 11 | while True: 12 | img = input('Input image filename:') 13 | try: 14 | image = Image.open(img) 15 | except: 16 | print('Open Error! Try again!') 17 | continue 18 | else: 19 | class_name = eval.detect_img(image,mode='predict') 20 | print(class_name) 21 | -------------------------------------------------------------------------------- /readme.md: -------------------------------------------------------------------------------- 1 | # classification 2 | Pytorch 用于训练自己的图像分类模型 3 | 4 | # 环境要求 5 | pytorch 6 | opencv 7 | 8 | # conda虚拟环境一键导入: 9 | ```bash 10 | conda env create -f torch.yaml 11 | ``` 12 | 13 | # 理论地址: 14 | csdn 博客地址: 15 | MLP-Mixer:https://no-coding.blog.csdn.net/article/details/121682740 16 | 17 | Conv-Mixer:https://blog.csdn.net/qq_38676487/article/details/120705254 18 | 19 | ConvNeXt:https://blog.csdn.net/qq_38676487/article/details/123298605 20 | 21 | 22 | # How2Train 23 | ## 1.数据集: 24 | 25 | ```bash 26 | ─dataset 27 | ├─train 28 | │ └─cats 29 | │ └─xxjj.jpg 30 | │ └─dogs 31 | │ └─xxx.jpg 32 | ├─test 33 | │ └─cats 34 | │ └─dogs 35 | ``` 36 | 37 | classes.txt: 38 | 39 | ``` 40 | cats 41 | dogs 42 | ``` 43 | 44 | 在txt_annotation.py 中 calsses 与上述文件classes.txt 分类顺序一致,运行txt_annotation.py 生成 cls_train.txt, cls_text.txt 45 | 46 | ``` 47 | classes = ["cats", "dogs"] 48 | sets = ["train", "test"] 49 | ``` 50 | 51 | ## 2.训练 52 | 53 | ​ 在config 中配置训练参数: 54 | 55 | ```python 56 | Cuda = True #是否使用GPU 没有为Flase 57 | 58 | input_shape = [48,48] # 输入图片大小 59 | is_grayscale = False #灰度图训练为True 非灰度图为False 60 | batch_size = 2 # 自己可以更改 61 | lr = 1e-3 62 | 63 | classes_path = 'classes.txt' 64 | 65 | 66 | num_workers = 0 # 是否开启多进程 67 | 68 | 69 | train_annotation_path = 'cls_train.txt' # 训练集用于训练 70 | 71 | val_annotation_path = 'cls_test.txt' # 测试集合用于每个epoch 训练完测试。保证训练属于充分 72 | 73 | 74 | 75 | 76 | resume ='' # 加载训练权重路径 77 | 78 | log_dir = 'logs' # 日志路径 tensorboard 保存 79 | 80 | #------------------------------------------# 81 | # FocalLoss :处理样本不均衡 82 | # alpha 83 | # gamma >0 当 gamma=0 时就是交叉熵损失函数 84 | # 论文中gamma = [0,0.5,1,2,5] 85 | # 一般而言当γ增加的时候,a需要减小一点 86 | # reduction : 就平均:'mean' 求和 'sum' 87 | # 还未ti 88 | #------------------------------------------# 89 | #Focal_loss = True # True Focal loss 处理原本不均衡 False 使用 CrossEntropyLoss() # 还未使用成功 90 | 91 | #label_smoothing 防止过拟合 92 | label_smoothing = True # 93 | 94 | smoothing_value = 0.1 #[0,1] 之间 95 | 96 | 97 | 98 | #学习率变化策略 99 | scheduler = 'cos' #[None,reduce,cos] None保持不变 reduce 按epoch 来减少 cos 余弦下降算法 100 | 101 | 102 | 103 | ``` 104 | 105 | 在 trian.py 中 106 | 107 | ```python 108 | #---------------------------------------------------# 109 | # 定义模型,可在nets 导入自己的模型去训练, 110 | # 目前支持MLP-Mixer Conv-Mixer ConvNeXt系列模型 111 | # 只有ConvNeXt 支持pretrain 官方提供的权重 112 | #---------------------------------------------------# 113 | model = ConvMixer_768_32(n_classes=num_classes) 114 | ``` 115 | 116 | ## 日志查看 117 | 118 | 由于每次启动训练时,会在logs 文件下按照时间创建一个日志文件。如: 119 | 120 | ```bash 121 | tensorboard --logdir=logs\loss_2022_03_06_12_11_30 122 | ``` 123 | 124 | # How2Eval 125 | 126 | 在eval.py 中: 127 | 128 | ```python 129 | if __name__ == "__main__": 130 | 131 | # 读取测试集路劲和标签 132 | with open("./cls_test.txt","r") as f: 133 | lines = f.readlines() 134 | #---------------------------------------------------# 135 | # 权重和模型 136 | # 注意:训练时设置的模型需要和权重匹配, 137 | # 也就是训练的啥模型使用啥权重 138 | #---------------------------------------------------# 139 | model_path = '' #训练好的权重路径 140 | model = ConvMixer_768_32(n_classes=2) # 自己训练好的模型 141 | 142 | mode = load_dict(model_path,model) # 加载权重 143 | eval = eval_top(anno_lines=lines,model=model) 144 | #---------------------------------------------------# 145 | # top1 预测概率最好高的值与真实标签一致 √ 146 | # top5 预测概率前五个值由一个与真实标签一致 √ 147 | #---------------------------------------------------# 148 | print('start eval.....') 149 | top1 = eval.eval_top1() 150 | 151 | top5 = eval.eval_top5() 152 | print('top1:%.3f,top5:%3.f'%(top1,top5)) 153 | print('Eval Finished') 154 | 155 | ``` 156 | 157 | 158 | 159 | # How2Predict 160 | 161 | predict.py 中,设置好模型和权重,控制台输入图片路径。 162 | 163 | ```python 164 | #加载模型 165 | model_path = 'logs\ep050-loss0.414-val_loss0.376.pth' 166 | model = ConvMixer_768_32(n_classes=2) 167 | model = load_dict(model_path,model) 168 | eval = eval_top(anno_lines=None,model=model) 169 | 170 | while True: 171 | img = input('Input image filename:') 172 | try: 173 | image = Image.open(img) 174 | except: 175 | print('Open Error! Try again!') 176 | continue 177 | else: 178 | class_name = eval.detect_img(image,mode='predict') 179 | print(class_name) 180 | 181 | ``` 182 | 183 | 控制台: 184 | 185 | ```bash 186 | Loading weights into state dict... 187 | Input image filename:d:\Classification\torch\datasets\test\cats\cat.4006.jpg 188 | ``` 189 | 190 | 191 | ## 训练技巧和练丹 192 | 193 | + [ ] Focl_loss(样本不均衡策略) 194 | 195 | + [x] label_smoothing (训练样本偏少时,防止过拟合策略) 196 | + [x] 学习率衰减(使模型收敛更充分) 197 | 198 | 存在bug及其他问题私信:1308659229@qq.com 199 | 200 | # 其他 201 | 202 | 该仓库可能存在bug,希望大家在使用过程中能及时反馈,或者留下一些代码修改意见。我们一起让它变得更好。 203 | 204 | 205 | 206 | **如果觉得有用清给我点star** 207 | -------------------------------------------------------------------------------- /torch.yaml: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jiantenggei/torch-classification/ed5a39941bda5897aa3c2e3b1c5bb3d00806aea7/torch.yaml -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | from cProfile import label 2 | import numpy as np 3 | from config import * 4 | import torch.backends.cudnn as cudnn 5 | from utils.train_one_epoch import fit_one_epoch 6 | from utils.dataloader import DataGenerator,detection_collate 7 | from utils.utils import get_classes,weights_init,create_tbWriter 8 | from nets.resnet import ResNet18 9 | from nets.ConvMixer import ConvMixer_768_32 10 | import torch 11 | from torch import nn 12 | import torch.optim as optim 13 | from torch.utils.data import DataLoader 14 | import os 15 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 16 | from utils.training_utils import cross_entropy,smooth_one_hot 17 | import config 18 | import torch.utils.model_zoo as model_zoo 19 | 20 | def train(): 21 | # 获得分类数 22 | class_names, num_classes = get_classes(classes_path) 23 | #---------------------------------------------------# 24 | # 定义模型,可在nets 导入自己的模型去训练, 25 | # 目前支持MLP-Mixer Conv-Mixer ConvNeXt系列模型 26 | # 只有ConvNeXt 支持pretrain 官方提供的权重 27 | #---------------------------------------------------# 28 | model = ConvMixer_768_32(n_classes=num_classes) 29 | #初始化 30 | if resume == '': 31 | #---------------------------------------------------# 32 | # 初始模型的方式: str: normal xavier kaiming orthogonal 33 | #---------------------------------------------------# 34 | weights_init(model,init_type='normal') 35 | 36 | else: 37 | #载入训练过的权重 38 | print('Loading weights into state dict...') 39 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 40 | model_dict = model.state_dict() 41 | pretrained_dict = torch.load(resume, map_location=device) 42 | pretrained_dict = {k: v for k, v in pretrained_dict.items() if np.shape(model_dict[k]) == np.shape(v)} 43 | model_dict.update(pretrained_dict) 44 | model.load_state_dict(model_dict) 45 | 46 | #配置训练 47 | model_train = model.train() 48 | 49 | # 使用gpu 50 | if Cuda: 51 | model_train = torch.nn.DataParallel(model) 52 | cudnn.benchmark = True 53 | model_train = model_train.cuda() 54 | #这里我将创建一个文件下 将所有的配置参数都记录下来 55 | hyperparameters = 'shape_'+str(config.input_shape)+'_batch_size_'+ str(config.batch_size)+'_lr_'+str(config.lr)+'_isgary_'\ 56 | +str(config.is_grayscale)+'_lrscheduler_'+str(config.scheduler)+'_labelsmooth_' + str(config.label_smoothing)+str(config.smoothing_value) 57 | 58 | 59 | tb_writer = create_tbWriter(log_dir=log_dir,hyperparameters=hyperparameters) 60 | # 设置loss 61 | if label_smoothing: 62 | criterion = nn.CrossEntropyLoss(label_smoothing=smoothing_value) 63 | else: 64 | criterion = nn.CrossEntropyLoss() 65 | #读取 train 66 | with open(train_annotation_path, "r") as f: 67 | lines = f.readlines() 68 | 69 | np.random.seed(10101) 70 | np.random.shuffle(lines) 71 | np.random.seed(None) 72 | 73 | with open(val_annotation_path,'r') as f: 74 | val_lines = f.readlines() 75 | # num_val = int(len(lines) * val_split) 76 | num_val = len(val_lines) 77 | num_train = len(lines) 78 | 79 | #配置训练参数 80 | lr = config.lr 81 | Batch_size = config.batch_size 82 | Init_Epoch = 0 83 | End_Epoch = 50 84 | 85 | epoch_step = num_train // Batch_size 86 | epoch_step_val = num_val // Batch_size 87 | 88 | if epoch_step == 0 or epoch_step_val == 0: 89 | raise ValueError("数据集过小,无法进行训练,请扩充数据集。") 90 | 91 | optimizer = optim.Adam(model_train.parameters(), lr, weight_decay = 5e-4) 92 | 93 | if config.scheduler== 'reduce': 94 | scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer=optimizer,mode='max',factor=0.95,verbose=True) 95 | 96 | if config.scheduler == 'cos': 97 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer=optimizer,T_max=End_Epoch) 98 | 99 | 100 | train_dataset = DataGenerator(lines, input_shape, True,is_grayscale=is_grayscale) 101 | val_dataset = DataGenerator(val_lines, input_shape, False,is_grayscale=is_grayscale) 102 | gen = DataLoader(train_dataset, batch_size=Batch_size, num_workers=num_workers, pin_memory=True, 103 | drop_last=True, collate_fn=detection_collate) 104 | gen_val = DataLoader(val_dataset, batch_size=Batch_size, num_workers=num_workers, pin_memory=True, 105 | drop_last=True, collate_fn=detection_collate) 106 | 107 | for epoch in range(Init_Epoch,End_Epoch): 108 | train_loss,train_accuracy,val_loss,val_accuracy=fit_one_epoch(model_train, model, tb_writer, optimizer, epoch, epoch_step, epoch_step_val, gen, gen_val, End_Epoch, Cuda,criterion) 109 | #学习率 衰减 110 | if config.scheduler == 'reduce': 111 | scheduler.step(val_accuracy) 112 | 113 | if config.scheduler == 'cos': 114 | scheduler.step() 115 | 116 | 117 | 118 | 119 | if __name__ == "__main__": 120 | train() 121 | 122 | -------------------------------------------------------------------------------- /txt_annotation.py: -------------------------------------------------------------------------------- 1 | import os 2 | from os import getcwd 3 | 4 | #---------------------------------------------------# 5 | # 训练自己的数据集的时候一定要注意修改classes 6 | # 修改成自己数据集所区分的种类 7 | # 种类顺序需要和训练时用到的classes.txt一样 8 | # 生成cls_train.txt, cls_text.txt 9 | #---------------------------------------------------# 10 | classes = ["cats", "dogs"] 11 | sets = ["train", "test"] 12 | if __name__ == "__main__": 13 | wd = getcwd() 14 | for se in sets: 15 | list_file = open('cls_' + se + '.txt', 'w') 16 | 17 | datasets_path = "datasets/" + se 18 | types_name = os.listdir(datasets_path) 19 | for type_name in types_name: 20 | if type_name not in classes: 21 | continue 22 | cls_id = classes.index(type_name) 23 | 24 | photos_path = os.path.join(datasets_path, type_name) 25 | photos_name = os.listdir(photos_path) 26 | for photo_name in photos_name: 27 | _, postfix = os.path.splitext(photo_name) 28 | if postfix not in ['.jpg', '.png', '.jpeg']: 29 | continue 30 | list_file.write(str(cls_id) + ";" + '%s/%s'%(wd, os.path.join(photos_path, photo_name))) 31 | list_file.write('\n') 32 | list_file.close() 33 | 34 | -------------------------------------------------------------------------------- /utils/__pycache__/dataloader.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jiantenggei/torch-classification/ed5a39941bda5897aa3c2e3b1c5bb3d00806aea7/utils/__pycache__/dataloader.cpython-38.pyc -------------------------------------------------------------------------------- /utils/__pycache__/train_one_epoch.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jiantenggei/torch-classification/ed5a39941bda5897aa3c2e3b1c5bb3d00806aea7/utils/__pycache__/train_one_epoch.cpython-38.pyc -------------------------------------------------------------------------------- /utils/__pycache__/utils.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jiantenggei/torch-classification/ed5a39941bda5897aa3c2e3b1c5bb3d00806aea7/utils/__pycache__/utils.cpython-38.pyc -------------------------------------------------------------------------------- /utils/dataloader.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | from random import shuffle 3 | 4 | import cv2 5 | import numpy as np 6 | import torch.utils.data as data 7 | from PIL import Image 8 | from .utils import cvtColor 9 | 10 | class DataGenerator(data.Dataset): 11 | 12 | def __init__(self,annotation_lines, input_shape, random=True,is_grayscale=False): 13 | super().__init__() 14 | self.annotation_lines = annotation_lines 15 | self.input_shape = input_shape 16 | self.random = random 17 | self.is_grayscale = is_grayscale 18 | 19 | def __len__(self): 20 | return len(self.annotation_lines) 21 | 22 | def __getitem__(self, index): 23 | annotation_path = self.annotation_lines[index].split(';')[1].split()[0] 24 | image = Image.open(annotation_path) 25 | image = self.get_random_data(image, self.input_shape, random=self.random) 26 | #如果是是读取图,读取之后需要做一个转换 27 | if self.is_grayscale: 28 | image = cv2.cvtColor(image,cv2.COLOR_RGB2GRAY) 29 | #[w,h] - > [w,h,1] 30 | image = np.expand_dims(image,-1) 31 | #归一化处理吧通道提前[w,h,c] - > [c,w,h] 32 | #print(np.array(image).shape) 33 | image = np.transpose(self.preprocess_input(np.array(image).astype(np.float32)), [2, 0, 1]) 34 | y = int(self.annotation_lines[index].split(';')[0]) 35 | return image, y 36 | def rand(self, a=0, b=1): 37 | return np.random.rand()*(b-a) + a 38 | # 图片归一化处理 39 | def preprocess_input(self,image): 40 | image /= 127.5 41 | image -= 1. 42 | return image 43 | 44 | 45 | 46 | def get_random_data(self, image, input_shape, jitter=.3, hue=.1, sat=1.5, val=1.5, random=True): 47 | #------------------------------# 48 | # 读取图像并转换成RGB图像 49 | #------------------------------# 50 | image = cvtColor(image) 51 | #------------------------------# 52 | # 获得图像的高宽与目标高宽 53 | #------------------------------# 54 | iw, ih = image.size 55 | h, w = input_shape 56 | 57 | if not random: 58 | scale = min(w/iw, h/ih) 59 | nw = int(iw*scale) 60 | nh = int(ih*scale) 61 | dx = (w-nw)//2 62 | dy = (h-nh)//2 63 | 64 | #---------------------------------# 65 | # 将图像多余的部分加上灰条 66 | #---------------------------------# 67 | image = image.resize((nw,nh), Image.BICUBIC) 68 | new_image = Image.new('RGB', (w,h), (128,128,128)) 69 | new_image.paste(image, (dx, dy)) 70 | image_data = np.array(new_image, np.float32) 71 | 72 | return image_data 73 | 74 | #------------------------------------------# 75 | # 对图像进行缩放并且进行长和宽的扭曲 76 | #------------------------------------------# 77 | new_ar = w/h * self.rand(1-jitter,1+jitter)/self.rand(1-jitter,1+jitter) 78 | scale = self.rand(.75, 1.25) 79 | if new_ar < 1: 80 | nh = int(scale*h) 81 | nw = int(nh*new_ar) 82 | else: 83 | nw = int(scale*w) 84 | nh = int(nw/new_ar) 85 | image = image.resize((nw,nh), Image.BICUBIC) 86 | 87 | #------------------------------------------# 88 | # 将图像多余的部分加上灰条 89 | #------------------------------------------# 90 | dx = int(self.rand(0, w-nw)) 91 | dy = int(self.rand(0, h-nh)) 92 | new_image = Image.new('RGB', (w,h), (128,128,128)) 93 | new_image.paste(image, (dx, dy)) 94 | image = new_image 95 | 96 | #------------------------------------------# 97 | # 翻转图像 98 | #------------------------------------------# 99 | flip = self.rand()<.5 100 | if flip: image = image.transpose(Image.FLIP_LEFT_RIGHT) 101 | 102 | rotate = self.rand()<.5 103 | if rotate: 104 | angle = np.random.randint(-15,15) 105 | a,b = w/2,h/2 106 | M = cv2.getRotationMatrix2D((a,b),angle,1) 107 | image = cv2.warpAffine(np.array(image), M, (w,h), borderValue=[128,128,128]) 108 | 109 | #------------------------------------------# 110 | # 色域扭曲 111 | #------------------------------------------# 112 | hue = self.rand(-hue, hue) 113 | sat = self.rand(1, sat) if self.rand()<.5 else 1/self.rand(1, sat) 114 | val = self.rand(1, val) if self.rand()<.5 else 1/self.rand(1, val) 115 | x = cv2.cvtColor(np.array(image,np.float32)/255, cv2.COLOR_RGB2HSV) 116 | x[..., 1] *= sat 117 | x[..., 2] *= val 118 | x[x[:,:, 0]>360, 0] = 360 119 | x[:, :, 1:][x[:, :, 1:]>1] = 1 120 | x[x<0] = 0 121 | image_data = cv2.cvtColor(x, cv2.COLOR_HSV2RGB)*255 122 | return image_data 123 | 124 | def detection_collate(batch): 125 | images = [] 126 | targets = [] 127 | for image, y in batch: 128 | images.append(image) 129 | targets.append(y) 130 | images = np.array(images) 131 | targets = np.array(targets) 132 | return images, targets 133 | 134 | 135 | 136 | 137 | 138 | 139 | 140 | 141 | -------------------------------------------------------------------------------- /utils/train_one_epoch.py: -------------------------------------------------------------------------------- 1 | import imp 2 | import torch 3 | import torch.nn.functional as F 4 | from torch import nn 5 | from tqdm import tqdm 6 | from utils.utils import get_classes 7 | from config import classes_path 8 | import config 9 | #---------------------------------------------------# 10 | # 获得学习率 11 | #---------------------------------------------------# 12 | def get_lr(optimizer): 13 | for param_group in optimizer.param_groups: 14 | return param_group['lr'] 15 | #nn.BCEWithLogitsLoss是对网络的输出进行Sigmoid(); 交叉熵则是采用的Softmax 16 | def fit_one_epoch(model_train, model, tb_writer, optimizer, epoch, epoch_step, epoch_step_val, gen, gen_val, Epoch, cuda,criterion): 17 | 18 | # 记录日志啊 19 | 20 | tags = ["train_loss", "train_acc", "val_loss", "val_acc", "learning_rate"] 21 | train_loss = 0 22 | train_accuracy = 0 23 | val_accuracy = 0 24 | val_loss = 0 25 | _,classes = get_classes(classes_path) 26 | 27 | 28 | model_train.train() 29 | print('Start Train') 30 | with tqdm(total=epoch_step,desc=f'Epoch {epoch + 1}/{Epoch}',postfix=dict,mininterval=0.3) as pbar: 31 | for iteration, batch in enumerate(gen): 32 | if iteration >= epoch_step: 33 | break 34 | images, targets = batch 35 | with torch.no_grad(): 36 | images = torch.from_numpy(images).type(torch.FloatTensor) 37 | targets = torch.from_numpy(targets).type(torch.FloatTensor).long() 38 | if cuda: 39 | images = images.cuda() 40 | targets = targets.cuda() 41 | optimizer.zero_grad() 42 | outputs = model_train(images) 43 | loss_value = criterion(outputs,targets) 44 | loss_value.backward() 45 | optimizer.step() 46 | 47 | train_loss += loss_value.item() 48 | with torch.no_grad(): # 训练集准确率 49 | accuracy = torch.mean((torch.argmax(F.softmax(outputs, dim=-1), dim=-1) == targets).type(torch.FloatTensor)) 50 | train_accuracy += accuracy.item() 51 | 52 | pbar.set_postfix(**{'train_loss': train_loss / (iteration + 1), 53 | 'train_accuracy' : train_accuracy / (iteration + 1), 54 | 'lr' : get_lr(optimizer)}) 55 | pbar.update(1) 56 | 57 | print('Finish Train') 58 | 59 | model_train.eval() 60 | print('Start Validation') 61 | with tqdm(total=epoch_step_val, desc=f'Epoch {epoch + 1}/{Epoch}',postfix=dict,mininterval=0.3) as pbar: 62 | for iteration, batch in enumerate(gen_val): 63 | if iteration >= epoch_step_val: 64 | break 65 | images, targets = batch 66 | with torch.no_grad(): 67 | images = torch.from_numpy(images).type(torch.FloatTensor) 68 | targets = torch.from_numpy(targets).type(torch.FloatTensor).long() 69 | if cuda: 70 | images = images.cuda() 71 | targets = targets.cuda() 72 | 73 | optimizer.zero_grad() 74 | 75 | outputs = model_train(images) 76 | loss_value = criterion(outputs,targets) 77 | 78 | val_loss += loss_value.item() 79 | # 验证集准确率 80 | accuracy = torch.mean((torch.argmax(F.softmax(outputs, dim=-1), dim=-1) == targets).type(torch.FloatTensor)) 81 | val_accuracy += accuracy.item() 82 | pbar.set_postfix(**{'val_loss': val_loss / (iteration + 1), 83 | 'val_accuracy' : val_accuracy / (iteration + 1), 84 | 'lr' : get_lr(optimizer)}) 85 | pbar.update(1) 86 | 87 | tb_writer.add_scalar(tags[0], train_loss/epoch_step, epoch) 88 | tb_writer.add_scalar(tags[1], train_accuracy/epoch_step, epoch) 89 | tb_writer.add_scalar(tags[2], val_loss/epoch_step_val, epoch) 90 | tb_writer.add_scalar(tags[3], val_accuracy/epoch_step_val, epoch) 91 | tb_writer.add_scalar(tags[4], optimizer.param_groups[0]["lr"], epoch) 92 | 93 | print('Finish Validation') 94 | print('Epoch:' + str(epoch + 1) + '/' + str(Epoch)) 95 | print('Total Loss: %.3f || Val Loss: %.3f ' % (train_loss / epoch_step, val_loss / epoch_step_val)) 96 | torch.save(model.state_dict(), 'logs/ep%03d-loss%.3f-val_loss%.3f.pth'%((epoch + 1), train_loss / epoch_step, val_loss / epoch_step_val)) 97 | 98 | return train_loss/epoch_step,train_accuracy/epoch_step,val_loss/epoch_step_val,val_accuracy/epoch_step_val 99 | 100 | 101 | 102 | 103 | 104 | -------------------------------------------------------------------------------- /utils/training_utils.py: -------------------------------------------------------------------------------- 1 | 2 | import logging 3 | import os 4 | import random 5 | import numpy as np 6 | import torch.nn.functional as F 7 | import torch 8 | 9 | #------------------------------ 10 | # smooth_labels loss 函数 11 | #------------------------------ 12 | def cross_entropy(outputs, smooth_labels): 13 | loss = torch.nn.KLDivLoss(reduction='batchmean') # KL 散度 14 | return loss(F.log_softmax(outputs, dim=1), smooth_labels) 15 | 16 | 17 | 18 | def smooth_one_hot(true_labels: torch.Tensor, classes: int, smoothing=0.0): 19 | """ 20 | if smoothing == 0, it's one-hot method 21 | if 0 < smoothing < 1, it's smooth method 22 | 23 | """ 24 | device = true_labels.device 25 | true_labels = torch.nn.functional.one_hot( 26 | true_labels, classes).detach().cpu() 27 | assert 0 <= smoothing < 1 28 | confidence = 1.0 - smoothing 29 | label_shape = torch.Size((true_labels.size(0), classes)) 30 | with torch.no_grad(): 31 | true_dist = torch.empty( 32 | size=label_shape, device=true_labels.device) 33 | true_dist.fill_(smoothing / (classes - 1)) 34 | _, index = torch.max(true_labels, 1) 35 | 36 | true_dist.scatter_(1, torch.LongTensor( 37 | index.unsqueeze(1)), confidence) 38 | return true_dist.to(device) -------------------------------------------------------------------------------- /utils/utils.py: -------------------------------------------------------------------------------- 1 | from statistics import mode 2 | import numpy as np 3 | import torch 4 | from PIL import Image 5 | import datetime 6 | import os 7 | import cv2 8 | from torch.utils.tensorboard import SummaryWriter 9 | #---------------------------------------------------# 10 | # 转化为RGB 格式 11 | #---------------------------------------------------# 12 | def cvtColor(image): 13 | if len(np.shape(image)) == 3 and np.shape(image)[-2] == 3: 14 | return image 15 | else: 16 | image = image.convert('RGB') 17 | return image 18 | 19 | 20 | #---------------------------------------------------# 21 | # 获得类 22 | #---------------------------------------------------# 23 | def get_classes(classes_path): 24 | with open(classes_path, encoding='utf-8') as f: 25 | class_names = f.readlines() 26 | class_names = [c.strip() for c in class_names] 27 | return class_names, len(class_names) 28 | 29 | #---------------------------------------------------# 30 | # 初始化权重 31 | #---------------------------------------------------# 32 | def weights_init(net, init_type='normal', init_gain=0.02): 33 | def init_func(m): 34 | classname = m.__class__.__name__ 35 | if hasattr(m, 'weight') and classname.find('Conv') != -1: 36 | if init_type == 'normal': 37 | torch.nn.init.normal_(m.weight.data, 0.0, init_gain) 38 | elif init_type == 'xavier': 39 | torch.nn.init.xavier_normal_(m.weight.data, gain=init_gain) 40 | elif init_type == 'kaiming': 41 | torch.nn.init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') 42 | elif init_type == 'orthogonal': 43 | torch.nn.init.orthogonal_(m.weight.data, gain=init_gain) 44 | else: 45 | raise NotImplementedError('initialization method [%s] is not implemented' % init_type) 46 | elif classname.find('BatchNorm2d') != -1: 47 | torch.nn.init.normal_(m.weight.data, 1.0, 0.02) 48 | torch.nn.init.constant_(m.bias.data, 0.0) 49 | print('initialize network with %s type' % init_type) 50 | net.apply(init_func) 51 | 52 | 53 | #---------------------------------------------------# 54 | # 创建一个文件下,以超参数文件名 记录logs 55 | #---------------------------------------------------# 56 | def create_tbWriter(log_dir:str,hyperparameters:str): 57 | 58 | save_path = os.path.join(log_dir, "loss_" + str(hyperparameters)) 59 | 60 | os.makedirs(save_path) 61 | 62 | tb_writer = SummaryWriter(log_dir=save_path) 63 | 64 | return tb_writer 65 | 66 | #---------------------------------------------------# 67 | # 加灰条的resize 防止图片是真 68 | #---------------------------------------------------# 69 | def letterbox_image(image, size): 70 | iw, ih = image.size 71 | h, w = size 72 | scale = min(w/iw, h/ih) 73 | nw = int(iw*scale) 74 | nh = int(ih*scale) 75 | 76 | image = image.resize((nw,nh), Image.BICUBIC) 77 | new_image = Image.new('RGB', size, (128,128,128)) 78 | new_image.paste(image, ((w-nw)//2, (h-nh)//2)) 79 | return new_image 80 | 81 | #---------------------------------------------------# 82 | # 加载权重 83 | #---------------------------------------------------# 84 | def load_dict(model_path,model): 85 | #异常 86 | if model_path == "": 87 | raise ValueError("请设置模型权重路径!") 88 | print('Loading weights into state dict...') 89 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 90 | model_dict = model.state_dict() 91 | pretrained_dict = torch.load(model_path, map_location=device) 92 | pretrained_dict = {k: v for k, v in pretrained_dict.items() if np.shape(model_dict[k]) == np.shape(v)} 93 | model_dict.update(pretrained_dict) 94 | model.load_state_dict(model_dict) 95 | return model 96 | 97 | def get_config()->str: 98 | 99 | s = '' 100 | 101 | if __name__=='__main__': 102 | # 测试tb_writer 103 | t = create_tbWriter('logs') 104 | 105 | # image = Image.open(r'C:\Users\Jian\Pictures\Camera Roll\\1.jpg') 106 | # image = cvtColor(image) 107 | # image = letterbox_image(image,(2224,2225)) 108 | # image = np.array(image) 109 | # cv2.imshow('c',image) 110 | # cv2.waitKey(0) 111 | 112 | 113 | --------------------------------------------------------------------------------