├── .gitignore ├── LICENSE ├── README.md ├── classification.py ├── datasets ├── test │ └── README.md └── train │ └── README.md ├── eval.py ├── img ├── cat.jpg └── dog.jpg ├── logs └── README.md ├── model_data ├── cls_classes.txt ├── mobilenet025_catvsdog.h5 └── mobilenet_2_5_224_tf_no_top.h5 ├── nets ├── __init__.py ├── mobilenetv1.py ├── mobilenetv2.py ├── resnet.py ├── swin_transformer.py ├── vgg.py └── vision_transformer.py ├── predict.py ├── requirements.txt ├── summary.py ├── train.py ├── txt_annotation.py └── utils ├── __init__.py ├── callbacks.py ├── dataloader.py ├── utils.py ├── utils_aug.py ├── utils_fit.py └── utils_metrics.py /.gitignore: -------------------------------------------------------------------------------- 1 | # ignore map, miou, datasets 2 | map_out/ 3 | miou_out/ 4 | VOCdevkit/ 5 | datasets/ 6 | Medical_Datasets/ 7 | lfw/ 8 | logs/ 9 | model_data/ 10 | metrics_out/ 11 | 12 | # Byte-compiled / optimized / DLL files 13 | __pycache__/ 14 | *.py[cod] 15 | *$py.class 16 | 17 | # C extensions 18 | *.so 19 | 20 | # Distribution / packaging 21 | .Python 22 | build/ 23 | develop-eggs/ 24 | dist/ 25 | downloads/ 26 | eggs/ 27 | .eggs/ 28 | lib/ 29 | lib64/ 30 | parts/ 31 | sdist/ 32 | var/ 33 | wheels/ 34 | pip-wheel-metadata/ 35 | share/python-wheels/ 36 | *.egg-info/ 37 | .installed.cfg 38 | *.egg 39 | MANIFEST 40 | 41 | # PyInstaller 42 | # Usually these files are written by a python script from a template 43 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 44 | *.manifest 45 | *.spec 46 | 47 | # Installer logs 48 | pip-log.txt 49 | pip-delete-this-directory.txt 50 | 51 | # Unit test / coverage reports 52 | htmlcov/ 53 | .tox/ 54 | .nox/ 55 | .coverage 56 | .coverage.* 57 | .cache 58 | nosetests.xml 59 | coverage.xml 60 | *.cover 61 | *.py,cover 62 | .hypothesis/ 63 | .pytest_cache/ 64 | 65 | # Translations 66 | *.mo 67 | *.pot 68 | 69 | # Django stuff: 70 | *.log 71 | local_settings.py 72 | db.sqlite3 73 | db.sqlite3-journal 74 | 75 | # Flask stuff: 76 | instance/ 77 | .webassets-cache 78 | 79 | # Scrapy stuff: 80 | .scrapy 81 | 82 | # Sphinx documentation 83 | docs/_build/ 84 | 85 | # PyBuilder 86 | target/ 87 | 88 | # Jupyter Notebook 89 | .ipynb_checkpoints 90 | 91 | # IPython 92 | profile_default/ 93 | ipython_config.py 94 | 95 | # pyenv 96 | .python-version 97 | 98 | # pipenv 99 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 100 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 101 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 102 | # install all needed dependencies. 103 | #Pipfile.lock 104 | 105 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 106 | __pypackages__/ 107 | 108 | # Celery stuff 109 | celerybeat-schedule 110 | celerybeat.pid 111 | 112 | # SageMath parsed files 113 | *.sage.py 114 | 115 | # Environments 116 | .env 117 | .venv 118 | env/ 119 | venv/ 120 | ENV/ 121 | env.bak/ 122 | venv.bak/ 123 | 124 | # Spyder project settings 125 | .spyderproject 126 | .spyproject 127 | 128 | # Rope project settings 129 | .ropeproject 130 | 131 | # mkdocs documentation 132 | /site 133 | 134 | # mypy 135 | .mypy_cache/ 136 | .dmypy.json 137 | dmypy.json 138 | 139 | # Pyre type checker 140 | .pyre/ 141 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Bubbliiiing 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## Classification:分类模型在Tensorflow2当中的实现 2 | --- 3 | 4 | ## 目录 5 | 1. [仓库更新 Top News](#仓库更新) 6 | 2. [所需环境 Environment](#所需环境) 7 | 3. [文件下载 Download](#文件下载) 8 | 4. [训练步骤 How2train](#训练步骤) 9 | 5. [预测步骤 How2predict](#预测步骤) 10 | 6. [评估步骤 How2eval](#评估步骤) 11 | 7. [参考资料 Reference](#Reference) 12 | 13 | ## Top News 14 | **`2022-03`**:**进行了大幅度的更新,支持step、cos学习率下降法、支持adam、sgd优化器选择、支持学习率根据batch_size自适应调整。** 15 | BiliBili视频中的原仓库地址为:https://github.com/bubbliiiing/classification-tf2/tree/bilibili 16 | 17 | **`2021-01`**:**仓库创建,支持模型训练,大量的注释,多个可调整参数。支持top1-top5的准确度评价。** 18 | 19 | ## 所需环境 20 | tensorflow-gpu==2.2.0 21 | 22 | ## 文件下载 23 | 训练所需的预训练权重都可以在百度云下载。 24 | 链接: https://pan.baidu.com/s/1z8zDREL1gFaGOhiVB5Xmdw 25 | 提取码: 4nmp 26 | 27 | 训练所用的示例猫狗数据集也可以在百度云下载。 28 | 链接: https://pan.baidu.com/s/1hYBNG0TnGIeWw1-SwkzqpA 29 | 提取码: ass8 30 | 31 | ## 训练步骤 32 | 1. datasets文件夹下存放的图片分为两部分,train里面是训练图片,test里面是测试图片。 33 | 2. 在训练之前需要首先准备好数据集,在train或者test文件里里面创建不同的文件夹,每个文件夹的名称为对应的类别名称,文件夹下面的图片为这个类的图片。文件格式可参考如下: 34 | ``` 35 | |-datasets 36 | |-train 37 | |-cat 38 | |-123.jpg 39 | |-234.jpg 40 | |-dog 41 | |-345.jpg 42 | |-456.jpg 43 | |-... 44 | |-test 45 | |-cat 46 | |-567.jpg 47 | |-678.jpg 48 | |-dog 49 | |-789.jpg 50 | |-890.jpg 51 | |-... 52 | ``` 53 | 3. 在准备好数据集后,需要在根目录运行txt_annotation.py生成训练所需的cls_train.txt,运行前需要修改其中的classes,将其修改成自己需要分的类。 54 | 4. 之后修改model_data文件夹下的cls_classes.txt,使其也对应自己需要分的类。 55 | 5. 在train.py里面调整自己要选择的网络和权重后,就可以开始训练了! 56 | 57 | ## 预测步骤 58 | ### a、使用预训练权重 59 | 1. 下载完库后解压,model_data已经存在一个训练好的猫狗模型mobilenet025_catvsdog.h5,运行predict.py,输入 60 | ```python 61 | img/cat.jpg 62 | ``` 63 | ### b、使用自己训练的权重 64 | 1. 按照训练步骤训练。 65 | 2. 在classification.py文件里面,在如下部分修改model_path、classes_path、backbone和alpha使其对应训练好的文件;**model_path对应logs文件夹下面的权值文件,classes_path是model_path对应分的类,backbone对应使用的主干特征提取网络,alpha是当使用mobilenet的alpha值**。 66 | ```python 67 | _defaults = { 68 | #--------------------------------------------------------------------------# 69 | # 使用自己训练好的模型进行预测一定要修改model_path和classes_path! 70 | # model_path指向logs文件夹下的权值文件,classes_path指向model_data下的txt 71 | # 如果出现shape不匹配,同时要注意训练时的model_path和classes_path参数的修改 72 | #--------------------------------------------------------------------------# 73 | "model_path" : 'model_data/mobilenet025_catvsdog.h5', 74 | "classes_path" : 'model_data/cls_classes.txt', 75 | #--------------------------------------------------------------------# 76 | # 输入的图片大小 77 | #--------------------------------------------------------------------# 78 | "input_shape" : [224, 224], 79 | #--------------------------------------------------------------------# 80 | # 所用模型种类: 81 | # mobilenet、resnet50、vgg16是常用的分类网络 82 | #--------------------------------------------------------------------# 83 | "backbone" : 'mobilenet', 84 | #--------------------------------------------------------------------# 85 | # 当使用mobilenet的alpha值 86 | # 仅在backbone='mobilenet'的时候有效 87 | #--------------------------------------------------------------------# 88 | "alpha" : 0.25 89 | } 90 | ``` 91 | 3. 运行predict.py,输入 92 | ```python 93 | img/cat.jpg 94 | ``` 95 | 96 | 97 | ## 评估步骤 98 | 1. datasets文件夹下存放的图片分为两部分,train里面是训练图片,test里面是测试图片,在评估的时候,我们使用的是test文件夹里面的图片。 99 | 2. 在评估之前需要首先准备好数据集,在train或者test文件里里面创建不同的文件夹,每个文件夹的名称为对应的类别名称,文件夹下面的图片为这个类的图片。文件格式可参考如下: 100 | ``` 101 | |-datasets 102 | |-train 103 | |-cat 104 | |-123.jpg 105 | |-234.jpg 106 | |-dog 107 | |-345.jpg 108 | |-456.jpg 109 | |-... 110 | |-test 111 | |-cat 112 | |-567.jpg 113 | |-678.jpg 114 | |-dog 115 | |-789.jpg 116 | |-890.jpg 117 | |-... 118 | ``` 119 | 3. 在准备好数据集后,需要在根目录运行txt_annotation.py生成评估所需的cls_test.txt,运行前需要修改其中的classes,将其修改成自己需要分的类。 120 | 4. 之后在classification.py文件里面修改如下部分model_path、classes_path、backbone和alpha使其对应训练好的文件;**model_path对应logs文件夹下面的权值文件,classes_path是model_path对应分的类,backbone对应使用的主干特征提取网络,alpha是当使用mobilenet的alpha值**。 121 | ```python 122 | _defaults = { 123 | #--------------------------------------------------------------------------# 124 | # 使用自己训练好的模型进行预测一定要修改model_path和classes_path! 125 | # model_path指向logs文件夹下的权值文件,classes_path指向model_data下的txt 126 | # 如果出现shape不匹配,同时要注意训练时的model_path和classes_path参数的修改 127 | #--------------------------------------------------------------------------# 128 | "model_path" : 'model_data/mobilenet025_catvsdog.h5', 129 | "classes_path" : 'model_data/cls_classes.txt', 130 | #--------------------------------------------------------------------# 131 | # 输入的图片大小 132 | #--------------------------------------------------------------------# 133 | "input_shape" : [224, 224], 134 | #--------------------------------------------------------------------# 135 | # 所用模型种类: 136 | # mobilenet、resnet50、vgg16是常用的分类网络 137 | #--------------------------------------------------------------------# 138 | "backbone" : 'mobilenet', 139 | #--------------------------------------------------------------------# 140 | # 当使用mobilenet的alpha值 141 | # 仅在backbone='mobilenet'的时候有效 142 | #--------------------------------------------------------------------# 143 | "alpha" : 0.25 144 | } 145 | ``` 146 | 5. 运行eval_top1.py和eval_top5.py来进行模型准确率评估。 147 | 148 | ## Reference 149 | https://github.com/keras-team/keras-applications 150 | -------------------------------------------------------------------------------- /classification.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import matplotlib.pyplot as plt 4 | import numpy as np 5 | 6 | from nets import get_model_from_name 7 | from utils.utils import (cvtColor, get_classes, letterbox_image, 8 | preprocess_input, show_config) 9 | 10 | 11 | #--------------------------------------------# 12 | # 使用自己训练好的模型预测需要修改4个参数 13 | # model_path和classes_path、backbone 14 | # 和alpha都需要修改! 15 | #--------------------------------------------# 16 | class Classification(object): 17 | _defaults = { 18 | #--------------------------------------------------------------------------# 19 | # 使用自己训练好的模型进行预测一定要修改model_path和classes_path! 20 | # model_path指向logs文件夹下的权值文件,classes_path指向model_data下的txt 21 | # 如果出现shape不匹配,同时要注意训练时的model_path和classes_path参数的修改 22 | #--------------------------------------------------------------------------# 23 | "model_path" : 'model_data/mobilenet025_catvsdog.h5', 24 | "classes_path" : 'model_data/cls_classes.txt', 25 | #--------------------------------------------------------------------# 26 | # 输入的图片大小 27 | #--------------------------------------------------------------------# 28 | "input_shape" : [224, 224], 29 | #--------------------------------------------------------------------# 30 | # 所用模型种类: 31 | # mobilenetv1、mobilenetv2、resnet50、vgg16、 32 | # vit_b_16、 33 | # swin_transformer_tiny、swin_transformer_small、swin_transformer_base 34 | #--------------------------------------------------------------------# 35 | "backbone" : 'mobilenetv1', 36 | #--------------------------------------------------------------------# 37 | # 当使用mobilenetv1的alpha值 38 | # 仅在backbone='mobilenetv1'的时候有效 39 | #--------------------------------------------------------------------# 40 | "alpha" : 0.25, 41 | #--------------------------------------------------------------------# 42 | # 该变量用于控制是否使用letterbox_image对输入图像进行不失真的resize 43 | # 否则对图像进行CenterCrop 44 | #--------------------------------------------------------------------# 45 | "letterbox_image" : False, 46 | } 47 | 48 | @classmethod 49 | def get_defaults(cls, n): 50 | if n in cls._defaults: 51 | return cls._defaults[n] 52 | else: 53 | return "Unrecognized attribute name '" + n + "'" 54 | 55 | #---------------------------------------------------# 56 | # 初始化classification 57 | #---------------------------------------------------# 58 | def __init__(self, **kwargs): 59 | self.__dict__.update(self._defaults) 60 | for name, value in kwargs.items(): 61 | setattr(self, name, value) 62 | 63 | #---------------------------------------------------# 64 | # 获得种类 65 | #---------------------------------------------------# 66 | self.class_names, self.num_classes = get_classes(self.classes_path) 67 | self.generate() 68 | 69 | show_config(**self._defaults) 70 | 71 | #---------------------------------------------------# 72 | # 载入模型 73 | #---------------------------------------------------# 74 | def generate(self): 75 | model_path = os.path.expanduser(self.model_path) 76 | assert model_path.endswith('.h5'), 'Keras model or weights must be a .h5 file.' 77 | 78 | #---------------------------------------------------# 79 | # 载入模型与权值 80 | #---------------------------------------------------# 81 | if self.backbone == "mobilenetv1": 82 | self.model = get_model_from_name[self.backbone](input_shape = [self.input_shape[0], self.input_shape[1], 3], classes = self.num_classes, alpha = self.alpha) 83 | else: 84 | self.model = get_model_from_name[self.backbone](input_shape = [self.input_shape[0], self.input_shape[1], 3], classes = self.num_classes) 85 | self.model.load_weights(self.model_path, by_name=True) 86 | print('{} model, and classes loaded.'.format(model_path)) 87 | 88 | #---------------------------------------------------# 89 | # 检测图片 90 | #---------------------------------------------------# 91 | def detect_image(self, image): 92 | #---------------------------------------------------------# 93 | # 在这里将图像转换成RGB图像,防止灰度图在预测时报错。 94 | # 代码仅仅支持RGB图像的预测,所有其它类型的图像都会转化成RGB 95 | #---------------------------------------------------------# 96 | image = cvtColor(image) 97 | #---------------------------------------------------# 98 | # 对图片进行不失真的resize 99 | #---------------------------------------------------# 100 | image_data = letterbox_image(image, [self.input_shape[1], self.input_shape[0]], self.letterbox_image) 101 | #---------------------------------------------------------# 102 | # 归一化+添加上batch_size维度 103 | #---------------------------------------------------------# 104 | image_data = np.expand_dims(preprocess_input(np.array(image_data, np.float32)), 0) 105 | 106 | #---------------------------------------------------# 107 | # 图片传入网络进行预测 108 | #---------------------------------------------------# 109 | preds = self.model.predict(image_data)[0] 110 | #---------------------------------------------------# 111 | # 获得所属种类 112 | #---------------------------------------------------# 113 | class_name = self.class_names[np.argmax(preds)] 114 | probability = np.max(preds) 115 | 116 | #---------------------------------------------------# 117 | # 绘图并写字 118 | #---------------------------------------------------# 119 | plt.subplot(1, 1, 1) 120 | plt.imshow(np.array(image)) 121 | plt.title('Class:%s Probability:%.3f' %(class_name, probability)) 122 | plt.show() 123 | return class_name 124 | -------------------------------------------------------------------------------- /datasets/test/README.md: -------------------------------------------------------------------------------- 1 | 用于存放测试图片 -------------------------------------------------------------------------------- /datasets/train/README.md: -------------------------------------------------------------------------------- 1 | 用于存放训练图片 -------------------------------------------------------------------------------- /eval.py: -------------------------------------------------------------------------------- 1 | #------------------------------------------------------# 2 | # 该eval文件会自动计算 3 | # Top1 acc 4 | # Top5 acc 5 | # Recall 6 | # Precision 7 | # 结果会保留在metrics_out文件夹中 8 | #------------------------------------------------------# 9 | import os 10 | 11 | import numpy as np 12 | import tensorflow as tf 13 | 14 | from classification import Classification, cvtColor, preprocess_input 15 | from utils.utils import letterbox_image 16 | from utils.utils_metrics import evaluteTop1_5 17 | 18 | gpus = tf.config.experimental.list_physical_devices(device_type='GPU') 19 | for gpu in gpus: 20 | tf.config.experimental.set_memory_growth(gpu, True) 21 | 22 | #------------------------------------------------------# 23 | # test_annotation_path 测试图片路径和标签 24 | #------------------------------------------------------# 25 | test_annotation_path = 'cls_test.txt' 26 | #------------------------------------------------------# 27 | # metrics_out_path 指标保存的文件夹 28 | #------------------------------------------------------# 29 | metrics_out_path = "metrics_out" 30 | 31 | class Eval_Classification(Classification): 32 | def detect_image(self, image): 33 | #---------------------------------------------------------# 34 | # 在这里将图像转换成RGB图像,防止灰度图在预测时报错。 35 | # 代码仅仅支持RGB图像的预测,所有其它类型的图像都会转化成RGB 36 | #---------------------------------------------------------# 37 | image = cvtColor(image) 38 | #---------------------------------------------------# 39 | # 对图片进行不失真的resize 40 | #---------------------------------------------------# 41 | image_data = letterbox_image(image, [self.input_shape[1], self.input_shape[0]], self.letterbox_image) 42 | #---------------------------------------------------------# 43 | # 归一化+添加上batch_size维度 44 | #---------------------------------------------------------# 45 | image_data = np.expand_dims(preprocess_input(np.array(image_data, np.float32)), 0) 46 | 47 | #---------------------------------------------------# 48 | # 图片传入网络进行预测 49 | #---------------------------------------------------# 50 | preds = self.model.predict(image_data)[0] 51 | return preds 52 | 53 | if __name__ == "__main__": 54 | if not os.path.exists(metrics_out_path): 55 | os.makedirs(metrics_out_path) 56 | 57 | classfication = Eval_Classification() 58 | 59 | with open("./cls_test.txt","r") as f: 60 | lines = f.readlines() 61 | top1, top5, Recall, Precision = evaluteTop1_5(classfication, lines, metrics_out_path) 62 | print("top-1 accuracy = %.2f%%" % (top1*100)) 63 | print("top-5 accuracy = %.2f%%" % (top5*100)) 64 | print("mean Recall = %.2f%%" % (np.mean(Recall)*100)) 65 | print("mean Precision = %.2f%%" % (np.mean(Precision)*100)) 66 | -------------------------------------------------------------------------------- /img/cat.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bubbliiiing/classification-tf2/112c1f8aac02830070ffa62f743c114cbf0d3608/img/cat.jpg -------------------------------------------------------------------------------- /img/dog.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bubbliiiing/classification-tf2/112c1f8aac02830070ffa62f743c114cbf0d3608/img/dog.jpg -------------------------------------------------------------------------------- /logs/README.md: -------------------------------------------------------------------------------- 1 | 存放训练后的模型 -------------------------------------------------------------------------------- /model_data/cls_classes.txt: -------------------------------------------------------------------------------- 1 | cat 2 | dog -------------------------------------------------------------------------------- /model_data/mobilenet025_catvsdog.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bubbliiiing/classification-tf2/112c1f8aac02830070ffa62f743c114cbf0d3608/model_data/mobilenet025_catvsdog.h5 -------------------------------------------------------------------------------- /model_data/mobilenet_2_5_224_tf_no_top.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bubbliiiing/classification-tf2/112c1f8aac02830070ffa62f743c114cbf0d3608/model_data/mobilenet_2_5_224_tf_no_top.h5 -------------------------------------------------------------------------------- /nets/__init__.py: -------------------------------------------------------------------------------- 1 | from .mobilenetv1 import MobileNetV1 2 | from .mobilenetv2 import MobileNetV2 3 | from .resnet import ResNet50 4 | from .swin_transformer import (swin_transformer_base, swin_transformer_small, 5 | swin_transformer_tiny) 6 | from .vgg import VGG16 7 | from .vision_transformer import VisionTransformer 8 | 9 | get_model_from_name = { 10 | "mobilenetv1" : MobileNetV1, 11 | "mobilenetv2" : MobileNetV2, 12 | "resnet50" : ResNet50, 13 | "vgg16" : VGG16, 14 | "vit_b_16" : VisionTransformer, 15 | "swin_transformer_tiny" : swin_transformer_tiny, 16 | "swin_transformer_small" : swin_transformer_small, 17 | "swin_transformer_base" : swin_transformer_base 18 | } 19 | 20 | freeze_layers = { 21 | "mobilenetv1" : 81, 22 | "mobilenetv2" : 151, 23 | "resnet50" : 173, 24 | "vgg16" : 19, 25 | "vit_b_16" : 130, 26 | "swin_transformer_tiny" : 181, 27 | "swin_transformer_small" : 350, 28 | "swin_transformer_base" : 350 29 | } 30 | -------------------------------------------------------------------------------- /nets/mobilenetv1.py: -------------------------------------------------------------------------------- 1 | from tensorflow.keras import backend as K 2 | from tensorflow.keras.layers import (Activation, BatchNormalization, Conv2D, 3 | DepthwiseConv2D, Dropout, 4 | GlobalAveragePooling2D, Input, Reshape) 5 | from tensorflow.keras.models import Model 6 | 7 | 8 | def _conv_block(inputs, filters, alpha, kernel=(3, 3), strides=(1, 1)): 9 | filters = int(filters * alpha) 10 | x = Conv2D(filters, kernel, 11 | padding='same', 12 | use_bias=False, 13 | strides=strides, 14 | name='conv1')(inputs) 15 | x = BatchNormalization(name='conv1_bn')(x) 16 | return Activation(relu6, name='conv1_relu')(x) 17 | 18 | 19 | def _depthwise_conv_block(inputs, pointwise_conv_filters, alpha, 20 | depth_multiplier=1, strides=(1, 1), block_id=1): 21 | 22 | pointwise_conv_filters = int(pointwise_conv_filters * alpha) 23 | 24 | x = DepthwiseConv2D((3, 3), 25 | padding='same', 26 | depth_multiplier=depth_multiplier, 27 | strides=strides, 28 | use_bias=False, 29 | name='conv_dw_%d' % block_id)(inputs) 30 | 31 | x = BatchNormalization(name='conv_dw_%d_bn' % block_id)(x) 32 | x = Activation(relu6, name='conv_dw_%d_relu' % block_id)(x) 33 | 34 | x = Conv2D(pointwise_conv_filters, (1, 1), 35 | padding='same', 36 | use_bias=False, 37 | strides=(1, 1), 38 | name='conv_pw_%d' % block_id)(x) 39 | x = BatchNormalization(name='conv_pw_%d_bn' % block_id)(x) 40 | return Activation(relu6, name='conv_pw_%d_relu' % block_id)(x) 41 | 42 | def MobileNetV1(input_shape=None, 43 | alpha=1.0, 44 | depth_multiplier=1, 45 | dropout=1e-3, 46 | classes=1000): 47 | 48 | img_input = Input(shape=input_shape) 49 | 50 | # 224,224,3 -> 112,112,32 51 | x = _conv_block(img_input, 32, alpha, strides=(2, 2)) 52 | 53 | # 112,112,32 -> 112,112,64 54 | x = _depthwise_conv_block(x, 64, alpha, depth_multiplier, block_id=1) 55 | 56 | 57 | # 112,112,64 -> 56,56,128 58 | x = _depthwise_conv_block(x, 128, alpha, depth_multiplier, 59 | strides=(2, 2), block_id=2) 60 | x = _depthwise_conv_block(x, 128, alpha, depth_multiplier, block_id=3) 61 | 62 | 63 | # 56,56,128 -> 28,28,256 64 | x = _depthwise_conv_block(x, 256, alpha, depth_multiplier, 65 | strides=(2, 2), block_id=4) 66 | x = _depthwise_conv_block(x, 256, alpha, depth_multiplier, block_id=5) 67 | 68 | 69 | # 28,28,256 -> 14,14,512 70 | x = _depthwise_conv_block(x, 512, alpha, depth_multiplier, 71 | strides=(2, 2), block_id=6) 72 | x = _depthwise_conv_block(x, 512, alpha, depth_multiplier, block_id=7) 73 | x = _depthwise_conv_block(x, 512, alpha, depth_multiplier, block_id=8) 74 | x = _depthwise_conv_block(x, 512, alpha, depth_multiplier, block_id=9) 75 | x = _depthwise_conv_block(x, 512, alpha, depth_multiplier, block_id=10) 76 | x = _depthwise_conv_block(x, 512, alpha, depth_multiplier, block_id=11) 77 | 78 | # 14,14,512 -> 7,7,1024 79 | x = _depthwise_conv_block(x, 1024, alpha, depth_multiplier, 80 | strides=(2, 2), block_id=12) 81 | x = _depthwise_conv_block(x, 1024, alpha, depth_multiplier, block_id=13) 82 | 83 | # 7,7,1024 -> 1,1,1024 84 | x = GlobalAveragePooling2D()(x) 85 | 86 | shape = (1, 1, int(1024 * alpha)) 87 | 88 | x = Reshape(shape, name='reshape_1')(x) 89 | x = Dropout(dropout, name='dropout')(x) 90 | 91 | x = Conv2D(classes, (1, 1),padding='same', name='conv_preds')(x) 92 | x = Activation('softmax', name='act_softmax')(x) 93 | x = Reshape((classes,), name='reshape_2')(x) 94 | 95 | inputs = img_input 96 | 97 | model = Model(inputs, x, name='mobilenet_%0.2f' % (alpha)) 98 | return model 99 | 100 | def relu6(x): 101 | return K.relu(x, max_value=6) 102 | 103 | if __name__ == '__main__': 104 | model = MobileNetV1(input_shape=(224, 224, 3)) 105 | model.summary() 106 | -------------------------------------------------------------------------------- /nets/mobilenetv2.py: -------------------------------------------------------------------------------- 1 | #-------------------------------------------------------------# 2 | # MobileNetV2的网络部分 3 | #-------------------------------------------------------------# 4 | from tensorflow.keras import backend as K 5 | from tensorflow.keras.layers import (Activation, BatchNormalization, Conv2D, ZeroPadding2D, Add, Dense, 6 | DepthwiseConv2D, Dropout, 7 | GlobalAveragePooling2D, Input, Reshape) 8 | from tensorflow.keras.models import Model 9 | from tensorflow.keras.initializers import RandomNormal 10 | 11 | 12 | # relu6! 13 | def relu6(x): 14 | return K.relu(x, max_value=6) 15 | 16 | # 用于计算padding的大小 17 | def correct_pad(inputs, kernel_size): 18 | img_dim = 1 19 | input_size = K.int_shape(inputs)[img_dim:(img_dim + 2)] 20 | 21 | if isinstance(kernel_size, int): 22 | kernel_size = (kernel_size, kernel_size) 23 | 24 | if input_size[0] is None: 25 | adjust = (1, 1) 26 | else: 27 | adjust = (1 - input_size[0] % 2, 1 - input_size[1] % 2) 28 | 29 | correct = (kernel_size[0] // 2, kernel_size[1] // 2) 30 | 31 | return ((correct[0] - adjust[0], correct[0]), 32 | (correct[1] - adjust[1], correct[1])) 33 | 34 | # 使其结果可以被8整除,因为使用到了膨胀系数α 35 | def _make_divisible(v, divisor, min_value=None): 36 | if min_value is None: 37 | min_value = divisor 38 | new_v = max(min_value, int(v + divisor / 2) // divisor * divisor) 39 | if new_v < 0.9 * v: 40 | new_v += divisor 41 | return new_v 42 | 43 | def _inverted_res_block(inputs, expansion, stride, alpha, filters, block_id): 44 | in_channels = K.int_shape(inputs)[-1] 45 | pointwise_conv_filters = int(filters * alpha) 46 | pointwise_filters = _make_divisible(pointwise_conv_filters, 8) 47 | 48 | x = inputs 49 | prefix = 'block_{}_'.format(block_id) 50 | # part1 数据扩张 51 | if block_id: 52 | # Expand 53 | x = Conv2D(expansion * in_channels, 54 | kernel_initializer=RandomNormal(stddev=0.02), 55 | kernel_size=1, 56 | padding='same', 57 | use_bias=False, 58 | activation=None, 59 | name=prefix + 'expand')(x) 60 | x = BatchNormalization(epsilon=1e-3, 61 | momentum=0.999, 62 | name=prefix + 'expand_BN')(x) 63 | x = Activation(relu6, name=prefix + 'expand_relu')(x) 64 | else: 65 | prefix = 'expanded_conv_' 66 | 67 | if stride == 2: 68 | x = ZeroPadding2D(padding=correct_pad(x, 3), 69 | name=prefix + 'pad')(x) 70 | 71 | # part2 可分离卷积 72 | x = DepthwiseConv2D(kernel_size=3, 73 | depthwise_initializer=RandomNormal(stddev=0.02), 74 | strides=stride, 75 | activation=None, 76 | use_bias=False, 77 | padding='same' if stride == 1 else 'valid', 78 | name=prefix + 'depthwise')(x) 79 | x = BatchNormalization(epsilon=1e-3, 80 | momentum=0.999, 81 | name=prefix + 'depthwise_BN')(x) 82 | 83 | x = Activation(relu6, name=prefix + 'depthwise_relu')(x) 84 | 85 | # part3压缩特征,而且不使用relu函数,保证特征不被破坏 86 | x = Conv2D(pointwise_filters, 87 | kernel_initializer=RandomNormal(stddev=0.02), 88 | kernel_size=1, 89 | padding='same', 90 | use_bias=False, 91 | activation=None, 92 | name=prefix + 'project')(x) 93 | 94 | x = BatchNormalization(epsilon=1e-3, momentum=0.999, name=prefix + 'project_BN')(x) 95 | 96 | if in_channels == pointwise_filters and stride == 1: 97 | return Add(name=prefix + 'add')([inputs, x]) 98 | return x 99 | 100 | def MobileNetV2(input_shape=[224,224,3], 101 | alpha=1.0, 102 | classes=1000): 103 | 104 | rows = input_shape[0] 105 | 106 | img_input = Input(shape=input_shape) 107 | 108 | # stem部分 109 | # 224,224,3 -> 112,112,32 110 | first_block_filters = _make_divisible(32 * alpha, 8) 111 | x = ZeroPadding2D(padding=correct_pad(img_input, 3), 112 | name='Conv1_pad')(img_input) 113 | x = Conv2D(first_block_filters, 114 | kernel_size=3, 115 | strides=(2, 2), 116 | padding='valid', 117 | use_bias=False, 118 | name='Conv1')(x) 119 | x = BatchNormalization(epsilon=1e-3, 120 | momentum=0.999, 121 | name='bn_Conv1')(x) 122 | x = Activation(relu6, name='Conv1_relu')(x) 123 | 124 | # 112,112,32 -> 112,112,16 125 | x = _inverted_res_block(x, filters=16, alpha=alpha, stride=1, 126 | expansion=1, block_id=0) 127 | 128 | # 112,112,16 -> 56,56,24 129 | x = _inverted_res_block(x, filters=24, alpha=alpha, stride=2, 130 | expansion=6, block_id=1) 131 | x = _inverted_res_block(x, filters=24, alpha=alpha, stride=1, 132 | expansion=6, block_id=2) 133 | 134 | # 56,56,24 -> 28,28,32 135 | x = _inverted_res_block(x, filters=32, alpha=alpha, stride=2, 136 | expansion=6, block_id=3) 137 | x = _inverted_res_block(x, filters=32, alpha=alpha, stride=1, 138 | expansion=6, block_id=4) 139 | x = _inverted_res_block(x, filters=32, alpha=alpha, stride=1, 140 | expansion=6, block_id=5) 141 | 142 | # 28,28,32 -> 14,14,64 143 | x = _inverted_res_block(x, filters=64, alpha=alpha, stride=2, 144 | expansion=6, block_id=6) 145 | x = _inverted_res_block(x, filters=64, alpha=alpha, stride=1, 146 | expansion=6, block_id=7) 147 | x = _inverted_res_block(x, filters=64, alpha=alpha, stride=1, 148 | expansion=6, block_id=8) 149 | x = _inverted_res_block(x, filters=64, alpha=alpha, stride=1, 150 | expansion=6, block_id=9) 151 | 152 | # 14,14,64 -> 14,14,96 153 | x = _inverted_res_block(x, filters=96, alpha=alpha, stride=1, 154 | expansion=6, block_id=10) 155 | x = _inverted_res_block(x, filters=96, alpha=alpha, stride=1, 156 | expansion=6, block_id=11) 157 | x = _inverted_res_block(x, filters=96, alpha=alpha, stride=1, 158 | expansion=6, block_id=12) 159 | # 14,14,96 -> 7,7,160 160 | x = _inverted_res_block(x, filters=160, alpha=alpha, stride=2, 161 | expansion=6, block_id=13) 162 | x = _inverted_res_block(x, filters=160, alpha=alpha, stride=1, 163 | expansion=6, block_id=14) 164 | x = _inverted_res_block(x, filters=160, alpha=alpha, stride=1, 165 | expansion=6, block_id=15) 166 | 167 | # 7,7,160 -> 7,7,320 168 | x = _inverted_res_block(x, filters=320, alpha=alpha, stride=1, 169 | expansion=6, block_id=16) 170 | 171 | if alpha > 1.0: 172 | last_block_filters = _make_divisible(1280 * alpha, 8) 173 | else: 174 | last_block_filters = 1280 175 | 176 | # 7,7,320 -> 7,7,1280 177 | x = Conv2D(last_block_filters, 178 | kernel_size=1, 179 | use_bias=False, 180 | name='Conv_1')(x) 181 | x = BatchNormalization(epsilon=1e-3, 182 | momentum=0.999, 183 | name='Conv_1_bn')(x) 184 | x = Activation(relu6, name='out_relu')(x) 185 | 186 | # 7,7,1280 -> 1,1,1280 187 | x = GlobalAveragePooling2D()(x) 188 | x = Dense(classes, activation='softmax', 189 | use_bias=True, name='Logits')(x) 190 | 191 | inputs = img_input 192 | 193 | model = Model(inputs, x, name='mobilenetv2_%0.2f_%s' % (alpha, rows)) 194 | 195 | return model 196 | 197 | 198 | -------------------------------------------------------------------------------- /nets/resnet.py: -------------------------------------------------------------------------------- 1 | from tensorflow.keras import layers 2 | from tensorflow.keras.layers import (Activation, AveragePooling2D, 3 | BatchNormalization, Conv2D, Dense, 4 | Flatten, Input, MaxPooling2D, 5 | ZeroPadding2D) 6 | from tensorflow.keras.models import Model 7 | 8 | 9 | def identity_block(input_tensor, kernel_size, filters, stage, block): 10 | 11 | filters1, filters2, filters3 = filters 12 | 13 | conv_name_base = 'res' + str(stage) + block + '_branch' 14 | bn_name_base = 'bn' + str(stage) + block + '_branch' 15 | 16 | # 减少通道数 17 | x = Conv2D(filters1, (1, 1), name=conv_name_base + '2a')(input_tensor) 18 | x = BatchNormalization(name=bn_name_base + '2a')(x) 19 | x = Activation('relu')(x) 20 | 21 | # 3x3卷积 22 | x = Conv2D(filters2, kernel_size,padding='same', name=conv_name_base + '2b')(x) 23 | x = BatchNormalization(name=bn_name_base + '2b')(x) 24 | x = Activation('relu')(x) 25 | 26 | # 上升通道数 27 | x = Conv2D(filters3, (1, 1), name=conv_name_base + '2c')(x) 28 | x = BatchNormalization(name=bn_name_base + '2c')(x) 29 | 30 | x = layers.add([x, input_tensor]) 31 | x = Activation('relu')(x) 32 | return x 33 | 34 | 35 | def conv_block(input_tensor, kernel_size, filters, stage, block, strides=(2, 2)): 36 | filters1, filters2, filters3 = filters 37 | 38 | conv_name_base = 'res' + str(stage) + block + '_branch' 39 | bn_name_base = 'bn' + str(stage) + block + '_branch' 40 | 41 | # 减少通道数 42 | x = Conv2D(filters1, (1, 1), strides=strides, name=conv_name_base + '2a')(input_tensor) 43 | x = BatchNormalization(name=bn_name_base + '2a')(x) 44 | x = Activation('relu')(x) 45 | 46 | # 3x3卷积 47 | x = Conv2D(filters2, kernel_size, padding='same', name=conv_name_base + '2b')(x) 48 | x = BatchNormalization(name=bn_name_base + '2b')(x) 49 | x = Activation('relu')(x) 50 | 51 | # 上升通道数 52 | x = Conv2D(filters3, (1, 1), name=conv_name_base + '2c')(x) 53 | x = BatchNormalization(name=bn_name_base + '2c')(x) 54 | 55 | # 残差边 56 | shortcut = Conv2D(filters3, (1, 1), strides=strides, 57 | name=conv_name_base + '1')(input_tensor) 58 | shortcut = BatchNormalization(name=bn_name_base + '1')(shortcut) 59 | 60 | x = layers.add([x, shortcut]) 61 | x = Activation('relu')(x) 62 | return x 63 | 64 | 65 | def ResNet50(input_shape=[224,224,3], classes=1000): 66 | img_input = Input(shape=input_shape) 67 | 68 | x = ZeroPadding2D((3, 3))(img_input) 69 | # 224,224,3 -> 112,112,64 70 | x = Conv2D(64, (7, 7), strides=(2, 2), name='conv1')(x) 71 | x = BatchNormalization(name='bn_conv1')(x) 72 | x = Activation('relu')(x) 73 | 74 | # 112,112,64 -> 56,56,64 75 | x = MaxPooling2D((3, 3), strides=(2, 2))(x) 76 | 77 | # 56,56,64 -> 56,56,256 78 | x = conv_block(x, 3, [64, 64, 256], stage=2, block='a', strides=(1, 1)) 79 | x = identity_block(x, 3, [64, 64, 256], stage=2, block='b') 80 | x = identity_block(x, 3, [64, 64, 256], stage=2, block='c') 81 | 82 | # 56,56,256 -> 28,28,512 83 | x = conv_block(x, 3, [128, 128, 512], stage=3, block='a') 84 | x = identity_block(x, 3, [128, 128, 512], stage=3, block='b') 85 | x = identity_block(x, 3, [128, 128, 512], stage=3, block='c') 86 | x = identity_block(x, 3, [128, 128, 512], stage=3, block='d') 87 | 88 | # 28,28,512 -> 14,14,1024 89 | x = conv_block(x, 3, [256, 256, 1024], stage=4, block='a') 90 | x = identity_block(x, 3, [256, 256, 1024], stage=4, block='b') 91 | x = identity_block(x, 3, [256, 256, 1024], stage=4, block='c') 92 | x = identity_block(x, 3, [256, 256, 1024], stage=4, block='d') 93 | x = identity_block(x, 3, [256, 256, 1024], stage=4, block='e') 94 | x = identity_block(x, 3, [256, 256, 1024], stage=4, block='f') 95 | 96 | # 14,14,1024 -> 7,7,2048 97 | x = conv_block(x, 3, [512, 512, 2048], stage=5, block='a') 98 | x = identity_block(x, 3, [512, 512, 2048], stage=5, block='b') 99 | x = identity_block(x, 3, [512, 512, 2048], stage=5, block='c') 100 | 101 | # 1,1,2048 102 | x = AveragePooling2D((7, 7), name='avg_pool')(x) 103 | 104 | # 进行预测 105 | # 2048 106 | x = Flatten()(x) 107 | 108 | # num_classes 109 | x = Dense(classes, activation='softmax', name='fc1000')(x) 110 | 111 | model = Model(img_input, x, name='resnet50') 112 | 113 | return model 114 | 115 | 116 | if __name__ == '__main__': 117 | model = ResNet50() 118 | model.summary() 119 | -------------------------------------------------------------------------------- /nets/swin_transformer.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import numpy as np 4 | import tensorflow as tf 5 | from tensorflow import keras 6 | from tensorflow.keras import backend as K 7 | from tensorflow.keras.layers import (Add, Conv2D, Dense, Dropout, 8 | GlobalAveragePooling1D, Input, Lambda, 9 | Layer, Reshape, Softmax) 10 | 11 | 12 | def drop_path(inputs, drop_prob, is_training): 13 | if (not is_training) or (drop_prob == 0.): 14 | return inputs 15 | 16 | # Compute keep_prob 17 | keep_prob = 1.0 - drop_prob 18 | 19 | # Compute drop_connect tensor 20 | random_tensor = keep_prob 21 | shape = (tf.shape(inputs)[0],) + (1,) * (len(tf.shape(inputs)) - 1) 22 | random_tensor += tf.random.uniform(shape, dtype=inputs.dtype) 23 | binary_tensor = tf.floor(random_tensor) 24 | output = tf.math.divide(inputs, keep_prob) * binary_tensor 25 | return output 26 | 27 | class DropPath(keras.layers.Layer): 28 | def __init__(self, drop_prob=None): 29 | super().__init__() 30 | self.drop_prob = drop_prob 31 | 32 | def call(self, x, training=None): 33 | return drop_path(x, self.drop_prob, training) 34 | 35 | #--------------------------------------# 36 | # LayerNormalization 37 | # 层标准化的实现 38 | #--------------------------------------# 39 | class LayerNormalization(keras.layers.Layer): 40 | def __init__(self, 41 | center=True, 42 | scale=True, 43 | epsilon=None, 44 | gamma_initializer='ones', 45 | beta_initializer='zeros', 46 | gamma_regularizer=None, 47 | beta_regularizer=None, 48 | gamma_constraint=None, 49 | beta_constraint=None, 50 | **kwargs): 51 | """Layer normalization layer 52 | 53 | See: [Layer Normalization](https://arxiv.org/pdf/1607.06450.pdf) 54 | 55 | :param center: Add an offset parameter if it is True. 56 | :param scale: Add a scale parameter if it is True. 57 | :param epsilon: Epsilon for calculating variance. 58 | :param gamma_initializer: Initializer for the gamma weight. 59 | :param beta_initializer: Initializer for the beta weight. 60 | :param gamma_regularizer: Optional regularizer for the gamma weight. 61 | :param beta_regularizer: Optional regularizer for the beta weight. 62 | :param gamma_constraint: Optional constraint for the gamma weight. 63 | :param beta_constraint: Optional constraint for the beta weight. 64 | :param kwargs: 65 | """ 66 | super(LayerNormalization, self).__init__(**kwargs) 67 | self.supports_masking = True 68 | self.center = center 69 | self.scale = scale 70 | if epsilon is None: 71 | epsilon = K.epsilon() * K.epsilon() 72 | self.epsilon = epsilon 73 | self.gamma_initializer = keras.initializers.get(gamma_initializer) 74 | self.beta_initializer = keras.initializers.get(beta_initializer) 75 | self.gamma_regularizer = keras.regularizers.get(gamma_regularizer) 76 | self.beta_regularizer = keras.regularizers.get(beta_regularizer) 77 | self.gamma_constraint = keras.constraints.get(gamma_constraint) 78 | self.beta_constraint = keras.constraints.get(beta_constraint) 79 | self.gamma, self.beta = None, None 80 | 81 | def get_config(self): 82 | config = { 83 | 'center': self.center, 84 | 'scale': self.scale, 85 | 'epsilon': self.epsilon, 86 | 'gamma_initializer': keras.initializers.serialize(self.gamma_initializer), 87 | 'beta_initializer': keras.initializers.serialize(self.beta_initializer), 88 | 'gamma_regularizer': keras.regularizers.serialize(self.gamma_regularizer), 89 | 'beta_regularizer': keras.regularizers.serialize(self.beta_regularizer), 90 | 'gamma_constraint': keras.constraints.serialize(self.gamma_constraint), 91 | 'beta_constraint': keras.constraints.serialize(self.beta_constraint), 92 | } 93 | base_config = super(LayerNormalization, self).get_config() 94 | return dict(list(base_config.items()) + list(config.items())) 95 | 96 | def compute_output_shape(self, input_shape): 97 | return input_shape 98 | 99 | def compute_mask(self, inputs, input_mask=None): 100 | return input_mask 101 | 102 | def build(self, input_shape): 103 | shape = input_shape[-1:] 104 | if self.scale: 105 | self.gamma = self.add_weight( 106 | shape=shape, 107 | initializer=self.gamma_initializer, 108 | regularizer=self.gamma_regularizer, 109 | constraint=self.gamma_constraint, 110 | name='gamma', 111 | ) 112 | if self.center: 113 | self.beta = self.add_weight( 114 | shape=shape, 115 | initializer=self.beta_initializer, 116 | regularizer=self.beta_regularizer, 117 | constraint=self.beta_constraint, 118 | name='beta', 119 | ) 120 | super(LayerNormalization, self).build(input_shape) 121 | 122 | def call(self, inputs, training=None): 123 | mean = K.mean(inputs, axis=-1, keepdims=True) 124 | variance = K.mean(K.square(inputs - mean), axis=-1, keepdims=True) 125 | std = K.sqrt(variance + self.epsilon) 126 | outputs = (inputs - mean) / std 127 | if self.scale: 128 | outputs *= self.gamma 129 | if self.center: 130 | outputs += self.beta 131 | return outputs 132 | 133 | #--------------------------------------# 134 | # Gelu激活函数的实现 135 | # 利用近似的数学公式 136 | #--------------------------------------# 137 | class Gelu(Layer): 138 | def __init__(self, **kwargs): 139 | super(Gelu, self).__init__(**kwargs) 140 | self.supports_masking = True 141 | 142 | def call(self, inputs): 143 | return 0.5 * inputs * (1 + tf.tanh(tf.sqrt(2 / math.pi) * (inputs + 0.044715 * tf.pow(inputs, 3)))) 144 | 145 | def get_config(self): 146 | config = super(Gelu, self).get_config() 147 | return config 148 | 149 | def compute_output_shape(self, input_shape): 150 | return input_shape 151 | 152 | class Mlp(): 153 | def __init__(self, in_features, hidden_features=None, out_features=None, drop=0., name=""): 154 | super().__init__() 155 | out_features = out_features or in_features 156 | hidden_features = hidden_features or in_features 157 | 158 | self.fc1 = Dense(hidden_features, name=name + '.fc1') 159 | self.act_layer = Gelu() 160 | self.fc2 = Dense(out_features, name=name + '.fc2') 161 | self.drop = Dropout(drop) 162 | 163 | def call(self, x): 164 | x = self.fc1(x) 165 | x = self.act_layer(x) 166 | x = self.drop(x) 167 | x = self.fc2(x) 168 | x = self.drop(x) 169 | return x 170 | 171 | def window_partition(x, window_size): 172 | B, H, W, C = x.get_shape().as_list() 173 | x = tf.reshape(x, shape=[-1, H // window_size, window_size, W // window_size, window_size, C]) 174 | x = tf.transpose(x, perm=[0, 1, 3, 2, 4, 5]) 175 | windows = tf.reshape(x, shape=[-1, window_size, window_size, C]) 176 | return windows 177 | 178 | def window_reverse(windows, window_size, H, W, C): 179 | x = tf.reshape(windows, shape=[-1, H // window_size, W // window_size, window_size, window_size, C]) 180 | x = tf.transpose(x, perm=[0, 1, 3, 2, 4, 5]) 181 | x = tf.reshape(x, shape=[-1, H, W, C]) 182 | return x 183 | 184 | class SwinTransformerBlock_pre(keras.layers.Layer): 185 | def __init__(self, input_resolution, window_size=7, shift_size=0): 186 | super().__init__() 187 | self.input_resolution = input_resolution 188 | self.window_size = window_size 189 | self.shift_size = shift_size 190 | 191 | def build(self, input_shape): 192 | if self.shift_size > 0: 193 | H, W = self.input_resolution 194 | img_mask = np.zeros([1, H, W, 1]) 195 | h_slices = (slice(0, -self.window_size), 196 | slice(-self.window_size, -self.shift_size), 197 | slice(-self.shift_size, None)) 198 | w_slices = (slice(0, -self.window_size), 199 | slice(-self.window_size, -self.shift_size), 200 | slice(-self.shift_size, None)) 201 | cnt = 0 202 | for h in h_slices: 203 | for w in w_slices: 204 | img_mask[:, h, w, :] = cnt 205 | cnt += 1 206 | 207 | img_mask = tf.convert_to_tensor(img_mask) 208 | mask_windows = window_partition(img_mask, self.window_size) 209 | mask_windows = tf.reshape( 210 | mask_windows, shape=[-1, self.window_size * self.window_size]) 211 | attn_mask = tf.expand_dims( 212 | mask_windows, axis=1) - tf.expand_dims(mask_windows, axis=2) 213 | attn_mask = tf.where(tf.not_equal(attn_mask, 0), -100.0 * tf.ones_like(attn_mask), attn_mask) 214 | attn_mask = tf.where(tf.equal(attn_mask, 0), tf.zeros_like(attn_mask), attn_mask) 215 | 216 | self.attn_mask = tf.Variable( 217 | initial_value=attn_mask, trainable=False) 218 | else: 219 | self.attn_mask = None 220 | 221 | self.built = True 222 | 223 | def compute_output_shape(self, input_shape): 224 | return (None, self.window_size * self.window_size, input_shape[2]) 225 | 226 | def call(self, x): 227 | H, W = self.input_resolution 228 | B, L, C = x.get_shape().as_list() 229 | 230 | x = tf.reshape(x, shape=[-1, H, W, C]) 231 | 232 | # 56, 56, 96 233 | if self.shift_size > 0: 234 | shifted_x = tf.roll( 235 | x, shift=[-self.shift_size, -self.shift_size], axis=[1, 2]) 236 | else: 237 | shifted_x = x 238 | 239 | # 56, 56, 96 -> 8, 7, 8, 7, 96 -> 8, 8, 7, 7, 96 -> 64, 7, 7, 96 -> 64, 49, 96 240 | x_windows = window_partition(shifted_x, self.window_size) 241 | x_windows = tf.reshape(x_windows, shape=[-1, self.window_size * self.window_size, C]) 242 | return x_windows 243 | 244 | 245 | class WindowAttention_pre(keras.layers.Layer): 246 | def __init__(self, dim, window_size, num_heads, qk_scale=None, attn_drop=0, name=""): 247 | super().__init__(name=name) 248 | self.dim = dim 249 | self.window_size = window_size 250 | self.num_heads = num_heads 251 | 252 | head_dim = dim // num_heads 253 | self.scale = qk_scale or head_dim ** -0.5 254 | self.attn_drop = Dropout(attn_drop) 255 | 256 | def build(self, input_shape): 257 | self.relative_position_bias_table = self.add_weight( 258 | f'attn/relative_position_bias_table', 259 | shape = ((2 * self.window_size[0] - 1) * (2 * self.window_size[1] - 1), self.num_heads), 260 | initializer = tf.initializers.zeros(), 261 | trainable = True 262 | ) 263 | 264 | coords_h = np.arange(self.window_size[0]) 265 | coords_w = np.arange(self.window_size[1]) 266 | coords = np.stack(np.meshgrid(coords_h, coords_w, indexing='ij')) 267 | coords_flatten = coords.reshape(2, -1) 268 | relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] 269 | relative_coords = relative_coords.transpose([1, 2, 0]) 270 | relative_coords[:, :, 0] += self.window_size[0] - 1 271 | relative_coords[:, :, 1] += self.window_size[1] - 1 272 | relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 273 | relative_position_index = relative_coords.sum(-1).astype(np.int64) 274 | 275 | self.relative_position_index = tf.Variable(initial_value=tf.convert_to_tensor( 276 | relative_position_index), trainable=False, name=f'attn/relative_position_index') 277 | self.built = True 278 | 279 | def compute_output_shape(self, input_shape): 280 | return (input_shape[0], input_shape[1], input_shape[2] // 3) 281 | 282 | def call(self, x, mask=None): 283 | B_, N, C = x.get_shape().as_list() 284 | C = C // 3 285 | # [B_, N, C] -> [B_, N, 3 * C] -> [B_, N, 3, num_heads, C / num_heads] -> [3, B_, num_heads, N, C / num_heads] 286 | qkv = tf.transpose(tf.reshape(x, shape=[-1, N, 3, self.num_heads, C // self.num_heads]), perm=[2, 0, 3, 1, 4]) 287 | # [B_, num_heads, N, C / num_heads] 288 | q, k, v = qkv[0], qkv[1], qkv[2] 289 | 290 | # [B_, num_heads, N, N] 291 | q = q * self.scale 292 | attn = (q @ tf.transpose(k, perm=[0, 1, 3, 2])) 293 | 294 | relative_position_bias = tf.gather(self.relative_position_bias_table, tf.reshape(self.relative_position_index, shape=[-1])) 295 | relative_position_bias = tf.reshape(relative_position_bias, shape=[ 296 | self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1]) 297 | relative_position_bias = tf.transpose(relative_position_bias, perm=[2, 0, 1]) 298 | # [B_, num_heads, N, N] 299 | attn = attn + tf.expand_dims(relative_position_bias, axis=0) 300 | 301 | if mask is not None: 302 | nW = mask.get_shape()[0] # tf.shape(mask)[0] 303 | attn = tf.reshape(attn, shape=[-1, nW, self.num_heads, N, N]) + \ 304 | tf.cast(tf.expand_dims(tf.expand_dims(mask, axis=1), axis=0), tf.float32) 305 | 306 | attn = tf.reshape(attn, shape=[-1, self.num_heads, N, N]) 307 | attn = tf.nn.softmax(attn, axis=-1) 308 | else: 309 | # [B_, num_heads, N, N] 310 | attn = tf.nn.softmax(attn, axis=-1) 311 | 312 | attn = self.attn_drop(attn) 313 | 314 | # [B_, num_heads, N, C / num_heads] -> [B_, N, num_heads, C / num_heads] -> [B_, N, C] 315 | x = tf.transpose((attn @ v), perm=[0, 2, 1, 3]) 316 | x = tf.reshape(x, shape=[-1, N, C]) 317 | return x 318 | 319 | class WindowAttention(): 320 | def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0., name=""): 321 | super().__init__() 322 | 323 | self.qkv = Dense(dim * 3, use_bias=qkv_bias, name=name + ".qkv") 324 | self.pre = WindowAttention_pre(dim, window_size, num_heads, qk_scale, attn_drop, name=name) 325 | 326 | self.proj = Dense(dim, name=name + ".proj") 327 | self.proj_drop = Dropout(proj_drop) 328 | 329 | 330 | def call(self, x, mask=None): 331 | B_, N, C = x.get_shape().as_list() 332 | x = self.qkv(x) 333 | x = self.pre(x, mask = mask) 334 | 335 | # [B_, N, C] -> [B_, N, C] 336 | x = self.proj(x) 337 | x = self.proj_drop(x) 338 | return x 339 | 340 | class SwinTransformerBlock_post(keras.layers.Layer): 341 | def __init__(self, dim, input_resolution, window_size=7, shift_size=0): 342 | super().__init__() 343 | self.dim = dim 344 | self.input_resolution = input_resolution 345 | self.window_size = window_size 346 | self.shift_size = shift_size 347 | 348 | def compute_output_shape(self, input_shape): 349 | return (None, self.input_resolution[0] * self.input_resolution[1], input_shape[2]) 350 | 351 | def call(self, x): 352 | H, W = self.input_resolution 353 | 354 | # 64, 49, 97 -> 64, 7, 7, 97 -> 8, 8, 7, 7, 96 -> 8, 7, 8, 7, 96 -> 56, 56, 96 355 | attn_windows = tf.reshape(x, shape=[-1, self.window_size, self.window_size, self.dim]) 356 | shifted_x = window_reverse(attn_windows, self.window_size, H, W, self.dim) 357 | # 56, 56, 96 358 | if self.shift_size > 0: 359 | x = tf.roll(shifted_x, shift=[self.shift_size, self.shift_size], axis=[1, 2]) 360 | else: 361 | x = shifted_x 362 | 363 | # 56 * 56, 96 364 | x = tf.reshape(x, shape=[-1, H * W, self.dim]) 365 | return x 366 | 367 | class SwinTransformerBlock(): 368 | def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0, mlp_ratio=4., 369 | qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path_prob=0., name=""): 370 | super().__init__() 371 | self.dim = dim 372 | self.input_resolution = input_resolution 373 | self.num_heads = num_heads 374 | self.window_size = window_size 375 | self.shift_size = shift_size 376 | self.mlp_ratio = mlp_ratio 377 | if min(self.input_resolution) <= self.window_size: 378 | self.shift_size = 0 379 | self.window_size = min(self.input_resolution) 380 | assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size" 381 | 382 | self.dim = dim 383 | self.input_resolution = input_resolution 384 | self.num_heads = num_heads 385 | self.window_size = window_size 386 | self.shift_size = shift_size 387 | self.mlp_ratio = mlp_ratio 388 | if min(self.input_resolution) <= self.window_size: 389 | self.shift_size = 0 390 | self.window_size = min(self.input_resolution) 391 | assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size" 392 | 393 | self.norm1 = LayerNormalization(epsilon=1e-5, name=name + ".norm1") 394 | self.pre = SwinTransformerBlock_pre(self.input_resolution, self.window_size, self.shift_size) 395 | self.attn = WindowAttention( 396 | dim, 397 | window_size = (self.window_size, self.window_size), 398 | num_heads = num_heads, 399 | qkv_bias = qkv_bias, 400 | qk_scale = qk_scale, 401 | attn_drop = attn_drop, 402 | proj_drop = drop, 403 | name = name + ".attn" 404 | ) 405 | self.post = SwinTransformerBlock_post(self.dim, self.input_resolution, self.window_size, self.shift_size) 406 | self.drop_path = DropPath(drop_path_prob if drop_path_prob > 0. else 0.) 407 | self.norm2 = LayerNormalization(epsilon=1e-5, name=name + ".norm2") 408 | mlp_hidden_dim = int(dim * mlp_ratio) 409 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, drop=drop, name=name + ".mlp") 410 | self.add = Add() 411 | 412 | 413 | def call(self, x): 414 | H, W = self.input_resolution 415 | B, L, C = x.get_shape().as_list() 416 | assert L == H * W, "input feature has wrong size" 417 | # 56, 56, 96 418 | 419 | shortcut = x 420 | 421 | x = self.norm1(x) 422 | x = self.pre(x) 423 | # 64, 49, 97 -> 64, 49, 97 424 | x = self.attn.call(x, mask=self.pre.attn_mask) 425 | x = self.post(x) 426 | 427 | # FFN 428 | # 56 * 56, 96 429 | x = self.add([shortcut, self.drop_path(x)]) 430 | x = self.add([x, self.drop_path(self.mlp.call(self.norm2(x)))]) 431 | return x 432 | 433 | class PatchMerging(keras.layers.Layer): 434 | def __init__(self, input_resolution): 435 | super().__init__() 436 | self.input_resolution = input_resolution 437 | 438 | def compute_output_shape(self, input_shape): 439 | return (input_shape[0], input_shape[1] // 4, input_shape[2] * 4) 440 | 441 | def call(self, x): 442 | H, W = self.input_resolution 443 | B, L, C = x.get_shape().as_list() 444 | assert L == H * W, "input feature has wrong size" 445 | assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even." 446 | 447 | # 56, 56, 96 448 | x = tf.reshape(x, shape=[-1, H, W, C]) 449 | 450 | # 28, 28, 96 451 | x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C 452 | x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C 453 | x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C 454 | x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C 455 | # 28, 28, 384 456 | x = tf.concat([x0, x1, x2, x3], axis=-1) 457 | # 784, 384 458 | x = tf.reshape(x, shape=[-1, (H // 2) * (W // 2), 4 * C]) 459 | 460 | return x 461 | 462 | def BasicLayer( 463 | x, dim, input_resolution, depth, num_heads, window_size, 464 | mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path_prob=0., name="" 465 | ): 466 | for i in range(depth): 467 | x = SwinTransformerBlock( 468 | dim = dim, 469 | input_resolution = input_resolution, 470 | num_heads = num_heads, 471 | window_size = window_size, 472 | shift_size = 0 if (i % 2 == 0) else window_size // 2, 473 | mlp_ratio = mlp_ratio, 474 | qkv_bias = qkv_bias, 475 | qk_scale = qk_scale, 476 | drop = drop, 477 | attn_drop = attn_drop, 478 | drop_path_prob = drop_path_prob[i] if isinstance(drop_path_prob, list) else drop_path_prob, 479 | name = name + ".blocks." + str(i), 480 | ).call(x) 481 | return x 482 | 483 | def build_model(input_shape = [224, 224], patch_size=(4, 4), classes=1000, 484 | embed_dim=96, depths=[2, 2, 6, 2], num_heads=[3, 6, 12, 24], 485 | window_size=7, mlp_ratio=4., qkv_bias=True, qk_scale=None, 486 | drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1): 487 | #-----------------------------------------------# 488 | # 224, 224, 3 489 | #-----------------------------------------------# 490 | inputs = Input(shape = (input_shape[0], input_shape[1], 3)) 491 | 492 | #-----------------------------------------------# 493 | # 224, 224, 3 -> 56, 56, 768 494 | #-----------------------------------------------# 495 | x = Conv2D(embed_dim, patch_size, strides = patch_size, padding = "valid", name = "patch_embed.proj")(inputs) 496 | #-----------------------------------------------# 497 | # 56, 56, 768 -> 3136, 768 498 | #-----------------------------------------------# 499 | x = Reshape(((input_shape[0] // patch_size[0]) * (input_shape[1] // patch_size[0]), embed_dim))(x) 500 | x = LayerNormalization(epsilon=1e-5, name = "patch_embed.norm")(x) 501 | x = Dropout(drop_rate)(x) 502 | 503 | num_layers = len(depths) 504 | patches_resolution = [input_shape[0] // patch_size[0], input_shape[1] // patch_size[1]] 505 | dpr = [x for x in np.linspace(0., drop_path_rate, sum(depths))] 506 | #-----------------------------------------------# 507 | # 3136, 768 -> 3136, 49 508 | #-----------------------------------------------# 509 | for i_layer in range(num_layers): 510 | dim = int(embed_dim * 2 ** i_layer) 511 | input_resolution = (patches_resolution[0] // (2 ** i_layer), patches_resolution[1] // (2 ** i_layer)) 512 | x = BasicLayer( 513 | x, 514 | dim = dim, 515 | input_resolution = input_resolution, 516 | depth = depths[i_layer], 517 | num_heads = num_heads[i_layer], 518 | window_size = window_size, 519 | mlp_ratio = mlp_ratio, 520 | qkv_bias = qkv_bias, qk_scale=qk_scale, 521 | drop = drop_rate, attn_drop=attn_drop_rate, 522 | drop_path_prob = dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], 523 | name = "layers." + str(i_layer) 524 | ) 525 | if (i_layer < num_layers - 1): 526 | x = PatchMerging(input_resolution)(x) 527 | x = LayerNormalization(epsilon=1e-5, name = "layers." + str(i_layer) + ".downsample.norm")(x) 528 | x = Dense(2 * dim, use_bias=False, name = "layers." + str(i_layer) + ".downsample.reduction")(x) 529 | 530 | x = LayerNormalization(epsilon=1e-5, name="norm")(x) 531 | x = GlobalAveragePooling1D()(x) 532 | x = Dense(classes, name="head")(x) 533 | x = Softmax()(x) 534 | return keras.models.Model(inputs, x) 535 | 536 | def swin_transformer_tiny(input_shape=[224, 224], classes=1000): 537 | model = build_model(input_shape, classes=classes, depths=[2, 2, 6, 2], num_heads=[3, 6, 12, 24], embed_dim=96, drop_path_rate=0.2) 538 | return model 539 | 540 | def swin_transformer_small(input_shape=[224, 224], classes=1000): 541 | model = build_model(input_shape, classes=classes, depths=[2, 2, 18, 2], num_heads=[3, 6, 12, 24], embed_dim=96, drop_path_rate=0.3) 542 | return model 543 | 544 | def swin_transformer_base(input_shape=[224, 224], classes=1000): 545 | model = build_model(input_shape, classes=classes, depths=[2, 2, 18, 2], num_heads=[4, 8, 16, 32], embed_dim=128, drop_path_rate=0.5) 546 | return model -------------------------------------------------------------------------------- /nets/vgg.py: -------------------------------------------------------------------------------- 1 | from tensorflow.keras.layers import Conv2D, Dense, Flatten, Input, MaxPooling2D 2 | from tensorflow.keras.models import Model 3 | 4 | 5 | def VGG16(input_shape=None, classes=1000): 6 | img_input = Input(shape=input_shape) 7 | 8 | # Block 1 9 | # 224, 224, 3 -> 224, 224, 64 10 | x = Conv2D(64, (3, 3), 11 | activation='relu', 12 | padding='same', 13 | name='block1_conv1')(img_input) 14 | x = Conv2D(64, (3, 3), 15 | activation='relu', 16 | padding='same', 17 | name='block1_conv2')(x) 18 | # 224, 224, 64 -> 112, 112, 64 19 | x = MaxPooling2D((2, 2), strides=(2, 2), name='block1_pool')(x) 20 | 21 | # Block 2 22 | # 112, 112, 64 -> 112, 112, 128 23 | x = Conv2D(128, (3, 3), 24 | activation='relu', 25 | padding='same', 26 | name='block2_conv1')(x) 27 | x = Conv2D(128, (3, 3), 28 | activation='relu', 29 | padding='same', 30 | name='block2_conv2')(x) 31 | # 112, 112, 128 -> 56, 56, 128 32 | x = MaxPooling2D((2, 2), strides=(2, 2), name='block2_pool')(x) 33 | 34 | # Block 3 35 | # 56, 56, 128 -> 56, 56, 256 36 | x = Conv2D(256, (3, 3), 37 | activation='relu', 38 | padding='same', 39 | name='block3_conv1')(x) 40 | x = Conv2D(256, (3, 3), 41 | activation='relu', 42 | padding='same', 43 | name='block3_conv2')(x) 44 | x = Conv2D(256, (3, 3), 45 | activation='relu', 46 | padding='same', 47 | name='block3_conv3')(x) 48 | # 56, 56, 256 -> 28, 28, 256 49 | x = MaxPooling2D((2, 2), strides=(2, 2), name='block3_pool')(x) 50 | 51 | # Block 4 52 | # 28, 28, 256 -> 28, 28, 512 53 | x = Conv2D(512, (3, 3), 54 | activation='relu', 55 | padding='same', 56 | name='block4_conv1')(x) 57 | x = Conv2D(512, (3, 3), 58 | activation='relu', 59 | padding='same', 60 | name='block4_conv2')(x) 61 | x = Conv2D(512, (3, 3), 62 | activation='relu', 63 | padding='same', 64 | name='block4_conv3')(x) 65 | 66 | # 28, 28, 512 -> 14, 14, 512 67 | x = MaxPooling2D((2, 2), strides=(2, 2), name='block4_pool')(x) 68 | 69 | # Block 5 70 | # 14, 14, 512 -> 14, 14, 512 71 | x = Conv2D(512, (3, 3), 72 | activation='relu', 73 | padding='same', 74 | name='block5_conv1')(x) 75 | x = Conv2D(512, (3, 3), 76 | activation='relu', 77 | padding='same', 78 | name='block5_conv2')(x) 79 | x = Conv2D(512, (3, 3), 80 | activation='relu', 81 | padding='same', 82 | name='block5_conv3')(x) 83 | # 14, 14, 512 -> 7, 7, 512 84 | x = MaxPooling2D((2, 2), strides=(2, 2), name='block5_pool')(x) 85 | 86 | x = Flatten(name='flatten')(x) 87 | x = Dense(4096, activation='relu', name='fc1')(x) 88 | x = Dense(4096, activation='relu', name='fc2')(x) 89 | x = Dense(classes, activation='softmax', name='predictions')(x) 90 | 91 | inputs = img_input 92 | 93 | model = Model(inputs, x, name='vgg16') 94 | return model 95 | 96 | if __name__ == '__main__': 97 | model = VGG16(input_shape=(224, 224, 3)) 98 | model.summary() 99 | -------------------------------------------------------------------------------- /nets/vision_transformer.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import tensorflow as tf 4 | from tensorflow import keras 5 | from tensorflow.keras import backend as K 6 | from tensorflow.keras.layers import (Add, Conv2D, Dense, Dropout, Input, 7 | Lambda, Layer, Reshape, Softmax) 8 | 9 | 10 | #--------------------------------------# 11 | # LayerNormalization 12 | # 层标准化的实现 13 | #--------------------------------------# 14 | class LayerNormalization(keras.layers.Layer): 15 | def __init__(self, 16 | center=True, 17 | scale=True, 18 | epsilon=None, 19 | gamma_initializer='ones', 20 | beta_initializer='zeros', 21 | gamma_regularizer=None, 22 | beta_regularizer=None, 23 | gamma_constraint=None, 24 | beta_constraint=None, 25 | **kwargs): 26 | """Layer normalization layer 27 | 28 | See: [Layer Normalization](https://arxiv.org/pdf/1607.06450.pdf) 29 | 30 | :param center: Add an offset parameter if it is True. 31 | :param scale: Add a scale parameter if it is True. 32 | :param epsilon: Epsilon for calculating variance. 33 | :param gamma_initializer: Initializer for the gamma weight. 34 | :param beta_initializer: Initializer for the beta weight. 35 | :param gamma_regularizer: Optional regularizer for the gamma weight. 36 | :param beta_regularizer: Optional regularizer for the beta weight. 37 | :param gamma_constraint: Optional constraint for the gamma weight. 38 | :param beta_constraint: Optional constraint for the beta weight. 39 | :param kwargs: 40 | """ 41 | super(LayerNormalization, self).__init__(**kwargs) 42 | self.supports_masking = True 43 | self.center = center 44 | self.scale = scale 45 | if epsilon is None: 46 | epsilon = K.epsilon() * K.epsilon() 47 | self.epsilon = epsilon 48 | self.gamma_initializer = keras.initializers.get(gamma_initializer) 49 | self.beta_initializer = keras.initializers.get(beta_initializer) 50 | self.gamma_regularizer = keras.regularizers.get(gamma_regularizer) 51 | self.beta_regularizer = keras.regularizers.get(beta_regularizer) 52 | self.gamma_constraint = keras.constraints.get(gamma_constraint) 53 | self.beta_constraint = keras.constraints.get(beta_constraint) 54 | self.gamma, self.beta = None, None 55 | 56 | def get_config(self): 57 | config = { 58 | 'center': self.center, 59 | 'scale': self.scale, 60 | 'epsilon': self.epsilon, 61 | 'gamma_initializer': keras.initializers.serialize(self.gamma_initializer), 62 | 'beta_initializer': keras.initializers.serialize(self.beta_initializer), 63 | 'gamma_regularizer': keras.regularizers.serialize(self.gamma_regularizer), 64 | 'beta_regularizer': keras.regularizers.serialize(self.beta_regularizer), 65 | 'gamma_constraint': keras.constraints.serialize(self.gamma_constraint), 66 | 'beta_constraint': keras.constraints.serialize(self.beta_constraint), 67 | } 68 | base_config = super(LayerNormalization, self).get_config() 69 | return dict(list(base_config.items()) + list(config.items())) 70 | 71 | def compute_output_shape(self, input_shape): 72 | return input_shape 73 | 74 | def compute_mask(self, inputs, input_mask=None): 75 | return input_mask 76 | 77 | def build(self, input_shape): 78 | shape = input_shape[-1:] 79 | if self.scale: 80 | self.gamma = self.add_weight( 81 | shape=shape, 82 | initializer=self.gamma_initializer, 83 | regularizer=self.gamma_regularizer, 84 | constraint=self.gamma_constraint, 85 | name='gamma', 86 | ) 87 | if self.center: 88 | self.beta = self.add_weight( 89 | shape=shape, 90 | initializer=self.beta_initializer, 91 | regularizer=self.beta_regularizer, 92 | constraint=self.beta_constraint, 93 | name='beta', 94 | ) 95 | super(LayerNormalization, self).build(input_shape) 96 | 97 | def call(self, inputs, training=None): 98 | mean = K.mean(inputs, axis=-1, keepdims=True) 99 | variance = K.mean(K.square(inputs - mean), axis=-1, keepdims=True) 100 | std = K.sqrt(variance + self.epsilon) 101 | outputs = (inputs - mean) / std 102 | if self.scale: 103 | outputs *= self.gamma 104 | if self.center: 105 | outputs += self.beta 106 | return outputs 107 | 108 | #--------------------------------------# 109 | # Gelu激活函数的实现 110 | # 利用近似的数学公式 111 | #--------------------------------------# 112 | class Gelu(Layer): 113 | def __init__(self, **kwargs): 114 | super(Gelu, self).__init__(**kwargs) 115 | self.supports_masking = True 116 | 117 | def call(self, inputs): 118 | return 0.5 * inputs * (1 + tf.tanh(tf.sqrt(2 / math.pi) * (inputs + 0.044715 * tf.pow(inputs, 3)))) 119 | 120 | def get_config(self): 121 | config = super(Gelu, self).get_config() 122 | return config 123 | 124 | def compute_output_shape(self, input_shape): 125 | return input_shape 126 | 127 | #--------------------------------------------------------------------------------------------------------------------# 128 | # classtoken部分是transformer的分类特征。用于堆叠到序列化后的图片特征中,作为一个单位的序列特征进行特征提取。 129 | # 130 | # 在利用步长为16x16的卷积将输入图片划分成14x14的部分后,将14x14部分的特征平铺,一幅图片会存在序列长度为196的特征。 131 | # 此时生成一个classtoken,将classtoken堆叠到序列长度为196的特征上,获得一个序列长度为197的特征。 132 | # 在特征提取的过程中,classtoken会与图片特征进行特征的交互。最终分类时,我们取出classtoken的特征,利用全连接分类。 133 | #--------------------------------------------------------------------------------------------------------------------# 134 | class ClassToken(Layer): 135 | def __init__(self, cls_initializer='zeros', cls_regularizer=None, cls_constraint=None, **kwargs): 136 | super(ClassToken, self).__init__(**kwargs) 137 | self.cls_initializer = keras.initializers.get(cls_initializer) 138 | self.cls_regularizer = keras.regularizers.get(cls_regularizer) 139 | self.cls_constraint = keras.constraints.get(cls_constraint) 140 | 141 | def get_config(self): 142 | config = { 143 | 'cls_initializer': keras.initializers.serialize(self.cls_initializer), 144 | 'cls_regularizer': keras.regularizers.serialize(self.cls_regularizer), 145 | 'cls_constraint': keras.constraints.serialize(self.cls_constraint), 146 | } 147 | base_config = super(ClassToken, self).get_config() 148 | return dict(list(base_config.items()) + list(config.items())) 149 | 150 | def compute_output_shape(self, input_shape): 151 | return (input_shape[0], input_shape[1] + 1, input_shape[2]) 152 | 153 | def build(self, input_shape): 154 | self.num_features = input_shape[-1] 155 | self.cls = self.add_weight( 156 | shape = (1, 1, self.num_features), 157 | initializer = self.cls_initializer, 158 | regularizer = self.cls_regularizer, 159 | constraint = self.cls_constraint, 160 | name = 'cls', 161 | ) 162 | super(ClassToken, self).build(input_shape) 163 | 164 | def call(self, inputs): 165 | batch_size = tf.shape(inputs)[0] 166 | cls_broadcasted = tf.cast(tf.broadcast_to(self.cls, [batch_size, 1, self.num_features]), dtype = inputs.dtype) 167 | return tf.concat([cls_broadcasted, inputs], 1) 168 | 169 | #--------------------------------------------------------------------------------------------------------------------# 170 | # 为网络提取到的特征添加上位置信息。 171 | # 以输入图片为224, 224, 3为例,我们获得的序列化后的图片特征为196, 768。加上classtoken后就是197, 768 172 | # 此时生成的pos_Embedding的shape也为197, 768,代表每一个特征的位置信息。 173 | #--------------------------------------------------------------------------------------------------------------------# 174 | class AddPositionEmbs(Layer): 175 | def __init__(self, image_shape, patch_size, pe_initializer='zeros', pe_regularizer=None, pe_constraint=None, **kwargs): 176 | super(AddPositionEmbs, self).__init__(**kwargs) 177 | self.image_shape = image_shape 178 | self.patch_size = patch_size 179 | self.pe_initializer = keras.initializers.get(pe_initializer) 180 | self.pe_regularizer = keras.regularizers.get(pe_regularizer) 181 | self.pe_constraint = keras.constraints.get(pe_constraint) 182 | 183 | def get_config(self): 184 | config = { 185 | 'pe_initializer': keras.initializers.serialize(self.pe_initializer), 186 | 'pe_regularizer': keras.regularizers.serialize(self.pe_regularizer), 187 | 'pe_constraint': keras.constraints.serialize(self.pe_constraint), 188 | } 189 | base_config = super(AddPositionEmbs, self).get_config() 190 | return dict(list(base_config.items()) + list(config.items())) 191 | 192 | def compute_output_shape(self, input_shape): 193 | return input_shape 194 | 195 | def build(self, input_shape): 196 | assert (len(input_shape) == 3), f"Number of dimensions should be 3, got {len(input_shape)}" 197 | length = (224 // self.patch_size) * (224 // self.patch_size) + 1 198 | self.pe = self.add_weight( 199 | # shape = [1, input_shape[1], input_shape[2]], 200 | shape = [1, length, input_shape[2]], 201 | initializer = self.pe_initializer, 202 | regularizer = self.pe_regularizer, 203 | constraint = self.pe_constraint, 204 | name = 'pos_embedding', 205 | ) 206 | super(AddPositionEmbs, self).build(input_shape) 207 | 208 | def call(self, inputs): 209 | num_features = tf.shape(inputs)[2] 210 | 211 | cls_token_pe = self.pe[:, 0:1, :] 212 | img_token_pe = self.pe[:, 1: , :] 213 | 214 | img_token_pe = tf.reshape(img_token_pe, [1, (224 // self.patch_size), (224 // self.patch_size), num_features]) 215 | img_token_pe = tf.compat.v1.image.resize_images(img_token_pe, (self.image_shape[0] // self.patch_size, self.image_shape[1] // self.patch_size), tf.image.ResizeMethod.BICUBIC, align_corners=False) 216 | img_token_pe = tf.reshape(img_token_pe, [1, -1, num_features]) 217 | 218 | pe = tf.concat([cls_token_pe, img_token_pe], axis = 1) 219 | 220 | return inputs + tf.cast(pe, dtype=inputs.dtype) 221 | 222 | #--------------------------------------------------------------------------------------------------------------------# 223 | # Attention机制 224 | # 将输入的特征qkv特征进行划分,首先生成query, key, value。query是查询向量、key是键向量、v是值向量。 225 | # 然后利用 查询向量query 点乘 转置后的键向量key,这一步可以通俗的理解为,利用查询向量去查询序列的特征,获得序列每个部分的重要程度score。 226 | # 然后利用 score 点乘 value,这一步可以通俗的理解为,将序列每个部分的重要程度重新施加到序列的值上去。 227 | #--------------------------------------------------------------------------------------------------------------------# 228 | class Attention(Layer): 229 | def __init__(self, num_features, num_heads, **kwargs): 230 | super(Attention, self).__init__(**kwargs) 231 | self.num_features = num_features 232 | self.num_heads = num_heads 233 | self.projection_dim = num_features // num_heads 234 | 235 | def get_config(self): 236 | base_config = super(Attention, self).get_config() 237 | return dict(list(base_config.items())) 238 | 239 | def compute_output_shape(self, input_shape): 240 | return (input_shape[0], input_shape[1], input_shape[2] // 3) 241 | 242 | def call(self, inputs): 243 | #-----------------------------------------------# 244 | # 获得batch_size 245 | #-----------------------------------------------# 246 | bs = tf.shape(inputs)[0] 247 | 248 | #-----------------------------------------------# 249 | # b, 197, 3 * 768 -> b, 197, 3, 12, 64 250 | #-----------------------------------------------# 251 | inputs = tf.reshape(inputs, [bs, -1, 3, self.num_heads, self.projection_dim]) 252 | #-----------------------------------------------# 253 | # b, 197, 3, 12, 64 -> 3, b, 12, 197, 64 254 | #-----------------------------------------------# 255 | inputs = tf.transpose(inputs, [2, 0, 3, 1, 4]) 256 | #-----------------------------------------------# 257 | # 将query, key, value划分开 258 | # query b, 12, 197, 64 259 | # key b, 12, 197, 64 260 | # value b, 12, 197, 64 261 | #-----------------------------------------------# 262 | query, key, value = inputs[0], inputs[1], inputs[2] 263 | #-----------------------------------------------# 264 | # b, 12, 197, 64 @ b, 12, 197, 64 = b, 12, 197, 197 265 | #-----------------------------------------------# 266 | score = tf.matmul(query, key, transpose_b=True) 267 | #-----------------------------------------------# 268 | # 进行数量级的缩放 269 | #-----------------------------------------------# 270 | scaled_score = score / tf.math.sqrt(tf.cast(self.projection_dim, score.dtype)) 271 | #-----------------------------------------------# 272 | # b, 12, 197, 197 -> b, 12, 197, 197 273 | #-----------------------------------------------# 274 | weights = tf.nn.softmax(scaled_score, axis=-1) 275 | #-----------------------------------------------# 276 | # b, 12, 197, 197 @ b, 12, 197, 64 = b, 12, 197, 64 277 | #-----------------------------------------------# 278 | value = tf.matmul(weights, value) 279 | 280 | #-----------------------------------------------# 281 | # b, 12, 197, 64 -> b, 197, 12, 64 282 | #-----------------------------------------------# 283 | value = tf.transpose(value, perm=[0, 2, 1, 3]) 284 | #-----------------------------------------------# 285 | # b, 197, 12, 64 -> b, 197, 768 286 | #-----------------------------------------------# 287 | output = tf.reshape(value, (bs, -1, self.num_features)) 288 | return output 289 | 290 | def MultiHeadSelfAttention(inputs, num_features, num_heads, dropout, name): 291 | #-----------------------------------------------# 292 | # qkv b, 197, 768 -> b, 197, 3 * 768 293 | #-----------------------------------------------# 294 | qkv = Dense(int(num_features * 3), name = name + "qkv")(inputs) 295 | #-----------------------------------------------# 296 | # b, 197, 3 * 768 -> b, 197, 768 297 | #-----------------------------------------------# 298 | x = Attention(num_features, num_heads)(qkv) 299 | #-----------------------------------------------# 300 | # 197, 768 -> 197, 768 301 | #-----------------------------------------------# 302 | x = Dense(num_features, name = name + "proj")(x) 303 | x = Dropout(dropout)(x) 304 | return x 305 | 306 | def MLP(y, num_features, mlp_dim, dropout, name): 307 | y = Dense(mlp_dim, name = name + "fc1")(y) 308 | y = Gelu()(y) 309 | y = Dropout(dropout)(y) 310 | y = Dense(num_features, name = name + "fc2")(y) 311 | return y 312 | 313 | def TransformerBlock(inputs, num_features, num_heads, mlp_dim, dropout, name): 314 | #-----------------------------------------------# 315 | # 施加层标准化 316 | #-----------------------------------------------# 317 | x = LayerNormalization(epsilon=1e-6, name = name + "norm1")(inputs) 318 | #-----------------------------------------------# 319 | # 施加多头注意力机制 320 | #-----------------------------------------------# 321 | x = MultiHeadSelfAttention(x, num_features, num_heads, dropout, name = name + "attn.") 322 | x = Dropout(dropout)(x) 323 | #-----------------------------------------------# 324 | # 施加残差结构 325 | #-----------------------------------------------# 326 | x = Add()([x, inputs]) 327 | 328 | #-----------------------------------------------# 329 | # 施加层标准化 330 | #-----------------------------------------------# 331 | y = LayerNormalization(epsilon=1e-6, name = name + "norm2")(x) 332 | #-----------------------------------------------# 333 | # 施加两次全连接 334 | #-----------------------------------------------# 335 | y = MLP(y, num_features, mlp_dim, dropout, name = name + "mlp.") 336 | y = Dropout(dropout)(y) 337 | #-----------------------------------------------# 338 | # 施加残差结构 339 | #-----------------------------------------------# 340 | y = Add()([x, y]) 341 | return y 342 | 343 | def VisionTransformer(input_shape = [224, 224], patch_size = 16, num_layers = 12, num_features = 768, num_heads = 12, mlp_dim = 3072, 344 | classes = 1000, dropout = 0.1): 345 | #-----------------------------------------------# 346 | # 224, 224, 3 347 | #-----------------------------------------------# 348 | inputs = Input(shape = (input_shape[0], input_shape[1], 3)) 349 | 350 | #-----------------------------------------------# 351 | # 224, 224, 3 -> 14, 14, 768 352 | #-----------------------------------------------# 353 | x = Conv2D(num_features, patch_size, strides = patch_size, padding = "valid", name = "patch_embed.proj")(inputs) 354 | #-----------------------------------------------# 355 | # 14, 14, 768 -> 196, 768 356 | #-----------------------------------------------# 357 | x = Reshape(((input_shape[0] // patch_size) * (input_shape[1] // patch_size), num_features))(x) 358 | #-----------------------------------------------# 359 | # 196, 768 -> 197, 768 360 | #-----------------------------------------------# 361 | x = ClassToken(name="cls_token")(x) 362 | #-----------------------------------------------# 363 | # 197, 768 -> 197, 768 364 | #-----------------------------------------------# 365 | x = AddPositionEmbs(input_shape, patch_size, name="pos_embed")(x) 366 | #-----------------------------------------------# 367 | # 197, 768 -> 197, 768 12次 368 | #-----------------------------------------------# 369 | for n in range(num_layers): 370 | x = TransformerBlock( 371 | x, 372 | num_features= num_features, 373 | num_heads = num_heads, 374 | mlp_dim = mlp_dim, 375 | dropout = dropout, 376 | name = "blocks." + str(n) + ".", 377 | ) 378 | x = LayerNormalization( 379 | epsilon=1e-6, name="norm" 380 | )(x) 381 | x = Lambda(lambda v: v[:, 0], name="ExtractToken")(x) 382 | x = Dense(classes, name="head")(x) 383 | x = Softmax()(x) 384 | return keras.models.Model(inputs, x) 385 | -------------------------------------------------------------------------------- /predict.py: -------------------------------------------------------------------------------- 1 | ''' 2 | predict.py有几个注意点 3 | 1、无法进行批量预测,如果想要批量预测,可以利用os.listdir()遍历文件夹,利用Image.open打开图片文件进行预测。 4 | 2、如果想要将预测结果保存成txt,可以利用open打开txt文件,使用write方法写入txt,可以参考一下txt_annotation.py文件。 5 | ''' 6 | import tensorflow as tf 7 | from PIL import Image 8 | 9 | from classification import Classification 10 | 11 | gpus = tf.config.experimental.list_physical_devices(device_type='GPU') 12 | for gpu in gpus: 13 | tf.config.experimental.set_memory_growth(gpu, True) 14 | 15 | classfication = Classification() 16 | 17 | while True: 18 | img = input('Input image filename:') 19 | try: 20 | image = Image.open(img) 21 | except: 22 | print('Open Error! Try again!') 23 | continue 24 | else: 25 | class_name = classfication.detect_image(image) 26 | print(class_name) 27 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | scipy==1.4.1 2 | numpy==1.18.4 3 | matplotlib==3.2.1 4 | opencv_python==4.2.0.34 5 | tensorflow_gpu==2.2.0 6 | tqdm==4.46.1 7 | Pillow==8.2.0 8 | h5py==2.10.0 9 | -------------------------------------------------------------------------------- /summary.py: -------------------------------------------------------------------------------- 1 | #--------------------------------------------# 2 | # 该部分代码只用于看网络结构,并非测试代码 3 | #--------------------------------------------# 4 | from nets import get_model_from_name 5 | from utils.utils import net_flops 6 | 7 | if __name__ == "__main__": 8 | input_shape = [224, 224] 9 | num_classes = 1000 10 | backbone = "swin_transformer_tiny" 11 | 12 | model = get_model_from_name[backbone]([input_shape[0], input_shape[1], 3], classes=num_classes) 13 | #--------------------------------------------# 14 | # 查看网络结构网络结构 15 | #--------------------------------------------# 16 | model.summary() 17 | #--------------------------------------------# 18 | # 计算网络的FLOPS 19 | #--------------------------------------------# 20 | net_flops(model, table=False) 21 | 22 | #--------------------------------------------# 23 | # 获得网络每个层的名称与序号 24 | #--------------------------------------------# 25 | # for i,layer in enumerate(model.layers): 26 | # print(i,layer.name) 27 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import os 3 | from functools import partial 4 | 5 | import numpy as np 6 | import tensorflow as tf 7 | from tensorflow.keras import backend as K 8 | from tensorflow.keras.callbacks import (EarlyStopping, LearningRateScheduler, 9 | TensorBoard) 10 | from tensorflow.keras.optimizers import SGD, Adam 11 | 12 | from nets import freeze_layers, get_model_from_name 13 | from utils.callbacks import (ExponentDecayScheduler, LossHistory, 14 | ModelCheckpoint) 15 | from utils.dataloader import ClsDatasets 16 | from utils.utils import get_classes, get_lr_scheduler, show_config 17 | from utils.utils_fit import fit_one_epoch 18 | 19 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' 20 | 21 | #----------------------------------------# 22 | # 主函数 23 | #----------------------------------------# 24 | if __name__ == "__main__": 25 | #----------------------------------------------------# 26 | # 是否使用eager模式训练 27 | #----------------------------------------------------# 28 | eager = False 29 | #---------------------------------------------------------------------# 30 | # train_gpu 训练用到的GPU 31 | # 默认为第一张卡、双卡为[0, 1]、三卡为[0, 1, 2] 32 | # 在使用多GPU时,每个卡上的batch为总batch除以卡的数量。 33 | #---------------------------------------------------------------------# 34 | train_gpu = [0,] 35 | #---------------------------------------------------------------------# 36 | # classes_path 指向model_data下的txt,与自己训练的数据集相关 37 | # 训练前一定要修改classes_path,使其对应自己的数据集 38 | #---------------------------------------------------------------------# 39 | classes_path = 'model_data/cls_classes.txt' 40 | #------------------------------------------------------# 41 | # input_shape 输入的shape大小 42 | #------------------------------------------------------# 43 | input_shape = [224, 224] 44 | #------------------------------------------------------# 45 | # 所用模型种类: 46 | # mobilenetv1、mobilenetv2、resnet50、vgg16、 47 | # vit_b_16、 48 | # swin_transformer_tiny、swin_transformer_small、swin_transformer_base 49 | #------------------------------------------------------# 50 | backbone = "mobilenetv1" 51 | #------------------------------------------------------# 52 | # 当使用mobilenetv1的alpha值 53 | # 仅在backbone='mobilenetv1'的时候有效 54 | #------------------------------------------------------# 55 | alpha = 0.25 56 | #----------------------------------------------------------------------------------------------------------------------------# 57 | # 权值文件的下载请看README,可以通过网盘下载。模型的 预训练权重 对不同数据集是通用的,因为特征是通用的。 58 | # 模型的 预训练权重 比较重要的部分是 主干特征提取网络的权值部分,用于进行特征提取。 59 | # 预训练权重对于99%的情况都必须要用,不用的话主干部分的权值太过随机,特征提取效果不明显,网络训练的结果也不会好 60 | # 61 | # 如果训练过程中存在中断训练的操作,可以将model_path设置成logs文件夹下的权值文件,将已经训练了一部分的权值再次载入。 62 | # 同时修改下方的 冻结阶段 或者 解冻阶段 的参数,来保证模型epoch的连续性。 63 | # 64 | # 当model_path = ''的时候不加载整个模型的权值。 65 | # 66 | # 此处使用的是整个模型的权重,因此是在train.py进行加载的。 67 | # 如果想要让模型从主干的预训练权值开始训练,则设置model_path为主干网络的权值,此时仅加载主干。 68 | # 如果想要让模型从0开始训练,则设置model_path = '',Freeze_Train = Fasle,此时从0开始训练,且没有冻结主干的过程。 69 | #----------------------------------------------------------------------------------------------------------------------------# 70 | model_path = "model_data/mobilenet_2_5_224_tf_no_top.h5" 71 | 72 | #----------------------------------------------------------------------------------------------------------------------------# 73 | # 训练分为两个阶段,分别是冻结阶段和解冻阶段。设置冻结阶段是为了满足机器性能不足的同学的训练需求。 74 | # 冻结训练需要的显存较小,显卡非常差的情况下,可设置Freeze_Epoch等于UnFreeze_Epoch,此时仅仅进行冻结训练。 75 | # 76 | # 在此提供若干参数设置建议,各位训练者根据自己的需求进行灵活调整: 77 | # (一)从整个模型的预训练权重开始训练: 78 | # Adam: 79 | # Init_Epoch = 0,Freeze_Epoch = 50,UnFreeze_Epoch = 100,Freeze_Train = True,optimizer_type = 'adam',Init_lr = 1e-3。(冻结) 80 | # Init_Epoch = 0,UnFreeze_Epoch = 100,Freeze_Train = False,optimizer_type = 'adam',Init_lr = 1e-3。(不冻结) 81 | # SGD: 82 | # Init_Epoch = 0,Freeze_Epoch = 50,UnFreeze_Epoch = 200,Freeze_Train = True,optimizer_type = 'sgd',Init_lr = 1e-2。(冻结) 83 | # Init_Epoch = 0,UnFreeze_Epoch = 200,Freeze_Train = False,optimizer_type = 'sgd',Init_lr = 1e-2。(不冻结) 84 | # 其中:UnFreeze_Epoch可以在100-300之间调整。 85 | # (二)从0开始训练: 86 | # Adam: 87 | # Init_Epoch = 0,UnFreeze_Epoch = 300,Unfreeze_batch_size >= 16,Freeze_Train = False,optimizer_type = 'adam',Init_lr = 1e-3。(不冻结) 88 | # SGD: 89 | # Init_Epoch = 0,UnFreeze_Epoch = 300,Unfreeze_batch_size >= 16,Freeze_Train = False,optimizer_type = 'sgd',Init_lr = 1e-2。(不冻结) 90 | # 其中:UnFreeze_Epoch尽量不小于300。 91 | # (三)batch_size的设置: 92 | # 在显卡能够接受的范围内,以大为好。显存不足与数据集大小无关,提示显存不足(OOM或者CUDA out of memory)请调小batch_size。 93 | # 受到BatchNorm层影响,batch_size最小为2,不能为1。 94 | # 正常情况下Freeze_batch_size建议为Unfreeze_batch_size的1-2倍。不建议设置的差距过大,因为关系到学习率的自动调整。 95 | #----------------------------------------------------------------------------------------------------------------------------# 96 | #------------------------------------------------------------------# 97 | # 冻结阶段训练参数 98 | # 此时模型的主干被冻结了,特征提取网络不发生改变 99 | # 占用的显存较小,仅对网络进行微调 100 | # Init_Epoch 模型当前开始的训练世代,其值可以大于Freeze_Epoch,如设置: 101 | # Init_Epoch = 60、Freeze_Epoch = 50、UnFreeze_Epoch = 100 102 | # 会跳过冻结阶段,直接从60代开始,并调整对应的学习率。 103 | # (断点续练时使用) 104 | # Freeze_Epoch 模型冻结训练的Freeze_Epoch 105 | # (当Freeze_Train=False时失效) 106 | # Freeze_batch_size 模型冻结训练的batch_size 107 | # (当Freeze_Train=False时失效) 108 | #------------------------------------------------------------------# 109 | Init_Epoch = 0 110 | Freeze_Epoch = 50 111 | Freeze_batch_size = 32 112 | #------------------------------------------------------------------# 113 | # 解冻阶段训练参数 114 | # 此时模型的主干不被冻结了,特征提取网络会发生改变 115 | # 占用的显存较大,网络所有的参数都会发生改变 116 | # UnFreeze_Epoch 模型总共训练的epoch 117 | # Unfreeze_batch_size 模型在解冻后的batch_size 118 | #------------------------------------------------------------------# 119 | UnFreeze_Epoch = 200 120 | Unfreeze_batch_size = 32 121 | #------------------------------------------------------------------# 122 | # Freeze_Train 是否进行冻结训练 123 | # 默认先冻结主干训练后解冻训练。 124 | #------------------------------------------------------------------# 125 | Freeze_Train = True 126 | 127 | #------------------------------------------------------------------# 128 | # 其它训练参数:学习率、优化器、学习率下降有关 129 | #------------------------------------------------------------------# 130 | #------------------------------------------------------------------# 131 | # Init_lr 模型的最大学习率 132 | # 当使用Adam优化器时建议设置 Init_lr=1e-3 133 | # 当使用SGD优化器时建议设置 Init_lr=1e-2 134 | # Min_lr 模型的最小学习率,默认为最大学习率的0.01 135 | #------------------------------------------------------------------# 136 | Init_lr = 1e-2 137 | Min_lr = Init_lr * 0.01 138 | #------------------------------------------------------------------# 139 | # optimizer_type 使用到的优化器种类,可选的有adam、sgd 140 | # 当使用Adam优化器时建议设置 Init_lr=1e-3 141 | # 当使用SGD优化器时建议设置 Init_lr=1e-2 142 | # momentum 优化器内部使用到的momentum参数 143 | #------------------------------------------------------------------# 144 | optimizer_type = "sgd" 145 | momentum = 0.9 146 | #------------------------------------------------------------------# 147 | # lr_decay_type 使用到的学习率下降方式,可选的有'step'、'cos' 148 | #------------------------------------------------------------------# 149 | lr_decay_type = 'cos' 150 | #------------------------------------------------------------------# 151 | # save_period 多少个epoch保存一次权值 152 | #------------------------------------------------------------------# 153 | save_period = 10 154 | #------------------------------------------------------------------# 155 | # save_dir 权值与日志文件保存的文件夹 156 | #------------------------------------------------------------------# 157 | save_dir = 'logs' 158 | #------------------------------------------------------------------# 159 | # num_workers 用于设置是否使用多线程读取数据,1代表关闭多线程 160 | # 开启后会加快数据读取速度,但是会占用更多内存 161 | # keras里开启多线程有些时候速度反而慢了许多 162 | # 在IO为瓶颈的时候再开启多线程,即GPU运算速度远大于读取图片的速度。 163 | #------------------------------------------------------------------# 164 | num_workers = 1 165 | 166 | #------------------------------------------------------# 167 | # train_annotation_path 训练图片路径和标签 168 | # test_annotation_path 验证图片路径和标签(使用测试集代替验证集) 169 | #------------------------------------------------------# 170 | train_annotation_path = "cls_train.txt" 171 | test_annotation_path = 'cls_test.txt' 172 | 173 | #------------------------------------------------------# 174 | # 设置用到的显卡 175 | #------------------------------------------------------# 176 | os.environ["CUDA_VISIBLE_DEVICES"] = ','.join(str(x) for x in train_gpu) 177 | ngpus_per_node = len(train_gpu) 178 | 179 | gpus = tf.config.experimental.list_physical_devices(device_type='GPU') 180 | for gpu in gpus: 181 | tf.config.experimental.set_memory_growth(gpu, True) 182 | 183 | #------------------------------------------------------# 184 | # 判断当前使用的GPU数量与机器上实际的GPU数量 185 | #------------------------------------------------------# 186 | if ngpus_per_node > 1 and ngpus_per_node > len(gpus): 187 | raise ValueError("The number of GPUs specified for training is more than the GPUs on the machine") 188 | 189 | if ngpus_per_node > 1: 190 | strategy = tf.distribute.MirroredStrategy() 191 | else: 192 | strategy = None 193 | print('Number of devices: {}'.format(ngpus_per_node)) 194 | 195 | #------------------------------------------------------# 196 | # 获取classes 197 | #------------------------------------------------------# 198 | class_names, num_classes = get_classes(classes_path) 199 | 200 | if ngpus_per_node > 1: 201 | with strategy.scope(): 202 | #------------------------------------------------------# 203 | # 创建分类模型 204 | #------------------------------------------------------# 205 | if backbone == "mobilenetv1": 206 | model = get_model_from_name[backbone](input_shape=[input_shape[0], input_shape[1], 3], classes=num_classes, alpha=alpha) 207 | else: 208 | model = get_model_from_name[backbone](input_shape=[input_shape[0], input_shape[1], 3], classes=num_classes) 209 | if model_path != "": 210 | #------------------------------------------------------# 211 | # 载入预训练权重 212 | #------------------------------------------------------# 213 | print('Load weights {}.'.format(model_path)) 214 | model.load_weights(model_path, by_name=True, skip_mismatch=True) 215 | else: 216 | #------------------------------------------------------# 217 | # 创建分类模型 218 | #------------------------------------------------------# 219 | if backbone == "mobilenetv1": 220 | model = get_model_from_name[backbone](input_shape=[input_shape[0], input_shape[1], 3], classes=num_classes, alpha=alpha) 221 | else: 222 | model = get_model_from_name[backbone](input_shape=[input_shape[0], input_shape[1], 3], classes=num_classes) 223 | if model_path != "": 224 | #------------------------------------------------------# 225 | # 载入预训练权重 226 | #------------------------------------------------------# 227 | print('Load weights {}.'.format(model_path)) 228 | model.load_weights(model_path, by_name=True, skip_mismatch=True) 229 | 230 | #---------------------------# 231 | # 读取数据集对应的txt 232 | #---------------------------# 233 | with open(train_annotation_path, encoding='utf-8') as f: 234 | train_lines = f.readlines() 235 | with open(test_annotation_path, encoding='utf-8') as f: 236 | val_lines = f.readlines() 237 | num_train = len(train_lines) 238 | num_val = len(val_lines) 239 | np.random.seed(10101) 240 | np.random.shuffle(train_lines) 241 | np.random.seed(None) 242 | 243 | show_config( 244 | num_classes = num_classes, backbone = backbone, model_path = model_path, input_shape = input_shape, \ 245 | Init_Epoch = Init_Epoch, Freeze_Epoch = Freeze_Epoch, UnFreeze_Epoch = UnFreeze_Epoch, Freeze_batch_size = Freeze_batch_size, Unfreeze_batch_size = Unfreeze_batch_size, Freeze_Train = Freeze_Train, \ 246 | Init_lr = Init_lr, Min_lr = Min_lr, optimizer_type = optimizer_type, momentum = momentum, lr_decay_type = lr_decay_type, \ 247 | save_period = save_period, save_dir = save_dir, num_workers = num_workers, num_train = num_train, num_val = num_val 248 | ) 249 | #---------------------------------------------------------# 250 | # 总训练世代指的是遍历全部数据的总次数 251 | # 总训练步长指的是梯度下降的总次数 252 | # 每个训练世代包含若干训练步长,每个训练步长进行一次梯度下降。 253 | # 此处仅建议最低训练世代,上不封顶,计算时只考虑了解冻部分 254 | #----------------------------------------------------------# 255 | wanted_step = 3e4 if optimizer_type == "sgd" else 1e4 256 | total_step = num_train // Unfreeze_batch_size * UnFreeze_Epoch 257 | if total_step <= wanted_step: 258 | wanted_epoch = wanted_step // (num_train // Unfreeze_batch_size) + 1 259 | print("\n\033[1;33;44m[Warning] 使用%s优化器时,建议将训练总步长设置到%d以上。\033[0m"%(optimizer_type, wanted_step)) 260 | print("\033[1;33;44m[Warning] 本次运行的总训练数据量为%d,Unfreeze_batch_size为%d,共训练%d个Epoch,计算出总训练步长为%d。\033[0m"%(num_train, Unfreeze_batch_size, UnFreeze_Epoch, total_step)) 261 | print("\033[1;33;44m[Warning] 由于总训练步长为%d,小于建议总步长%d,建议设置总世代为%d。\033[0m"%(total_step, wanted_step, wanted_epoch)) 262 | 263 | #------------------------------------------------------# 264 | # 主干特征提取网络特征通用,冻结训练可以加快训练速度 265 | # 也可以在训练初期防止权值被破坏。 266 | # Init_Epoch为起始世代 267 | # Freeze_Epoch为冻结训练的世代 268 | # UnFreeze_Epoch总训练世代 269 | # 提示OOM或者显存不足请调小Batch_size 270 | #------------------------------------------------------# 271 | if True: 272 | if Freeze_Train: 273 | freeze_layers = freeze_layers[backbone] 274 | for i in range(freeze_layers): model.layers[i].trainable = False 275 | print('Freeze the first {} layers of total {} layers.'.format(freeze_layers, len(model.layers))) 276 | 277 | #-------------------------------------------------------------------# 278 | # 如果不冻结训练的话,直接设置batch_size为Unfreeze_batch_size 279 | #-------------------------------------------------------------------# 280 | batch_size = Freeze_batch_size if Freeze_Train else Unfreeze_batch_size 281 | start_epoch = Init_Epoch 282 | end_epoch = Freeze_Epoch if Freeze_Train else UnFreeze_Epoch 283 | 284 | #-------------------------------------------------------------------# 285 | # 判断当前batch_size,自适应调整学习率 286 | #-------------------------------------------------------------------# 287 | nbs = 64 288 | lr_limit_max = 1e-3 if optimizer_type == 'adam' else 1e-1 289 | lr_limit_min = 1e-4 if optimizer_type == 'adam' else 5e-4 290 | if backbone in ['vit_b_16', 'swin_transformer_tiny', 'swin_transformer_small', 'swin_transformer_base']: 291 | nbs = 256 292 | lr_limit_max = 1e-3 if optimizer_type == 'adam' else 1e-1 293 | lr_limit_min = 1e-5 if optimizer_type == 'adam' else 5e-4 294 | Init_lr_fit = min(max(batch_size / nbs * Init_lr, lr_limit_min), lr_limit_max) 295 | Min_lr_fit = min(max(batch_size / nbs * Min_lr, lr_limit_min * 1e-2), lr_limit_max * 1e-2) 296 | 297 | #---------------------------------------# 298 | # 获得学习率下降的公式 299 | #---------------------------------------# 300 | lr_scheduler_func = get_lr_scheduler(lr_decay_type, Init_lr_fit, Min_lr_fit, UnFreeze_Epoch) 301 | 302 | epoch_step = num_train // batch_size 303 | epoch_step_val = num_val // batch_size 304 | 305 | if epoch_step == 0 or epoch_step_val == 0: 306 | raise ValueError('数据集过小,无法进行训练,请扩充数据集。') 307 | 308 | train_dataloader = ClsDatasets(train_lines, input_shape, batch_size, num_classes, train = True) 309 | val_dataloader = ClsDatasets(val_lines, input_shape, batch_size, num_classes, train = False) 310 | 311 | optimizer = { 312 | 'adam' : Adam(lr = Init_lr_fit, beta_1 = momentum), 313 | 'sgd' : SGD(lr = Init_lr_fit, momentum = momentum, nesterov=True) 314 | }[optimizer_type] 315 | 316 | if eager: 317 | start_epoch = Init_Epoch 318 | end_epoch = UnFreeze_Epoch 319 | UnFreeze_flag = False 320 | 321 | gen = tf.data.Dataset.from_generator(partial(train_dataloader.generate), (tf.float32, tf.float32)) 322 | gen_val = tf.data.Dataset.from_generator(partial(val_dataloader.generate), (tf.float32, tf.float32)) 323 | 324 | gen = gen.shuffle(buffer_size = batch_size).prefetch(buffer_size = batch_size) 325 | gen_val = gen_val.shuffle(buffer_size = batch_size).prefetch(buffer_size = batch_size) 326 | 327 | if ngpus_per_node > 1: 328 | gen = strategy.experimental_distribute_dataset(gen) 329 | gen_val = strategy.experimental_distribute_dataset(gen_val) 330 | 331 | time_str = datetime.datetime.strftime(datetime.datetime.now(),'%Y_%m_%d_%H_%M_%S') 332 | log_dir = os.path.join(save_dir, "loss_" + str(time_str)) 333 | loss_history = LossHistory(log_dir) 334 | #---------------------------------------# 335 | # 开始模型训练 336 | #---------------------------------------# 337 | for epoch in range(start_epoch, end_epoch): 338 | #---------------------------------------# 339 | # 如果模型有冻结学习部分 340 | # 则解冻,并设置参数 341 | #---------------------------------------# 342 | if epoch >= Freeze_Epoch and not UnFreeze_flag and Freeze_Train: 343 | batch_size = Unfreeze_batch_size 344 | 345 | #-------------------------------------------------------------------# 346 | # 判断当前batch_size,自适应调整学习率 347 | #-------------------------------------------------------------------# 348 | nbs = 64 349 | lr_limit_max = 1e-3 if optimizer_type == 'adam' else 1e-1 350 | lr_limit_min = 1e-4 if optimizer_type == 'adam' else 5e-4 351 | if backbone in ['vit_b_16', 'swin_transformer_tiny', 'swin_transformer_small', 'swin_transformer_base']: 352 | nbs = 256 353 | lr_limit_max = 1e-3 if optimizer_type == 'adam' else 1e-1 354 | lr_limit_min = 1e-5 if optimizer_type == 'adam' else 5e-4 355 | Init_lr_fit = min(max(batch_size / nbs * Init_lr, lr_limit_min), lr_limit_max) 356 | Min_lr_fit = min(max(batch_size / nbs * Min_lr, lr_limit_min * 1e-2), lr_limit_max * 1e-2) 357 | #---------------------------------------# 358 | # 获得学习率下降的公式 359 | #---------------------------------------# 360 | lr_scheduler_func = get_lr_scheduler(lr_decay_type, Init_lr_fit, Min_lr_fit, UnFreeze_Epoch) 361 | 362 | for i in range(len(model.layers)): 363 | model.layers[i].trainable = True 364 | 365 | epoch_step = num_train // batch_size 366 | epoch_step_val = num_val // batch_size 367 | 368 | if epoch_step == 0 or epoch_step_val == 0: 369 | raise ValueError("数据集过小,无法继续进行训练,请扩充数据集。") 370 | 371 | train_dataloader.batch_size = batch_size 372 | val_dataloader.batch_size = batch_size 373 | 374 | gen = tf.data.Dataset.from_generator(partial(train_dataloader.generate), (tf.float32, tf.float32)) 375 | gen_val = tf.data.Dataset.from_generator(partial(val_dataloader.generate), (tf.float32, tf.float32)) 376 | 377 | gen = gen.shuffle(buffer_size = batch_size).prefetch(buffer_size = batch_size) 378 | gen_val = gen_val.shuffle(buffer_size = batch_size).prefetch(buffer_size = batch_size) 379 | 380 | if ngpus_per_node > 1: 381 | gen = strategy.experimental_distribute_dataset(gen) 382 | gen_val = strategy.experimental_distribute_dataset(gen_val) 383 | 384 | UnFreeze_flag = True 385 | 386 | lr = lr_scheduler_func(epoch) 387 | K.set_value(optimizer.lr, lr) 388 | 389 | fit_one_epoch(model, loss_history, optimizer, epoch, epoch_step, epoch_step_val, gen, gen_val, 390 | end_epoch, save_period, save_dir, strategy) 391 | 392 | train_dataloader.on_epoch_end() 393 | val_dataloader.on_epoch_end() 394 | else: 395 | start_epoch = Init_Epoch 396 | end_epoch = Freeze_Epoch if Freeze_Train else UnFreeze_Epoch 397 | 398 | if ngpus_per_node > 1: 399 | with strategy.scope(): 400 | model.compile(loss = 'categorical_crossentropy', optimizer = optimizer, metrics = ['categorical_accuracy']) 401 | else: 402 | model.compile(loss = 'categorical_crossentropy', optimizer = optimizer, metrics = ['categorical_accuracy']) 403 | #-------------------------------------------------------------------------------# 404 | # 训练参数的设置 405 | # logging 用于设置tensorboard的保存地址 406 | # checkpoint 用于设置权值保存的细节,period用于修改多少epoch保存一次 407 | # lr_scheduler 用于设置学习率下降的方式 408 | # early_stopping 用于设定早停,val_loss多次不下降自动结束训练,表示模型基本收敛 409 | #-------------------------------------------------------------------------------# 410 | time_str = datetime.datetime.strftime(datetime.datetime.now(),'%Y_%m_%d_%H_%M_%S') 411 | log_dir = os.path.join(save_dir, "loss_" + str(time_str)) 412 | logging = TensorBoard(log_dir) 413 | loss_history = LossHistory(log_dir) 414 | checkpoint = ModelCheckpoint(os.path.join(save_dir, "ep{epoch:03d}-loss{loss:.3f}-val_loss{val_loss:.3f}.h5"), 415 | monitor = 'val_loss', save_weights_only = True, save_best_only = False, period = save_period) 416 | checkpoint_last = ModelCheckpoint(os.path.join(save_dir, "last_epoch_weights.h5"), 417 | monitor = 'val_loss', save_weights_only = True, save_best_only = False, period = 1) 418 | checkpoint_best = ModelCheckpoint(os.path.join(save_dir, "best_epoch_weights.h5"), 419 | monitor = 'val_loss', save_weights_only = True, save_best_only = True, period = 1) 420 | early_stopping = EarlyStopping(monitor='val_loss', min_delta = 0, patience = 10, verbose = 1) 421 | lr_scheduler = LearningRateScheduler(lr_scheduler_func, verbose = 1) 422 | callbacks = [logging, loss_history, checkpoint, checkpoint_last, checkpoint_best, lr_scheduler] 423 | 424 | if start_epoch < end_epoch: 425 | print('Train on {} samples, val on {} samples, with batch size {}.'.format(num_train, num_val, batch_size)) 426 | model.fit( 427 | x = train_dataloader, 428 | steps_per_epoch = epoch_step, 429 | validation_data = val_dataloader, 430 | validation_steps = epoch_step_val, 431 | epochs = end_epoch, 432 | initial_epoch = start_epoch, 433 | use_multiprocessing = True if num_workers > 1 else False, 434 | workers = num_workers, 435 | callbacks = callbacks 436 | ) 437 | #---------------------------------------# 438 | # 如果模型有冻结学习部分 439 | # 则解冻,并设置参数 440 | #---------------------------------------# 441 | if Freeze_Train: 442 | batch_size = Unfreeze_batch_size 443 | start_epoch = Freeze_Epoch if start_epoch < Freeze_Epoch else start_epoch 444 | end_epoch = UnFreeze_Epoch 445 | 446 | #-------------------------------------------------------------------# 447 | # 判断当前batch_size,自适应调整学习率 448 | #-------------------------------------------------------------------# 449 | nbs = 64 450 | lr_limit_max = 1e-3 if optimizer_type == 'adam' else 1e-1 451 | lr_limit_min = 1e-4 if optimizer_type == 'adam' else 5e-4 452 | if backbone == 'vit': 453 | nbs = 256 454 | lr_limit_max = 1e-3 if optimizer_type == 'adam' else 1e-1 455 | lr_limit_min = 1e-5 if optimizer_type == 'adam' else 5e-4 456 | Init_lr_fit = min(max(batch_size / nbs * Init_lr, lr_limit_min), lr_limit_max) 457 | Min_lr_fit = min(max(batch_size / nbs * Min_lr, lr_limit_min * 1e-2), lr_limit_max * 1e-2) 458 | #---------------------------------------# 459 | # 获得学习率下降的公式 460 | #---------------------------------------# 461 | lr_scheduler_func = get_lr_scheduler(lr_decay_type, Init_lr_fit, Min_lr_fit, UnFreeze_Epoch) 462 | lr_scheduler = LearningRateScheduler(lr_scheduler_func, verbose = 1) 463 | callbacks = [logging, loss_history, checkpoint, checkpoint_last, checkpoint_best, lr_scheduler] 464 | 465 | for i in range(len(model.layers)): 466 | model.layers[i].trainable = True 467 | if ngpus_per_node > 1: 468 | with strategy.scope(): 469 | model.compile(loss = 'categorical_crossentropy', optimizer = optimizer, metrics = ['categorical_accuracy']) 470 | else: 471 | model.compile(loss = 'categorical_crossentropy', optimizer = optimizer, metrics = ['categorical_accuracy']) 472 | 473 | epoch_step = num_train // batch_size 474 | epoch_step_val = num_val // batch_size 475 | 476 | if epoch_step == 0 or epoch_step_val == 0: 477 | raise ValueError("数据集过小,无法继续进行训练,请扩充数据集。") 478 | 479 | train_dataloader.batch_size = Unfreeze_batch_size 480 | val_dataloader.batch_size = Unfreeze_batch_size 481 | 482 | print('Train on {} samples, val on {} samples, with batch size {}.'.format(num_train, num_val, batch_size)) 483 | model.fit( 484 | x = train_dataloader, 485 | steps_per_epoch = epoch_step, 486 | validation_data = val_dataloader, 487 | validation_steps = epoch_step_val, 488 | epochs = end_epoch, 489 | initial_epoch = start_epoch, 490 | use_multiprocessing = True if num_workers > 1 else False, 491 | workers = num_workers, 492 | callbacks = callbacks 493 | ) 494 | -------------------------------------------------------------------------------- /txt_annotation.py: -------------------------------------------------------------------------------- 1 | import os 2 | from os import getcwd 3 | 4 | from utils.utils import get_classes 5 | 6 | #-------------------------------------------------------------------# 7 | # classes_path 指向model_data下的txt,与自己训练的数据集相关 8 | # 训练前一定要修改classes_path,使其对应自己的数据集 9 | # txt文件中是自己所要去区分的种类 10 | # 与训练和预测所用的classes_path一致即可 11 | #-------------------------------------------------------------------# 12 | classes_path = 'model_data/cls_classes.txt' 13 | #-------------------------------------------------------# 14 | # datasets_path 指向数据集所在的路径 15 | #-------------------------------------------------------# 16 | datasets_path = 'datasets' 17 | 18 | sets = ["train", "test"] 19 | classes, _ = get_classes(classes_path) 20 | 21 | if __name__ == "__main__": 22 | wd = getcwd() 23 | 24 | for se in sets: 25 | list_file = open('cls_' + se + '.txt', 'w') 26 | 27 | datasets_path_t = os.path.join(datasets_path, se) 28 | types_name = os.listdir(datasets_path_t) 29 | for type_name in types_name: 30 | if type_name not in classes: 31 | continue 32 | cls_id = classes.index(type_name) 33 | 34 | photos_path = os.path.join(datasets_path_t, type_name) 35 | photos_name = os.listdir(photos_path) 36 | for photo_name in photos_name: 37 | _, postfix = os.path.splitext(photo_name) 38 | if postfix not in ['.jpg', '.png', '.jpeg']: 39 | continue 40 | list_file.write(str(cls_id) + ";" + '%s'%(os.path.join(photos_path, photo_name))) 41 | list_file.write('\n') 42 | list_file.close() 43 | 44 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------------------------------- /utils/callbacks.py: -------------------------------------------------------------------------------- 1 | import os 2 | import warnings 3 | import math 4 | 5 | from tensorflow import keras 6 | import matplotlib 7 | matplotlib.use('Agg') 8 | from matplotlib import pyplot as plt 9 | import numpy as np 10 | import scipy.signal 11 | import tensorflow as tf 12 | from tensorflow.keras import backend as K 13 | 14 | class LossHistory(keras.callbacks.Callback): 15 | def __init__(self, log_dir): 16 | self.log_dir = log_dir 17 | self.losses = [] 18 | self.val_loss = [] 19 | 20 | os.makedirs(self.log_dir) 21 | 22 | def on_epoch_end(self, epoch, logs={}): 23 | if not os.path.exists(self.log_dir): 24 | os.makedirs(self.log_dir) 25 | 26 | self.losses.append(logs.get('loss')) 27 | self.val_loss.append(logs.get('val_loss')) 28 | 29 | with open(os.path.join(self.log_dir, "epoch_loss.txt"), 'a') as f: 30 | f.write(str(logs.get('loss'))) 31 | f.write("\n") 32 | with open(os.path.join(self.log_dir, "epoch_val_loss.txt"), 'a') as f: 33 | f.write(str(logs.get('val_loss'))) 34 | f.write("\n") 35 | self.loss_plot() 36 | 37 | def loss_plot(self): 38 | iters = range(len(self.losses)) 39 | 40 | plt.figure() 41 | plt.plot(iters, self.losses, 'red', linewidth = 2, label='train loss') 42 | plt.plot(iters, self.val_loss, 'coral', linewidth = 2, label='val loss') 43 | try: 44 | if len(self.losses) < 25: 45 | num = 5 46 | else: 47 | num = 15 48 | 49 | plt.plot(iters, scipy.signal.savgol_filter(self.losses, num, 3), 'green', linestyle = '--', linewidth = 2, label='smooth train loss') 50 | plt.plot(iters, scipy.signal.savgol_filter(self.val_loss, num, 3), '#8B4513', linestyle = '--', linewidth = 2, label='smooth val loss') 51 | except: 52 | pass 53 | 54 | plt.grid(True) 55 | plt.xlabel('Epoch') 56 | plt.ylabel('Loss') 57 | plt.title('A Loss Curve') 58 | plt.legend(loc="upper right") 59 | 60 | plt.savefig(os.path.join(self.log_dir, "epoch_loss.png")) 61 | 62 | plt.cla() 63 | plt.close("all") 64 | 65 | class ExponentDecayScheduler(keras.callbacks.Callback): 66 | def __init__(self, 67 | decay_rate, 68 | verbose=0): 69 | super(ExponentDecayScheduler, self).__init__() 70 | self.decay_rate = decay_rate 71 | self.verbose = verbose 72 | self.learning_rates = [] 73 | 74 | def on_epoch_end(self, batch, logs=None): 75 | learning_rate = K.get_value(self.model.optimizer.lr) * self.decay_rate 76 | K.set_value(self.model.optimizer.lr, learning_rate) 77 | if self.verbose > 0: 78 | print('Setting learning rate to %s.' % (learning_rate)) 79 | 80 | class WarmUpCosineDecayScheduler(keras.callbacks.Callback): 81 | def __init__(self, T_max, eta_min=0, verbose=0): 82 | super(WarmUpCosineDecayScheduler, self).__init__() 83 | self.T_max = T_max 84 | self.eta_min = eta_min 85 | self.verbose = verbose 86 | self.init_lr = 0 87 | self.last_epoch = 0 88 | 89 | def on_train_begin(self, batch, logs=None): 90 | self.init_lr = K.get_value(self.model.optimizer.lr) 91 | 92 | def on_epoch_end(self, batch, logs=None): 93 | learning_rate = self.eta_min + (self.init_lr - self.eta_min) * (1 + math.cos(math.pi * self.last_epoch / self.T_max)) / 2 94 | self.last_epoch += 1 95 | 96 | K.set_value(self.model.optimizer.lr, learning_rate) 97 | if self.verbose > 0: 98 | print('Setting learning rate to %s.' % (learning_rate)) 99 | 100 | class ModelCheckpoint(tf.keras.callbacks.Callback): 101 | def __init__(self, filepath, monitor='val_loss', verbose=0, 102 | save_best_only=False, save_weights_only=False, 103 | mode='auto', period=1): 104 | super(ModelCheckpoint, self).__init__() 105 | self.monitor = monitor 106 | self.verbose = verbose 107 | self.filepath = filepath 108 | self.save_best_only = save_best_only 109 | self.save_weights_only = save_weights_only 110 | self.period = period 111 | self.epochs_since_last_save = 0 112 | 113 | if mode not in ['auto', 'min', 'max']: 114 | warnings.warn('ModelCheckpoint mode %s is unknown, ' 115 | 'fallback to auto mode.' % (mode), 116 | RuntimeWarning) 117 | mode = 'auto' 118 | 119 | if mode == 'min': 120 | self.monitor_op = np.less 121 | self.best = np.Inf 122 | elif mode == 'max': 123 | self.monitor_op = np.greater 124 | self.best = -np.Inf 125 | else: 126 | if 'acc' in self.monitor or self.monitor.startswith('fmeasure'): 127 | self.monitor_op = np.greater 128 | self.best = -np.Inf 129 | else: 130 | self.monitor_op = np.less 131 | self.best = np.Inf 132 | 133 | def on_epoch_end(self, epoch, logs=None): 134 | logs = logs or {} 135 | self.epochs_since_last_save += 1 136 | if self.epochs_since_last_save >= self.period: 137 | self.epochs_since_last_save = 0 138 | filepath = self.filepath.format(epoch=epoch + 1, **logs) 139 | if self.save_best_only: 140 | current = logs.get(self.monitor) 141 | if current is None: 142 | warnings.warn('Can save best model only with %s available, ' 143 | 'skipping.' % (self.monitor), RuntimeWarning) 144 | else: 145 | if self.monitor_op(current, self.best): 146 | if self.verbose > 0: 147 | print('\nEpoch %05d: %s improved from %0.5f to %0.5f,' 148 | ' saving model to %s' 149 | % (epoch + 1, self.monitor, self.best, 150 | current, filepath)) 151 | self.best = current 152 | if self.save_weights_only: 153 | self.model.save_weights(filepath, overwrite=True) 154 | else: 155 | self.model.save(filepath, overwrite=True) 156 | else: 157 | if self.verbose > 0: 158 | print('\nEpoch %05d: %s did not improve' % 159 | (epoch + 1, self.monitor)) 160 | else: 161 | if self.verbose > 0: 162 | print('\nEpoch %05d: saving model to %s' % (epoch + 1, filepath)) 163 | if self.save_weights_only: 164 | self.model.save_weights(filepath, overwrite=True) 165 | else: 166 | self.model.save(filepath, overwrite=True) 167 | -------------------------------------------------------------------------------- /utils/dataloader.py: -------------------------------------------------------------------------------- 1 | import math 2 | from random import shuffle 3 | 4 | import cv2 5 | import numpy as np 6 | from PIL import Image 7 | from tensorflow import keras 8 | from tensorflow.python.keras.utils.np_utils import to_categorical 9 | 10 | from .utils import cvtColor, preprocess_input 11 | from .utils_aug import CenterCrop, ImageNetPolicy, RandomResizedCrop, Resize 12 | 13 | 14 | class ClsDatasets(keras.utils.Sequence): 15 | def __init__(self, annotation_lines, input_shape, batch_size, num_classes, train, autoaugment_flag=True): 16 | self.annotation_lines = annotation_lines 17 | self.length = len(self.annotation_lines) 18 | 19 | self.input_shape = input_shape 20 | self.batch_size = batch_size 21 | self.num_classes = num_classes 22 | self.train = train 23 | 24 | self.autoaugment_flag = autoaugment_flag 25 | if self.autoaugment_flag: 26 | self.resize_crop = RandomResizedCrop(input_shape) 27 | self.policy = ImageNetPolicy() 28 | 29 | self.resize = Resize(input_shape[0] if input_shape[0] == input_shape[1] else input_shape) 30 | self.center_crop = CenterCrop(input_shape) 31 | 32 | 33 | def __len__(self): 34 | return math.ceil(len(self.annotation_lines) / float(self.batch_size)) 35 | 36 | def __getitem__(self, index): 37 | X_train = [] 38 | Y_train = [] 39 | for i in range(index * self.batch_size, (index + 1) * self.batch_size): 40 | i = i % self.length 41 | 42 | annotation_path = self.annotation_lines[i].split(';')[1].split()[0] 43 | image = Image.open(annotation_path) 44 | #------------------------------# 45 | # 读取图像并转换成RGB图像 46 | #------------------------------# 47 | image = cvtColor(image) 48 | if self.autoaugment_flag: 49 | image = self.AutoAugment(image, random=self.train) 50 | else: 51 | image = self.get_random_data(image, self.input_shape, random=self.train) 52 | image = preprocess_input(np.array(image).astype(np.float32)) 53 | 54 | X_train.append(image) 55 | Y_train.append(int(self.annotation_lines[i].split(';')[0])) 56 | 57 | X_train = np.array(X_train) 58 | Y_train = to_categorical(np.array(Y_train), num_classes = self.num_classes) 59 | return X_train, Y_train 60 | 61 | def generate(self): 62 | i = 0 63 | while 1: 64 | X_train = [] 65 | Y_train = [] 66 | for _ in range(self.batch_size): 67 | if i == 0: 68 | np.random.shuffle(self.annotation_lines) 69 | annotation_path = self.annotation_lines[i].split(';')[1].split()[0] 70 | image = Image.open(annotation_path) 71 | #------------------------------# 72 | # 读取图像并转换成RGB图像 73 | #------------------------------# 74 | image = cvtColor(image) 75 | if self.autoaugment_flag: 76 | image = self.AutoAugment(image, random=self.train) 77 | else: 78 | image = self.get_random_data(image, self.input_shape, random=self.train) 79 | image = preprocess_input(np.array(image).astype(np.float32)) 80 | 81 | X_train.append(image) 82 | Y_train.append(int(self.annotation_lines[i].split(';')[0])) 83 | 84 | i = (i + 1) % self.length 85 | 86 | X_train = np.array(X_train) 87 | Y_train = to_categorical(np.array(Y_train), num_classes = self.num_classes) 88 | yield (X_train, Y_train) 89 | 90 | def on_epoch_end(self): 91 | shuffle(self.annotation_lines) 92 | 93 | def rand(self, a=0, b=1): 94 | return np.random.rand()*(b-a) + a 95 | 96 | def get_random_data(self, image, input_shape, jitter=.3, hue=.1, sat=0.7, val=0.3, random=True): 97 | #------------------------------# 98 | # 获得图像的高宽与目标高宽 99 | #------------------------------# 100 | iw, ih = image.size 101 | h, w = input_shape 102 | 103 | if not random: 104 | scale = min(w/iw, h/ih) 105 | nw = int(iw*scale) 106 | nh = int(ih*scale) 107 | dx = (w-nw)//2 108 | dy = (h-nh)//2 109 | 110 | #---------------------------------# 111 | # 将图像多余的部分加上灰条 112 | #---------------------------------# 113 | image = image.resize((nw,nh), Image.BICUBIC) 114 | new_image = Image.new('RGB', (w,h), (128,128,128)) 115 | new_image.paste(image, (dx, dy)) 116 | image_data = np.array(new_image, np.float32) 117 | 118 | return image_data 119 | 120 | #------------------------------------------# 121 | # 对图像进行缩放并且进行长和宽的扭曲 122 | #------------------------------------------# 123 | new_ar = iw/ih * self.rand(1-jitter,1+jitter) / self.rand(1-jitter,1+jitter) 124 | scale = self.rand(0.75, 1.5) 125 | if new_ar < 1: 126 | nh = int(scale*h) 127 | nw = int(nh*new_ar) 128 | else: 129 | nw = int(scale*w) 130 | nh = int(nw/new_ar) 131 | image = image.resize((nw,nh), Image.BICUBIC) 132 | 133 | #------------------------------------------# 134 | # 将图像多余的部分加上灰条 135 | #------------------------------------------# 136 | dx = int(self.rand(0, w-nw)) 137 | dy = int(self.rand(0, h-nh)) 138 | new_image = Image.new('RGB', (w,h), (128, 128, 128)) 139 | new_image.paste(image, (dx, dy)) 140 | image = new_image 141 | 142 | #------------------------------------------# 143 | # 翻转图像 144 | #------------------------------------------# 145 | flip = self.rand()<.5 146 | if flip: image = image.transpose(Image.FLIP_LEFT_RIGHT) 147 | 148 | rotate = self.rand()<.5 149 | if rotate: 150 | angle = np.random.randint(-15,15) 151 | a,b = w/2,h/2 152 | M = cv2.getRotationMatrix2D((a,b),angle,1) 153 | image = cv2.warpAffine(np.array(image), M, (w,h), borderValue=[128, 128, 128]) 154 | 155 | image_data = np.array(image, np.uint8) 156 | #---------------------------------# 157 | # 对图像进行色域变换 158 | # 计算色域变换的参数 159 | #---------------------------------# 160 | r = np.random.uniform(-1, 1, 3) * [hue, sat, val] + 1 161 | #---------------------------------# 162 | # 将图像转到HSV上 163 | #---------------------------------# 164 | hue, sat, val = cv2.split(cv2.cvtColor(image_data, cv2.COLOR_RGB2HSV)) 165 | dtype = image_data.dtype 166 | #---------------------------------# 167 | # 应用变换 168 | #---------------------------------# 169 | x = np.arange(0, 256, dtype=r.dtype) 170 | lut_hue = ((x * r[0]) % 180).astype(dtype) 171 | lut_sat = np.clip(x * r[1], 0, 255).astype(dtype) 172 | lut_val = np.clip(x * r[2], 0, 255).astype(dtype) 173 | 174 | image_data = cv2.merge((cv2.LUT(hue, lut_hue), cv2.LUT(sat, lut_sat), cv2.LUT(val, lut_val))) 175 | image_data = cv2.cvtColor(image_data, cv2.COLOR_HSV2RGB) 176 | return image_data 177 | 178 | def AutoAugment(self, image, random=True): 179 | if not random: 180 | image = self.resize(image) 181 | image = self.center_crop(image) 182 | return image 183 | 184 | #------------------------------------------# 185 | # resize并且随即裁剪 186 | #------------------------------------------# 187 | image = self.resize_crop(image) 188 | 189 | #------------------------------------------# 190 | # 翻转图像 191 | #------------------------------------------# 192 | flip = self.rand()<.5 193 | if flip: image = image.transpose(Image.FLIP_LEFT_RIGHT) 194 | 195 | #------------------------------------------# 196 | # 随机增强 197 | #------------------------------------------# 198 | image = self.policy(image) 199 | return image -------------------------------------------------------------------------------- /utils/utils.py: -------------------------------------------------------------------------------- 1 | import math 2 | from functools import partial 3 | 4 | import numpy as np 5 | from PIL import Image 6 | 7 | from .utils_aug import resize, center_crop 8 | 9 | 10 | #---------------------------------------------------------# 11 | # 将图像转换成RGB图像,防止灰度图在预测时报错。 12 | # 代码仅仅支持RGB图像的预测,所有其它类型的图像都会转化成RGB 13 | #---------------------------------------------------------# 14 | def cvtColor(image): 15 | if len(np.shape(image)) == 3 and np.shape(image)[2] == 3: 16 | return image 17 | else: 18 | image = image.convert('RGB') 19 | return image 20 | 21 | #---------------------------------------------------# 22 | # 对输入图像进行resize 23 | #---------------------------------------------------# 24 | def letterbox_image(image, size, letterbox_image): 25 | w, h = size 26 | iw, ih = image.size 27 | if letterbox_image: 28 | '''resize image with unchanged aspect ratio using padding''' 29 | scale = min(w/iw, h/ih) 30 | nw = int(iw*scale) 31 | nh = int(ih*scale) 32 | 33 | image = image.resize((nw,nh), Image.BICUBIC) 34 | new_image = Image.new('RGB', size, (128,128,128)) 35 | new_image.paste(image, ((w-nw)//2, (h-nh)//2)) 36 | else: 37 | if h == w: 38 | new_image = resize(image, h) 39 | else: 40 | new_image = resize(image, [h ,w]) 41 | new_image = center_crop(new_image, [h ,w]) 42 | return new_image 43 | 44 | #---------------------------------------------------# 45 | # 获得类 46 | #---------------------------------------------------# 47 | def get_classes(classes_path): 48 | with open(classes_path, encoding='utf-8') as f: 49 | class_names = f.readlines() 50 | class_names = [c.strip() for c in class_names] 51 | return class_names, len(class_names) 52 | 53 | #----------------------------------------# 54 | # 预处理训练图片 55 | #----------------------------------------# 56 | def preprocess_input(x): 57 | # x /= 127.5 58 | # x -= 1. 59 | x /= 255 60 | x -= np.array([0.485, 0.456, 0.406]) 61 | x /= np.array([0.229, 0.224, 0.225]) 62 | return x 63 | 64 | def show_config(**kwargs): 65 | print('Configurations:') 66 | print('-' * 70) 67 | print('|%25s | %40s|' % ('keys', 'values')) 68 | print('-' * 70) 69 | for key, value in kwargs.items(): 70 | print('|%25s | %40s|' % (str(key), str(value))) 71 | print('-' * 70) 72 | 73 | def get_lr_scheduler(lr_decay_type, lr, min_lr, total_iters, warmup_iters_ratio = 0.05, warmup_lr_ratio = 0.1, no_aug_iter_ratio = 0.05, step_num = 10): 74 | def yolox_warm_cos_lr(lr, min_lr, total_iters, warmup_total_iters, warmup_lr_start, no_aug_iter, iters): 75 | if iters <= warmup_total_iters: 76 | # lr = (lr - warmup_lr_start) * iters / float(warmup_total_iters) + warmup_lr_start 77 | lr = (lr - warmup_lr_start) * pow(iters / float(warmup_total_iters), 2 78 | ) + warmup_lr_start 79 | elif iters >= total_iters - no_aug_iter: 80 | lr = min_lr 81 | else: 82 | lr = min_lr + 0.5 * (lr - min_lr) * ( 83 | 1.0 84 | + math.cos( 85 | math.pi 86 | * (iters - warmup_total_iters) 87 | / (total_iters - warmup_total_iters - no_aug_iter) 88 | ) 89 | ) 90 | return lr 91 | 92 | def step_lr(lr, decay_rate, step_size, iters): 93 | if step_size < 1: 94 | raise ValueError("step_size must above 1.") 95 | n = iters // step_size 96 | out_lr = lr * decay_rate ** n 97 | return out_lr 98 | 99 | if lr_decay_type == "cos": 100 | warmup_total_iters = min(max(warmup_iters_ratio * total_iters, 1), 3) 101 | warmup_lr_start = max(warmup_lr_ratio * lr, 1e-6) 102 | no_aug_iter = min(max(no_aug_iter_ratio * total_iters, 1), 15) 103 | func = partial(yolox_warm_cos_lr ,lr, min_lr, total_iters, warmup_total_iters, warmup_lr_start, no_aug_iter) 104 | else: 105 | decay_rate = (min_lr / lr) ** (1 / (step_num - 1)) 106 | step_size = total_iters / step_num 107 | func = partial(step_lr, lr, decay_rate, step_size) 108 | 109 | return func 110 | 111 | #-------------------------------------------------------------------------------------------------------------------------------# 112 | # From https://github.com/ckyrkou/Keras_FLOP_Estimator 113 | # Fix lots of bugs 114 | #-------------------------------------------------------------------------------------------------------------------------------# 115 | def net_flops(model, table=False, print_result=True): 116 | if (table == True): 117 | print("\n") 118 | print('%25s | %16s | %16s | %16s | %16s | %6s | %6s' % ( 119 | 'Layer Name', 'Input Shape', 'Output Shape', 'Kernel Size', 'Filters', 'Strides', 'FLOPS')) 120 | print('=' * 120) 121 | 122 | #---------------------------------------------------# 123 | # 总的FLOPs 124 | #---------------------------------------------------# 125 | t_flops = 0 126 | factor = 1e9 127 | 128 | for l in model.layers: 129 | try: 130 | #--------------------------------------# 131 | # 所需参数的初始化定义 132 | #--------------------------------------# 133 | o_shape, i_shape, strides, ks, filters = ('', '', ''), ('', '', ''), (1, 1), (0, 0), 0 134 | flops = 0 135 | #--------------------------------------# 136 | # 获得层的名字 137 | #--------------------------------------# 138 | name = l.name 139 | 140 | if ('InputLayer' in str(l)): 141 | i_shape = l.get_input_shape_at(0)[1:4] 142 | o_shape = l.get_output_shape_at(0)[1:4] 143 | 144 | #--------------------------------------# 145 | # Reshape层 146 | #--------------------------------------# 147 | elif ('Reshape' in str(l)): 148 | i_shape = l.get_input_shape_at(0)[1:4] 149 | o_shape = l.get_output_shape_at(0)[1:4] 150 | 151 | #--------------------------------------# 152 | # 填充层 153 | #--------------------------------------# 154 | elif ('Padding' in str(l)): 155 | i_shape = l.get_input_shape_at(0)[1:4] 156 | o_shape = l.get_output_shape_at(0)[1:4] 157 | 158 | #--------------------------------------# 159 | # 平铺层 160 | #--------------------------------------# 161 | elif ('Flatten' in str(l)): 162 | i_shape = l.get_input_shape_at(0)[1:4] 163 | o_shape = l.get_output_shape_at(0)[1:4] 164 | 165 | #--------------------------------------# 166 | # 激活函数层 167 | #--------------------------------------# 168 | elif 'Activation' in str(l): 169 | i_shape = l.get_input_shape_at(0)[1:4] 170 | o_shape = l.get_output_shape_at(0)[1:4] 171 | 172 | #--------------------------------------# 173 | # LeakyReLU 174 | #--------------------------------------# 175 | elif 'LeakyReLU' in str(l): 176 | for i in range(len(l._inbound_nodes)): 177 | i_shape = l.get_input_shape_at(i)[1:4] 178 | o_shape = l.get_output_shape_at(i)[1:4] 179 | 180 | flops += i_shape[0] * i_shape[1] * i_shape[2] 181 | 182 | #--------------------------------------# 183 | # 池化层 184 | #--------------------------------------# 185 | elif 'MaxPooling' in str(l): 186 | i_shape = l.get_input_shape_at(0)[1:4] 187 | o_shape = l.get_output_shape_at(0)[1:4] 188 | 189 | #--------------------------------------# 190 | # 池化层 191 | #--------------------------------------# 192 | elif ('AveragePooling' in str(l) and 'Global' not in str(l)): 193 | strides = l.strides 194 | ks = l.pool_size 195 | 196 | for i in range(len(l._inbound_nodes)): 197 | i_shape = l.get_input_shape_at(i)[1:4] 198 | o_shape = l.get_output_shape_at(i)[1:4] 199 | 200 | flops += o_shape[0] * o_shape[1] * o_shape[2] 201 | 202 | #--------------------------------------# 203 | # 全局池化层 204 | #--------------------------------------# 205 | elif ('AveragePooling' in str(l) and 'Global' in str(l)): 206 | for i in range(len(l._inbound_nodes)): 207 | i_shape = l.get_input_shape_at(i)[1:4] 208 | o_shape = l.get_output_shape_at(i)[1:4] 209 | 210 | flops += (i_shape[0] * i_shape[1] + 1) * i_shape[2] 211 | 212 | #--------------------------------------# 213 | # 标准化层 214 | #--------------------------------------# 215 | elif ('BatchNormalization' in str(l)): 216 | for i in range(len(l._inbound_nodes)): 217 | i_shape = l.get_input_shape_at(i)[1:4] 218 | o_shape = l.get_output_shape_at(i)[1:4] 219 | 220 | temp_flops = 1 221 | for i in range(len(i_shape)): 222 | temp_flops *= i_shape[i] 223 | temp_flops *= 2 224 | 225 | flops += temp_flops 226 | 227 | #--------------------------------------# 228 | # 全连接层 229 | #--------------------------------------# 230 | elif ('Dense' in str(l)): 231 | for i in range(len(l._inbound_nodes)): 232 | i_shape = l.get_input_shape_at(i)[1:4] 233 | o_shape = l.get_output_shape_at(i)[1:4] 234 | 235 | temp_flops = 1 236 | for i in range(len(o_shape)): 237 | temp_flops *= o_shape[i] 238 | 239 | if (i_shape[-1] == None): 240 | temp_flops = temp_flops * o_shape[-1] 241 | else: 242 | temp_flops = temp_flops * i_shape[-1] 243 | flops += temp_flops 244 | 245 | #--------------------------------------# 246 | # 普通卷积层 247 | #--------------------------------------# 248 | elif ('Conv2D' in str(l) and 'DepthwiseConv2D' not in str(l) and 'SeparableConv2D' not in str(l)): 249 | strides = l.strides 250 | ks = l.kernel_size 251 | filters = l.filters 252 | bias = 1 if l.use_bias else 0 253 | 254 | for i in range(len(l._inbound_nodes)): 255 | i_shape = l.get_input_shape_at(i)[1:4] 256 | o_shape = l.get_output_shape_at(i)[1:4] 257 | 258 | if (filters == None): 259 | filters = i_shape[2] 260 | flops += filters * o_shape[0] * o_shape[1] * (ks[0] * ks[1] * i_shape[2] + bias) 261 | 262 | #--------------------------------------# 263 | # 逐层卷积层 264 | #--------------------------------------# 265 | elif ('Conv2D' in str(l) and 'DepthwiseConv2D' in str(l) and 'SeparableConv2D' not in str(l)): 266 | strides = l.strides 267 | ks = l.kernel_size 268 | filters = l.filters 269 | bias = 1 if l.use_bias else 0 270 | 271 | for i in range(len(l._inbound_nodes)): 272 | i_shape = l.get_input_shape_at(i)[1:4] 273 | o_shape = l.get_output_shape_at(i)[1:4] 274 | 275 | if (filters == None): 276 | filters = i_shape[2] 277 | flops += filters * o_shape[0] * o_shape[1] * (ks[0] * ks[1] + bias) 278 | 279 | #--------------------------------------# 280 | # 深度可分离卷积层 281 | #--------------------------------------# 282 | elif ('Conv2D' in str(l) and 'DepthwiseConv2D' not in str(l) and 'SeparableConv2D' in str(l)): 283 | strides = l.strides 284 | ks = l.kernel_size 285 | filters = l.filters 286 | 287 | for i in range(len(l._inbound_nodes)): 288 | i_shape = l.get_input_shape_at(i)[1:4] 289 | o_shape = l.get_output_shape_at(i)[1:4] 290 | 291 | if (filters == None): 292 | filters = i_shape[2] 293 | flops += i_shape[2] * o_shape[0] * o_shape[1] * (ks[0] * ks[1] + bias) + \ 294 | filters * o_shape[0] * o_shape[1] * (1 * 1 * i_shape[2] + bias) 295 | #--------------------------------------# 296 | # 模型中有模型时 297 | #--------------------------------------# 298 | elif 'Model' in str(l): 299 | flops = net_flops(l, print_result=False) 300 | 301 | t_flops += flops 302 | 303 | if (table == True): 304 | print('%25s | %16s | %16s | %16s | %16s | %6s | %5.4f' % ( 305 | name[:25], str(i_shape), str(o_shape), str(ks), str(filters), str(strides), flops)) 306 | 307 | except: 308 | pass 309 | 310 | t_flops = t_flops * 2 311 | if print_result: 312 | show_flops = t_flops / factor 313 | print('Total GFLOPs: %.3fG' % (show_flops)) 314 | return t_flops -------------------------------------------------------------------------------- /utils/utils_aug.py: -------------------------------------------------------------------------------- 1 | import numbers 2 | import random 3 | import math 4 | import warnings 5 | import numpy as np 6 | from PIL import Image, ImageEnhance, ImageOps 7 | 8 | 9 | class ShearX(object): 10 | def __init__(self, fillcolor=(128, 128, 128)): 11 | self.fillcolor = fillcolor 12 | 13 | def __call__(self, x, magnitude): 14 | return x.transform( 15 | x.size, Image.AFFINE, (1, magnitude * random.choice([-1, 1]), 0, 0, 1, 0), 16 | Image.BICUBIC, fillcolor=self.fillcolor) 17 | 18 | 19 | class ShearY(object): 20 | def __init__(self, fillcolor=(128, 128, 128)): 21 | self.fillcolor = fillcolor 22 | 23 | def __call__(self, x, magnitude): 24 | return x.transform( 25 | x.size, Image.AFFINE, (1, 0, 0, magnitude * random.choice([-1, 1]), 1, 0), 26 | Image.BICUBIC, fillcolor=self.fillcolor) 27 | 28 | 29 | class TranslateX(object): 30 | def __init__(self, fillcolor=(128, 128, 128)): 31 | self.fillcolor = fillcolor 32 | 33 | def __call__(self, x, magnitude): 34 | return x.transform( 35 | x.size, Image.AFFINE, (1, 0, magnitude * x.size[0] * random.choice([-1, 1]), 0, 1, 0), 36 | fillcolor=self.fillcolor) 37 | 38 | 39 | class TranslateY(object): 40 | def __init__(self, fillcolor=(128, 128, 128)): 41 | self.fillcolor = fillcolor 42 | 43 | def __call__(self, x, magnitude): 44 | return x.transform( 45 | x.size, Image.AFFINE, (1, 0, 0, 0, 1, magnitude * x.size[1] * random.choice([-1, 1])), 46 | fillcolor=self.fillcolor) 47 | 48 | 49 | class Rotate(object): 50 | # from https://stackoverflow.com/questions/ 51 | # 5252170/specify-image-filling-color-when-rotating-in-python-with-pil-and-setting-expand 52 | def __call__(self, x, magnitude): 53 | rot = x.convert("RGBA").rotate(magnitude * random.choice([-1, 1])) 54 | return Image.composite(rot, Image.new("RGBA", rot.size, (128,) * 4), rot).convert(x.mode) 55 | 56 | 57 | class Color(object): 58 | def __call__(self, x, magnitude): 59 | return ImageEnhance.Color(x).enhance(1 + magnitude * random.choice([-1, 1])) 60 | 61 | 62 | class Posterize(object): 63 | def __call__(self, x, magnitude): 64 | return ImageOps.posterize(x, magnitude) 65 | 66 | 67 | class Solarize(object): 68 | def __call__(self, x, magnitude): 69 | return ImageOps.solarize(x, magnitude) 70 | 71 | 72 | class Contrast(object): 73 | def __call__(self, x, magnitude): 74 | return ImageEnhance.Contrast(x).enhance(1 + magnitude * random.choice([-1, 1])) 75 | 76 | 77 | class Sharpness(object): 78 | def __call__(self, x, magnitude): 79 | return ImageEnhance.Sharpness(x).enhance(1 + magnitude * random.choice([-1, 1])) 80 | 81 | 82 | class Brightness(object): 83 | def __call__(self, x, magnitude): 84 | return ImageEnhance.Brightness(x).enhance(1 + magnitude * random.choice([-1, 1])) 85 | 86 | 87 | class AutoContrast(object): 88 | def __call__(self, x, magnitude): 89 | return ImageOps.autocontrast(x) 90 | 91 | 92 | class Equalize(object): 93 | def __call__(self, x, magnitude): 94 | return ImageOps.equalize(x) 95 | 96 | 97 | class Invert(object): 98 | def __call__(self, x, magnitude): 99 | return ImageOps.invert(x) 100 | 101 | 102 | class ImageNetPolicy(object): 103 | """ Randomly choose one of the best 24 Sub-policies on ImageNet. 104 | Example: 105 | >>> policy = ImageNetPolicy() 106 | >>> transformed = policy(image) 107 | Example as a PyTorch Transform: 108 | >>> transform = transforms.Compose([ 109 | >>> transforms.Resize(256), 110 | >>> ImageNetPolicy(), 111 | >>> transforms.ToTensor()]) 112 | """ 113 | def __init__(self, fillcolor=(128, 128, 128)): 114 | self.policies = [ 115 | SubPolicy(0.4, "posterize", 8, 0.6, "rotate", 9, fillcolor), 116 | SubPolicy(0.6, "solarize", 5, 0.6, "autocontrast", 5, fillcolor), 117 | SubPolicy(0.8, "equalize", 8, 0.6, "equalize", 3, fillcolor), 118 | SubPolicy(0.6, "posterize", 7, 0.6, "posterize", 6, fillcolor), 119 | SubPolicy(0.4, "equalize", 7, 0.2, "solarize", 4, fillcolor), 120 | 121 | SubPolicy(0.4, "equalize", 4, 0.8, "rotate", 8, fillcolor), 122 | SubPolicy(0.6, "solarize", 3, 0.6, "equalize", 7, fillcolor), 123 | SubPolicy(0.8, "posterize", 5, 1.0, "equalize", 2, fillcolor), 124 | SubPolicy(0.2, "rotate", 3, 0.6, "solarize", 8, fillcolor), 125 | SubPolicy(0.6, "equalize", 8, 0.4, "posterize", 6, fillcolor), 126 | 127 | SubPolicy(0.8, "rotate", 8, 0.4, "color", 0, fillcolor), 128 | SubPolicy(0.4, "rotate", 9, 0.6, "equalize", 2, fillcolor), 129 | SubPolicy(0.0, "equalize", 7, 0.8, "equalize", 8, fillcolor), 130 | SubPolicy(0.6, "invert", 4, 1.0, "equalize", 8, fillcolor), 131 | SubPolicy(0.6, "color", 4, 1.0, "contrast", 8, fillcolor), 132 | 133 | SubPolicy(0.8, "rotate", 8, 1.0, "color", 2, fillcolor), 134 | SubPolicy(0.8, "color", 8, 0.8, "solarize", 7, fillcolor), 135 | SubPolicy(0.4, "sharpness", 7, 0.6, "invert", 8, fillcolor), 136 | SubPolicy(0.6, "shearX", 5, 1.0, "equalize", 9, fillcolor), 137 | SubPolicy(0.4, "color", 0, 0.6, "equalize", 3, fillcolor), 138 | 139 | SubPolicy(0.4, "equalize", 7, 0.2, "solarize", 4, fillcolor), 140 | SubPolicy(0.6, "solarize", 5, 0.6, "autocontrast", 5, fillcolor), 141 | SubPolicy(0.6, "invert", 4, 1.0, "equalize", 8, fillcolor), 142 | SubPolicy(0.6, "color", 4, 1.0, "contrast", 8, fillcolor), 143 | SubPolicy(0.8, "equalize", 8, 0.6, "equalize", 3, fillcolor) 144 | ] 145 | 146 | def __call__(self, img): 147 | policy_idx = random.randint(0, len(self.policies) - 1) 148 | return self.policies[policy_idx](img) 149 | 150 | def __repr__(self): 151 | return "AutoAugment ImageNet Policy" 152 | 153 | class SubPolicy(object): 154 | def __init__(self, p1, operation1, magnitude_idx1, p2, operation2, magnitude_idx2, fillcolor=(128, 128, 128)): 155 | ranges = { 156 | "shearX": np.linspace(0, 0.3, 10), 157 | "shearY": np.linspace(0, 0.3, 10), 158 | "translateX": np.linspace(0, 150 / 331, 10), 159 | "translateY": np.linspace(0, 150 / 331, 10), 160 | "rotate": np.linspace(0, 30, 10), 161 | "color": np.linspace(0.0, 0.9, 10), 162 | "posterize": np.round(np.linspace(8, 4, 10), 0).astype(np.int), 163 | "solarize": np.linspace(256, 0, 10), 164 | "contrast": np.linspace(0.0, 0.9, 10), 165 | "sharpness": np.linspace(0.0, 0.9, 10), 166 | "brightness": np.linspace(0.0, 0.9, 10), 167 | "autocontrast": [0] * 10, 168 | "equalize": [0] * 10, 169 | "invert": [0] * 10 170 | } 171 | 172 | func = { 173 | "shearX": ShearX(fillcolor=fillcolor), 174 | "shearY": ShearY(fillcolor=fillcolor), 175 | "translateX": TranslateX(fillcolor=fillcolor), 176 | "translateY": TranslateY(fillcolor=fillcolor), 177 | "rotate": Rotate(), 178 | "color": Color(), 179 | "posterize": Posterize(), 180 | "solarize": Solarize(), 181 | "contrast": Contrast(), 182 | "sharpness": Sharpness(), 183 | "brightness": Brightness(), 184 | "autocontrast": AutoContrast(), 185 | "equalize": Equalize(), 186 | "invert": Invert() 187 | } 188 | 189 | self.p1 = p1 190 | self.operation1 = func[operation1] 191 | self.magnitude1 = ranges[operation1][magnitude_idx1] 192 | self.p2 = p2 193 | self.operation2 = func[operation2] 194 | self.magnitude2 = ranges[operation2][magnitude_idx2] 195 | 196 | def __call__(self, img): 197 | if random.random() < self.p1: 198 | img = self.operation1(img, self.magnitude1) 199 | if random.random() < self.p2: 200 | img = self.operation2(img, self.magnitude2) 201 | return img 202 | 203 | def crop(img, i, j, h, w): 204 | """Crop the given PIL Image. 205 | 206 | Args: 207 | img (PIL Image): Image to be cropped. 208 | i (int): i in (i,j) i.e coordinates of the upper left corner. 209 | j (int): j in (i,j) i.e coordinates of the upper left corner. 210 | h (int): Height of the cropped image. 211 | w (int): Width of the cropped image. 212 | 213 | Returns: 214 | PIL Image: Cropped image. 215 | """ 216 | return img.crop((j, i, j + w, i + h)) 217 | 218 | def resize(img, size, interpolation=Image.BILINEAR): 219 | r"""Resize the input PIL Image to the given size. 220 | 221 | Args: 222 | img (PIL Image): Image to be resized. 223 | size (sequence or int): Desired output size. If size is a sequence like 224 | (h, w), the output size will be matched to this. If size is an int, 225 | the smaller edge of the image will be matched to this number maintaing 226 | the aspect ratio. i.e, if height > width, then image will be rescaled to 227 | :math:`\left(\text{size} \times \frac{\text{height}}{\text{width}}, \text{size}\right)` 228 | interpolation (int, optional): Desired interpolation. Default is 229 | ``PIL.Image.BILINEAR`` 230 | 231 | Returns: 232 | PIL Image: Resized image. 233 | """ 234 | if isinstance(size, int): 235 | w, h = img.size 236 | if (w <= h and w == size) or (h <= w and h == size): 237 | return img 238 | if w < h: 239 | ow = size 240 | oh = int(size * h / w) 241 | return img.resize((ow, oh), interpolation) 242 | else: 243 | oh = size 244 | ow = int(size * w / h) 245 | return img.resize((ow, oh), interpolation) 246 | else: 247 | return img.resize(size[::-1], interpolation) 248 | 249 | def center_crop(img, output_size): 250 | if isinstance(output_size, numbers.Number): 251 | output_size = (int(output_size), int(output_size)) 252 | w, h = img.size 253 | th, tw = output_size 254 | i = int(round((h - th) / 2.)) 255 | j = int(round((w - tw) / 2.)) 256 | return crop(img, i, j, th, tw) 257 | 258 | def resized_crop(img, i, j, h, w, size, interpolation=Image.BILINEAR): 259 | """Crop the given PIL Image and resize it to desired size. 260 | 261 | Notably used in :class:`~torchvision.transforms.RandomResizedCrop`. 262 | 263 | Args: 264 | img (PIL Image): Image to be cropped. 265 | i (int): i in (i,j) i.e coordinates of the upper left corner 266 | j (int): j in (i,j) i.e coordinates of the upper left corner 267 | h (int): Height of the cropped image. 268 | w (int): Width of the cropped image. 269 | size (sequence or int): Desired output size. Same semantics as ``resize``. 270 | interpolation (int, optional): Desired interpolation. Default is 271 | ``PIL.Image.BILINEAR``. 272 | Returns: 273 | PIL Image: Cropped image. 274 | """ 275 | img = crop(img, i, j, h, w) 276 | img = resize(img, size, interpolation) 277 | return img 278 | 279 | class Resize(object): 280 | """Resize the input PIL Image to the given size. 281 | 282 | Args: 283 | size (sequence or int): Desired output size. If size is a sequence like 284 | (h, w), output size will be matched to this. If size is an int, 285 | smaller edge of the image will be matched to this number. 286 | i.e, if height > width, then image will be rescaled to 287 | (size * height / width, size) 288 | interpolation (int, optional): Desired interpolation. Default is 289 | ``PIL.Image.BILINEAR`` 290 | """ 291 | 292 | def __init__(self, size, interpolation=Image.BILINEAR): 293 | self.size = size 294 | self.interpolation = interpolation 295 | 296 | def __call__(self, img): 297 | """ 298 | Args: 299 | img (PIL Image): Image to be scaled. 300 | 301 | Returns: 302 | PIL Image: Rescaled image. 303 | """ 304 | return resize(img, self.size, self.interpolation) 305 | 306 | class CenterCrop(object): 307 | """Crops the given PIL Image at the center. 308 | 309 | Args: 310 | size (sequence or int): Desired output size of the crop. If size is an 311 | int instead of sequence like (h, w), a square crop (size, size) is 312 | made. 313 | """ 314 | 315 | def __init__(self, size): 316 | self.size = size 317 | 318 | def __call__(self, img): 319 | """ 320 | Args: 321 | img (PIL Image): Image to be cropped. 322 | 323 | Returns: 324 | PIL Image: Cropped image. 325 | """ 326 | return center_crop(img, self.size) 327 | 328 | class RandomResizedCrop(object): 329 | """Crop the given PIL Image to random size and aspect ratio. 330 | 331 | A crop of random size (default: of 0.08 to 1.0) of the original size and a random 332 | aspect ratio (default: of 3/4 to 4/3) of the original aspect ratio is made. This crop 333 | is finally resized to given size. 334 | This is popularly used to train the Inception networks. 335 | 336 | Args: 337 | size: expected output size of each edge 338 | scale: range of size of the origin size cropped 339 | ratio: range of aspect ratio of the origin aspect ratio cropped 340 | interpolation: Default: PIL.Image.BILINEAR 341 | """ 342 | 343 | def __init__(self, size, scale=(0.08, 1.0), ratio=(3. / 4., 4. / 3.), interpolation=Image.BILINEAR): 344 | self.size = size 345 | if (scale[0] > scale[1]) or (ratio[0] > ratio[1]): 346 | warnings.warn("range should be of kind (min, max)") 347 | 348 | self.interpolation = interpolation 349 | self.scale = scale 350 | self.ratio = ratio 351 | 352 | @staticmethod 353 | def get_params(img, scale, ratio): 354 | """Get parameters for ``crop`` for a random sized crop. 355 | 356 | Args: 357 | img (PIL Image): Image to be cropped. 358 | scale (tuple): range of size of the origin size cropped 359 | ratio (tuple): range of aspect ratio of the origin aspect ratio cropped 360 | 361 | Returns: 362 | tuple: params (i, j, h, w) to be passed to ``crop`` for a random 363 | sized crop. 364 | """ 365 | area = img.size[0] * img.size[1] 366 | 367 | for attempt in range(10): 368 | target_area = random.uniform(*scale) * area 369 | log_ratio = (math.log(ratio[0]), math.log(ratio[1])) 370 | aspect_ratio = math.exp(random.uniform(*log_ratio)) 371 | 372 | w = int(round(math.sqrt(target_area * aspect_ratio))) 373 | h = int(round(math.sqrt(target_area / aspect_ratio))) 374 | 375 | if w <= img.size[0] and h <= img.size[1]: 376 | i = random.randint(0, img.size[1] - h) 377 | j = random.randint(0, img.size[0] - w) 378 | return i, j, h, w 379 | 380 | # Fallback to central crop 381 | in_ratio = img.size[0] / img.size[1] 382 | if (in_ratio < min(ratio)): 383 | w = img.size[0] 384 | h = int(round(w / min(ratio))) 385 | elif (in_ratio > max(ratio)): 386 | h = img.size[1] 387 | w = int(round(h * max(ratio))) 388 | else: # whole image 389 | w = img.size[0] 390 | h = img.size[1] 391 | i = (img.size[1] - h) // 2 392 | j = (img.size[0] - w) // 2 393 | return i, j, h, w 394 | 395 | def __call__(self, img): 396 | """ 397 | Args: 398 | img (PIL Image): Image to be cropped and resized. 399 | 400 | Returns: 401 | PIL Image: Randomly cropped and resized image. 402 | """ 403 | i, j, h, w = self.get_params(img, self.scale, self.ratio) 404 | return resized_crop(img, i, j, h, w, self.size, self.interpolation) -------------------------------------------------------------------------------- /utils/utils_fit.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import tensorflow as tf 4 | from tqdm import tqdm 5 | 6 | 7 | #------------------------------# 8 | # 防止bug 9 | #------------------------------# 10 | def get_train_step_fn(strategy): 11 | @tf.function 12 | def train_step(batch_images, batch_labels, net, optimizer): 13 | with tf.GradientTape() as tape: 14 | #------------------------------# 15 | # 计算loss 16 | #------------------------------# 17 | predict = net([batch_images], training=True) 18 | loss_value = tf.reduce_mean(tf.losses.categorical_crossentropy(batch_labels, predict)) 19 | 20 | grads = tape.gradient(loss_value, net.trainable_variables) 21 | optimizer.apply_gradients(zip(grads, net.trainable_variables)) 22 | acc = tf.reduce_mean(tf.cast(tf.equal(tf.argmax(predict, axis=-1), tf.argmax(batch_labels, axis=-1)), tf.float32)) 23 | return loss_value, acc 24 | 25 | if strategy == None: 26 | return train_step 27 | else: 28 | #----------------------# 29 | # 多gpu训练 30 | #----------------------# 31 | @tf.function 32 | def distributed_train_step(images, targets, net, optimizer): 33 | per_replica_losses, per_replica_acc = strategy.run(train_step, args=(images, targets, net, optimizer,)) 34 | return strategy.reduce(tf.distribute.ReduceOp.MEAN, per_replica_losses, axis=None), strategy.reduce(tf.distribute.ReduceOp.MEAN, per_replica_acc, axis=None) 35 | return distributed_train_step 36 | 37 | #----------------------# 38 | # 防止bug 39 | #----------------------# 40 | def get_val_step_fn(strategy): 41 | @tf.function 42 | def val_step(batch_images, batch_labels, net, optimizer): 43 | predict = net(batch_images) 44 | loss_value = tf.reduce_mean(tf.losses.categorical_crossentropy(batch_labels, predict)) 45 | acc = tf.reduce_mean(tf.cast(tf.equal(tf.argmax(predict, axis=-1), tf.argmax(batch_labels, axis=-1)), tf.float32)) 46 | return loss_value, acc 47 | if strategy == None: 48 | return val_step 49 | else: 50 | #----------------------# 51 | # 多gpu验证 52 | #----------------------# 53 | @tf.function 54 | def distributed_val_step(images, targets, net, optimizer): 55 | per_replica_losses, per_replica_acc = strategy.run(val_step, args=(images, targets, net, optimizer,)) 56 | return strategy.reduce(tf.distribute.ReduceOp.MEAN, per_replica_losses, axis=None), strategy.reduce(tf.distribute.ReduceOp.MEAN, per_replica_acc, axis=None) 57 | return distributed_val_step 58 | 59 | def fit_one_epoch(net, loss_history, optimizer, epoch, epoch_step, epoch_step_val, gen, gen_val, Epoch, save_period, save_dir, strategy): 60 | train_step = get_train_step_fn(strategy) 61 | val_step = get_val_step_fn(strategy) 62 | 63 | total_loss = 0 64 | total_acc = 0 65 | val_loss = 0 66 | val_acc = 0 67 | print('Start Train') 68 | with tqdm(total=epoch_step,desc=f'Epoch {epoch + 1}/{Epoch}',postfix=dict,mininterval=0.3) as pbar: 69 | for iteration, batch in enumerate(gen): 70 | if iteration>=epoch_step: 71 | break 72 | batch_images, batch_labels = batch 73 | 74 | loss_value, acc = train_step(batch_images, batch_labels, net, optimizer) 75 | total_loss += loss_value.numpy() 76 | total_acc += acc.numpy() 77 | 78 | pbar.set_postfix(**{'total_loss' : float(total_loss) / (iteration + 1), 79 | 'acc' : float(total_acc) / (iteration + 1), 80 | 'lr' : optimizer._decayed_lr(tf.float32).numpy()}) 81 | pbar.update(1) 82 | print('Finish Train') 83 | 84 | print('Start Validation') 85 | with tqdm(total=epoch_step_val, desc=f'Epoch {epoch + 1}/{Epoch}',postfix=dict,mininterval=0.3) as pbar: 86 | for iteration, batch in enumerate(gen_val): 87 | if iteration>=epoch_step_val: 88 | break 89 | batch_images, batch_labels = batch 90 | 91 | loss_value, acc = val_step(batch_images, batch_labels, net, optimizer) 92 | val_loss += loss_value.numpy() 93 | val_acc += acc.numpy() 94 | 95 | pbar.set_postfix(**{'val_loss' : float(val_loss) / (iteration + 1), 96 | 'val_acc' : float(val_acc) / (iteration + 1)}) 97 | pbar.update(1) 98 | print('Finish Validation') 99 | 100 | logs = {'loss': total_loss / epoch_step, 'val_loss': val_loss / epoch_step_val} 101 | loss_history.on_epoch_end([], logs) 102 | print('Epoch:'+ str(epoch+1) + '/' + str(Epoch)) 103 | print('Total Loss: %.3f || Val Loss: %.3f ' % (total_loss / epoch_step, val_loss / epoch_step_val)) 104 | 105 | #-----------------------------------------------# 106 | # 保存权值 107 | #-----------------------------------------------# 108 | if (epoch + 1) % save_period == 0 or epoch + 1 == Epoch: 109 | net.save_weights(os.path.join(save_dir, 'ep%03d-loss%.3f-val_loss%.3f.h5' % (epoch + 1, total_loss / epoch_step, val_loss / epoch_step_val))) 110 | 111 | if len(loss_history.val_loss) <= 1 or (val_loss / epoch_step_val) <= min(loss_history.val_loss): 112 | print('Save best model to best_epoch_weights.pth') 113 | net.save_weights(os.path.join(save_dir, "best_epoch_weights.h5")) 114 | 115 | net.save_weights(os.path.join(save_dir, "last_epoch_weights.h5")) -------------------------------------------------------------------------------- /utils/utils_metrics.py: -------------------------------------------------------------------------------- 1 | import csv 2 | import os 3 | 4 | import matplotlib.pyplot as plt 5 | import numpy as np 6 | from PIL import Image 7 | 8 | 9 | def evaluteTop1_5(classfication, lines, metrics_out_path): 10 | correct_1 = 0 11 | correct_5 = 0 12 | preds = [] 13 | labels = [] 14 | total = len(lines) 15 | for index, line in enumerate(lines): 16 | annotation_path = line.split(';')[1].split()[0] 17 | x = Image.open(annotation_path) 18 | y = int(line.split(';')[0]) 19 | 20 | pred = classfication.detect_image(x) 21 | pred_1 = np.argmax(pred) 22 | correct_1 += pred_1 == y 23 | 24 | pred_5 = np.argsort(pred)[::-1] 25 | pred_5 = pred_5[:5] 26 | correct_5 += y in pred_5 27 | 28 | preds.append(pred_1) 29 | labels.append(y) 30 | if index % 100 == 0: 31 | print("[%d/%d]"%(index, total)) 32 | 33 | hist = fast_hist(np.array(labels), np.array(preds), len(classfication.class_names)) 34 | Recall = per_class_Recall(hist) 35 | Precision = per_class_Precision(hist) 36 | 37 | show_results(metrics_out_path, hist, Recall, Precision, classfication.class_names) 38 | return correct_1 / total, correct_5 / total, Recall, Precision 39 | 40 | def fast_hist(a, b, n): 41 | k = (a >= 0) & (a < n) 42 | return np.bincount(n * a[k].astype(int) + b[k], minlength=n ** 2).reshape(n, n) 43 | 44 | def per_class_Recall(hist): 45 | return np.diag(hist) / np.maximum(hist.sum(1), 1) 46 | 47 | def per_class_Precision(hist): 48 | return np.diag(hist) / np.maximum(hist.sum(0), 1) 49 | 50 | def adjust_axes(r, t, fig, axes): 51 | bb = t.get_window_extent(renderer=r) 52 | text_width_inches = bb.width / fig.dpi 53 | current_fig_width = fig.get_figwidth() 54 | new_fig_width = current_fig_width + text_width_inches 55 | propotion = new_fig_width / current_fig_width 56 | x_lim = axes.get_xlim() 57 | axes.set_xlim([x_lim[0], x_lim[1] * propotion]) 58 | 59 | def draw_plot_func(values, name_classes, plot_title, x_label, output_path, tick_font_size = 12, plt_show = True): 60 | fig = plt.gcf() 61 | axes = plt.gca() 62 | plt.barh(range(len(values)), values, color='royalblue') 63 | plt.title(plot_title, fontsize=tick_font_size + 2) 64 | plt.xlabel(x_label, fontsize=tick_font_size) 65 | plt.yticks(range(len(values)), name_classes, fontsize=tick_font_size) 66 | r = fig.canvas.get_renderer() 67 | for i, val in enumerate(values): 68 | str_val = " " + str(val) 69 | if val < 1.0: 70 | str_val = " {0:.2f}".format(val) 71 | t = plt.text(val, i, str_val, color='royalblue', va='center', fontweight='bold') 72 | if i == (len(values)-1): 73 | adjust_axes(r, t, fig, axes) 74 | 75 | fig.tight_layout() 76 | fig.savefig(output_path) 77 | if plt_show: 78 | plt.show() 79 | plt.close() 80 | 81 | def show_results(miou_out_path, hist, Recall, Precision, name_classes, tick_font_size = 12): 82 | draw_plot_func(Recall, name_classes, "mRecall = {0:.2f}%".format(np.nanmean(Recall)*100), "Recall", \ 83 | os.path.join(miou_out_path, "Recall.png"), tick_font_size = tick_font_size, plt_show = False) 84 | print("Save Recall out to " + os.path.join(miou_out_path, "Recall.png")) 85 | 86 | draw_plot_func(Precision, name_classes, "mPrecision = {0:.2f}%".format(np.nanmean(Precision)*100), "Precision", \ 87 | os.path.join(miou_out_path, "Precision.png"), tick_font_size = tick_font_size, plt_show = False) 88 | print("Save Precision out to " + os.path.join(miou_out_path, "Precision.png")) 89 | 90 | with open(os.path.join(miou_out_path, "confusion_matrix.csv"), 'w', newline='') as f: 91 | writer = csv.writer(f) 92 | writer_list = [] 93 | writer_list.append([' '] + [str(c) for c in name_classes]) 94 | for i in range(len(hist)): 95 | writer_list.append([name_classes[i]] + [str(x) for x in hist[i]]) 96 | writer.writerows(writer_list) 97 | print("Save confusion_matrix out to " + os.path.join(miou_out_path, "confusion_matrix.csv")) 98 | --------------------------------------------------------------------------------