├── .idea ├── .gitignore ├── inspectionProfiles │ ├── Project_Default.xml │ └── profiles_settings.xml ├── misc.xml ├── modules.xml ├── pytorch-image-classifier-collection.iml └── vcs.xml ├── LICENSE ├── README.md ├── config └── config.yaml ├── data ├── cat │ ├── 1.jpg │ ├── 10.jpg │ ├── 2.jpg │ ├── 3.jpg │ ├── 4.jpg │ ├── 5.jpg │ ├── 7.jpg │ ├── 8.jpg │ └── 9.jpg └── dog │ ├── 1.jpg │ ├── 10.jpg │ ├── 2.jpg │ ├── 3.jpg │ ├── 4.jpg │ ├── 5.jpg │ ├── 6.jpg │ ├── 7.jpg │ ├── 8.jpg │ └── 9.jpg ├── infer.py ├── infer_prune_model.py ├── model ├── dataset │ ├── __pycache__ │ │ └── dataset.cpython-38.pyc │ └── dataset.py ├── loss │ ├── __pycache__ │ │ └── loss_fun.cpython-38.pyc │ └── loss_fun.py ├── net │ ├── __pycache__ │ │ └── net.cpython-38.pyc │ └── net.py ├── optimizer │ ├── __pycache__ │ │ └── optim.cpython-38.pyc │ └── optim.py └── utils │ ├── __pycache__ │ └── utils.cpython-38.pyc │ └── utils.py ├── pack_tools ├── pytorch_onnx_infer.py └── pytorch_to_onnx.py ├── prune_model └── pruning_model.py └── train.py /.idea/.gitignore: -------------------------------------------------------------------------------- 1 | # Default ignored files 2 | /shelf/ 3 | /workspace.xml 4 | -------------------------------------------------------------------------------- /.idea/inspectionProfiles/Project_Default.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 38 | -------------------------------------------------------------------------------- /.idea/inspectionProfiles/profiles_settings.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 6 | -------------------------------------------------------------------------------- /.idea/misc.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | -------------------------------------------------------------------------------- /.idea/modules.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /.idea/pytorch-image-classifier-collection.iml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 11 | -------------------------------------------------------------------------------- /.idea/vcs.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | 木兰宽松许可证, 第2版 2 | 3 | 木兰宽松许可证, 第2版 4 | 2020年1月 http://license.coscl.org.cn/MulanPSL2 5 | 6 | 7 | 您对“软件”的复制、使用、修改及分发受木兰宽松许可证,第2版(“本许可证”)的如下条款的约束: 8 | 9 | 0. 定义 10 | 11 | “软件”是指由“贡献”构成的许可在“本许可证”下的程序和相关文档的集合。 12 | 13 | “贡献”是指由任一“贡献者”许可在“本许可证”下的受版权法保护的作品。 14 | 15 | “贡献者”是指将受版权法保护的作品许可在“本许可证”下的自然人或“法人实体”。 16 | 17 | “法人实体”是指提交贡献的机构及其“关联实体”。 18 | 19 | “关联实体”是指,对“本许可证”下的行为方而言,控制、受控制或与其共同受控制的机构,此处的控制是指有受控方或共同受控方至少50%直接或间接的投票权、资金或其他有价证券。 20 | 21 | 1. 授予版权许可 22 | 23 | 每个“贡献者”根据“本许可证”授予您永久性的、全球性的、免费的、非独占的、不可撤销的版权许可,您可以复制、使用、修改、分发其“贡献”,不论修改与否。 24 | 25 | 2. 授予专利许可 26 | 27 | 每个“贡献者”根据“本许可证”授予您永久性的、全球性的、免费的、非独占的、不可撤销的(根据本条规定撤销除外)专利许可,供您制造、委托制造、使用、许诺销售、销售、进口其“贡献”或以其他方式转移其“贡献”。前述专利许可仅限于“贡献者”现在或将来拥有或控制的其“贡献”本身或其“贡献”与许可“贡献”时的“软件”结合而将必然会侵犯的专利权利要求,不包括对“贡献”的修改或包含“贡献”的其他结合。如果您或您的“关联实体”直接或间接地,就“软件”或其中的“贡献”对任何人发起专利侵权诉讼(包括反诉或交叉诉讼)或其他专利维权行动,指控其侵犯专利权,则“本许可证”授予您对“软件”的专利许可自您提起诉讼或发起维权行动之日终止。 28 | 29 | 3. 无商标许可 30 | 31 | “本许可证”不提供对“贡献者”的商品名称、商标、服务标志或产品名称的商标许可,但您为满足第4条规定的声明义务而必须使用除外。 32 | 33 | 4. 分发限制 34 | 35 | 您可以在任何媒介中将“软件”以源程序形式或可执行形式重新分发,不论修改与否,但您必须向接收者提供“本许可证”的副本,并保留“软件”中的版权、商标、专利及免责声明。 36 | 37 | 5. 免责声明与责任限制 38 | 39 | “软件”及其中的“贡献”在提供时不带任何明示或默示的担保。在任何情况下,“贡献者”或版权所有者不对任何人因使用“软件”或其中的“贡献”而引发的任何直接或间接损失承担责任,不论因何种原因导致或者基于何种法律理论,即使其曾被建议有此种损失的可能性。 40 | 41 | 6. 语言 42 | “本许可证”以中英文双语表述,中英文版本具有同等法律效力。如果中英文版本存在任何冲突不一致,以中文版为准。 43 | 44 | 条款结束 45 | 46 | 如何将木兰宽松许可证,第2版,应用到您的软件 47 | 48 | 如果您希望将木兰宽松许可证,第2版,应用到您的新软件,为了方便接收者查阅,建议您完成如下三步: 49 | 50 | 1, 请您补充如下声明中的空白,包括软件名、软件的首次发表年份以及您作为版权人的名字; 51 | 52 | 2, 请您在软件包的一级目录下创建以“LICENSE”为名的文件,将整个许可证文本放入该文件中; 53 | 54 | 3, 请将如下声明文本放入每个源文件的头部注释中。 55 | 56 | Copyright (c) [Year] [name of copyright holder] 57 | [Software Name] is licensed under Mulan PSL v2. 58 | You can use this software according to the terms and conditions of the Mulan PSL v2. 59 | You may obtain a copy of Mulan PSL v2 at: 60 | http://license.coscl.org.cn/MulanPSL2 61 | THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. 62 | See the Mulan PSL v2 for more details. 63 | 64 | 65 | Mulan Permissive Software License,Version 2 66 | 67 | Mulan Permissive Software License,Version 2 (Mulan PSL v2) 68 | January 2020 http://license.coscl.org.cn/MulanPSL2 69 | 70 | Your reproduction, use, modification and distribution of the Software shall be subject to Mulan PSL v2 (this License) with the following terms and conditions: 71 | 72 | 0. Definition 73 | 74 | Software means the program and related documents which are licensed under this License and comprise all Contribution(s). 75 | 76 | Contribution means the copyrightable work licensed by a particular Contributor under this License. 77 | 78 | Contributor means the Individual or Legal Entity who licenses its copyrightable work under this License. 79 | 80 | Legal Entity means the entity making a Contribution and all its Affiliates. 81 | 82 | Affiliates means entities that control, are controlled by, or are under common control with the acting entity under this License, ‘control’ means direct or indirect ownership of at least fifty percent (50%) of the voting power, capital or other securities of controlled or commonly controlled entity. 83 | 84 | 1. Grant of Copyright License 85 | 86 | Subject to the terms and conditions of this License, each Contributor hereby grants to you a perpetual, worldwide, royalty-free, non-exclusive, irrevocable copyright license to reproduce, use, modify, or distribute its Contribution, with modification or not. 87 | 88 | 2. Grant of Patent License 89 | 90 | Subject to the terms and conditions of this License, each Contributor hereby grants to you a perpetual, worldwide, royalty-free, non-exclusive, irrevocable (except for revocation under this Section) patent license to make, have made, use, offer for sale, sell, import or otherwise transfer its Contribution, where such patent license is only limited to the patent claims owned or controlled by such Contributor now or in future which will be necessarily infringed by its Contribution alone, or by combination of the Contribution with the Software to which the Contribution was contributed. The patent license shall not apply to any modification of the Contribution, and any other combination which includes the Contribution. If you or your Affiliates directly or indirectly institute patent litigation (including a cross claim or counterclaim in a litigation) or other patent enforcement activities against any individual or entity by alleging that the Software or any Contribution in it infringes patents, then any patent license granted to you under this License for the Software shall terminate as of the date such litigation or activity is filed or taken. 91 | 92 | 3. No Trademark License 93 | 94 | No trademark license is granted to use the trade names, trademarks, service marks, or product names of Contributor, except as required to fulfill notice requirements in Section 4. 95 | 96 | 4. Distribution Restriction 97 | 98 | You may distribute the Software in any medium with or without modification, whether in source or executable forms, provided that you provide recipients with a copy of this License and retain copyright, patent, trademark and disclaimer statements in the Software. 99 | 100 | 5. Disclaimer of Warranty and Limitation of Liability 101 | 102 | THE SOFTWARE AND CONTRIBUTION IN IT ARE PROVIDED WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED. IN NO EVENT SHALL ANY CONTRIBUTOR OR COPYRIGHT HOLDER BE LIABLE TO YOU FOR ANY DAMAGES, INCLUDING, BUT NOT LIMITED TO ANY DIRECT, OR INDIRECT, SPECIAL OR CONSEQUENTIAL DAMAGES ARISING FROM YOUR USE OR INABILITY TO USE THE SOFTWARE OR THE CONTRIBUTION IN IT, NO MATTER HOW IT’S CAUSED OR BASED ON WHICH LEGAL THEORY, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGES. 103 | 104 | 6. Language 105 | 106 | THIS LICENSE IS WRITTEN IN BOTH CHINESE AND ENGLISH, AND THE CHINESE VERSION AND ENGLISH VERSION SHALL HAVE THE SAME LEGAL EFFECT. IN THE CASE OF DIVERGENCE BETWEEN THE CHINESE AND ENGLISH VERSIONS, THE CHINESE VERSION SHALL PREVAIL. 107 | 108 | END OF THE TERMS AND CONDITIONS 109 | 110 | How to Apply the Mulan Permissive Software License,Version 2 (Mulan PSL v2) to Your Software 111 | 112 | To apply the Mulan PSL v2 to your work, for easy identification by recipients, you are suggested to complete following three steps: 113 | 114 | i Fill in the blanks in following statement, including insert your software name, the year of the first publication of your software, and your name identified as the copyright owner; 115 | 116 | ii Create a file named “LICENSE” which contains the whole context of this License in the first directory of your software package; 117 | 118 | iii Attach the statement to the appropriate annotated syntax at the beginning of each source file. 119 | 120 | 121 | Copyright (c) [Year] [name of copyright holder] 122 | [Software Name] is licensed under Mulan PSL v2. 123 | You can use this software according to the terms and conditions of the Mulan PSL v2. 124 | You may obtain a copy of Mulan PSL v2 at: 125 | http://license.coscl.org.cn/MulanPSL2 126 | THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. 127 | See the Mulan PSL v2 for more details. 128 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Pytorch-Image-Classifier-Collection 2 | 3 | #### 介绍 4 | ============================== 5 | 6 | ​ 支持多模型工程化的图像分类器 7 | 8 | ============================== 9 | 10 | #### 软件架构 11 | Pytorch+opencv 12 | 13 | #### 模型支持架构 14 | 15 | ##### 模型 16 | 17 | | - | - | - | - | 18 | | :------------------: | :----------------: | :----------------: | :----------------: | 19 | | resnet18 | resnet34 | resnet50 | resnet101 | 20 | | resnet152 | resnext101_32x8d | resnext50_32x4d | wide_resnet50_2 | 21 | | wide_resnet101_2 | densenet121 | densenet161 | densenet169 | 22 | | densenet201 | vgg11 | vgg13 | vgg13_bn | 23 | | vgg19 | vgg19_bn | vgg16 | vgg16_bn | 24 | | inception_v3 | mobilenet_v2 | mobilenet_v3_small | mobilenet_v3_large | 25 | | shufflenet_v2_x0_5 | shufflenet_v2_x1_0 | shufflenet_v2_x1_5 | shufflenet_v2_x2_0 | 26 | | alexnet | googlenet | mnasnet0_5 | mnasnet1_0 | 27 | | mnasnet1_3 | mnasnet0_75 | squeezenet1_0 | squeezenet1_1 | 28 | | efficientnet-b0(0-7) | | | | 29 | 30 | ##### 损失函数 31 | 32 | | - | - | - | - | 33 | | :--: | :--: | :-------: | :-----------: | 34 | | mse | l1 | smooth_l1 | cross_entropy | 35 | 36 | ##### 优化器 37 | 38 | | - | - | - | - | 39 | | :----: | :-----: | :------: | :--------: | 40 | | SGD | ASGD | Adam | AdamW | 41 | | Adamax | Adagrad | Adadelta | SparseAdam | 42 | | LBFGS | Rprop | RMSprop | | 43 | 44 | 45 | #### 安装教程 46 | 47 | 1. pytorch>=1.5即可,其余库自行安装即可。 48 | 49 | #### 使用说明 50 | 51 | 1. 配置文件config/config.yaml 52 | 53 | ``` 54 | data_dir: "./data/" #数据集存放地址 55 | train_rate: 0.8 #数据集划分,训练集比例 56 | image_size: 128 #输入网络图像大小 57 | net_type: "shufflenet_v2_x1_0" 58 | pretrained: True #是否添加预训练权重 59 | batch_size: 4 #批次 60 | init_lr: 0.01 #初始学习率 61 | optimizer: 'Adam' #优化器 62 | class_names: [ 'cat','dog' ] #你的类别名称,必须和data文件夹下的类别文件名一样 63 | epochs: 10 #训练总轮次 64 | loss_type: "mse" # mse / l1 / smooth_l1 / cross_entropy #损失函数 65 | model_dir: "./shufflenet_v2_x1_0/weight/" #权重存放地址 66 | log_dir: "./shufflenet_v2_x1_0/logs/" # tensorboard可视化文件存放地址 67 | ``` 68 | 69 | 2. 模型训练 70 | 71 | ``` 72 | # 第一次训练 73 | python train.py 74 | 75 | # 接着自己未训练完成的模型继续训练 76 | python train.py --weights_path 模型保存路径 77 | ``` 78 | 79 | 3. 模型推理 80 | 81 | ``` 82 | # 检测图片 83 | python infer.py image --image_path 图片地址 84 | 85 | # 检测视频 86 | python infer.py video --video_path 图片地址 87 | 88 | # 检测摄像头 89 | python infer.pu camera --camera_id 摄像头id 90 | ``` 91 | 92 | 4. 部署 93 | 94 | 1. onnx打包部署 95 | 96 | ``` 97 | # onnx打包 98 | python pack_tools/pytorch_to_onnx.py --config_path 配置文件地址 --weights_path 模型权重存放地址 99 | 100 | # onnx推理部署 101 | # 检测图片 102 | python pack_tools/pytorch_onnx_infer.py image --config_path 配置文件地址 --onnx_path 打包完成的onnx包地址 --image_path 图片地址 103 | 104 | # 检测视频 105 | python pack_tools/pytorch_onnx_infer.py video --config_path 配置文件地址 --onnx_path 打包完成的onnx包地址 --video_path 图片地址 106 | 107 | # 检测摄像头 108 | python pack_tools/pytorch_onnx_infer.py camera --config_path 配置文件地址 --onnx_path 打包完成的onnx包地址 --camera_id 摄像头id,默认为0 109 | ``` 110 | 111 | 5. 模型剪枝、量化压缩加速 112 | 113 | 1. 模型剪枝微调 114 | 115 | ``` 116 | # 模型剪枝微调 117 | python prune_model/pruning_model.py --weight_path 已训练好的模型权重地址 --prune_type 修剪模型的方式,支持:l1filter,l2filter,fpgm --sparsity 模型稀疏化比例 --finetune_epoches 微调模型的轮次数 --dummy_input 输入模型的形状,例如:(10,3,128,128) 118 | 119 | # onnx推理部署 120 | # 检测图片 121 | python infer_prune_model.py image --prune_weights_path 剪枝后的模型权重路径 --image_path 图片地址 122 | 123 | # 检测视频 124 | python infer_prune_model.py video --prune_weights_path 剪枝后的模型权重路径 --video_path 图片地址 125 | 126 | # 检测摄像头 127 | python infer_prune_model.py camera --prune_weights_path 剪枝后的模型权重路径 --camera_id 摄像头id,默认为0 128 | ``` 129 | 130 | 131 | 参与贡献 132 | 133 | ​ 作者:qiaofengsheng 134 | 135 | ​ B站地址:https://space.bilibili.com/241747799 136 | 137 | ​ github地址:https://github.com/qiaofengsheng/Pytorch-Image-Classifier-Collection.git 138 | 139 | ​ gitee地址:https://gitee.com/qiaofengsheng/pytorch-image-classifier-collection.git -------------------------------------------------------------------------------- /config/config.yaml: -------------------------------------------------------------------------------- 1 | data_dir: "./data/" #数据集存放地址 2 | train_rate: 0.8 #数据集划分,训练集比例 3 | image_size: 128 #输入网络图像大小 4 | net_type: "efficientnet-b0" 5 | #支持模型[resnet18,resnet34,resnet50,resnet101,resnet152,resnext101_32x8d,resnext50_32x4d,wide_resnet50_2,wide_resnet101_2, 6 | # densenet121,densenet161,densenet169,densenet201,vgg11,vgg13,vgg13_bn,vgg19,vgg19_bn,vgg16,vgg16_bn,inception_v3, 7 | # mobilenet_v2,mobilenet_v3_small,mobilenet_v3_large,shufflenet_v2_x0_5,shufflenet_v2_x1_0,shufflenet_v2_x1_5, 8 | # shufflenet_v2_x2_0,alexnet,googlenet,mnasnet0_5,mnasnet1_0,mnasnet1_3,mnasnet0_75,squeezenet1_0,squeezenet1_1] 9 | # efficientnet-b0 ... efficientnet-b7 10 | pretrained: False #是否添加预训练权重 11 | batch_size: 50 #批次 12 | init_lr: 0.01 #初始学习率 13 | optimizer: 'Adam' #优化器 [SGD,ASGD,Adam,AdamW,Adamax,Adagrad,Adadelta,SparseAdam,LBFGS,Rprop,RMSprop] 14 | class_names: [ 'cat','dog' ] #你的类别名称,必须和data文件夹下的类别文件名一样 15 | epochs: 15 #训练总轮次 16 | loss_type: "cross_entropy" # mse / l1 / smooth_l1 / cross_entropy #损失函数 17 | model_dir: "./efficientnet-b0/weight/" #权重存放地址 18 | log_dir: "./efficientnet-b0/logs/" # tensorboard可视化文件存放地址 -------------------------------------------------------------------------------- /data/cat/1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qiaofengsheng/Pytorch-Image-Classifier-Collection/b95a07451c6c169639af9a4f6f5e074055570828/data/cat/1.jpg -------------------------------------------------------------------------------- /data/cat/10.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qiaofengsheng/Pytorch-Image-Classifier-Collection/b95a07451c6c169639af9a4f6f5e074055570828/data/cat/10.jpg -------------------------------------------------------------------------------- /data/cat/2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qiaofengsheng/Pytorch-Image-Classifier-Collection/b95a07451c6c169639af9a4f6f5e074055570828/data/cat/2.jpg -------------------------------------------------------------------------------- /data/cat/3.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qiaofengsheng/Pytorch-Image-Classifier-Collection/b95a07451c6c169639af9a4f6f5e074055570828/data/cat/3.jpg -------------------------------------------------------------------------------- /data/cat/4.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qiaofengsheng/Pytorch-Image-Classifier-Collection/b95a07451c6c169639af9a4f6f5e074055570828/data/cat/4.jpg -------------------------------------------------------------------------------- /data/cat/5.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qiaofengsheng/Pytorch-Image-Classifier-Collection/b95a07451c6c169639af9a4f6f5e074055570828/data/cat/5.jpg -------------------------------------------------------------------------------- /data/cat/7.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qiaofengsheng/Pytorch-Image-Classifier-Collection/b95a07451c6c169639af9a4f6f5e074055570828/data/cat/7.jpg -------------------------------------------------------------------------------- /data/cat/8.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qiaofengsheng/Pytorch-Image-Classifier-Collection/b95a07451c6c169639af9a4f6f5e074055570828/data/cat/8.jpg -------------------------------------------------------------------------------- /data/cat/9.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qiaofengsheng/Pytorch-Image-Classifier-Collection/b95a07451c6c169639af9a4f6f5e074055570828/data/cat/9.jpg -------------------------------------------------------------------------------- /data/dog/1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qiaofengsheng/Pytorch-Image-Classifier-Collection/b95a07451c6c169639af9a4f6f5e074055570828/data/dog/1.jpg -------------------------------------------------------------------------------- /data/dog/10.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qiaofengsheng/Pytorch-Image-Classifier-Collection/b95a07451c6c169639af9a4f6f5e074055570828/data/dog/10.jpg -------------------------------------------------------------------------------- /data/dog/2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qiaofengsheng/Pytorch-Image-Classifier-Collection/b95a07451c6c169639af9a4f6f5e074055570828/data/dog/2.jpg -------------------------------------------------------------------------------- /data/dog/3.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qiaofengsheng/Pytorch-Image-Classifier-Collection/b95a07451c6c169639af9a4f6f5e074055570828/data/dog/3.jpg -------------------------------------------------------------------------------- /data/dog/4.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qiaofengsheng/Pytorch-Image-Classifier-Collection/b95a07451c6c169639af9a4f6f5e074055570828/data/dog/4.jpg -------------------------------------------------------------------------------- /data/dog/5.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qiaofengsheng/Pytorch-Image-Classifier-Collection/b95a07451c6c169639af9a4f6f5e074055570828/data/dog/5.jpg -------------------------------------------------------------------------------- /data/dog/6.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qiaofengsheng/Pytorch-Image-Classifier-Collection/b95a07451c6c169639af9a4f6f5e074055570828/data/dog/6.jpg -------------------------------------------------------------------------------- /data/dog/7.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qiaofengsheng/Pytorch-Image-Classifier-Collection/b95a07451c6c169639af9a4f6f5e074055570828/data/dog/7.jpg -------------------------------------------------------------------------------- /data/dog/8.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qiaofengsheng/Pytorch-Image-Classifier-Collection/b95a07451c6c169639af9a4f6f5e074055570828/data/dog/8.jpg -------------------------------------------------------------------------------- /data/dog/9.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qiaofengsheng/Pytorch-Image-Classifier-Collection/b95a07451c6c169639af9a4f6f5e074055570828/data/dog/9.jpg -------------------------------------------------------------------------------- /infer.py: -------------------------------------------------------------------------------- 1 | ''' 2 | _*_coding:utf-8 _*_ 3 | @Time :2022/1/29 10:52 4 | @Author : qiaofengsheng 5 | @File :infer.py 6 | @Software :PyCharm 7 | ''' 8 | import os 9 | 10 | from PIL import Image, ImageDraw, ImageFont 11 | import cv2 12 | import torch 13 | from model.utils import utils 14 | from torchvision import transforms 15 | from model.net.net import * 16 | import argparse 17 | 18 | parse = argparse.ArgumentParser('infer models') 19 | parse.add_argument('demo', type=str, help='推理类型支持:image/video/camera') 20 | parse.add_argument('--weights_path', type=str, default='', help='模型权重路径') 21 | parse.add_argument('--image_path', type=str, default='', help='图片存放路径') 22 | parse.add_argument('--video_path', type=str, default='', help='视频路径') 23 | parse.add_argument('--camera_id', type=int, default=0, help='摄像头id') 24 | 25 | 26 | class ModelInfer: 27 | def __init__(self, config, weights_path): 28 | self.config = config 29 | self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 30 | 31 | self.transform = transforms.Compose([ 32 | transforms.ToTensor() 33 | ]) 34 | self.net = ClassifierNet(self.config['net_type'], len(self.config['class_names']), 35 | False).to(self.device) 36 | if weights_path is not None: 37 | if os.path.exists(weights_path): 38 | self.net.load_state_dict(torch.load(weights_path)) 39 | print('successfully loading model weights!') 40 | else: 41 | print('no loading model weights!') 42 | else: 43 | print('please input weights_path!') 44 | exit(0) 45 | self.net.eval() 46 | 47 | def image_infer(self, image_path): 48 | image = Image.open(image_path) 49 | image_data = utils.keep_shape_resize(image, self.config['image_size']) 50 | image_data = self.transform(image_data) 51 | image_data = torch.unsqueeze(image_data, dim=0).to(self.device) 52 | out = self.net(image_data) 53 | out = torch.argmax(out) 54 | result = self.config['class_names'][int(out)] 55 | draw = ImageDraw.Draw(image) 56 | font = ImageFont.truetype(r"C:\Windows\Fonts\BRITANIC.TTF", 35) 57 | draw.text((10, 10), result, font=font, fill='red') 58 | image.show() 59 | 60 | def video_infer(self, video_path): 61 | cap = cv2.VideoCapture(video_path) 62 | while True: 63 | _, frame = cap.read() 64 | if _: 65 | image_data = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) 66 | image_data = Image.fromarray(image_data) 67 | image_data = utils.keep_shape_resize(image_data, self.config['image_size']) 68 | image_data = self.transform(image_data) 69 | image_data = torch.unsqueeze(image_data, dim=0).to(self.device) 70 | out = self.net(image_data) 71 | out = torch.argmax(out) 72 | result = self.config['class_names'][int(out)] 73 | cv2.putText(frame, result, (50, 50), cv2.FONT_HERSHEY_SIMPLEX, 2, (0, 0, 255), thickness=2) 74 | cv2.imshow('frame', frame) 75 | if cv2.waitKey(24) & 0XFF == ord('q'): 76 | break 77 | else: 78 | break 79 | 80 | def camera_infer(self, camera_id): 81 | cap = cv2.VideoCapture(camera_id) 82 | while True: 83 | _, frame = cap.read() 84 | h, w, c = frame.shape 85 | if _: 86 | image_data = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) 87 | image_data = Image.fromarray(image_data) 88 | image_data = utils.keep_shape_resize(image_data, self.config['image_size']) 89 | image_data = self.transform(image_data) 90 | image_data = torch.unsqueeze(image_data, dim=0).to(self.device) 91 | out = self.net(image_data) 92 | out = torch.argmax(out) 93 | result = self.config['class_names'][int(out)] 94 | cv2.putText(frame, result, (50, 50), cv2.FONT_HERSHEY_SIMPLEX, 2, (0, 0, 255), thickness=2) 95 | cv2.imshow('frame', frame) 96 | if cv2.waitKey(24) & 0XFF == ord('q'): 97 | break 98 | else: 99 | break 100 | 101 | 102 | if __name__ == '__main__': 103 | args = parse.parse_args() 104 | config = utils.load_config_util('config/config.yaml') 105 | model = ModelInfer(config, args.weights_path) 106 | if args.demo == 'image': 107 | model.image_infer(args.image_path) 108 | elif args.demo == 'video': 109 | model.video_infer(args.video_path) 110 | elif args.demo == 'camera': 111 | model.camera_infer(args.camera_id) 112 | else: 113 | exit(0) 114 | -------------------------------------------------------------------------------- /infer_prune_model.py: -------------------------------------------------------------------------------- 1 | ''' 2 | ==================板块功能描述==================== 3 | @Time :2022/2/8 10:04 4 | @Author : qiaofengsheng 5 | @File :infer_prune_model.py 6 | @Software :PyCharm 7 | @description: 8 | ================================================ 9 | ''' 10 | ''' 11 | _*_coding:utf-8 _*_ 12 | @Time :2022/1/29 10:52 13 | @Author : qiaofengsheng 14 | @File :infer.py 15 | @Software :PyCharm 16 | ''' 17 | import os 18 | 19 | from PIL import Image, ImageDraw, ImageFont 20 | import cv2 21 | import torch 22 | from model.utils import utils 23 | from torchvision import transforms 24 | from model.net.net import * 25 | import argparse 26 | 27 | parse = argparse.ArgumentParser('infer models') 28 | parse.add_argument('demo', type=str, help='推理类型支持:image/video/camera') 29 | parse.add_argument('--prune_weights_path', type=str, default='', help='剪枝后的模型权重路径') 30 | parse.add_argument('--image_path', type=str, default='', help='图片存放路径') 31 | parse.add_argument('--video_path', type=str, default='', help='视频路径') 32 | parse.add_argument('--camera_id', type=int, default=0, help='摄像头id') 33 | 34 | 35 | class ModelInfer: 36 | def __init__(self, config, weights_path): 37 | self.config = config 38 | self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 39 | 40 | self.transform = transforms.Compose([ 41 | transforms.ToTensor() 42 | ]) 43 | self.net = torch.load(weights_path) 44 | self.net.eval() 45 | 46 | def image_infer(self, image_path): 47 | image = Image.open(image_path) 48 | image_data = utils.keep_shape_resize(image, self.config['image_size']) 49 | image_data = self.transform(image_data) 50 | image_data = torch.unsqueeze(image_data, dim=0).to(self.device) 51 | out = self.net(image_data) 52 | out = torch.argmax(out) 53 | result = self.config['class_names'][int(out)] 54 | draw = ImageDraw.Draw(image) 55 | font = ImageFont.truetype(r"C:\Windows\Fonts\BRITANIC.TTF", 35) 56 | draw.text((10, 10), result, font=font, fill='red') 57 | image.show() 58 | 59 | def video_infer(self, video_path): 60 | cap = cv2.VideoCapture(video_path) 61 | while True: 62 | _, frame = cap.read() 63 | if _: 64 | image_data = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) 65 | image_data = Image.fromarray(image_data) 66 | image_data = utils.keep_shape_resize(image_data, self.config['image_size']) 67 | image_data = self.transform(image_data) 68 | image_data = torch.unsqueeze(image_data, dim=0).to(self.device) 69 | out = self.net(image_data) 70 | out = torch.argmax(out) 71 | result = self.config['class_names'][int(out)] 72 | cv2.putText(frame, result, (50, 50), cv2.FONT_HERSHEY_SIMPLEX, 2, (0, 0, 255), thickness=2) 73 | cv2.imshow('frame', frame) 74 | if cv2.waitKey(24) & 0XFF == ord('q'): 75 | break 76 | else: 77 | break 78 | 79 | def camera_infer(self, camera_id): 80 | cap = cv2.VideoCapture(camera_id) 81 | while True: 82 | _, frame = cap.read() 83 | h, w, c = frame.shape 84 | if _: 85 | image_data = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) 86 | image_data = Image.fromarray(image_data) 87 | image_data = utils.keep_shape_resize(image_data, self.config['image_size']) 88 | image_data = self.transform(image_data) 89 | image_data = torch.unsqueeze(image_data, dim=0).to(self.device) 90 | out = self.net(image_data) 91 | out = torch.argmax(out) 92 | result = self.config['class_names'][int(out)] 93 | cv2.putText(frame, result, (50, 50), cv2.FONT_HERSHEY_SIMPLEX, 2, (0, 0, 255), thickness=2) 94 | cv2.imshow('frame', frame) 95 | if cv2.waitKey(24) & 0XFF == ord('q'): 96 | break 97 | else: 98 | break 99 | 100 | 101 | if __name__ == '__main__': 102 | args = parse.parse_args() 103 | config = utils.load_config_util('config/config.yaml') 104 | model = ModelInfer(config, args.prune_weights_path) 105 | if args.demo == 'image': 106 | model.image_infer(args.image_path) 107 | elif args.demo == 'video': 108 | model.video_infer(args.video_path) 109 | elif args.demo == 'camera': 110 | model.camera_infer(args.camera_id) 111 | else: 112 | exit(0) 113 | -------------------------------------------------------------------------------- /model/dataset/__pycache__/dataset.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qiaofengsheng/Pytorch-Image-Classifier-Collection/b95a07451c6c169639af9a4f6f5e074055570828/model/dataset/__pycache__/dataset.cpython-38.pyc -------------------------------------------------------------------------------- /model/dataset/dataset.py: -------------------------------------------------------------------------------- 1 | ''' 2 | _*_coding:utf-8 _*_ 3 | @Time :2022/1/28 19:00 4 | @Author : qiaofengsheng 5 | @File :dataset.py 6 | @Software :PyCharm 7 | ''' 8 | import os 9 | 10 | from PIL import Image 11 | from torch.utils.data import * 12 | from model.utils import utils 13 | from torchvision import transforms 14 | 15 | 16 | class ClassDataset(Dataset): 17 | def __init__(self, data_dir, config): 18 | self.config = config 19 | self.transform = transforms.Compose([ 20 | transforms.RandomRotation(60), 21 | transforms.ToTensor() 22 | ]) 23 | self.dataset = [] 24 | class_dirs = os.listdir(data_dir) 25 | for class_dir in class_dirs: 26 | image_names = os.listdir(os.path.join(data_dir, class_dir)) 27 | for image_name in image_names: 28 | self.dataset.append( 29 | [os.path.join(data_dir, class_dir, image_name), 30 | int(config['class_names'].index(class_dir))]) 31 | 32 | def __len__(self): 33 | return len(self.dataset) 34 | 35 | def __getitem__(self, index): 36 | data = self.dataset[index] 37 | image_path, image_label = data 38 | image = Image.open(image_path) 39 | image = utils.keep_shape_resize(image, self.config['image_size']) 40 | return self.transform(image), image_label 41 | -------------------------------------------------------------------------------- /model/loss/__pycache__/loss_fun.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qiaofengsheng/Pytorch-Image-Classifier-Collection/b95a07451c6c169639af9a4f6f5e074055570828/model/loss/__pycache__/loss_fun.cpython-38.pyc -------------------------------------------------------------------------------- /model/loss/loss_fun.py: -------------------------------------------------------------------------------- 1 | ''' 2 | _*_coding:utf-8 _*_ 3 | @Time :2022/1/28 19:05 4 | @Author : qiaofengsheng 5 | @File :loss_fun.py 6 | @Software :PyCharm 7 | ''' 8 | 9 | from torch import nn 10 | 11 | 12 | class Loss: 13 | def __init__(self, loss_type='mse'): 14 | self.loss_fun = nn.MSELoss() 15 | if loss_type == 'mse': 16 | self.loss_fun = nn.MSELoss() 17 | elif loss_type == 'l1': 18 | self.loss_fun = nn.L1Loss() 19 | elif loss_type == 'smooth_l1': 20 | self.loss_fun = nn.SmoothL1Loss() 21 | elif loss_type == 'cross_entropy': 22 | self.loss_fun = nn.CrossEntropyLoss() 23 | 24 | def get_loss_fun(self): 25 | return self.loss_fun 26 | -------------------------------------------------------------------------------- /model/net/__pycache__/net.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qiaofengsheng/Pytorch-Image-Classifier-Collection/b95a07451c6c169639af9a4f6f5e074055570828/model/net/__pycache__/net.cpython-38.pyc -------------------------------------------------------------------------------- /model/net/net.py: -------------------------------------------------------------------------------- 1 | ''' 2 | _*_coding:utf-8 _*_ 3 | @Time :2022/1/28 19:05 4 | @Author : qiaofengsheng 5 | @File :net.py 6 | @Software :PyCharm 7 | ''' 8 | 9 | import torch 10 | from torchvision import models 11 | from torch import nn 12 | from efficientnet_pytorch import EfficientNet 13 | 14 | 15 | class ClassifierNet(nn.Module): 16 | def __init__(self, net_type='resnet18', num_classes=10, pretrained=False): 17 | super(ClassifierNet, self).__init__() 18 | self.layer = None 19 | if net_type == 'resnet18': self.layer = nn.Sequential(models.resnet18(pretrained=pretrained,num_classes=num_classes), ) 20 | if net_type == 'resnet34': self.layer = nn.Sequential(models.resnet34(pretrained=pretrained,num_classes=num_classes), ) 21 | if net_type == 'resnet50': self.layer = nn.Sequential(models.resnet50(pretrained=pretrained,num_classes=num_classes), ) 22 | if net_type == 'resnet101': self.layer = nn.Sequential(models.resnet101(pretrained=pretrained,num_classes=num_classes), ) 23 | if net_type == 'resnet152': self.layer = nn.Sequential(models.resnet152(pretrained=pretrained,num_classes=num_classes), ) 24 | if net_type == 'resnext101_32x8d': self.layer = nn.Sequential(models.resnext101_32x8d(pretrained=pretrained,num_classes=num_classes), ) 25 | if net_type == 'resnext50_32x4d': self.layer = nn.Sequential(models.resnext50_32x4d(pretrained=pretrained,num_classes=num_classes), ) 26 | if net_type == 'wide_resnet50_2': self.layer = nn.Sequential(models.wide_resnet50_2(pretrained=pretrained,num_classes=num_classes), ) 27 | if net_type == 'wide_resnet101_2': self.layer = nn.Sequential(models.wide_resnet101_2(pretrained=pretrained,num_classes=num_classes), ) 28 | if net_type == 'densenet121': self.layer = nn.Sequential(models.densenet121(pretrained=pretrained,num_classes=num_classes), ) 29 | if net_type == 'densenet161': self.layer = nn.Sequential(models.densenet161(pretrained=pretrained,num_classes=num_classes), ) 30 | if net_type == 'densenet169': self.layer = nn.Sequential(models.densenet169(pretrained=pretrained,num_classes=num_classes), ) 31 | if net_type == 'densenet201': self.layer = nn.Sequential(models.densenet201(pretrained=pretrained,num_classes=num_classes), ) 32 | if net_type == 'vgg11': self.layer = nn.Sequential(models.vgg11(pretrained=pretrained,num_classes=num_classes), ) 33 | if net_type == 'vgg13': self.layer = nn.Sequential(models.vgg13(pretrained=pretrained,num_classes=num_classes), ) 34 | if net_type == 'vgg13_bn': self.layer = nn.Sequential(models.vgg13_bn(pretrained=pretrained,num_classes=num_classes), ) 35 | if net_type == 'vgg19': self.layer = nn.Sequential(models.vgg19(pretrained=pretrained,num_classes=num_classes), ) 36 | if net_type == 'vgg19_bn': self.layer = nn.Sequential(models.vgg19_bn(pretrained=pretrained,num_classes=num_classes), ) 37 | if net_type == 'vgg16': self.layer = nn.Sequential(models.vgg16(pretrained=pretrained,num_classes=num_classes), ) 38 | if net_type == 'vgg16_bn': self.layer = nn.Sequential(models.vgg16_bn(pretrained=pretrained,num_classes=num_classes), ) 39 | if net_type == 'inception_v3': self.layer = nn.Sequential(models.inception_v3(pretrained=pretrained,num_classes=num_classes), ) 40 | if net_type == 'mobilenet_v2': self.layer = nn.Sequential(models.mobilenet_v2(pretrained=pretrained,num_classes=num_classes), ) 41 | if net_type == 'mobilenet_v3_small': self.layer = nn.Sequential( 42 | models.mobilenet_v3_small(pretrained=pretrained,num_classes=num_classes), ) 43 | if net_type == 'mobilenet_v3_large': self.layer = nn.Sequential( 44 | models.mobilenet_v3_large(pretrained=pretrained,num_classes=num_classes), ) 45 | if net_type == 'shufflenet_v2_x0_5': self.layer = nn.Sequential( 46 | models.shufflenet_v2_x0_5(pretrained=pretrained,num_classes=num_classes), ) 47 | if net_type == 'shufflenet_v2_x1_0': self.layer = nn.Sequential( 48 | models.shufflenet_v2_x1_0(pretrained=pretrained,num_classes=num_classes), ) 49 | if net_type == 'shufflenet_v2_x1_5': self.layer = nn.Sequential( 50 | models.shufflenet_v2_x1_5(pretrained=pretrained,num_classes=num_classes), ) 51 | if net_type == 'shufflenet_v2_x2_0': self.layer = nn.Sequential( 52 | models.shufflenet_v2_x2_0(pretrained=pretrained,num_classes=num_classes), ) 53 | if net_type == 'alexnet': 54 | self.layer = nn.Sequential(models.alexnet(pretrained=pretrained,num_classes=num_classes), ) 55 | if net_type == 'googlenet': 56 | self.layer = nn.Sequential(models.googlenet(pretrained=pretrained,num_classes=num_classes), ) 57 | if net_type == 'mnasnet0_5': 58 | self.layer = nn.Sequential(models.mnasnet0_5(pretrained=pretrained,num_classes=num_classes), ) 59 | if net_type == 'mnasnet1_0': 60 | self.layer = nn.Sequential(models.mnasnet1_0(pretrained=pretrained,num_classes=num_classes), ) 61 | if net_type == 'mnasnet1_3': 62 | self.layer = nn.Sequential(models.mnasnet1_3(pretrained=pretrained,num_classes=num_classes), ) 63 | if net_type == 'mnasnet0_75': 64 | self.layer = nn.Sequential(models.mnasnet0_75(pretrained=pretrained,num_classes=num_classes), ) 65 | if net_type == 'squeezenet1_0': 66 | self.layer = nn.Sequential(models.squeezenet1_0(pretrained=pretrained,num_classes=num_classes), ) 67 | if net_type == 'squeezenet1_1': 68 | self.layer = nn.Sequential(models.squeezenet1_1(pretrained=pretrained,num_classes=num_classes), ) 69 | if net_type in ['efficientnet-b0', 'efficientnet-b1', 'efficientnet-b2', 'efficientnet-b3', 'efficientnet-b4', 70 | 'efficientnet-b5', 'efficientnet-b6']: 71 | if pretrained: 72 | self.layer = nn.Sequential(EfficientNet.from_pretrained(net_type,num_classes=num_classes)) 73 | else: 74 | self.layer = nn.Sequential(EfficientNet.from_name(net_type,num_classes=num_classes)) 75 | 76 | def forward(self, x): 77 | return self.layer(x) 78 | 79 | if __name__ == '__main__': 80 | net=ClassifierNet('mnasnet1_0',pretrained=False) 81 | x=torch.randn(1,3,125,125) 82 | print(net(x).shape) 83 | -------------------------------------------------------------------------------- /model/optimizer/__pycache__/optim.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qiaofengsheng/Pytorch-Image-Classifier-Collection/b95a07451c6c169639af9a4f6f5e074055570828/model/optimizer/__pycache__/optim.cpython-38.pyc -------------------------------------------------------------------------------- /model/optimizer/optim.py: -------------------------------------------------------------------------------- 1 | ''' 2 | _*_coding:utf-8 _*_ 3 | @Time :2022/1/28 19:06 4 | @Author : qiaofengsheng 5 | @File :optim.py 6 | @Software :PyCharm 7 | ''' 8 | 9 | from torch import optim 10 | 11 | 12 | class Optimizer: 13 | def __init__(self, net, opt_type='Adam'): 14 | super(Optimizer, self).__init__() 15 | self.opt = optim.Adam(net.parameters()) 16 | if opt_type == 'SGD': 17 | self.opt = optim.SGD(net.parameters(), lr=0.01) 18 | elif opt_type == 'ASGD': 19 | self.opt = optim.ASGD(net.parameters()) 20 | elif opt_type == 'Adam': 21 | self.opt = optim.Adam(net.parameters()) 22 | elif opt_type == 'AdamW': 23 | self.opt = optim.AdamW(net.parameters()) 24 | elif opt_type == 'Adamax': 25 | self.opt = optim.Adamax(net.parameters()) 26 | elif opt_type == 'Adagrad': 27 | self.opt = optim.Adagrad(net.parameters()) 28 | elif opt_type == 'Adadelta': 29 | self.opt = optim.Adadelta(net.parameters()) 30 | elif opt_type == 'SparseAdam': 31 | self.opt = optim.SparseAdam(net.parameters()) 32 | elif opt_type == 'LBFGS': 33 | self.opt = optim.LBFGS(net.parameters()) 34 | elif opt_type == 'Rprop': 35 | self.opt = optim.Rprop(net.parameters()) 36 | elif opt_type == 'RMSprop': 37 | self.opt = optim.RMSprop(net.parameters()) 38 | 39 | def get_optimizer(self): 40 | return self.opt 41 | -------------------------------------------------------------------------------- /model/utils/__pycache__/utils.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qiaofengsheng/Pytorch-Image-Classifier-Collection/b95a07451c6c169639af9a4f6f5e074055570828/model/utils/__pycache__/utils.cpython-38.pyc -------------------------------------------------------------------------------- /model/utils/utils.py: -------------------------------------------------------------------------------- 1 | ''' 2 | _*_coding:utf-8 _*_ 3 | @Time :2022/1/28 19:58 4 | @Author : qiaofengsheng 5 | @File :utils.py 6 | @Software :PyCharm 7 | ''' 8 | import torch 9 | import yaml 10 | from PIL import Image 11 | from torch.nn.functional import one_hot 12 | 13 | 14 | def load_config_util(config_path): 15 | config_file = open(config_path, 'r', encoding='utf-8') 16 | config_data = yaml.load(config_file) 17 | return config_data 18 | 19 | 20 | def keep_shape_resize(frame, size=256): 21 | w, h = frame.size 22 | temp = max(w, h) 23 | mask = Image.new('RGB', (temp, temp), (0, 0, 0)) 24 | if w >= h: 25 | position = (0, (w - h) // 2) 26 | else: 27 | position = ((h - w) // 2, 0) 28 | mask.paste(frame, position) 29 | mask = mask.resize((size, size)) 30 | return mask 31 | 32 | 33 | def label_one_hot(label): 34 | return one_hot(torch.tensor(label)) 35 | -------------------------------------------------------------------------------- /pack_tools/pytorch_onnx_infer.py: -------------------------------------------------------------------------------- 1 | ''' 2 | _*_coding:utf-8 _*_ 3 | @Time :2022/1/30 10:28 4 | @Author : qiaofengsheng 5 | @File :pytorch_onnx_infer.py 6 | @Software :PyCharm 7 | ''' 8 | import os 9 | import sys 10 | 11 | import numpy as np 12 | 13 | sys.path.append(os.path.abspath(os.path.dirname(os.path.dirname(__file__)))) 14 | import cv2 15 | import onnxruntime 16 | import argparse 17 | from PIL import Image, ImageDraw, ImageFont 18 | from torchvision import transforms 19 | import torch 20 | from model.utils import utils 21 | 22 | parse = argparse.ArgumentParser(description='onnx model infer!') 23 | parse.add_argument('demo', type=str, help='推理类型支持:image/video/camera') 24 | parse.add_argument('--config_path', type=str, help='配置文件存放地址') 25 | parse.add_argument('--onnx_path', type=str, default=None, help='onnx包存放路径') 26 | parse.add_argument('--image_path', type=str, default='', help='图片存放路径') 27 | parse.add_argument('--video_path', type=str, default='', help='视频路径') 28 | parse.add_argument('--camera_id', type=int, default=0, help='摄像头id') 29 | parse.add_argument('--device', type=str, default='cpu', help='默认设备cpu (暂未完善GPU代码)') 30 | 31 | 32 | def to_numpy(tensor): 33 | return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy() 34 | 35 | 36 | def onnx_infer_image(args, config): 37 | ort_session = onnxruntime.InferenceSession(args.onnx_path) 38 | transform = transforms.Compose([transforms.ToTensor()]) 39 | image = Image.open(args.image_path) 40 | image_data = utils.keep_shape_resize(image, config['image_size']) 41 | image_data = transform(image_data) 42 | image_data = torch.unsqueeze(image_data, dim=0) 43 | if args.device == 'cpu': 44 | ort_input = {ort_session.get_inputs()[0].name: to_numpy(image_data)} 45 | ort_out = ort_session.run(None, ort_input) 46 | out = np.argmax(ort_out[0], axis=1) 47 | result = config['class_names'][int(out)] 48 | draw = ImageDraw.Draw(image) 49 | font = ImageFont.truetype(r"C:\Windows\Fonts\BRITANIC.TTF", 35) 50 | draw.text((10, 10), result, font=font, fill='red') 51 | image.show() 52 | elif args.device == 'cuda': 53 | pass 54 | else: 55 | exit(0) 56 | 57 | 58 | def onnx_infer_video(args, config): 59 | ort_session = onnxruntime.InferenceSession(args.onnx_path) 60 | transform = transforms.Compose([transforms.ToTensor()]) 61 | cap = cv2.VideoCapture(args.video_path) 62 | while True: 63 | _, frame = cap.read() 64 | if _: 65 | image_data = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) 66 | image_data = Image.fromarray(image_data) 67 | image_data = utils.keep_shape_resize(image_data, config['image_size']) 68 | image_data = transform(image_data) 69 | image_data = torch.unsqueeze(image_data, dim=0) 70 | if args.device == 'cpu': 71 | ort_input = {ort_session.get_inputs()[0].name: to_numpy(image_data)} 72 | ort_out = ort_session.run(None, ort_input) 73 | out = np.argmax(ort_out[0], axis=1) 74 | result = config['class_names'][int(out)] 75 | cv2.putText(frame, result, (50, 50), cv2.FONT_HERSHEY_SIMPLEX, 2, (0, 0, 255), thickness=2) 76 | cv2.imshow('frame', frame) 77 | if cv2.waitKey(24) & 0XFF == ord('q'): 78 | break 79 | elif args.device == 'cuda': 80 | pass 81 | else: 82 | exit(0) 83 | else: 84 | exit(0) 85 | 86 | 87 | def onnx_infer_camera(args, config): 88 | ort_session = onnxruntime.InferenceSession(args.onnx_path) 89 | transform = transforms.Compose([transforms.ToTensor()]) 90 | cap = cv2.VideoCapture(args.camera_id) 91 | while True: 92 | _, frame = cap.read() 93 | if _: 94 | image_data = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) 95 | image_data = Image.fromarray(image_data) 96 | image_data = utils.keep_shape_resize(image_data, config['image_size']) 97 | image_data = transform(image_data) 98 | image_data = torch.unsqueeze(image_data, dim=0) 99 | if args.device == 'cpu': 100 | ort_input = {ort_session.get_inputs()[0].name: to_numpy(image_data)} 101 | ort_out = ort_session.run(None, ort_input) 102 | out = np.argmax(ort_out[0], axis=1) 103 | result = config['class_names'][int(out)] 104 | cv2.putText(frame, result, (50, 50), cv2.FONT_HERSHEY_SIMPLEX, 2, (0, 0, 255), thickness=2) 105 | cv2.imshow('frame', frame) 106 | if cv2.waitKey(24) & 0XFF == ord('q'): 107 | break 108 | elif args.device == 'cuda': 109 | pass 110 | else: 111 | exit(0) 112 | else: 113 | exit(0) 114 | 115 | 116 | if __name__ == '__main__': 117 | args = parse.parse_args() 118 | config = utils.load_config_util(args.config_path) 119 | if args.demo == 'image': 120 | onnx_infer_image(args, config) 121 | elif args.demo == 'video': 122 | onnx_infer_video(args, config) 123 | elif args.demo == 'camera': 124 | onnx_infer_camera(args, config) 125 | else: 126 | exit(0) 127 | -------------------------------------------------------------------------------- /pack_tools/pytorch_to_onnx.py: -------------------------------------------------------------------------------- 1 | ''' 2 | _*_coding:utf-8 _*_ 3 | @Time :2022/1/29 19:00 4 | @Author : qiaofengsheng 5 | @File :pytorch_to_onnx.py 6 | @Software :PyCharm 7 | ''' 8 | import os 9 | import sys 10 | 11 | sys.path.append(os.path.abspath(os.path.dirname(os.path.dirname(__file__)))) 12 | import numpy as np 13 | import torch.onnx 14 | import torch.cuda 15 | import onnx, onnxruntime 16 | from model.net.net import * 17 | from model.utils import utils 18 | import argparse 19 | 20 | parse = argparse.ArgumentParser(description='pack onnx model') 21 | parse.add_argument('--config_path', type=str, default='', help='配置文件存放地址') 22 | parse.add_argument('--weights_path', type=str, default='', help='模型权重文件地址') 23 | 24 | 25 | def pack_onnx(model_path, config): 26 | model = ClassifierNet(config['net_type'], len(config['class_names']), 27 | False) 28 | map_location = lambda storage, loc: storage 29 | if torch.cuda.is_available(): 30 | map_location = None 31 | model.load_state_dict(torch.load(model_path, map_location=map_location)) 32 | model.eval() 33 | batch_size = 1 34 | input = torch.randn(batch_size, 3, 128, 128, requires_grad=True) 35 | output = model(input) 36 | torch.onnx.export(model, 37 | input, 38 | config['net_type'] + '.onnx', 39 | export_params=True, 40 | opset_version=11, 41 | do_constant_folding=True, 42 | input_names=['input'], 43 | output_names=['output'], 44 | dynamic_axes={ 45 | 'input': {0: 'batch_size'}, 46 | 'output': {0: 'batch_size'} 47 | } 48 | ) 49 | print('onnx打包成功!') 50 | output = model(input) 51 | onnx_model = onnx.load(config['net_type'] + '.onnx') 52 | onnx.checker.check_model(onnx_model) 53 | ort_session = onnxruntime.InferenceSession(config['net_type'] + '.onnx') 54 | 55 | def to_numpy(tensor): 56 | return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy 57 | 58 | ort_input = {ort_session.get_inputs()[0].name: to_numpy(input)} 59 | ort_output = ort_session.run(None, ort_input) 60 | 61 | np.testing.assert_allclose(to_numpy(output), ort_output[0], rtol=1e-03, atol=1e-05) 62 | print("Exported model has been tested with ONNXRuntime, and the result looks good!") 63 | 64 | 65 | if __name__ == '__main__': 66 | args = parse.parse_args() 67 | config = utils.load_config_util(args.config_path) 68 | pack_onnx(args.weights_path, config) 69 | -------------------------------------------------------------------------------- /prune_model/pruning_model.py: -------------------------------------------------------------------------------- 1 | ''' 2 | ==================板块功能描述==================== 3 | @Time :2022/2/7 15:55 4 | @Author : qiaofengsheng 5 | @File :pruning_model.py 6 | @Software :PyCharm 7 | @description:模型剪枝、量化压缩 8 | 支持FPGMPruner,L1FilterPruner,L2FilterPruner裁剪方式 9 | 其他裁剪方式待完善 10 | ================================================ 11 | ''' 12 | 13 | # 微调模型训练函数 14 | import sys 15 | import os 16 | 17 | import torch 18 | 19 | sys.path.append(os.path.abspath(os.path.dirname(os.path.dirname(__file__)))) 20 | import argparse 21 | from nni.compression.pytorch import ModelSpeedup 22 | from model.dataset.dataset import * 23 | from model.loss.loss_fun import Loss 24 | from model.optimizer.optim import Optimizer 25 | from model.utils.utils import * 26 | from model.net.net import * 27 | from nni.algorithms.compression.pytorch.pruning import ( 28 | FPGMPruner, 29 | L1FilterPruner, 30 | L2FilterPruner, 31 | LevelPruner, 32 | SlimPruner, 33 | AGPPruner, 34 | TaylorFOWeightFilterPruner, 35 | ActivationMeanRankFilterPruner, 36 | ActivationAPoZRankFilterPruner 37 | ) 38 | 39 | prune_save_path = 'prune_model/prune_checkpoints' 40 | if not os.path.exists(prune_save_path): 41 | os.makedirs(prune_save_path) 42 | 43 | parse = argparse.ArgumentParser('pruning model') 44 | parse.add_argument('--weight_path', type=str, default=None, help='已训练好的模型权重地址') 45 | parse.add_argument('--prune_type', type=str, default='l1filter', help='修剪模型的方式,支持:l1filter,l2filter,fpgm') 46 | parse.add_argument('--sparsity', type=float, default=0.5, help='模型稀疏化比例') 47 | parse.add_argument('--op_names', type=list, default=None, help='指定修剪哪些layer名字,默认修剪全部的Conv2d,输入为样例:["conv1","conv2"]') 48 | parse.add_argument('--finetune_epoches', type=int, default=10, help='微调模型的轮次数') 49 | parse.add_argument('--dummy_input', type=str, required=True, help='输入模型的形状,例如:(10,3,128,128)') 50 | 51 | 52 | def trainer(model, train_loader, optimizer, criterion, epoch, device): 53 | model = model.to(device) 54 | model.train() 55 | for idx, (image_data, target) in enumerate(train_loader): 56 | image_data, target = image_data.to(device), target.to(device) 57 | optimizer.zero_grad() 58 | output = model(image_data) 59 | train_loss = criterion(output, target) 60 | train_loss.backward() 61 | optimizer.step() 62 | if idx % 20 == 0: 63 | print(f'第{epoch}轮--第{idx}批次--train_loss : {train_loss.item()}') 64 | 65 | 66 | # 微调模型评估函数 67 | def evaluator(model, test_loader, criterion, epoch, device): 68 | model = model.to(device) 69 | index = 0 70 | test_loss = 0 71 | acc = 0 72 | with torch.no_grad(): 73 | model.eval() 74 | for image_data, target in test_loader: 75 | image_data, target = image_data.to(device), target.to(device) 76 | output = model(image_data) 77 | test_loss += criterion(output, target) 78 | pred = output.argmax(dim=1) 79 | acc += torch.mean(torch.eq(pred, target).float()).item() 80 | index += 1 81 | test_loss /= index 82 | acc /= index 83 | print(f'第{epoch}轮--Average test_loss : {test_loss} -- Average Accuracy : {acc}') 84 | return acc 85 | 86 | 87 | def prune_tools(args, model, train_loader, test_loader, criterion, optimizer, device): 88 | model.load_state_dict(torch.load(args.weight_path)) 89 | model = model.to(device) 90 | print(model) 91 | if args.op_names is None: 92 | config_list = [{ 93 | 'sparsity': args.sparsity, 94 | 'op_types': ['Conv2d'] 95 | }] 96 | else: 97 | config_list = [{ 98 | 'sparsity': args.sparsity, 99 | 'op_type': ['Conv2d'], 100 | 'op_names': args.op_names 101 | }] 102 | prune_type = { 103 | 'l1filter': L1FilterPruner, 104 | 'l2filter': L2FilterPruner, 105 | 'fpgm': FPGMPruner 106 | } 107 | # 裁剪模型 108 | pruner = prune_type[args.prune_type](model, config_list) 109 | pruner.compress() 110 | 111 | # 导出稀疏模型和掩码模型 112 | pruner.export_model(model_path=os.path.join(prune_save_path, 'sparsity_model.pth'), 113 | mask_path=os.path.join(prune_save_path, 'mask_model.pth')) 114 | 115 | # 打开新模型 116 | pruner._unwrap_model() 117 | # 模型加速 118 | dummy_input = args.dummy_input.split(',') 119 | n, c, h, w = int(dummy_input[0][1:]), int(dummy_input[1]), int(dummy_input[2]), int(dummy_input[3][:-1]) 120 | m_speedup = ModelSpeedup(model, dummy_input=torch.randn(n, c, h, w).to(device), 121 | masks_file=os.path.join(prune_save_path, 'mask_model.pth')) 122 | m_speedup.speedup_model() 123 | 124 | # 微调模型 125 | best_acc = 0 126 | for epoch in range(1, args.finetune_epoches + 1): 127 | trainer(model, train_loader, optimizer, criterion, epoch, device) 128 | acc = evaluator(model, test_loader, criterion, epoch, device) 129 | if acc > best_acc: 130 | torch.save(model, os.path.join(prune_save_path, 'pruned_model.pth')) 131 | print('successfully save pruned_model weights!') 132 | best_acc = acc 133 | else: 134 | continue 135 | print(f'微调后的模型准确率为 : {best_acc * 100}%') 136 | 137 | 138 | if __name__ == '__main__': 139 | args = parse.parse_args() 140 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 141 | config = load_config_util('config/config.yaml') 142 | dataset = ClassDataset(config['data_dir'], config) 143 | train_dataset, test_dataset = random_split(dataset, 144 | [int(len(dataset) * config['train_rate']), 145 | len(dataset) - int( 146 | len(dataset) * config['train_rate'])] 147 | ) 148 | train_data_loader = DataLoader(train_dataset, batch_size=config['batch_size'], shuffle=True) 149 | test_data_loader = DataLoader(test_dataset, batch_size=config['batch_size'], shuffle=True) 150 | model = ClassifierNet(config['net_type'], len(config['class_names']), False) 151 | model.load_state_dict(torch.load(args.weight_path)) 152 | criterion = Loss(config['loss_type']).get_loss_fun() 153 | optimizer = Optimizer(model, config['optimizer']).get_optimizer() 154 | prune_tools(args, model, train_data_loader, test_data_loader, criterion, optimizer, device) 155 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | ''' 2 | _*_coding:utf-8 _*_ 3 | @Time :2022/1/29 10:52 4 | @Author : qiaofengsheng 5 | @File :train.py 6 | @Software :PyCharm 7 | ''' 8 | import os.path 9 | import time 10 | import torch 11 | import tqdm 12 | from torch.utils.tensorboard import SummaryWriter 13 | from model.net.net import ClassifierNet 14 | from model.loss.loss_fun import * 15 | from model.optimizer.optim import * 16 | from model.dataset.dataset import * 17 | import argparse 18 | 19 | parse = argparse.ArgumentParser(description='train_demo of argparse') 20 | parse.add_argument('--weights_path', default=None) 21 | 22 | 23 | class Train: 24 | def __init__(self, config): 25 | self.config = config 26 | if not os.path.exists(config['model_dir']): 27 | os.makedirs(config['model_dir']) 28 | self.summary_writer = SummaryWriter(config['log_dir']) 29 | self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 30 | self.net = ClassifierNet(self.config['net_type'], len(self.config['class_names']), 31 | self.config['pretrained']).to(self.device) 32 | self.loss_fun = Loss(self.config['loss_type']).get_loss_fun() 33 | self.optimizer = Optimizer(self.net, self.config['optimizer']).get_optimizer() 34 | self.dataset = ClassDataset(self.config['data_dir'], config) 35 | self.train_dataset, self.test_dataset = random_split(self.dataset, 36 | [int(len(self.dataset) * config['train_rate']), 37 | len(self.dataset) - int( 38 | len(self.dataset) * config['train_rate'])] 39 | ) 40 | self.train_data_loader = DataLoader(self.train_dataset, batch_size=self.config['batch_size'], shuffle=True) 41 | self.test_data_loader = DataLoader(self.test_dataset, batch_size=self.config['batch_size'], shuffle=True) 42 | 43 | def train(self, weights_path): 44 | print(f'device:{self.device} 训练集:{len(self.train_dataset)} 测试集:{len(self.test_dataset)}') 45 | if weights_path is not None: 46 | if os.path.exists(weights_path): 47 | self.net.load_state_dict(torch.load(weights_path)) 48 | print('successfully loading model weights!') 49 | else: 50 | print('no loading model weights') 51 | temp_acc = 0 52 | for epoch in range(1, self.config['epochs'] + 1): 53 | self.net.train() 54 | with tqdm.tqdm(self.train_data_loader) as t1: 55 | for i, (image_data, image_label) in enumerate(self.train_data_loader): 56 | image_data, image_label = image_data.to(self.device), image_label.to(self.device) 57 | out = self.net(image_data) 58 | if self.config['loss_type'] == 'cross_entropy': 59 | train_loss = self.loss_fun(out, image_label) 60 | else: 61 | train_loss = self.loss_fun(out, utils.label_one_hot(image_label).type(torch.FloatTensor).to( 62 | self.device)) 63 | t1.set_description(f'Train-Epoch {epoch} 轮 {i} 批次 : ') 64 | t1.set_postfix(train_loss=train_loss.item()) 65 | time.sleep(0.1) 66 | t1.update(1) 67 | self.optimizer.zero_grad() 68 | train_loss.backward() 69 | self.optimizer.step() 70 | if i % 10 == 0: 71 | torch.save(self.net.state_dict(), os.path.join(self.config['model_dir'], 'last.pth')) 72 | self.summary_writer.add_scalar('train_loss', train_loss.item(), epoch) 73 | 74 | self.net.eval() 75 | acc, temp = 0, 0 76 | with torch.no_grad(): 77 | with tqdm.tqdm(self.test_data_loader) as t2: 78 | for j, (image_data, image_label) in enumerate(self.test_data_loader): 79 | image_data, image_label = image_data.to(self.device), image_label.to(self.device) 80 | out = self.net(image_data) 81 | if self.config['loss_type'] == 'cross_entropy': 82 | test_loss = self.loss_fun(out, image_label) 83 | else: 84 | test_loss = self.loss_fun(out, utils.label_one_hot(image_label).type(torch.FloatTensor).to( 85 | self.device)) 86 | out = torch.argmax(out, dim=1) 87 | test_acc = torch.mean(torch.eq(out, image_label).float()).item() 88 | acc += test_acc 89 | temp += 1 90 | t2.set_description(f'Test-Epoch {epoch} 轮 {j} 批次 : ') 91 | t2.set_postfix(test_loss=test_loss.item(), test_acc=test_acc) 92 | time.sleep(0.1) 93 | t2.update(1) 94 | print(f'Test-Epoch {epoch} 轮准确率为 : {acc / temp}') 95 | if (acc / temp) > temp_acc: 96 | temp_acc = acc / temp 97 | torch.save(self.net.state_dict(), os.path.join(self.config['model_dir'], 'best.pth')) 98 | else: 99 | torch.save(self.net.state_dict(), os.path.join(self.config['model_dir'], 'last.pth')) 100 | self.summary_writer.add_scalar('test_loss', test_loss.item(), epoch) 101 | self.summary_writer.add_scalar('test_acc', acc / temp, epoch) 102 | 103 | 104 | if __name__ == '__main__': 105 | args = parse.parse_args() 106 | config = utils.load_config_util('config/config.yaml') 107 | Train(config).train(args.weights_path) 108 | --------------------------------------------------------------------------------