├── .idea
├── .gitignore
├── inspectionProfiles
│ ├── Project_Default.xml
│ └── profiles_settings.xml
├── misc.xml
├── modules.xml
├── pytorch-image-classifier-collection.iml
└── vcs.xml
├── LICENSE
├── README.md
├── config
└── config.yaml
├── data
├── cat
│ ├── 1.jpg
│ ├── 10.jpg
│ ├── 2.jpg
│ ├── 3.jpg
│ ├── 4.jpg
│ ├── 5.jpg
│ ├── 7.jpg
│ ├── 8.jpg
│ └── 9.jpg
└── dog
│ ├── 1.jpg
│ ├── 10.jpg
│ ├── 2.jpg
│ ├── 3.jpg
│ ├── 4.jpg
│ ├── 5.jpg
│ ├── 6.jpg
│ ├── 7.jpg
│ ├── 8.jpg
│ └── 9.jpg
├── infer.py
├── infer_prune_model.py
├── model
├── dataset
│ ├── __pycache__
│ │ └── dataset.cpython-38.pyc
│ └── dataset.py
├── loss
│ ├── __pycache__
│ │ └── loss_fun.cpython-38.pyc
│ └── loss_fun.py
├── net
│ ├── __pycache__
│ │ └── net.cpython-38.pyc
│ └── net.py
├── optimizer
│ ├── __pycache__
│ │ └── optim.cpython-38.pyc
│ └── optim.py
└── utils
│ ├── __pycache__
│ └── utils.cpython-38.pyc
│ └── utils.py
├── pack_tools
├── pytorch_onnx_infer.py
└── pytorch_to_onnx.py
├── prune_model
└── pruning_model.py
└── train.py
/.idea/.gitignore:
--------------------------------------------------------------------------------
1 | # Default ignored files
2 | /shelf/
3 | /workspace.xml
4 |
--------------------------------------------------------------------------------
/.idea/inspectionProfiles/Project_Default.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
36 |
37 |
38 |
--------------------------------------------------------------------------------
/.idea/inspectionProfiles/profiles_settings.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/.idea/misc.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
--------------------------------------------------------------------------------
/.idea/modules.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/.idea/pytorch-image-classifier-collection.iml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
--------------------------------------------------------------------------------
/.idea/vcs.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | 木兰宽松许可证, 第2版
2 |
3 | 木兰宽松许可证, 第2版
4 | 2020年1月 http://license.coscl.org.cn/MulanPSL2
5 |
6 |
7 | 您对“软件”的复制、使用、修改及分发受木兰宽松许可证,第2版(“本许可证”)的如下条款的约束:
8 |
9 | 0. 定义
10 |
11 | “软件”是指由“贡献”构成的许可在“本许可证”下的程序和相关文档的集合。
12 |
13 | “贡献”是指由任一“贡献者”许可在“本许可证”下的受版权法保护的作品。
14 |
15 | “贡献者”是指将受版权法保护的作品许可在“本许可证”下的自然人或“法人实体”。
16 |
17 | “法人实体”是指提交贡献的机构及其“关联实体”。
18 |
19 | “关联实体”是指,对“本许可证”下的行为方而言,控制、受控制或与其共同受控制的机构,此处的控制是指有受控方或共同受控方至少50%直接或间接的投票权、资金或其他有价证券。
20 |
21 | 1. 授予版权许可
22 |
23 | 每个“贡献者”根据“本许可证”授予您永久性的、全球性的、免费的、非独占的、不可撤销的版权许可,您可以复制、使用、修改、分发其“贡献”,不论修改与否。
24 |
25 | 2. 授予专利许可
26 |
27 | 每个“贡献者”根据“本许可证”授予您永久性的、全球性的、免费的、非独占的、不可撤销的(根据本条规定撤销除外)专利许可,供您制造、委托制造、使用、许诺销售、销售、进口其“贡献”或以其他方式转移其“贡献”。前述专利许可仅限于“贡献者”现在或将来拥有或控制的其“贡献”本身或其“贡献”与许可“贡献”时的“软件”结合而将必然会侵犯的专利权利要求,不包括对“贡献”的修改或包含“贡献”的其他结合。如果您或您的“关联实体”直接或间接地,就“软件”或其中的“贡献”对任何人发起专利侵权诉讼(包括反诉或交叉诉讼)或其他专利维权行动,指控其侵犯专利权,则“本许可证”授予您对“软件”的专利许可自您提起诉讼或发起维权行动之日终止。
28 |
29 | 3. 无商标许可
30 |
31 | “本许可证”不提供对“贡献者”的商品名称、商标、服务标志或产品名称的商标许可,但您为满足第4条规定的声明义务而必须使用除外。
32 |
33 | 4. 分发限制
34 |
35 | 您可以在任何媒介中将“软件”以源程序形式或可执行形式重新分发,不论修改与否,但您必须向接收者提供“本许可证”的副本,并保留“软件”中的版权、商标、专利及免责声明。
36 |
37 | 5. 免责声明与责任限制
38 |
39 | “软件”及其中的“贡献”在提供时不带任何明示或默示的担保。在任何情况下,“贡献者”或版权所有者不对任何人因使用“软件”或其中的“贡献”而引发的任何直接或间接损失承担责任,不论因何种原因导致或者基于何种法律理论,即使其曾被建议有此种损失的可能性。
40 |
41 | 6. 语言
42 | “本许可证”以中英文双语表述,中英文版本具有同等法律效力。如果中英文版本存在任何冲突不一致,以中文版为准。
43 |
44 | 条款结束
45 |
46 | 如何将木兰宽松许可证,第2版,应用到您的软件
47 |
48 | 如果您希望将木兰宽松许可证,第2版,应用到您的新软件,为了方便接收者查阅,建议您完成如下三步:
49 |
50 | 1, 请您补充如下声明中的空白,包括软件名、软件的首次发表年份以及您作为版权人的名字;
51 |
52 | 2, 请您在软件包的一级目录下创建以“LICENSE”为名的文件,将整个许可证文本放入该文件中;
53 |
54 | 3, 请将如下声明文本放入每个源文件的头部注释中。
55 |
56 | Copyright (c) [Year] [name of copyright holder]
57 | [Software Name] is licensed under Mulan PSL v2.
58 | You can use this software according to the terms and conditions of the Mulan PSL v2.
59 | You may obtain a copy of Mulan PSL v2 at:
60 | http://license.coscl.org.cn/MulanPSL2
61 | THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
62 | See the Mulan PSL v2 for more details.
63 |
64 |
65 | Mulan Permissive Software License,Version 2
66 |
67 | Mulan Permissive Software License,Version 2 (Mulan PSL v2)
68 | January 2020 http://license.coscl.org.cn/MulanPSL2
69 |
70 | Your reproduction, use, modification and distribution of the Software shall be subject to Mulan PSL v2 (this License) with the following terms and conditions:
71 |
72 | 0. Definition
73 |
74 | Software means the program and related documents which are licensed under this License and comprise all Contribution(s).
75 |
76 | Contribution means the copyrightable work licensed by a particular Contributor under this License.
77 |
78 | Contributor means the Individual or Legal Entity who licenses its copyrightable work under this License.
79 |
80 | Legal Entity means the entity making a Contribution and all its Affiliates.
81 |
82 | Affiliates means entities that control, are controlled by, or are under common control with the acting entity under this License, ‘control’ means direct or indirect ownership of at least fifty percent (50%) of the voting power, capital or other securities of controlled or commonly controlled entity.
83 |
84 | 1. Grant of Copyright License
85 |
86 | Subject to the terms and conditions of this License, each Contributor hereby grants to you a perpetual, worldwide, royalty-free, non-exclusive, irrevocable copyright license to reproduce, use, modify, or distribute its Contribution, with modification or not.
87 |
88 | 2. Grant of Patent License
89 |
90 | Subject to the terms and conditions of this License, each Contributor hereby grants to you a perpetual, worldwide, royalty-free, non-exclusive, irrevocable (except for revocation under this Section) patent license to make, have made, use, offer for sale, sell, import or otherwise transfer its Contribution, where such patent license is only limited to the patent claims owned or controlled by such Contributor now or in future which will be necessarily infringed by its Contribution alone, or by combination of the Contribution with the Software to which the Contribution was contributed. The patent license shall not apply to any modification of the Contribution, and any other combination which includes the Contribution. If you or your Affiliates directly or indirectly institute patent litigation (including a cross claim or counterclaim in a litigation) or other patent enforcement activities against any individual or entity by alleging that the Software or any Contribution in it infringes patents, then any patent license granted to you under this License for the Software shall terminate as of the date such litigation or activity is filed or taken.
91 |
92 | 3. No Trademark License
93 |
94 | No trademark license is granted to use the trade names, trademarks, service marks, or product names of Contributor, except as required to fulfill notice requirements in Section 4.
95 |
96 | 4. Distribution Restriction
97 |
98 | You may distribute the Software in any medium with or without modification, whether in source or executable forms, provided that you provide recipients with a copy of this License and retain copyright, patent, trademark and disclaimer statements in the Software.
99 |
100 | 5. Disclaimer of Warranty and Limitation of Liability
101 |
102 | THE SOFTWARE AND CONTRIBUTION IN IT ARE PROVIDED WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED. IN NO EVENT SHALL ANY CONTRIBUTOR OR COPYRIGHT HOLDER BE LIABLE TO YOU FOR ANY DAMAGES, INCLUDING, BUT NOT LIMITED TO ANY DIRECT, OR INDIRECT, SPECIAL OR CONSEQUENTIAL DAMAGES ARISING FROM YOUR USE OR INABILITY TO USE THE SOFTWARE OR THE CONTRIBUTION IN IT, NO MATTER HOW IT’S CAUSED OR BASED ON WHICH LEGAL THEORY, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGES.
103 |
104 | 6. Language
105 |
106 | THIS LICENSE IS WRITTEN IN BOTH CHINESE AND ENGLISH, AND THE CHINESE VERSION AND ENGLISH VERSION SHALL HAVE THE SAME LEGAL EFFECT. IN THE CASE OF DIVERGENCE BETWEEN THE CHINESE AND ENGLISH VERSIONS, THE CHINESE VERSION SHALL PREVAIL.
107 |
108 | END OF THE TERMS AND CONDITIONS
109 |
110 | How to Apply the Mulan Permissive Software License,Version 2 (Mulan PSL v2) to Your Software
111 |
112 | To apply the Mulan PSL v2 to your work, for easy identification by recipients, you are suggested to complete following three steps:
113 |
114 | i Fill in the blanks in following statement, including insert your software name, the year of the first publication of your software, and your name identified as the copyright owner;
115 |
116 | ii Create a file named “LICENSE” which contains the whole context of this License in the first directory of your software package;
117 |
118 | iii Attach the statement to the appropriate annotated syntax at the beginning of each source file.
119 |
120 |
121 | Copyright (c) [Year] [name of copyright holder]
122 | [Software Name] is licensed under Mulan PSL v2.
123 | You can use this software according to the terms and conditions of the Mulan PSL v2.
124 | You may obtain a copy of Mulan PSL v2 at:
125 | http://license.coscl.org.cn/MulanPSL2
126 | THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
127 | See the Mulan PSL v2 for more details.
128 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Pytorch-Image-Classifier-Collection
2 |
3 | #### 介绍
4 | ==============================
5 |
6 | 支持多模型工程化的图像分类器
7 |
8 | ==============================
9 |
10 | #### 软件架构
11 | Pytorch+opencv
12 |
13 | #### 模型支持架构
14 |
15 | ##### 模型
16 |
17 | | - | - | - | - |
18 | | :------------------: | :----------------: | :----------------: | :----------------: |
19 | | resnet18 | resnet34 | resnet50 | resnet101 |
20 | | resnet152 | resnext101_32x8d | resnext50_32x4d | wide_resnet50_2 |
21 | | wide_resnet101_2 | densenet121 | densenet161 | densenet169 |
22 | | densenet201 | vgg11 | vgg13 | vgg13_bn |
23 | | vgg19 | vgg19_bn | vgg16 | vgg16_bn |
24 | | inception_v3 | mobilenet_v2 | mobilenet_v3_small | mobilenet_v3_large |
25 | | shufflenet_v2_x0_5 | shufflenet_v2_x1_0 | shufflenet_v2_x1_5 | shufflenet_v2_x2_0 |
26 | | alexnet | googlenet | mnasnet0_5 | mnasnet1_0 |
27 | | mnasnet1_3 | mnasnet0_75 | squeezenet1_0 | squeezenet1_1 |
28 | | efficientnet-b0(0-7) | | | |
29 |
30 | ##### 损失函数
31 |
32 | | - | - | - | - |
33 | | :--: | :--: | :-------: | :-----------: |
34 | | mse | l1 | smooth_l1 | cross_entropy |
35 |
36 | ##### 优化器
37 |
38 | | - | - | - | - |
39 | | :----: | :-----: | :------: | :--------: |
40 | | SGD | ASGD | Adam | AdamW |
41 | | Adamax | Adagrad | Adadelta | SparseAdam |
42 | | LBFGS | Rprop | RMSprop | |
43 |
44 |
45 | #### 安装教程
46 |
47 | 1. pytorch>=1.5即可,其余库自行安装即可。
48 |
49 | #### 使用说明
50 |
51 | 1. 配置文件config/config.yaml
52 |
53 | ```
54 | data_dir: "./data/" #数据集存放地址
55 | train_rate: 0.8 #数据集划分,训练集比例
56 | image_size: 128 #输入网络图像大小
57 | net_type: "shufflenet_v2_x1_0"
58 | pretrained: True #是否添加预训练权重
59 | batch_size: 4 #批次
60 | init_lr: 0.01 #初始学习率
61 | optimizer: 'Adam' #优化器
62 | class_names: [ 'cat','dog' ] #你的类别名称,必须和data文件夹下的类别文件名一样
63 | epochs: 10 #训练总轮次
64 | loss_type: "mse" # mse / l1 / smooth_l1 / cross_entropy #损失函数
65 | model_dir: "./shufflenet_v2_x1_0/weight/" #权重存放地址
66 | log_dir: "./shufflenet_v2_x1_0/logs/" # tensorboard可视化文件存放地址
67 | ```
68 |
69 | 2. 模型训练
70 |
71 | ```
72 | # 第一次训练
73 | python train.py
74 |
75 | # 接着自己未训练完成的模型继续训练
76 | python train.py --weights_path 模型保存路径
77 | ```
78 |
79 | 3. 模型推理
80 |
81 | ```
82 | # 检测图片
83 | python infer.py image --image_path 图片地址
84 |
85 | # 检测视频
86 | python infer.py video --video_path 图片地址
87 |
88 | # 检测摄像头
89 | python infer.pu camera --camera_id 摄像头id
90 | ```
91 |
92 | 4. 部署
93 |
94 | 1. onnx打包部署
95 |
96 | ```
97 | # onnx打包
98 | python pack_tools/pytorch_to_onnx.py --config_path 配置文件地址 --weights_path 模型权重存放地址
99 |
100 | # onnx推理部署
101 | # 检测图片
102 | python pack_tools/pytorch_onnx_infer.py image --config_path 配置文件地址 --onnx_path 打包完成的onnx包地址 --image_path 图片地址
103 |
104 | # 检测视频
105 | python pack_tools/pytorch_onnx_infer.py video --config_path 配置文件地址 --onnx_path 打包完成的onnx包地址 --video_path 图片地址
106 |
107 | # 检测摄像头
108 | python pack_tools/pytorch_onnx_infer.py camera --config_path 配置文件地址 --onnx_path 打包完成的onnx包地址 --camera_id 摄像头id,默认为0
109 | ```
110 |
111 | 5. 模型剪枝、量化压缩加速
112 |
113 | 1. 模型剪枝微调
114 |
115 | ```
116 | # 模型剪枝微调
117 | python prune_model/pruning_model.py --weight_path 已训练好的模型权重地址 --prune_type 修剪模型的方式,支持:l1filter,l2filter,fpgm --sparsity 模型稀疏化比例 --finetune_epoches 微调模型的轮次数 --dummy_input 输入模型的形状,例如:(10,3,128,128)
118 |
119 | # onnx推理部署
120 | # 检测图片
121 | python infer_prune_model.py image --prune_weights_path 剪枝后的模型权重路径 --image_path 图片地址
122 |
123 | # 检测视频
124 | python infer_prune_model.py video --prune_weights_path 剪枝后的模型权重路径 --video_path 图片地址
125 |
126 | # 检测摄像头
127 | python infer_prune_model.py camera --prune_weights_path 剪枝后的模型权重路径 --camera_id 摄像头id,默认为0
128 | ```
129 |
130 |
131 | 参与贡献
132 |
133 | 作者:qiaofengsheng
134 |
135 | B站地址:https://space.bilibili.com/241747799
136 |
137 | github地址:https://github.com/qiaofengsheng/Pytorch-Image-Classifier-Collection.git
138 |
139 | gitee地址:https://gitee.com/qiaofengsheng/pytorch-image-classifier-collection.git
--------------------------------------------------------------------------------
/config/config.yaml:
--------------------------------------------------------------------------------
1 | data_dir: "./data/" #数据集存放地址
2 | train_rate: 0.8 #数据集划分,训练集比例
3 | image_size: 128 #输入网络图像大小
4 | net_type: "efficientnet-b0"
5 | #支持模型[resnet18,resnet34,resnet50,resnet101,resnet152,resnext101_32x8d,resnext50_32x4d,wide_resnet50_2,wide_resnet101_2,
6 | # densenet121,densenet161,densenet169,densenet201,vgg11,vgg13,vgg13_bn,vgg19,vgg19_bn,vgg16,vgg16_bn,inception_v3,
7 | # mobilenet_v2,mobilenet_v3_small,mobilenet_v3_large,shufflenet_v2_x0_5,shufflenet_v2_x1_0,shufflenet_v2_x1_5,
8 | # shufflenet_v2_x2_0,alexnet,googlenet,mnasnet0_5,mnasnet1_0,mnasnet1_3,mnasnet0_75,squeezenet1_0,squeezenet1_1]
9 | # efficientnet-b0 ... efficientnet-b7
10 | pretrained: False #是否添加预训练权重
11 | batch_size: 50 #批次
12 | init_lr: 0.01 #初始学习率
13 | optimizer: 'Adam' #优化器 [SGD,ASGD,Adam,AdamW,Adamax,Adagrad,Adadelta,SparseAdam,LBFGS,Rprop,RMSprop]
14 | class_names: [ 'cat','dog' ] #你的类别名称,必须和data文件夹下的类别文件名一样
15 | epochs: 15 #训练总轮次
16 | loss_type: "cross_entropy" # mse / l1 / smooth_l1 / cross_entropy #损失函数
17 | model_dir: "./efficientnet-b0/weight/" #权重存放地址
18 | log_dir: "./efficientnet-b0/logs/" # tensorboard可视化文件存放地址
--------------------------------------------------------------------------------
/data/cat/1.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/qiaofengsheng/Pytorch-Image-Classifier-Collection/b95a07451c6c169639af9a4f6f5e074055570828/data/cat/1.jpg
--------------------------------------------------------------------------------
/data/cat/10.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/qiaofengsheng/Pytorch-Image-Classifier-Collection/b95a07451c6c169639af9a4f6f5e074055570828/data/cat/10.jpg
--------------------------------------------------------------------------------
/data/cat/2.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/qiaofengsheng/Pytorch-Image-Classifier-Collection/b95a07451c6c169639af9a4f6f5e074055570828/data/cat/2.jpg
--------------------------------------------------------------------------------
/data/cat/3.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/qiaofengsheng/Pytorch-Image-Classifier-Collection/b95a07451c6c169639af9a4f6f5e074055570828/data/cat/3.jpg
--------------------------------------------------------------------------------
/data/cat/4.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/qiaofengsheng/Pytorch-Image-Classifier-Collection/b95a07451c6c169639af9a4f6f5e074055570828/data/cat/4.jpg
--------------------------------------------------------------------------------
/data/cat/5.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/qiaofengsheng/Pytorch-Image-Classifier-Collection/b95a07451c6c169639af9a4f6f5e074055570828/data/cat/5.jpg
--------------------------------------------------------------------------------
/data/cat/7.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/qiaofengsheng/Pytorch-Image-Classifier-Collection/b95a07451c6c169639af9a4f6f5e074055570828/data/cat/7.jpg
--------------------------------------------------------------------------------
/data/cat/8.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/qiaofengsheng/Pytorch-Image-Classifier-Collection/b95a07451c6c169639af9a4f6f5e074055570828/data/cat/8.jpg
--------------------------------------------------------------------------------
/data/cat/9.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/qiaofengsheng/Pytorch-Image-Classifier-Collection/b95a07451c6c169639af9a4f6f5e074055570828/data/cat/9.jpg
--------------------------------------------------------------------------------
/data/dog/1.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/qiaofengsheng/Pytorch-Image-Classifier-Collection/b95a07451c6c169639af9a4f6f5e074055570828/data/dog/1.jpg
--------------------------------------------------------------------------------
/data/dog/10.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/qiaofengsheng/Pytorch-Image-Classifier-Collection/b95a07451c6c169639af9a4f6f5e074055570828/data/dog/10.jpg
--------------------------------------------------------------------------------
/data/dog/2.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/qiaofengsheng/Pytorch-Image-Classifier-Collection/b95a07451c6c169639af9a4f6f5e074055570828/data/dog/2.jpg
--------------------------------------------------------------------------------
/data/dog/3.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/qiaofengsheng/Pytorch-Image-Classifier-Collection/b95a07451c6c169639af9a4f6f5e074055570828/data/dog/3.jpg
--------------------------------------------------------------------------------
/data/dog/4.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/qiaofengsheng/Pytorch-Image-Classifier-Collection/b95a07451c6c169639af9a4f6f5e074055570828/data/dog/4.jpg
--------------------------------------------------------------------------------
/data/dog/5.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/qiaofengsheng/Pytorch-Image-Classifier-Collection/b95a07451c6c169639af9a4f6f5e074055570828/data/dog/5.jpg
--------------------------------------------------------------------------------
/data/dog/6.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/qiaofengsheng/Pytorch-Image-Classifier-Collection/b95a07451c6c169639af9a4f6f5e074055570828/data/dog/6.jpg
--------------------------------------------------------------------------------
/data/dog/7.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/qiaofengsheng/Pytorch-Image-Classifier-Collection/b95a07451c6c169639af9a4f6f5e074055570828/data/dog/7.jpg
--------------------------------------------------------------------------------
/data/dog/8.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/qiaofengsheng/Pytorch-Image-Classifier-Collection/b95a07451c6c169639af9a4f6f5e074055570828/data/dog/8.jpg
--------------------------------------------------------------------------------
/data/dog/9.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/qiaofengsheng/Pytorch-Image-Classifier-Collection/b95a07451c6c169639af9a4f6f5e074055570828/data/dog/9.jpg
--------------------------------------------------------------------------------
/infer.py:
--------------------------------------------------------------------------------
1 | '''
2 | _*_coding:utf-8 _*_
3 | @Time :2022/1/29 10:52
4 | @Author : qiaofengsheng
5 | @File :infer.py
6 | @Software :PyCharm
7 | '''
8 | import os
9 |
10 | from PIL import Image, ImageDraw, ImageFont
11 | import cv2
12 | import torch
13 | from model.utils import utils
14 | from torchvision import transforms
15 | from model.net.net import *
16 | import argparse
17 |
18 | parse = argparse.ArgumentParser('infer models')
19 | parse.add_argument('demo', type=str, help='推理类型支持:image/video/camera')
20 | parse.add_argument('--weights_path', type=str, default='', help='模型权重路径')
21 | parse.add_argument('--image_path', type=str, default='', help='图片存放路径')
22 | parse.add_argument('--video_path', type=str, default='', help='视频路径')
23 | parse.add_argument('--camera_id', type=int, default=0, help='摄像头id')
24 |
25 |
26 | class ModelInfer:
27 | def __init__(self, config, weights_path):
28 | self.config = config
29 | self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
30 |
31 | self.transform = transforms.Compose([
32 | transforms.ToTensor()
33 | ])
34 | self.net = ClassifierNet(self.config['net_type'], len(self.config['class_names']),
35 | False).to(self.device)
36 | if weights_path is not None:
37 | if os.path.exists(weights_path):
38 | self.net.load_state_dict(torch.load(weights_path))
39 | print('successfully loading model weights!')
40 | else:
41 | print('no loading model weights!')
42 | else:
43 | print('please input weights_path!')
44 | exit(0)
45 | self.net.eval()
46 |
47 | def image_infer(self, image_path):
48 | image = Image.open(image_path)
49 | image_data = utils.keep_shape_resize(image, self.config['image_size'])
50 | image_data = self.transform(image_data)
51 | image_data = torch.unsqueeze(image_data, dim=0).to(self.device)
52 | out = self.net(image_data)
53 | out = torch.argmax(out)
54 | result = self.config['class_names'][int(out)]
55 | draw = ImageDraw.Draw(image)
56 | font = ImageFont.truetype(r"C:\Windows\Fonts\BRITANIC.TTF", 35)
57 | draw.text((10, 10), result, font=font, fill='red')
58 | image.show()
59 |
60 | def video_infer(self, video_path):
61 | cap = cv2.VideoCapture(video_path)
62 | while True:
63 | _, frame = cap.read()
64 | if _:
65 | image_data = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
66 | image_data = Image.fromarray(image_data)
67 | image_data = utils.keep_shape_resize(image_data, self.config['image_size'])
68 | image_data = self.transform(image_data)
69 | image_data = torch.unsqueeze(image_data, dim=0).to(self.device)
70 | out = self.net(image_data)
71 | out = torch.argmax(out)
72 | result = self.config['class_names'][int(out)]
73 | cv2.putText(frame, result, (50, 50), cv2.FONT_HERSHEY_SIMPLEX, 2, (0, 0, 255), thickness=2)
74 | cv2.imshow('frame', frame)
75 | if cv2.waitKey(24) & 0XFF == ord('q'):
76 | break
77 | else:
78 | break
79 |
80 | def camera_infer(self, camera_id):
81 | cap = cv2.VideoCapture(camera_id)
82 | while True:
83 | _, frame = cap.read()
84 | h, w, c = frame.shape
85 | if _:
86 | image_data = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
87 | image_data = Image.fromarray(image_data)
88 | image_data = utils.keep_shape_resize(image_data, self.config['image_size'])
89 | image_data = self.transform(image_data)
90 | image_data = torch.unsqueeze(image_data, dim=0).to(self.device)
91 | out = self.net(image_data)
92 | out = torch.argmax(out)
93 | result = self.config['class_names'][int(out)]
94 | cv2.putText(frame, result, (50, 50), cv2.FONT_HERSHEY_SIMPLEX, 2, (0, 0, 255), thickness=2)
95 | cv2.imshow('frame', frame)
96 | if cv2.waitKey(24) & 0XFF == ord('q'):
97 | break
98 | else:
99 | break
100 |
101 |
102 | if __name__ == '__main__':
103 | args = parse.parse_args()
104 | config = utils.load_config_util('config/config.yaml')
105 | model = ModelInfer(config, args.weights_path)
106 | if args.demo == 'image':
107 | model.image_infer(args.image_path)
108 | elif args.demo == 'video':
109 | model.video_infer(args.video_path)
110 | elif args.demo == 'camera':
111 | model.camera_infer(args.camera_id)
112 | else:
113 | exit(0)
114 |
--------------------------------------------------------------------------------
/infer_prune_model.py:
--------------------------------------------------------------------------------
1 | '''
2 | ==================板块功能描述====================
3 | @Time :2022/2/8 10:04
4 | @Author : qiaofengsheng
5 | @File :infer_prune_model.py
6 | @Software :PyCharm
7 | @description:
8 | ================================================
9 | '''
10 | '''
11 | _*_coding:utf-8 _*_
12 | @Time :2022/1/29 10:52
13 | @Author : qiaofengsheng
14 | @File :infer.py
15 | @Software :PyCharm
16 | '''
17 | import os
18 |
19 | from PIL import Image, ImageDraw, ImageFont
20 | import cv2
21 | import torch
22 | from model.utils import utils
23 | from torchvision import transforms
24 | from model.net.net import *
25 | import argparse
26 |
27 | parse = argparse.ArgumentParser('infer models')
28 | parse.add_argument('demo', type=str, help='推理类型支持:image/video/camera')
29 | parse.add_argument('--prune_weights_path', type=str, default='', help='剪枝后的模型权重路径')
30 | parse.add_argument('--image_path', type=str, default='', help='图片存放路径')
31 | parse.add_argument('--video_path', type=str, default='', help='视频路径')
32 | parse.add_argument('--camera_id', type=int, default=0, help='摄像头id')
33 |
34 |
35 | class ModelInfer:
36 | def __init__(self, config, weights_path):
37 | self.config = config
38 | self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
39 |
40 | self.transform = transforms.Compose([
41 | transforms.ToTensor()
42 | ])
43 | self.net = torch.load(weights_path)
44 | self.net.eval()
45 |
46 | def image_infer(self, image_path):
47 | image = Image.open(image_path)
48 | image_data = utils.keep_shape_resize(image, self.config['image_size'])
49 | image_data = self.transform(image_data)
50 | image_data = torch.unsqueeze(image_data, dim=0).to(self.device)
51 | out = self.net(image_data)
52 | out = torch.argmax(out)
53 | result = self.config['class_names'][int(out)]
54 | draw = ImageDraw.Draw(image)
55 | font = ImageFont.truetype(r"C:\Windows\Fonts\BRITANIC.TTF", 35)
56 | draw.text((10, 10), result, font=font, fill='red')
57 | image.show()
58 |
59 | def video_infer(self, video_path):
60 | cap = cv2.VideoCapture(video_path)
61 | while True:
62 | _, frame = cap.read()
63 | if _:
64 | image_data = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
65 | image_data = Image.fromarray(image_data)
66 | image_data = utils.keep_shape_resize(image_data, self.config['image_size'])
67 | image_data = self.transform(image_data)
68 | image_data = torch.unsqueeze(image_data, dim=0).to(self.device)
69 | out = self.net(image_data)
70 | out = torch.argmax(out)
71 | result = self.config['class_names'][int(out)]
72 | cv2.putText(frame, result, (50, 50), cv2.FONT_HERSHEY_SIMPLEX, 2, (0, 0, 255), thickness=2)
73 | cv2.imshow('frame', frame)
74 | if cv2.waitKey(24) & 0XFF == ord('q'):
75 | break
76 | else:
77 | break
78 |
79 | def camera_infer(self, camera_id):
80 | cap = cv2.VideoCapture(camera_id)
81 | while True:
82 | _, frame = cap.read()
83 | h, w, c = frame.shape
84 | if _:
85 | image_data = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
86 | image_data = Image.fromarray(image_data)
87 | image_data = utils.keep_shape_resize(image_data, self.config['image_size'])
88 | image_data = self.transform(image_data)
89 | image_data = torch.unsqueeze(image_data, dim=0).to(self.device)
90 | out = self.net(image_data)
91 | out = torch.argmax(out)
92 | result = self.config['class_names'][int(out)]
93 | cv2.putText(frame, result, (50, 50), cv2.FONT_HERSHEY_SIMPLEX, 2, (0, 0, 255), thickness=2)
94 | cv2.imshow('frame', frame)
95 | if cv2.waitKey(24) & 0XFF == ord('q'):
96 | break
97 | else:
98 | break
99 |
100 |
101 | if __name__ == '__main__':
102 | args = parse.parse_args()
103 | config = utils.load_config_util('config/config.yaml')
104 | model = ModelInfer(config, args.prune_weights_path)
105 | if args.demo == 'image':
106 | model.image_infer(args.image_path)
107 | elif args.demo == 'video':
108 | model.video_infer(args.video_path)
109 | elif args.demo == 'camera':
110 | model.camera_infer(args.camera_id)
111 | else:
112 | exit(0)
113 |
--------------------------------------------------------------------------------
/model/dataset/__pycache__/dataset.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/qiaofengsheng/Pytorch-Image-Classifier-Collection/b95a07451c6c169639af9a4f6f5e074055570828/model/dataset/__pycache__/dataset.cpython-38.pyc
--------------------------------------------------------------------------------
/model/dataset/dataset.py:
--------------------------------------------------------------------------------
1 | '''
2 | _*_coding:utf-8 _*_
3 | @Time :2022/1/28 19:00
4 | @Author : qiaofengsheng
5 | @File :dataset.py
6 | @Software :PyCharm
7 | '''
8 | import os
9 |
10 | from PIL import Image
11 | from torch.utils.data import *
12 | from model.utils import utils
13 | from torchvision import transforms
14 |
15 |
16 | class ClassDataset(Dataset):
17 | def __init__(self, data_dir, config):
18 | self.config = config
19 | self.transform = transforms.Compose([
20 | transforms.RandomRotation(60),
21 | transforms.ToTensor()
22 | ])
23 | self.dataset = []
24 | class_dirs = os.listdir(data_dir)
25 | for class_dir in class_dirs:
26 | image_names = os.listdir(os.path.join(data_dir, class_dir))
27 | for image_name in image_names:
28 | self.dataset.append(
29 | [os.path.join(data_dir, class_dir, image_name),
30 | int(config['class_names'].index(class_dir))])
31 |
32 | def __len__(self):
33 | return len(self.dataset)
34 |
35 | def __getitem__(self, index):
36 | data = self.dataset[index]
37 | image_path, image_label = data
38 | image = Image.open(image_path)
39 | image = utils.keep_shape_resize(image, self.config['image_size'])
40 | return self.transform(image), image_label
41 |
--------------------------------------------------------------------------------
/model/loss/__pycache__/loss_fun.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/qiaofengsheng/Pytorch-Image-Classifier-Collection/b95a07451c6c169639af9a4f6f5e074055570828/model/loss/__pycache__/loss_fun.cpython-38.pyc
--------------------------------------------------------------------------------
/model/loss/loss_fun.py:
--------------------------------------------------------------------------------
1 | '''
2 | _*_coding:utf-8 _*_
3 | @Time :2022/1/28 19:05
4 | @Author : qiaofengsheng
5 | @File :loss_fun.py
6 | @Software :PyCharm
7 | '''
8 |
9 | from torch import nn
10 |
11 |
12 | class Loss:
13 | def __init__(self, loss_type='mse'):
14 | self.loss_fun = nn.MSELoss()
15 | if loss_type == 'mse':
16 | self.loss_fun = nn.MSELoss()
17 | elif loss_type == 'l1':
18 | self.loss_fun = nn.L1Loss()
19 | elif loss_type == 'smooth_l1':
20 | self.loss_fun = nn.SmoothL1Loss()
21 | elif loss_type == 'cross_entropy':
22 | self.loss_fun = nn.CrossEntropyLoss()
23 |
24 | def get_loss_fun(self):
25 | return self.loss_fun
26 |
--------------------------------------------------------------------------------
/model/net/__pycache__/net.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/qiaofengsheng/Pytorch-Image-Classifier-Collection/b95a07451c6c169639af9a4f6f5e074055570828/model/net/__pycache__/net.cpython-38.pyc
--------------------------------------------------------------------------------
/model/net/net.py:
--------------------------------------------------------------------------------
1 | '''
2 | _*_coding:utf-8 _*_
3 | @Time :2022/1/28 19:05
4 | @Author : qiaofengsheng
5 | @File :net.py
6 | @Software :PyCharm
7 | '''
8 |
9 | import torch
10 | from torchvision import models
11 | from torch import nn
12 | from efficientnet_pytorch import EfficientNet
13 |
14 |
15 | class ClassifierNet(nn.Module):
16 | def __init__(self, net_type='resnet18', num_classes=10, pretrained=False):
17 | super(ClassifierNet, self).__init__()
18 | self.layer = None
19 | if net_type == 'resnet18': self.layer = nn.Sequential(models.resnet18(pretrained=pretrained,num_classes=num_classes), )
20 | if net_type == 'resnet34': self.layer = nn.Sequential(models.resnet34(pretrained=pretrained,num_classes=num_classes), )
21 | if net_type == 'resnet50': self.layer = nn.Sequential(models.resnet50(pretrained=pretrained,num_classes=num_classes), )
22 | if net_type == 'resnet101': self.layer = nn.Sequential(models.resnet101(pretrained=pretrained,num_classes=num_classes), )
23 | if net_type == 'resnet152': self.layer = nn.Sequential(models.resnet152(pretrained=pretrained,num_classes=num_classes), )
24 | if net_type == 'resnext101_32x8d': self.layer = nn.Sequential(models.resnext101_32x8d(pretrained=pretrained,num_classes=num_classes), )
25 | if net_type == 'resnext50_32x4d': self.layer = nn.Sequential(models.resnext50_32x4d(pretrained=pretrained,num_classes=num_classes), )
26 | if net_type == 'wide_resnet50_2': self.layer = nn.Sequential(models.wide_resnet50_2(pretrained=pretrained,num_classes=num_classes), )
27 | if net_type == 'wide_resnet101_2': self.layer = nn.Sequential(models.wide_resnet101_2(pretrained=pretrained,num_classes=num_classes), )
28 | if net_type == 'densenet121': self.layer = nn.Sequential(models.densenet121(pretrained=pretrained,num_classes=num_classes), )
29 | if net_type == 'densenet161': self.layer = nn.Sequential(models.densenet161(pretrained=pretrained,num_classes=num_classes), )
30 | if net_type == 'densenet169': self.layer = nn.Sequential(models.densenet169(pretrained=pretrained,num_classes=num_classes), )
31 | if net_type == 'densenet201': self.layer = nn.Sequential(models.densenet201(pretrained=pretrained,num_classes=num_classes), )
32 | if net_type == 'vgg11': self.layer = nn.Sequential(models.vgg11(pretrained=pretrained,num_classes=num_classes), )
33 | if net_type == 'vgg13': self.layer = nn.Sequential(models.vgg13(pretrained=pretrained,num_classes=num_classes), )
34 | if net_type == 'vgg13_bn': self.layer = nn.Sequential(models.vgg13_bn(pretrained=pretrained,num_classes=num_classes), )
35 | if net_type == 'vgg19': self.layer = nn.Sequential(models.vgg19(pretrained=pretrained,num_classes=num_classes), )
36 | if net_type == 'vgg19_bn': self.layer = nn.Sequential(models.vgg19_bn(pretrained=pretrained,num_classes=num_classes), )
37 | if net_type == 'vgg16': self.layer = nn.Sequential(models.vgg16(pretrained=pretrained,num_classes=num_classes), )
38 | if net_type == 'vgg16_bn': self.layer = nn.Sequential(models.vgg16_bn(pretrained=pretrained,num_classes=num_classes), )
39 | if net_type == 'inception_v3': self.layer = nn.Sequential(models.inception_v3(pretrained=pretrained,num_classes=num_classes), )
40 | if net_type == 'mobilenet_v2': self.layer = nn.Sequential(models.mobilenet_v2(pretrained=pretrained,num_classes=num_classes), )
41 | if net_type == 'mobilenet_v3_small': self.layer = nn.Sequential(
42 | models.mobilenet_v3_small(pretrained=pretrained,num_classes=num_classes), )
43 | if net_type == 'mobilenet_v3_large': self.layer = nn.Sequential(
44 | models.mobilenet_v3_large(pretrained=pretrained,num_classes=num_classes), )
45 | if net_type == 'shufflenet_v2_x0_5': self.layer = nn.Sequential(
46 | models.shufflenet_v2_x0_5(pretrained=pretrained,num_classes=num_classes), )
47 | if net_type == 'shufflenet_v2_x1_0': self.layer = nn.Sequential(
48 | models.shufflenet_v2_x1_0(pretrained=pretrained,num_classes=num_classes), )
49 | if net_type == 'shufflenet_v2_x1_5': self.layer = nn.Sequential(
50 | models.shufflenet_v2_x1_5(pretrained=pretrained,num_classes=num_classes), )
51 | if net_type == 'shufflenet_v2_x2_0': self.layer = nn.Sequential(
52 | models.shufflenet_v2_x2_0(pretrained=pretrained,num_classes=num_classes), )
53 | if net_type == 'alexnet':
54 | self.layer = nn.Sequential(models.alexnet(pretrained=pretrained,num_classes=num_classes), )
55 | if net_type == 'googlenet':
56 | self.layer = nn.Sequential(models.googlenet(pretrained=pretrained,num_classes=num_classes), )
57 | if net_type == 'mnasnet0_5':
58 | self.layer = nn.Sequential(models.mnasnet0_5(pretrained=pretrained,num_classes=num_classes), )
59 | if net_type == 'mnasnet1_0':
60 | self.layer = nn.Sequential(models.mnasnet1_0(pretrained=pretrained,num_classes=num_classes), )
61 | if net_type == 'mnasnet1_3':
62 | self.layer = nn.Sequential(models.mnasnet1_3(pretrained=pretrained,num_classes=num_classes), )
63 | if net_type == 'mnasnet0_75':
64 | self.layer = nn.Sequential(models.mnasnet0_75(pretrained=pretrained,num_classes=num_classes), )
65 | if net_type == 'squeezenet1_0':
66 | self.layer = nn.Sequential(models.squeezenet1_0(pretrained=pretrained,num_classes=num_classes), )
67 | if net_type == 'squeezenet1_1':
68 | self.layer = nn.Sequential(models.squeezenet1_1(pretrained=pretrained,num_classes=num_classes), )
69 | if net_type in ['efficientnet-b0', 'efficientnet-b1', 'efficientnet-b2', 'efficientnet-b3', 'efficientnet-b4',
70 | 'efficientnet-b5', 'efficientnet-b6']:
71 | if pretrained:
72 | self.layer = nn.Sequential(EfficientNet.from_pretrained(net_type,num_classes=num_classes))
73 | else:
74 | self.layer = nn.Sequential(EfficientNet.from_name(net_type,num_classes=num_classes))
75 |
76 | def forward(self, x):
77 | return self.layer(x)
78 |
79 | if __name__ == '__main__':
80 | net=ClassifierNet('mnasnet1_0',pretrained=False)
81 | x=torch.randn(1,3,125,125)
82 | print(net(x).shape)
83 |
--------------------------------------------------------------------------------
/model/optimizer/__pycache__/optim.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/qiaofengsheng/Pytorch-Image-Classifier-Collection/b95a07451c6c169639af9a4f6f5e074055570828/model/optimizer/__pycache__/optim.cpython-38.pyc
--------------------------------------------------------------------------------
/model/optimizer/optim.py:
--------------------------------------------------------------------------------
1 | '''
2 | _*_coding:utf-8 _*_
3 | @Time :2022/1/28 19:06
4 | @Author : qiaofengsheng
5 | @File :optim.py
6 | @Software :PyCharm
7 | '''
8 |
9 | from torch import optim
10 |
11 |
12 | class Optimizer:
13 | def __init__(self, net, opt_type='Adam'):
14 | super(Optimizer, self).__init__()
15 | self.opt = optim.Adam(net.parameters())
16 | if opt_type == 'SGD':
17 | self.opt = optim.SGD(net.parameters(), lr=0.01)
18 | elif opt_type == 'ASGD':
19 | self.opt = optim.ASGD(net.parameters())
20 | elif opt_type == 'Adam':
21 | self.opt = optim.Adam(net.parameters())
22 | elif opt_type == 'AdamW':
23 | self.opt = optim.AdamW(net.parameters())
24 | elif opt_type == 'Adamax':
25 | self.opt = optim.Adamax(net.parameters())
26 | elif opt_type == 'Adagrad':
27 | self.opt = optim.Adagrad(net.parameters())
28 | elif opt_type == 'Adadelta':
29 | self.opt = optim.Adadelta(net.parameters())
30 | elif opt_type == 'SparseAdam':
31 | self.opt = optim.SparseAdam(net.parameters())
32 | elif opt_type == 'LBFGS':
33 | self.opt = optim.LBFGS(net.parameters())
34 | elif opt_type == 'Rprop':
35 | self.opt = optim.Rprop(net.parameters())
36 | elif opt_type == 'RMSprop':
37 | self.opt = optim.RMSprop(net.parameters())
38 |
39 | def get_optimizer(self):
40 | return self.opt
41 |
--------------------------------------------------------------------------------
/model/utils/__pycache__/utils.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/qiaofengsheng/Pytorch-Image-Classifier-Collection/b95a07451c6c169639af9a4f6f5e074055570828/model/utils/__pycache__/utils.cpython-38.pyc
--------------------------------------------------------------------------------
/model/utils/utils.py:
--------------------------------------------------------------------------------
1 | '''
2 | _*_coding:utf-8 _*_
3 | @Time :2022/1/28 19:58
4 | @Author : qiaofengsheng
5 | @File :utils.py
6 | @Software :PyCharm
7 | '''
8 | import torch
9 | import yaml
10 | from PIL import Image
11 | from torch.nn.functional import one_hot
12 |
13 |
14 | def load_config_util(config_path):
15 | config_file = open(config_path, 'r', encoding='utf-8')
16 | config_data = yaml.load(config_file)
17 | return config_data
18 |
19 |
20 | def keep_shape_resize(frame, size=256):
21 | w, h = frame.size
22 | temp = max(w, h)
23 | mask = Image.new('RGB', (temp, temp), (0, 0, 0))
24 | if w >= h:
25 | position = (0, (w - h) // 2)
26 | else:
27 | position = ((h - w) // 2, 0)
28 | mask.paste(frame, position)
29 | mask = mask.resize((size, size))
30 | return mask
31 |
32 |
33 | def label_one_hot(label):
34 | return one_hot(torch.tensor(label))
35 |
--------------------------------------------------------------------------------
/pack_tools/pytorch_onnx_infer.py:
--------------------------------------------------------------------------------
1 | '''
2 | _*_coding:utf-8 _*_
3 | @Time :2022/1/30 10:28
4 | @Author : qiaofengsheng
5 | @File :pytorch_onnx_infer.py
6 | @Software :PyCharm
7 | '''
8 | import os
9 | import sys
10 |
11 | import numpy as np
12 |
13 | sys.path.append(os.path.abspath(os.path.dirname(os.path.dirname(__file__))))
14 | import cv2
15 | import onnxruntime
16 | import argparse
17 | from PIL import Image, ImageDraw, ImageFont
18 | from torchvision import transforms
19 | import torch
20 | from model.utils import utils
21 |
22 | parse = argparse.ArgumentParser(description='onnx model infer!')
23 | parse.add_argument('demo', type=str, help='推理类型支持:image/video/camera')
24 | parse.add_argument('--config_path', type=str, help='配置文件存放地址')
25 | parse.add_argument('--onnx_path', type=str, default=None, help='onnx包存放路径')
26 | parse.add_argument('--image_path', type=str, default='', help='图片存放路径')
27 | parse.add_argument('--video_path', type=str, default='', help='视频路径')
28 | parse.add_argument('--camera_id', type=int, default=0, help='摄像头id')
29 | parse.add_argument('--device', type=str, default='cpu', help='默认设备cpu (暂未完善GPU代码)')
30 |
31 |
32 | def to_numpy(tensor):
33 | return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()
34 |
35 |
36 | def onnx_infer_image(args, config):
37 | ort_session = onnxruntime.InferenceSession(args.onnx_path)
38 | transform = transforms.Compose([transforms.ToTensor()])
39 | image = Image.open(args.image_path)
40 | image_data = utils.keep_shape_resize(image, config['image_size'])
41 | image_data = transform(image_data)
42 | image_data = torch.unsqueeze(image_data, dim=0)
43 | if args.device == 'cpu':
44 | ort_input = {ort_session.get_inputs()[0].name: to_numpy(image_data)}
45 | ort_out = ort_session.run(None, ort_input)
46 | out = np.argmax(ort_out[0], axis=1)
47 | result = config['class_names'][int(out)]
48 | draw = ImageDraw.Draw(image)
49 | font = ImageFont.truetype(r"C:\Windows\Fonts\BRITANIC.TTF", 35)
50 | draw.text((10, 10), result, font=font, fill='red')
51 | image.show()
52 | elif args.device == 'cuda':
53 | pass
54 | else:
55 | exit(0)
56 |
57 |
58 | def onnx_infer_video(args, config):
59 | ort_session = onnxruntime.InferenceSession(args.onnx_path)
60 | transform = transforms.Compose([transforms.ToTensor()])
61 | cap = cv2.VideoCapture(args.video_path)
62 | while True:
63 | _, frame = cap.read()
64 | if _:
65 | image_data = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
66 | image_data = Image.fromarray(image_data)
67 | image_data = utils.keep_shape_resize(image_data, config['image_size'])
68 | image_data = transform(image_data)
69 | image_data = torch.unsqueeze(image_data, dim=0)
70 | if args.device == 'cpu':
71 | ort_input = {ort_session.get_inputs()[0].name: to_numpy(image_data)}
72 | ort_out = ort_session.run(None, ort_input)
73 | out = np.argmax(ort_out[0], axis=1)
74 | result = config['class_names'][int(out)]
75 | cv2.putText(frame, result, (50, 50), cv2.FONT_HERSHEY_SIMPLEX, 2, (0, 0, 255), thickness=2)
76 | cv2.imshow('frame', frame)
77 | if cv2.waitKey(24) & 0XFF == ord('q'):
78 | break
79 | elif args.device == 'cuda':
80 | pass
81 | else:
82 | exit(0)
83 | else:
84 | exit(0)
85 |
86 |
87 | def onnx_infer_camera(args, config):
88 | ort_session = onnxruntime.InferenceSession(args.onnx_path)
89 | transform = transforms.Compose([transforms.ToTensor()])
90 | cap = cv2.VideoCapture(args.camera_id)
91 | while True:
92 | _, frame = cap.read()
93 | if _:
94 | image_data = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
95 | image_data = Image.fromarray(image_data)
96 | image_data = utils.keep_shape_resize(image_data, config['image_size'])
97 | image_data = transform(image_data)
98 | image_data = torch.unsqueeze(image_data, dim=0)
99 | if args.device == 'cpu':
100 | ort_input = {ort_session.get_inputs()[0].name: to_numpy(image_data)}
101 | ort_out = ort_session.run(None, ort_input)
102 | out = np.argmax(ort_out[0], axis=1)
103 | result = config['class_names'][int(out)]
104 | cv2.putText(frame, result, (50, 50), cv2.FONT_HERSHEY_SIMPLEX, 2, (0, 0, 255), thickness=2)
105 | cv2.imshow('frame', frame)
106 | if cv2.waitKey(24) & 0XFF == ord('q'):
107 | break
108 | elif args.device == 'cuda':
109 | pass
110 | else:
111 | exit(0)
112 | else:
113 | exit(0)
114 |
115 |
116 | if __name__ == '__main__':
117 | args = parse.parse_args()
118 | config = utils.load_config_util(args.config_path)
119 | if args.demo == 'image':
120 | onnx_infer_image(args, config)
121 | elif args.demo == 'video':
122 | onnx_infer_video(args, config)
123 | elif args.demo == 'camera':
124 | onnx_infer_camera(args, config)
125 | else:
126 | exit(0)
127 |
--------------------------------------------------------------------------------
/pack_tools/pytorch_to_onnx.py:
--------------------------------------------------------------------------------
1 | '''
2 | _*_coding:utf-8 _*_
3 | @Time :2022/1/29 19:00
4 | @Author : qiaofengsheng
5 | @File :pytorch_to_onnx.py
6 | @Software :PyCharm
7 | '''
8 | import os
9 | import sys
10 |
11 | sys.path.append(os.path.abspath(os.path.dirname(os.path.dirname(__file__))))
12 | import numpy as np
13 | import torch.onnx
14 | import torch.cuda
15 | import onnx, onnxruntime
16 | from model.net.net import *
17 | from model.utils import utils
18 | import argparse
19 |
20 | parse = argparse.ArgumentParser(description='pack onnx model')
21 | parse.add_argument('--config_path', type=str, default='', help='配置文件存放地址')
22 | parse.add_argument('--weights_path', type=str, default='', help='模型权重文件地址')
23 |
24 |
25 | def pack_onnx(model_path, config):
26 | model = ClassifierNet(config['net_type'], len(config['class_names']),
27 | False)
28 | map_location = lambda storage, loc: storage
29 | if torch.cuda.is_available():
30 | map_location = None
31 | model.load_state_dict(torch.load(model_path, map_location=map_location))
32 | model.eval()
33 | batch_size = 1
34 | input = torch.randn(batch_size, 3, 128, 128, requires_grad=True)
35 | output = model(input)
36 | torch.onnx.export(model,
37 | input,
38 | config['net_type'] + '.onnx',
39 | export_params=True,
40 | opset_version=11,
41 | do_constant_folding=True,
42 | input_names=['input'],
43 | output_names=['output'],
44 | dynamic_axes={
45 | 'input': {0: 'batch_size'},
46 | 'output': {0: 'batch_size'}
47 | }
48 | )
49 | print('onnx打包成功!')
50 | output = model(input)
51 | onnx_model = onnx.load(config['net_type'] + '.onnx')
52 | onnx.checker.check_model(onnx_model)
53 | ort_session = onnxruntime.InferenceSession(config['net_type'] + '.onnx')
54 |
55 | def to_numpy(tensor):
56 | return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy
57 |
58 | ort_input = {ort_session.get_inputs()[0].name: to_numpy(input)}
59 | ort_output = ort_session.run(None, ort_input)
60 |
61 | np.testing.assert_allclose(to_numpy(output), ort_output[0], rtol=1e-03, atol=1e-05)
62 | print("Exported model has been tested with ONNXRuntime, and the result looks good!")
63 |
64 |
65 | if __name__ == '__main__':
66 | args = parse.parse_args()
67 | config = utils.load_config_util(args.config_path)
68 | pack_onnx(args.weights_path, config)
69 |
--------------------------------------------------------------------------------
/prune_model/pruning_model.py:
--------------------------------------------------------------------------------
1 | '''
2 | ==================板块功能描述====================
3 | @Time :2022/2/7 15:55
4 | @Author : qiaofengsheng
5 | @File :pruning_model.py
6 | @Software :PyCharm
7 | @description:模型剪枝、量化压缩
8 | 支持FPGMPruner,L1FilterPruner,L2FilterPruner裁剪方式
9 | 其他裁剪方式待完善
10 | ================================================
11 | '''
12 |
13 | # 微调模型训练函数
14 | import sys
15 | import os
16 |
17 | import torch
18 |
19 | sys.path.append(os.path.abspath(os.path.dirname(os.path.dirname(__file__))))
20 | import argparse
21 | from nni.compression.pytorch import ModelSpeedup
22 | from model.dataset.dataset import *
23 | from model.loss.loss_fun import Loss
24 | from model.optimizer.optim import Optimizer
25 | from model.utils.utils import *
26 | from model.net.net import *
27 | from nni.algorithms.compression.pytorch.pruning import (
28 | FPGMPruner,
29 | L1FilterPruner,
30 | L2FilterPruner,
31 | LevelPruner,
32 | SlimPruner,
33 | AGPPruner,
34 | TaylorFOWeightFilterPruner,
35 | ActivationMeanRankFilterPruner,
36 | ActivationAPoZRankFilterPruner
37 | )
38 |
39 | prune_save_path = 'prune_model/prune_checkpoints'
40 | if not os.path.exists(prune_save_path):
41 | os.makedirs(prune_save_path)
42 |
43 | parse = argparse.ArgumentParser('pruning model')
44 | parse.add_argument('--weight_path', type=str, default=None, help='已训练好的模型权重地址')
45 | parse.add_argument('--prune_type', type=str, default='l1filter', help='修剪模型的方式,支持:l1filter,l2filter,fpgm')
46 | parse.add_argument('--sparsity', type=float, default=0.5, help='模型稀疏化比例')
47 | parse.add_argument('--op_names', type=list, default=None, help='指定修剪哪些layer名字,默认修剪全部的Conv2d,输入为样例:["conv1","conv2"]')
48 | parse.add_argument('--finetune_epoches', type=int, default=10, help='微调模型的轮次数')
49 | parse.add_argument('--dummy_input', type=str, required=True, help='输入模型的形状,例如:(10,3,128,128)')
50 |
51 |
52 | def trainer(model, train_loader, optimizer, criterion, epoch, device):
53 | model = model.to(device)
54 | model.train()
55 | for idx, (image_data, target) in enumerate(train_loader):
56 | image_data, target = image_data.to(device), target.to(device)
57 | optimizer.zero_grad()
58 | output = model(image_data)
59 | train_loss = criterion(output, target)
60 | train_loss.backward()
61 | optimizer.step()
62 | if idx % 20 == 0:
63 | print(f'第{epoch}轮--第{idx}批次--train_loss : {train_loss.item()}')
64 |
65 |
66 | # 微调模型评估函数
67 | def evaluator(model, test_loader, criterion, epoch, device):
68 | model = model.to(device)
69 | index = 0
70 | test_loss = 0
71 | acc = 0
72 | with torch.no_grad():
73 | model.eval()
74 | for image_data, target in test_loader:
75 | image_data, target = image_data.to(device), target.to(device)
76 | output = model(image_data)
77 | test_loss += criterion(output, target)
78 | pred = output.argmax(dim=1)
79 | acc += torch.mean(torch.eq(pred, target).float()).item()
80 | index += 1
81 | test_loss /= index
82 | acc /= index
83 | print(f'第{epoch}轮--Average test_loss : {test_loss} -- Average Accuracy : {acc}')
84 | return acc
85 |
86 |
87 | def prune_tools(args, model, train_loader, test_loader, criterion, optimizer, device):
88 | model.load_state_dict(torch.load(args.weight_path))
89 | model = model.to(device)
90 | print(model)
91 | if args.op_names is None:
92 | config_list = [{
93 | 'sparsity': args.sparsity,
94 | 'op_types': ['Conv2d']
95 | }]
96 | else:
97 | config_list = [{
98 | 'sparsity': args.sparsity,
99 | 'op_type': ['Conv2d'],
100 | 'op_names': args.op_names
101 | }]
102 | prune_type = {
103 | 'l1filter': L1FilterPruner,
104 | 'l2filter': L2FilterPruner,
105 | 'fpgm': FPGMPruner
106 | }
107 | # 裁剪模型
108 | pruner = prune_type[args.prune_type](model, config_list)
109 | pruner.compress()
110 |
111 | # 导出稀疏模型和掩码模型
112 | pruner.export_model(model_path=os.path.join(prune_save_path, 'sparsity_model.pth'),
113 | mask_path=os.path.join(prune_save_path, 'mask_model.pth'))
114 |
115 | # 打开新模型
116 | pruner._unwrap_model()
117 | # 模型加速
118 | dummy_input = args.dummy_input.split(',')
119 | n, c, h, w = int(dummy_input[0][1:]), int(dummy_input[1]), int(dummy_input[2]), int(dummy_input[3][:-1])
120 | m_speedup = ModelSpeedup(model, dummy_input=torch.randn(n, c, h, w).to(device),
121 | masks_file=os.path.join(prune_save_path, 'mask_model.pth'))
122 | m_speedup.speedup_model()
123 |
124 | # 微调模型
125 | best_acc = 0
126 | for epoch in range(1, args.finetune_epoches + 1):
127 | trainer(model, train_loader, optimizer, criterion, epoch, device)
128 | acc = evaluator(model, test_loader, criterion, epoch, device)
129 | if acc > best_acc:
130 | torch.save(model, os.path.join(prune_save_path, 'pruned_model.pth'))
131 | print('successfully save pruned_model weights!')
132 | best_acc = acc
133 | else:
134 | continue
135 | print(f'微调后的模型准确率为 : {best_acc * 100}%')
136 |
137 |
138 | if __name__ == '__main__':
139 | args = parse.parse_args()
140 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
141 | config = load_config_util('config/config.yaml')
142 | dataset = ClassDataset(config['data_dir'], config)
143 | train_dataset, test_dataset = random_split(dataset,
144 | [int(len(dataset) * config['train_rate']),
145 | len(dataset) - int(
146 | len(dataset) * config['train_rate'])]
147 | )
148 | train_data_loader = DataLoader(train_dataset, batch_size=config['batch_size'], shuffle=True)
149 | test_data_loader = DataLoader(test_dataset, batch_size=config['batch_size'], shuffle=True)
150 | model = ClassifierNet(config['net_type'], len(config['class_names']), False)
151 | model.load_state_dict(torch.load(args.weight_path))
152 | criterion = Loss(config['loss_type']).get_loss_fun()
153 | optimizer = Optimizer(model, config['optimizer']).get_optimizer()
154 | prune_tools(args, model, train_data_loader, test_data_loader, criterion, optimizer, device)
155 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | '''
2 | _*_coding:utf-8 _*_
3 | @Time :2022/1/29 10:52
4 | @Author : qiaofengsheng
5 | @File :train.py
6 | @Software :PyCharm
7 | '''
8 | import os.path
9 | import time
10 | import torch
11 | import tqdm
12 | from torch.utils.tensorboard import SummaryWriter
13 | from model.net.net import ClassifierNet
14 | from model.loss.loss_fun import *
15 | from model.optimizer.optim import *
16 | from model.dataset.dataset import *
17 | import argparse
18 |
19 | parse = argparse.ArgumentParser(description='train_demo of argparse')
20 | parse.add_argument('--weights_path', default=None)
21 |
22 |
23 | class Train:
24 | def __init__(self, config):
25 | self.config = config
26 | if not os.path.exists(config['model_dir']):
27 | os.makedirs(config['model_dir'])
28 | self.summary_writer = SummaryWriter(config['log_dir'])
29 | self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
30 | self.net = ClassifierNet(self.config['net_type'], len(self.config['class_names']),
31 | self.config['pretrained']).to(self.device)
32 | self.loss_fun = Loss(self.config['loss_type']).get_loss_fun()
33 | self.optimizer = Optimizer(self.net, self.config['optimizer']).get_optimizer()
34 | self.dataset = ClassDataset(self.config['data_dir'], config)
35 | self.train_dataset, self.test_dataset = random_split(self.dataset,
36 | [int(len(self.dataset) * config['train_rate']),
37 | len(self.dataset) - int(
38 | len(self.dataset) * config['train_rate'])]
39 | )
40 | self.train_data_loader = DataLoader(self.train_dataset, batch_size=self.config['batch_size'], shuffle=True)
41 | self.test_data_loader = DataLoader(self.test_dataset, batch_size=self.config['batch_size'], shuffle=True)
42 |
43 | def train(self, weights_path):
44 | print(f'device:{self.device} 训练集:{len(self.train_dataset)} 测试集:{len(self.test_dataset)}')
45 | if weights_path is not None:
46 | if os.path.exists(weights_path):
47 | self.net.load_state_dict(torch.load(weights_path))
48 | print('successfully loading model weights!')
49 | else:
50 | print('no loading model weights')
51 | temp_acc = 0
52 | for epoch in range(1, self.config['epochs'] + 1):
53 | self.net.train()
54 | with tqdm.tqdm(self.train_data_loader) as t1:
55 | for i, (image_data, image_label) in enumerate(self.train_data_loader):
56 | image_data, image_label = image_data.to(self.device), image_label.to(self.device)
57 | out = self.net(image_data)
58 | if self.config['loss_type'] == 'cross_entropy':
59 | train_loss = self.loss_fun(out, image_label)
60 | else:
61 | train_loss = self.loss_fun(out, utils.label_one_hot(image_label).type(torch.FloatTensor).to(
62 | self.device))
63 | t1.set_description(f'Train-Epoch {epoch} 轮 {i} 批次 : ')
64 | t1.set_postfix(train_loss=train_loss.item())
65 | time.sleep(0.1)
66 | t1.update(1)
67 | self.optimizer.zero_grad()
68 | train_loss.backward()
69 | self.optimizer.step()
70 | if i % 10 == 0:
71 | torch.save(self.net.state_dict(), os.path.join(self.config['model_dir'], 'last.pth'))
72 | self.summary_writer.add_scalar('train_loss', train_loss.item(), epoch)
73 |
74 | self.net.eval()
75 | acc, temp = 0, 0
76 | with torch.no_grad():
77 | with tqdm.tqdm(self.test_data_loader) as t2:
78 | for j, (image_data, image_label) in enumerate(self.test_data_loader):
79 | image_data, image_label = image_data.to(self.device), image_label.to(self.device)
80 | out = self.net(image_data)
81 | if self.config['loss_type'] == 'cross_entropy':
82 | test_loss = self.loss_fun(out, image_label)
83 | else:
84 | test_loss = self.loss_fun(out, utils.label_one_hot(image_label).type(torch.FloatTensor).to(
85 | self.device))
86 | out = torch.argmax(out, dim=1)
87 | test_acc = torch.mean(torch.eq(out, image_label).float()).item()
88 | acc += test_acc
89 | temp += 1
90 | t2.set_description(f'Test-Epoch {epoch} 轮 {j} 批次 : ')
91 | t2.set_postfix(test_loss=test_loss.item(), test_acc=test_acc)
92 | time.sleep(0.1)
93 | t2.update(1)
94 | print(f'Test-Epoch {epoch} 轮准确率为 : {acc / temp}')
95 | if (acc / temp) > temp_acc:
96 | temp_acc = acc / temp
97 | torch.save(self.net.state_dict(), os.path.join(self.config['model_dir'], 'best.pth'))
98 | else:
99 | torch.save(self.net.state_dict(), os.path.join(self.config['model_dir'], 'last.pth'))
100 | self.summary_writer.add_scalar('test_loss', test_loss.item(), epoch)
101 | self.summary_writer.add_scalar('test_acc', acc / temp, epoch)
102 |
103 |
104 | if __name__ == '__main__':
105 | args = parse.parse_args()
106 | config = utils.load_config_util('config/config.yaml')
107 | Train(config).train(args.weights_path)
108 |
--------------------------------------------------------------------------------