├── weights └── README.md ├── class_indices.json ├── flops.py ├── README.md ├── my_dataset.py ├── predict.py ├── train.py ├── utils.py └── vit_model.py /weights/README.md: -------------------------------------------------------------------------------- 1 | 放置预训练权重,保存训练的权重参数 -------------------------------------------------------------------------------- /class_indices.json: -------------------------------------------------------------------------------- 1 | { 2 | "0": "daisy", 3 | "1": "dandelion", 4 | "2": "roses", 5 | "3": "sunflowers", 6 | "4": "tulips" 7 | } -------------------------------------------------------------------------------- /flops.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from fvcore.nn import FlopCountAnalysis 3 | 4 | from vit_model import Attention 5 | 6 | 7 | def main(): 8 | # Self-Attention 9 | a1 = Attention(dim=512, num_heads=1) 10 | a1.proj = torch.nn.Identity() # remove Wo 11 | 12 | # Multi-Head Attention 13 | a2 = Attention(dim=512, num_heads=8) 14 | 15 | # [batch_size, num_tokens, total_embed_dim] 16 | t = (torch.rand(32, 1024, 512),) 17 | 18 | flops1 = FlopCountAnalysis(a1, t) 19 | print("Self-Attention FLOPs:", flops1.total()) 20 | 21 | flops2 = FlopCountAnalysis(a2, t) 22 | print("Multi-Head Attention FLOPs:", flops2.total()) 23 | 24 | 25 | if __name__ == '__main__': 26 | main() 27 | 28 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ### 关于项目 2 | 3 | 本项目代码主要来源于 4 | 5 | https://github.com/WZMIAOMIAO/deep-learning-for-image-processing/tree/master/pytorch_classification/vision_transformer 6 | 7 | 在此基础上进行了书写了一些注释。 8 | 9 | ``` 10 | ├── vit_model.py: ViT模型搭建 11 | ├── weights: 权重文件保存的文件夹 12 | ├── train.py: 训练脚本 13 | ├── predict.py: 单张图像预测脚本 14 | ├── my_dataset.py: 重写dataset类,用于读取数据集 15 | ├── flops.py: 计算浮点量的代码 16 | └── utils.py:本项目涉及的常用操作的代码 17 | ``` 18 | 19 | ### 训练 20 | 21 | 在train.py脚本下,opt选择设置数据集和权重路径,设置batch_size等参数。 22 | 23 | ### 相关下载 24 | 25 | 本项目使用的是花分类数据集,首先需要下载花分类数据集,链接为 26 | 27 | http://download.tensorflow.org/example_images/flower_photos.tgz 28 | 29 | 预训练权重下载 30 | 31 | https://github.com/google-research/vision_transformer 32 | 33 | 本项目使用的pytorch版本的权重下载,其他的权重在vit_model.py上有下载链接 34 | 35 | https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_patch16_224_in21k-e5005f0a.pth 36 | 37 | 个人关于ViT解读的文章 38 | 39 | https://zhuanlan.zhihu.com/p/461077472 40 | -------------------------------------------------------------------------------- /my_dataset.py: -------------------------------------------------------------------------------- 1 | from PIL import Image 2 | import torch 3 | from torch.utils.data import Dataset 4 | 5 | 6 | class MyDataSet(Dataset): 7 | """自定义数据集""" 8 | 9 | def __init__(self, images_path: list, images_class: list, transform=None): 10 | self.images_path = images_path #图片路径 11 | self.images_class = images_class #图片类别 12 | self.transform = transform #图片数据transform变换 13 | 14 | def __len__(self): 15 | return len(self.images_path) 16 | 17 | def __getitem__(self, item): 18 | img = Image.open(self.images_path[item]) 19 | # RGB为彩色图片,L为灰度图片 20 | if img.mode != 'RGB': 21 | raise ValueError("image: {} isn't RGB mode.".format(self.images_path[item])) 22 | label = self.images_class[item] 23 | 24 | if self.transform is not None: 25 | img = self.transform(img) 26 | 27 | return img, label 28 | 29 | @staticmethod 30 | def collate_fn(batch): 31 | # 官方实现的default_collate可以参考 32 | # https://github.com/pytorch/pytorch/blob/67b7e751e6b5931a9f45274653f4f653a4e6cdf6/torch/utils/data/_utils/collate.py 33 | images, labels = tuple(zip(*batch)) 34 | 35 | images = torch.stack(images, dim=0) 36 | labels = torch.as_tensor(labels) 37 | return images, labels 38 | -------------------------------------------------------------------------------- /predict.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | 4 | import torch 5 | from PIL import Image 6 | from torchvision import transforms 7 | import matplotlib.pyplot as plt 8 | 9 | from vit_model import vit_base_patch16_224_in21k as create_model 10 | 11 | 12 | def main(): 13 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 14 | 15 | data_transform = transforms.Compose( 16 | [transforms.Resize(256), 17 | transforms.CenterCrop(224), 18 | transforms.ToTensor(), 19 | transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]) 20 | 21 | # load image 22 | img_path = "../tulip.jpg" 23 | assert os.path.exists(img_path), "file: '{}' dose not exist.".format(img_path) 24 | img = Image.open(img_path) 25 | plt.imshow(img) 26 | # [N, C, H, W] 27 | img = data_transform(img) 28 | # expand batch dimension 29 | img = torch.unsqueeze(img, dim=0) 30 | 31 | # read class_indict 32 | json_path = './class_indices.json' 33 | assert os.path.exists(json_path), "file: '{}' dose not exist.".format(json_path) 34 | 35 | json_file = open(json_path, "r") 36 | class_indict = json.load(json_file) 37 | 38 | # create model 39 | model = create_model(num_classes=5, has_logits=False).to(device) 40 | # load model weights 41 | model_weight_path = "./weights/model-9.pth" 42 | model.load_state_dict(torch.load(model_weight_path, map_location=device)) 43 | model.eval() 44 | with torch.no_grad(): 45 | # predict class 46 | output = torch.squeeze(model(img.to(device))).cpu() 47 | predict = torch.softmax(output, dim=0) 48 | predict_cla = torch.argmax(predict).numpy() 49 | 50 | print_res = "class: {} prob: {:.3}".format(class_indict[str(predict_cla)], 51 | predict[predict_cla].numpy()) 52 | plt.title(print_res) 53 | for i in range(len(predict)): 54 | print("class: {:10} prob: {:.3}".format(class_indict[str(i)], 55 | predict[i].numpy())) 56 | plt.show() 57 | 58 | 59 | if __name__ == '__main__': 60 | main() 61 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import math 3 | import argparse 4 | 5 | import torch 6 | import torch.optim as optim 7 | import torch.optim.lr_scheduler as lr_scheduler 8 | # from torch.utils.tensorboard import SummaryWriter 9 | from torchvision import transforms 10 | 11 | 12 | from my_dataset import MyDataSet 13 | from vit_model import vit_base_patch16_224_in21k as create_model 14 | from utils import read_split_data, train_one_epoch, evaluate 15 | 16 | 17 | def main(args): 18 | device = torch.device(args.device if torch.cuda.is_available() else "cpu") 19 | print("using device",device) 20 | 21 | if os.path.exists("./weights") is False: 22 | os.makedirs("./weights") 23 | 24 | # tb_writer = SummaryWriter() 25 | 26 | # 划分验证集和训练集,获得图片的路径和对应的分类标签 27 | train_images_path, train_images_label, val_images_path, val_images_label = read_split_data(args.data_path) 28 | 29 | # 数据进行transform变换 30 | data_transform = { 31 | "train": transforms.Compose([transforms.RandomResizedCrop(224), 32 | transforms.RandomHorizontalFlip(), 33 | transforms.ToTensor(), 34 | transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]), 35 | "val": transforms.Compose([transforms.Resize(256), 36 | transforms.CenterCrop(224), 37 | transforms.ToTensor(), 38 | transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])} 39 | 40 | # 实例化训练数据集 41 | train_dataset = MyDataSet(images_path=train_images_path, 42 | images_class=train_images_label, 43 | transform=data_transform["train"]) 44 | 45 | # 实例化验证数据集 46 | val_dataset = MyDataSet(images_path=val_images_path, 47 | images_class=val_images_label, 48 | transform=data_transform["val"]) 49 | 50 | batch_size = args.batch_size 51 | nw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8]) # number of workers 52 | print('Using {} dataloader workers every process'.format(nw)) 53 | # 设置训练集的数据加载器 54 | train_loader = torch.utils.data.DataLoader(train_dataset, 55 | batch_size=batch_size, 56 | shuffle=True, 57 | pin_memory=True, 58 | num_workers=nw, 59 | collate_fn=train_dataset.collate_fn) 60 | 61 | # 设置验证集的数据加载器 62 | val_loader = torch.utils.data.DataLoader(val_dataset, 63 | batch_size=batch_size, 64 | shuffle=False, 65 | pin_memory=True, 66 | num_workers=nw, 67 | collate_fn=val_dataset.collate_fn) 68 | 69 | model = create_model(num_classes=5, has_logits=False).to(device) 70 | 71 | if args.weights != "": 72 | assert os.path.exists(args.weights), "weights file: '{}' not exist.".format(args.weights) 73 | weights_dict = torch.load(args.weights, map_location=device) 74 | # 删除不需要的权重 75 | del_keys = ['head.weight', 'head.bias'] if model.has_logits \ 76 | else ['pre_logits.fc.weight', 'pre_logits.fc.bias', 'head.weight', 'head.bias'] 77 | for k in del_keys: 78 | del weights_dict[k] 79 | print(model.load_state_dict(weights_dict, strict=False)) 80 | 81 | if args.freeze_layers: 82 | for name, para in model.named_parameters(): 83 | # 除head, pre_logits外,其他权重全部冻结 84 | if "head" not in name and "pre_logits" not in name: 85 | para.requires_grad_(False) 86 | else: 87 | print("training {}".format(name)) 88 | 89 | pg = [p for p in model.parameters() if p.requires_grad] 90 | optimizer = optim.SGD(pg, lr=args.lr, momentum=0.9, weight_decay=5E-5) 91 | # Scheduler https://arxiv.org/pdf/1812.01187.pdf 92 | # 学习率策略,余弦退火算法 93 | lf = lambda x: ((1 + math.cos(x * math.pi / args.epochs)) / 2) * (1 - args.lrf) + args.lrf # cosine 94 | scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lf) 95 | 96 | for epoch in range(args.epochs): 97 | # train,训练一个epoch 98 | train_loss, train_acc = train_one_epoch(model=model, 99 | optimizer=optimizer, 100 | data_loader=train_loader, 101 | device=device, 102 | epoch=epoch) 103 | 104 | scheduler.step() 105 | 106 | # validate,验证一次,计算准确率 107 | val_loss, val_acc = evaluate(model=model, 108 | data_loader=val_loader, 109 | device=device, 110 | epoch=epoch) 111 | 112 | tags = ["train_loss", "train_acc", "val_loss", "val_acc", "learning_rate"] 113 | # tb_writer.add_scalar(tags[0], train_loss, epoch) 114 | # tb_writer.add_scalar(tags[1], train_acc, epoch) 115 | # tb_writer.add_scalar(tags[2], val_loss, epoch) 116 | # tb_writer.add_scalar(tags[3], val_acc, epoch) 117 | # tb_writer.add_scalar(tags[4], optimizer.param_groups[0]["lr"], epoch) 118 | 119 | torch.save(model.state_dict(), "./weights/model-{}.pth".format(epoch)) 120 | 121 | 122 | if __name__ == '__main__': 123 | parser = argparse.ArgumentParser() 124 | parser.add_argument('--num_classes', type=int, default=5) 125 | parser.add_argument('--epochs', type=int, default=10) 126 | parser.add_argument('--batch-size', type=int, default=2) 127 | parser.add_argument('--lr', type=float, default=0.001) 128 | parser.add_argument('--lrf', type=float, default=0.01) 129 | 130 | # 数据集所在根目录 131 | # http://download.tensorflow.org/example_images/flower_photos.tgz 132 | parser.add_argument('--data-path', type=str, 133 | default="/data/zdata/flower_data/flower_photos") 134 | parser.add_argument('--model-name', default='', help='create model name') 135 | 136 | # 预训练权重路径,如果不想载入就设置为空字符 137 | parser.add_argument('--weights', type=str, default='/data/zdata/weight/vit_base_patch16_224_in21k.pth', 138 | help='initial weights path') 139 | # 是否冻结权重 140 | parser.add_argument('--freeze-layers', type=bool, default=True) 141 | parser.add_argument('--device', default='cuda:0', help='device id (i.e. 0 or 0,1 or cpu)') 142 | 143 | opt = parser.parse_args() 144 | 145 | main(opt) 146 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import json 4 | import pickle 5 | import random 6 | 7 | import torch 8 | from tqdm import tqdm 9 | 10 | import matplotlib.pyplot as plt 11 | 12 | 13 | def read_split_data(root: str, val_rate: float = 0.2): 14 | random.seed(0) # 保证随机结果可复现 15 | assert os.path.exists(root), "dataset root: {} does not exist.".format(root) 16 | 17 | # 遍历文件夹,一个文件夹对应一个类别 18 | flower_class = [cla for cla in os.listdir(root) if os.path.isdir(os.path.join(root, cla))] 19 | # 排序,保证顺序一致 20 | flower_class.sort() 21 | # 生成类别名称以及对应的数字索引 22 | class_indices = dict((k, v) for v, k in enumerate(flower_class)) 23 | # 将类别信息写入json缓存下来 24 | json_str = json.dumps(dict((val, key) for key, val in class_indices.items()), indent=4) 25 | # 将缓存的json保存至json文件中 26 | with open('class_indices.json', 'w') as json_file: 27 | json_file.write(json_str) 28 | 29 | train_images_path = [] # 存储训练集的所有图片路径 30 | train_images_label = [] # 存储训练集图片对应索引信息 31 | val_images_path = [] # 存储验证集的所有图片路径 32 | val_images_label = [] # 存储验证集图片对应索引信息 33 | every_class_num = [] # 存储每个类别的样本总数 34 | supported = [".jpg", ".JPG", ".png", ".PNG"] # 支持的文件后缀类型 35 | # 遍历每个文件夹下的文件 36 | for cla in flower_class: 37 | cla_path = os.path.join(root, cla) 38 | # 遍历获取supported支持的所有文件路径 39 | images = [os.path.join(root, cla, i) for i in os.listdir(cla_path) 40 | if os.path.splitext(i)[-1] in supported] 41 | # 获取该类别对应的索引 42 | image_class = class_indices[cla] 43 | # 记录该类别的样本数量 44 | every_class_num.append(len(images)) 45 | # 按比例随机采样验证样本 46 | val_path = random.sample(images, k=int(len(images) * val_rate)) 47 | 48 | for img_path in images: 49 | if img_path in val_path: # 如果该路径在采样的验证集样本中则存入验证集 50 | val_images_path.append(img_path) 51 | val_images_label.append(image_class) 52 | else: # 否则存入训练集 53 | train_images_path.append(img_path) 54 | train_images_label.append(image_class) 55 | 56 | print("{} images were found in the dataset.".format(sum(every_class_num))) 57 | print("{} images for training.".format(len(train_images_path))) 58 | print("{} images for validation.".format(len(val_images_path))) 59 | 60 | plot_image = False 61 | if plot_image: 62 | # 绘制每种类别个数柱状图 63 | plt.bar(range(len(flower_class)), every_class_num, align='center') 64 | # 将横坐标0,1,2,3,4替换为相应的类别名称 65 | plt.xticks(range(len(flower_class)), flower_class) 66 | # 在柱状图上添加数值标签 67 | for i, v in enumerate(every_class_num): 68 | plt.text(x=i, y=v + 5, s=str(v), ha='center') 69 | # 设置x坐标 70 | plt.xlabel('image class') 71 | # 设置y坐标 72 | plt.ylabel('number of images') 73 | # 设置柱状图的标题 74 | plt.title('flower class distribution') 75 | plt.show() 76 | 77 | return train_images_path, train_images_label, val_images_path, val_images_label 78 | 79 | 80 | def plot_data_loader_image(data_loader): 81 | batch_size = data_loader.batch_size 82 | plot_num = min(batch_size, 4) 83 | 84 | json_path = './class_indices.json' 85 | assert os.path.exists(json_path), json_path + " does not exist." 86 | json_file = open(json_path, 'r') 87 | class_indices = json.load(json_file) 88 | 89 | for data in data_loader: 90 | images, labels = data 91 | for i in range(plot_num): 92 | # [C, H, W] -> [H, W, C] 93 | img = images[i].numpy().transpose(1, 2, 0) 94 | # 反Normalize操作 95 | img = (img * [0.229, 0.224, 0.225] + [0.485, 0.456, 0.406]) * 255 96 | label = labels[i].item() 97 | plt.subplot(1, plot_num, i+1) 98 | plt.xlabel(class_indices[str(label)]) 99 | plt.xticks([]) # 去掉x轴的刻度 100 | plt.yticks([]) # 去掉y轴的刻度 101 | plt.imshow(img.astype('uint8')) 102 | plt.show() 103 | 104 | 105 | def write_pickle(list_info: list, file_name: str): 106 | with open(file_name, 'wb') as f: 107 | pickle.dump(list_info, f) 108 | 109 | 110 | def read_pickle(file_name: str) -> list: 111 | with open(file_name, 'rb') as f: 112 | info_list = pickle.load(f) 113 | return info_list 114 | 115 | 116 | def train_one_epoch(model, optimizer, data_loader, device, epoch): 117 | model.train() 118 | loss_function = torch.nn.CrossEntropyLoss() 119 | accu_loss = torch.zeros(1).to(device) # 累计损失 120 | accu_num = torch.zeros(1).to(device) # 累计预测正确的样本数 121 | optimizer.zero_grad() 122 | 123 | sample_num = 0 124 | data_loader = tqdm(data_loader, file=sys.stdout) 125 | for step, data in enumerate(data_loader): 126 | images, labels = data 127 | sample_num += images.shape[0] 128 | 129 | pred = model(images.to(device)) 130 | pred_classes = torch.max(pred, dim=1)[1] #预测的类别,[1]是标签索引 131 | accu_num += torch.eq(pred_classes, labels.to(device)).sum() 132 | 133 | loss = loss_function(pred, labels.to(device)) 134 | loss.backward() 135 | accu_loss += loss.detach() 136 | 137 | data_loader.desc = "[train epoch {}] loss: {:.3f}, acc: {:.3f}".format(epoch, 138 | accu_loss.item() / (step + 1), 139 | accu_num.item() / sample_num) 140 | # 保证loss不会无穷大 141 | if not torch.isfinite(loss): 142 | print('WARNING: non-finite loss, ending training ', loss) 143 | sys.exit(1) 144 | 145 | optimizer.step() #更新 146 | optimizer.zero_grad() #梯度清零 147 | 148 | return accu_loss.item() / (step + 1), accu_num.item() / sample_num 149 | 150 | 151 | @torch.no_grad() 152 | def evaluate(model, data_loader, device, epoch): 153 | loss_function = torch.nn.CrossEntropyLoss() 154 | 155 | model.eval() 156 | 157 | accu_num = torch.zeros(1).to(device) # 累计预测正确的样本数 158 | accu_loss = torch.zeros(1).to(device) # 累计损失 159 | 160 | sample_num = 0 161 | data_loader = tqdm(data_loader, file=sys.stdout) 162 | for step, data in enumerate(data_loader): 163 | images, labels = data 164 | sample_num += images.shape[0] 165 | 166 | pred = model(images.to(device)) 167 | pred_classes = torch.max(pred, dim=1)[1] 168 | accu_num += torch.eq(pred_classes, labels.to(device)).sum() 169 | 170 | loss = loss_function(pred, labels.to(device)) 171 | accu_loss += loss 172 | 173 | data_loader.desc = "[valid epoch {}] loss: {:.3f}, acc: {:.3f}".format(epoch, 174 | accu_loss.item() / (step + 1), 175 | accu_num.item() / sample_num) 176 | 177 | return accu_loss.item() / (step + 1), accu_num.item() / sample_num 178 | -------------------------------------------------------------------------------- /vit_model.py: -------------------------------------------------------------------------------- 1 | """ 2 | original code from rwightman: 3 | https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py 4 | """ 5 | from functools import partial 6 | from collections import OrderedDict 7 | 8 | import torch 9 | import torch.nn as nn 10 | 11 | """ 12 | 说明:本代码关于维度相关的注释,均以vit的base模型为基础 13 | """ 14 | 15 | def drop_path(x, drop_prob: float = 0., training: bool = False): 16 | """ 17 | Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). 18 | This is the same as the DropConnect impl I created for EfficientNet, etc networks, however, 19 | the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... 20 | See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for 21 | changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 22 | 'survival rate' as the argument. 23 | """ 24 | if drop_prob == 0. or not training: 25 | return x 26 | keep_prob = 1 - drop_prob 27 | shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets 28 | random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device) 29 | random_tensor.floor_() # binarize 30 | output = x.div(keep_prob) * random_tensor 31 | return output 32 | 33 | class DropPath(nn.Module): 34 | """ 35 | Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). 36 | """ 37 | def __init__(self, drop_prob=None): 38 | super(DropPath, self).__init__() 39 | self.drop_prob = drop_prob 40 | 41 | def forward(self, x): 42 | return drop_path(x, self.drop_prob, self.training) 43 | 44 | class PatchEmbed(nn.Module): 45 | """ 46 | 2D Image to Patch Embedding,二维图像patch Embedding 47 | """ 48 | def __init__(self, img_size=224, patch_size=16, in_c=3, embed_dim=768, norm_layer=None): 49 | super().__init__() 50 | img_size = (img_size, img_size) # 图片尺寸224*224 51 | patch_size = (patch_size, patch_size) #下采样倍数,一个grid cell包含了16*16的图片信息 52 | self.img_size = img_size 53 | self.patch_size = patch_size 54 | # grid_size是经过patchembed后的特征层的尺寸 55 | self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1]) 56 | self.num_patches = self.grid_size[0] * self.grid_size[1] #path个数 14*14=196 57 | 58 | # 通过一个卷积,完成patchEmbed 59 | self.proj = nn.Conv2d(in_c, embed_dim, kernel_size=patch_size, stride=patch_size) 60 | # 如果使用了norm层,如BatchNorm2d,将通道数传入,以进行归一化,否则进行恒等映射 61 | self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() 62 | 63 | def forward(self, x): 64 | B, C, H, W = x.shape #batch,channels,heigth,weigth 65 | # 输入图片的尺寸要满足既定的尺寸 66 | assert H == self.img_size[0] and W == self.img_size[1], \ 67 | f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." 68 | 69 | # proj: [B, C, H, W] -> [B, C, H,W] , [B,3,224,224]-> [B,768,14,14] 70 | # flatten: [B, C, H, W] -> [B, C, HW] , [B,768,14,14]-> [B,768,196] 71 | # transpose: [B, C, HW] -> [B, HW, C] , [B,768,196]-> [B,196,768] 72 | x = self.proj(x).flatten(2).transpose(1, 2) 73 | x = self.norm(x) 74 | return x 75 | 76 | class Attention(nn.Module): 77 | """ 78 | muti-head attention模块,也是transformer最主要的操作 79 | """ 80 | def __init__(self, 81 | dim, # 输入token的dim,768 82 | num_heads=8, #muti-head的head个数,实例化时base尺寸的vit默认为12 83 | qkv_bias=False, 84 | qk_scale=None, 85 | attn_drop_ratio=0., 86 | proj_drop_ratio=0.): 87 | super(Attention, self).__init__() 88 | self.num_heads = num_heads 89 | head_dim = dim // num_heads #平均每个head的维度 90 | self.scale = qk_scale or head_dim ** -0.5 #进行query操作时,缩放因子 91 | # qkv矩阵相乘操作,dim * 3使得一次性进行qkv操作 92 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 93 | self.attn_drop = nn.Dropout(attn_drop_ratio) 94 | self.proj = nn.Linear(dim, dim) #一个卷积层 95 | self.proj_drop = nn.Dropout(proj_drop_ratio) 96 | 97 | def forward(self, x): 98 | # [batch_size, num_patches + 1, total_embed_dim] 如 [bactn,197,768] 99 | B, N, C = x.shape # N:197 , C:768 100 | 101 | # qkv进行注意力操作,reshape进行muti-head的维度分配,permute维度调换以便后续操作 102 | # qkv(): -> [batch_size, num_patches + 1, 3 * total_embed_dim] 如 [b,197,2304] 103 | # reshape: -> [batch_size, num_patches + 1, 3, num_heads, embed_dim_per_head] 如 [b,197,3,12,64] 104 | # permute: -> [3, batch_size, num_heads, num_patches + 1, embed_dim_per_head] 105 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 106 | # qkv的维度相同,[batch_size, num_heads, num_patches + 1, embed_dim_per_head] 107 | q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) 108 | 109 | # transpose: -> [batch_size, num_heads, embed_dim_per_head, num_patches + 1] 110 | # @: multiply -> [batch_size, num_heads, num_patches + 1, num_patches + 1] 111 | attn = (q @ k.transpose(-2, -1)) * self.scale #矩阵相乘操作 112 | attn = attn.softmax(dim=-1) #每一path进行softmax操作 113 | attn = self.attn_drop(attn) 114 | 115 | # [b,12,197,197]@[b,12,197,64] -> [b,12,197,64] 116 | # @: multiply -> [batch_size, num_heads, num_patches + 1, embed_dim_per_head] 117 | # 维度交换 transpose: -> [batch_size, num_patches + 1, num_heads, embed_dim_per_head] 118 | # reshape: -> [batch_size, num_patches + 1, total_embed_dim] 119 | x = (attn @ v).transpose(1, 2).reshape(B, N, C) 120 | x = self.proj(x) #经过一层卷积 121 | x = self.proj_drop(x) #Dropout 122 | return x 123 | 124 | 125 | class Mlp(nn.Module): 126 | """ 127 | MLP as used in Vision Transformer, MLP-Mixer and related networks 128 | """ 129 | def __init__(self, in_features, hidden_features=None, out_features=None, 130 | act_layer=nn.GELU, # GELU是更加平滑的relu 131 | drop=0.): 132 | super().__init__() 133 | out_features = out_features or in_features #如果out_features不存在,则为in_features 134 | hidden_features = hidden_features or in_features #如果hidden_features不存在,则为in_features 135 | self.fc1 = nn.Linear(in_features, hidden_features) # fc层1 136 | self.act = act_layer() #激活 137 | self.fc2 = nn.Linear(hidden_features, out_features) # fc层2 138 | self.drop = nn.Dropout(drop) 139 | 140 | def forward(self, x): 141 | x = self.fc1(x) 142 | x = self.act(x) 143 | x = self.drop(x) 144 | x = self.fc2(x) 145 | x = self.drop(x) 146 | return x 147 | 148 | 149 | class Block(nn.Module): 150 | """ 151 | 基本的Transformer模块 152 | """ 153 | def __init__(self, 154 | dim, 155 | num_heads, 156 | mlp_ratio=4., 157 | qkv_bias=False, 158 | qk_scale=None, 159 | drop_ratio=0., 160 | attn_drop_ratio=0., 161 | drop_path_ratio=0., 162 | act_layer=nn.GELU, 163 | norm_layer=nn.LayerNorm): 164 | super(Block, self).__init__() 165 | self.norm1 = norm_layer(dim) #norm层 166 | self.attn = Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, 167 | attn_drop_ratio=attn_drop_ratio, proj_drop_ratio=drop_ratio) 168 | # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here 169 | # 代码使用了DropPath,而不是原版的dropout 170 | self.drop_path = DropPath(drop_path_ratio) if drop_path_ratio > 0. else nn.Identity() 171 | self.norm2 = norm_layer(dim) #norm层 172 | mlp_hidden_dim = int(dim * mlp_ratio) #隐藏层维度扩张后的通道数 173 | # 多层感知机 174 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop_ratio) 175 | 176 | def forward(self, x): 177 | x = x + self.drop_path(self.attn(self.norm1(x))) # attention后残差连接 178 | x = x + self.drop_path(self.mlp(self.norm2(x))) # mlp后残差连接 179 | return x 180 | 181 | 182 | class VisionTransformer(nn.Module): 183 | def __init__(self, img_size=224, patch_size=16, in_c=3, num_classes=1000, 184 | embed_dim=768, depth=12, num_heads=12, mlp_ratio=4.0, qkv_bias=True, 185 | qk_scale=None, representation_size=None, distilled=False, drop_ratio=0., 186 | attn_drop_ratio=0., drop_path_ratio=0., embed_layer=PatchEmbed, norm_layer=None, 187 | act_layer=None): 188 | """ 189 | Args: 190 | img_size (int, tuple): input image size 191 | patch_size (int, tuple): patch size 192 | in_c (int): number of input channels 193 | num_classes (int): number of classes for classification head 194 | embed_dim (int): embedding dimension 195 | depth (int): depth of transformer 196 | num_heads (int): number of attention heads 197 | mlp_ratio (int): ratio of mlp hidden dim to embedding dim 198 | qkv_bias (bool): enable bias for qkv if True 199 | qk_scale (float): override default qk scale of head_dim ** -0.5 if set 200 | representation_size (Optional[int]): enable and set representation layer (pre-logits) to this value if set 201 | distilled (bool): model includes a distillation token and head as in DeiT models 202 | drop_ratio (float): dropout rate 203 | attn_drop_ratio (float): attention dropout rate 204 | drop_path_ratio (float): stochastic depth rate 205 | embed_layer (nn.Module): patch embedding layer 206 | norm_layer: (nn.Module): normalization layer 207 | """ 208 | super(VisionTransformer, self).__init__() 209 | self.num_classes = num_classes #分类类别数量 210 | self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models 211 | self.num_tokens = 2 if distilled else 1 #distilled在vit中没有使用到 212 | norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6) #层归一化 213 | act_layer = act_layer or nn.GELU #激活函数 214 | 215 | self.patch_embed = embed_layer(img_size=img_size, patch_size=patch_size, in_c=in_c, embed_dim=embed_dim) 216 | num_patches = self.patch_embed.num_patches 217 | 218 | self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) #[1,1,768],以0填充 219 | self.dist_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) if distilled else None 220 | self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim)) 221 | self.pos_drop = nn.Dropout(p=drop_ratio) 222 | 223 | # 按照block数量等间距设置drop率 224 | dpr = [x.item() for x in torch.linspace(0, drop_path_ratio, depth)] # stochastic depth decay rule 225 | self.blocks = nn.Sequential(*[ 226 | Block(dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, 227 | drop_ratio=drop_ratio, attn_drop_ratio=attn_drop_ratio, drop_path_ratio=dpr[i], 228 | norm_layer=norm_layer, act_layer=act_layer) 229 | for i in range(depth) 230 | ]) 231 | self.norm = norm_layer(embed_dim) # layer_norm 232 | 233 | # Representation layer 234 | if representation_size and not distilled: 235 | self.has_logits = True 236 | self.num_features = representation_size 237 | self.pre_logits = nn.Sequential(OrderedDict([ 238 | ("fc", nn.Linear(embed_dim, representation_size)), 239 | ("act", nn.Tanh()) 240 | ])) 241 | else: 242 | self.has_logits = False 243 | self.pre_logits = nn.Identity() 244 | 245 | # Classifier head(s),分类头,self.num_features=768 246 | self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() 247 | self.head_dist = None 248 | if distilled: 249 | self.head_dist = nn.Linear(self.embed_dim, self.num_classes) if num_classes > 0 else nn.Identity() 250 | 251 | # Weight init,权重初始化 252 | nn.init.trunc_normal_(self.pos_embed, std=0.02) 253 | if self.dist_token is not None: 254 | nn.init.trunc_normal_(self.dist_token, std=0.02) 255 | 256 | nn.init.trunc_normal_(self.cls_token, std=0.02) 257 | self.apply(_init_vit_weights) 258 | 259 | def forward_features(self, x): 260 | # [B, C, H, W] -> [B, num_patches, embed_dim] 261 | x = self.patch_embed(x) # [B, 196, 768] 262 | # cls_token类别token [1, 1, 768] -> [B, 1, 768],扩张为batch个cls_token 263 | cls_token = self.cls_token.expand(x.shape[0], -1, -1) 264 | if self.dist_token is None: 265 | x = torch.cat((cls_token, x), dim=1) # [B, 196, 768]-> [B, 197, 768],维度1上的cat 266 | else: 267 | x = torch.cat((cls_token, self.dist_token.expand(x.shape[0], -1, -1), x), dim=1) 268 | 269 | x = self.pos_drop(x + self.pos_embed) #添加位置嵌入信息 270 | x = self.blocks(x) #通过attention堆叠模块(12个) 271 | x = self.norm(x) #layer_norm 272 | if self.dist_token is None: 273 | return self.pre_logits(x[:, 0]) #返回第一层特征,即为分类值 274 | else: 275 | return x[:, 0], x[:, 1] 276 | 277 | def forward(self, x): 278 | # 分类头 279 | x = self.forward_features(x) # 经过att操作,但是没有进行分类头的前传 280 | if self.head_dist is not None: 281 | x, x_dist = self.head(x[0]), self.head_dist(x[1]) 282 | if self.training and not torch.jit.is_scripting(): 283 | # during inference, return the average of both classifier predictions 284 | return x, x_dist 285 | else: 286 | return (x + x_dist) / 2 287 | else: 288 | x = self.head(x) 289 | return x 290 | 291 | 292 | def _init_vit_weights(m): 293 | """ 294 | ViT weight initialization 295 | :param m: module 296 | """ 297 | if isinstance(m, nn.Linear): # fc层初始化 298 | nn.init.trunc_normal_(m.weight, std=.01) 299 | if m.bias is not None: 300 | nn.init.zeros_(m.bias) 301 | elif isinstance(m, nn.Conv2d): # conv层初始化 302 | nn.init.kaiming_normal_(m.weight, mode="fan_out") 303 | if m.bias is not None: 304 | nn.init.zeros_(m.bias) 305 | elif isinstance(m, nn.LayerNorm): #LayerNorm层初始化 306 | nn.init.zeros_(m.bias) 307 | nn.init.ones_(m.weight) 308 | 309 | 310 | def vit_base_patch16_224(num_classes: int = 1000): 311 | """ 312 | ViT-Base model (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929). 313 | ImageNet-1k weights @ 224x224, source https://github.com/google-research/vision_transformer. 314 | weights ported from official Google JAX impl: 315 | 链接: https://pan.baidu.com/s/1zqb08naP0RPqqfSXfkB2EA 密码: eu9f 316 | """ 317 | model = VisionTransformer(img_size=224, 318 | patch_size=16, 319 | embed_dim=768, 320 | depth=12, 321 | num_heads=12, 322 | representation_size=None, 323 | num_classes=num_classes) 324 | return model 325 | 326 | 327 | def vit_base_patch16_224_in21k(num_classes: int = 21843, has_logits: bool = True): 328 | """ 329 | ViT-Base model (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929). 330 | ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer. 331 | weights ported from official Google JAX impl: 332 | https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_patch16_224_in21k-e5005f0a.pth 333 | """ 334 | model = VisionTransformer(img_size=224, 335 | patch_size=16, 336 | embed_dim=768, 337 | depth=12, 338 | num_heads=12, 339 | representation_size=768 if has_logits else None, 340 | num_classes=num_classes) 341 | return model 342 | 343 | 344 | def vit_base_patch32_224(num_classes: int = 1000): 345 | """ 346 | ViT-Base model (ViT-B/32) from original paper (https://arxiv.org/abs/2010.11929). 347 | ImageNet-1k weights @ 224x224, source https://github.com/google-research/vision_transformer. 348 | weights ported from official Google JAX impl: 349 | 链接: https://pan.baidu.com/s/1hCv0U8pQomwAtHBYc4hmZg 密码: s5hl 350 | """ 351 | model = VisionTransformer(img_size=224, 352 | patch_size=32, 353 | embed_dim=768, 354 | depth=12, 355 | num_heads=12, 356 | representation_size=None, 357 | num_classes=num_classes) 358 | return model 359 | 360 | 361 | def vit_base_patch32_224_in21k(num_classes: int = 21843, has_logits: bool = True): 362 | """ 363 | ViT-Base model (ViT-B/32) from original paper (https://arxiv.org/abs/2010.11929). 364 | ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer. 365 | weights ported from official Google JAX impl: 366 | https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_patch32_224_in21k-8db57226.pth 367 | """ 368 | model = VisionTransformer(img_size=224, 369 | patch_size=32, 370 | embed_dim=768, 371 | depth=12, 372 | num_heads=12, 373 | representation_size=768 if has_logits else None, 374 | num_classes=num_classes) 375 | return model 376 | 377 | 378 | def vit_large_patch16_224(num_classes: int = 1000): 379 | """ 380 | ViT-Large model (ViT-L/16) from original paper (https://arxiv.org/abs/2010.11929). 381 | ImageNet-1k weights @ 224x224, source https://github.com/google-research/vision_transformer. 382 | weights ported from official Google JAX impl: 383 | 链接: https://pan.baidu.com/s/1cxBgZJJ6qUWPSBNcE4TdRQ 密码: qqt8 384 | """ 385 | model = VisionTransformer(img_size=224, 386 | patch_size=16, 387 | embed_dim=1024, 388 | depth=24, 389 | num_heads=16, 390 | representation_size=None, 391 | num_classes=num_classes) 392 | return model 393 | 394 | 395 | def vit_large_patch16_224_in21k(num_classes: int = 21843, has_logits: bool = True): 396 | """ 397 | ViT-Large model (ViT-L/16) from original paper (https://arxiv.org/abs/2010.11929). 398 | ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer. 399 | weights ported from official Google JAX impl: 400 | https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_patch16_224_in21k-606da67d.pth 401 | """ 402 | model = VisionTransformer(img_size=224, 403 | patch_size=16, 404 | embed_dim=1024, 405 | depth=24, 406 | num_heads=16, 407 | representation_size=1024 if has_logits else None, 408 | num_classes=num_classes) 409 | return model 410 | 411 | 412 | def vit_large_patch32_224_in21k(num_classes: int = 21843, has_logits: bool = True): 413 | """ 414 | ViT-Large model (ViT-L/32) from original paper (https://arxiv.org/abs/2010.11929). 415 | ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer. 416 | weights ported from official Google JAX impl: 417 | https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_patch32_224_in21k-9046d2e7.pth 418 | """ 419 | model = VisionTransformer(img_size=224, 420 | patch_size=32, 421 | embed_dim=1024, 422 | depth=24, 423 | num_heads=16, 424 | representation_size=1024 if has_logits else None, 425 | num_classes=num_classes) 426 | return model 427 | 428 | 429 | def vit_huge_patch14_224_in21k(num_classes: int = 21843, has_logits: bool = True): 430 | """ 431 | ViT-Huge model (ViT-H/14) from original paper (https://arxiv.org/abs/2010.11929). 432 | ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer. 433 | NOTE: converted weights not currently available, too large for github release hosting. 434 | """ 435 | model = VisionTransformer(img_size=224, 436 | patch_size=14, 437 | embed_dim=1280, 438 | depth=32, 439 | num_heads=16, 440 | representation_size=1280 if has_logits else None, 441 | num_classes=num_classes) 442 | return model 443 | --------------------------------------------------------------------------------