├── .gitignore ├── README.md ├── applications ├── __init__.py ├── convert.py ├── distillation.py ├── make_curriculum.py ├── test.py ├── train.py └── visualize.py ├── checkpoints └── README.md ├── config.py ├── criterions ├── __init__.py ├── binary_cross_entropy │ ├── __init__.py │ ├── basic_bce.py │ ├── ghm_bce.py │ ├── hybrid_bce.py │ ├── ohm_bce.py │ ├── threshold_bce.py │ └── weighted_bce.py ├── softmax_cross_entropy │ ├── __init__.py │ ├── basic_softmax.py │ ├── ghm_softmax.py │ ├── hybrid_softmax.py │ ├── ohm_softmax.py │ ├── threshold_softmax.py │ └── weighted_softmax.py └── utils.py ├── data ├── test │ ├── cats │ │ ├── cats_00001.jpg │ │ ├── cats_00002.jpg │ │ ├── cats_00003.jpg │ │ ├── cats_00004.jpg │ │ ├── cats_00005.jpg │ │ ├── cats_00006.jpg │ │ ├── cats_00007.jpg │ │ ├── cats_00008.jpg │ │ ├── cats_00009.jpg │ │ └── cats_00010.jpg │ ├── dogs │ │ ├── dogs_00001.jpg │ │ ├── dogs_00002.jpg │ │ ├── dogs_00003.jpg │ │ ├── dogs_00004.jpg │ │ ├── dogs_00005.jpg │ │ ├── dogs_00006.jpg │ │ ├── dogs_00007.jpg │ │ ├── dogs_00008.jpg │ │ ├── dogs_00009.jpg │ │ └── dogs_00010.jpg │ └── panda │ │ ├── panda_00001.jpg │ │ ├── panda_00002.jpg │ │ ├── panda_00003.jpg │ │ ├── panda_00004.jpg │ │ ├── panda_00005.jpg │ │ ├── panda_00006.jpg │ │ ├── panda_00007.jpg │ │ ├── panda_00008.jpg │ │ ├── panda_00009.jpg │ │ └── panda_00010.jpg ├── train │ ├── cats │ │ ├── cats_00001.jpg │ │ ├── cats_00002.jpg │ │ ├── cats_00003.jpg │ │ ├── cats_00004.jpg │ │ ├── cats_00005.jpg │ │ ├── cats_00006.jpg │ │ ├── cats_00007.jpg │ │ ├── cats_00008.jpg │ │ ├── cats_00009.jpg │ │ └── cats_00010.jpg │ ├── dogs │ │ ├── dogs_00001.jpg │ │ ├── dogs_00002.jpg │ │ ├── dogs_00003.jpg │ │ ├── dogs_00004.jpg │ │ ├── dogs_00005.jpg │ │ ├── dogs_00006.jpg │ │ ├── dogs_00007.jpg │ │ ├── dogs_00008.jpg │ │ ├── dogs_00009.jpg │ │ └── dogs_00010.jpg │ └── panda │ │ ├── panda_00001.jpg │ │ ├── panda_00002.jpg │ │ ├── panda_00003.jpg │ │ ├── panda_00004.jpg │ │ ├── panda_00005.jpg │ │ ├── panda_00006.jpg │ │ ├── panda_00007.jpg │ │ ├── panda_00008.jpg │ │ ├── panda_00009.jpg │ │ └── panda_00010.jpg └── val │ ├── cats │ ├── cats_00001.jpg │ ├── cats_00002.jpg │ ├── cats_00003.jpg │ ├── cats_00004.jpg │ ├── cats_00005.jpg │ ├── cats_00006.jpg │ ├── cats_00007.jpg │ ├── cats_00008.jpg │ ├── cats_00009.jpg │ └── cats_00010.jpg │ ├── dogs │ ├── dogs_00001.jpg │ ├── dogs_00002.jpg │ ├── dogs_00003.jpg │ ├── dogs_00004.jpg │ ├── dogs_00005.jpg │ ├── dogs_00006.jpg │ ├── dogs_00007.jpg │ ├── dogs_00008.jpg │ ├── dogs_00009.jpg │ └── dogs_00010.jpg │ └── panda │ ├── panda_00001.jpg │ ├── panda_00002.jpg │ ├── panda_00003.jpg │ ├── panda_00004.jpg │ ├── panda_00005.jpg │ ├── panda_00006.jpg │ ├── panda_00007.jpg │ ├── panda_00008.jpg │ ├── panda_00009.jpg │ └── panda_00010.jpg ├── dataloader ├── __init__.py ├── enhancement │ ├── __init__.py │ ├── autoaugment.py │ ├── mixup.py │ ├── multi_scale.py │ ├── my_augment.py │ └── rescale.py ├── my_dataloader.py └── utils.py ├── demos ├── main.py └── test_images │ ├── cats_00001.jpg │ ├── dogs_00001.jpg │ └── panda_00001.jpg ├── main.py ├── models ├── __init__.py ├── efficientnet │ ├── __init__.py │ ├── components.py │ ├── config.py │ ├── efficientnet.py │ ├── factory.py │ └── utils.py ├── mobilenetv3 │ ├── __init__.py │ ├── factory.py │ └── mobilenetv3.py ├── model_factory.py └── resnest │ ├── __init__.py │ ├── ablation.py │ ├── resnest.py │ ├── resnet.py │ └── splat.py ├── optim ├── __init__.py ├── swa.py └── torchtools │ ├── __init__.py │ ├── lr_scheduler │ ├── __init__.py │ └── delayed.py │ ├── nn │ ├── __init__.py │ ├── adain.py │ ├── functional │ │ ├── __init__.py │ │ ├── gradient_penalty.py │ │ ├── perceptual.py │ │ └── vq.py │ ├── gp_loss.py │ ├── mish.py │ ├── perceptual.py │ ├── pixel_normalzation.py │ ├── simple_self_attention.py │ └── vq.py │ └── optim │ ├── __init__.py │ ├── lamb.py │ ├── lookahead.py │ ├── novograd.py │ ├── over9000.py │ ├── radam.py │ ├── ralamb.py │ └── ranger.py ├── pretrained └── README.md ├── requirements.txt ├── utils ├── __init__.py ├── cam_tool │ ├── __init__.py │ ├── cam.py │ ├── grad_cam.py │ ├── grad_cam_plus.py │ └── heatmap.py ├── check_images.py ├── confusion_matrix.py ├── meters.py ├── my_logger.py ├── my_summary.py └── network_viz.py ├── visual_images └── README.md └── z_task_shell ├── 0_check_best_lr.sh ├── 1_print_model_info.sh ├── 2_train_cpu_or_gpu.sh ├── 3_train_distributed.sh ├── 4_evaluate_model.sh ├── 5_visualize_model_layer.sh ├── 6_make_curriculum.sh ├── 7_knowledge_distillation.sh └── 8_convert_to_jit.sh /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | 131 | .DS_Store 132 | .idea 133 | data/205087_450106_bundle_archive.zip 134 | logs/ 135 | pretrained/* 136 | checkpoints/* 137 | !pretrained/README.md 138 | !checkpoints/README.md 139 | 140 | ConfusionMatrix.png 141 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # 基于PyTorch的分类网络库 2 | 3 | 实现的分类网络包括: 4 | 5 | - [x] PyTorch自带的网络:resnet, shufflenet, densenet, mobilenet, mnasnet等; 6 | - [x] MobileNet v3; 7 | - [x] EfficientNet系列; 8 | - [x] ResNeSt系列; 9 | 10 | --- 11 | 12 | ## 包含特性 13 | 14 | - [x] 支持多种功能(`application/`):训练、测试、转JIT部署、模型蒸馏、可视化; 15 | - [x] 数据增强(`dataloader/enhancement`):AutoAugment,自定义Augment(MyAugment),mixup数据增强,多尺度训练数据增强; 16 | - [x] 库中包含多种优化器(`optim`):目前使用的是Adam,同时推荐RAdam; 17 | - [x] 不同损失指标的实现(`criterions`):OHM、GHM、weighted loss等; 18 | 19 | --- 20 | 21 | ## 文件结构说明 22 | 23 | - `applications`: 包括`test.py, train.py, convert.py`等应用,提供给`main.py`调用; 24 | - `checkpoints`: 训练好的模型文件保存目录(当前可能不存在); 25 | - `criterions`: 自定义损失函数; 26 | - `data`: 训练/测试/验证/预测等数据集存放的路径; 27 | - `dataloader`: 数据加载、数据增强、数据预处理(默认采用ImageNet方式); 28 | - `demos`: 模型使用的demo,目前`classifier.py`显示如何调用`jit`格式模型进行预测; 29 | - `logs`: 训练过程中TensorBoard日志存放的文件(当前可能不存在); 30 | - `models`: 自定义的模型结构; 31 | - `optim`: 一些前沿的优化器,PyTorch官方还未实现; 32 | - `pretrained`: 预训练模型文件; 33 | - `utils`: 工具脚本:混淆矩阵、图片数据校验、模型结构打印、日志等; 34 | - `config.py`: 配置文件; 35 | - `main.py`: 总入口; 36 | - `requirements.txt`: 工程依赖包列表; 37 | 38 | --- 39 | 40 | ## 使用说明 41 | 42 | ### 数据准备 43 | 44 | 在文件夹`data`下放数据,分成三个文件夹: `train/test/val`,对应 训练/测试/验证 数据文件夹; 45 | 每个子文件夹下,依据分类类别每个类别建立一个对应的文件夹,放置该类别的图片。 46 | 47 | 数据准备完毕后,使用`utils/check_images.py`脚本,检查图像数据的有效性,防止在训练过程中遇到无效图片中止训练。 48 | 49 | 最终大概结构为: 50 | ``` 51 | - data 52 | - train 53 | - class_0 54 | - 0.jpg 55 | - 1.jpg 56 | - ... 57 | - class_1 58 | - ... 59 | - .. 60 | - test 61 | - ... 62 | - val 63 | - ... 64 | - dataloader 65 | - ... 66 | ``` 67 | 68 | ### 部分重要配置参数说明 69 | 70 | 针对`config.py`里的部分重要参数说明如下: 71 | 72 | - `--data`: 数据集根目录,下面包含`train`, `test`, `val`三个目录的数据集,默认当前文件夹下`data/`目录; 73 | - `--image_size`: 输入应该为两个整数值,预训练模型的输入时正方形的,也就是[224, 224]之类的; 74 | 实际可以根据自己需要更改,数据预处理时,会将图像 等比例resize然后再padding(默认用0 padding)到 指定的输入尺寸。 75 | - `--num_classes`: 分类模型的预测类别数; 76 | - `-b`: 设置batch size大小,默认为256,可根据GPU显存设置; 77 | - `-j`: 设置数据加载的进程数,默认为8,可根据CPU使用量设置; 78 | - `--criterion`: 损失函数,一种使用PyTorch自带的softmax损失函数,一种使用我自定义的sigmoid损失函数; 79 | sigmoid损失函数则是将多分类问题转化为多标签二分类问题,同时我增加了几个如GHM自定义的sigmoid损失函数, 80 | 可通过`--weighted_loss --ghm_loss --threshold_loss --ohm_loss`指定是否启动; 81 | - `--lr`: 初始学习率,`main.py`里我默认使用Adam优化器;目前学习率的scheduler我使用的是`LambdaLR`接口,自定义函数规则如下, 82 | 详细可参考`main.py`的`adjust_learning_rate(epoch, args)`函数: 83 | ``` 84 | ~ warmup: 0.1 85 | ~ warmup + int([1.5 * (epochs - warmup)]/4.0): 1, 86 | ~ warmup + int([2.5 * (epochs - warmup)]/4.0): 0.1 87 | ~ warmup + int([3.5 * (epochs - warmup)]/4.0) 0.01 88 | ~ epochs: 0.001 89 | ``` 90 | - `--warmup`: warmup的迭代次数,训练前warmup个epoch会将 初始学习率*0.1 作为warmup期间的学习率; 91 | - `--epochs`: 训练的总迭代次数; 92 | - `--aug`: 是否使用数据增强,目前默认使用的是我自定义的数据增强方式:`dataloader/my_augment.py`; 93 | - `--mixup`: 数据增强mixup,默认 False; 94 | - `--multi_scale`: 多尺度训练,默认 False; 95 | - `--resume`: 权重文件路径,模型文件将被加载以进行模型初始化,`--jit`和`--evaluation`时需要指定; 96 | - `--jit`: 将模型转为JIT格式,利于部署; 97 | - `--evaluation`: 在测试集上进行模型评估; 98 | - `--knowledge`: 指定数据集,使用教师模型(配合resume选型指定)对该数据集进行预测,获取概率文件(知识), 99 | 生成的概率文件路径为`data/distill.txt`,同时生成原始概率`data/label.txt`; 100 | - `--distill`: 模型蒸馏(需要教师模型输出的概率文件),默认 False, 101 | 使用该模式训练前,需要先启用`--knowledge train --resume teacher.pth`对训练集进行测试,生成概率文件作为教师模型的概率; 102 | 概率文件形式为`data`路径下`distill*.txt`模式的文件,有多个文件会都使用,取均值作为教师模型的概率输出指导接下来训练的学生模型; 103 | - `--visual_data`: 对指定数据集运行测试,并进行可视化; 104 | - `--visual_method`: 可视化方法,包含`cam`, `grad-cam`, `grad-camm++`三种; 105 | - `--make_curriculum`: 制作课程学习的课程文件; 106 | - `--curriculum_thresholds`: 不同课程中样本的阈值; 107 | - `--curriculum_weights`: 不同课程中样本的损失函数权重; 108 | - `--curriculum_learning`: 进行课程学习,从`data/curriculum.txt`中读取样本权重数据,训练时给对应样本的损失函数加权; 109 | 110 | BTW,在`models/efficientnet/model.py`中增加了`sample-free`的思想,目前代码注释掉了,若需要可以借鉴使用。 111 | `sample-free`主要是我使用bce进行多标签二分类时,我希望任务偏好某些类别,所以在初始某些类别的bias上设置一个较大的数,提高初始概率。 112 | (具体计算公式可参考原论文 Is Sampling Heuristics Necessary in Training Deep Object Detectors) 113 | 114 | 参数的详细说明可查看`config.py`文件。 115 | 116 | --- 117 | 118 | ## 快速使用 119 | 120 | 可参考对应的`z_task_shell/*.sh`文件 121 | 122 | ### 模型部署demo 123 | 124 | 训练好模型后,想用该模型对图像数据进行预测,可使用`demos`目录下的脚本`classifier.py`: 125 | 126 | ```shell 127 | cd demos 128 | python classifier.py 129 | ``` 130 | 131 | --- 132 | 133 | ## Reference 134 | 135 | [d-li14/mobilenetv3.pytorch](https://github.com/d-li14/mobilenetv3.pytorch) 136 | 137 | [lukemelas/EfficientNet-PyTorch](https://github.com/lukemelas/EfficientNet-PyTorch) 138 | 139 | [zhanghang1989/ResNeSt](https://github.com/zhanghang1989/ResNeSt) 140 | 141 | [yizt/Grad-CAM.pytorch](https://github.com/yizt/Grad-CAM.pytorch) 142 | 143 | ## TODO 144 | 145 | - [x] 预训练模型下载URL整理(参考Reference); 146 | - [ ] 模型的openvino格式的转换和对应的部署demo; 147 | -------------------------------------------------------------------------------- /applications/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | File __init__.py.py 4 | 5 | 应用: 6 | train:训练 7 | test:测试 8 | Visualize:可视化 9 | make_curriculum:课程学习时制作课程权重文件 10 | distill:蒸馏时制作教师模型概率文件 11 | convert_to_jit:模型转JIT 12 | """ 13 | from .train import train 14 | from .test import test 15 | from .visualize import Visualize 16 | from .make_curriculum import make_curriculum 17 | from .distillation import distill 18 | from .convert import convert_to_jit 19 | -------------------------------------------------------------------------------- /applications/convert.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | File convert.py 4 | 5 | 模型转换:转 torch.jit.script 6 | """ 7 | import argparse 8 | import logging 9 | 10 | import torch 11 | 12 | 13 | def convert_to_jit(model: torch.nn.Module, args: argparse.Namespace): 14 | """ 15 | 将模型转为JIT格式,利于部署 16 | :param model: 待转格式模型 17 | :param args: 转模型超参 18 | """ 19 | logging.info('Converting model ...') 20 | image_height, image_width = args.image_size 21 | 22 | model.eval() 23 | rand_image = torch.rand(1, 3, image_height, image_width) 24 | with torch.no_grad(): 25 | with torch.autograd.profiler.profile(use_cuda=args.cuda) as prof: 26 | model(rand_image) 27 | logging.info(prof) 28 | torch_model = torch.jit.trace(model, (rand_image,)) 29 | torch_model.save(f'checkpoints/jit_{args.arch}.pt') 30 | logging.info('Save with jit script mode over ~ ') 31 | -------------------------------------------------------------------------------- /applications/distillation.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | File distillation.py 4 | 5 | 使用模型进行推理,获取概率文件,用于后续的模型蒸馏 6 | """ 7 | import os 8 | import logging 9 | import argparse 10 | 11 | import numpy as np 12 | from torch import nn 13 | from torch.utils.data import DataLoader 14 | 15 | from .test import test 16 | 17 | 18 | def distill(distill_loader: DataLoader, model: nn.Module, criterion: nn.Module, 19 | args: argparse.Namespace, is_confuse_matrix: bool = True): 20 | """ 21 | 获取模型蒸馏时,教师模型的评估的概率文件 22 | :param distill_loader: 评估的数据集 23 | :param model: 教师模型 24 | :param criterion: 评估的指标/损失函数 25 | :param args: 超参 26 | :param is_confuse_matrix: 是否对评估的数据集输出混淆矩阵 27 | """ 28 | _, _, paths_targets_preds_probs = test(distill_loader, model, criterion, args, is_confuse_matrix) 29 | 30 | knowledges = [] 31 | original_labels = [] 32 | for path, target, _, prob in paths_targets_preds_probs: 33 | prob = ','.join([f'{num:.2f}' for num in prob]) 34 | knowledges.append(f'{path},{prob}\n') 35 | 36 | label = np.eye(args.num_classes, dtype=np.float32)[target] 37 | label = ','.join([str(num) for num in label]) 38 | original_labels.append(f"{path},{label}\n") 39 | 40 | distill_file_path = os.path.join(args.data, 'distill.txt') 41 | with open(distill_file_path, 'w+') as knowledge_file: 42 | knowledge_file.writelines(knowledges) 43 | 44 | distill_file_path = os.path.join(args.data, 'label.txt') 45 | with open(distill_file_path, 'w+') as label_file: 46 | label_file.writelines(original_labels) 47 | 48 | logging.info('Finish generating knowledge file for model distillation!') 49 | -------------------------------------------------------------------------------- /applications/make_curriculum.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | File make_curriculum.py 4 | 5 | 使用模型进行推理,获取课程文件,用于后续的课程学习 6 | """ 7 | import os 8 | import logging 9 | import argparse 10 | 11 | from torch import nn 12 | from torch.utils.data import DataLoader 13 | 14 | from .test import test 15 | 16 | 17 | def make_curriculum(curriculum_loader: DataLoader, model: nn.Module, criterion: nn.Module, 18 | args: argparse.Namespace, is_confuse_matrix: bool = True): 19 | """ 20 | 获取课程学习时,不同难易样本具有不同损失权重的课程文件 21 | :param curriculum_loader: 评估的数据集 22 | :param model: 制作课程的模型 23 | :param criterion: 评估的指标/损失函数 24 | :param args: 超参 25 | :param is_confuse_matrix: 是否对评估的数据集输出混淆矩阵 26 | """ 27 | _, _, paths_targets_preds_probs = test(curriculum_loader, model, criterion, args, is_confuse_matrix) 28 | 29 | if len(args.curriculum_thresholds) != len(args.curriculum_weights): 30 | raise ValueError(f'课程数不确定:课程阈值({args.curriculum_thresholds})与' 31 | f'课程权重({args.curriculum_weights})的长度必须相等!') 32 | 33 | if args.curriculum_thresholds[-1] > 0.0: 34 | logging.warning(f'课程阈值({args.curriculum_thresholds})中没有指定最小阈值(0.0),' 35 | f'可能部分样本不会被添加到该学习的课程之中') 36 | curriculums = [] 37 | # 逐样本划分课程 38 | for path, target, _, prob in paths_targets_preds_probs: 39 | for i, threshold in enumerate(args.curriculum_thresholds): 40 | if prob[target] > threshold: 41 | curriculums.append(f'{path},{args.curriculum_weights[i]}\n') 42 | break 43 | 44 | distill_file_path = os.path.join(args.data, 'curriculum.txt') 45 | with open(distill_file_path, 'w+') as curriculum_file: 46 | curriculum_file.writelines(curriculums) 47 | 48 | logging.info('Finish generating curriculum file for curriculum learning!') 49 | -------------------------------------------------------------------------------- /applications/test.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | File test.py 4 | 5 | 模型测试脚本 6 | """ 7 | import time 8 | import typing 9 | import logging 10 | import argparse 11 | 12 | import numpy as np 13 | import torch 14 | from torch import nn 15 | import torch.distributed as dist 16 | from torch.utils.data import DataLoader 17 | from sklearn.metrics import classification_report 18 | 19 | from utils import meters, confusion_matrix 20 | from dataloader.my_dataloader import DataLoaderX 21 | 22 | 23 | def test(test_loader: DataLoader, model: nn.Module, criterion: nn.Module, 24 | args: argparse.Namespace, is_confuse_matrix: bool = True) -> \ 25 | (float, float, typing.List[typing.Tuple[(str, int, int, str)]]): 26 | """ 27 | 验证集、测试集 模型评估 28 | :param test_loader: 测试集DataLoader对象 29 | :param model: 待测试模型 30 | :param criterion: 损失函数 31 | :param args: 测试参数 32 | :param is_confuse_matrix: 是否输出混淆矩阵 33 | """ 34 | batch_time = meters.AverageMeter('Time', ':6.3f') 35 | losses = meters.AverageMeter('Loss', ':.4e') 36 | top1 = meters.AverageMeter('Acc@1', ':6.2f') 37 | progress = meters.ProgressMeter( 38 | len(test_loader), batch_time, losses, top1, prefix='Test: ') 39 | 40 | model.eval() 41 | # 图像路径,label,预测类别,概率向量 统计量 42 | all_paths, all_targets, all_preds, all_probs = list(), list(), list(), list() 43 | # 将路径映射为index,这样可以转为tensor,在分布式训练的时候才能使用gather 44 | path_2_index = {path: i for i, (path, _, _) in enumerate(test_loader.dataset.samples)} 45 | index_2_path = {i: path for path, i in path_2_index.items()} 46 | with torch.no_grad(): 47 | end_time = time.time() 48 | for i, (images, targets, paths, weights) in enumerate(test_loader): 49 | images = DataLoaderX.normalize(images, args) 50 | if args.cuda: 51 | images = images.cuda(args.gpu, non_blocking=True) 52 | targets = targets.cuda(args.gpu, non_blocking=True) 53 | weights = weights.cuda(args.gpu, non_blocking=True) 54 | 55 | # 模型预测 56 | outputs = model(images) 57 | loss = criterion(outputs, targets, weights) 58 | 59 | # 统计准确率和损失函数 60 | acc1, pred, target = accuracy(outputs, targets) 61 | # 统计量 62 | losses.update(loss.item(), images.size(0)) 63 | top1.update(acc1.item(), images.size(0)) 64 | batch_time.update(time.time() - end_time) 65 | end_time = time.time() 66 | 67 | # 收集结果 68 | all_paths.append(np.array([path_2_index[p] for p in paths], dtype=np.float32)) 69 | all_targets.append(targets.cpu().numpy()) 70 | all_preds.append(pred.cpu().numpy()) 71 | all_probs.append(outputs.sigmoid().cpu().numpy()) 72 | 73 | if i % args.print_freq == 0: 74 | progress.print(i) 75 | all_paths = np.concatenate(all_paths, axis=0) 76 | all_targets = np.concatenate(all_targets, axis=0) 77 | all_preds = np.concatenate(all_preds, axis=0) 78 | all_probs = np.concatenate(all_probs, axis=0) 79 | 80 | if args.distributed: 81 | all_paths = gather_tensors_from_gpus(all_paths, args) 82 | all_targets = gather_tensors_from_gpus(all_targets, args) 83 | all_preds = gather_tensors_from_gpus(all_preds, args) 84 | all_probs = gather_tensors_from_gpus(all_probs, args) 85 | top1 = gather_meters_from_gpus(top1, args) 86 | losses = gather_meters_from_gpus(losses, args) 87 | batch_time = gather_meters_from_gpus(batch_time, args) 88 | 89 | all_paths = [index_2_path[i] for i in all_paths] 90 | logging.info(f'* Acc@1 {top1.avg:.3f} and loss {losses.avg:.3f} with time {batch_time.avg:.3f}') 91 | 92 | if is_confuse_matrix and ((args.gpus <= 1) or (args.gpu == args.gpus - 1)): 93 | # 同一台服务器上多卡训练时,只有最后一张卡保存混淆图 94 | report = classification_report(all_targets, all_preds, target_names=test_loader.dataset.classes) 95 | logging.info(f'分类结果报告:\n{report}',) 96 | 97 | confusion_matrix.plot_confusion_matrix(all_targets, all_preds, test_loader.dataset.classes, 98 | title=f'ConfusionMatrix', is_save=True) 99 | return top1.avg, losses.avg, list(zip(all_paths, all_targets, all_preds, all_probs)) 100 | 101 | 102 | def gather_tensors_from_gpus(array: np.ndarray, args) -> np.ndarray: 103 | """ 104 | 从多个GPU中汇总变量,并拼成新的tensor 105 | :param array: 待汇总变量 106 | :param args: 测试超参 107 | :return 汇总并拼接后的数组 108 | """ 109 | with torch.no_grad(): 110 | tensor = torch.from_numpy(array).cuda(args.gpu) 111 | multi_gpu_tensors = [torch.empty_like(tensor) for _ in range(args.world_size)] 112 | dist.all_gather(multi_gpu_tensors, tensor) 113 | return torch.cat(multi_gpu_tensors, dim=0).detach().cpu().numpy() 114 | 115 | 116 | def gather_meters_from_gpus(meter: meters.AverageMeter, args) -> meters.AverageMeter: 117 | """ 118 | 从多个GPU中汇总统计指标,并返回更新后的指标 119 | :param meter: 待汇总统计指标 120 | :param args: 测试超参 121 | :return 汇总并更新后的指标 122 | """ 123 | meter_array = np.array([[meter.avg, meter.count]], dtype=np.float32) 124 | all_meter = gather_tensors_from_gpus(meter_array, args) 125 | meter.reset() 126 | for item, count in all_meter: 127 | meter.update(item, count) 128 | return meter 129 | 130 | 131 | def accuracy(output, target): 132 | """ 133 | 计算准确率和预测结果 134 | :param output: 分类预测 135 | :param target: 分类标签 136 | """ 137 | with torch.no_grad(): 138 | _, pred = output.max(dim=1) 139 | if pred.dim() != target.dim(): 140 | _, target = target.max(dim=1) 141 | correct = pred.eq(target) 142 | acc = correct.float().mean() * 100 143 | return acc, pred, target 144 | -------------------------------------------------------------------------------- /applications/train.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | File train.py 4 | 5 | 模型训练脚本 6 | """ 7 | import time 8 | import shutil 9 | import logging 10 | 11 | import numpy as np 12 | import torch 13 | from torch import nn 14 | from torch.utils.data import DataLoader 15 | from torch.utils.tensorboard import SummaryWriter 16 | from torch.optim.optimizer import Optimizer 17 | import apex 18 | from apex import amp 19 | 20 | from utils import AverageMeter, ProgressMeter 21 | from applications.test import test 22 | from dataloader import MixUp 23 | from dataloader import MultiScale 24 | from dataloader.my_dataloader import DataLoaderX 25 | 26 | 27 | mix_up = None 28 | multi_scale = None 29 | bn_gammas = None 30 | net_weights = None # 排除bias项的weight decay 31 | 32 | 33 | def train(train_loader: DataLoader, val_loader: DataLoader, model: nn.Module, 34 | criterion: nn.Module, optimizer: Optimizer, 35 | scheduler: torch.optim.lr_scheduler._LRScheduler, args): 36 | """ 37 | 训练模型 38 | :param train_loader: 训练集 39 | :param val_loader: 验证集 40 | :param model: 模型 41 | :param criterion: 损失函数 42 | :param optimizer: 优化器 43 | :param args: 训练超参 44 | """ 45 | writer = SummaryWriter(args.logdir) 46 | # writer.add_graph(model, (torch.rand(1, 3, args.image_size[0], args.image_size[1]),)) 47 | global mix_up, multi_scale, bn_gammas, net_weights 48 | if mix_up is None: 49 | mix_up = MixUp(args) 50 | if args.multi_scale and multi_scale is None: 51 | multi_scale = MultiScale(args.image_size) 52 | if bn_gammas is None: 53 | bn_gammas = [m.weight for m in model.modules() 54 | if isinstance(m, nn.BatchNorm2d) or 55 | isinstance(m, nn.SyncBatchNorm) or 56 | isinstance(m, apex.parallel.SyncBatchNorm)] 57 | 58 | if net_weights is None: 59 | net_weights = [param for name, param in model.named_parameters() if name[-4:] != 'bias'] 60 | 61 | best_val_acc1 = 0 62 | learning_rate = 0 63 | for epoch in range(args.epochs): 64 | if args.distributed: 65 | train_loader.sampler.set_epoch(epoch) 66 | learning_rate = scheduler.get_last_lr() 67 | if isinstance(learning_rate, list): 68 | learning_rate = learning_rate[0] 69 | # 训练一个epoch,并在验证集上评估 70 | train_loss, train_acc1 = train_epoch(train_loader, model, criterion, optimizer, epoch, args) 71 | val_acc1, val_loss, _ = test(val_loader, model, criterion, args, is_confuse_matrix=False) 72 | scheduler.step() 73 | # 保存当前及最好的acc@1的checkpoint 74 | is_best = val_acc1 >= best_val_acc1 75 | best_val_acc1 = max(val_acc1, best_val_acc1) 76 | save_checkpoint({ 77 | # 'epoch': epoch + 1, 78 | # 'arch': args.arch, 79 | 'state_dict': model.module.state_dict(), 80 | # 'best_acc1': best_val_acc1, 81 | # 'optimizer': optimizer.state_dict(), 82 | }, is_best, args) 83 | 84 | all_bn_weight = [] 85 | for gamma in bn_gammas: 86 | all_bn_weight.append(gamma.cpu().data.numpy()) 87 | writer.add_histogram('BN gamma', np.concatenate(all_bn_weight, axis=0), epoch) 88 | # writer.add_scalars('Loss', {'Train': train_loss, 'Val': val_loss}, epoch) 89 | # writer.add_scalars('Accuracy', {'Train': train_acc1, 'Val': val_acc1}, epoch) 90 | writer.add_scalar('Train/Loss', train_loss, epoch) 91 | writer.add_scalar('Train/Accuracy', train_acc1, epoch) 92 | writer.add_scalar('Val/Loss', val_loss, epoch) 93 | writer.add_scalar('Val/Accuracy', val_acc1, epoch) 94 | writer.add_scalar('learning rate', learning_rate, epoch) 95 | writer.flush() 96 | writer.close() 97 | logging.info(f'Training Over with lr={learning_rate}~~') 98 | 99 | 100 | def train_epoch(train_loader, model, criterion, optimizer, epoch, args): 101 | """ 102 | 训练模型一个epoch的数据 103 | :param train_loader: 训练集 104 | :param model: 模型 105 | :param criterion: 损失函数 106 | :param optimizer: 优化器 107 | :param epoch: 当前迭代次数 108 | :param args: 训练超参 109 | """ 110 | global mix_up, multi_scale, bn_gammas, net_weights 111 | batch_time = AverageMeter('Time', ':6.3f') 112 | data_time = AverageMeter('Data', ':6.3f') 113 | losses = AverageMeter('Loss', ':.4e') 114 | top1 = AverageMeter('Acc@1', ':6.2f') 115 | progress = ProgressMeter(len(train_loader), batch_time, data_time, 116 | losses, top1, prefix=f"Epoch: [{epoch}]") 117 | 118 | # 训练模式 119 | model.train() 120 | end_time = time.time() 121 | for i, (images, targets, _, weights) in enumerate(train_loader): 122 | # 更新数据加载时间度量 123 | data_time.update(time.time() - end_time) 124 | # 只有训练集,才可能进行mixup和multi-scale数据增强 125 | images, targets1, targets2, mix_rate = mix_up(images, targets) 126 | if args.multi_scale: 127 | images = multi_scale(images) 128 | 129 | images = DataLoaderX.normalize(images, args) 130 | if args.cuda: 131 | images = images.cuda(args.gpu, non_blocking=True) 132 | targets1 = targets1.cuda(args.gpu, non_blocking=True) 133 | weights = weights.cuda(args.gpu, non_blocking=True) 134 | if targets2 is not None: 135 | targets2 = targets2.cuda(args.gpu, non_blocking=True) 136 | mix_rate = mix_rate.cuda(args.gpu, non_blocking=True) 137 | 138 | output = model(images) 139 | loss = criterion(output, targets1, weights) 140 | if targets2 is not None: 141 | loss = mix_rate * loss + (1.0 - mix_rate) * criterion(output, targets2, weights) 142 | if mix_rate < 0.5: 143 | targets1 = targets2 144 | 145 | optimizer.zero_grad() 146 | if args.cuda: 147 | with amp.scale_loss(loss, optimizer) as scaled_loss: 148 | scaled_loss.backward() 149 | else: 150 | loss.backward() 151 | 152 | # network slimming 153 | if args.sparsity: 154 | for gamma in bn_gammas: 155 | gamma.data.add_(-torch.sign(gamma.data), 156 | alpha=args.slim * optimizer.param_groups[0]['lr']) 157 | # weight decay 158 | for param in net_weights: 159 | param.data.add_(-param.data, 160 | alpha=args.weight_decay * optimizer.param_groups[0]['lr']) 161 | 162 | optimizer.step() 163 | 164 | # 更新度量 165 | acc1 = accuracy(output, targets1) 166 | losses.update(loss.detach().cpu().item(), images.size(0)) 167 | top1.update(acc1.item(), images.size(0)) 168 | # 更新一个batch训练时间度量 169 | batch_time.update(time.time() - end_time) 170 | end_time = time.time() 171 | 172 | if i % args.print_freq == 0: 173 | progress.print(i) 174 | return losses.avg, top1.avg 175 | 176 | 177 | def accuracy(output, target): 178 | """ 179 | 计算准确率和预测结果 180 | :param output: 分类预测 181 | :param target: 分类标签 182 | """ 183 | with torch.no_grad(): 184 | _, pred = output.max(dim=1) 185 | if pred.dim() != target.dim(): 186 | _, target = target.max(dim=1) 187 | acc = pred.eq(target).float().mean() * 100 188 | return acc 189 | 190 | 191 | def save_checkpoint(state, is_best, args, filename='checkpoints/checkpoint_{}.pth'): 192 | """ 193 | 保存模型 194 | :param state: 模型状态 195 | :param is_best: 模型是否当前测试集准确率最高 196 | :param args: 训练超参 197 | :param filename: 保存的文件名 198 | """ 199 | filename = filename.format(args.arch) 200 | if (args.gpus > 1) and (args.gpu != args.gpus - 1): 201 | # 同一台服务器上多卡训练时,只有最后一张卡保存模型(多卡同时保存到同一位置会混乱) 202 | return 203 | torch.save(state, filename) 204 | if is_best: 205 | shutil.copyfile(filename, f'checkpoints/model_best_{args.arch}.pth') 206 | -------------------------------------------------------------------------------- /applications/visualize.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | File visualize.py 4 | 5 | 使用CAM(class activaion mapping)进行可视化 6 | Learning Deep Features for Discriminative Localization 7 | ref: https://github.com/zhoubolei/CAM/blob/master/pytorch_CAM.py 8 | """ 9 | import os 10 | import logging 11 | import argparse 12 | 13 | import cv2 14 | import numpy as np 15 | import torch 16 | from torch.utils.data import DataLoader 17 | 18 | from utils import HeatMapTool, CAM, GradCAM, GradCamPlusPlus 19 | 20 | 21 | class Visualize: 22 | 23 | @staticmethod 24 | def visualize(data_loader: DataLoader, model: torch.nn.Module, args: argparse.Namespace): 25 | """ 26 | 获取最后一层卷积层激活函数后的激活输出,根据选取的可视化方法,可视化所有类别的激活图 27 | efficientnet的最后一层卷积层激活函数输出模块名为 _swish_fc,如果要可视化其他模型,这里需要做对应修改 28 | :param data_loader: 待可视化的数据集 29 | :param model: 待可视化模型 30 | :param args: 可视化超参 31 | """ 32 | model.eval() 33 | cam_inferences = [] 34 | if args.visual_method in ('cam', 'all'): 35 | cam_inferences.append(CAM(model, '_swish_fc')) 36 | if args.visual_method in ('grad-cam', 'all'): 37 | cam_inferences.append(GradCAM(model, '_swish_fc')) 38 | if args.visual_method in ('grad-cam++', 'all'): 39 | cam_inferences.append(GradCamPlusPlus(model, '_swish_fc')) 40 | if len(cam_inferences) == 0: 41 | raise NotImplementedError('--visual_method must be in (cam, grad-cam, grad-cam++)') 42 | for i, (images, labels, paths, _) in enumerate(data_loader): 43 | batch_images, batch_cams = [], [] 44 | uint8_images = np.uint8(images.numpy().transpose((0, 2, 3, 1))[..., ::-1] * 255) 45 | for cam_inference in cam_inferences: 46 | outputs, cams = cam_inference(images, args) 47 | batch_cams.append(cams) 48 | batch_images.append(uint8_images) 49 | batch_cams = np.concatenate(batch_cams, axis=2) 50 | batch_images = np.concatenate(batch_images, axis=1) 51 | if args.cuda: 52 | labels = labels.cuda(args.gpu, non_blocking=True) 53 | 54 | match_results, labels, predictions = Visualize._assess(outputs, labels) 55 | for path, image, cams, is_match, label, pred in \ 56 | zip(paths, batch_images, batch_cams, match_results, labels, predictions): 57 | # 预测错误,则需要画出错误的热图和正确的热图 58 | logging.info(f'{is_match} path: {path},{data_loader.dataset.classes[label]},' 59 | f'{data_loader.dataset.classes[pred]}') 60 | cv2.imwrite(f'visual_images/{os.path.basename(path)}', 61 | HeatMapTool.add_heat(image, cams)) 62 | # cv2.imshow(f'{path[10:-4]}-truth', HeatMapTool.add_heat(image, cams)) 63 | # cv2.waitKey(0) 64 | 65 | @staticmethod 66 | def _assess(outputs: torch.Tensor, labels: torch.Tensor) -> (np.ndarray, np.ndarray, np.ndarray): 67 | """ 68 | 评估输出是否预测对标签 69 | :param outputs: 模型输出 70 | :param labels: 标签 71 | :return 是否预测准确,标签,预测结果 72 | """ 73 | with torch.no_grad(): 74 | _, preds = outputs.max(dim=1) 75 | if preds.dim() != labels.dim(): 76 | _, labels = labels.max(dim=1) 77 | result = preds.eq(labels).detach().cpu().numpy() 78 | return result, labels.detach().cpu().numpy(), preds.detach().cpu().numpy() 79 | -------------------------------------------------------------------------------- /checkpoints/README.md: -------------------------------------------------------------------------------- 1 | # 这个文件夹作为训练模型保存的路径 -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | File config.py 4 | 5 | 配置文件 6 | """ 7 | import argparse 8 | import numpy as np 9 | 10 | 11 | parser = argparse.ArgumentParser(description='PyTorch Classification Model Training') 12 | 13 | parser.add_argument('--data', default='data/', metavar='DIR', help='数据集路径') 14 | parser.add_argument('-a', '--arch', metavar='ARCH', default='efficientnet-b0', 15 | help='模型结构,默认:efficientnet-b0') 16 | parser.add_argument('--image_size', default=[400, 224], type=int, nargs='+', dest='image_size', 17 | help='模型输入尺寸[H, W],默认:[400, 224]') 18 | parser.add_argument('--num_classes', default=6, type=int, help='类别数,默认:6') 19 | parser.add_argument('-j', '--workers', default=16, type=int, metavar='N', 20 | help='数据加载进程数,默认:16') 21 | parser.add_argument('-b', '--batch_size', default=256, type=int, metavar='N', 22 | help='训练batch size大小,默认:256') 23 | 24 | # 分布式训练相关 25 | parser.add_argument('--seed', default=1234, type=int, 26 | help='训练或测试时,使用随机种子保证结果的可复现,默认不使用') 27 | parser.add_argument('--sync_bn', default=False, action='store_true', 28 | help='BN同步,默认使用') 29 | parser.add_argument('--cuda', default=True, dest='cuda', action='store_true', 30 | help='是否使用cuda进行模型推理,默认 True,会根据实际机器情况调整') 31 | parser.add_argument('-n', '--nodes', default=1, type=int, help='分布式训练的节点数') 32 | parser.add_argument('-g', '--gpus', default=2, type=int, 33 | help='每个节点使用的GPU数量,可通过设置环境变量(CUDA_VISIBLE_DEVICES=1)限制使用哪些/单个GPU') 34 | parser.add_argument('--rank', default=-1, type=int, help='分布式训练的当前节点的序号') 35 | parser.add_argument('--init_method', default='tcp://11.6.127.208:10006', type=str, 36 | help='url used to set up distributed training') 37 | parser.add_argument('--logdir', default='logs', type=str, metavar='PATH', 38 | help='Tensorboard日志目录,默认 logs') 39 | 40 | # 训练过程参数设置 41 | parser.add_argument('--train', default=False, dest='train', action='store_true', 42 | help='是否训练,默认:False') 43 | parser.add_argument('--epochs', default=85, type=int, metavar='N', 44 | help='训练epoch数,默认:85') 45 | parser.add_argument('--opt', default='adam', type=str, help='优化器,默认:adam') 46 | parser.add_argument('--lr', '--learning-rate', default=1e-4, type=float, metavar='LR', 47 | help='初始学习率,默认:1e-4', dest='lr') 48 | parser.add_argument('--lr_ratios', '--lr_ratios', default=None, type=float, nargs='+', 49 | help='初始学习率每step的变化率,如 [1., 0.1] 则是一开始使用初始学习率,后续衰减 * 0.1,' 50 | '默认:[0.1, 1., 0.1, 0.01, 0.001]', dest='lr_ratios') 51 | parser.add_argument('--lr_steps', '--lr_steps', default=None, type=float, nargs='+', 52 | help='初始学习率每次衰减的epoch数,如[10, 20]表示在10 epoch衰减,20应该是结束epoch,' 53 | '默认:将总epoch分为5段', dest='lr_steps') 54 | parser.add_argument('--momentum', default=0.9, type=float, metavar='M', help='学习率动量') 55 | parser.add_argument('--wd', '--weight-decay', default=5e-4, type=float, metavar='W', 56 | help='网络权重衰减正则项,默认: 5e-4', dest='weight_decay') 57 | parser.add_argument('--warmup', default=5, type=int, metavar='W', help='warm-up迭代数') 58 | parser.add_argument('-p', '--print-freq', default=50, type=int, metavar='N', 59 | help='训练过程中的信息打印,每隔多少个batch打印一次,默认: 50') 60 | parser.add_argument('--pretrained', default=False, dest='pretrained', action='store_true', 61 | help='是否使用预训练模型,默认不使用') 62 | parser.add_argument('--advprop', default=False, action='store_true', 63 | help='使用advprop的预训练模型,默认否,主要针对EfficientNet系列') 64 | 65 | # 网络参数设置 66 | parser.add_argument('--criterion', default='softmax', type=str, 67 | help='使用的损失函数,默认 softmax,可选 bce') 68 | parser.add_argument('--weighted_loss', default=False, dest='weighted_loss', action='store_true', 69 | help='损失函数是否使用加权策略,默认否') 70 | parser.add_argument('--threshold_loss', default=False, dest='threshold_loss', action='store_true', 71 | help='损失函数是否使用阈值策略,默认否') 72 | parser.add_argument('--ghm_loss', default=False, dest='ghm_loss', action='store_true', 73 | help='损失函数是否使用GHM策略,默认否') 74 | parser.add_argument('--ohm_loss', default=False, dest='ohm_loss', action='store_true', 75 | help='损失函数是否使用OHM策略,默认否') 76 | parser.add_argument('--hard_ratio', default=0.7, dest='hard_ratio', type=float, 77 | help='OHM损失函数中困难样本的比例,默认 0.7') 78 | 79 | # 额外的训练技巧参数设置 80 | parser.add_argument('--mixup', default=False, dest='mixup', action='store_true', 81 | help='使用mix-up对训练数据进行数据增强,默认 False') 82 | parser.add_argument('--mixup_ratio', default=0.5, dest='mixup_ratio', type=float, 83 | help='开启mix-up对训练数据进行数据增强时,使用增强的概率,默认 0.5') 84 | parser.add_argument('--mixup_alpha', default=1.1, dest='mixup_alpha', type=float, 85 | help='mix-up时两张图像混合的beta分布的参数,默认 1.1') 86 | parser.add_argument('--aug', default=False, dest='aug', action='store_true', 87 | help='进行数据增强,默认 False') 88 | parser.add_argument('--multi_scale', default=False, dest='multi_scale', action='store_true', 89 | help='多尺度训练,默认 False') 90 | 91 | parser.add_argument('--sparsity', default=False, dest='sparsity', action='store_true', 92 | help='是否使用network slimming训练稀疏网络,默认 False') 93 | parser.add_argument('--slim', default=5.e-4, type=float, dest='slim', 94 | help='network slimming中BN gamma的权重衰减系数,默认 5.e-4)') 95 | 96 | # 其他策略的参数设置 97 | parser.add_argument('-e', '--evaluate', dest='evaluate', default=False, action='store_true', 98 | help='在测试集上评估模型') 99 | parser.add_argument('--resume', default='', type=str, metavar='PATH', 100 | help='重加载已训练好的模型 (默认: none)') 101 | parser.add_argument('--jit', dest='jit', default=False, action='store_true', 102 | help='将模型转为jit格式!') 103 | 104 | parser.add_argument('--knowledge', dest='knowledge', default=None, type=str, 105 | choices=[None, 'train', 'test', 'val'], 106 | help='指定数据集,使用教师模型(配合resume选型指定)对该数据集进行预测,获取概率文件(知识)') 107 | parser.add_argument('--distill', default=False, dest='distill', action='store_true', 108 | help='模型蒸馏(需要教师模型输出的概率文件),默认 False') 109 | 110 | parser.add_argument('--visual_data', dest='visual_data', default=None, type=str, 111 | choices=[None, 'train', 'test', 'val'], 112 | help='指定数据集,对模型进行可视化 TIPs:也可查看数据增强方式后的可视化效果,但仅限于train集') 113 | parser.add_argument('--visual_method', dest='visual_method', default='cam', type=str, 114 | choices=['all', 'cam', 'grad-cam', 'grad-cam++'], help='模型进行可视化的方法') 115 | 116 | parser.add_argument('--make_curriculum', dest='make_curriculum', default=None, type=str, 117 | choices=[None, 'train', 'test', 'val'], 118 | help='指定数据集,制作课程学习中,不同样本损失权重的课程文件') 119 | parser.add_argument('--curriculum_thresholds', dest='curriculum_thresholds', default=None, 120 | type=float, nargs='+', 121 | help='样本大于阈值列表(由大到小排序)中相应阈值,会被分配到同一课程中,制作在同一课程文件中') 122 | parser.add_argument('--curriculum_weights', dest='curriculum_weights', default=None, 123 | type=float, nargs='+', 124 | help='与curriculum_thresholds对应,指定对应课程中样本的权重') 125 | parser.add_argument('--curriculum_learning', dest='curriculum_learning', 126 | default=False, action='store_true', 127 | help='是否进行课程学习,默认 False') 128 | 129 | 130 | class CriterionConstant: 131 | # weighted_bce中损失函数不同label间误分的loss权重,每一行表示该label误分为其他label的损失权重 132 | weights_for_bce = np.array([ 133 | [1., 0.5, 2.], 134 | [0.5, 1., 2.], # 猫(0)狗(1)有时比较接近相近,误分可以接受,但熊猫误分很严重 135 | [2., 2., 1.] 136 | ], dtype=np.float32) 137 | weights_for_ce = weights_for_bce 138 | # threshold_bce中损失函数不同label的概率阈值控制 139 | # 每一行表示该label的阈值下界,低于该阈值则不计算loss 140 | low_threshold_for_bce = np.array([ 141 | [-0.1, 0.05, 0.05], 142 | [0.05, -0.1, 0.05], 143 | [0.05, 0.05, -0.1] 144 | ], dtype=np.float32) 145 | low_threshold_for_ce = low_threshold_for_bce 146 | # 每一行表示该label的阈值下界,高该阈值则不计算loss 147 | up_threshold_for_bce = np.array([ 148 | [0.95, 1.1, 1.1], 149 | [1.1, 0.95, 1.1], 150 | [1.1, 1.1, 0.95] 151 | ], dtype=np.float32) 152 | up_threshold_for_ce = up_threshold_for_bce 153 | -------------------------------------------------------------------------------- /criterions/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | File __init__.py.py 4 | 5 | 损失函数 6 | softmax_cross_entropy: 多分类交叉熵损失函数及其变种 7 | binary_cross_entropy: 多标签二分类损失函数及其变种 8 | """ 9 | from .softmax_cross_entropy import ( 10 | CrossEntropyLoss, 11 | ThresholdCELoss, 12 | WeightedCELoss, 13 | GHMCELoss, 14 | OHMCELoss, 15 | HybridCELoss 16 | ) 17 | from .binary_cross_entropy import ( 18 | MultiLabelBCELoss, 19 | ThresholdBCELoss, 20 | WeightedBCELoss, 21 | GHMBCELoss, 22 | OHMBCELoss, 23 | HybridBCELoss, 24 | ) 25 | -------------------------------------------------------------------------------- /criterions/binary_cross_entropy/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | File __init__.py.py 4 | 5 | 多标签二分类损失函数及其变体 6 | """ 7 | from .basic_bce import MultiLabelBCELoss 8 | from .ghm_bce import GHMBCELoss 9 | from .hybrid_bce import HybridBCELoss 10 | from .ohm_bce import OHMBCELoss 11 | from .threshold_bce import ThresholdBCELoss 12 | from .weighted_bce import WeightedBCELoss 13 | -------------------------------------------------------------------------------- /criterions/binary_cross_entropy/basic_bce.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | File basic_bce.py 4 | 5 | 多标签二分类交叉熵损失函数 6 | """ 7 | import torch 8 | from torch import nn 9 | from torch.nn import functional as F 10 | 11 | from criterions.utils import decode_to_onehot 12 | 13 | 14 | class MultiLabelBCELoss(nn.Module): 15 | """多标签二分类交叉熵损失函数""" 16 | 17 | def __init__(self): 18 | super(MultiLabelBCELoss, self).__init__() 19 | 20 | def forward(self, predictions: torch.Tensor, targets: torch.Tensor, weights: torch.Tensor) -> torch.Tensor: 21 | """ 22 | 多标签二分类交叉熵损失函数 23 | :param predictions: 网络预测logit输出, (N, num_class) 24 | :param targets: 多标签二分类label,index形式(N,) 或 one-hot形式(N, num_class) 25 | :param weights: 每个样本的权重 (N,) 26 | :return: 损失值 27 | """ 28 | if targets.dim() != predictions.dim(): 29 | targets = decode_to_onehot(targets, predictions.size(-1)) 30 | if weights is None: 31 | weights = self.identity_weights(predictions, targets) 32 | if weights.dim() != predictions.dim(): 33 | weights = weights.repeat(predictions.size(-1), 1).T 34 | 35 | weights *= self.get_weights(predictions.sigmoid().detach(), targets) 36 | loss = F.binary_cross_entropy_with_logits(predictions, targets, weights, reduction='sum') 37 | return loss 38 | 39 | def get_weights(self, predictions: torch.Tensor, targets: torch.Tensor) -> torch.Tensor: 40 | """ 41 | 每一项多标签二分类交叉熵损失的权重 42 | :param predictions: 预测的概率矩阵(非logit形式,且detach了), (N, num_class) 43 | :param targets: 解码为one-hot形式的多标签二分类label,(N, num_class) 44 | :return: 每一项损失的权重,(N, num_class) 45 | """ 46 | return self.identity_weights(predictions, targets) 47 | 48 | @staticmethod 49 | def identity_weights(predictions: torch.Tensor, _: torch.Tensor) -> torch.Tensor: 50 | """ 51 | 返回与输入矩阵一样大小的、都为1的矩阵,作为单位损失权重 52 | :param predictions: 预测的概率矩阵, (N, num_class) 53 | :param _: 解码为one-hot形式的多标签二分类label,(N, num_class) 54 | :return: 单位损失权重,(N, num_class) 55 | """ 56 | return torch.ones_like(predictions) 57 | -------------------------------------------------------------------------------- /criterions/binary_cross_entropy/ghm_bce.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | File ghm_bce.py 4 | 5 | GHM加权的多标签二分类交叉熵损失 GHMCELoss 6 | 基于梯度密度的倒数对损失进行加权(密度越大,损失越小)的BCE损失 7 | """ 8 | import argparse 9 | import logging 10 | 11 | import torch 12 | 13 | from criterions.binary_cross_entropy.basic_bce import MultiLabelBCELoss 14 | 15 | 16 | class GHMBCELoss(MultiLabelBCELoss): 17 | """ GHM Classification Loss. 18 | "Gradient Harmonized Single-stage Detector". 19 | https://arxiv.org/abs/1811.05181 20 | """ 21 | def __init__(self, args: argparse.Namespace, bins: int = 30, momentum: float = 0.75): 22 | """ GHM多标签二分类损失函数 23 | :param args: 训练超参 24 | :param bins: Number of the unit regions for distribution calculation. 25 | :param momentum: The parameter for moving average. 26 | """ 27 | super(GHMBCELoss, self).__init__() 28 | self.bins = bins 29 | self.momentum = momentum 30 | self.edges = torch.arange(bins + 1).float() 31 | if args.cuda: 32 | self.edges = self.edges.cuda(args.gpu) 33 | self.edges /= bins 34 | self.edges[0] -= 1e-6 35 | self.edges[-1] += 1e-6 36 | 37 | if momentum > 0: 38 | self.acc_sum = torch.zeros(bins) 39 | if args.cuda: 40 | self.acc_sum = self.acc_sum.cuda(args.gpu) 41 | 42 | def get_weights(self, predictions: torch.Tensor, targets: torch.Tensor) -> torch.Tensor: 43 | """ 44 | 多标签二分类交叉熵损失 45 | :param predictions: 预测的概率矩阵,(batch_size, label_num) 46 | :param targets: 解码后的多标签二分类label概率矩阵,(batch_size, label_num) 47 | :return: 每一项损失的权重,(N, num_class) 48 | """ 49 | edges = self.edges 50 | mmt = self.momentum 51 | weights = torch.zeros_like(targets) 52 | 53 | # 计算梯度 54 | g = torch.abs(predictions - targets) 55 | total_items = targets.shape[0] * targets.shape[1] 56 | n = 0 # n valid bins 57 | for i in range(self.bins): 58 | inds = (g >= edges[i]) & (g < edges[i+1]) 59 | num_in_bin = inds.sum().item() 60 | if num_in_bin > 0: 61 | if mmt > 0: 62 | self.acc_sum[i] = mmt * self.acc_sum[i] + (1 - mmt) * num_in_bin 63 | weights[inds] = total_items / self.acc_sum[i] 64 | else: 65 | weights[inds] = total_items / num_in_bin 66 | n += 1 67 | 68 | if n > 0: 69 | weights = weights / n 70 | 71 | return weights 72 | 73 | 74 | if __name__ == '__main__': 75 | inputs = torch.tensor([[-3.1781, -2.9444, -3.8918, -4.5951, 4.5951], 76 | [-3.1781, -2.9444, 4.5951, -3.8918, -4.5951]]) 77 | 78 | real_targets = torch.tensor([[0., 0., 0., 0., 1.], 79 | [0., 0., 1., 0., 0.]]) 80 | temp_args = argparse.Namespace() 81 | temp_args.cuda = False 82 | ghm_loss = GHMBCELoss(temp_args) 83 | 84 | loss_value = ghm_loss(inputs.sigmoid().detach(), real_targets, None) 85 | print(f'first loss: {loss_value}') 86 | loss_value = ghm_loss(inputs.sigmoid().detach(), real_targets, None) 87 | print(f'first loss: {loss_value}') 88 | -------------------------------------------------------------------------------- /criterions/binary_cross_entropy/hybrid_bce.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | File hybrid_bce.py 4 | 5 | 混合策略的多标签二分类交叉熵损失 6 | """ 7 | import argparse 8 | 9 | import torch 10 | 11 | from .basic_bce import MultiLabelBCELoss 12 | from .threshold_bce import ThresholdBCELoss 13 | from .weighted_bce import WeightedBCELoss 14 | from .ohm_bce import OHMBCELoss 15 | from .ghm_bce import GHMBCELoss 16 | 17 | 18 | class HybridBCELoss(MultiLabelBCELoss): 19 | """ 自定义的多标签二分类损失函数,包含weighted loss,threshold loss,ghm loss,ohm loss等功能 """ 20 | 21 | def __init__(self, args: argparse.Namespace): 22 | super(HybridBCELoss, self).__init__() 23 | self.threshold_func = self.identity_weights 24 | self.weighted_func = self.identity_weights 25 | self.ohm_func = self.identity_weights 26 | self.ghm_func = self.identity_weights 27 | 28 | if args.threshold_loss: 29 | self.threshold_func = ThresholdBCELoss(args).get_weights 30 | if args.weighted_loss: 31 | self.weighted_func = WeightedBCELoss(args).get_weights 32 | if args.ohm_loss: 33 | self.ohm_func = OHMBCELoss(args).get_weights 34 | if args.ghm_loss: 35 | self.ghm_func = GHMBCELoss(args).get_weights 36 | 37 | def get_weights(self, predictions: torch.Tensor, targets: torch.Tensor) -> torch.Tensor: 38 | """ 39 | 多标签二分类交叉熵损失 40 | :param predictions: 预测的概率矩阵,(batch_size, label_num) 41 | :param targets: 解码后的多标签二分类label概率矩阵,(batch_size, label_num) 42 | :return: 每一项损失的权重,(N, num_class) 43 | """ 44 | weights = self.threshold_func(predictions, targets) 45 | weights *= self.weighted_func(predictions, targets) 46 | weights *= self.ohm_func(predictions, targets) 47 | weights *= self.ghm_func(predictions, targets) 48 | return weights 49 | -------------------------------------------------------------------------------- /criterions/binary_cross_entropy/ohm_bce.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | File ohm_bce.py 4 | 5 | 基于在线困难样本挖掘的多标签二分类交叉熵损失 OHMBCELoss 6 | """ 7 | import argparse 8 | 9 | import torch 10 | 11 | from .basic_bce import MultiLabelBCELoss 12 | 13 | 14 | class OHMBCELoss(MultiLabelBCELoss): 15 | """ 困难样本学习 """ 16 | def __init__(self, args: argparse.Namespace): 17 | super(OHMBCELoss, self).__init__() 18 | self.hard_ratio = args.hard_ratio 19 | 20 | def get_weights(self, predictions: torch.Tensor, targets: torch.Tensor) -> torch.Tensor: 21 | """ 22 | 多标签二分类交叉熵损失 23 | :param predictions: 预测的概率矩阵,(batch_size, label_num) 24 | :param targets: 解码后的多标签二分类label概率矩阵,(batch_size, label_num) 25 | :return: 每一项损失的权重,(N, num_class) 26 | """ 27 | gradients = torch.abs(predictions - targets) 28 | threshold = gradients.flatten().sort()[0][int(targets.numel() * self.hard_ratio)] 29 | weights = (gradients > threshold).float() 30 | return weights 31 | -------------------------------------------------------------------------------- /criterions/binary_cross_entropy/threshold_bce.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | File threshold_bce.py 4 | 5 | 带阈值控制的多标签二分类交叉熵损失 ThresholdCELoss 6 | 非目标类需要大于指定阈值才计算loss,目标类需要小于指定阈值才计算loss, 7 | 与制作软标签、hard-example mining有类似思想 8 | """ 9 | import argparse 10 | 11 | import torch 12 | 13 | from .basic_bce import MultiLabelBCELoss 14 | from config import CriterionConstant 15 | 16 | 17 | class ThresholdBCELoss(MultiLabelBCELoss): 18 | """ 借用软标签和困难样本的思想 """ 19 | def __init__(self, args: argparse.Namespace): 20 | super(ThresholdBCELoss, self).__init__() 21 | self.low_threshold = torch.from_numpy(CriterionConstant.low_threshold_for_bce) 22 | self.up_threshold = torch.from_numpy(CriterionConstant.up_threshold_for_bce) 23 | if args.cuda: 24 | self.low_threshold = self.low_threshold.cuda(args.gpu) 25 | self.up_threshold = self.up_threshold.cuda(args.gpu) 26 | 27 | def get_weights(self, predictions: torch.Tensor, targets: torch.Tensor) -> torch.Tensor: 28 | """ 29 | 多标签二分类交叉熵损失 30 | :param predictions: 预测的概率矩阵,(batch_size, label_num) 31 | :param targets: 解码后的多标签二分类label概率矩阵,(batch_size, label_num) 32 | :return: 每一项损失的权重,(N, num_class) 33 | """ 34 | low_threshold = targets @ self.low_threshold 35 | up_threshold = targets @ self.up_threshold 36 | weights = ((predictions > low_threshold) & (predictions < up_threshold)).float() 37 | return weights 38 | -------------------------------------------------------------------------------- /criterions/binary_cross_entropy/weighted_bce.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | File weighted_bce.py 4 | 5 | 加权多标签二分类交叉熵损失 WeightedCELoss,不同的误分类情况具有不同的损失权重 6 | """ 7 | import argparse 8 | 9 | import torch 10 | 11 | from .basic_bce import MultiLabelBCELoss 12 | from config import CriterionConstant 13 | 14 | 15 | class WeightedBCELoss(MultiLabelBCELoss): 16 | """ 借用偏序loss的思想 """ 17 | def __init__(self, args: argparse.Namespace): 18 | super(WeightedBCELoss, self).__init__() 19 | self.label_weights = torch.from_numpy(CriterionConstant.weights_for_bce) 20 | if args.cuda: 21 | self.label_weights = self.label_weights.cuda(args.gpu) 22 | 23 | def get_weights(self, predictions: torch.Tensor, targets: torch.Tensor) -> torch.Tensor: 24 | """ 25 | 多标签二分类交叉熵损失 26 | :param predictions: 预测的概率矩阵,(batch_size, label_num) 27 | :param targets: 解码后的多标签二分类label概率矩阵,(batch_size, label_num) 28 | :return: 每一项损失的权重,(N, num_class) 29 | """ 30 | return targets @ self.label_weights 31 | -------------------------------------------------------------------------------- /criterions/softmax_cross_entropy/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | File __init__.py.py 4 | 5 | softmax交叉熵损失函数及其变体 6 | """ 7 | from .basic_softmax import CrossEntropyLoss 8 | from .ghm_softmax import GHMCELoss 9 | from .hybrid_softmax import HybridCELoss 10 | from .ohm_softmax import OHMCELoss 11 | from .threshold_softmax import ThresholdCELoss 12 | from .weighted_softmax import WeightedCELoss 13 | -------------------------------------------------------------------------------- /criterions/softmax_cross_entropy/basic_softmax.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | File basic_softmax.py 4 | 5 | softmax交叉熵损失函数,分类index label和one hot label形式都支持 6 | """ 7 | import torch 8 | from torch import nn 9 | 10 | from criterions.utils import decode_to_onehot 11 | 12 | 13 | class CrossEntropyLoss(nn.Module): 14 | """ 多分类交叉熵损失 """ 15 | def __init__(self): 16 | super(CrossEntropyLoss, self).__init__() 17 | 18 | def forward(self, predictions: torch.Tensor, targets: torch.Tensor, weights: torch.Tensor) -> torch.Tensor: 19 | """ 20 | 多分类交叉熵损失 21 | :param predictions: 预测的logit矩阵,(batch_size, label_num) 22 | :param targets: 多分类label,如果是非one-hot形状(batch_size,),则需要先解码 23 | :param weights: 每个样本的权重 24 | :return: 损失值 25 | """ 26 | if predictions.dim() != targets.dim(): 27 | targets = decode_to_onehot(targets, predictions.size(-1)) 28 | if weights is None: 29 | weights = self.identity_weights(predictions, targets) 30 | if weights.dim() != predictions.dim(): 31 | weights = weights.repeat(predictions.size(-1), 1).T 32 | 33 | logged_x_pred = predictions.log_softmax(dim=1) 34 | weights *= self.get_weights(predictions.softmax(dim=1).detach(), targets) 35 | return -torch.sum(targets * logged_x_pred * weights, dim=1).sum() 36 | 37 | def get_weights(self, predictions: torch.Tensor, targets: torch.Tensor) -> torch.Tensor: 38 | """ 39 | 多标签二分类交叉熵损失的权重 40 | :param predictions: 预测的概率矩阵,(batch_size, label_num) 41 | :param targets: 解码后的多标签二分类label概率矩阵,(batch_size, label_num) 42 | :return: 与predictions同维度的权重矩阵,(batch_size, label_num) 43 | """ 44 | return self.identity_weights(predictions, targets) 45 | 46 | @staticmethod 47 | def identity_weights(predictions: torch.Tensor, _: torch.Tensor) -> torch.Tensor: 48 | """ 49 | 多分类交叉熵损失的单位权重的dummy函数 50 | :param predictions: 预测的概率矩阵,(batch_size, label_num) 51 | :param _: 解码后的多标签二分类label概率矩阵,(batch_size, label_num) 52 | :return: 与predictions同维度的单位矩阵 53 | """ 54 | return torch.ones_like(predictions) 55 | -------------------------------------------------------------------------------- /criterions/softmax_cross_entropy/ghm_softmax.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | File ghm_softmax.py 4 | 5 | GHM加权的多分类交叉熵损失 GHMCELoss 6 | 基于梯度密度的倒数对损失进行加权(密度越大,损失越小)的BCE损失 7 | """ 8 | import argparse 9 | 10 | import torch 11 | 12 | from criterions.softmax_cross_entropy.basic_softmax import CrossEntropyLoss 13 | 14 | 15 | class GHMCELoss(CrossEntropyLoss): 16 | """ GHM Classification Loss. 17 | "Gradient Harmonized Single-stage Detector". 18 | https://arxiv.org/abs/1811.05181 19 | """ 20 | def __init__(self, args: argparse.Namespace, bins: int = 30, momentum: float = 0.75): 21 | """ 22 | GHM多分类损失函数 23 | :param args: 训练超参 24 | :param bins: Number of the unit regions for distribution calculation. 25 | :param momentum: The parameter for moving average. 26 | """ 27 | super(GHMCELoss, self).__init__() 28 | self.bins = bins 29 | self.momentum = momentum 30 | self.edges = torch.arange(bins + 1).float() 31 | if args.cuda: 32 | self.edges = self.edges.cuda(args.gpu) 33 | self.edges /= bins 34 | self.edges[0] -= 1e-6 35 | self.edges[-1] += 1e-6 36 | 37 | if momentum > 0: 38 | self.acc_sum = torch.zeros(bins) 39 | if args.cuda: 40 | self.acc_sum = self.acc_sum.cuda(args.gpu) 41 | 42 | def get_weights(self, predictions: torch.Tensor, targets: torch.Tensor) -> torch.Tensor: 43 | """ 44 | 多分类交叉熵损失 45 | :param predictions: 预测的概率矩阵,(batch_size, label_num) 46 | :param targets: 解码后的多分类label概率矩阵,(batch_size, label_num) 47 | :return: 与predictions同维度的权重矩阵,(batch_size, label_num) 48 | """ 49 | edges = self.edges 50 | mmt = self.momentum 51 | weights = torch.zeros_like(targets) 52 | 53 | # 计算梯度 54 | g = torch.abs(predictions - targets) 55 | total_items = targets.shape[0] * targets.shape[1] 56 | n = 0 # n valid bins 57 | for i in range(self.bins): 58 | inds = (g >= edges[i]) & (g < edges[i+1]) 59 | num_in_bin = inds.sum().item() 60 | if num_in_bin > 0: 61 | if mmt > 0: 62 | self.acc_sum[i] = mmt * self.acc_sum[i] + (1 - mmt) * num_in_bin 63 | weights[inds] = total_items / self.acc_sum[i] 64 | else: 65 | weights[inds] = total_items / num_in_bin 66 | n += 1 67 | 68 | if n > 0: 69 | weights = weights / n 70 | 71 | return weights 72 | 73 | 74 | if __name__ == '__main__': 75 | import argparse 76 | inputs = torch.tensor([[-3.1781, -2.9444, -3.8918, -4.5951, 4.5951], 77 | [-3.1781, -2.9444, 4.5951, -3.8918, -4.5951]]) 78 | 79 | real_targets = torch.tensor([[0., 0., 0., 0., 1.], 80 | [0., 0., 1., 0., 0.]]) 81 | temp_args = argparse.Namespace() 82 | temp_args.cuda = False 83 | ghm_loss = GHMCELoss(temp_args) 84 | 85 | loss_value = ghm_loss(inputs.sigmoid().detach(), real_targets, None) 86 | print(f'first loss: {loss_value}') 87 | loss_value = ghm_loss(inputs.sigmoid().detach(), real_targets, None) 88 | print(f'first loss: {loss_value}') 89 | -------------------------------------------------------------------------------- /criterions/softmax_cross_entropy/hybrid_softmax.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | File hybrid_softmax.py 4 | 5 | 混合策略的多分类交叉熵损失 6 | """ 7 | import argparse 8 | 9 | import torch 10 | 11 | from .basic_softmax import CrossEntropyLoss 12 | from .threshold_softmax import ThresholdCELoss 13 | from .weighted_softmax import WeightedCELoss 14 | from .ohm_softmax import OHMCELoss 15 | from .ghm_softmax import GHMCELoss 16 | 17 | 18 | class HybridCELoss(CrossEntropyLoss): 19 | """ 自定义的多标签二分类损失函数,包含weighted loss,threshold loss,ghm loss,ohm loss等功能 """ 20 | 21 | def __init__(self, args: argparse.Namespace): 22 | super(HybridCELoss, self).__init__() 23 | self.threshold_func = self.identity_weights 24 | self.weighted_func = self.identity_weights 25 | self.ohm_func = self.identity_weights 26 | self.ghm_func = self.identity_weights 27 | 28 | if args.threshold_loss: 29 | self.threshold_func = ThresholdCELoss(args).get_weights 30 | if args.weighted_loss: 31 | self.weighted_func = WeightedCELoss(args).get_weights 32 | if args.ohm_loss: 33 | self.ohm_func = OHMCELoss(args).get_weights 34 | if args.ghm_loss: 35 | self.ghm_func = GHMCELoss(args).get_weights 36 | 37 | def get_weights(self, predictions: torch.Tensor, targets: torch.Tensor) -> torch.Tensor: 38 | """ 39 | 多标签二分类交叉熵损失 40 | :param predictions: 预测的概率矩阵,(batch_size, label_num) 41 | :param targets: 解码后的多分类label概率矩阵,(batch_size, label_num) 42 | :return: 与predictions同维度的权重矩阵,(batch_size, label_num) 43 | """ 44 | weights = self.threshold_func(predictions, targets) 45 | weights *= self.weighted_func(predictions, targets) 46 | weights *= self.ohm_func(predictions, targets) 47 | weights *= self.ghm_func(predictions, targets) 48 | return weights 49 | -------------------------------------------------------------------------------- /criterions/softmax_cross_entropy/ohm_softmax.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | File ohm_softmax.py 4 | 5 | 加入OHM的softmax交叉熵损失函数,分类index label和one hot label形式都支持 6 | """ 7 | import argparse 8 | 9 | import torch 10 | 11 | from .basic_softmax import CrossEntropyLoss 12 | 13 | 14 | class OHMCELoss(CrossEntropyLoss): 15 | """ 困难样本学习 """ 16 | def __init__(self, args: argparse.Namespace): 17 | super(OHMCELoss, self).__init__() 18 | self.hard_ratio = args.hard_ratio 19 | 20 | def get_weights(self, predictions: torch.Tensor, targets: torch.Tensor) -> torch.Tensor: 21 | """ 22 | 多分类交叉熵损失 23 | :param predictions: 预测的概率矩阵,(batch_size, label_num) 24 | :param targets: 解码后的多分类label概率矩阵,(batch_size, label_num) 25 | :return: 与predictions同维度的权重矩阵,(batch_size, label_num) 26 | """ 27 | gradients, _ = torch.abs(predictions - targets).max(dim=1) 28 | ohm_positions = gradients.sort()[1][:int(gradients.numel() * self.hard_ratio)] 29 | weights = torch.zeros_like(predictions) 30 | weights[ohm_positions] = 1 31 | return weights 32 | -------------------------------------------------------------------------------- /criterions/softmax_cross_entropy/threshold_softmax.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | File threshold_softmax.py 4 | 5 | 带阈值控制的多分类交叉熵损失 ThresholdCELoss 6 | 非目标类需要大于指定阈值才计算loss,目标类需要小于指定阈值才计算loss, 7 | 与制作软标签、hard-example mining有类似思想 8 | """ 9 | import argparse 10 | 11 | import torch 12 | 13 | from .basic_softmax import CrossEntropyLoss 14 | from config import CriterionConstant 15 | 16 | 17 | class ThresholdCELoss(CrossEntropyLoss): 18 | """ 借用软标签和困难样本的思想 """ 19 | def __init__(self, args: argparse.Namespace): 20 | super(ThresholdCELoss, self).__init__() 21 | self.low_threshold = torch.from_numpy(CriterionConstant.low_threshold_for_ce) 22 | self.up_threshold = torch.from_numpy(CriterionConstant.up_threshold_for_ce) 23 | if args.cuda: 24 | self.low_threshold = self.low_threshold.cuda(args.gpu) 25 | self.up_threshold = self.up_threshold.cuda(args.gpu) 26 | 27 | def get_weights(self, predictions: torch.Tensor, targets: torch.Tensor) -> torch.Tensor: 28 | """ 29 | 多分类交叉熵损失 30 | :param predictions: 预测的概率矩阵,(batch_size, label_num) 31 | :param targets: 解码后的多分类label概率矩阵,(batch_size, label_num) 32 | :return: 与predictions同维度的权重矩阵,(batch_size, label_num) 33 | """ 34 | low_threshold = targets @ self.low_threshold 35 | up_threshold = targets @ self.up_threshold 36 | weights = ((predictions > low_threshold) & (predictions < up_threshold)).float() 37 | return weights 38 | -------------------------------------------------------------------------------- /criterions/softmax_cross_entropy/weighted_softmax.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | File weighted_softmax.py 4 | 5 | 加权多分类交叉熵损失 WeightedCELoss,不同的误分类情况具有不同的损失权重 6 | """ 7 | import argparse 8 | 9 | import torch 10 | 11 | from .basic_softmax import CrossEntropyLoss 12 | from config import CriterionConstant 13 | 14 | 15 | class WeightedCELoss(CrossEntropyLoss): 16 | """ 借用偏序loss的思想 """ 17 | def __init__(self, args: argparse.Namespace): 18 | super(WeightedCELoss, self).__init__() 19 | self.label_weights = torch.from_numpy(CriterionConstant.weights_for_ce) 20 | if args.cuda: 21 | self.label_weights = self.label_weights.cuda(args.gpu) 22 | 23 | def get_weights(self, predictions: torch.Tensor, targets: torch.Tensor) -> torch.Tensor: 24 | """ 25 | 多标签二分类交叉熵损失 26 | :param predictions: 预测的概率矩阵,(batch_size, label_num) 27 | :param targets: 解码后的多标签二分类label概率矩阵,(batch_size, label_num) 28 | :return: 与predictions同维度的权重矩阵,(batch_size, label_num) 29 | """ 30 | return targets @ self.label_weights 31 | -------------------------------------------------------------------------------- /criterions/utils.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | File utils.py 4 | 5 | 损失函数相关工具 6 | decode_to_onehot: 将用index表示的多标签二分类label,转为onehot表示的概率矩阵 7 | """ 8 | import torch 9 | 10 | 11 | def decode_to_onehot(labels: torch.Tensor, num_classes: int) -> torch.Tensor: 12 | """ 13 | 将用index表示的多标签二分类label,转为onehot表示的概率矩阵 14 | :param labels: index表示的多标签二分类label,如 [0, 3, 4] 表示连续三个样本的标签为类别 0, 3, 4 15 | :param num_classes: 类别数,或者标签数 16 | :return: 一批 每个样本的标签都是onehot表示 组成的 概率矩阵 17 | """ 18 | batch_size = labels.size(0) 19 | onehot_labels = labels.new_full((batch_size, num_classes), 0) 20 | onehot_labels = onehot_labels.scatter(1, labels.unsqueeze(1), 1).float() 21 | return onehot_labels 22 | -------------------------------------------------------------------------------- /data/test/cats/cats_00001.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zheng-yuwei/PyTorch-Image-Classification/13030bd157a499b80d1860b8b654a66224eaf475/data/test/cats/cats_00001.jpg -------------------------------------------------------------------------------- /data/test/cats/cats_00002.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zheng-yuwei/PyTorch-Image-Classification/13030bd157a499b80d1860b8b654a66224eaf475/data/test/cats/cats_00002.jpg -------------------------------------------------------------------------------- /data/test/cats/cats_00003.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zheng-yuwei/PyTorch-Image-Classification/13030bd157a499b80d1860b8b654a66224eaf475/data/test/cats/cats_00003.jpg -------------------------------------------------------------------------------- /data/test/cats/cats_00004.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zheng-yuwei/PyTorch-Image-Classification/13030bd157a499b80d1860b8b654a66224eaf475/data/test/cats/cats_00004.jpg -------------------------------------------------------------------------------- /data/test/cats/cats_00005.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zheng-yuwei/PyTorch-Image-Classification/13030bd157a499b80d1860b8b654a66224eaf475/data/test/cats/cats_00005.jpg -------------------------------------------------------------------------------- /data/test/cats/cats_00006.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zheng-yuwei/PyTorch-Image-Classification/13030bd157a499b80d1860b8b654a66224eaf475/data/test/cats/cats_00006.jpg -------------------------------------------------------------------------------- /data/test/cats/cats_00007.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zheng-yuwei/PyTorch-Image-Classification/13030bd157a499b80d1860b8b654a66224eaf475/data/test/cats/cats_00007.jpg -------------------------------------------------------------------------------- /data/test/cats/cats_00008.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zheng-yuwei/PyTorch-Image-Classification/13030bd157a499b80d1860b8b654a66224eaf475/data/test/cats/cats_00008.jpg -------------------------------------------------------------------------------- /data/test/cats/cats_00009.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zheng-yuwei/PyTorch-Image-Classification/13030bd157a499b80d1860b8b654a66224eaf475/data/test/cats/cats_00009.jpg -------------------------------------------------------------------------------- /data/test/cats/cats_00010.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zheng-yuwei/PyTorch-Image-Classification/13030bd157a499b80d1860b8b654a66224eaf475/data/test/cats/cats_00010.jpg -------------------------------------------------------------------------------- /data/test/dogs/dogs_00001.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zheng-yuwei/PyTorch-Image-Classification/13030bd157a499b80d1860b8b654a66224eaf475/data/test/dogs/dogs_00001.jpg -------------------------------------------------------------------------------- /data/test/dogs/dogs_00002.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zheng-yuwei/PyTorch-Image-Classification/13030bd157a499b80d1860b8b654a66224eaf475/data/test/dogs/dogs_00002.jpg -------------------------------------------------------------------------------- /data/test/dogs/dogs_00003.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zheng-yuwei/PyTorch-Image-Classification/13030bd157a499b80d1860b8b654a66224eaf475/data/test/dogs/dogs_00003.jpg -------------------------------------------------------------------------------- /data/test/dogs/dogs_00004.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zheng-yuwei/PyTorch-Image-Classification/13030bd157a499b80d1860b8b654a66224eaf475/data/test/dogs/dogs_00004.jpg -------------------------------------------------------------------------------- /data/test/dogs/dogs_00005.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zheng-yuwei/PyTorch-Image-Classification/13030bd157a499b80d1860b8b654a66224eaf475/data/test/dogs/dogs_00005.jpg -------------------------------------------------------------------------------- /data/test/dogs/dogs_00006.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zheng-yuwei/PyTorch-Image-Classification/13030bd157a499b80d1860b8b654a66224eaf475/data/test/dogs/dogs_00006.jpg -------------------------------------------------------------------------------- /data/test/dogs/dogs_00007.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zheng-yuwei/PyTorch-Image-Classification/13030bd157a499b80d1860b8b654a66224eaf475/data/test/dogs/dogs_00007.jpg -------------------------------------------------------------------------------- /data/test/dogs/dogs_00008.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zheng-yuwei/PyTorch-Image-Classification/13030bd157a499b80d1860b8b654a66224eaf475/data/test/dogs/dogs_00008.jpg -------------------------------------------------------------------------------- /data/test/dogs/dogs_00009.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zheng-yuwei/PyTorch-Image-Classification/13030bd157a499b80d1860b8b654a66224eaf475/data/test/dogs/dogs_00009.jpg -------------------------------------------------------------------------------- /data/test/dogs/dogs_00010.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zheng-yuwei/PyTorch-Image-Classification/13030bd157a499b80d1860b8b654a66224eaf475/data/test/dogs/dogs_00010.jpg -------------------------------------------------------------------------------- /data/test/panda/panda_00001.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zheng-yuwei/PyTorch-Image-Classification/13030bd157a499b80d1860b8b654a66224eaf475/data/test/panda/panda_00001.jpg -------------------------------------------------------------------------------- /data/test/panda/panda_00002.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zheng-yuwei/PyTorch-Image-Classification/13030bd157a499b80d1860b8b654a66224eaf475/data/test/panda/panda_00002.jpg -------------------------------------------------------------------------------- /data/test/panda/panda_00003.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zheng-yuwei/PyTorch-Image-Classification/13030bd157a499b80d1860b8b654a66224eaf475/data/test/panda/panda_00003.jpg -------------------------------------------------------------------------------- /data/test/panda/panda_00004.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zheng-yuwei/PyTorch-Image-Classification/13030bd157a499b80d1860b8b654a66224eaf475/data/test/panda/panda_00004.jpg -------------------------------------------------------------------------------- /data/test/panda/panda_00005.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zheng-yuwei/PyTorch-Image-Classification/13030bd157a499b80d1860b8b654a66224eaf475/data/test/panda/panda_00005.jpg -------------------------------------------------------------------------------- /data/test/panda/panda_00006.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zheng-yuwei/PyTorch-Image-Classification/13030bd157a499b80d1860b8b654a66224eaf475/data/test/panda/panda_00006.jpg -------------------------------------------------------------------------------- /data/test/panda/panda_00007.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zheng-yuwei/PyTorch-Image-Classification/13030bd157a499b80d1860b8b654a66224eaf475/data/test/panda/panda_00007.jpg -------------------------------------------------------------------------------- /data/test/panda/panda_00008.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zheng-yuwei/PyTorch-Image-Classification/13030bd157a499b80d1860b8b654a66224eaf475/data/test/panda/panda_00008.jpg -------------------------------------------------------------------------------- /data/test/panda/panda_00009.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zheng-yuwei/PyTorch-Image-Classification/13030bd157a499b80d1860b8b654a66224eaf475/data/test/panda/panda_00009.jpg -------------------------------------------------------------------------------- /data/test/panda/panda_00010.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zheng-yuwei/PyTorch-Image-Classification/13030bd157a499b80d1860b8b654a66224eaf475/data/test/panda/panda_00010.jpg -------------------------------------------------------------------------------- /data/train/cats/cats_00001.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zheng-yuwei/PyTorch-Image-Classification/13030bd157a499b80d1860b8b654a66224eaf475/data/train/cats/cats_00001.jpg -------------------------------------------------------------------------------- /data/train/cats/cats_00002.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zheng-yuwei/PyTorch-Image-Classification/13030bd157a499b80d1860b8b654a66224eaf475/data/train/cats/cats_00002.jpg -------------------------------------------------------------------------------- /data/train/cats/cats_00003.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zheng-yuwei/PyTorch-Image-Classification/13030bd157a499b80d1860b8b654a66224eaf475/data/train/cats/cats_00003.jpg -------------------------------------------------------------------------------- /data/train/cats/cats_00004.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zheng-yuwei/PyTorch-Image-Classification/13030bd157a499b80d1860b8b654a66224eaf475/data/train/cats/cats_00004.jpg -------------------------------------------------------------------------------- /data/train/cats/cats_00005.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zheng-yuwei/PyTorch-Image-Classification/13030bd157a499b80d1860b8b654a66224eaf475/data/train/cats/cats_00005.jpg -------------------------------------------------------------------------------- /data/train/cats/cats_00006.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zheng-yuwei/PyTorch-Image-Classification/13030bd157a499b80d1860b8b654a66224eaf475/data/train/cats/cats_00006.jpg -------------------------------------------------------------------------------- /data/train/cats/cats_00007.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zheng-yuwei/PyTorch-Image-Classification/13030bd157a499b80d1860b8b654a66224eaf475/data/train/cats/cats_00007.jpg -------------------------------------------------------------------------------- /data/train/cats/cats_00008.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zheng-yuwei/PyTorch-Image-Classification/13030bd157a499b80d1860b8b654a66224eaf475/data/train/cats/cats_00008.jpg -------------------------------------------------------------------------------- /data/train/cats/cats_00009.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zheng-yuwei/PyTorch-Image-Classification/13030bd157a499b80d1860b8b654a66224eaf475/data/train/cats/cats_00009.jpg -------------------------------------------------------------------------------- /data/train/cats/cats_00010.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zheng-yuwei/PyTorch-Image-Classification/13030bd157a499b80d1860b8b654a66224eaf475/data/train/cats/cats_00010.jpg -------------------------------------------------------------------------------- /data/train/dogs/dogs_00001.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zheng-yuwei/PyTorch-Image-Classification/13030bd157a499b80d1860b8b654a66224eaf475/data/train/dogs/dogs_00001.jpg -------------------------------------------------------------------------------- /data/train/dogs/dogs_00002.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zheng-yuwei/PyTorch-Image-Classification/13030bd157a499b80d1860b8b654a66224eaf475/data/train/dogs/dogs_00002.jpg -------------------------------------------------------------------------------- /data/train/dogs/dogs_00003.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zheng-yuwei/PyTorch-Image-Classification/13030bd157a499b80d1860b8b654a66224eaf475/data/train/dogs/dogs_00003.jpg -------------------------------------------------------------------------------- /data/train/dogs/dogs_00004.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zheng-yuwei/PyTorch-Image-Classification/13030bd157a499b80d1860b8b654a66224eaf475/data/train/dogs/dogs_00004.jpg -------------------------------------------------------------------------------- /data/train/dogs/dogs_00005.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zheng-yuwei/PyTorch-Image-Classification/13030bd157a499b80d1860b8b654a66224eaf475/data/train/dogs/dogs_00005.jpg -------------------------------------------------------------------------------- /data/train/dogs/dogs_00006.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zheng-yuwei/PyTorch-Image-Classification/13030bd157a499b80d1860b8b654a66224eaf475/data/train/dogs/dogs_00006.jpg -------------------------------------------------------------------------------- /data/train/dogs/dogs_00007.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zheng-yuwei/PyTorch-Image-Classification/13030bd157a499b80d1860b8b654a66224eaf475/data/train/dogs/dogs_00007.jpg -------------------------------------------------------------------------------- /data/train/dogs/dogs_00008.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zheng-yuwei/PyTorch-Image-Classification/13030bd157a499b80d1860b8b654a66224eaf475/data/train/dogs/dogs_00008.jpg -------------------------------------------------------------------------------- /data/train/dogs/dogs_00009.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zheng-yuwei/PyTorch-Image-Classification/13030bd157a499b80d1860b8b654a66224eaf475/data/train/dogs/dogs_00009.jpg -------------------------------------------------------------------------------- /data/train/dogs/dogs_00010.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zheng-yuwei/PyTorch-Image-Classification/13030bd157a499b80d1860b8b654a66224eaf475/data/train/dogs/dogs_00010.jpg -------------------------------------------------------------------------------- /data/train/panda/panda_00001.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zheng-yuwei/PyTorch-Image-Classification/13030bd157a499b80d1860b8b654a66224eaf475/data/train/panda/panda_00001.jpg -------------------------------------------------------------------------------- /data/train/panda/panda_00002.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zheng-yuwei/PyTorch-Image-Classification/13030bd157a499b80d1860b8b654a66224eaf475/data/train/panda/panda_00002.jpg -------------------------------------------------------------------------------- /data/train/panda/panda_00003.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zheng-yuwei/PyTorch-Image-Classification/13030bd157a499b80d1860b8b654a66224eaf475/data/train/panda/panda_00003.jpg -------------------------------------------------------------------------------- /data/train/panda/panda_00004.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zheng-yuwei/PyTorch-Image-Classification/13030bd157a499b80d1860b8b654a66224eaf475/data/train/panda/panda_00004.jpg -------------------------------------------------------------------------------- /data/train/panda/panda_00005.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zheng-yuwei/PyTorch-Image-Classification/13030bd157a499b80d1860b8b654a66224eaf475/data/train/panda/panda_00005.jpg -------------------------------------------------------------------------------- /data/train/panda/panda_00006.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zheng-yuwei/PyTorch-Image-Classification/13030bd157a499b80d1860b8b654a66224eaf475/data/train/panda/panda_00006.jpg -------------------------------------------------------------------------------- /data/train/panda/panda_00007.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zheng-yuwei/PyTorch-Image-Classification/13030bd157a499b80d1860b8b654a66224eaf475/data/train/panda/panda_00007.jpg -------------------------------------------------------------------------------- /data/train/panda/panda_00008.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zheng-yuwei/PyTorch-Image-Classification/13030bd157a499b80d1860b8b654a66224eaf475/data/train/panda/panda_00008.jpg -------------------------------------------------------------------------------- /data/train/panda/panda_00009.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zheng-yuwei/PyTorch-Image-Classification/13030bd157a499b80d1860b8b654a66224eaf475/data/train/panda/panda_00009.jpg -------------------------------------------------------------------------------- /data/train/panda/panda_00010.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zheng-yuwei/PyTorch-Image-Classification/13030bd157a499b80d1860b8b654a66224eaf475/data/train/panda/panda_00010.jpg -------------------------------------------------------------------------------- /data/val/cats/cats_00001.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zheng-yuwei/PyTorch-Image-Classification/13030bd157a499b80d1860b8b654a66224eaf475/data/val/cats/cats_00001.jpg -------------------------------------------------------------------------------- /data/val/cats/cats_00002.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zheng-yuwei/PyTorch-Image-Classification/13030bd157a499b80d1860b8b654a66224eaf475/data/val/cats/cats_00002.jpg -------------------------------------------------------------------------------- /data/val/cats/cats_00003.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zheng-yuwei/PyTorch-Image-Classification/13030bd157a499b80d1860b8b654a66224eaf475/data/val/cats/cats_00003.jpg -------------------------------------------------------------------------------- /data/val/cats/cats_00004.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zheng-yuwei/PyTorch-Image-Classification/13030bd157a499b80d1860b8b654a66224eaf475/data/val/cats/cats_00004.jpg -------------------------------------------------------------------------------- /data/val/cats/cats_00005.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zheng-yuwei/PyTorch-Image-Classification/13030bd157a499b80d1860b8b654a66224eaf475/data/val/cats/cats_00005.jpg -------------------------------------------------------------------------------- /data/val/cats/cats_00006.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zheng-yuwei/PyTorch-Image-Classification/13030bd157a499b80d1860b8b654a66224eaf475/data/val/cats/cats_00006.jpg -------------------------------------------------------------------------------- /data/val/cats/cats_00007.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zheng-yuwei/PyTorch-Image-Classification/13030bd157a499b80d1860b8b654a66224eaf475/data/val/cats/cats_00007.jpg -------------------------------------------------------------------------------- /data/val/cats/cats_00008.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zheng-yuwei/PyTorch-Image-Classification/13030bd157a499b80d1860b8b654a66224eaf475/data/val/cats/cats_00008.jpg -------------------------------------------------------------------------------- /data/val/cats/cats_00009.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zheng-yuwei/PyTorch-Image-Classification/13030bd157a499b80d1860b8b654a66224eaf475/data/val/cats/cats_00009.jpg -------------------------------------------------------------------------------- /data/val/cats/cats_00010.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zheng-yuwei/PyTorch-Image-Classification/13030bd157a499b80d1860b8b654a66224eaf475/data/val/cats/cats_00010.jpg -------------------------------------------------------------------------------- /data/val/dogs/dogs_00001.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zheng-yuwei/PyTorch-Image-Classification/13030bd157a499b80d1860b8b654a66224eaf475/data/val/dogs/dogs_00001.jpg -------------------------------------------------------------------------------- /data/val/dogs/dogs_00002.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zheng-yuwei/PyTorch-Image-Classification/13030bd157a499b80d1860b8b654a66224eaf475/data/val/dogs/dogs_00002.jpg -------------------------------------------------------------------------------- /data/val/dogs/dogs_00003.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zheng-yuwei/PyTorch-Image-Classification/13030bd157a499b80d1860b8b654a66224eaf475/data/val/dogs/dogs_00003.jpg -------------------------------------------------------------------------------- /data/val/dogs/dogs_00004.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zheng-yuwei/PyTorch-Image-Classification/13030bd157a499b80d1860b8b654a66224eaf475/data/val/dogs/dogs_00004.jpg -------------------------------------------------------------------------------- /data/val/dogs/dogs_00005.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zheng-yuwei/PyTorch-Image-Classification/13030bd157a499b80d1860b8b654a66224eaf475/data/val/dogs/dogs_00005.jpg -------------------------------------------------------------------------------- /data/val/dogs/dogs_00006.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zheng-yuwei/PyTorch-Image-Classification/13030bd157a499b80d1860b8b654a66224eaf475/data/val/dogs/dogs_00006.jpg -------------------------------------------------------------------------------- /data/val/dogs/dogs_00007.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zheng-yuwei/PyTorch-Image-Classification/13030bd157a499b80d1860b8b654a66224eaf475/data/val/dogs/dogs_00007.jpg -------------------------------------------------------------------------------- /data/val/dogs/dogs_00008.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zheng-yuwei/PyTorch-Image-Classification/13030bd157a499b80d1860b8b654a66224eaf475/data/val/dogs/dogs_00008.jpg -------------------------------------------------------------------------------- /data/val/dogs/dogs_00009.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zheng-yuwei/PyTorch-Image-Classification/13030bd157a499b80d1860b8b654a66224eaf475/data/val/dogs/dogs_00009.jpg -------------------------------------------------------------------------------- /data/val/dogs/dogs_00010.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zheng-yuwei/PyTorch-Image-Classification/13030bd157a499b80d1860b8b654a66224eaf475/data/val/dogs/dogs_00010.jpg -------------------------------------------------------------------------------- /data/val/panda/panda_00001.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zheng-yuwei/PyTorch-Image-Classification/13030bd157a499b80d1860b8b654a66224eaf475/data/val/panda/panda_00001.jpg -------------------------------------------------------------------------------- /data/val/panda/panda_00002.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zheng-yuwei/PyTorch-Image-Classification/13030bd157a499b80d1860b8b654a66224eaf475/data/val/panda/panda_00002.jpg -------------------------------------------------------------------------------- /data/val/panda/panda_00003.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zheng-yuwei/PyTorch-Image-Classification/13030bd157a499b80d1860b8b654a66224eaf475/data/val/panda/panda_00003.jpg -------------------------------------------------------------------------------- /data/val/panda/panda_00004.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zheng-yuwei/PyTorch-Image-Classification/13030bd157a499b80d1860b8b654a66224eaf475/data/val/panda/panda_00004.jpg -------------------------------------------------------------------------------- /data/val/panda/panda_00005.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zheng-yuwei/PyTorch-Image-Classification/13030bd157a499b80d1860b8b654a66224eaf475/data/val/panda/panda_00005.jpg -------------------------------------------------------------------------------- /data/val/panda/panda_00006.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zheng-yuwei/PyTorch-Image-Classification/13030bd157a499b80d1860b8b654a66224eaf475/data/val/panda/panda_00006.jpg -------------------------------------------------------------------------------- /data/val/panda/panda_00007.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zheng-yuwei/PyTorch-Image-Classification/13030bd157a499b80d1860b8b654a66224eaf475/data/val/panda/panda_00007.jpg -------------------------------------------------------------------------------- /data/val/panda/panda_00008.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zheng-yuwei/PyTorch-Image-Classification/13030bd157a499b80d1860b8b654a66224eaf475/data/val/panda/panda_00008.jpg -------------------------------------------------------------------------------- /data/val/panda/panda_00009.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zheng-yuwei/PyTorch-Image-Classification/13030bd157a499b80d1860b8b654a66224eaf475/data/val/panda/panda_00009.jpg -------------------------------------------------------------------------------- /data/val/panda/panda_00010.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zheng-yuwei/PyTorch-Image-Classification/13030bd157a499b80d1860b8b654a66224eaf475/data/val/panda/panda_00010.jpg -------------------------------------------------------------------------------- /dataloader/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | File __init__.py.py 4 | 5 | 图像预处理,数据加载等 6 | """ 7 | from .enhancement import ( 8 | MyAugment, 9 | ImageNetPolicy, CIFAR10Policy, SVHNPolicy, 10 | Rescale, 11 | MultiScale, 12 | MixUp 13 | ) 14 | from .my_dataloader import load, DataLoaderX 15 | -------------------------------------------------------------------------------- /dataloader/enhancement/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | File __init__.py.py 4 | 5 | 数据增强 6 | Rescale:等比例缩放,应用在单张图像上 7 | MyAugment:自定义的数据增强,应用在单张图像上 8 | autoaugment:autoaugment,应用在单张图像上 9 | MixUp:图像mixup增强,应用在整个批次上 10 | MultiScale:多尺度训练,应用在整个批次上 11 | """ 12 | from .rescale import Rescale 13 | from .my_augment import MyAugment 14 | from .autoaugment import ImageNetPolicy, CIFAR10Policy, SVHNPolicy 15 | from .mixup import MixUp 16 | from .multi_scale import MultiScale 17 | -------------------------------------------------------------------------------- /dataloader/enhancement/mixup.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | File mixup.py 4 | 5 | mixup数据增强,应用在整个批次上 6 | """ 7 | import typing 8 | import argparse 9 | 10 | import numpy as np 11 | import torch 12 | 13 | 14 | class MixUp: 15 | 16 | def __init__(self, args: argparse.Namespace): 17 | """ 18 | 数据增强 mixup,ref: https://github.com/facebookresearch/mixup-cifar10 19 | :param args: 超参 20 | """ 21 | self.is_mixup = args.mixup 22 | self.mixup_ratio = args.mixup_ratio 23 | self.mixup_alpha = args.mixup_alpha 24 | self.num_classes = args.num_classes 25 | 26 | def __call__(self, inputs: torch.FloatTensor, targets: typing.Union[torch.IntTensor, torch.FloatTensor]) \ 27 | -> (torch.FloatTensor, typing.Union[torch.IntTensor, torch.FloatTensor], 28 | typing.Optional[torch.FloatTensor], float): 29 | """ 30 | 对图像和标签进行线性插值混合 31 | :param inputs: 图像pytorch数组,NCHW 32 | :param targets: 标签pytorch数组,(N,) 或 (N, num_classes) 33 | :return: with or w/o mixup的 图像,标签1,标签2(没有mixup就为None),混合比例 34 | """ 35 | if not self.is_mixup or np.random.rand() > self.mixup_ratio: 36 | return inputs, targets, None, 1.0 37 | 38 | # 确保标签不是类别index,否则转为one-hot编码, (N, num_classes) 39 | batch_size, num_channel, image_height, image_width = inputs.shape 40 | if targets.dim() == 1: 41 | targets = targets.unsqueeze(1) 42 | y_onehot = torch.zeros((batch_size, self.num_classes), dtype=torch.float32) 43 | targets = y_onehot.scatter_(1, targets, 1) 44 | # 制作mixup的数据 45 | rp2 = torch.randperm(batch_size) 46 | inputs1 = inputs 47 | targets1 = targets 48 | inputs2 = inputs[rp2] 49 | targets2 = targets[rp2] 50 | # mix images 51 | mix_rate = np.random.beta(self.mixup_alpha, self.mixup_alpha) 52 | mix_rate = torch.tensor(mix_rate).float() 53 | inputs_shuffle = mix_rate * inputs1 + (1 - mix_rate) * inputs2 54 | 55 | return inputs_shuffle, targets1, targets2, mix_rate 56 | -------------------------------------------------------------------------------- /dataloader/enhancement/multi_scale.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | File multi_scale.py 4 | 5 | 将图像重新等比例伸缩到指定尺寸随机产生的随机尺寸,用于后续多尺度训练,应用在整个批次上 6 | """ 7 | import typing 8 | import logging 9 | 10 | import numpy as np 11 | import torch 12 | from torch import nn 13 | from scipy.stats import truncnorm 14 | 15 | from dataloader.utils import get_rescale_size 16 | 17 | 18 | class MultiScale: 19 | 20 | def __init__(self, output_size: typing.Union[int, tuple, list], h_w_ratio: float = 1.8): 21 | """ 22 | 将图像伸缩到某一尺寸,尺寸是根据指定尺寸的正态分布随机生成的 23 | :param output_size: 指定的基线尺寸, int 或 tuple,int则表示宽,长则根据h_w_ratio而计算 24 | 如果是tuple,则代表 (H, W),h_w_ratio依此计算 25 | :param h_w_ratio: 长:宽 26 | """ 27 | assert isinstance(output_size, (int, tuple, list)) 28 | if isinstance(output_size, (tuple, list)): 29 | h_w_ratio = 1.0 * output_size[0] / output_size[1] 30 | output_size = output_size[1] 31 | logging.info(f'>>>>>>>>>>>> MultiScale Mode, Use H:W={h_w_ratio}, base width={output_size}') 32 | self.h_w_ratio = h_w_ratio 33 | 34 | # 2倍标准差 = -/+15% 截断范围 的正态分布采用图像宽度 35 | range_ratio = 0.15 36 | norm_delta = 2.0 37 | mean = output_size 38 | sigma = range_ratio * mean / norm_delta 39 | self.generator = truncnorm(-norm_delta, norm_delta, loc=mean, scale=sigma) 40 | # plt.hist(self.generator.rvs(1000)) 41 | 42 | def __call__(self, images: torch.FloatTensor) -> torch.Tensor: 43 | """ 44 | 对一批图像进行等比例随机缩放 45 | :param images: 一批图像, N3HW 46 | :return: 等比例随机缩放后的一批图像 47 | """ 48 | h, w = images.shape[2:4] 49 | # 随机生成rescale的目标尺寸 50 | target_w = int(np.round(self.generator.rvs())) 51 | target_h = int(np.round(target_w * self.h_w_ratio)) 52 | 53 | # resize & padding 54 | (new_h, new_w), (left, right, top, bottom) = get_rescale_size(h, w, target_h, target_w) 55 | images = nn.functional.interpolate(images, size=(new_h, new_w), 56 | mode='bilinear', align_corners=False) 57 | images = nn.functional.pad(images, [left, right, top, bottom]) 58 | 59 | return images 60 | -------------------------------------------------------------------------------- /dataloader/enhancement/my_augment.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | File my_augment.py 4 | 5 | 自定义数据增强,应用在单张图像上 6 | """ 7 | import logging 8 | 9 | import numpy as np 10 | import matplotlib.pyplot as plt 11 | import imgaug as ia 12 | from imgaug import augmenters as iaa 13 | 14 | 15 | class MyAugment: 16 | 17 | def __init__(self): 18 | self.seq = iaa.Sequential( 19 | [ 20 | iaa.Fliplr(0.5), 21 | iaa.Sometimes(0.5, iaa.Crop(percent=(0, 0.1))), 22 | 23 | iaa.Sometimes(0.5, iaa.Affine( 24 | rotate=(-20, 20), # 旋转±20度 25 | # shear=(-16, 16), # 剪切变换±16度,矩形变平行四边形 26 | # order=[0, 1], # 使用最近邻插值 或 双线性插值 27 | cval=0, # 填充值 28 | mode=ia.ALL # 定义填充图像外区域的方法 29 | )), 30 | 31 | # 使用0~3个方法进行图像增强 32 | iaa.SomeOf((0, 3), 33 | [ 34 | iaa.Sometimes(0.8, iaa.OneOf([ 35 | iaa.GaussianBlur((0, 2.0)), # 高斯模糊 36 | iaa.AverageBlur(k=(1, 5)), # 平均模糊,磨砂 37 | ])), 38 | 39 | # 要么运动,要么美颜 40 | iaa.Sometimes(0.8, iaa.OneOf([ 41 | iaa.MotionBlur(k=(3, 11)), # 运动模糊 42 | iaa.BilateralBlur(d=(1, 5), 43 | sigma_color=(10, 250), 44 | sigma_space=(10, 250)), # 双边滤波,美颜 45 | ])), 46 | 47 | # 模仿雪花 48 | iaa.Sometimes(0.8, iaa.OneOf([ 49 | iaa.SaltAndPepper(p=(0., 0.03)), 50 | iaa.AdditiveGaussianNoise(loc=0, scale=(0., 0.05 * 255), per_channel=False) 51 | ])), 52 | 53 | # 对比度 54 | iaa.Sometimes(0.8, iaa.LinearContrast((0.6, 1.4), per_channel=0.5)), 55 | 56 | # 锐化 57 | iaa.Sometimes(0.8, iaa.Sharpen(alpha=(0, 1.0), lightness=(0.75, 1.5))), 58 | 59 | # 整体亮度 60 | iaa.Sometimes(0.8, iaa.OneOf([ 61 | # 加性调整 62 | iaa.AddToBrightness((-30, 30)), 63 | # 线性调整 64 | iaa.MultiplyBrightness((0.5, 1.5)), 65 | # 加性 & 线性 66 | iaa.MultiplyAndAddToBrightness(mul=(0.5, 1.5), add=(-30, 30)), 67 | ])), 68 | 69 | # 饱和度 70 | iaa.Sometimes(0.8, iaa.OneOf([ 71 | iaa.AddToSaturation((-75, 75)), 72 | iaa.MultiplySaturation((0., 3.)), 73 | ])), 74 | 75 | # 色相 76 | iaa.Sometimes(0.8, iaa.OneOf([ 77 | iaa.AddToHue((-255, 255)), 78 | iaa.MultiplyHue((-3.0, 3.0)), 79 | ])), 80 | 81 | # 云雾 82 | # iaa.Sometimes(0.3, iaa.Clouds()), 83 | 84 | # 卡通化 85 | # iaa.Sometimes(0.01, iaa.Cartoon()), 86 | ], 87 | random_order=True 88 | ) 89 | ], 90 | random_order=True 91 | ) 92 | 93 | def __call__(self, img: np.ndarray) -> np.ndarray: 94 | """ 95 | 对cv2读取的单张BGR图像进行图像增强 96 | :param img: cv2读取的bgr格式图像, (h, w, 3) 97 | :return: 增强后的图像, (h, w, 3) 98 | """ 99 | image_aug = self.seq.augment_image(img) 100 | return image_aug 101 | 102 | def __repr__(self): 103 | return 'Self-defined Augment Policy' 104 | 105 | 106 | if __name__ == '__main__': 107 | import argparse 108 | 109 | import cv2 110 | import torchvision 111 | from torchvision import transforms 112 | from torch.utils.data import DataLoader 113 | 114 | from dataloader.enhancement.multi_scale import MultiScale 115 | from dataloader.enhancement.rescale import Rescale 116 | from dataloader.enhancement.mixup import MixUp 117 | 118 | image_size = (400, 224) 119 | scale = MultiScale(image_size) 120 | mixup = MixUp(argparse.Namespace(mixup=True, mixup_ratio=1.0, mixup_alpha=1.0, num_classes=3)) 121 | 122 | def show_images(dataset, is_multi_scale=True, is_mixup=True, col=4): 123 | loader = DataLoader(dataset, batch_size=col, shuffle=False, 124 | num_workers=0, pin_memory=False) 125 | for images, labels in loader: 126 | if is_multi_scale: 127 | images = scale(images) 128 | if is_mixup: 129 | images, _, _, _ = mixup(images, labels) 130 | images = images.permute(0, 2, 3, 1).numpy()[..., ::-1] 131 | logging.info(f'size: {images[0].shape}') 132 | images = np.hstack(images) 133 | 134 | plt.imshow(images) 135 | plt.axis('off') 136 | plt.show() 137 | 138 | data_set = torchvision.datasets.ImageFolder('data/train', loader=cv2.imread, 139 | transform=transforms.Compose([ 140 | MyAugment(), 141 | Rescale(image_size), 142 | transforms.ToTensor(), 143 | ])) 144 | show_images(data_set) 145 | -------------------------------------------------------------------------------- /dataloader/enhancement/rescale.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | File rescale.py 4 | 5 | 将图像等比例伸缩到指定尺寸,空余部分pad 0,应用在单张图像上 6 | """ 7 | import typing 8 | 9 | import cv2 10 | import numpy as np 11 | 12 | from dataloader.utils import get_rescale_size 13 | 14 | 15 | class Rescale: 16 | 17 | def __init__(self, output_size: typing.Union[int, tuple, list]): 18 | """ 19 | 将图像等比例伸缩到指定尺寸,空余部分pad 0 20 | :param output_size: 指定的等比例伸缩后的尺寸 21 | """ 22 | assert isinstance(output_size, (int, tuple, list)) 23 | if isinstance(output_size, int): 24 | output_size = (output_size, output_size) 25 | self.output_size = output_size 26 | 27 | def __call__(self, image: np.ndarray) -> np.ndarray: 28 | """ 29 | 对cv2读取的单张BGR图像进行图像等比例伸缩,空余部分pad 0 30 | :param image: cv2读取的bgr格式图像, (h, w, 3) 31 | :return: 等比例伸缩后的图像, (h, w, 3) 32 | """ 33 | h, w = image.shape[:2] 34 | target_h, target_w = self.output_size[0], self.output_size[1] 35 | (new_h, new_w), (left, right, top, bottom) = get_rescale_size(h, w, target_h, target_w) 36 | 37 | # 等比例缩放 38 | image = cv2.resize(image, (new_w, new_h)) 39 | # padding 40 | image = cv2.copyMakeBorder(image, top, bottom, left, right, 41 | cv2.BORDER_CONSTANT, value=[0, 0, 0]) 42 | return image 43 | -------------------------------------------------------------------------------- /dataloader/utils.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | File utils.py 4 | 5 | 工具函数 6 | """ 7 | 8 | 9 | def get_rescale_size(src_h: int, src_w: int, target_h: int, target_w: int) -> \ 10 | ((int, int), (int, int, int, int)): 11 | """ 12 | 按长边等比例缩放,短边pad 0 13 | :param src_h: 源尺寸高 14 | :param src_w: 源尺寸宽 15 | :param target_h: 目标尺寸高 16 | :param target_w: 目标尺寸宽 17 | :return: (缩放后高,缩放后宽),(左边需要pad的宽度,右边需要pad的宽度,上边需要pad的宽度,下边需要pad的宽度) 18 | """ 19 | # 等比例缩放 20 | scale = max(src_h / target_h, src_w / target_w) 21 | new_h, new_w = int(src_h / scale), int(src_w / scale) 22 | # padding 23 | left_more_pad, top_more_pad = 0, 0 24 | if new_w % 2 != 0: 25 | left_more_pad = 1 26 | if new_h % 2 != 0: 27 | top_more_pad = 1 28 | left = right = (target_w - new_w) // 2 29 | top = bottom = (target_h - new_h) // 2 30 | left += left_more_pad 31 | top += top_more_pad 32 | return (new_h, new_w), (left, right, top, bottom) 33 | -------------------------------------------------------------------------------- /demos/main.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | File main.py 4 | 5 | 图像识别分类器demo 6 | """ 7 | import typing 8 | from functools import partial 9 | 10 | import cv2 11 | import numpy as np 12 | import torch 13 | from torchvision import transforms 14 | 15 | 16 | class Classifier: 17 | """ 部署的分类器 """ 18 | MODEL_WEIGHT_PATH = '../checkpoints/jit_efficientnet_b0.pt' 19 | IMAGE_SIZE = (400, 224) 20 | 21 | def __init__(self): 22 | """ 初始化 模型、预处理器 """ 23 | torch.set_num_threads(1) 24 | torch.set_flush_denormal(True) 25 | self.model = torch.jit.load(self.MODEL_WEIGHT_PATH) 26 | self.model.eval() 27 | self.preprocess = transforms.Compose([ 28 | Rescale(self.IMAGE_SIZE), 29 | partial(cv2.cvtColor, code=cv2.COLOR_BGR2RGB), 30 | transforms.ToTensor(), 31 | transforms.Normalize(mean=[0.485, 0.456, 0.406], 32 | std=[0.229, 0.224, 0.225]), 33 | ]) 34 | 35 | def recognize(self, image: np.ndarray) -> np.ndarray: 36 | """ 37 | 图像识别 38 | :param image: opencv bgr格式的numpy数组 39 | :return 概率最大的类别结果,每一类的概率 40 | """ 41 | image = self.preprocess(image) 42 | 43 | self.model.eval() 44 | with torch.no_grad(): 45 | output = self.model(image.unsqueeze(0))[0] 46 | probabilities = output.sigmoid().detach().numpy() 47 | return probabilities 48 | 49 | 50 | class Rescale: 51 | """ 将样本中图片按指定尺寸等比例缩放 """ 52 | 53 | def __init__(self, output_size: typing.Union[int, tuple, list]): 54 | """ 55 | 将图像等比例伸缩到指定尺寸,空余部分pad 0 56 | :param output_size: 指定的等比例伸缩后的尺寸 57 | """ 58 | assert isinstance(output_size, (int, tuple, list)) 59 | if isinstance(output_size, int): 60 | output_size = (output_size, output_size) 61 | self.output_size = output_size 62 | 63 | def __call__(self, image: np.ndarray) -> np.ndarray: 64 | """ 65 | 对cv2读取的单张BGR图像进行图像等比例伸缩,空余部分pad 0 66 | :param image: cv2读取的bgr格式图像, (h, w, 3) 67 | :return: 等比例伸缩后的图像, (h, w, 3) 68 | """ 69 | h, w = image.shape[:2] 70 | target_h, target_w = self.output_size[0], self.output_size[1] 71 | (new_h, new_w), (left, right, top, bottom) = self.get_rescale_size(h, w, target_h, target_w) 72 | 73 | # 等比例缩放 74 | image = cv2.resize(image, (new_w, new_h)) 75 | # padding 76 | image = cv2.copyMakeBorder(image, top, bottom, left, right, 77 | cv2.BORDER_CONSTANT, value=[0, 0, 0]) 78 | return image 79 | 80 | @staticmethod 81 | def get_rescale_size(src_h: int, src_w: int, target_h: int, target_w: int) -> \ 82 | ((int, int), (int, int, int, int)): 83 | """ 84 | 按长边等比例缩放,短边pad 0 85 | :param src_h: 源尺寸高 86 | :param src_w: 源尺寸宽 87 | :param target_h: 目标尺寸高 88 | :param target_w: 目标尺寸宽 89 | :return: (缩放后高,缩放后宽),(左边需要pad的宽度,右边需要pad的宽度,上边需要pad的宽度,下边需要pad的宽度) 90 | """ 91 | # 等比例缩放 92 | scale = max(src_h / target_h, src_w / target_w) 93 | new_h, new_w = int(src_h / scale), int(src_w / scale) 94 | # padding 95 | left_more_pad, top_more_pad = 0, 0 96 | if new_w % 2 != 0: 97 | left_more_pad = 1 98 | if new_h % 2 != 0: 99 | top_more_pad = 1 100 | left = right = (target_w - new_w) // 2 101 | top = bottom = (target_h - new_h) // 2 102 | left += left_more_pad 103 | top += top_more_pad 104 | return (new_h, new_w), (left, right, top, bottom) 105 | 106 | 107 | if __name__ == '__main__': 108 | import os 109 | recognizer = Classifier() 110 | root_dir = 'test_images' 111 | for image_name in os.listdir(root_dir): 112 | image_path = os.path.join(root_dir, image_name) 113 | image = cv2.imread(image_path) 114 | result = recognizer.recognize(image) 115 | print(f'{image_name}: {result}') 116 | -------------------------------------------------------------------------------- /demos/test_images/cats_00001.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zheng-yuwei/PyTorch-Image-Classification/13030bd157a499b80d1860b8b654a66224eaf475/demos/test_images/cats_00001.jpg -------------------------------------------------------------------------------- /demos/test_images/dogs_00001.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zheng-yuwei/PyTorch-Image-Classification/13030bd157a499b80d1860b8b654a66224eaf475/demos/test_images/dogs_00001.jpg -------------------------------------------------------------------------------- /demos/test_images/panda_00001.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zheng-yuwei/PyTorch-Image-Classification/13030bd157a499b80d1860b8b654a66224eaf475/demos/test_images/panda_00001.jpg -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | File main.py 4 | 5 | 总入口 6 | """ 7 | import os 8 | import random 9 | import datetime 10 | import logging 11 | import traceback 12 | from functools import partial 13 | 14 | import numpy as np 15 | import torch 16 | import torch.backends.cudnn as cudnn 17 | from torch.optim.lr_scheduler import LambdaLR 18 | import torch.multiprocessing as mp 19 | import torch.distributed as dist 20 | import apex 21 | from apex import amp 22 | 23 | import models as my_models 24 | from config import parser 25 | import utils 26 | import dataloader 27 | import applications 28 | import criterions 29 | from optim.torchtools.optim import ( 30 | AdamW, RAdam, RangerLars, Ralamb, 31 | Novograd, LookaheadAdam, Ranger 32 | ) 33 | 34 | 35 | def main(): 36 | """ 37 | 3种运行方式: 38 | 1. 单CPU运行模式; 39 | 2. 单GPU运行模式; 40 | 3. 分布式运行模式:多机多卡 或 单机多卡。 41 | 分布式优势:1.支持同步BN; 2.DDP每个训练有独立进程管理,训练速度更快,显存均衡; 42 | """ 43 | args = parser.parse_args() 44 | # 根据训练机器和超参,选择运行方式 45 | num_gpus_available = torch.cuda.device_count() 46 | if num_gpus_available == 0: 47 | args.gpus = 0 48 | elif args.gpus > num_gpus_available: 49 | raise ValueError(f'--gpus(-g {args.gpus}) can not greater than available device({num_gpus_available})') 50 | 51 | # 根据每个节点的GPU数量调整world size 52 | args.world_size = args.gpus * args.nodes 53 | if not args.cuda or args.world_size == 0: 54 | # 1. cpu运行模式 55 | args.cuda = False 56 | args.gpus = 0 57 | args.distributed = False 58 | elif args.world_size == 1: 59 | # 2. 单GPU运行模式 60 | args.distributed = False 61 | elif args.world_size > 1: 62 | # 3. 分布式运行模式 63 | args.distributed = True 64 | else: 65 | raise ValueError(f'Check config parameters --nodes/-n={args.nodes} and --gpus/-g={args.gpus}!') 66 | 67 | if args.distributed and args.gpus > 1: 68 | # use torch.multiprocessing.spawn to launch distributed processes 69 | mp.spawn(main_worker, nprocs=args.gpus, args=(args,)) 70 | else: 71 | # Simply call main_worker function 72 | main_worker(0, args) 73 | 74 | 75 | def main_worker(gpu, args): 76 | """ 77 | 模型训练、测试、转JIT、蒸馏文件制作 78 | :param gpu: 运行的gpu id 79 | :param args: 运行超参 80 | """ 81 | args.gpu = gpu 82 | utils.generate_logger(f"{datetime.datetime.now().strftime('%Y%m%d%H%M%S')}-{gpu}.log") 83 | logging.info(f'args: {args}') 84 | 85 | # 可复现性 86 | if args.seed is not None: 87 | random.seed(args.seed) 88 | torch.manual_seed(args.seed) 89 | cudnn.deterministic = True 90 | logging.warning('You have chosen to seed training. ' 91 | 'This will turn on the CUDNN deterministic setting, ' 92 | 'which can slow down your training considerably! ' 93 | 'You may see unexpected behavior when restarting ' 94 | 'from checkpoints.') 95 | 96 | if args.cuda: 97 | logging.info(f"Use GPU: {args.gpu} ~") 98 | if args.distributed: 99 | args.rank = args.rank * args.gpus + gpu 100 | dist.init_process_group(backend='nccl', init_method=args.init_method, 101 | world_size=args.world_size, rank=args.rank) 102 | else: 103 | logging.info(f"Use CPU ~") 104 | 105 | # 创建/加载模型,使用预训练模型时,需要自己先下载好放到 pretrained 文件夹下,以网络名词命名 106 | logging.info(f"=> creating model '{args.arch}'") 107 | model = my_models.get_model(args.arch, args.pretrained, num_classes=args.num_classes) 108 | 109 | # 重加载之前训练好的模型 110 | if args.resume: 111 | if os.path.isfile(args.resume): 112 | logging.info(f"=> loading checkpoint '{args.resume}'") 113 | checkpoint = torch.load(args.resume, map_location=torch.device('cpu')) 114 | acc = model.load_state_dict(checkpoint['state_dict'], strict=True) 115 | logging.info(f'missing keys of models: {acc.missing_keys}') 116 | del checkpoint 117 | else: 118 | raise Exception(f"No checkpoint found at '{args.resume}' to be resumed") 119 | 120 | # 模型信息 121 | image_height, image_width = args.image_size 122 | logging.info(f'Model {args.arch} input size: ({image_height}, {image_width})') 123 | utils.summary(size=(image_height, image_width), channel=3, model=model) 124 | 125 | # 模型转换:转为 torch.jit.script 126 | if args.jit: 127 | if not args.resume: 128 | raise Exception('Option --resume must specified!') 129 | applications.convert_to_jit(model, args=args) 130 | return 131 | 132 | if args.criterion == 'softmax': 133 | criterion = criterions.HybridCELoss(args=args) # 混合策略多分类 134 | elif args.criterion == 'bce': 135 | criterion = criterions.HybridBCELoss(args=args) # 混合策略多标签二分类 136 | else: 137 | raise NotImplementedError(f'Not loss function {args.criterion}') 138 | 139 | if args.cuda: 140 | if args.distributed and args.sync_bn: 141 | model = apex.parallel.convert_syncbn_model(model) 142 | torch.cuda.set_device(args.gpu) 143 | model.cuda(args.gpu) 144 | criterion = criterion.cuda(args.gpu) 145 | 146 | if args.knowledge in ('train', 'test', 'val'): 147 | torch.set_flush_denormal(True) 148 | distill_loader = dataloader.load(args, name=args.knowledge) 149 | applications.distill(distill_loader, model, criterion, args, is_confuse_matrix=True) 150 | return 151 | 152 | if args.make_curriculum in ('train', 'test', 'val'): 153 | torch.set_flush_denormal(True) 154 | curriculum_loader = dataloader.load(args, name=args.make_curriculum) 155 | applications.make_curriculum(curriculum_loader, model, criterion, args, is_confuse_matrix=True) 156 | return 157 | 158 | if args.visual_data in ('train', 'test', 'val'): 159 | torch.set_flush_denormal(True) 160 | test_loader = dataloader.load(args, name=args.visual_data) 161 | applications.Visualize.visualize(test_loader, model, args) 162 | return 163 | 164 | # 优化器 165 | opt_set = { 166 | 'sgd': partial(torch.optim.SGD, momentum=args.momentum), 167 | 'adam': torch.optim.Adam, 'adamw': AdamW, 168 | 'radam': RAdam, 'ranger': Ranger, 'lookaheadadam': LookaheadAdam, 169 | 'ralamb': Ralamb, 'rangerlars': RangerLars, 170 | 'novograd': Novograd, 171 | } 172 | optimizer = opt_set[args.opt](model.parameters(), lr=args.lr) # weight decay转移到train那里了 173 | # 随机均值平均优化器 174 | # from optim.swa import SWA 175 | # optimizer = SWA(optimizer, swa_start=10, swa_freq=5, swa_lr=0.05) 176 | 177 | # 混合精度训练 178 | if args.cuda: 179 | model, optimizer = amp.initialize(model, optimizer, opt_level="O1") 180 | 181 | if args.distributed: 182 | model = apex.parallel.DistributedDataParallel(model) 183 | else: 184 | model = torch.nn.DataParallel(model) 185 | 186 | if args.train: 187 | train_loader = dataloader.load(args, 'train') 188 | val_loader = dataloader.load(args, 'val') 189 | scheduler = LambdaLR(optimizer, 190 | lambda epoch: adjust_learning_rate(epoch, args=args)) 191 | applications.train(train_loader, val_loader, model, criterion, optimizer, scheduler, args) 192 | args.evaluate = True 193 | 194 | if args.evaluate: 195 | torch.set_flush_denormal(True) 196 | test_loader = dataloader.load(args, name='test') 197 | acc, loss, paths_targets_preds_probs = applications.test(test_loader, model, 198 | criterion, args, is_confuse_matrix=True) 199 | logging.info(f'Evaluation: * Acc@1 {acc:.3f} and loss {loss:.3f}.') 200 | logging.info(f'Evaluation Result:\n') 201 | for path, target, pred, prob in paths_targets_preds_probs: 202 | logging.info(path + ' ' + str(target) + ' ' + str(pred) + ' ' + ','.join([f'{num:.2f}' for num in prob])) 203 | logging.info('Evaluation Over~') 204 | 205 | 206 | def adjust_learning_rate(epoch, args): 207 | """ 根据warmup设置、迭代代数、设置的学习率,调整每一代的学习率 208 | :param epoch: 当前epoch数 209 | :param args: 使用warmup代数、学习率 210 | """ 211 | # lr_rates = [0.1, 1., 10., 100., 1e-10] 212 | # epochs = [2, 4, 6, 8, 10] 213 | lr_ratios = np.array([0.1, 1., 0.1, 0.01, 0.001]) 214 | epoch_step = (args.epochs - args.warmup) / 4.0 215 | epochs = np.array([args.warmup, 216 | args.warmup + int(1.5 * epoch_step), 217 | args.warmup + int(2.5 * epoch_step), 218 | args.warmup + int(3.5 * epoch_step), 219 | args.epochs]) 220 | if args.lr_ratios is not None: 221 | lr_ratios = np.array(args.lr_ratios) 222 | if args.lr_steps is not None: 223 | epochs = np.array(args.lr_steps) 224 | 225 | for i, e in enumerate(epochs): 226 | if e > epoch: 227 | return lr_ratios[i] 228 | elif e == epoch: 229 | next_rate = lr_ratios[i] 230 | if len(lr_ratios) > i + 1: 231 | next_rate = lr_ratios[i + 1] 232 | logging.info(f'===== lr decay rate: {lr_ratios[i]} -> {next_rate} =====') 233 | 234 | return lr_ratios[-1] 235 | 236 | 237 | if __name__ == '__main__': 238 | try: 239 | main() 240 | except Exception as e: 241 | logging.error(traceback.format_exc()) 242 | raise e 243 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | File __init__.py.py 4 | 5 | 模型结构定义脚本 6 | """ 7 | from .model_factory import get_model 8 | -------------------------------------------------------------------------------- /models/efficientnet/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | File __init__.py.py 4 | 5 | EfficientNet模型结构定义脚本 6 | """ 7 | from .factory import * 8 | -------------------------------------------------------------------------------- /models/efficientnet/components.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | File components.py 4 | 5 | EfficientNet网络的一些组件 6 | """ 7 | import math 8 | from functools import partial 9 | 10 | import torch 11 | from torch import nn 12 | from torch.nn import functional as F 13 | 14 | from .config import BlockArgs, GlobalParams 15 | from .utils import drop_connect 16 | 17 | 18 | class MBConvBlock(nn.Module): 19 | """ Mobile Inverted Residual Bottleneck Block """ 20 | 21 | def __init__(self, block_args: BlockArgs, global_params: GlobalParams): 22 | """ 23 | MobileNet v3的基础block 24 | :param block_args: block的超参 25 | :param global_params: 整体网络的超参 26 | """ 27 | super().__init__() 28 | self._block_args = block_args 29 | self._bn_mom = 1 - global_params.batch_norm_momentum 30 | self._bn_eps = global_params.batch_norm_epsilon 31 | self.has_se = (self._block_args.se_ratio is not None) and (0 < self._block_args.se_ratio <= 1) 32 | self.id_skip = block_args.id_skip # skip connection and drop connect 33 | 34 | Conv2d = partial(Conv2dStaticSamePadding, image_size=global_params.image_size) 35 | 36 | # Expansion phase 37 | inp = self._block_args.input_filters # number of input channels 38 | oup = self._block_args.input_filters * self._block_args.expand_ratio # number of features channels 39 | if self._block_args.expand_ratio != 1: 40 | self._expand_conv = Conv2d(in_channels=inp, out_channels=oup, kernel_size=1, bias=False) 41 | self._bn0 = nn.BatchNorm2d(num_features=oup, momentum=self._bn_mom, eps=self._bn_eps) 42 | 43 | # Depthwise convolution phase 44 | k = self._block_args.kernel_size 45 | s = self._block_args.stride 46 | self._depthwise_conv = Conv2d( 47 | in_channels=oup, out_channels=oup, groups=oup, # groups makes it depthwise 48 | kernel_size=k, stride=s, bias=False) 49 | self._bn1 = nn.BatchNorm2d(num_features=oup, momentum=self._bn_mom, eps=self._bn_eps) 50 | 51 | # Squeeze and Excitation layer, if desired 52 | if self.has_se: 53 | num_squeezed_channels = max(1, int(self._block_args.input_filters * self._block_args.se_ratio)) 54 | self._se_reduce = Conv2d(in_channels=oup, out_channels=num_squeezed_channels, kernel_size=1) 55 | self._se_expand = Conv2d(in_channels=num_squeezed_channels, out_channels=oup, kernel_size=1) 56 | 57 | # Output phase 58 | final_oup = self._block_args.output_filters 59 | self._project_conv = Conv2d(in_channels=oup, out_channels=final_oup, kernel_size=1, bias=False) 60 | self._bn2 = nn.BatchNorm2d(num_features=final_oup, momentum=self._bn_mom, eps=self._bn_eps) 61 | self._swish0 = Swish() 62 | self._swish1 = Swish() 63 | self._swish2 = Swish() 64 | 65 | def forward(self, inputs: torch.Tensor, drop_connect_rate: int = None) -> torch.Tensor: 66 | """ 67 | MobileNet v3的基础block前向计算 68 | :param inputs: 输入向量 69 | :param drop_connect_rate: drop connect rate (float, between 0 and 1) 70 | :return: features of block 71 | """ 72 | 73 | # Expansion and Depthwise Convolution 74 | x = inputs 75 | if self._block_args.expand_ratio != 1: 76 | x = self._swish0(self._bn0(self._expand_conv(inputs))) 77 | x = self._swish1(self._bn1(self._depthwise_conv(x))) 78 | 79 | # Squeeze and Excitation 80 | if self.has_se: 81 | x_squeezed = F.adaptive_avg_pool2d(x, [1, 1]) 82 | x_squeezed = self._se_expand(self._swish2(self._se_reduce(x_squeezed))) 83 | x = torch.sigmoid(x_squeezed) * x 84 | 85 | x = self._bn2(self._project_conv(x)) 86 | 87 | # Skip connection and drop connect 88 | input_filters, output_filters = self._block_args.input_filters, self._block_args.output_filters 89 | if self.id_skip and self._block_args.stride == 1 and input_filters == output_filters: 90 | if drop_connect_rate: 91 | x = drop_connect(x, p=drop_connect_rate, training=self.training) 92 | x = x + inputs # skip connection 93 | return x 94 | 95 | 96 | class Swish(nn.Module): 97 | """ swish激活函数 """ 98 | def forward(self, x): 99 | return x * torch.sigmoid(x) 100 | 101 | 102 | class Conv2dStaticSamePadding(nn.Conv2d): 103 | """ 模拟TensorFlow的2D卷积类(固定图像输入尺寸) """ 104 | 105 | def __init__(self, in_channels, out_channels, kernel_size, image_size=None, **kwargs): 106 | super().__init__(in_channels, out_channels, kernel_size, **kwargs) 107 | self.stride = self.stride if len(self.stride) == 2 else [self.stride[0]] * 2 108 | 109 | # Calculate padding based on image size and save it 110 | assert image_size is not None 111 | ih, iw = image_size if type(image_size) == list else [image_size, image_size] 112 | kh, kw = self.weight.size()[-2:] 113 | sh, sw = self.stride 114 | oh, ow = math.ceil(ih / sh), math.ceil(iw / sw) 115 | pad_h = max((oh - 1) * self.stride[0] + (kh - 1) * self.dilation[0] + 1 - ih, 0) 116 | pad_w = max((ow - 1) * self.stride[1] + (kw - 1) * self.dilation[1] + 1 - iw, 0) 117 | if pad_h > 0 or pad_w > 0: 118 | self.static_padding = nn.ZeroPad2d((pad_w // 2, pad_w - pad_w // 2, 119 | pad_h // 2, pad_h - pad_h // 2)) 120 | else: 121 | self.static_padding = Identity() 122 | 123 | def forward(self, x: torch.Tensor) -> torch.Tensor: 124 | x = self.static_padding(x) 125 | x = F.conv2d(x, self.weight, self.bias, self.stride, 126 | self.padding, self.dilation, self.groups) 127 | return x 128 | 129 | 130 | class Identity(nn.Module): 131 | def __init__(self, ): 132 | super(Identity, self).__init__() 133 | 134 | def forward(self, input_x): 135 | return input_x 136 | -------------------------------------------------------------------------------- /models/efficientnet/config.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | File config.py 4 | 5 | EfficientNet系列配置 6 | """ 7 | import re 8 | import typing 9 | import collections 10 | 11 | 12 | # Parameters for the entire model (stem, all blocks, and head) 13 | GlobalParams = collections.namedtuple('GlobalParams', [ 14 | 'batch_norm_momentum', 'batch_norm_epsilon', 15 | 'dropout_rate', 'drop_connect_rate', 16 | 'width_coefficient', 'depth_coefficient', 17 | 'image_size', 'num_classes', 18 | 'depth_divisor', 'min_depth']) 19 | 20 | # Parameters for an individual model block 21 | BlockArgs = collections.namedtuple('BlockArgs', [ 22 | 'num_repeat', 'kernel_size', 'stride', 'expand_ratio', 23 | 'input_filters', 'output_filters', 'se_ratio', 'id_skip']) 24 | 25 | # Change namedtuple defaults 26 | GlobalParams.__new__.__defaults__ = (None,) * len(GlobalParams._fields) 27 | BlockArgs.__new__.__defaults__ = (None,) * len(BlockArgs._fields) 28 | 29 | 30 | class EfficientNetConfig: 31 | 32 | # 不同大小的EfficientNet的网络超参 33 | SEARCH_PARAMS = { 34 | # Coefficients: width, depth, res, dropout 35 | 'efficientnet-b0': (1.0, 1.0, [224, 224], 0.2), 36 | 'efficientnet-b1': (1.0, 1.1, [240, 240], 0.2), 37 | 'efficientnet-b2': (1.1, 1.2, [260, 260], 0.3), 38 | 'efficientnet-b3': (1.2, 1.4, [300, 300], 0.3), 39 | 'efficientnet-b4': (1.4, 1.8, [380, 380], 0.4), 40 | 'efficientnet-b5': (1.6, 2.2, [456, 456], 0.4), 41 | 'efficientnet-b6': (1.8, 2.6, [528, 528], 0.5), 42 | 'efficientnet-b7': (2.0, 3.1, [600, 600], 0.5), 43 | 'efficientnet-b8': (2.2, 3.6, [672, 672], 0.5), 44 | 'efficientnet-l2': (4.3, 5.3, [800, 800], 0.5), 45 | } 46 | 47 | # EfficientNet的7个block的结构超参 48 | BLOCKS_ARGS = [ 49 | # num_repeat, kernel_size, stride, expand_ratio, 50 | # input_filters, output_filters, se_ratio, id_skip(默认为True,noskip为False) 51 | 'r1_k3_s11_e1_i32_o16_se0.25', 52 | 'r2_k3_s22_e6_i16_o24_se0.25', 53 | 'r2_k5_s22_e6_i24_o40_se0.25', 54 | 'r3_k3_s22_e6_i40_o80_se0.25', 55 | 'r3_k5_s11_e6_i80_o112_se0.25', 56 | 'r4_k5_s22_e6_i112_o192_se0.25', 57 | 'r1_k3_s11_e6_i192_o320_se0.25', 58 | ] 59 | 60 | @classmethod 61 | def check_model_name_is_valid(cls, model_name): 62 | """ 检查EfficientNet模型名是否有效 """ 63 | if model_name not in cls.SEARCH_PARAMS: 64 | raise ValueError(f'model_name should be one of: {cls.SEARCH_PARAMS.keys()}.') 65 | 66 | @classmethod 67 | def get_model_params(cls, model_name, override_params): 68 | """ 根据模型名称,获取该模型的基础结构blocks_args、模型超参global_params """ 69 | if model_name.startswith('efficientnet'): 70 | w, d, s, p = cls.get_search_params(model_name) 71 | # note: all models have drop connect rate = 0.2 72 | blocks_args, global_params = cls._get_model_params( 73 | width_coefficient=w, depth_coefficient=d, dropout_rate=p, image_size=s) 74 | else: 75 | raise NotImplementedError(f'model name is not pre-defined: {model_name}') 76 | 77 | if override_params: 78 | global_params = global_params._replace(**override_params) 79 | return blocks_args, global_params 80 | 81 | @classmethod 82 | def get_search_params(cls, model_name: str) -> \ 83 | typing.Tuple[float, float, typing.List[int], float]: 84 | """ 85 | 获取对应EfficientNet结构搜索出来的超参 86 | :param model_name: 87 | :return EfficientNet搜索超参 88 | """ 89 | return cls.SEARCH_PARAMS[model_name] 90 | 91 | @classmethod 92 | def _get_model_params( 93 | cls, width_coefficient: float = None, depth_coefficient: float = None, 94 | dropout_rate: float = 0.2, drop_connect_rate: float = 0.2, 95 | image_size: typing.List[int] = None, num_classes: int = 1000 96 | ) -> (typing.List[BlockArgs], GlobalParams): 97 | """ 98 | 根据模型超参,构造 每个block的超参的列表、整体模型的超参 99 | :param width_coefficient: block宽度系数 100 | :param depth_coefficient: block深度系数 101 | :param dropout_rate: 整体网络fc全连接层的dropout系数 102 | :param drop_connect_rate: 卷积层drop connection系数 103 | :param image_size: 输入图像分辨率,[H, W] 104 | :param num_classes: 分类网络输出的类别数 105 | :returns 106 | blocks_args: 每个block的参数 107 | 108 | """ 109 | blocks_args = BlockDecoder.decode(cls.BLOCKS_ARGS) 110 | 111 | global_params = GlobalParams( 112 | batch_norm_momentum=0.99, 113 | batch_norm_epsilon=1e-3, 114 | dropout_rate=dropout_rate, 115 | drop_connect_rate=drop_connect_rate, 116 | num_classes=num_classes, 117 | width_coefficient=width_coefficient, 118 | depth_coefficient=depth_coefficient, 119 | depth_divisor=8, 120 | min_depth=None, 121 | image_size=image_size, 122 | ) 123 | 124 | return blocks_args, global_params 125 | 126 | 127 | class BlockDecoder: 128 | """ 解析efficient net基础block的参数 """ 129 | 130 | @staticmethod 131 | def _decode_block_string(block_string: str) -> BlockArgs: 132 | """ 133 | 解析block超参的字符串为BlockArgs变量 134 | :param block_string: block超参的字符串 135 | :return 136 | """ 137 | assert isinstance(block_string, str) 138 | 139 | ops = block_string.split('_') 140 | options = {} 141 | for op in ops: 142 | splits = re.split(r'(\d.*)', op) 143 | if len(splits) >= 2: 144 | key, value = splits[:2] 145 | options[key] = value 146 | 147 | # Check stride 148 | assert (('s' in options and len(options['s']) == 1) or 149 | (len(options['s']) == 2 and options['s'][0] == options['s'][1])) 150 | 151 | return BlockArgs( 152 | kernel_size=int(options['k']), 153 | num_repeat=int(options['r']), 154 | input_filters=int(options['i']), 155 | output_filters=int(options['o']), 156 | expand_ratio=int(options['e']), 157 | id_skip=('noskip' not in block_string), 158 | se_ratio=float(options['se']) if 'se' in options else None, 159 | stride=[int(options['s'][0])]) 160 | 161 | @staticmethod 162 | def _encode_block_string(block: BlockArgs) -> str: 163 | """Encodes a block to a string.""" 164 | args = [ 165 | 'r%d' % block.num_repeat, 166 | 'k%d' % block.kernel_size, 167 | 's%d%d' % (block.strides[0], block.strides[1]), 168 | 'e%s' % block.expand_ratio, 169 | 'i%d' % block.input_filters, 170 | 'o%d' % block.output_filters 171 | ] 172 | if 0 < block.se_ratio <= 1: 173 | args.append('se%s' % block.se_ratio) 174 | if block.id_skip is False: 175 | args.append('noskip') 176 | return '_'.join(args) 177 | 178 | @staticmethod 179 | def decode(string_list: typing.List[str]) -> typing.List[BlockArgs]: 180 | """ 181 | Decodes a list of string notations to specify blocks inside the network. 182 | 183 | :param string_list: a list of strings, each string is a notation of block 184 | :return: a list of BlockArgs namedtuples of block args 185 | """ 186 | assert isinstance(string_list, list) 187 | blocks_args = [] 188 | for block_string in string_list: 189 | blocks_args.append(BlockDecoder._decode_block_string(block_string)) 190 | return blocks_args 191 | 192 | @staticmethod 193 | def encode(blocks_args: typing.List[BlockArgs]) -> typing.List[str]: 194 | """ 195 | Encodes a list of BlockArgs to a list of strings. 196 | 197 | :param blocks_args: a list of BlockArgs namedtuples of block args 198 | :return: a list of strings, each string is a notation of block 199 | """ 200 | block_strings = [] 201 | for block in blocks_args: 202 | block_strings.append(BlockDecoder._encode_block_string(block)) 203 | return block_strings 204 | -------------------------------------------------------------------------------- /models/efficientnet/efficientnet.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | File efficientnet.py 4 | 5 | """ 6 | import typing 7 | from functools import partial 8 | 9 | from torch import nn 10 | 11 | from .config import BlockArgs, GlobalParams, EfficientNetConfig 12 | from .utils import round_filters, round_repeats, load_pretrained_weights 13 | from .components import Swish, Conv2dStaticSamePadding, Identity, MBConvBlock 14 | 15 | 16 | class EfficientNet(nn.Module): 17 | """ EfficientNet模型 """ 18 | 19 | def __init__(self, blocks_args: typing.List[BlockArgs] = None, 20 | global_params: GlobalParams = None): 21 | super().__init__() 22 | assert isinstance(blocks_args, list), 'blocks_args should be a list' 23 | assert len(blocks_args) > 0, 'block args must be greater than 0' 24 | self._global_params = global_params 25 | self._blocks_args = blocks_args 26 | 27 | Conv2d = partial(Conv2dStaticSamePadding, image_size=global_params.image_size) 28 | 29 | # BN层超参 30 | bn_mom = 1 - self._global_params.batch_norm_momentum 31 | bn_eps = self._global_params.batch_norm_epsilon 32 | 33 | # 主干网络 34 | in_channels = 3 # rgb 35 | out_channels = round_filters(32, self._global_params) # number of features channels 36 | self._conv_stem = Conv2d(in_channels, out_channels, kernel_size=3, stride=2, bias=False) 37 | self._bn0 = nn.BatchNorm2d(num_features=out_channels, momentum=bn_mom, eps=bn_eps) 38 | 39 | # Build blocks 40 | self._blocks = nn.ModuleList() 41 | block_args = None 42 | for block_args in self._blocks_args: 43 | 44 | # 根据网络超参更新模块的输入/输出卷积核数 45 | block_args = block_args._replace( 46 | input_filters=round_filters(block_args.input_filters, self._global_params), 47 | output_filters=round_filters(block_args.output_filters, self._global_params), 48 | num_repeat=round_repeats(block_args.num_repeat, self._global_params) 49 | ) 50 | 51 | # 第一个block需要处理好stride和卷积核尺寸 52 | self._blocks.append(MBConvBlock(block_args, self._global_params)) 53 | if block_args.num_repeat > 1: 54 | block_args = block_args._replace(input_filters=block_args.output_filters, stride=1) 55 | for _ in range(block_args.num_repeat - 1): 56 | self._blocks.append(MBConvBlock(block_args, self._global_params)) 57 | 58 | # Head 59 | in_channels = block_args.output_filters # features of final block 60 | out_channels = round_filters(1280, self._global_params) 61 | self._conv_head = Conv2d(in_channels, out_channels, kernel_size=1, bias=False) 62 | self._bn1 = nn.BatchNorm2d(num_features=out_channels, momentum=bn_mom, eps=bn_eps) 63 | 64 | # 最后线性全连接层 65 | self._swish_head = Swish() 66 | self._swish_fc = Swish() 67 | self._avg_pooling = nn.AdaptiveAvgPool2d(1) 68 | self._dropout = nn.Dropout(self._global_params.dropout_rate) 69 | 70 | self._fc = nn.Linear(out_channels, self._global_params.num_classes) 71 | # sample-free 72 | # pos_bias = np.array([47775, 377470, 22176, 21935, 21584], dtype=np.float32) # 每一类的正类数量 73 | # neg_bias = np.array([47775, 100000, 22176, 21935, 21584], dtype=np.float32) # 每一类的负类数量 74 | # bias = -np.log(neg_bias / pos_bias) 75 | # self._fc.bias = torch.nn.Parameter(torch.from_numpy(bias)) 76 | 77 | def extract_features(self, inputs): 78 | """ 得到最后一层卷积层的输出 """ 79 | 80 | # Stem 81 | x = self._swish_head(self._bn0(self._conv_stem(inputs))) 82 | 83 | # Blocks 84 | for idx, block in enumerate(self._blocks): 85 | drop_connect_rate = self._global_params.drop_connect_rate 86 | if drop_connect_rate: 87 | drop_connect_rate *= float(idx) / len(self._blocks) 88 | x = block(x, drop_connect_rate=drop_connect_rate) 89 | 90 | # Head 91 | x = self._swish_fc(self._bn1(self._conv_head(x))) 92 | 93 | return x 94 | 95 | def forward(self, inputs): 96 | """ 调用extract_features函数抽取最后一层特征,接上全连接层,得到logits输出 """ 97 | bs = inputs.size(0) 98 | # Convolution layers 99 | x = self.extract_features(inputs) 100 | 101 | # Pooling and final linear layer 102 | x = self._avg_pooling(x) 103 | x = x.view(bs, -1) 104 | x = self._dropout(x) 105 | x = self._fc(x) 106 | return x 107 | 108 | @classmethod 109 | def from_name(cls, model_name, **override_params): 110 | """ 111 | 根据EfficientNet系列相应的名字构建模型 112 | :param model_name: EfficientNet系列相应的名字 113 | :param override_params: 模型的一些自定义超参 114 | """ 115 | EfficientNetConfig.check_model_name_is_valid(model_name) 116 | blocks_args, global_params = EfficientNetConfig.get_model_params(model_name, override_params) 117 | return cls(blocks_args, global_params) 118 | 119 | @classmethod 120 | def from_pretrained(cls, model_name, num_classes=1000, in_channels=3, adv_prop=False): 121 | """ 122 | 加载预训练模型 123 | :param model_name: 模型名称 124 | :param num_classes: 模型类别数 125 | :param in_channels: 输入图像通道数 126 | :param adv_prop: 使用adv_prop预训练模型 127 | """ 128 | model = cls.from_name(model_name, num_classes=num_classes) 129 | load_pretrained_weights(model, model_name, load_fc=(num_classes == 1000), adv_prop=adv_prop) 130 | if in_channels != 3: 131 | Conv2d = partial(Conv2dStaticSamePadding, image_size=model._global_params.image_size) 132 | out_channels = round_filters(32, model._global_params) 133 | model._conv_stem = Conv2d(in_channels, out_channels, kernel_size=3, stride=2, bias=False) 134 | return model 135 | 136 | @classmethod 137 | def get_image_size(cls, model_name): 138 | """ 139 | 获取图像输入尺寸 140 | :param model_name: 模型名称 141 | """ 142 | EfficientNetConfig.check_model_name_is_valid(model_name) 143 | _, _, res, _ = EfficientNetConfig.get_search_params(model_name) 144 | return res 145 | -------------------------------------------------------------------------------- /models/efficientnet/factory.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | File factory.py 4 | 5 | EfficientNet系列的工厂脚本 6 | """ 7 | from .efficientnet import EfficientNet 8 | 9 | __all__ = ['efficientnet_b0', 'efficientnet_b1', 'efficientnet_b2', 'efficientnet_b3', 10 | 'efficientnet_b4', 'efficientnet_b5', 'efficientnet_b6', 'efficientnet_b7', 11 | 'efficientnet_b8'] 12 | 13 | 14 | def load_efficientnet(name, num_classes, adv_prop, pretrained): 15 | if pretrained: 16 | model = EfficientNet.from_pretrained(name, 17 | num_classes=num_classes, 18 | adv_prop=adv_prop) 19 | else: 20 | model = EfficientNet.from_name(name, num_classes=num_classes) 21 | return model 22 | 23 | 24 | def efficientnet_b0(pretrained=False, **kwargs): 25 | num_classes = kwargs.get('num_classes') or 1000 26 | adv_prop = kwargs.get('adv_prop') or False 27 | model = load_efficientnet('efficientnet-b0', num_classes, adv_prop, pretrained) 28 | return model 29 | 30 | 31 | def efficientnet_b1(pretrained=False, **kwargs): 32 | num_classes = kwargs.get('num_classes') or 1000 33 | adv_prop = kwargs.get('adv_prop') or False 34 | model = load_efficientnet('efficientnet-b1', num_classes, adv_prop, pretrained) 35 | return model 36 | 37 | 38 | def efficientnet_b2(pretrained=False, **kwargs): 39 | num_classes = kwargs.get('num_classes') or 1000 40 | adv_prop = kwargs.get('adv_prop') or False 41 | model = load_efficientnet('efficientnet-b2', num_classes, adv_prop, pretrained) 42 | return model 43 | 44 | 45 | def efficientnet_b3(pretrained=False, **kwargs): 46 | num_classes = kwargs.get('num_classes') or 1000 47 | adv_prop = kwargs.get('adv_prop') or False 48 | model = load_efficientnet('efficientnet-b3', num_classes, adv_prop, pretrained) 49 | return model 50 | 51 | 52 | def efficientnet_b4(pretrained=False, **kwargs): 53 | num_classes = kwargs.get('num_classes') or 1000 54 | adv_prop = kwargs.get('adv_prop') or False 55 | model = load_efficientnet('efficientnet-b4', num_classes, adv_prop, pretrained) 56 | return model 57 | 58 | 59 | def efficientnet_b5(pretrained=False, **kwargs): 60 | num_classes = kwargs.get('num_classes') or 1000 61 | adv_prop = kwargs.get('adv_prop') or False 62 | model = load_efficientnet('efficientnet-b5', num_classes, adv_prop, pretrained) 63 | return model 64 | 65 | 66 | def efficientnet_b6(pretrained=False, **kwargs): 67 | num_classes = kwargs.get('num_classes') or 1000 68 | adv_prop = kwargs.get('adv_prop') or False 69 | model = load_efficientnet('efficientnet-b6', num_classes, adv_prop, pretrained) 70 | return model 71 | 72 | 73 | def efficientnet_b7(pretrained=False, **kwargs): 74 | num_classes = kwargs.get('num_classes') or 1000 75 | adv_prop = kwargs.get('adv_prop') or False 76 | model = load_efficientnet('efficientnet-b7', num_classes, adv_prop, pretrained) 77 | return model 78 | 79 | 80 | def efficientnet_b8(pretrained=False, **kwargs): 81 | num_classes = kwargs.get('num_classes') or 1000 82 | adv_prop = kwargs.get('adv_prop') or False 83 | model = load_efficientnet('efficientnet-b8', num_classes, adv_prop, pretrained) 84 | return model 85 | -------------------------------------------------------------------------------- /models/efficientnet/utils.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | File utils.py 4 | 5 | EfficientNet工具 6 | """ 7 | import os 8 | import math 9 | import logging 10 | 11 | import torch 12 | 13 | from .config import GlobalParams 14 | 15 | 16 | def round_filters(filters: int, global_params: GlobalParams) -> int: 17 | """ Calculate and round number of filters based on depth multiplier. """ 18 | multiplier = global_params.width_coefficient 19 | if not multiplier: 20 | return filters 21 | divisor = global_params.depth_divisor 22 | min_depth = global_params.min_depth 23 | filters *= multiplier 24 | min_depth = min_depth or divisor 25 | new_filters = max(min_depth, int(filters + divisor / 2) // divisor * divisor) 26 | if new_filters < 0.9 * filters: # prevent rounding by more than 10% 27 | new_filters += divisor 28 | return int(new_filters) 29 | 30 | 31 | def round_repeats(repeats: int, global_params: GlobalParams) -> int: 32 | """ Round number of filters based on depth multiplier. """ 33 | multiplier = global_params.depth_coefficient 34 | if not multiplier: 35 | return repeats 36 | return int(math.ceil(multiplier * repeats)) 37 | 38 | 39 | def drop_connect(inputs: torch.Tensor, p: float, 40 | training: bool) -> torch.Tensor: 41 | """ Drop connect. """ 42 | if not training: 43 | return inputs 44 | batch_size = inputs.shape[0] 45 | keep_prob = 1 - p 46 | random_tensor = keep_prob 47 | random_tensor += torch.rand([batch_size, 1, 1, 1], dtype=inputs.dtype, device=inputs.device) 48 | binary_tensor = torch.floor(random_tensor) 49 | output = inputs / keep_prob * binary_tensor 50 | return output 51 | 52 | 53 | def load_pretrained_weights(model, model_name, load_fc=True, adv_prop=False): 54 | """ 55 | 加载预训练模型 56 | :param model: 模型 57 | :param model_name: 模型名称 58 | :param load_fc: 是否复用fc层 59 | :param adv_prop: 是否使用adv_prop预训练模型 60 | """ 61 | model_root = 'pretrained' 62 | if adv_prop: 63 | model_root = os.path.join(model_root, 'advprop') 64 | model_path = os.path.join(model_root, f'{model_name}.pth') 65 | state_dict = torch.load(model_path) 66 | 67 | if load_fc: 68 | model.load_state_dict(state_dict) 69 | else: 70 | state_dict.pop('_fc.weight') 71 | state_dict.pop('_fc.bias') 72 | res = model.load_state_dict(state_dict, strict=False) 73 | assert all([key.find('_fc') != -1 for key in set(res.missing_keys)]), 'issue loading pretrained weights' 74 | del state_dict 75 | logging.info(f'Loaded pretrained weights for {model_path}') 76 | -------------------------------------------------------------------------------- /models/mobilenetv3/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | File __init__.py.py 4 | 5 | MobileNet v3结构定义脚本 6 | """ 7 | from .factory import * 8 | -------------------------------------------------------------------------------- /models/mobilenetv3/factory.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | File factory.py 4 | 5 | MobileNet v3的工厂函数 6 | """ 7 | import logging 8 | 9 | import torch 10 | from .mobilenetv3 import MobileNetV3 11 | 12 | 13 | __all__ = ['mobilenetv3_large', 'mobilenetv3_small'] 14 | 15 | 16 | def load_pretrained(model, model_path, load_fc): 17 | """ 18 | 加载预训练模型 19 | :param model: 模型 20 | :param model_path: 预训练模型文件所在路径 21 | :param load_fc: 是否加载前向全连接层 22 | """ 23 | state_dict = torch.load(model_path) 24 | 25 | if load_fc: 26 | model.load_state_dict(state_dict) 27 | else: 28 | state_dict.pop('classifier.3.weight') 29 | state_dict.pop('classifier.3.bias') 30 | res = model.load_state_dict(state_dict, strict=False) 31 | assert all([key.find('classifier.3') != -1 for key in set(res.missing_keys)]), \ 32 | 'issue loading pretrained weights' 33 | del state_dict 34 | logging.info(f'Loaded pretrained weights for {model_path}') 35 | 36 | 37 | def mobilenetv3_large(pretrained=False, **kwargs): 38 | """ 构造 MobileNetV3-Large model """ 39 | cfgs = [ 40 | # k, t, c, SE, HS, s 41 | [3, 1, 16, 0, 0, 1], 42 | [3, 4, 24, 0, 0, 2], 43 | [3, 3, 24, 0, 0, 1], 44 | [5, 3, 40, 1, 0, 2], 45 | [5, 3, 40, 1, 0, 1], 46 | [5, 3, 40, 1, 0, 1], 47 | [3, 6, 80, 0, 1, 2], 48 | [3, 2.5, 80, 0, 1, 1], 49 | [3, 2.3, 80, 0, 1, 1], 50 | [3, 2.3, 80, 0, 1, 1], 51 | [3, 6, 112, 1, 1, 1], 52 | [3, 6, 112, 1, 1, 1], 53 | [5, 6, 160, 1, 1, 2], 54 | [5, 6, 160, 1, 1, 1], 55 | [5, 6, 160, 1, 1, 1] 56 | ] 57 | model = MobileNetV3(cfgs, mode='large', **kwargs) 58 | 59 | if pretrained: 60 | load_pretrained(model, 'pretrained/mobilenetv3_large.pth', 61 | load_fc=(kwargs.get('num_classes') in (1000, None))) 62 | return model 63 | 64 | 65 | def mobilenetv3_small(pretrained=False, **kwargs): 66 | """ 构造 MobileNetV3-Small模型 """ 67 | 68 | cfgs = [ 69 | # k, t, c, SE, HS, s 70 | [3, 1, 16, 1, 0, 2], 71 | [3, 4.5, 24, 0, 0, 2], 72 | [3, 3.67, 24, 0, 0, 1], 73 | [5, 4, 40, 1, 1, 2], 74 | [5, 6, 40, 1, 1, 1], 75 | [5, 6, 40, 1, 1, 1], 76 | [5, 3, 48, 1, 1, 1], 77 | [5, 3, 48, 1, 1, 1], 78 | [5, 6, 96, 1, 1, 2], 79 | [5, 6, 96, 1, 1, 1], 80 | [5, 6, 96, 1, 1, 1], 81 | ] 82 | model = MobileNetV3(cfgs, mode='small', **kwargs) 83 | 84 | if pretrained: 85 | load_pretrained(model, 'pretrained/mobilenetv3_small.pth', 86 | load_fc=(kwargs.get('num_classes') in (1000, None))) 87 | return model 88 | -------------------------------------------------------------------------------- /models/mobilenetv3/mobilenetv3.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | File mobilenetv3.py 4 | 5 | Creates a MobileNetV3 Model as defined in: 6 | Searching for MobileNetV3 7 | arXiv preprint arXiv:1905.02244. 8 | """ 9 | import torch.nn as nn 10 | import math 11 | 12 | BN_MOMENTUM = 0.01 13 | 14 | 15 | def _make_divisible(v, divisor, min_value=None): 16 | """ 17 | This function is taken from the original tf repo. 18 | It ensures that all layers have a channel number that is divisible by 8 19 | It can be seen here: 20 | https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py 21 | :param v: 22 | :param divisor: 23 | :param min_value: 24 | :return: 25 | """ 26 | if min_value is None: 27 | min_value = divisor 28 | new_v = max(min_value, int(v + divisor / 2) // divisor * divisor) 29 | # Make sure that round down does not go down by more than 10%. 30 | if new_v < 0.9 * v: 31 | new_v += divisor 32 | return new_v 33 | 34 | 35 | class h_sigmoid(nn.Module): 36 | def __init__(self, inplace=True): 37 | super(h_sigmoid, self).__init__() 38 | self.relu = nn.ReLU6(inplace=inplace) 39 | 40 | def forward(self, x): 41 | return self.relu(x + 3) / 6 42 | 43 | 44 | class h_swish(nn.Module): 45 | def __init__(self, inplace=True): 46 | super(h_swish, self).__init__() 47 | self.sigmoid = h_sigmoid(inplace=inplace) 48 | 49 | def forward(self, x): 50 | return x * self.sigmoid(x) 51 | 52 | 53 | class SELayer(nn.Module): 54 | def __init__(self, channel, reduction=4): 55 | super(SELayer, self).__init__() 56 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 57 | self.fc = nn.Sequential( 58 | nn.Linear(channel, _make_divisible(channel // reduction, 8)), 59 | nn.ReLU(inplace=True), 60 | nn.Linear(_make_divisible(channel // reduction, 8), channel), 61 | h_sigmoid() 62 | ) 63 | 64 | def forward(self, x): 65 | b, c, _, _ = x.size() 66 | y = self.avg_pool(x).view(b, c) 67 | y = self.fc(y).view(b, c, 1, 1) 68 | return x * y 69 | 70 | 71 | def conv_3x3_bn(inp, oup, stride): 72 | return nn.Sequential( 73 | nn.Conv2d(inp, oup, 3, stride, 1, bias=False), 74 | nn.BatchNorm2d(oup, momentum=BN_MOMENTUM), 75 | h_swish() 76 | ) 77 | 78 | 79 | def conv_1x1_bn(inp, oup): 80 | return nn.Sequential( 81 | nn.Conv2d(inp, oup, 1, 1, 0, bias=False), 82 | nn.BatchNorm2d(oup, momentum=BN_MOMENTUM), 83 | h_swish() 84 | ) 85 | 86 | 87 | class InvertedResidual(nn.Module): 88 | def __init__(self, inp, hidden_dim, oup, kernel_size, stride, use_se, use_hs): 89 | super(InvertedResidual, self).__init__() 90 | assert stride in [1, 2] 91 | 92 | self.identity = stride == 1 and inp == oup 93 | 94 | if inp == hidden_dim: 95 | self.conv = nn.Sequential( 96 | # dw 97 | nn.Conv2d(hidden_dim, hidden_dim, kernel_size, stride, (kernel_size - 1) // 2, groups=hidden_dim, bias=False), 98 | nn.BatchNorm2d(hidden_dim, momentum=BN_MOMENTUM), 99 | h_swish() if use_hs else nn.ReLU(inplace=True), 100 | # Squeeze-and-Excite 101 | SELayer(hidden_dim) if use_se else nn.Identity(), 102 | # pw-linear 103 | nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False), 104 | nn.BatchNorm2d(oup, momentum=BN_MOMENTUM), 105 | ) 106 | else: 107 | self.conv = nn.Sequential( 108 | # pw 109 | nn.Conv2d(inp, hidden_dim, 1, 1, 0, bias=False), 110 | nn.BatchNorm2d(hidden_dim, momentum=BN_MOMENTUM), 111 | h_swish() if use_hs else nn.ReLU(inplace=True), 112 | # dw 113 | nn.Conv2d(hidden_dim, hidden_dim, kernel_size, stride, (kernel_size - 1) // 2, groups=hidden_dim, bias=False), 114 | nn.BatchNorm2d(hidden_dim, momentum=BN_MOMENTUM), 115 | # Squeeze-and-Excite 116 | SELayer(hidden_dim) if use_se else nn.Identity(), 117 | h_swish() if use_hs else nn.ReLU(inplace=True), 118 | # pw-linear 119 | nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False), 120 | nn.BatchNorm2d(oup, momentum=BN_MOMENTUM), 121 | ) 122 | 123 | def forward(self, x): 124 | if self.identity: 125 | return x + self.conv(x) 126 | else: 127 | return self.conv(x) 128 | 129 | 130 | class MobileNetV3(nn.Module): 131 | def __init__(self, cfgs, mode, num_classes=1000, width_mult=1.): 132 | super(MobileNetV3, self).__init__() 133 | # setting of inverted residual blocks 134 | self.cfgs = cfgs 135 | assert mode in ['large', 'small'] 136 | 137 | # building first layer 138 | input_channel = _make_divisible(16 * width_mult, 8) 139 | layers = [conv_3x3_bn(3, input_channel, 2)] 140 | # building inverted residual blocks 141 | block = InvertedResidual 142 | for k, t, c, use_se, use_hs, s in self.cfgs: 143 | output_channel = _make_divisible(c * width_mult, 8) 144 | exp_size = _make_divisible(input_channel * t, 8) 145 | layers.append(block(input_channel, exp_size, output_channel, k, s, use_se, use_hs)) 146 | input_channel = output_channel 147 | self.features = nn.Sequential(*layers) 148 | # building last several layers 149 | self.conv = conv_1x1_bn(input_channel, exp_size) 150 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 151 | output_channel = {'large': 1280, 'small': 1024} 152 | output_channel = _make_divisible(output_channel[mode] * width_mult, 8) if width_mult > 1.0 else output_channel[mode] 153 | self.classifier = nn.Sequential( 154 | nn.Linear(exp_size, output_channel), 155 | h_swish(), 156 | nn.Dropout(0.2), 157 | nn.Linear(output_channel, num_classes), 158 | ) 159 | 160 | self._initialize_weights() 161 | 162 | def forward(self, x): 163 | x = self.features(x) 164 | x = self.conv(x) 165 | x = self.avgpool(x) 166 | x = x.view(x.size(0), -1) 167 | x = self.classifier(x) 168 | return x 169 | 170 | def _initialize_weights(self): 171 | for m in self.modules(): 172 | if isinstance(m, nn.Conv2d): 173 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 174 | m.weight.data.normal_(0, math.sqrt(2. / n)) 175 | if m.bias is not None: 176 | m.bias.data.zero_() 177 | elif isinstance(m, nn.BatchNorm2d): 178 | m.weight.data.fill_(1) 179 | m.bias.data.zero_() 180 | elif isinstance(m, nn.Linear): 181 | n = m.weight.size(1) 182 | m.weight.data.normal_(0, 0.01) 183 | m.bias.data.zero_() 184 | -------------------------------------------------------------------------------- /models/model_factory.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | File model_factory.py 4 | 5 | PyTorch官方提供的预定义模型及自定义模型 6 | """ 7 | import logging 8 | 9 | import torch 10 | from torchvision import models 11 | from .efficientnet import ( 12 | efficientnet_b0, efficientnet_b1, efficientnet_b2, efficientnet_b3, 13 | efficientnet_b4, efficientnet_b5, efficientnet_b6, efficientnet_b7, 14 | efficientnet_b8 15 | ) 16 | from .mobilenetv3 import mobilenetv3_small, mobilenetv3_large 17 | from .resnest import ( 18 | resnest50, resnest101, resnest200, resnest269, 19 | resnest50_fast_1s1x64d, resnest50_fast_2s1x64d, resnest50_fast_4s1x64d, 20 | resnest50_fast_1s2x40d, resnest50_fast_2s2x40d, resnest50_fast_4s2x40d, 21 | resnest50_fast_1s4x24d 22 | ) 23 | 24 | 25 | models_map = { 26 | 'efficientnet_b0': efficientnet_b0, 'efficientnet_b1': efficientnet_b1, 27 | 'efficientnet_b2': efficientnet_b2, 'efficientnet_b3': efficientnet_b3, 28 | 'efficientnet_b4': efficientnet_b4, 'efficientnet_b5': efficientnet_b5, 29 | 'efficientnet_b6': efficientnet_b6, 'efficientnet_b7': efficientnet_b7, 30 | 'efficientnet_b8': efficientnet_b8, 31 | 32 | 'mobilenetv3_small': mobilenetv3_small, 'mobilenetv3_large': mobilenetv3_large, 33 | 34 | 'resnest50': resnest50, 'resnest101': resnest101, 'resnest200': resnest200, 'resnest269': resnest269, 35 | 'resnest50_fast_1s1x64d': resnest50_fast_1s1x64d, 'resnest50_fast_2s1x64d': resnest50_fast_2s1x64d, 36 | 'resnest50_fast_4s1x64d': resnest50_fast_4s1x64d, 'resnest50_fast_1s2x40d': resnest50_fast_1s2x40d, 37 | 'resnest50_fast_2s2x40d': resnest50_fast_2s2x40d, 'resnest50_fast_4s2x40d': resnest50_fast_4s2x40d, 38 | 'resnest50_fast_1s4x24d': resnest50_fast_1s4x24d, 39 | } 40 | 41 | 42 | def get_model(name, pretrained=False, **kwargs): 43 | """ 44 | 获取指定名称的模型 45 | :param name: 指定模型名称 46 | :param pretrained: 是否加载预训练模型 47 | :param kwargs: num_classes等 48 | :return 指定名称的模型 49 | """ 50 | if name in models_map: 51 | model = models_map[name](pretrained, **kwargs) 52 | else: 53 | model = models.__dict__[name](**kwargs) 54 | if pretrained: 55 | model_path = f'pretrained/{name}.pth' 56 | state_dict = torch.load(model_path) 57 | state_dict.pop('fc.weight') 58 | state_dict.pop('fc.bias') 59 | acc = model.load_state_dict(state_dict, strict=False) 60 | del state_dict 61 | assert set(acc.missing_keys) == {'fc.weight', 'fc.bias'}, 'issue loading pretrained weights' 62 | logging.info(f"=> using pre-trained model '{model_path}'") 63 | 64 | return model 65 | -------------------------------------------------------------------------------- /models/resnest/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | File __init__.py.py 4 | 5 | ResNeSt模型结构定义脚本 6 | ref: https://github.com/zhanghang1989/ResNeSt 7 | """ 8 | from .resnest import * 9 | from .ablation import * 10 | -------------------------------------------------------------------------------- /models/resnest/ablation.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 3 | # Created by: Hang Zhang 4 | # Email: zhanghang0704@gmail.com 5 | # Copyright (c) 2020 6 | # 7 | # LICENSE file in the root directory of this source tree 8 | # +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 9 | """ResNeSt ablation study models""" 10 | from .resnet import ResNet, Bottleneck, load_pretrained 11 | 12 | __all__ = ['resnest50_fast_1s1x64d', 'resnest50_fast_2s1x64d', 'resnest50_fast_4s1x64d', 13 | 'resnest50_fast_1s2x40d', 'resnest50_fast_2s2x40d', 'resnest50_fast_4s2x40d', 14 | 'resnest50_fast_1s4x24d'] 15 | 16 | 17 | def resnest50_fast_1s1x64d(pretrained=False, **kwargs): 18 | model = ResNet(Bottleneck, [3, 4, 6, 3], 19 | radix=1, groups=1, bottleneck_width=64, 20 | deep_stem=True, stem_width=32, avg_down=True, 21 | avd=True, avd_first=True, **kwargs) 22 | if pretrained: 23 | load_pretrained(model, 'pretrained/resnest50_fast_1s1x64d.pth', 24 | load_fc=(kwargs.get('num_classes') in (1000, None))) 25 | return model 26 | 27 | 28 | def resnest50_fast_2s1x64d(pretrained=False, **kwargs): 29 | model = ResNet(Bottleneck, [3, 4, 6, 3], 30 | radix=2, groups=1, bottleneck_width=64, 31 | deep_stem=True, stem_width=32, avg_down=True, 32 | avd=True, avd_first=True, **kwargs) 33 | if pretrained: 34 | load_pretrained(model, 'pretrained/resnest50_fast_2s1x64d.pth', 35 | load_fc=(kwargs.get('num_classes') in (1000, None))) 36 | return model 37 | 38 | 39 | def resnest50_fast_4s1x64d(pretrained=False, **kwargs): 40 | model = ResNet(Bottleneck, [3, 4, 6, 3], 41 | radix=4, groups=1, bottleneck_width=64, 42 | deep_stem=True, stem_width=32, avg_down=True, 43 | avd=True, avd_first=True, **kwargs) 44 | if pretrained: 45 | load_pretrained(model, 'pretrained/resnest50_fast_4s1x64d.pth', 46 | load_fc=(kwargs.get('num_classes') in (1000, None))) 47 | return model 48 | 49 | 50 | def resnest50_fast_1s2x40d(pretrained=False, **kwargs): 51 | model = ResNet(Bottleneck, [3, 4, 6, 3], 52 | radix=1, groups=2, bottleneck_width=40, 53 | deep_stem=True, stem_width=32, avg_down=True, 54 | avd=True, avd_first=True, **kwargs) 55 | if pretrained: 56 | load_pretrained(model, 'pretrained/resnest50_fast_1s2x40d.pth', 57 | load_fc=(kwargs.get('num_classes') in (1000, None))) 58 | return model 59 | 60 | 61 | def resnest50_fast_2s2x40d(pretrained=False, **kwargs): 62 | model = ResNet(Bottleneck, [3, 4, 6, 3], 63 | radix=2, groups=2, bottleneck_width=40, 64 | deep_stem=True, stem_width=32, avg_down=True, 65 | avd=True, avd_first=True, **kwargs) 66 | if pretrained: 67 | load_pretrained(model, 'pretrained/resnest50_fast_2s2x40d.pth', 68 | load_fc=(kwargs.get('num_classes') in (1000, None))) 69 | return model 70 | 71 | 72 | def resnest50_fast_4s2x40d(pretrained=False, **kwargs): 73 | model = ResNet(Bottleneck, [3, 4, 6, 3], 74 | radix=4, groups=2, bottleneck_width=40, 75 | deep_stem=True, stem_width=32, avg_down=True, 76 | avd=True, avd_first=True, **kwargs) 77 | if pretrained: 78 | load_pretrained(model, 'pretrained/resnest50_fast_4s2x40d.pth', 79 | load_fc=(kwargs.get('num_classes') in (1000, None))) 80 | return model 81 | 82 | 83 | def resnest50_fast_1s4x24d(pretrained=False, **kwargs): 84 | model = ResNet(Bottleneck, [3, 4, 6, 3], 85 | radix=1, groups=4, bottleneck_width=24, 86 | deep_stem=True, stem_width=32, avg_down=True, 87 | avd=True, avd_first=True, **kwargs) 88 | if pretrained: 89 | load_pretrained(model, 'pretrained/resnest50_fast_1s4x24d.pth', 90 | load_fc=(kwargs.get('num_classes') in (1000, None))) 91 | return model 92 | -------------------------------------------------------------------------------- /models/resnest/resnest.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 3 | # Created by: Hang Zhang 4 | # Email: zhanghang0704@gmail.com 5 | # Copyright (c) 2020 6 | # 7 | # LICENSE file in the root directory of this source tree 8 | # +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 9 | """ResNeSt models""" 10 | from .resnet import ResNet, Bottleneck, load_pretrained 11 | 12 | __all__ = ['resnest50', 'resnest101', 'resnest200', 'resnest269', ] 13 | 14 | 15 | def resnest50(pretrained=False, **kwargs): 16 | model = ResNet(Bottleneck, [3, 4, 6, 3], 17 | radix=2, groups=1, bottleneck_width=64, 18 | deep_stem=True, stem_width=32, avg_down=True, 19 | avd=True, avd_first=False, **kwargs) 20 | if pretrained: 21 | load_pretrained(model, 'pretrained/resnest50.pth', 22 | load_fc=(kwargs.get('num_classes') in (1000, None))) 23 | return model 24 | 25 | 26 | def resnest101(pretrained=False, **kwargs): 27 | model = ResNet(Bottleneck, [3, 4, 23, 3], 28 | radix=2, groups=1, bottleneck_width=64, 29 | deep_stem=True, stem_width=64, avg_down=True, 30 | avd=True, avd_first=False, **kwargs) 31 | if pretrained: 32 | load_pretrained(model, 'pretrained/resnest101.pth', 33 | load_fc=(kwargs.get('num_classes') in (1000, None))) 34 | return model 35 | 36 | 37 | def resnest200(pretrained=False, **kwargs): 38 | model = ResNet(Bottleneck, [3, 24, 36, 3], 39 | radix=2, groups=1, bottleneck_width=64, 40 | deep_stem=True, stem_width=64, avg_down=True, 41 | avd=True, avd_first=False, **kwargs) 42 | if pretrained: 43 | load_pretrained(model, 'pretrained/resnest200.pth', 44 | load_fc=(kwargs.get('num_classes') in (1000, None))) 45 | return model 46 | 47 | 48 | def resnest269(pretrained=False, **kwargs): 49 | model = ResNet(Bottleneck, [3, 30, 48, 8], 50 | radix=2, groups=1, bottleneck_width=64, 51 | deep_stem=True, stem_width=64, avg_down=True, 52 | avd=True, avd_first=False, **kwargs) 53 | if pretrained: 54 | load_pretrained(model, 'pretrained/resnest269.pth', 55 | load_fc=(kwargs.get('num_classes') in (1000, None))) 56 | return model 57 | -------------------------------------------------------------------------------- /models/resnest/splat.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ Split-Attention模块 """ 3 | 4 | import torch 5 | from torch import nn 6 | import torch.nn.functional as F 7 | from torch.nn import Conv2d, Module, Linear, BatchNorm2d, ReLU 8 | from torch.nn.modules.utils import _pair 9 | 10 | __all__ = ['SplAtConv2d'] 11 | 12 | 13 | class SplAtConv2d(Module): 14 | """Split-Attention Conv2d 15 | """ 16 | 17 | def __init__(self, in_channels, channels, kernel_size, stride=(1, 1), padding=(0, 0), 18 | dilation=(1, 1), groups=1, bias=True, 19 | radix=2, reduction_factor=4, 20 | rectify=False, rectify_avg=False, norm_layer=None, 21 | dropblock_prob=0.0, **kwargs): 22 | super(SplAtConv2d, self).__init__() 23 | padding = _pair(padding) 24 | self.rectify = rectify and (padding[0] > 0 or padding[1] > 0) 25 | self.rectify_avg = rectify_avg 26 | inter_channels = max(in_channels * radix // reduction_factor, 32) 27 | self.radix = radix 28 | self.cardinality = groups 29 | self.channels = channels 30 | self.dropblock_prob = dropblock_prob 31 | if self.rectify: 32 | from rfconv import RFConv2d 33 | self.conv = RFConv2d(in_channels, channels * radix, kernel_size, stride, padding, dilation, 34 | groups=groups * radix, bias=bias, average_mode=rectify_avg, **kwargs) 35 | else: 36 | self.conv = Conv2d(in_channels, channels * radix, kernel_size, stride, padding, dilation, 37 | groups=groups * radix, bias=bias, **kwargs) 38 | self.use_bn = norm_layer is not None 39 | if self.use_bn: 40 | self.bn0 = norm_layer(channels * radix) 41 | self.relu = ReLU(inplace=True) 42 | self.fc1 = Conv2d(channels, inter_channels, 1, groups=self.cardinality) 43 | if self.use_bn: 44 | self.bn1 = norm_layer(inter_channels) 45 | self.fc2 = Conv2d(inter_channels, channels * radix, 1, groups=self.cardinality) 46 | if dropblock_prob > 0.0: 47 | self.dropblock = DropBlock2D(dropblock_prob, 3) 48 | self.rsoftmax = rSoftMax(radix, groups) 49 | 50 | def forward(self, x): 51 | x = self.conv(x) 52 | if self.use_bn: 53 | x = self.bn0(x) 54 | if self.dropblock_prob > 0.0: 55 | x = self.dropblock(x) 56 | x = self.relu(x) 57 | 58 | batch, rchannel = x.shape[:2] 59 | if self.radix > 1: 60 | splited = torch.split(x, rchannel // self.radix, dim=1) 61 | gap = sum(splited) 62 | else: 63 | gap = x 64 | gap = F.adaptive_avg_pool2d(gap, 1) 65 | gap = self.fc1(gap) 66 | 67 | if self.use_bn: 68 | gap = self.bn1(gap) 69 | gap = self.relu(gap) 70 | 71 | atten = self.fc2(gap) 72 | atten = self.rsoftmax(atten).view(batch, -1, 1, 1) 73 | 74 | if self.radix > 1: 75 | attens = torch.split(atten, rchannel // self.radix, dim=1) 76 | out = sum([att * split for (att, split) in zip(attens, splited)]) 77 | else: 78 | out = atten * x 79 | return out.contiguous() 80 | 81 | 82 | class rSoftMax(nn.Module): 83 | def __init__(self, radix, cardinality): 84 | super().__init__() 85 | self.radix = radix 86 | self.cardinality = cardinality 87 | 88 | def forward(self, x): 89 | batch = x.size(0) 90 | if self.radix > 1: 91 | x = x.view(batch, self.cardinality, self.radix, -1).transpose(1, 2) 92 | x = F.softmax(x, dim=1) 93 | x = x.reshape(batch, -1) 94 | else: 95 | x = torch.sigmoid(x) 96 | return x 97 | -------------------------------------------------------------------------------- /optim/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | File __init__.py.py 4 | 5 | """ -------------------------------------------------------------------------------- /optim/torchtools/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | File __init__.py 4 | 5 | """ 6 | from . import optim, nn, lr_scheduler 7 | -------------------------------------------------------------------------------- /optim/torchtools/lr_scheduler/__init__.py: -------------------------------------------------------------------------------- 1 | from .delayed import DelayerScheduler, DelayedCosineAnnealingLR 2 | -------------------------------------------------------------------------------- /optim/torchtools/lr_scheduler/delayed.py: -------------------------------------------------------------------------------- 1 | from torch.optim.lr_scheduler import _LRScheduler, CosineAnnealingLR 2 | 3 | class DelayerScheduler(_LRScheduler): 4 | """ Starts with a flat lr schedule until it reaches N epochs the applies a scheduler 5 | 6 | Args: 7 | optimizer (Optimizer): Wrapped optimizer. 8 | delay_epochs: number of epochs to keep the initial lr until starting aplying the scheduler 9 | after_scheduler: after target_epoch, use this scheduler(eg. ReduceLROnPlateau) 10 | """ 11 | 12 | def __init__(self, optimizer, delay_epochs, after_scheduler): 13 | self.delay_epochs = delay_epochs 14 | self.after_scheduler = after_scheduler 15 | self.finished = False 16 | super().__init__(optimizer) 17 | 18 | def get_lr(self): 19 | if self.last_epoch >= self.delay_epochs: 20 | if not self.finished: 21 | self.after_scheduler.base_lrs = self.base_lrs 22 | self.finished = True 23 | return self.after_scheduler.get_lr() 24 | 25 | return self.base_lrs 26 | 27 | def step(self, epoch=None): 28 | if self.finished: 29 | if epoch is None: 30 | self.after_scheduler.step(None) 31 | else: 32 | self.after_scheduler.step(epoch - self.delay_epochs) 33 | else: 34 | return super(DelayerScheduler, self).step(epoch) 35 | 36 | def DelayedCosineAnnealingLR(optimizer, delay_epochs, cosine_annealing_epochs): 37 | base_scheduler = CosineAnnealingLR(optimizer, cosine_annealing_epochs) 38 | return DelayerScheduler(optimizer, delay_epochs, base_scheduler) -------------------------------------------------------------------------------- /optim/torchtools/nn/__init__.py: -------------------------------------------------------------------------------- 1 | from .mish import Mish 2 | from .simple_self_attention import SimpleSelfAttention 3 | from .vq import VectorQuantize, Binarize 4 | from .gp_loss import GPLoss 5 | from .pixel_normalzation import PixelNorm 6 | from .perceptual import TVLoss 7 | from .adain import AdaIN -------------------------------------------------------------------------------- /optim/torchtools/nn/adain.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | class AdaIN(nn.Module): 5 | def __init__(self, n_channels): 6 | super(AdaIN, self).__init__() 7 | self.norm = nn.InstanceNorm2d(n_channels) 8 | 9 | def forward(self, image, style): 10 | factor, bias = style.view(style.size(0), style.size(1), 1, 1).chunk(2, dim=1) 11 | result = self.norm(image) * factor + bias 12 | return result -------------------------------------------------------------------------------- /optim/torchtools/nn/functional/__init__.py: -------------------------------------------------------------------------------- 1 | from .vq import vector_quantize, binarize 2 | from .gradient_penalty import gradient_penalty 3 | from .perceptual import total_variation -------------------------------------------------------------------------------- /optim/torchtools/nn/functional/gradient_penalty.py: -------------------------------------------------------------------------------- 1 | #### 2 | # CODE TAKEN WITH FEW MODIFICATIONS FROM https://github.com/caogang/wgan-gp 3 | # ORIGINAL PAPER https://arxiv.org/pdf/1704.00028.pdf 4 | #### 5 | 6 | import torch 7 | from torch import autograd 8 | 9 | def gradient_penalty(netD, real_data, fake_data, l=10): 10 | batch_size = real_data.size(0) 11 | alpha = real_data.new_empty((batch_size, 1, 1, 1)).uniform_(0, 1) 12 | alpha = alpha.expand_as(real_data) 13 | 14 | interpolates = alpha * real_data + ((1 - alpha) * fake_data) 15 | interpolates = autograd.Variable(interpolates, requires_grad=True) 16 | 17 | disc_interpolates = netD(interpolates) 18 | 19 | gradients = autograd.grad(outputs=disc_interpolates, inputs=interpolates, 20 | grad_outputs=real_data.new_ones(disc_interpolates.size()), 21 | create_graph=True, retain_graph=True, only_inputs=True)[0] 22 | 23 | gradients = gradients.view(gradients.size(0), -1) 24 | gradients_norm = torch.sqrt(torch.sum(gradients ** 2, dim=1) + 1e-12) 25 | gradient_penalty = ((gradients_norm - 1) ** 2).mean() * l 26 | 27 | return gradient_penalty -------------------------------------------------------------------------------- /optim/torchtools/nn/functional/perceptual.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | def total_variation(X, reduction='sum'): 4 | tv_h = torch.abs(X[:, :, :, 1:] - X[:, :, :, :-1]) 5 | tv_v = torch.abs(X[:, :, 1:] - X[:, :, :-1]) 6 | 7 | tv = torch.mean(tv_h) + torch.mean(tv_v) if reduction == 'mean' else torch.sum(tv_h) + torch.sum(tv_v) 8 | 9 | return tv -------------------------------------------------------------------------------- /optim/torchtools/nn/functional/vq.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.autograd import Function 3 | 4 | class vector_quantize(Function): 5 | @staticmethod 6 | def forward(ctx, x, codebook): 7 | with torch.no_grad(): 8 | codebook_sqr = torch.sum(codebook ** 2, dim=1) 9 | x_sqr = torch.sum(x ** 2, dim=1, keepdim=True) 10 | 11 | dist = torch.addmm(codebook_sqr + x_sqr, x, codebook.t(), alpha=-2.0, beta=1.0) 12 | _, indices = dist.min(dim=1) 13 | 14 | ctx.save_for_backward(indices, codebook) 15 | ctx.mark_non_differentiable(indices) 16 | 17 | nn = torch.index_select(codebook, 0, indices) 18 | return nn, indices 19 | 20 | @staticmethod 21 | def backward(ctx, grad_output, grad_indices): 22 | grad_inputs, grad_codebook = None, None 23 | 24 | if ctx.needs_input_grad[0]: 25 | grad_inputs = grad_output.clone() 26 | if ctx.needs_input_grad[1]: 27 | # Gradient wrt. the codebook 28 | indices, codebook = ctx.saved_tensors 29 | 30 | grad_codebook = torch.zeros_like(codebook) 31 | grad_codebook.index_add_(0, indices, grad_output) 32 | 33 | return (grad_inputs, grad_codebook) 34 | 35 | 36 | class binarize(Function): 37 | @staticmethod 38 | def forward(ctx, x, threshold=0.5): 39 | with torch.no_grad(): 40 | binarized = (x > threshold).float() 41 | ctx.mark_non_differentiable(binarized) 42 | 43 | return binarized 44 | 45 | @staticmethod 46 | def backward(ctx, grad_output): 47 | grad_inputs = None 48 | 49 | if ctx.needs_input_grad[0]: 50 | grad_inputs = grad_output.clone() 51 | 52 | return grad_inputs -------------------------------------------------------------------------------- /optim/torchtools/nn/gp_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from .functional import gradient_penalty 4 | 5 | class GPLoss(nn.Module): 6 | def __init__(self, discriminator, l=10): 7 | super(GPLoss, self).__init__() 8 | self.discriminator = discriminator 9 | self.l = l 10 | 11 | def forward(self, real_data, fake_data): 12 | return gradient_penalty(self.discriminator, real_data, fake_data, self.l) -------------------------------------------------------------------------------- /optim/torchtools/nn/mish.py: -------------------------------------------------------------------------------- 1 | #### 2 | # CODE TAKEN FROM https://github.com/lessw2020/mish 3 | # ORIGINAL PAPER https://arxiv.org/abs/1908.08681v1 4 | #### 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F #(uncomment if needed,but you likely already have it) 9 | 10 | #Mish - "Mish: A Self Regularized Non-Monotonic Neural Activation Function" 11 | #https://arxiv.org/abs/1908.08681v1 12 | #implemented for PyTorch / FastAI by lessw2020 13 | #github: https://github.com/lessw2020/mish 14 | 15 | class Mish(nn.Module): 16 | def __init__(self): 17 | super().__init__() 18 | 19 | def forward(self, x): 20 | #inlining this saves 1 second per epoch (V100 GPU) vs having a temp x and then returning x(!) 21 | return x *( torch.tanh(F.softplus(x))) -------------------------------------------------------------------------------- /optim/torchtools/nn/perceptual.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from .functional import total_variation 4 | 5 | class TVLoss(nn.Module): 6 | def __init__(self, reduction='sum', alpha=1e-4): 7 | super(TVLoss, self).__init__() 8 | self.reduction = reduction 9 | self.alpha = alpha 10 | 11 | def forward(self, x): 12 | return total_variation(x, reduction=self.reduction) * self.alpha -------------------------------------------------------------------------------- /optim/torchtools/nn/pixel_normalzation.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | class PixelNorm(nn.Module): 5 | def __init__(self): 6 | super(PixelNorm, self).__init__() 7 | 8 | def forward(self, x): 9 | return x / torch.sqrt(torch.mean(x ** 2, dim=1, keepdim=True) + 1e-8) -------------------------------------------------------------------------------- /optim/torchtools/nn/simple_self_attention.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch, math, sys 3 | 4 | #### 5 | # CODE TAKEN FROM https://github.com/sdoria/SimpleSelfAttention 6 | #### 7 | 8 | #Unmodified from https://github.com/fastai/fastai/blob/5c51f9eabf76853a89a9bc5741804d2ed4407e49/fastai/layers.py 9 | def conv1d(ni, no, ks=1, stride=1, padding=0, bias=False): 10 | "Create and initialize a `nn.Conv1d` layer with spectral normalization." 11 | conv = nn.Conv1d(ni, no, ks, stride=stride, padding=padding, bias=bias) 12 | nn.init.kaiming_normal_(conv.weight) 13 | if bias: conv.bias.data.zero_() 14 | return nn.utils.spectral_norm(conv) 15 | 16 | # Adapted from SelfAttention layer at https://github.com/fastai/fastai/blob/5c51f9eabf76853a89a9bc5741804d2ed4407e49/fastai/layers.py 17 | # Inspired by https://arxiv.org/pdf/1805.08318.pdf 18 | class SimpleSelfAttention(nn.Module): 19 | 20 | def __init__(self, n_in, ks=1, sym=False): 21 | super().__init__() 22 | self.conv = conv1d(n_in, n_in, ks, padding=ks//2, bias=False) 23 | self.gamma = nn.Parameter(torch.Tensor([0.])) 24 | self.sym = sym 25 | self.n_in = n_in 26 | 27 | def forward(self, x): 28 | if self.sym: 29 | # symmetry hack by https://github.com/mgrankin 30 | c = self.conv.weight.view(self.n_in,self.n_in) 31 | c = (c + c.t())/2 32 | self.conv.weight = c.view(self.n_in,self.n_in,1) 33 | 34 | size = x.size() 35 | x = x.view(*size[:2],-1) # (C,N) 36 | 37 | # changed the order of mutiplication to avoid O(N^2) complexity 38 | # (x*xT)*(W*x) instead of (x*(xT*(W*x))) 39 | 40 | convx = self.conv(x) # (C,C) * (C,N) = (C,N) => O(NC^2) 41 | xxT = torch.bmm(x, x.permute(0,2,1).contiguous()) # (C,N) * (N,C) = (C,C) => O(NC^2) 42 | o = torch.bmm(xxT, convx) # (C,C) * (C,N) = (C,N) => O(NC^2) 43 | o = self.gamma * o + x 44 | 45 | return o.view(*size).contiguous() -------------------------------------------------------------------------------- /optim/torchtools/nn/vq.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from .functional.vq import vector_quantize, binarize 4 | 5 | class VectorQuantize(nn.Module): 6 | def __init__(self, embedding_size, k, ema_decay=0.99, ema_loss=False): 7 | """ 8 | Takes an input of variable size (as long as the last dimension matches the embedding size). 9 | Returns one tensor containing the nearest neigbour embeddings to each of the inputs, 10 | with the same size as the input, vq and commitment components for the loss as a touple 11 | in the second output and the indices of the quantized vectors in the third: 12 | quantized, (vq_loss, commit_loss), indices 13 | """ 14 | super(VectorQuantize, self).__init__() 15 | 16 | self.codebook = nn.Embedding(k, embedding_size) 17 | self.codebook.weight.data.uniform_(-1./k, 1./k) 18 | self.vq = vector_quantize.apply 19 | 20 | self.ema_decay = ema_decay 21 | self.ema_loss = ema_loss 22 | if ema_loss: 23 | self.register_buffer('ema_element_count', torch.ones(k)) 24 | self.register_buffer('ema_weight_sum', torch.zeros_like(self.codebook.weight)) 25 | 26 | def _laplace_smoothing(self, x, epsilon): 27 | n = torch.sum(x) 28 | return ((x + epsilon) / (n + x.size(0) * epsilon) * n) 29 | 30 | def _updateEMA(self, z_e_x, indices): 31 | mask = nn.functional.one_hot(indices, self.ema_element_count.size(0)).float() 32 | elem_count = mask.sum(dim=0) 33 | weight_sum = torch.mm(mask.t(), z_e_x) 34 | 35 | self.ema_element_count = (self.ema_decay * self.ema_element_count) + ((1-self.ema_decay) * elem_count) 36 | self.ema_element_count = self._laplace_smoothing(self.ema_element_count, 1e-5) 37 | self.ema_weight_sum = (self.ema_decay * self.ema_weight_sum) + ((1-self.ema_decay) * weight_sum) 38 | 39 | self.codebook.weight.data = self.ema_weight_sum / self.ema_element_count.unsqueeze(-1) 40 | 41 | def idx2vq(self, idx, dim=-1): 42 | q_idx = self.codebook(idx) 43 | if dim != -1: 44 | q_idx = q_idx.transpose(-1, dim) 45 | return q_idx 46 | 47 | def forward(self, x, get_losses=True, dim=-1): 48 | if dim != -1: 49 | x = x.transpose(dim, -1) 50 | z_e_x = x.contiguous().view(-1, x.size(-1)) if len(x.shape) > 2 else x 51 | z_q_x, indices = self.vq(z_e_x, self.codebook.weight.detach()) 52 | vq_loss, commit_loss = None, None 53 | if self.ema_loss and self.training: 54 | self._updateEMA(z_e_x.detach(), indices.detach()) 55 | # pick the graded embeddings after updating the codebook in order to have a more accurate commitment loss 56 | z_q_x_grd = torch.index_select(self.codebook.weight, dim=0, index=indices) 57 | if get_losses: 58 | vq_loss = (z_q_x_grd - z_e_x.detach()).pow(2).mean() 59 | commit_loss = (z_e_x - z_q_x_grd.detach()).pow(2).mean() 60 | 61 | z_q_x = z_q_x.view(x.shape) 62 | if dim != -1: 63 | z_q_x = z_q_x.transpose(dim, -1) 64 | return z_q_x, (vq_loss, commit_loss), indices.view(x.shape[:-1]) 65 | 66 | class Binarize(nn.Module): 67 | def __init__(self, threshold=0.5): 68 | """ 69 | Takes an input of any size. 70 | Returns an output of the same size but with its values binarized (0 if input is below a threshold, 1 if its above) 71 | """ 72 | super(Binarize, self).__init__() 73 | 74 | self.bin = binarize.apply 75 | self.threshold = threshold 76 | 77 | def forward(self, x): 78 | return self.bin(x, self.threshold) -------------------------------------------------------------------------------- /optim/torchtools/optim/__init__.py: -------------------------------------------------------------------------------- 1 | from .radam import RAdam, PlainRAdam, AdamW 2 | from .ranger import Ranger 3 | from .lookahead import Lookahead, LookaheadAdam 4 | from .over9000 import Over9000, RangerLars 5 | from .ralamb import Ralamb 6 | from .novograd import Novograd 7 | from .lamb import Lamb 8 | -------------------------------------------------------------------------------- /optim/torchtools/optim/lamb.py: -------------------------------------------------------------------------------- 1 | #### 2 | # CODE TAKEN FROM https://github.com/mgrankin/over9000 3 | #### 4 | 5 | import collections 6 | 7 | import torch 8 | from torch.optim.optimizer import Optimizer 9 | 10 | try: 11 | from torch.utils.tensorboard import SummaryWriter 12 | 13 | def log_lamb_rs(optimizer: Optimizer, event_writer: SummaryWriter, token_count: int): 14 | """Log a histogram of trust ratio scalars in across layers.""" 15 | results = collections.defaultdict(list) 16 | for group in optimizer.param_groups: 17 | for p in group['params']: 18 | state = optimizer.state[p] 19 | for i in ('weight_norm', 'adam_norm', 'trust_ratio'): 20 | if i in state: 21 | results[i].append(state[i]) 22 | 23 | for k, v in results.items(): 24 | event_writer.add_histogram(f'lamb/{k}', torch.tensor(v), token_count) 25 | except ModuleNotFoundError as e: 26 | print("To use this log_lamb_rs, please run 'pip install tensorboard'. Also you must have Tensorboard running to see results") 27 | 28 | class Lamb(Optimizer): 29 | r"""Implements Lamb algorithm. 30 | It has been proposed in `Large Batch Optimization for Deep Learning: Training BERT in 76 minutes`_. 31 | Arguments: 32 | params (iterable): iterable of parameters to optimize or dicts defining 33 | parameter groups 34 | lr (float, optional): learning rate (default: 1e-3) 35 | betas (Tuple[float, float], optional): coefficients used for computing 36 | running averages of gradients and its square (default: (0.9, 0.999)) 37 | eps (float, optional): term added to the denominator to improve 38 | numerical stability (default: 1e-8) 39 | weight_decay (float, optional): weight decay (L2 penalty) (default: 0) 40 | adam (bool, optional): always use trust ratio = 1, which turns this into 41 | Adam. Useful for comparison purposes. 42 | .. _Large Batch Optimization for Deep Learning: Training BERT in 76 minutes: 43 | https://arxiv.org/abs/1904.00962 44 | """ 45 | 46 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-6, 47 | weight_decay=0, adam=False): 48 | if not 0.0 <= lr: 49 | raise ValueError("Invalid learning rate: {}".format(lr)) 50 | if not 0.0 <= eps: 51 | raise ValueError("Invalid epsilon value: {}".format(eps)) 52 | if not 0.0 <= betas[0] < 1.0: 53 | raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) 54 | if not 0.0 <= betas[1] < 1.0: 55 | raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) 56 | defaults = dict(lr=lr, betas=betas, eps=eps, 57 | weight_decay=weight_decay) 58 | self.adam = adam 59 | super(Lamb, self).__init__(params, defaults) 60 | 61 | def step(self, closure=None): 62 | """Performs a single optimization step. 63 | Arguments: 64 | closure (callable, optional): A closure that reevaluates the model 65 | and returns the loss. 66 | """ 67 | loss = None 68 | if closure is not None: 69 | loss = closure() 70 | 71 | for group in self.param_groups: 72 | for p in group['params']: 73 | if p.grad is None: 74 | continue 75 | grad = p.grad.data 76 | if grad.is_sparse: 77 | raise RuntimeError('Lamb does not support sparse gradients, consider SparseAdam instad.') 78 | 79 | state = self.state[p] 80 | 81 | # State initialization 82 | if len(state) == 0: 83 | state['step'] = 0 84 | # Exponential moving average of gradient values 85 | state['exp_avg'] = torch.zeros_like(p.data) 86 | # Exponential moving average of squared gradient values 87 | state['exp_avg_sq'] = torch.zeros_like(p.data) 88 | 89 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] 90 | beta1, beta2 = group['betas'] 91 | 92 | state['step'] += 1 93 | 94 | # Decay the first and second moment running average coefficient 95 | # m_t 96 | exp_avg.mul_(beta1).add_(1 - beta1, grad) 97 | # v_t 98 | exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) 99 | 100 | # Paper v3 does not use debiasing. 101 | # bias_correction1 = 1 - beta1 ** state['step'] 102 | # bias_correction2 = 1 - beta2 ** state['step'] 103 | # Apply bias to lr to avoid broadcast. 104 | step_size = group['lr'] # * math.sqrt(bias_correction2) / bias_correction1 105 | 106 | weight_norm = p.data.pow(2).sum().sqrt().clamp(0, 10) 107 | 108 | adam_step = exp_avg / exp_avg_sq.sqrt().add(group['eps']) 109 | if group['weight_decay'] != 0: 110 | adam_step.add_(group['weight_decay'], p.data) 111 | 112 | adam_norm = adam_step.pow(2).sum().sqrt() 113 | if weight_norm == 0 or adam_norm == 0: 114 | trust_ratio = 1 115 | else: 116 | trust_ratio = weight_norm / adam_norm 117 | state['weight_norm'] = weight_norm 118 | state['adam_norm'] = adam_norm 119 | state['trust_ratio'] = trust_ratio 120 | if self.adam: 121 | trust_ratio = 1 122 | 123 | p.data.add_(-step_size * trust_ratio, adam_step) 124 | 125 | return loss -------------------------------------------------------------------------------- /optim/torchtools/optim/lookahead.py: -------------------------------------------------------------------------------- 1 | #### 2 | # CODE TAKEN FROM https://github.com/lonePatient/lookahead_pytorch 3 | # Original paper: https://arxiv.org/abs/1907.08610 4 | #### 5 | # Lookahead implementation from https://github.com/rwightman/pytorch-image-models/blob/master/timm/optim/lookahead.py 6 | 7 | """ Lookahead Optimizer Wrapper. 8 | Implementation modified from: https://github.com/alphadl/lookahead.pytorch 9 | Paper: `Lookahead Optimizer: k steps forward, 1 step back` - https://arxiv.org/abs/1907.08610 10 | """ 11 | import torch 12 | from torch.optim import Adam 13 | from torch.optim.optimizer import Optimizer 14 | from collections import defaultdict 15 | 16 | 17 | class Lookahead(Optimizer): 18 | def __init__(self, base_optimizer, alpha=0.5, k=6): 19 | if not 0.0 <= alpha <= 1.0: 20 | raise ValueError(f'Invalid slow update rate: {alpha}') 21 | if not 1 <= k: 22 | raise ValueError(f'Invalid lookahead steps: {k}') 23 | defaults = dict(lookahead_alpha=alpha, lookahead_k=k, lookahead_step=0) 24 | self.base_optimizer = base_optimizer 25 | self.param_groups = self.base_optimizer.param_groups 26 | self.defaults = base_optimizer.defaults 27 | self.defaults.update(defaults) 28 | self.state = defaultdict(dict) 29 | # manually add our defaults to the param groups 30 | for name, default in defaults.items(): 31 | for group in self.param_groups: 32 | group.setdefault(name, default) 33 | 34 | def update_slow(self, group): 35 | for fast_p in group["params"]: 36 | if fast_p.grad is None: 37 | continue 38 | param_state = self.state[fast_p] 39 | if 'slow_buffer' not in param_state: 40 | param_state['slow_buffer'] = torch.empty_like(fast_p.data) 41 | param_state['slow_buffer'].copy_(fast_p.data) 42 | slow = param_state['slow_buffer'] 43 | slow.add_(group['lookahead_alpha'], fast_p.data - slow) 44 | fast_p.data.copy_(slow) 45 | 46 | def sync_lookahead(self): 47 | for group in self.param_groups: 48 | self.update_slow(group) 49 | 50 | def step(self, closure=None): 51 | # print(self.k) 52 | # assert id(self.param_groups) == id(self.base_optimizer.param_groups) 53 | loss = self.base_optimizer.step(closure) 54 | for group in self.param_groups: 55 | group['lookahead_step'] += 1 56 | if group['lookahead_step'] % group['lookahead_k'] == 0: 57 | self.update_slow(group) 58 | return loss 59 | 60 | def state_dict(self): 61 | fast_state_dict = self.base_optimizer.state_dict() 62 | slow_state = { 63 | (id(k) if isinstance(k, torch.Tensor) else k): v 64 | for k, v in self.state.items() 65 | } 66 | fast_state = fast_state_dict['state'] 67 | param_groups = fast_state_dict['param_groups'] 68 | return { 69 | 'state': fast_state, 70 | 'slow_state': slow_state, 71 | 'param_groups': param_groups, 72 | } 73 | 74 | def load_state_dict(self, state_dict): 75 | fast_state_dict = { 76 | 'state': state_dict['state'], 77 | 'param_groups': state_dict['param_groups'], 78 | } 79 | self.base_optimizer.load_state_dict(fast_state_dict) 80 | 81 | # We want to restore the slow state, but share param_groups reference 82 | # with base_optimizer. This is a bit redundant but least code 83 | slow_state_new = False 84 | if 'slow_state' not in state_dict: 85 | print('Loading state_dict from optimizer without Lookahead applied.') 86 | state_dict['slow_state'] = defaultdict(dict) 87 | slow_state_new = True 88 | slow_state_dict = { 89 | 'state': state_dict['slow_state'], 90 | 'param_groups': state_dict['param_groups'], # this is pointless but saves code 91 | } 92 | super(Lookahead, self).load_state_dict(slow_state_dict) 93 | self.param_groups = self.base_optimizer.param_groups # make both ref same container 94 | if slow_state_new: 95 | # reapply defaults to catch missing lookahead specific ones 96 | for name, default in self.defaults.items(): 97 | for group in self.param_groups: 98 | group.setdefault(name, default) 99 | 100 | 101 | def LookaheadAdam(params, alpha=0.5, k=6, *args, **kwargs): 102 | adam = Adam(params, *args, **kwargs) 103 | return Lookahead(adam, alpha, k) 104 | -------------------------------------------------------------------------------- /optim/torchtools/optim/over9000.py: -------------------------------------------------------------------------------- 1 | #### 2 | # CODE TAKEN FROM https://github.com/mgrankin/over9000 3 | #### 4 | 5 | from .lookahead import Lookahead 6 | from .ralamb import Ralamb 7 | 8 | 9 | # RAdam + LARS + LookAHead 10 | 11 | # Lookahead implementation from https://github.com/lonePatient/lookahead_pytorch/blob/master/optimizer.py 12 | # RAdam + LARS implementation from https://gist.github.com/redknightlois/c4023d393eb8f92bb44b2ab582d7ec20 13 | 14 | def Over9000(params, alpha=0.5, k=6, *args, **kwargs): 15 | ralamb = Ralamb(params, *args, **kwargs) 16 | return Lookahead(ralamb, alpha, k) 17 | 18 | 19 | RangerLars = Over9000 20 | -------------------------------------------------------------------------------- /optim/torchtools/optim/ralamb.py: -------------------------------------------------------------------------------- 1 | #### 2 | # CODE TAKEN FROM https://github.com/mgrankin/over9000 3 | #### 4 | 5 | import math 6 | import torch 7 | from torch.optim.optimizer import Optimizer 8 | 9 | # RAdam + LARS 10 | class Ralamb(Optimizer): 11 | 12 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0): 13 | defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) 14 | self.buffer = [[None, None, None] for ind in range(10)] 15 | super(Ralamb, self).__init__(params, defaults) 16 | 17 | def __setstate__(self, state): 18 | super(Ralamb, self).__setstate__(state) 19 | 20 | def step(self, closure=None): 21 | 22 | loss = None 23 | if closure is not None: 24 | loss = closure() 25 | 26 | for group in self.param_groups: 27 | 28 | for p in group['params']: 29 | if p.grad is None: 30 | continue 31 | grad = p.grad.data.float() 32 | if grad.is_sparse: 33 | raise RuntimeError('Ralamb does not support sparse gradients') 34 | 35 | p_data_fp32 = p.data.float() 36 | 37 | state = self.state[p] 38 | 39 | if len(state) == 0: 40 | state['step'] = 0 41 | state['exp_avg'] = torch.zeros_like(p_data_fp32) 42 | state['exp_avg_sq'] = torch.zeros_like(p_data_fp32) 43 | else: 44 | state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32) 45 | state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_data_fp32) 46 | 47 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] 48 | beta1, beta2 = group['betas'] 49 | 50 | # Decay the first and second moment running average coefficient 51 | # m_t 52 | exp_avg.mul_(beta1).add_(1 - beta1, grad) 53 | # v_t 54 | exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) 55 | 56 | state['step'] += 1 57 | buffered = self.buffer[int(state['step'] % 10)] 58 | 59 | if state['step'] == buffered[0]: 60 | N_sma, radam_step_size = buffered[1], buffered[2] 61 | else: 62 | buffered[0] = state['step'] 63 | beta2_t = beta2 ** state['step'] 64 | N_sma_max = 2 / (1 - beta2) - 1 65 | N_sma = N_sma_max - 2 * state['step'] * beta2_t / (1 - beta2_t) 66 | buffered[1] = N_sma 67 | 68 | # more conservative since it's an approximated value 69 | if N_sma >= 5: 70 | radam_step_size = math.sqrt((1 - beta2_t) * (N_sma - 4) / (N_sma_max - 4) * (N_sma - 2) / N_sma * N_sma_max / (N_sma_max - 2)) / (1 - beta1 ** state['step']) 71 | else: 72 | radam_step_size = 1.0 / (1 - beta1 ** state['step']) 73 | buffered[2] = radam_step_size 74 | 75 | if group['weight_decay'] != 0: 76 | p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32) 77 | 78 | # more conservative since it's an approximated value 79 | radam_step = p_data_fp32.clone() 80 | if N_sma >= 5: 81 | denom = exp_avg_sq.sqrt().add_(group['eps']) 82 | radam_step.addcdiv_(-radam_step_size * group['lr'], exp_avg, denom) 83 | else: 84 | radam_step.add_(-radam_step_size * group['lr'], exp_avg) 85 | 86 | radam_norm = radam_step.pow(2).sum().sqrt() 87 | weight_norm = p.data.pow(2).sum().sqrt().clamp(0, 10) 88 | if weight_norm == 0 or radam_norm == 0: 89 | trust_ratio = 1 90 | else: 91 | trust_ratio = weight_norm / radam_norm 92 | 93 | state['weight_norm'] = weight_norm 94 | state['adam_norm'] = radam_norm 95 | state['trust_ratio'] = trust_ratio 96 | 97 | if N_sma >= 5: 98 | p_data_fp32.addcdiv_(-radam_step_size * group['lr'] * trust_ratio, exp_avg, denom) 99 | else: 100 | p_data_fp32.add_(-radam_step_size * group['lr'] * trust_ratio, exp_avg) 101 | 102 | p.data.copy_(p_data_fp32) 103 | 104 | return loss -------------------------------------------------------------------------------- /optim/torchtools/optim/ranger.py: -------------------------------------------------------------------------------- 1 | #### 2 | # CODE TAKEN FROM https://github.com/lessw2020/Ranger-Deep-Learning-Optimizer 3 | # Blog post: https://medium.com/@lessw/new-deep-learning-optimizer-ranger-synergistic-combination-of-radam-lookahead-for-the-best-of-2dc83f79a48d 4 | #### 5 | 6 | from .lookahead import Lookahead 7 | from .radam import RAdam 8 | 9 | 10 | def Ranger(params, alpha=0.5, k=6, betas=(.95, 0.999), *args, **kwargs): 11 | radam = RAdam(params, betas=betas, *args, **kwargs) 12 | return Lookahead(radam, alpha, k) 13 | -------------------------------------------------------------------------------- /pretrained/README.md: -------------------------------------------------------------------------------- 1 | # 此文件夹存放预训练模型 -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch==1.5.0 2 | torchvision==0.6.0 3 | prefetch_generator==1.0.1 4 | scipy==1.4.1 5 | opencv-python==4.2.0.34 6 | matplotlib==3.2.1 7 | imgaug==0.4.0 8 | pandas==1.0.4 9 | seaborn==0.10.1 10 | scikit-learn==0.23.1 11 | apex==0.1 12 | tensorboard==2.2.2 13 | git+https://github.com/szagoruyko/pytorchviz 14 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | File __init__.py.py 4 | 5 | 工具 6 | """ 7 | from .confusion_matrix import plot_confusion_matrix 8 | from .meters import AverageMeter, ProgressMeter 9 | from .my_summary import summary 10 | from .my_logger import generate_logger 11 | from .cam_tool import HeatMapTool, CAM, GradCAM, GradCamPlusPlus 12 | -------------------------------------------------------------------------------- /utils/cam_tool/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | File __init__.py.py 4 | 5 | 可视化工具:CAM、Grad-CAM、Grad-CAM++系列 6 | """ 7 | from .heatmap import HeatMapTool 8 | from .cam import CAM 9 | from .grad_cam import GradCAM 10 | from .grad_cam_plus import GradCamPlusPlus 11 | -------------------------------------------------------------------------------- /utils/cam_tool/cam.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | File cam.py 4 | 5 | CAM可视化 6 | """ 7 | import numpy as np 8 | import torch 9 | 10 | from dataloader.my_dataloader import DataLoaderX 11 | 12 | 13 | class CAM: 14 | """ 15 | Class Activation Mapping可视化 16 | """ 17 | 18 | def __init__(self, model, module_name): 19 | """ 20 | 在模型的指定模块上注册 forward hook 21 | :param model: 模型 22 | :param module_name: 网络最后的卷积层模块的名称 23 | """ 24 | self.model = model 25 | self.model.eval() 26 | self.features = None 27 | # 最后一层全连接层的权重矩阵,(num_classes, num_conv_channel) 28 | self.fc_weights = np.squeeze(list(model.parameters())[-2].data.cpu().numpy()) 29 | # 获取最后的卷积层的前向输出 30 | self.hook = getattr(model, module_name).register_forward_hook(self._hook_fn) 31 | 32 | def _hook_fn(self, module, input, output): 33 | """ 34 | hook函数,进行自定义操作 35 | :param module: 整个模块,若需要可对内部层进行细致的操控 36 | :param input: 模块输入 37 | :param output: 模块输出 38 | """ 39 | self.features = output.detach().cpu().numpy() 40 | print("features shape:{}".format(output.size())) 41 | 42 | def remove(self): 43 | """ 移除hook """ 44 | self.hook.remove() 45 | 46 | def __call__(self, images, args) -> (torch.tensor, np.ndarray): 47 | """ 48 | 计算CAM图 49 | :param images: pytorch的tensor,[N,3,H,W] 50 | :param args: 超参,主要用cuda选项和预处理选项 51 | :return: 模型预测输出 (N, num_classes),cam图 (N, num_classes, H_conv, W_conv) 52 | """ 53 | with torch.no_grad(): 54 | images = DataLoaderX.normalize(images, args) # 图像标准化 55 | if args.cuda: 56 | images = images.cuda(args.gpu, non_blocking=True) 57 | outputs = self.model(images) 58 | cams = self._calc_cam() 59 | return outputs, cams 60 | 61 | def _calc_cam(self) -> np.ndarray: 62 | """ 63 | 计算CAM,cam = sum_nc(w * features) 64 | :return batch中每个image的每个CAM,(N, num_classes, H, W) 65 | """ 66 | bz, nc, h, w = self.features.shape # batch的最后卷积层,(N, num_conv_channel, H, W) 67 | batch_cams = self.fc_weights @ self.features.transpose(1, 0, 2, 3).reshape(nc, -1) # C(NHW) 68 | batch_cams = batch_cams.reshape(-1, bz, h, w).transpose(1, 0, 2, 3) # NCHW 69 | # 各自image的标准化 70 | batch_cams -= np.min(batch_cams, axis=(1, 2, 3), keepdims=True) 71 | batch_cams /= np.max(batch_cams, axis=(1, 2, 3), keepdims=True) 72 | batch_cam_imgs = np.uint8(255 * batch_cams) 73 | return batch_cam_imgs 74 | -------------------------------------------------------------------------------- /utils/cam_tool/grad_cam.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | File grad_cam.py 4 | 5 | Grad-CAM 可视化 6 | """ 7 | import numpy as np 8 | 9 | from dataloader.my_dataloader import DataLoaderX 10 | 11 | 12 | class GradCAM(object): 13 | """ 14 | 1: 网络不更新梯度,输入需要梯度更新 15 | 2: 使用目标类别的得分做反向传播 16 | """ 17 | 18 | def __init__(self, net, layer_name): 19 | self.net = net 20 | self.net.eval() 21 | self.layer_name = layer_name 22 | self.features = None 23 | self.gradients = None 24 | self.handlers = [] 25 | self._register_hook() 26 | 27 | def _register_hook(self): 28 | handler = getattr(self.net, self.layer_name).register_forward_hook(self._get_features_hook) 29 | self.handlers.append(handler) 30 | handler = getattr(self.net, self.layer_name).register_backward_hook(self._get_grads_hook) 31 | self.handlers.append(handler) 32 | 33 | def _get_features_hook(self, module, input, output): 34 | self.features = output.cpu().data.numpy() # [N,C,H,W] 35 | print("features shape:{}".format(self.features.shape)) 36 | 37 | def _get_grads_hook(self, module, input_grad, output_grad): 38 | """ 39 | :param input_grad: tuple, input_grad[0]: None 40 | input_grad[1]: weight 41 | input_grad[2]: bias 42 | :param output_grad:tuple,长度为1 43 | :return: 44 | """ 45 | self.gradients = output_grad[0].cpu().data.numpy() # [N,C,H,W] 46 | print("gradients shape:{}".format(self.gradients.shape)) 47 | 48 | def remove_handlers(self): 49 | for handle in self.handlers: 50 | handle.remove() 51 | 52 | def __call__(self, images, args): 53 | """ 54 | 计算Grad-CAM图 55 | :param images: pytorch的tensor,[N,3,H,W] 56 | :param args: 超参,主要用cuda选项和预处理选项 57 | :return: 58 | """ 59 | images = DataLoaderX.normalize(images, args) # 图像标准化 60 | if args.cuda: 61 | images = images.cuda(args.gpu, non_blocking=True) 62 | # 前向得到特征图 63 | outputs = self.net(images) # [N, num_classes] 64 | # 后向,得到梯度,计算每个特征图的加权 65 | weights = self._calc_grad(outputs) # [N, num_classes, C] 66 | # grad-cam 67 | batch_cam_imgs = self._calc_grad_cam(weights) 68 | return outputs, batch_cam_imgs 69 | 70 | def _calc_grad(self, output): 71 | """ 72 | :param output: 模型输出,[N, num_classes] 73 | :return 梯度图均值 74 | """ 75 | weights = [] # [N, num_classes] 76 | for index in range(output.size(1)): 77 | self.net.zero_grad() 78 | targets = output[:, index].sum() 79 | targets.backward(retain_graph=True) 80 | weight = np.mean(self.gradients, axis=(2, 3)) # [N, C] 81 | weights.append(weight) 82 | weights = np.stack(weights, axis=1) # [N, num_classes, C] 83 | return weights 84 | 85 | def _calc_grad_cam(self, weights): 86 | """ 87 | :param weights: 每个类别的梯度对每张特征图的加权权重,[N, num_classes, C] 88 | :return 89 | """ 90 | bz, nc, h, w = self.features.shape # [N,C,H,W] 91 | batch_cams = [] # [N, num_classes, H, W] 92 | for i in range(bz): 93 | cams = weights[i] @ self.features[i].reshape(nc, -1) 94 | cams = cams.reshape(-1, h, w) 95 | batch_cams.append(cams) 96 | batch_cams = np.array(batch_cams, dtype=np.float32) # [N, num_classes, H, W] 97 | # batch_cams = np.maximum(batch_cams, 0) # ReLU 98 | # 数值归一化 99 | batch_cams -= np.min(batch_cams, axis=(1, 2, 3), keepdims=True) 100 | batch_cams /= np.max(batch_cams, axis=(1, 2, 3), keepdims=True) 101 | batch_cam_imgs = np.uint8(255 * batch_cams) 102 | return batch_cam_imgs 103 | -------------------------------------------------------------------------------- /utils/cam_tool/grad_cam_plus.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | File grad_cam_plus.py 4 | 5 | Grad-CAM++ 可视化 6 | """ 7 | import numpy as np 8 | 9 | from .grad_cam import GradCAM 10 | 11 | 12 | class GradCamPlusPlus(GradCAM): 13 | def __init__(self, net, layer_name): 14 | super(GradCamPlusPlus, self).__init__(net, layer_name) 15 | 16 | def _calc_grad(self, output): 17 | """ 18 | :param output: 模型输出,[N, num_classes] 19 | :return 梯度图均值 20 | """ 21 | weights = [] # [N, num_classes] 22 | for index in range(output.size(1)): 23 | self.net.zero_grad() 24 | targets = output[:, index].sum() 25 | targets.backward(retain_graph=True) 26 | 27 | gradients = np.maximum(self.gradients, 0.) # ReLU, [N,C,H,W] 28 | # indicates = np.where(gradients > 0, 1., 0.) # 示性函数 29 | norm_factor = np.sum(gradients, axis=(2, 3), keepdims=True) # 归一化 30 | # alpha = indicates / (norm_factor + 1e-7) 31 | weight = np.sum(gradients / (norm_factor + 1e-7), axis=(2, 3)) # [N, C] 32 | weights.append(weight) 33 | weights = np.stack(weights, axis=1) # [N, num_classes, C] 34 | return weights 35 | -------------------------------------------------------------------------------- /utils/cam_tool/heatmap.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | File heatmap.py 4 | 5 | 绘制热力图 6 | """ 7 | import cv2 8 | import numpy as np 9 | 10 | 11 | class HeatMapTool: 12 | 13 | @staticmethod 14 | def add_heat(image: np.ndarray, cams: np.ndarray) -> np.ndarray: 15 | """ 16 | 给图像加heatmap 17 | :param image: 单张图像,bgr uint8格式 18 | :param cams: 类激活图,灰度图,uint8格式 19 | :return: 原图 和 叠加了热力图图像 按宽方向拼接起来的图 20 | """ 21 | h, w, nc = image.shape 22 | cam_images = [image] 23 | for cam in cams: 24 | heatmap = cv2.applyColorMap(cam, cv2.COLORMAP_JET) 25 | heatmap = cv2.resize(heatmap, (w, h)) 26 | cam_image = heatmap * 0.5 + image * 0.5 27 | cam_images.append(cam_image) 28 | return np.concatenate(cam_images, axis=1).astype(np.uint8) 29 | -------------------------------------------------------------------------------- /utils/check_images.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | File check_images.py 4 | 5 | 检查指定文件夹下,图片的有效性 6 | """ 7 | import logging 8 | 9 | import os 10 | from PIL import Image 11 | 12 | 13 | def check(image_paths): 14 | """ 15 | 检查图像文件是否可读、可用有效 16 | :param image_paths: 图像文件所在路径 17 | """ 18 | bad_paths = list() 19 | for image_path in image_paths: 20 | try: 21 | img = Image.open(image_path) 22 | img.resize((360, 640)) 23 | except Exception: 24 | bad_paths.append(image_path) 25 | return bad_paths 26 | 27 | 28 | def check_set(set_dir): 29 | """ 30 | 检查整个图像数据集中的图像文件是否可读有效 31 | :param set_dir: 图像数据集文件夹路径 32 | """ 33 | for mid_dir in os.listdir(set_dir): 34 | image_dir = os.path.join(set_dir, mid_dir) 35 | image_paths = [os.path.join(image_dir, image_name) 36 | for image_name in os.listdir(image_dir)] 37 | logging.info(check(image_paths)) 38 | 39 | 40 | if __name__ == '__main__': 41 | for _set_dir in ['train', 'test', 'val']: 42 | check_set(_set_dir) 43 | -------------------------------------------------------------------------------- /utils/confusion_matrix.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | File confusion_matrix.py 4 | 5 | 混淆矩阵 6 | """ 7 | import os 8 | import numpy as np 9 | import pandas as pd 10 | import matplotlib.pyplot as plt 11 | import seaborn as sns 12 | from sklearn.metrics import confusion_matrix 13 | 14 | 15 | def plot_confusion_matrix(y_true, y_pred, labels, title='ConfusionMatrix', is_save=False): 16 | """ 17 | 绘制混淆矩阵 18 | :param y_true: 正确类别标签 19 | :param y_pred: 预测类别标签 20 | :param labels: 类别标签列表 21 | :param title: 图名 22 | :param is_save: 是否保存图片 23 | :return: 24 | """ 25 | if labels: 26 | y_true = [labels[int(i)] for i in y_true] 27 | y_pred = [labels[int(i)] for i in y_pred] 28 | # 计算混淆矩阵,y轴是true,x轴是predicted 29 | conf_matrix = confusion_matrix(y_true, y_pred, labels=labels) 30 | conf_matrix_pred_sum = np.sum(conf_matrix, axis=0, keepdims=True).astype(float) + 1e-7 31 | conf_matrix_percent = conf_matrix / conf_matrix_pred_sum * 100 # 沿y轴的百分比 32 | 33 | annot = np.empty_like(conf_matrix).astype(str) 34 | nrows, ncols = conf_matrix.shape 35 | for i in range(nrows): 36 | for j in range(ncols): 37 | c = conf_matrix[i, j] 38 | p = conf_matrix_percent[i, j] 39 | if i == j: 40 | s = conf_matrix_pred_sum[0][i] 41 | # annot[i, j] = '%.2f%%\n%d/%d' % (p, c, s) 42 | annot[i, j] = '%.2f%%\n%d' % (p, c) 43 | elif c == 0: 44 | annot[i, j] = '' 45 | else: 46 | annot[i, j] = '%.2f%%\n%d' % (p, c) 47 | 48 | # 绘制混淆矩阵图 49 | conf_matrix = pd.DataFrame(conf_matrix, index=labels, columns=labels, dtype='float') 50 | fig = plt.figure(figsize=(10, 10)) 51 | ax = fig.gca() 52 | # Oranges,Oranges_r,YlGnBu,Blues,RdBu, PuRd ... 53 | sns.heatmap(conf_matrix, annot=annot, fmt='', ax=ax, cmap='YlGnBu', 54 | annot_kws={"size": 11}, linewidths=0.5) 55 | # 设置坐标轴 56 | ax.set_xticklabels(ax.get_xticklabels(), rotation=25, fontsize=10) 57 | ax.xaxis.set_ticks_position('none') 58 | ax.set_yticklabels(ax.get_yticklabels(), rotation=25, fontsize=10) 59 | ax.yaxis.set_ticks_position('none') 60 | 61 | plt.title(title, size=18) 62 | plt.xlabel('Predicted', size=16) 63 | plt.ylabel('Actual', size=16) 64 | plt.tight_layout() 65 | if is_save: 66 | plt.savefig(os.path.join('.', title+'.png')) 67 | else: 68 | plt.show() 69 | 70 | 71 | if __name__ == '__main__': 72 | y_predict = np.random.randint(low=0, high=10, size=(100,)) 73 | y_truth = np.random.randint(low=0, high=10, size=(100,)) 74 | y_labels = [str(i)+'s' for i in range(10)] 75 | plot_confusion_matrix(y_truth, y_predict, y_labels) 76 | -------------------------------------------------------------------------------- /utils/meters.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | File meters.py 4 | 5 | 评估量:记录,打印 6 | """ 7 | import logging 8 | 9 | 10 | class AverageMeter: 11 | """ 计算并存储 评估量的均值和当前值 """ 12 | def __init__(self, name, fmt=':f'): 13 | self.name = name # 评估量名称 14 | self.fmt = fmt # 评估量打印格式 15 | self.val = 0 # 评估量当前值 16 | self.avg = 0 # 评估量均值 17 | self.sum = 0 # 历史评估量的和 18 | self.count = 0 # 历史评估量的数量 19 | 20 | def reset(self): 21 | self.val = 0 22 | self.avg = 0 23 | self.sum = 0 24 | self.count = 0 25 | 26 | def update(self, val, n=1): 27 | self.val = val 28 | self.sum += val * n 29 | self.count += n 30 | self.avg = self.sum / self.count 31 | 32 | def __str__(self): 33 | fmtstr = f'{{name}} {{val{self.fmt}}} ({{avg{self.fmt}}})' 34 | return fmtstr.format(**self.__dict__) 35 | 36 | 37 | class ProgressMeter: 38 | """ 评估量的进度条打印 """ 39 | def __init__(self, num_batches, *meters, prefix=""): 40 | self.batch_fmtstr = self._get_batch_fmtstr(num_batches) 41 | self.meters = meters 42 | self.prefix = prefix 43 | 44 | def print(self, batch): 45 | entries = [self.prefix + self.batch_fmtstr.format(batch)] 46 | entries += [str(meter) for meter in self.meters] 47 | logging.info('\t'.join(entries)) 48 | 49 | @staticmethod 50 | def _get_batch_fmtstr(num_batches): 51 | num_digits = len(str(num_batches // 1)) 52 | fmt = f'{{:{str(num_digits)}d}}' 53 | return f'[{fmt}/{fmt.format(num_batches)}]' 54 | -------------------------------------------------------------------------------- /utils/my_logger.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | File my_loggers.py 4 | 5 | 日志对应 6 | """ 7 | import logging 8 | from logging.handlers import RotatingFileHandler 9 | 10 | 11 | def generate_logger(filename, **log_params): 12 | """ 13 | 生成日志记录对象记录日志 14 | :param filename: 日志文件名称 15 | :param log_params: 日志参数 16 | :return: 17 | """ 18 | level = log_params.setdefault('level', logging.INFO) 19 | 20 | logger = logging.getLogger() 21 | logger.setLevel(level=level) 22 | formatter = logging.Formatter('%(asctime)s %(filename)s:%(lineno)d %(levelname)s %(message)s') 23 | # 定义一个RotatingFileHandler,最多备份3个日志文件,每个日志文件最大10M 24 | file_handler = RotatingFileHandler(filename, maxBytes=10 * 1024 * 1024, backupCount=3) 25 | file_handler.setFormatter(formatter) 26 | # 控制台输出 27 | console = logging.StreamHandler() 28 | console.setFormatter(formatter) 29 | 30 | logger.addHandler(file_handler) 31 | logger.addHandler(console) 32 | -------------------------------------------------------------------------------- /utils/network_viz.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 使用graphviz进行网络可视化 """ 3 | import torch 4 | from torchviz import make_dot 5 | 6 | import models as my_models 7 | 8 | 9 | if __name__ == '__main__': 10 | model = my_models.get_model('mobilenetv3_large') 11 | width = 224 12 | vis_graph = make_dot(model(torch.randn((1, 3, width, width))), 13 | params=dict(model.named_parameters())) 14 | vis_graph.view() 15 | -------------------------------------------------------------------------------- /visual_images/README.md: -------------------------------------------------------------------------------- 1 | # 图像可视化结果示例 -------------------------------------------------------------------------------- /z_task_shell/0_check_best_lr.sh: -------------------------------------------------------------------------------- 1 | # 训练时,先预估最大学习率 2 | CUDA_VISIBLE_DEVICES=0 python main.py --data data/ --train \ 3 | --arch efficientnet_b0 --num_classes 3 \ 4 | --criterion bce \ 5 | --opt sgd \ 6 | --lr_ratios 0.01 0.1 1. 10. 100. 1.e3 1.e4 1.e5 1.e6 \ 7 | --lr_steps 3 6 9 12 15 18 21 24 27 \ 8 | --epoches 27 --warmup -1 \ 9 | -b 128 -j 16 \ 10 | --image_size 400 224 \ 11 | --aug --pretrained \ 12 | --gpus 1 --nodes 1 13 | -------------------------------------------------------------------------------- /z_task_shell/1_print_model_info.sh: -------------------------------------------------------------------------------- 1 | # 打印模型信息 2 | python main.py --arch efficientnet_b0 --num_classes 3 --image_size 400 224 3 | -------------------------------------------------------------------------------- /z_task_shell/2_train_cpu_or_gpu.sh: -------------------------------------------------------------------------------- 1 | # 单GPU或cpu训练 2 | # --curriculum_learning 开启课程学习 3 | # --distill 开启蒸馏 4 | CUDA_VISIBLE_DEVICES=0 python main.py --train \ 5 | --arch efficientnet_b0 --num_classes 3 \ 6 | --criterion bce --weighted_loss --ohm_loss --ghm_loss --threshold_loss \ 7 | --opt adam \ 8 | --epoches 65 --warmup 5 \ 9 | -b 128 -j 16 \ 10 | --image_size 400 224 \ 11 | --mixup --multi_scale --aug --pretrained \ 12 | --gpus 1 --nodes 1 13 | -------------------------------------------------------------------------------- /z_task_shell/3_train_distributed.sh: -------------------------------------------------------------------------------- 1 | # 分布式训练 2 | # --curriculum_learning 开启课程学习 3 | # --distill 开启蒸馏 4 | CUDA_VISIBLE_DEVICES=0,1 python main.py --train \ 5 | --arch efficientnet_b0 --num_classes 3 \ 6 | --criterion bce --weighted_loss --ohm_loss --ghm_loss --threshold_loss \ 7 | --opt adam \ 8 | --epoches 65 --warmup 5 \ 9 | -b 128 -j 16 \ 10 | --image_size 400 224 \ 11 | --mixup --multi_scale --aug --pretrained \ 12 | --init_method tcp://111.111.111.111:11111 \ 13 | --sync_bn --gpus 2 --nodes 2 --rank 0 -------------------------------------------------------------------------------- /z_task_shell/4_evaluate_model.sh: -------------------------------------------------------------------------------- 1 | # 用多/单GPU评估模型(没GPU则自动退化为CPU) 2 | python main.py --data data/ -e --arch efficientnet_b0 --num_classes 3 \ 3 | --criterion bce --image_size 400 224 \ 4 | --batch_size 256 -j 16 \ 5 | --resume checkpoints/model_best_efficientnet_b0.pth -g 1 -n 1 6 | -------------------------------------------------------------------------------- /z_task_shell/5_visualize_model_layer.sh: -------------------------------------------------------------------------------- 1 | # 模型、数据可视化 2 | python main.py --data data --visual_data test/ --visual_method all \ 3 | --arch efficientnet_b0 --num_classes 3 \ 4 | --criterion bce \ 5 | -b 5 -j 0 \ 6 | --image_size 400 224 \ 7 | ---resume checkpoints/model_best_efficientnet_b0.pth 8 | -------------------------------------------------------------------------------- /z_task_shell/6_make_curriculum.sh: -------------------------------------------------------------------------------- 1 | # 课程学习时,制作不同样本的损失函数权重文件 2 | python main.py --data data --make_curriculum train \ 3 | --curriculum_thresholds 0.7 0.5 0.3 0.0 \ 4 | --curriculum_weights 1 0.7 0.4 0.1 \ 5 | --criterion bce \ 6 | --image_size 400 224 -b 256 -j 16 \ 7 | ---resume checkpoints/model_best_efficientnet_b0.pth \ 8 | -g 1 -n 1 9 | -------------------------------------------------------------------------------- /z_task_shell/7_knowledge_distillation.sh: -------------------------------------------------------------------------------- 1 | # 模型蒸馏时,准备教师模型的概率文件 2 | python main.py --data data --knowledge train \ 3 | --arch efficientnet_b0 --num_classes 3 \ 4 | --criterion bce \ 5 | --image_size 400 224 -b 256 -j 16 \ 6 | ---resume checkpoints/model_best_efficientnet_b0.pth \ 7 | -g 1 -n 1 8 | -------------------------------------------------------------------------------- /z_task_shell/8_convert_to_jit.sh: -------------------------------------------------------------------------------- 1 | # 将指定的模型文件转为JIT格式 2 | python main.py --jit --arch efficientnet_b0 --num_classes 3 --image_size 400 224 \ 3 | --resume checkpoints/model_best_efficientnet_b0.pth -g 0 4 | --------------------------------------------------------------------------------